From 7432b58f82e84408fe0828b7f8cec5d99f212944 Mon Sep 17 00:00:00 2001 From: 99 Date: Thu, 5 Mar 2026 14:31:28 +0800 Subject: [PATCH 1/5] refactor(dify_graph): introduce run_context and delegate child engine creation (#32964) --- api/.importlinter | 13 -- api/core/app/apps/pipeline/pipeline_runner.py | 16 +- api/core/app/apps/workflow_app_runner.py | 27 ++-- api/core/app/entities/app_invoke_entities.py | 66 +++++++- api/core/app/workflow/layers/llm_quota.py | 3 +- api/core/workflow/node_factory.py | 18 ++- api/core/workflow/workflow_entry.py | 92 ++++++++++-- api/dify_graph/entities/graph_init_params.py | 8 +- api/dify_graph/enums.py | 33 ---- api/dify_graph/graph_engine/graph_engine.py | 27 +++- api/dify_graph/nodes/agent/agent_node.py | 32 ++-- api/dify_graph/nodes/base/node.py | 59 +++++++- .../nodes/datasource/datasource_node.py | 9 +- api/dify_graph/nodes/http_request/node.py | 7 +- .../nodes/human_input/human_input_node.py | 32 ++-- .../nodes/iteration/iteration_node.py | 53 ++----- .../knowledge_index/knowledge_index_node.py | 6 +- .../knowledge_retrieval_node.py | 22 +-- api/dify_graph/nodes/llm/node.py | 9 +- api/dify_graph/nodes/loop/loop_node.py | 34 +---- .../parameter_extractor_node.py | 2 +- .../question_classifier_node.py | 7 +- api/dify_graph/nodes/tool/tool_node.py | 17 ++- api/dify_graph/nodes/trigger_webhook/node.py | 3 +- api/dify_graph/runtime/__init__.py | 10 +- api/dify_graph/runtime/graph_runtime_state.py | 52 +++++++ api/services/workflow_app_service.py | 11 +- api/services/workflow_service.py | 16 +- .../workflow/nodes/test_code.py | 12 +- .../workflow/nodes/test_http.py | 18 +-- .../workflow/nodes/test_llm.py | 12 +- .../nodes/test_parameter_extractor.py | 12 +- .../workflow/nodes/test_template_transform.py | 12 +- .../workflow/nodes/test_tool.py | 12 +- .../test_human_input_resume_node_execution.py | 8 +- .../core/app/apps/test_pause_resume.py | 8 +- .../graph/test_graph_skip_validation.py | 14 +- .../workflow/graph/test_graph_validation.py | 14 +- .../graph_engine/layers/test_llm_quota.py | 5 + .../graph_engine/test_auto_mock_system.py | 23 ++- .../graph_engine/test_command_system.py | 47 +++--- .../graph_engine/test_graph_state_snapshot.py | 7 +- .../test_human_input_pause_multi_branch.py | 8 +- .../test_human_input_pause_single_branch.py | 8 +- .../graph_engine/test_if_else_streaming.py | 9 +- .../test_mock_iteration_simple.py | 53 ++++--- .../workflow/graph_engine/test_mock_nodes.py | 12 +- .../test_mock_nodes_template_code.py | 141 +++++++++++------- .../workflow/graph_engine/test_mock_simple.py | 36 +++-- .../test_parallel_human_input_join_resume.py | 8 +- ...rallel_human_input_pause_missing_finish.py | 8 +- .../test_parallel_streaming_workflow.py | 16 +- .../test_pause_deferred_ready_nodes.py | 8 +- .../graph_engine/test_pause_resume_state.py | 8 +- .../graph_engine/test_table_runner.py | 71 +++++++-- .../core/workflow/nodes/answer/test_answer.py | 12 +- .../nodes/datasource/test_datasource_node.py | 15 +- .../http_request/test_http_request_node.py | 12 +- .../nodes/human_input/test_entities.py | 57 ++++--- .../test_human_input_form_filled_event.py | 34 +++-- .../test_iteration_child_engine_errors.py | 100 +++++++++++++ .../test_knowledge_index_node.py | 12 +- .../test_knowledge_retrieval_node.py | 12 +- .../workflow/nodes/list_operator/node_spec.py | 112 ++++++-------- .../core/workflow/nodes/llm/test_node.py | 10 +- .../template_transform_node_spec.py | 13 +- .../core/workflow/nodes/test_base_node.py | 36 ++--- .../nodes/test_document_extractor_node.py | 11 +- .../core/workflow/nodes/test_if_else.py | 51 ++++--- .../core/workflow/nodes/test_list_operator.py | 19 ++- .../nodes/test_start_node_json_object.py | 8 +- .../workflow/nodes/tool/test_tool_node.py | 8 +- .../v1/test_variable_assigner_v1.py | 46 +++--- .../v2/test_variable_assigner_v2.py | 74 +++++---- .../webhook/test_webhook_file_conversion.py | 21 +-- .../nodes/webhook/test_webhook_node.py | 21 +-- .../test_workflow_entry_redis_channel.py | 3 +- api/tests/workflow_test_utils.py | 53 +++++++ 78 files changed, 1281 insertions(+), 733 deletions(-) create mode 100644 api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py create mode 100644 api/tests/workflow_test_utils.py diff --git a/api/.importlinter b/api/.importlinter index 10faeb448a..e4536b1f10 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -28,17 +28,8 @@ ignore_imports = dify_graph.nodes.iteration.iteration_node -> dify_graph.graph_events dify_graph.nodes.loop.loop_node -> dify_graph.graph_events - dify_graph.nodes.iteration.iteration_node -> core.workflow.node_factory - dify_graph.nodes.loop.loop_node -> core.workflow.node_factory - dify_graph.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota - dify_graph.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota - dify_graph.nodes.iteration.iteration_node -> dify_graph.graph_engine - dify_graph.nodes.iteration.iteration_node -> dify_graph.graph - dify_graph.nodes.iteration.iteration_node -> dify_graph.graph_engine.command_channels dify_graph.nodes.loop.loop_node -> dify_graph.graph_engine - dify_graph.nodes.loop.loop_node -> dify_graph.graph - dify_graph.nodes.loop.loop_node -> dify_graph.graph_engine.command_channels # TODO(QuantumGhost): fix the import violation later dify_graph.entities.pause_reason -> dify_graph.nodes.human_input.entities @@ -101,12 +92,9 @@ forbidden_modules = core.trigger core.variables ignore_imports = - dify_graph.nodes.loop.loop_node -> core.workflow.node_factory dify_graph.nodes.agent.agent_node -> core.model_manager dify_graph.nodes.agent.agent_node -> core.provider_manager dify_graph.nodes.agent.agent_node -> core.tools.tool_manager - dify_graph.nodes.iteration.iteration_node -> core.workflow.node_factory - dify_graph.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota dify_graph.nodes.llm.llm_utils -> core.model_manager dify_graph.nodes.llm.protocols -> core.model_manager dify_graph.nodes.llm.llm_utils -> dify_graph.model_runtime.model_providers.__base.large_language_model @@ -151,7 +139,6 @@ ignore_imports = dify_graph.nodes.llm.node -> extensions.ext_database dify_graph.nodes.tool.tool_node -> extensions.ext_database dify_graph.nodes.agent.agent_node -> models - dify_graph.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota dify_graph.nodes.llm.node -> models.model dify_graph.nodes.agent.agent_node -> services dify_graph.nodes.tool.tool_node -> services diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index 748edb7956..4222aae809 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -8,12 +8,14 @@ from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.app_invoke_entities import ( InvokeFrom, RagPipelineGenerateEntity, + UserFrom, + build_dify_run_context, ) from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer from core.workflow.node_factory import DifyNodeFactory from core.workflow.workflow_entry import WorkflowEntry from dify_graph.entities.graph_init_params import GraphInitParams -from dify_graph.enums import UserFrom, WorkflowType +from dify_graph.enums import WorkflowType from dify_graph.graph import Graph from dify_graph.graph_events import GraphEngineEvent, GraphRunFailedEvent from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository @@ -256,13 +258,15 @@ class PipelineRunner(WorkflowBasedAppRunner): # init graph # Create required parameters for Graph.init graph_init_params = GraphInitParams( - tenant_id=workflow.tenant_id, - app_id=self._app_id, workflow_id=workflow.id, graph_config=graph_config, - user_id=self.application_generate_entity.user_id, - user_from=user_from, - invoke_from=invoke_from, + run_context=build_dify_run_context( + tenant_id=workflow.tenant_id, + app_id=self._app_id, + user_id=self.application_generate_entity.user_id, + user_from=user_from, + invoke_from=invoke_from, + ), call_depth=0, ) diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index c5a00aa4ff..7ef6ff7cc2 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -4,7 +4,7 @@ from collections.abc import Mapping, Sequence from typing import Any, cast from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context from core.app.entities.queue_entities import ( AppQueueEvent, QueueAgentLogEvent, @@ -33,7 +33,6 @@ from core.workflow.node_factory import DifyNodeFactory from core.workflow.workflow_entry import WorkflowEntry from dify_graph.entities import GraphInitParams from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.enums import UserFrom from dify_graph.graph import Graph from dify_graph.graph_engine.layers.base import GraphEngineLayer from dify_graph.graph_events import ( @@ -119,13 +118,15 @@ class WorkflowBasedAppRunner: # Create required parameters for Graph.init graph_init_params = GraphInitParams( - tenant_id=tenant_id or "", - app_id=self._app_id, workflow_id=workflow_id, graph_config=graph_config, - user_id=user_id, - user_from=user_from, - invoke_from=invoke_from, + run_context=build_dify_run_context( + tenant_id=tenant_id or "", + app_id=self._app_id, + user_id=user_id, + user_from=user_from, + invoke_from=invoke_from, + ), call_depth=0, ) @@ -267,13 +268,15 @@ class WorkflowBasedAppRunner: # Create required parameters for Graph.init graph_init_params = GraphInitParams( - tenant_id=workflow.tenant_id, - app_id=self._app_id, workflow_id=workflow.id, graph_config=graph_config, - user_id="", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context=build_dify_run_context( + tenant_id=workflow.tenant_id, + app_id=self._app_id, + user_id="", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ), call_depth=0, ) diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 6ecca84425..ecbb1cf2f3 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -1,4 +1,5 @@ from collections.abc import Mapping, Sequence +from enum import StrEnum from typing import TYPE_CHECKING, Any, Optional from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator @@ -6,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validat from constants import UUID_NIL from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig from core.entities.provider_configuration import ProviderModelBundle -from dify_graph.enums import InvokeFrom +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.file import File, FileUploadConfig from dify_graph.model_runtime.entities.model_entities import AIModelEntity @@ -14,6 +15,69 @@ if TYPE_CHECKING: from core.ops.ops_trace_manager import TraceQueueManager +class UserFrom(StrEnum): + ACCOUNT = "account" + END_USER = "end-user" + + +class InvokeFrom(StrEnum): + SERVICE_API = "service-api" + WEB_APP = "web-app" + TRIGGER = "trigger" + EXPLORE = "explore" + DEBUGGER = "debugger" + PUBLISHED_PIPELINE = "published" + VALIDATION = "validation" + + @classmethod + def value_of(cls, value: str) -> "InvokeFrom": + return cls(value) + + def to_source(self) -> str: + source_mapping = { + InvokeFrom.WEB_APP: "web_app", + InvokeFrom.DEBUGGER: "dev", + InvokeFrom.EXPLORE: "explore_app", + InvokeFrom.TRIGGER: "trigger", + InvokeFrom.SERVICE_API: "api", + } + return source_mapping.get(self, "dev") + + +class DifyRunContext(BaseModel): + tenant_id: str + app_id: str + user_id: str + user_from: UserFrom + invoke_from: InvokeFrom + + +def build_dify_run_context( + *, + tenant_id: str, + app_id: str, + user_id: str, + user_from: UserFrom, + invoke_from: InvokeFrom, + extra_context: Mapping[str, Any] | None = None, +) -> dict[str, Any]: + """ + Build graph run_context with the reserved Dify runtime payload. + + `extra_context` can carry user-defined context keys. The reserved `_dify` + payload is always overwritten by this function to keep one canonical source. + """ + run_context = dict(extra_context) if extra_context else {} + run_context[DIFY_RUN_CONTEXT_KEY] = DifyRunContext( + tenant_id=tenant_id, + app_id=app_id, + user_id=user_id, + user_from=user_from, + invoke_from=invoke_from, + ) + return run_context + + class ModelConfigWithCredentialsEntity(BaseModel): """ Model Config With Credentials Entity. diff --git a/api/core/app/workflow/layers/llm_quota.py b/api/core/app/workflow/layers/llm_quota.py index f6c80d25c5..2e930a1f58 100644 --- a/api/core/app/workflow/layers/llm_quota.py +++ b/api/core/app/workflow/layers/llm_quota.py @@ -75,8 +75,9 @@ class LLMQuotaLayer(GraphEngineLayer): return try: + dify_ctx = node.require_dify_context() deduct_llm_quota( - tenant_id=node.tenant_id, + tenant_id=dify_ctx.tenant_id, model_instance=model_instance, usage=result_event.node_run_result.llm_usage, ) diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py index 714b0ca3d0..4cbee08a65 100644 --- a/api/core/workflow/node_factory.py +++ b/api/core/workflow/node_factory.py @@ -6,6 +6,7 @@ from sqlalchemy.orm import Session from typing_extensions import override from configs import dify_config +from core.app.entities.app_invoke_entities import DifyRunContext from core.app.llm.model_access import build_dify_model_access from core.datasource.datasource_manager import DatasourceManager from core.helper.code_executor.code_executor import ( @@ -22,6 +23,7 @@ from core.rag.summary_index.summary_index import SummaryIndex from core.repositories.human_input_repository import HumanInputFormRepositoryImpl from core.tools.tool_file_manager import ToolFileManager from dify_graph.entities.graph_config import NodeConfigDict +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.enums import NodeType, SystemVariableKey from dify_graph.file.file_manager import file_manager from dify_graph.graph.graph import NodeFactory @@ -110,6 +112,7 @@ class DifyNodeFactory(NodeFactory): ) -> None: self.graph_init_params = graph_init_params self.graph_runtime_state = graph_runtime_state + self._dify_context = self._resolve_dify_context(graph_init_params.run_context) self._code_executor: WorkflowCodeExecutor = DefaultWorkflowCodeExecutor() self._code_limits = CodeNodeLimits( max_string_length=dify_config.CODE_MAX_STRING_LENGTH, @@ -141,7 +144,16 @@ class DifyNodeFactory(NodeFactory): ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES, ) - self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(graph_init_params.tenant_id) + self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(self._dify_context.tenant_id) + + @staticmethod + def _resolve_dify_context(run_context: Mapping[str, Any]) -> DifyRunContext: + raw_ctx = run_context.get(DIFY_RUN_CONTEXT_KEY) + if raw_ctx is None: + raise ValueError(f"run_context missing required key: {DIFY_RUN_CONTEXT_KEY}") + if isinstance(raw_ctx, DifyRunContext): + return raw_ctx + return DifyRunContext.model_validate(raw_ctx) @override def create_node(self, node_config: NodeConfigDict) -> Node: @@ -213,7 +225,7 @@ class DifyNodeFactory(NodeFactory): config=node_config, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, - form_repository=HumanInputFormRepositoryImpl(tenant_id=self.graph_init_params.tenant_id), + form_repository=HumanInputFormRepositoryImpl(tenant_id=self._dify_context.tenant_id), ) if node_type == NodeType.KNOWLEDGE_INDEX: @@ -356,7 +368,7 @@ class DifyNodeFactory(NodeFactory): ) return fetch_memory( conversation_id=conversation_id, - app_id=self.graph_init_params.app_id, + app_id=self._dify_context.app_id, node_data_memory=node_memory, model_instance=model_instance, ) diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 6210a81c4e..c259e7ac08 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -5,26 +5,26 @@ from typing import Any, cast from configs import dify_config from core.app.apps.exc import GenerateTaskStoppedError -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context from core.app.workflow.layers.llm_quota import LLMQuotaLayer from core.app.workflow.layers.observability import ObservabilityLayer from core.workflow.node_factory import DifyNodeFactory from dify_graph.constants import ENVIRONMENT_VARIABLE_NODE_ID from dify_graph.entities import GraphInitParams from dify_graph.entities.graph_config import NodeConfigData, NodeConfigDict -from dify_graph.enums import UserFrom from dify_graph.errors import WorkflowNodeRunFailedError from dify_graph.file.models import File from dify_graph.graph import Graph from dify_graph.graph_engine import GraphEngine, GraphEngineConfig from dify_graph.graph_engine.command_channels import InMemoryChannel from dify_graph.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer +from dify_graph.graph_engine.layers.base import GraphEngineLayer from dify_graph.graph_engine.protocols.command_channel import CommandChannel from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent from dify_graph.nodes import NodeType from dify_graph.nodes.base.node import Node from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING -from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.runtime import ChildGraphNotFoundError, GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from extensions.otel.runtime import is_instrument_flag_enabled @@ -34,6 +34,66 @@ from models.workflow import Workflow logger = logging.getLogger(__name__) +class _WorkflowChildEngineBuilder: + @staticmethod + def _has_node_id(graph_config: Mapping[str, Any], node_id: str) -> bool | None: + """ + Return whether `graph_config["nodes"]` contains the given node id. + + Returns `None` when the nodes payload shape is unexpected, so graph-level + validation can surface the original configuration error. + """ + nodes = graph_config.get("nodes") + if not isinstance(nodes, list): + return None + + for node in nodes: + if not isinstance(node, Mapping): + return None + current_id = node.get("id") + if isinstance(current_id, str) and current_id == node_id: + return True + return False + + def build_child_engine( + self, + *, + workflow_id: str, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + graph_config: Mapping[str, Any], + root_node_id: str, + layers: Sequence[object] = (), + ) -> GraphEngine: + node_factory = DifyNodeFactory( + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + + has_root_node = self._has_node_id(graph_config=graph_config, node_id=root_node_id) + if has_root_node is False: + raise ChildGraphNotFoundError(f"child graph root node '{root_node_id}' not found") + + child_graph = Graph.init( + graph_config=graph_config, + node_factory=node_factory, + root_node_id=root_node_id, + ) + + child_engine = GraphEngine( + workflow_id=workflow_id, + graph=child_graph, + graph_runtime_state=graph_runtime_state, + command_channel=InMemoryChannel(), + config=GraphEngineConfig(), + child_engine_builder=self, + ) + child_engine.layer(LLMQuotaLayer()) + for layer in layers: + child_engine.layer(cast(GraphEngineLayer, layer)) + return child_engine + + class WorkflowEntry: def __init__( self, @@ -77,6 +137,7 @@ class WorkflowEntry: command_channel = InMemoryChannel() self.command_channel = command_channel + self._child_engine_builder = _WorkflowChildEngineBuilder() self.graph_engine = GraphEngine( workflow_id=workflow_id, graph=graph, @@ -88,6 +149,7 @@ class WorkflowEntry: scale_up_threshold=dify_config.GRAPH_ENGINE_SCALE_UP_THRESHOLD, scale_down_idle_time=dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME, ), + child_engine_builder=self._child_engine_builder, ) # Add debug logging layer when in debug mode @@ -154,13 +216,15 @@ class WorkflowEntry: # init graph init params and runtime state graph_init_params = GraphInitParams( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, workflow_id=workflow.id, graph_config=workflow.graph_dict, - user_id=user_id, - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context=build_dify_run_context( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + user_id=user_id, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ), call_depth=0, ) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) @@ -293,13 +357,15 @@ class WorkflowEntry: # init graph init params and runtime state graph_init_params = GraphInitParams( - tenant_id=tenant_id, - app_id="", workflow_id="", graph_config=graph_dict, - user_id=user_id, - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context=build_dify_run_context( + tenant_id=tenant_id, + app_id="", + user_id=user_id, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ), call_depth=0, ) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) diff --git a/api/dify_graph/entities/graph_init_params.py b/api/dify_graph/entities/graph_init_params.py index 3712842aaf..f785d58a52 100644 --- a/api/dify_graph/entities/graph_init_params.py +++ b/api/dify_graph/entities/graph_init_params.py @@ -3,7 +3,7 @@ from typing import Any from pydantic import BaseModel, Field -from dify_graph.enums import InvokeFrom, UserFrom +DIFY_RUN_CONTEXT_KEY = "_dify" class GraphInitParams(BaseModel): @@ -18,11 +18,7 @@ class GraphInitParams(BaseModel): """ # init params - tenant_id: str = Field(..., description="tenant / workspace id") - app_id: str = Field(..., description="app id") workflow_id: str = Field(..., description="workflow id") graph_config: Mapping[str, Any] = Field(..., description="graph config") - user_id: str = Field(..., description="user id") - user_from: UserFrom = Field(..., description="user from, account or end-user") - invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger") + run_context: Mapping[str, Any] = Field(..., description="runtime context") call_depth: int = Field(..., description="call depth") diff --git a/api/dify_graph/enums.py b/api/dify_graph/enums.py index 6c0593945e..bb3b13e8c6 100644 --- a/api/dify_graph/enums.py +++ b/api/dify_graph/enums.py @@ -33,39 +33,6 @@ class SystemVariableKey(StrEnum): INVOKE_FROM = "invoke_from" -class UserFrom(StrEnum): - ACCOUNT = "account" - END_USER = "end-user" - - -class InvokeFrom(StrEnum): - SERVICE_API = "service-api" - WEB_APP = "web-app" - TRIGGER = "trigger" - EXPLORE = "explore" - DEBUGGER = "debugger" - PUBLISHED_PIPELINE = "published" - VALIDATION = "validation" - - @classmethod - def value_of(cls, value: str) -> "InvokeFrom": - return cls(value) - - def to_source(self) -> str: - """Get source of invoke from. - - :return: source - """ - source_mapping = { - InvokeFrom.WEB_APP: "web_app", - InvokeFrom.DEBUGGER: "dev", - InvokeFrom.EXPLORE: "explore_app", - InvokeFrom.TRIGGER: "trigger", - InvokeFrom.SERVICE_API: "api", - } - return source_mapping.get(self, "dev") - - class NodeType(StrEnum): START = "start" END = "end" diff --git a/api/dify_graph/graph_engine/graph_engine.py b/api/dify_graph/graph_engine/graph_engine.py index 772e607328..ea98a46b06 100644 --- a/api/dify_graph/graph_engine/graph_engine.py +++ b/api/dify_graph/graph_engine/graph_engine.py @@ -9,7 +9,7 @@ from __future__ import annotations import logging import queue -from collections.abc import Generator +from collections.abc import Generator, Mapping from typing import TYPE_CHECKING, cast, final from dify_graph.context import capture_current_context @@ -27,6 +27,7 @@ from dify_graph.graph_events import ( GraphRunSucceededEvent, ) from dify_graph.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper +from dify_graph.runtime.graph_runtime_state import ChildGraphEngineBuilderProtocol if TYPE_CHECKING: # pragma: no cover - used only for static analysis from dify_graph.runtime.graph_runtime_state import GraphProtocol @@ -49,6 +50,7 @@ from .protocols.command_channel import CommandChannel from .worker_management import WorkerPool if TYPE_CHECKING: + from dify_graph.entities import GraphInitParams from dify_graph.graph_engine.domain.graph_execution import GraphExecution from dify_graph.graph_engine.response_coordinator import ResponseStreamCoordinator @@ -74,6 +76,7 @@ class GraphEngine: graph_runtime_state: GraphRuntimeState, command_channel: CommandChannel, config: GraphEngineConfig = _DEFAULT_CONFIG, + child_engine_builder: ChildGraphEngineBuilderProtocol | None = None, ) -> None: """Initialize the graph engine with all subsystems and dependencies.""" @@ -83,6 +86,9 @@ class GraphEngine: self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph)) self._command_channel = command_channel self._config = config + self._child_engine_builder = child_engine_builder + if child_engine_builder is not None: + self._graph_runtime_state.bind_child_engine_builder(child_engine_builder) # Graph execution tracks the overall execution state self._graph_execution = cast("GraphExecution", self._graph_runtime_state.graph_execution) @@ -214,6 +220,25 @@ class GraphEngine: self._bind_layer_context(layer) return self + def create_child_engine( + self, + *, + workflow_id: str, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + graph_config: dict[str, object] | Mapping[str, object], + root_node_id: str, + layers: list[GraphEngineLayer] | tuple[GraphEngineLayer, ...] = (), + ) -> GraphEngine: + return self._graph_runtime_state.create_child_engine( + workflow_id=workflow_id, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + graph_config=graph_config, + root_node_id=root_node_id, + layers=layers, + ) + def run(self) -> Generator[GraphEngineEvent, None, None]: """ Execute the graph using the modular architecture. diff --git a/api/dify_graph/nodes/agent/agent_node.py b/api/dify_graph/nodes/agent/agent_node.py index f55871718f..d770f7afd1 100644 --- a/api/dify_graph/nodes/agent/agent_node.py +++ b/api/dify_graph/nodes/agent/agent_node.py @@ -80,9 +80,11 @@ class AgentNode(Node[AgentNodeData]): def _run(self) -> Generator[NodeEventBase, None, None]: from core.plugin.impl.exc import PluginDaemonClientSideError + dify_ctx = self.require_dify_context() + try: strategy = get_plugin_agent_strategy( - tenant_id=self.tenant_id, + tenant_id=dify_ctx.tenant_id, agent_strategy_provider_name=self.node_data.agent_strategy_provider_name, agent_strategy_name=self.node_data.agent_strategy_name, ) @@ -120,8 +122,8 @@ class AgentNode(Node[AgentNodeData]): try: message_stream = strategy.invoke( params=parameters, - user_id=self.user_id, - app_id=self.app_id, + user_id=dify_ctx.user_id, + app_id=dify_ctx.app_id, conversation_id=conversation_id.text if conversation_id else None, credentials=credentials, ) @@ -144,8 +146,8 @@ class AgentNode(Node[AgentNodeData]): "agent_strategy": self.node_data.agent_strategy_name, }, parameters_for_log=parameters_for_log, - user_id=self.user_id, - tenant_id=self.tenant_id, + user_id=dify_ctx.user_id, + tenant_id=dify_ctx.tenant_id, node_type=self.node_type, node_id=self._node_id, node_execution_id=self.id, @@ -283,8 +285,13 @@ class AgentNode(Node[AgentNodeData]): runtime_variable_pool: VariablePool | None = None if node_data.version != "1" or node_data.tool_node_version is not None: runtime_variable_pool = variable_pool + dify_ctx = self.require_dify_context() tool_runtime = ToolManager.get_agent_tool_runtime( - self.tenant_id, self.app_id, entity, self.invoke_from, runtime_variable_pool + dify_ctx.tenant_id, + dify_ctx.app_id, + entity, + dify_ctx.invoke_from, + runtime_variable_pool, ) if tool_runtime.entity.description: tool_runtime.entity.description.llm = ( @@ -396,7 +403,8 @@ class AgentNode(Node[AgentNodeData]): from core.plugin.impl.plugin import PluginInstaller manager = PluginInstaller() - plugins = manager.list_plugins(self.tenant_id) + dify_ctx = self.require_dify_context() + plugins = manager.list_plugins(dify_ctx.tenant_id) try: current_plugin = next( plugin @@ -417,8 +425,11 @@ class AgentNode(Node[AgentNodeData]): return None conversation_id = conversation_id_variable.value + dify_ctx = self.require_dify_context() with Session(db.engine, expire_on_commit=False) as session: - stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id) + stmt = select(Conversation).where( + Conversation.app_id == dify_ctx.app_id, Conversation.id == conversation_id + ) conversation = session.scalar(stmt) if not conversation: @@ -429,9 +440,10 @@ class AgentNode(Node[AgentNodeData]): return memory def _fetch_model(self, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]: + dify_ctx = self.require_dify_context() provider_manager = ProviderManager() provider_model_bundle = provider_manager.get_provider_model_bundle( - tenant_id=self.tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM + tenant_id=dify_ctx.tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM ) model_name = value.get("model", "") model_credentials = provider_model_bundle.configuration.get_current_credentials( @@ -440,7 +452,7 @@ class AgentNode(Node[AgentNodeData]): provider_name = provider_model_bundle.configuration.provider.provider model_type_instance = provider_model_bundle.model_type_instance model_instance = ModelManager().get_model_instance( - tenant_id=self.tenant_id, + tenant_id=dify_ctx.tenant_id, provider=provider_name, model_type=ModelType(value.get("model_type", "")), model=model_name, diff --git a/api/dify_graph/nodes/base/node.py b/api/dify_graph/nodes/base/node.py index 8eaf0b16b3..1f99a0a6e2 100644 --- a/api/dify_graph/nodes/base/node.py +++ b/api/dify_graph/nodes/base/node.py @@ -8,10 +8,11 @@ from abc import abstractmethod from collections.abc import Generator, Mapping, Sequence from functools import singledispatchmethod from types import MappingProxyType -from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin +from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, get_args, get_origin from uuid import uuid4 from dify_graph.entities import AgentNodeStrategyInit, GraphInitParams +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.enums import ( ErrorStrategy, NodeExecutionType, @@ -64,10 +65,28 @@ from libs.datetime_utils import naive_utc_now from .entities import BaseNodeData, RetryConfig NodeDataT = TypeVar("NodeDataT", bound=BaseNodeData) +_MISSING_RUN_CONTEXT_VALUE = object() logger = logging.getLogger(__name__) +class DifyRunContextProtocol(Protocol): + tenant_id: str + app_id: str + user_id: str + user_from: Any + invoke_from: Any + + +class _MappingDifyRunContext: + def __init__(self, mapping: Mapping[str, Any]) -> None: + self.tenant_id = str(mapping["tenant_id"]) + self.app_id = str(mapping["app_id"]) + self.user_id = str(mapping["user_id"]) + self.user_from = mapping["user_from"] + self.invoke_from = mapping["invoke_from"] + + class Node(Generic[NodeDataT]): """BaseNode serves as the foundational class for all node implementations. @@ -227,14 +246,10 @@ class Node(Generic[NodeDataT]): graph_runtime_state: GraphRuntimeState, ) -> None: self._graph_init_params = graph_init_params + self._run_context = MappingProxyType(dict(graph_init_params.run_context)) self.id = id - self.tenant_id = graph_init_params.tenant_id - self.app_id = graph_init_params.app_id self.workflow_id = graph_init_params.workflow_id self.graph_config = graph_init_params.graph_config - self.user_id = graph_init_params.user_id - self.user_from = graph_init_params.user_from - self.invoke_from = graph_init_params.invoke_from self.workflow_call_depth = graph_init_params.call_depth self.graph_runtime_state = graph_runtime_state self.state: NodeState = NodeState.UNKNOWN # node execution state @@ -263,6 +278,38 @@ class Node(Generic[NodeDataT]): def graph_init_params(self) -> GraphInitParams: return self._graph_init_params + @property + def run_context(self) -> Mapping[str, Any]: + return self._run_context + + def get_run_context_value(self, key: str, default: Any = None) -> Any: + return self._run_context.get(key, default) + + def require_run_context_value(self, key: str) -> Any: + value = self.get_run_context_value(key, _MISSING_RUN_CONTEXT_VALUE) + if value is _MISSING_RUN_CONTEXT_VALUE: + raise ValueError(f"run_context missing required key: {key}") + return value + + def require_dify_context(self) -> DifyRunContextProtocol: + raw_ctx = self.require_run_context_value(DIFY_RUN_CONTEXT_KEY) + if raw_ctx is None: + raise ValueError(f"run_context missing required key: {DIFY_RUN_CONTEXT_KEY}") + + if isinstance(raw_ctx, Mapping): + missing_keys = [ + key for key in ("tenant_id", "app_id", "user_id", "user_from", "invoke_from") if key not in raw_ctx + ] + if missing_keys: + raise ValueError(f"dify context missing required keys: {', '.join(missing_keys)}") + return _MappingDifyRunContext(raw_ctx) + + for attr in ("tenant_id", "app_id", "user_id", "user_from", "invoke_from"): + if not hasattr(raw_ctx, attr): + raise TypeError(f"invalid dify context object, missing attribute: {attr}") + + return cast(DifyRunContextProtocol, raw_ctx) + @property def execution_id(self) -> str: return self._node_execution_id diff --git a/api/dify_graph/nodes/datasource/datasource_node.py b/api/dify_graph/nodes/datasource/datasource_node.py index 802d34d2d0..b97394744e 100644 --- a/api/dify_graph/nodes/datasource/datasource_node.py +++ b/api/dify_graph/nodes/datasource/datasource_node.py @@ -52,6 +52,7 @@ class DatasourceNode(Node[DatasourceNodeData]): Run the datasource node """ + dify_ctx = self.require_dify_context() node_data = self.node_data variable_pool = self.graph_runtime_state.variable_pool datasource_type_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE]) @@ -75,7 +76,7 @@ class DatasourceNode(Node[DatasourceNodeData]): datasource_info["icon"] = self.datasource_manager.get_icon_url( provider_id=provider_id, datasource_name=node_data.datasource_name or "", - tenant_id=self.tenant_id, + tenant_id=dify_ctx.tenant_id, datasource_type=datasource_type.value, ) @@ -104,11 +105,11 @@ class DatasourceNode(Node[DatasourceNodeData]): yield from self.datasource_manager.stream_node_events( node_id=self._node_id, - user_id=self.user_id, + user_id=dify_ctx.user_id, datasource_name=node_data.datasource_name or "", datasource_type=datasource_type.value, provider_id=provider_id, - tenant_id=self.tenant_id, + tenant_id=dify_ctx.tenant_id, provider=node_data.provider_name, plugin_id=node_data.plugin_id, credential_id=credential_id, @@ -136,7 +137,7 @@ class DatasourceNode(Node[DatasourceNodeData]): raise DatasourceNodeError("File is not exist") file_info = self.datasource_manager.get_upload_file_by_id( - file_id=related_id, tenant_id=self.tenant_id + file_id=related_id, tenant_id=dify_ctx.tenant_id ) variable_pool.add([self._node_id, "file"], file_info) # variable_pool.add([self.node_id, "file"], file_info.to_dict()) diff --git a/api/dify_graph/nodes/http_request/node.py b/api/dify_graph/nodes/http_request/node.py index ae0faa8a56..2e48d5502a 100644 --- a/api/dify_graph/nodes/http_request/node.py +++ b/api/dify_graph/nodes/http_request/node.py @@ -212,6 +212,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]): """ Extract files from response by checking both Content-Type header and URL """ + dify_ctx = self.require_dify_context() files: list[File] = [] is_file = response.is_file content_type = response.content_type @@ -236,8 +237,8 @@ class HttpRequestNode(Node[HttpRequestNodeData]): tool_file_manager = self._tool_file_manager_factory() tool_file = tool_file_manager.create_file_by_raw( - user_id=self.user_id, - tenant_id=self.tenant_id, + user_id=dify_ctx.user_id, + tenant_id=dify_ctx.tenant_id, conversation_id=None, file_binary=content, mimetype=mime_type, @@ -249,7 +250,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]): } file = file_factory.build_from_mapping( mapping=mapping, - tenant_id=self.tenant_id, + tenant_id=dify_ctx.tenant_id, ) files.append(file) diff --git a/api/dify_graph/nodes/human_input/human_input_node.py b/api/dify_graph/nodes/human_input/human_input_node.py index e54650898d..03c2d17b1d 100644 --- a/api/dify_graph/nodes/human_input/human_input_node.py +++ b/api/dify_graph/nodes/human_input/human_input_node.py @@ -4,7 +4,7 @@ from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.enums import InvokeFrom, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus +from dify_graph.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus from dify_graph.node_events import ( HumanInputFormFilledEvent, HumanInputFormTimeoutEvent, @@ -31,6 +31,8 @@ if TYPE_CHECKING: _SELECTED_BRANCH_KEY = "selected_branch" +_INVOKE_FROM_DEBUGGER = "debugger" +_INVOKE_FROM_EXPLORE = "explore" logger = logging.getLogger(__name__) @@ -155,30 +157,39 @@ class HumanInputNode(Node[HumanInputNodeData]): return resolved_defaults def _should_require_console_recipient(self) -> bool: - if self.invoke_from == InvokeFrom.DEBUGGER: + invoke_from = self._invoke_from_value() + if invoke_from == _INVOKE_FROM_DEBUGGER: return True - if self.invoke_from == InvokeFrom.EXPLORE: + if invoke_from == _INVOKE_FROM_EXPLORE: return self._node_data.is_webapp_enabled() return False def _display_in_ui(self) -> bool: - if self.invoke_from == InvokeFrom.DEBUGGER: + if self._invoke_from_value() == _INVOKE_FROM_DEBUGGER: return True return self._node_data.is_webapp_enabled() def _effective_delivery_methods(self) -> Sequence[DeliveryChannelConfig]: + dify_ctx = self.require_dify_context() + invoke_from = self._invoke_from_value() enabled_methods = [method for method in self._node_data.delivery_methods if method.enabled] - if self.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE}: + if invoke_from in {_INVOKE_FROM_DEBUGGER, _INVOKE_FROM_EXPLORE}: enabled_methods = [method for method in enabled_methods if method.type != DeliveryMethodType.WEBAPP] return [ apply_debug_email_recipient( method, - enabled=self.invoke_from == InvokeFrom.DEBUGGER, - user_id=self.user_id or "", + enabled=invoke_from == _INVOKE_FROM_DEBUGGER, + user_id=dify_ctx.user_id, ) for method in enabled_methods ] + def _invoke_from_value(self) -> str: + invoke_from = self.require_dify_context().invoke_from + if isinstance(invoke_from, str): + return invoke_from + return str(getattr(invoke_from, "value", invoke_from)) + def _human_input_required_event(self, form_entity: HumanInputFormEntity) -> HumanInputRequired: node_data = self._node_data resolved_default_values = self.resolve_default_values() @@ -212,10 +223,11 @@ class HumanInputNode(Node[HumanInputNodeData]): """ repo = self._form_repository form = repo.get_form(self._workflow_execution_id, self.id) + dify_ctx = self.require_dify_context() if form is None: display_in_ui = self._display_in_ui() params = FormCreateParams( - app_id=self.app_id, + app_id=dify_ctx.app_id, workflow_execution_id=self._workflow_execution_id, node_id=self.id, form_config=self._node_data, @@ -225,7 +237,9 @@ class HumanInputNode(Node[HumanInputNodeData]): resolved_default_values=self.resolve_default_values(), console_recipient_required=self._should_require_console_recipient(), console_creator_account_id=( - self.user_id if self.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE} else None + dify_ctx.user_id + if self._invoke_from_value() in {_INVOKE_FROM_DEBUGGER, _INVOKE_FROM_EXPLORE} + else None ), backstage_recipient_required=True, ) diff --git a/api/dify_graph/nodes/iteration/iteration_node.py b/api/dify_graph/nodes/iteration/iteration_node.py index ed3634fa91..6d26cbfce4 100644 --- a/api/dify_graph/nodes/iteration/iteration_node.py +++ b/api/dify_graph/nodes/iteration/iteration_node.py @@ -587,24 +587,14 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): return def _create_graph_engine(self, index: int, item: object): - # Import dependencies - from core.app.workflow.layers.llm_quota import LLMQuotaLayer - from core.workflow.node_factory import DifyNodeFactory from dify_graph.entities import GraphInitParams - from dify_graph.graph import Graph - from dify_graph.graph_engine import GraphEngine, GraphEngineConfig - from dify_graph.graph_engine.command_channels import InMemoryChannel - from dify_graph.runtime import GraphRuntimeState + from dify_graph.runtime import ChildGraphNotFoundError, GraphRuntimeState - # Create GraphInitParams from node attributes + # Create GraphInitParams for child graph execution. graph_init_params = GraphInitParams( - tenant_id=self.tenant_id, - app_id=self.app_id, workflow_id=self.workflow_id, graph_config=self.graph_config, - user_id=self.user_id, - user_from=self.user_from, - invoke_from=self.invoke_from, + run_context=self.run_context, call_depth=self.workflow_call_depth, ) # Create a deep copy of the variable pool for each iteration @@ -621,28 +611,17 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): total_tokens=0, node_run_steps=0, ) + root_node_id = self.node_data.start_node_id + if root_node_id is None: + raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found") - # Create a new node factory with the new GraphRuntimeState - node_factory = DifyNodeFactory( - graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state_copy - ) - - # Initialize the iteration graph with the new node factory - iteration_graph = Graph.init( - graph_config=self.graph_config, node_factory=node_factory, root_node_id=self.node_data.start_node_id - ) - - if not iteration_graph: - raise IterationGraphNotFoundError("iteration graph not found") - - # Create a new GraphEngine for this iteration - graph_engine = GraphEngine( - workflow_id=self.workflow_id, - graph=iteration_graph, - graph_runtime_state=graph_runtime_state_copy, - command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs - config=GraphEngineConfig(), - ) - graph_engine.layer(LLMQuotaLayer()) - - return graph_engine + try: + return self.graph_runtime_state.create_child_engine( + workflow_id=self.workflow_id, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state_copy, + graph_config=self.graph_config, + root_node_id=root_node_id, + ) + except ChildGraphNotFoundError as exc: + raise IterationGraphNotFoundError("iteration graph not found") from exc diff --git a/api/dify_graph/nodes/knowledge_index/knowledge_index_node.py b/api/dify_graph/nodes/knowledge_index/knowledge_index_node.py index daf97d6ca9..eeb4f3c229 100644 --- a/api/dify_graph/nodes/knowledge_index/knowledge_index_node.py +++ b/api/dify_graph/nodes/knowledge_index/knowledge_index_node.py @@ -3,7 +3,7 @@ from collections.abc import Mapping from typing import TYPE_CHECKING, Any from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.enums import InvokeFrom, NodeExecutionType, NodeType, SystemVariableKey +from dify_graph.enums import NodeExecutionType, NodeType, SystemVariableKey from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base.node import Node from dify_graph.nodes.base.template import Template @@ -20,6 +20,7 @@ if TYPE_CHECKING: from dify_graph.runtime import GraphRuntimeState logger = logging.getLogger(__name__) +_INVOKE_FROM_DEBUGGER = "debugger" class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): @@ -58,7 +59,8 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): if not variable: raise KnowledgeIndexNodeError("Index chunk variable is required.") invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) - is_preview = invoke_from.value == InvokeFrom.DEBUGGER if invoke_from else False + invoke_from_value = str(invoke_from.value) if invoke_from else None + is_preview = invoke_from_value == _INVOKE_FROM_DEBUGGER chunks = variable.value variables = {"chunks": chunks} diff --git a/api/dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 97c013812e..d84dda42d6 100644 --- a/api/dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -66,9 +66,10 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD self._rag_retrieval = rag_retrieval if llm_file_saver is None: + dify_ctx = self.require_dify_context() llm_file_saver = FileSaverImpl( - user_id=graph_init_params.user_id, - tenant_id=graph_init_params.tenant_id, + user_id=dify_ctx.user_id, + tenant_id=dify_ctx.tenant_id, ) self._llm_file_saver = llm_file_saver @@ -160,6 +161,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD def _fetch_dataset_retriever( self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any] ) -> tuple[list[Source], LLMUsage]: + dify_ctx = self.require_dify_context() dataset_ids = node_data.dataset_ids query = variables.get("query") attachments = variables.get("attachments") @@ -176,10 +178,10 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD model = node_data.single_retrieval_config.model retrieval_resource_list = self._rag_retrieval.knowledge_retrieval( request=KnowledgeRetrievalRequest( - tenant_id=self.tenant_id, - user_id=self.user_id, - app_id=self.app_id, - user_from=self.user_from.value, + tenant_id=dify_ctx.tenant_id, + user_id=dify_ctx.user_id, + app_id=dify_ctx.app_id, + user_from=dify_ctx.user_from.value, dataset_ids=dataset_ids, retrieval_mode=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value, completion_params=model.completion_params, @@ -229,10 +231,10 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD retrieval_resource_list = self._rag_retrieval.knowledge_retrieval( request=KnowledgeRetrievalRequest( - app_id=self.app_id, - tenant_id=self.tenant_id, - user_id=self.user_id, - user_from=self.user_from.value, + app_id=dify_ctx.app_id, + tenant_id=dify_ctx.tenant_id, + user_id=dify_ctx.user_id, + user_from=dify_ctx.user_from.value, dataset_ids=dataset_ids, query=query, retrieval_mode=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value, diff --git a/api/dify_graph/nodes/llm/node.py b/api/dify_graph/nodes/llm/node.py index 65b92b3bcc..c7697a0972 100644 --- a/api/dify_graph/nodes/llm/node.py +++ b/api/dify_graph/nodes/llm/node.py @@ -145,9 +145,10 @@ class LLMNode(Node[LLMNodeData]): self._memory = memory if llm_file_saver is None: + dify_ctx = self.require_dify_context() llm_file_saver = FileSaverImpl( - user_id=graph_init_params.user_id, - tenant_id=graph_init_params.tenant_id, + user_id=dify_ctx.user_id, + tenant_id=dify_ctx.tenant_id, ) self._llm_file_saver = llm_file_saver @@ -242,7 +243,7 @@ class LLMNode(Node[LLMNodeData]): model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, - user_id=self.user_id, + user_id=self.require_dify_context().user_id, structured_output_enabled=self.node_data.structured_output_enabled, structured_output=self.node_data.structured_output, file_saver=self._llm_file_saver, @@ -702,7 +703,7 @@ class LLMNode(Node[LLMNodeData]): filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, - tenant_id=self.tenant_id, + tenant_id=self.require_dify_context().tenant_id, type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, diff --git a/api/dify_graph/nodes/loop/loop_node.py b/api/dify_graph/nodes/loop/loop_node.py index 93a9b4d7eb..8279f0fc66 100644 --- a/api/dify_graph/nodes/loop/loop_node.py +++ b/api/dify_graph/nodes/loop/loop_node.py @@ -412,24 +412,14 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): return build_segment_with_type(var_type, value) def _create_graph_engine(self, start_at: datetime, root_node_id: str): - # Import dependencies - from core.app.workflow.layers.llm_quota import LLMQuotaLayer - from core.workflow.node_factory import DifyNodeFactory from dify_graph.entities import GraphInitParams - from dify_graph.graph import Graph - from dify_graph.graph_engine import GraphEngine, GraphEngineConfig - from dify_graph.graph_engine.command_channels import InMemoryChannel from dify_graph.runtime import GraphRuntimeState - # Create GraphInitParams from node attributes + # Create GraphInitParams for child graph execution. graph_init_params = GraphInitParams( - tenant_id=self.tenant_id, - app_id=self.app_id, workflow_id=self.workflow_id, graph_config=self.graph_config, - user_id=self.user_id, - user_from=self.user_from, - invoke_from=self.invoke_from, + run_context=self.run_context, call_depth=self.workflow_call_depth, ) @@ -439,22 +429,10 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): start_at=start_at.timestamp(), ) - # Create a new node factory with the new GraphRuntimeState - node_factory = DifyNodeFactory( - graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state_copy - ) - - # Initialize the loop graph with the new node factory - loop_graph = Graph.init(graph_config=self.graph_config, node_factory=node_factory, root_node_id=root_node_id) - - # Create a new GraphEngine for this iteration - graph_engine = GraphEngine( + return self.graph_runtime_state.create_child_engine( workflow_id=self.workflow_id, - graph=loop_graph, + graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state_copy, - command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs - config=GraphEngineConfig(), + graph_config=self.graph_config, + root_node_id=root_node_id, ) - graph_engine.layer(LLMQuotaLayer()) - - return graph_engine diff --git a/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py b/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py index a9b21d83b1..1325a6a09a 100644 --- a/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py @@ -297,7 +297,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): tools=tools, stop=list(stop), stream=False, - user=self.user_id, + user=self.require_dify_context().user_id, ) # handle invoke result diff --git a/api/dify_graph/nodes/question_classifier/question_classifier_node.py b/api/dify_graph/nodes/question_classifier/question_classifier_node.py index 03ddf9ab5f..97535d832d 100644 --- a/api/dify_graph/nodes/question_classifier/question_classifier_node.py +++ b/api/dify_graph/nodes/question_classifier/question_classifier_node.py @@ -86,9 +86,10 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): self._memory = memory if llm_file_saver is None: + dify_ctx = self.require_dify_context() llm_file_saver = FileSaverImpl( - user_id=graph_init_params.user_id, - tenant_id=graph_init_params.tenant_id, + user_id=dify_ctx.user_id, + tenant_id=dify_ctx.tenant_id, ) self._llm_file_saver = llm_file_saver @@ -160,7 +161,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, - user_id=self.user_id, + user_id=self.require_dify_context().user_id, structured_output_enabled=False, structured_output=None, file_saver=self._llm_file_saver, diff --git a/api/dify_graph/nodes/tool/tool_node.py b/api/dify_graph/nodes/tool/tool_node.py index eee065c311..57fb946559 100644 --- a/api/dify_graph/nodes/tool/tool_node.py +++ b/api/dify_graph/nodes/tool/tool_node.py @@ -56,6 +56,8 @@ class ToolNode(Node[ToolNodeData]): """ from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError + dify_ctx = self.require_dify_context() + # fetch tool icon tool_info = { "provider_type": self.node_data.provider_type.value, @@ -75,7 +77,12 @@ class ToolNode(Node[ToolNodeData]): if self.node_data.version != "1" or self.node_data.tool_node_version is not None: variable_pool = self.graph_runtime_state.variable_pool tool_runtime = ToolManager.get_workflow_tool_runtime( - self.tenant_id, self.app_id, self._node_id, self.node_data, self.invoke_from, variable_pool + dify_ctx.tenant_id, + dify_ctx.app_id, + self._node_id, + self.node_data, + dify_ctx.invoke_from, + variable_pool, ) except ToolNodeError as e: yield StreamCompletedEvent( @@ -109,10 +116,10 @@ class ToolNode(Node[ToolNodeData]): message_stream = ToolEngine.generic_invoke( tool=tool_runtime, tool_parameters=parameters, - user_id=self.user_id, + user_id=dify_ctx.user_id, workflow_tool_callback=DifyWorkflowCallbackHandler(), workflow_call_depth=self.workflow_call_depth, - app_id=self.app_id, + app_id=dify_ctx.app_id, conversation_id=conversation_id.text if conversation_id else None, ) except ToolNodeError as e: @@ -133,8 +140,8 @@ class ToolNode(Node[ToolNodeData]): messages=message_stream, tool_info=tool_info, parameters_for_log=parameters_for_log, - user_id=self.user_id, - tenant_id=self.tenant_id, + user_id=dify_ctx.user_id, + tenant_id=dify_ctx.tenant_id, node_id=self._node_id, tool_runtime=tool_runtime, ) diff --git a/api/dify_graph/nodes/trigger_webhook/node.py b/api/dify_graph/nodes/trigger_webhook/node.py index 1b8167e799..e466541908 100644 --- a/api/dify_graph/nodes/trigger_webhook/node.py +++ b/api/dify_graph/nodes/trigger_webhook/node.py @@ -69,6 +69,7 @@ class TriggerWebhookNode(Node[WebhookData]): ) def generate_file_var(self, param_name: str, file: dict): + dify_ctx = self.require_dify_context() related_id = file.get("related_id") transfer_method_value = file.get("transfer_method") if transfer_method_value: @@ -84,7 +85,7 @@ class TriggerWebhookNode(Node[WebhookData]): try: file_obj = file_factory.build_from_mapping( mapping=file, - tenant_id=self.tenant_id, + tenant_id=dify_ctx.tenant_id, ) file_segment = build_segment_with_type(SegmentType.FILE, file_obj) return FileVariable(name=param_name, value=file_segment.value, selector=[self.id, param_name]) diff --git a/api/dify_graph/runtime/__init__.py b/api/dify_graph/runtime/__init__.py index 10014c7182..adca07e59a 100644 --- a/api/dify_graph/runtime/__init__.py +++ b/api/dify_graph/runtime/__init__.py @@ -1,9 +1,17 @@ -from .graph_runtime_state import GraphRuntimeState +from .graph_runtime_state import ( + ChildEngineBuilderNotConfiguredError, + ChildEngineError, + ChildGraphNotFoundError, + GraphRuntimeState, +) from .graph_runtime_state_protocol import ReadOnlyGraphRuntimeState, ReadOnlyVariablePool from .read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper, ReadOnlyVariablePoolWrapper from .variable_pool import VariablePool, VariableValue __all__ = [ + "ChildEngineBuilderNotConfiguredError", + "ChildEngineError", + "ChildGraphNotFoundError", "GraphRuntimeState", "ReadOnlyGraphRuntimeState", "ReadOnlyGraphRuntimeStateWrapper", diff --git a/api/dify_graph/runtime/graph_runtime_state.py b/api/dify_graph/runtime/graph_runtime_state.py index 6b88dd683c..41acc6db35 100644 --- a/api/dify_graph/runtime/graph_runtime_state.py +++ b/api/dify_graph/runtime/graph_runtime_state.py @@ -15,6 +15,7 @@ from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.runtime.variable_pool import VariablePool if TYPE_CHECKING: + from dify_graph.entities import GraphInitParams from dify_graph.entities.pause_reason import PauseReason @@ -135,6 +136,31 @@ class GraphProtocol(Protocol): def get_outgoing_edges(self, node_id: str) -> Sequence[EdgeProtocol]: ... +class ChildGraphEngineBuilderProtocol(Protocol): + def build_child_engine( + self, + *, + workflow_id: str, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + graph_config: Mapping[str, Any], + root_node_id: str, + layers: Sequence[object] = (), + ) -> Any: ... + + +class ChildEngineError(ValueError): + """Base error type for child-engine creation failures.""" + + +class ChildEngineBuilderNotConfiguredError(ChildEngineError): + """Raised when child-engine creation is requested without a bound builder.""" + + +class ChildGraphNotFoundError(ChildEngineError): + """Raised when the requested child graph entry point cannot be resolved.""" + + class _GraphStateSnapshot(BaseModel): """Serializable graph state snapshot for node/edge states.""" @@ -209,6 +235,7 @@ class GraphRuntimeState: self._pending_graph_execution_workflow_id: str | None = None self._paused_nodes: set[str] = set() self._deferred_nodes: set[str] = set() + self._child_engine_builder: ChildGraphEngineBuilderProtocol | None = None # Node and edges states needed to be restored into # graph object. @@ -250,6 +277,31 @@ class GraphRuntimeState: if self._graph is not None: _ = self.response_coordinator + def bind_child_engine_builder(self, builder: ChildGraphEngineBuilderProtocol) -> None: + self._child_engine_builder = builder + + def create_child_engine( + self, + *, + workflow_id: str, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + graph_config: Mapping[str, Any], + root_node_id: str, + layers: Sequence[object] = (), + ) -> Any: + if self._child_engine_builder is None: + raise ChildEngineBuilderNotConfiguredError("Child engine builder is not configured.") + + return self._child_engine_builder.build_child_engine( + workflow_id=workflow_id, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + graph_config=graph_config, + root_node_id=root_node_id, + layers=layers, + ) + # ------------------------------------------------------------------ # Primary collaborators # ------------------------------------------------------------------ diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py index 44bc64ec11..7147fe1eab 100644 --- a/api/services/workflow_app_service.py +++ b/api/services/workflow_app_service.py @@ -7,7 +7,7 @@ from sqlalchemy import and_, func, or_, select from sqlalchemy.orm import Session from dify_graph.enums import WorkflowExecutionStatus -from models import Account, App, EndUser, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun +from models import Account, App, EndUser, TenantAccountJoin, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun from models.enums import AppTriggerType, CreatorUserRole from models.trigger import WorkflowTriggerLog from services.plugin.plugin_service import PluginService @@ -132,7 +132,14 @@ class WorkflowAppService: ), ) if created_by_account: - account = session.scalar(select(Account).where(Account.email == created_by_account)) + account = session.scalar( + select(Account) + .join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id) + .where( + Account.email == created_by_account, + TenantAccountJoin.tenant_id == app_model.tenant_id, + ) + ) if not account: raise ValueError(f"Account not found: {created_by_account}") diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 21bc95136e..6d462b60b9 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -11,13 +11,13 @@ from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context from core.repositories import DifyCoreRepositoryFactory from core.repositories.human_input_repository import HumanInputFormRepositoryImpl from core.workflow.workflow_entry import WorkflowEntry from dify_graph.entities import GraphInitParams, WorkflowNodeExecution from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.enums import ErrorStrategy, UserFrom, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from dify_graph.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from dify_graph.errors import WorkflowNodeRunFailedError from dify_graph.file import File from dify_graph.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent @@ -1063,13 +1063,15 @@ class WorkflowService: variable_pool: VariablePool, ) -> HumanInputNode: graph_init_params = GraphInitParams( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, workflow_id=workflow.id, graph_config=workflow.graph_dict, - user_id=account.id, - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context=build_dify_run_context( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + user_id=account.id, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ), call_depth=0, ) graph_runtime_state = GraphRuntimeState( diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index 9971e357d2..f8b7f95493 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -4,10 +4,9 @@ import uuid import pytest from configs import dify_config -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities import GraphInitParams -from dify_graph.enums import UserFrom, WorkflowNodeExecutionStatus +from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.graph import Graph from dify_graph.node_events import NodeRunResult from dify_graph.nodes.code.code_node import CodeNode @@ -15,6 +14,7 @@ from dify_graph.nodes.code.limits import CodeNodeLimits from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock +from tests.workflow_test_utils import build_test_graph_init_params CODE_MAX_STRING_LENGTH = dify_config.CODE_MAX_STRING_LENGTH @@ -31,11 +31,11 @@ def init_code_node(code_config: dict): "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, code_config], } - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index 6e7b3a573a..f691113511 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -5,18 +5,18 @@ from urllib.parse import urlencode import pytest from configs import dify_config -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.helper.ssrf_proxy import ssrf_proxy from core.tools.tool_file_manager import ToolFileManager from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities import GraphInitParams -from dify_graph.enums import UserFrom, WorkflowNodeExecutionStatus +from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.file.file_manager import file_manager from dify_graph.graph import Graph from dify_graph.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock +from tests.workflow_test_utils import build_test_graph_init_params HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( max_connect_timeout=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, @@ -41,11 +41,11 @@ def init_http_node(config: dict): "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, config], } - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, @@ -685,11 +685,11 @@ def test_nested_object_variable_selector(setup_http_mock): ], } - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index cc83f0ea16..b4779ebcdd 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -4,17 +4,17 @@ import uuid from collections.abc import Generator from unittest.mock import MagicMock, patch -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.llm_generator.output_parser.structured_output import _parse_structured_output from core.model_manager import ModelInstance -from dify_graph.entities import GraphInitParams -from dify_graph.enums import UserFrom, WorkflowNodeExecutionStatus +from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.node_events import StreamCompletedEvent from dify_graph.nodes.llm.node import LLMNode from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from extensions.ext_database import db +from tests.workflow_test_utils import build_test_graph_init_params """FOR MOCK FIXTURES, DO NOT REMOVE""" @@ -37,11 +37,11 @@ def init_llm_node(config: dict) -> LLMNode: workflow_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056d" user_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056e" - init_params = GraphInitParams( - tenant_id=tenant_id, - app_id=app_id, + init_params = build_test_graph_init_params( workflow_id=workflow_id, graph_config=graph_config, + tenant_id=tenant_id, + app_id=app_id, user_id=user_id, user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index 7310c40c50..62d9af0196 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -3,10 +3,9 @@ import time import uuid from unittest.mock import MagicMock -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.model_manager import ModelInstance -from dify_graph.entities import GraphInitParams -from dify_graph.enums import UserFrom, WorkflowNodeExecutionStatus +from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.model_runtime.entities import AssistantPromptMessage, UserPromptMessage from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory from dify_graph.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode @@ -14,6 +13,7 @@ from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from extensions.ext_database import db from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_instance +from tests.workflow_test_utils import build_test_graph_init_params """FOR MOCK FIXTURES, DO NOT REMOVE""" from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock @@ -43,11 +43,11 @@ def init_parameter_extractor_node(config: dict, memory=None): "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, config], } - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 9b4274c667..970e2cae00 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -1,15 +1,15 @@ import time import uuid -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities import GraphInitParams -from dify_graph.enums import UserFrom, WorkflowNodeExecutionStatus +from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.graph import Graph from dify_graph.nodes.template_transform.template_renderer import TemplateRenderError from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params class _SimpleJinja2Renderer: @@ -53,11 +53,11 @@ def test_execute_template_transform(): "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, config], } - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index fdc690e4cb..f70bf46979 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -2,16 +2,16 @@ import time import uuid from unittest.mock import MagicMock -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.tools.utils.configuration import ToolParameterConfigurationManager from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities import GraphInitParams -from dify_graph.enums import UserFrom, WorkflowNodeExecutionStatus +from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.graph import Graph from dify_graph.node_events import StreamCompletedEvent from dify_graph.nodes.tool.tool_node import ToolNode from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params def init_tool_node(config: dict): @@ -26,11 +26,11 @@ def init_tool_node(config: dict): "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, config], } - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, diff --git a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py index c4c9d62c84..9733735df3 100644 --- a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py +++ b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py @@ -12,7 +12,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerat from core.app.workflow.layers import PersistenceWorkflowInfo, WorkflowPersistenceLayer from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository -from dify_graph.entities import GraphInitParams from dify_graph.enums import WorkflowType from dify_graph.graph import Graph from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel @@ -33,6 +32,7 @@ from models.account import Tenant, TenantAccountJoin, TenantAccountRole from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.model import App, AppMode, IconType from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowRun +from tests.workflow_test_utils import build_test_graph_init_params def _mock_form_repository_without_submission() -> HumanInputFormRepository: @@ -87,11 +87,11 @@ def _build_graph( form_repository: HumanInputFormRepository, ) -> Graph: graph_config: dict[str, object] = {"nodes": [], "edges": []} - params = GraphInitParams( - tenant_id=tenant_id, - app_id=app_id, + params = build_test_graph_init_params( workflow_id=workflow_id, graph_config=graph_config, + tenant_id=tenant_id, + app_id=app_id, user_id=user_id, user_from="account", invoke_from="debugger", diff --git a/api/tests/unit_tests/core/app/apps/test_pause_resume.py b/api/tests/unit_tests/core/app/apps/test_pause_resume.py index dd722c52b2..44af89601c 100644 --- a/api/tests/unit_tests/core/app/apps/test_pause_resume.py +++ b/api/tests/unit_tests/core/app/apps/test_pause_resume.py @@ -13,7 +13,6 @@ from core.app.apps.advanced_chat import app_generator as adv_app_gen_module from core.app.apps.workflow import app_generator as wf_app_gen_module from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities import GraphInitParams from dify_graph.entities.pause_reason import SchedulingPause from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus @@ -34,6 +33,7 @@ from dify_graph.nodes.end.entities import EndNodeData from dify_graph.nodes.start.entities import StartNodeData from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params if "core.ops.ops_trace_manager" not in sys.modules: ops_stub = ModuleType("core.ops.ops_trace_manager") @@ -142,11 +142,11 @@ def _build_graph_config(*, pause_on: str | None) -> dict[str, object]: def _build_graph(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> Graph: graph_config = _build_graph_config(pause_on=pause_on) - params = GraphInitParams( - tenant_id="tenant", - app_id="app", + params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from="account", invoke_from="service-api", diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py index d7d2258df8..b93f18c5bd 100644 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py +++ b/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py @@ -4,15 +4,13 @@ from typing import Any import pytest -from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities import GraphInitParams -from dify_graph.enums import UserFrom from dify_graph.graph import Graph from dify_graph.graph.validation import GraphValidationError from dify_graph.nodes import NodeType from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params def _build_iteration_graph(node_id: str) -> dict[str, Any]: @@ -53,14 +51,14 @@ def _build_loop_graph(node_id: str) -> dict[str, Any]: def _make_factory(graph_config: dict[str, Any]) -> DifyNodeFactory: - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + user_from="account", + invoke_from="debugger", call_depth=0, ) graph_runtime_state = GraphRuntimeState( diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py index d3ef971e6a..b98d56147e 100644 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py +++ b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py @@ -6,15 +6,15 @@ from dataclasses import dataclass import pytest -from core.app.entities.app_invoke_entities import InvokeFrom from dify_graph.entities import GraphInitParams -from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeType, UserFrom +from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeType from dify_graph.graph import Graph from dify_graph.graph.validation import GraphValidationError from dify_graph.nodes.base.entities import BaseNodeData from dify_graph.nodes.base.node import Node from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params class _TestNodeData(BaseNodeData): @@ -91,14 +91,14 @@ class _SimpleNodeFactory: @pytest.fixture def graph_init_dependencies() -> tuple[_SimpleNodeFactory, dict[str, object]]: graph_config: dict[str, object] = {"edges": [], "nodes": []} - init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + user_from="account", + invoke_from="service-api", call_depth=0, ) variable_pool = VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py index 352e270fe4..819fd67f9d 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py @@ -32,6 +32,7 @@ def test_deduct_quota_called_for_successful_llm_node() -> None: node.execution_id = "execution-id" node.node_type = NodeType.LLM node.tenant_id = "tenant-id" + node.require_dify_context.return_value.tenant_id = "tenant-id" node.model_instance = object() result_event = _build_succeeded_event() @@ -52,6 +53,7 @@ def test_deduct_quota_called_for_question_classifier_node() -> None: node.execution_id = "execution-id" node.node_type = NodeType.QUESTION_CLASSIFIER node.tenant_id = "tenant-id" + node.require_dify_context.return_value.tenant_id = "tenant-id" node.model_instance = object() result_event = _build_succeeded_event() @@ -72,6 +74,7 @@ def test_non_llm_node_is_ignored() -> None: node.execution_id = "execution-id" node.node_type = NodeType.START node.tenant_id = "tenant-id" + node.require_dify_context.return_value.tenant_id = "tenant-id" node._model_instance = object() result_event = _build_succeeded_event() @@ -88,6 +91,7 @@ def test_quota_error_is_handled_in_layer() -> None: node.execution_id = "execution-id" node.node_type = NodeType.LLM node.tenant_id = "tenant-id" + node.require_dify_context.return_value.tenant_id = "tenant-id" node.model_instance = object() result_event = _build_succeeded_event() @@ -109,6 +113,7 @@ def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None: node.execution_id = "execution-id" node.node_type = NodeType.LLM node.tenant_id = "tenant-id" + node.require_dify_context.return_value.tenant_id = "tenant-id" node.model_instance = object() node.graph_runtime_state = MagicMock() node.graph_runtime_state.stop_event = stop_event diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py index 5196af277e..f886ae1c2b 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py @@ -8,6 +8,7 @@ for workflows containing nodes that require third-party services. import pytest from dify_graph.enums import NodeType +from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig from .test_table_runner import TableTestRunner, WorkflowTestCase @@ -199,22 +200,19 @@ def test_mock_config_builder(): def test_mock_factory_node_type_detection(): """Test that MockNodeFactory correctly identifies nodes to mock.""" - from core.app.entities.app_invoke_entities import InvokeFrom - from dify_graph.entities import GraphInitParams - from dify_graph.enums import UserFrom + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from dify_graph.runtime import GraphRuntimeState, VariablePool from .test_mock_factory import MockNodeFactory - graph_init_params = GraphInitParams( - tenant_id="test", - app_id="test", + graph_init_params = build_test_graph_init_params( workflow_id="test", graph_config={}, + tenant_id="test", + app_id="test", user_id="test", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.SERVICE_API, - call_depth=0, ) graph_runtime_state = GraphRuntimeState( variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), @@ -309,9 +307,7 @@ def test_workflow_without_auto_mock(): def test_register_custom_mock_node(): """Test registering a custom mock implementation for a node type.""" - from core.app.entities.app_invoke_entities import InvokeFrom - from dify_graph.entities import GraphInitParams - from dify_graph.enums import UserFrom + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from dify_graph.nodes.template_transform import TemplateTransformNode from dify_graph.runtime import GraphRuntimeState, VariablePool @@ -323,15 +319,14 @@ def test_register_custom_mock_node(): # Custom mock implementation pass - graph_init_params = GraphInitParams( - tenant_id="test", - app_id="test", + graph_init_params = build_test_graph_init_params( workflow_id="test", graph_config={}, + tenant_id="test", + app_id="test", user_id="test", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.SERVICE_API, - call_depth=0, ) graph_runtime_state = GraphRuntimeState( variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py index 09fc412e7f..765c4deba3 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py @@ -3,10 +3,9 @@ import time from unittest.mock import MagicMock -from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.entities.graph_init_params import GraphInitParams +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams from dify_graph.entities.pause_reason import SchedulingPause -from dify_graph.enums import UserFrom from dify_graph.graph import Graph from dify_graph.graph_engine import GraphEngine, GraphEngineConfig from dify_graph.graph_engine.command_channels import InMemoryChannel @@ -41,13 +40,17 @@ def test_abort_command(): id="start", config={"id": "start", "data": {"title": "start", "variables": []}}, graph_init_params=GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ), graph_runtime_state=shared_runtime_state, @@ -151,13 +154,17 @@ def test_pause_command(): id="start", config={"id": "start", "data": {"title": "start", "variables": []}}, graph_init_params=GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ), graph_runtime_state=shared_runtime_state, @@ -207,13 +214,17 @@ def test_update_variables_command_updates_pool(): id="start", config={"id": "start", "data": {"title": "start", "variables": []}}, graph_init_params=GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ), graph_runtime_state=shared_runtime_state, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py index 84e033156d..d54f0be190 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py @@ -21,6 +21,7 @@ from dify_graph.nodes.start.entities import StartNodeData from dify_graph.nodes.start.start_node import StartNode from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig from .test_mock_nodes import MockLLMNode @@ -73,11 +74,11 @@ def _build_llm_node( def _build_graph(runtime_state: GraphRuntimeState) -> Graph: graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from="account", invoke_from="debugger", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py index 695e99c1cf..538f53c603 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py @@ -4,7 +4,6 @@ from collections.abc import Iterable from unittest import mock from unittest.mock import MagicMock -from dify_graph.entities import GraphInitParams from dify_graph.graph import Graph from dify_graph.graph_events import ( GraphRunPausedEvent, @@ -35,6 +34,7 @@ from dify_graph.repositories.human_input_form_repository import HumanInputFormEn from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now +from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig from .test_mock_nodes import MockLLMNode @@ -47,11 +47,11 @@ def _build_branching_graph( graph_runtime_state: GraphRuntimeState | None = None, ) -> tuple[Graph, GraphRuntimeState]: graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from="account", invoke_from="debugger", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py index 0275062c41..36bba6deb6 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py @@ -3,7 +3,6 @@ import time from unittest import mock from unittest.mock import MagicMock -from dify_graph.entities import GraphInitParams from dify_graph.graph import Graph from dify_graph.graph_events import ( GraphRunPausedEvent, @@ -34,6 +33,7 @@ from dify_graph.repositories.human_input_form_repository import HumanInputFormEn from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now +from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig from .test_mock_nodes import MockLLMNode @@ -46,11 +46,11 @@ def _build_llm_human_llm_graph( graph_runtime_state: GraphRuntimeState | None = None, ) -> tuple[Graph, GraphRuntimeState]: graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from="account", invoke_from="debugger", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py b/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py index fbcb8d7155..8da179c15e 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py @@ -1,7 +1,6 @@ import time from unittest import mock -from dify_graph.entities import GraphInitParams from dify_graph.graph import Graph from dify_graph.graph_events import ( GraphRunStartedEvent, @@ -29,6 +28,7 @@ from dify_graph.nodes.start.start_node import StartNode from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from dify_graph.utils.condition.entities import Condition +from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig from .test_mock_nodes import MockLLMNode @@ -37,15 +37,10 @@ from .test_table_runner import TableTestRunner, WorkflowTestCase def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]: graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", - workflow_id="workflow", + graph_init_params = build_test_graph_init_params( graph_config=graph_config, - user_id="user", user_from="account", invoke_from="debugger", - call_depth=0, ) variable_pool = VariablePool( diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py index da666ce987..eb449e6d75 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py @@ -5,6 +5,8 @@ Simple test to verify MockNodeFactory works with iteration nodes. import sys from pathlib import Path +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY + # Add api directory to path api_dir = Path(__file__).parent.parent.parent.parent.parent.parent sys.path.insert(0, str(api_dir)) @@ -16,20 +18,23 @@ from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNo def test_mock_factory_registers_iteration_node(): """Test that MockNodeFactory has iteration node registered.""" - from core.app.entities.app_invoke_entities import InvokeFrom + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from dify_graph.entities import GraphInitParams - from dify_graph.enums import UserFrom from dify_graph.runtime import GraphRuntimeState, VariablePool # Create a MockNodeFactory instance graph_init_params = GraphInitParams( - tenant_id="test", - app_id="test", workflow_id="test", graph_config={"nodes": [], "edges": []}, - user_id="test", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test", + "app_id": "test", + "user_id": "test", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) graph_runtime_state = GraphRuntimeState( @@ -65,9 +70,8 @@ def test_mock_factory_registers_iteration_node(): def test_mock_iteration_node_preserves_config(): """Test that MockIterationNode preserves mock configuration.""" - from core.app.entities.app_invoke_entities import InvokeFrom + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from dify_graph.entities import GraphInitParams - from dify_graph.enums import UserFrom from dify_graph.runtime import GraphRuntimeState, VariablePool from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockIterationNode @@ -76,13 +80,17 @@ def test_mock_iteration_node_preserves_config(): # Create minimal graph init params graph_init_params = GraphInitParams( - tenant_id="test", - app_id="test", workflow_id="test", graph_config={"nodes": [], "edges": []}, - user_id="test", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test", + "app_id": "test", + "user_id": "test", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) @@ -127,9 +135,8 @@ def test_mock_iteration_node_preserves_config(): def test_mock_loop_node_preserves_config(): """Test that MockLoopNode preserves mock configuration.""" - from core.app.entities.app_invoke_entities import InvokeFrom + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from dify_graph.entities import GraphInitParams - from dify_graph.enums import UserFrom from dify_graph.runtime import GraphRuntimeState, VariablePool from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockLoopNode @@ -138,13 +145,17 @@ def test_mock_loop_node_preserves_config(): # Create minimal graph init params graph_init_params = GraphInitParams( - tenant_id="test", - app_id="test", workflow_id="test", graph_config={"nodes": [], "edges": []}, - user_id="test", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test", + "app_id": "test", + "user_id": "test", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index 22afbb4909..3f458e9de9 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -603,13 +603,9 @@ class MockIterationNode(MockNodeMixin, IterationNode): # Create GraphInitParams from node attributes graph_init_params = GraphInitParams( - tenant_id=self.tenant_id, - app_id=self.app_id, workflow_id=self.workflow_id, graph_config=self.graph_config, - user_id=self.user_id, - user_from=self.user_from.value, - invoke_from=self.invoke_from.value, + run_context=self.run_context, call_depth=self.workflow_call_depth, ) @@ -679,13 +675,9 @@ class MockLoopNode(MockNodeMixin, LoopNode): # Create GraphInitParams from node attributes graph_init_params = GraphInitParams( - tenant_id=self.tenant_id, - app_id=self.app_id, workflow_id=self.workflow_id, graph_config=self.graph_config, - user_id=self.user_id, - user_from=self.user_from.value, - invoke_from=self.invoke_from.value, + run_context=self.run_context, call_depth=self.workflow_call_depth, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py index 0942d15073..1550dca402 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py @@ -6,6 +6,7 @@ to ensure they work correctly with the TableTestRunner. """ from configs import dify_config +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus from dify_graph.nodes.code.limits import CodeNodeLimits from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig @@ -44,13 +45,17 @@ class TestMockTemplateTransformNode: # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -103,13 +108,17 @@ class TestMockTemplateTransformNode: # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -163,13 +172,17 @@ class TestMockTemplateTransformNode: # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -221,13 +234,17 @@ class TestMockTemplateTransformNode: # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -286,13 +303,17 @@ class TestMockCodeNode: # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -348,13 +369,17 @@ class TestMockCodeNode: # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -418,13 +443,17 @@ class TestMockCodeNode: # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -490,13 +519,17 @@ class TestMockNodeFactory: # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -531,13 +564,17 @@ class TestMockNodeFactory: # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -582,13 +619,17 @@ class TestMockNodeFactory: # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py index 2376423738..84d1444585 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py @@ -5,6 +5,8 @@ Simple test to validate the auto-mock system without external dependencies. import sys from pathlib import Path +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY + # Add api directory to path api_dir = Path(__file__).parent.parent.parent.parent.parent.parent sys.path.insert(0, str(api_dir)) @@ -101,21 +103,24 @@ def test_node_mock_config(): def test_mock_factory_detection(): """Test MockNodeFactory node type detection.""" - from core.app.entities.app_invoke_entities import InvokeFrom + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from dify_graph.entities import GraphInitParams - from dify_graph.enums import UserFrom from dify_graph.runtime import GraphRuntimeState, VariablePool print("Testing MockNodeFactory detection...") graph_init_params = GraphInitParams( - tenant_id="test", - app_id="test", workflow_id="test", graph_config={}, - user_id="test", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test", + "app_id": "test", + "user_id": "test", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) graph_runtime_state = GraphRuntimeState( @@ -154,21 +159,24 @@ def test_mock_factory_detection(): def test_mock_factory_registration(): """Test registering and unregistering mock node types.""" - from core.app.entities.app_invoke_entities import InvokeFrom + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from dify_graph.entities import GraphInitParams - from dify_graph.enums import UserFrom from dify_graph.runtime import GraphRuntimeState, VariablePool print("Testing MockNodeFactory registration...") graph_init_params = GraphInitParams( - tenant_id="test", - app_id="test", workflow_id="test", graph_config={}, - user_id="test", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test", + "app_id": "test", + "user_id": "test", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) graph_runtime_state = GraphRuntimeState( diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py index e269263bde..e681b39cc7 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py @@ -4,7 +4,6 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any, Protocol -from dify_graph.entities import GraphInitParams from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.graph import Graph from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel @@ -32,6 +31,7 @@ from dify_graph.repositories.human_input_form_repository import ( from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now +from tests.workflow_test_utils import build_test_graph_init_params class PauseStateStore(Protocol): @@ -126,11 +126,11 @@ def _build_runtime_state() -> GraphRuntimeState: def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository) -> Graph: graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from="account", invoke_from="debugger", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py index 910292b52c..60167c0441 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py @@ -4,7 +4,6 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any -from dify_graph.entities import GraphInitParams from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.graph import Graph from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel @@ -39,6 +38,7 @@ from dify_graph.repositories.human_input_form_repository import ( from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now +from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig, NodeMockConfig from .test_mock_nodes import MockLLMNode @@ -129,11 +129,11 @@ def _build_runtime_state() -> GraphRuntimeState: def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository, mock_config: MockConfig) -> Graph: graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from="account", invoke_from="debugger", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py index 8e6b30896f..0ac9d6618d 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py @@ -12,11 +12,10 @@ import time from unittest.mock import MagicMock, patch from uuid import uuid4 -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.model_manager import ModelInstance from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities import GraphInitParams -from dify_graph.enums import NodeType, UserFrom, WorkflowNodeExecutionStatus +from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus from dify_graph.graph import Graph from dify_graph.graph_engine import GraphEngine, GraphEngineConfig from dify_graph.graph_engine.command_channels import InMemoryChannel @@ -30,6 +29,7 @@ from dify_graph.node_events import NodeRunResult, StreamCompletedEvent from dify_graph.nodes.llm.node import LLMNode from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params from .test_table_runner import TableTestRunner @@ -86,11 +86,11 @@ def test_parallel_streaming_workflow(): graph_config = workflow_config.get("graph", {}) # Create graph initialization parameters - init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", + init_params = build_test_graph_init_params( workflow_id="test_workflow", graph_config=graph_config, + tenant_id="test_tenant", + app_id="test_app", user_id="test_user", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.WEB_APP, @@ -99,8 +99,8 @@ def test_parallel_streaming_workflow(): # Create variable pool with system variables system_variables = SystemVariable( - user_id=init_params.user_id, - app_id=init_params.app_id, + user_id="test_user", + app_id="test_app", workflow_id=init_params.workflow_id, files=[], query="Tell me about yourself", # User query diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py index e5a9a29a1f..7328ce443f 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py @@ -4,7 +4,6 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any -from dify_graph.entities import GraphInitParams from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.graph import Graph from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel @@ -40,6 +39,7 @@ from dify_graph.repositories.human_input_form_repository import ( from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now +from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig, NodeMockConfig from .test_mock_nodes import MockLLMNode @@ -121,11 +121,11 @@ def _build_runtime_state() -> GraphRuntimeState: def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository, mock_config: MockConfig) -> Graph: graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from="account", invoke_from="debugger", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py index 183d589e2b..15a7de3c52 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py @@ -3,7 +3,6 @@ import time from typing import Any from unittest.mock import MagicMock -from dify_graph.entities import GraphInitParams from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.graph import Graph from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel @@ -30,6 +29,7 @@ from dify_graph.repositories.human_input_form_repository import ( from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now +from tests.workflow_test_utils import build_test_graph_init_params def _build_runtime_state() -> GraphRuntimeState: @@ -79,11 +79,11 @@ def _build_human_input_graph( form_repository: HumanInputFormRepository, ) -> Graph: graph_config: dict[str, object] = {"nodes": [], "edges": []} - params = GraphInitParams( - tenant_id="tenant", - app_id="app", + params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from="account", invoke_from="service-api", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py index ee36c976f0..767a8f60ce 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py @@ -12,19 +12,21 @@ This module provides a robust table-driven testing framework with support for: import logging import time -from collections.abc import Callable, Sequence +from collections.abc import Callable, Mapping, Sequence from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass, field from functools import lru_cache from pathlib import Path -from typing import Any +from typing import Any, cast +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.tools.utils.yaml_utils import _load_yaml_file from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities.graph_init_params import GraphInitParams +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams from dify_graph.graph import Graph from dify_graph.graph_engine import GraphEngine, GraphEngineConfig from dify_graph.graph_engine.command_channels import InMemoryChannel +from dify_graph.graph_engine.layers.base import GraphEngineLayer from dify_graph.graph_events import ( GraphEngineEvent, GraphRunStartedEvent, @@ -48,6 +50,47 @@ from .test_mock_factory import MockNodeFactory logger = logging.getLogger(__name__) +class _TableTestChildEngineBuilder: + def __init__(self, *, use_mock_factory: bool, mock_config: MockConfig | None) -> None: + self._use_mock_factory = use_mock_factory + self._mock_config = mock_config + + def build_child_engine( + self, + *, + workflow_id: str, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + graph_config: Mapping[str, Any], + root_node_id: str, + layers: Sequence[object] = (), + ) -> GraphEngine: + if self._use_mock_factory: + node_factory = MockNodeFactory( + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + mock_config=self._mock_config, + ) + else: + node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state) + + child_graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=root_node_id) + if not child_graph: + raise ValueError("child graph not found") + + child_engine = GraphEngine( + workflow_id=workflow_id, + graph=child_graph, + graph_runtime_state=graph_runtime_state, + command_channel=InMemoryChannel(), + config=GraphEngineConfig(), + child_engine_builder=self, + ) + for layer in layers: + child_engine.layer(cast(GraphEngineLayer, layer)) + return child_engine + + @dataclass class WorkflowTestCase: """Represents a single test case for table-driven testing.""" @@ -149,19 +192,23 @@ class WorkflowRunner: raise ValueError("Fixture missing workflow.graph configuration") graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config=graph_config, - user_id="test_user", - user_from="account", - invoke_from="debugger", # Set to debugger to avoid conversation_id requirement + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, # Set to debugger to avoid conversation_id requirement + } + }, call_depth=0, ) system_variables = SystemVariable( - user_id=graph_init_params.user_id, - app_id=graph_init_params.app_id, + user_id="test_user", + app_id="test_app", workflow_id=graph_init_params.workflow_id, files=[], query=query, @@ -315,6 +362,10 @@ class TableTestRunner: scale_up_threshold=self.graph_engine_scale_up_threshold, scale_down_idle_time=self.graph_engine_scale_down_idle_time, ), + child_engine_builder=_TableTestChildEngineBuilder( + use_mock_factory=test_case.use_auto_mock, + mock_config=test_case.mock_config, + ), ) # Execute and collect events diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py index b1351c9fc3..f0d80af1ed 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -2,15 +2,15 @@ import time import uuid from unittest.mock import MagicMock -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities import GraphInitParams -from dify_graph.enums import UserFrom, WorkflowNodeExecutionStatus +from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.graph import Graph from dify_graph.nodes.answer.answer_node import AnswerNode from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from extensions.ext_database import db +from tests.workflow_test_utils import build_test_graph_init_params def test_execute_answer(): @@ -35,11 +35,11 @@ def test_execute_answer(): ], } - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, diff --git a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py index b100c7c02c..db096b1aed 100644 --- a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py @@ -1,3 +1,4 @@ +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus from dify_graph.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent from dify_graph.nodes.datasource.datasource_node import DatasourceNode @@ -28,13 +29,17 @@ class _GraphState: class _GraphParams: - tenant_id = "t1" - app_id = "app-1" workflow_id = "wf-1" graph_config = {} - user_id = "u1" - user_from = "account" - invoke_from = "debugger" + run_context = { + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "t1", + "app_id": "app-1", + "user_id": "u1", + "user_from": "account", + "invoke_from": "debugger", + } + } call_depth = 0 diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py index 38e68dcdc9..5e34bf1d94 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py @@ -4,16 +4,16 @@ from typing import Any import httpx import pytest -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.helper.ssrf_proxy import ssrf_proxy from core.tools.tool_file_manager import ToolFileManager -from dify_graph.entities import GraphInitParams -from dify_graph.enums import UserFrom, WorkflowNodeExecutionStatus +from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.file.file_manager import file_manager from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig from dify_graph.nodes.http_request.entities import HttpRequestNodeTimeout, Response from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( max_connect_timeout=10, @@ -98,11 +98,11 @@ def _build_http_node( ], "edges": [], } - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py index 3b2a81ccef..55aa62a1c0 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py @@ -9,6 +9,7 @@ import pytest from pydantic import ValidationError from dify_graph.entities import GraphInitParams +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.node_events import PauseRequestedEvent from dify_graph.node_events.node import StreamCompletedEvent from dify_graph.nodes.human_input.entities import ( @@ -314,13 +315,17 @@ class TestHumanInputNodeVariableResolution: variable_pool.add(("start", "name"), "Jane Doe") runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", workflow_id="workflow", graph_config={"nodes": [], "edges": []}, - user_id="user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "tenant", + "app_id": "app", + "user_id": "user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -384,13 +389,17 @@ class TestHumanInputNodeVariableResolution: ) runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", workflow_id="workflow", graph_config={"nodes": [], "edges": []}, - user_id="user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "tenant", + "app_id": "app", + "user_id": "user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -439,13 +448,17 @@ class TestHumanInputNodeVariableResolution: ) runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", workflow_id="workflow", graph_config={"nodes": [], "edges": []}, - user_id="user-123", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "tenant", + "app_id": "app", + "user_id": "user-123", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -550,13 +563,17 @@ class TestHumanInputNodeRenderedContent: ) runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", workflow_id="workflow", graph_config={"nodes": [], "edges": []}, - user_id="user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "tenant", + "app_id": "app", + "user_id": "user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py index 46b4f1ed37..1fea19e795 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py @@ -1,9 +1,9 @@ import datetime from types import SimpleNamespace -from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.entities.graph_init_params import GraphInitParams -from dify_graph.enums import NodeType, UserFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams +from dify_graph.enums import NodeType from dify_graph.graph_events import ( NodeRunHumanInputFormFilledEvent, NodeRunHumanInputFormTimeoutEvent, @@ -31,13 +31,17 @@ def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name# start_at=0.0, ) graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", workflow_id="workflow", graph_config={"nodes": [], "edges": []}, - user_id="user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "tenant", + "app_id": "app", + "user_id": "user", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) @@ -91,13 +95,17 @@ def _build_timeout_node() -> HumanInputNode: start_at=0.0, ) graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", workflow_id="workflow", graph_config={"nodes": [], "edges": []}, - user_id="user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "tenant", + "app_id": "app", + "user_id": "user", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py new file mode 100644 index 0000000000..2eb4feef5f --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py @@ -0,0 +1,100 @@ +from collections.abc import Mapping, Sequence +from typing import Any + +import pytest + +from dify_graph.entities import GraphInitParams +from dify_graph.nodes.iteration.exc import IterationGraphNotFoundError +from dify_graph.nodes.iteration.iteration_node import IterationNode +from dify_graph.runtime import ( + ChildEngineBuilderNotConfiguredError, + ChildGraphNotFoundError, + GraphRuntimeState, + VariablePool, +) +from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params + + +class _MissingGraphBuilder: + def build_child_engine( + self, + *, + workflow_id: str, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + graph_config: Mapping[str, Any], + root_node_id: str, + layers: Sequence[object] = (), + ) -> object: + raise ChildGraphNotFoundError(f"child graph root node '{root_node_id}' not found") + + +def _build_runtime_state() -> GraphRuntimeState: + return GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable.default(), user_inputs={}), + start_at=0.0, + ) + + +def _build_iteration_node( + *, + graph_config: Mapping[str, Any], + runtime_state: GraphRuntimeState, + start_node_id: str, +) -> IterationNode: + init_params = build_test_graph_init_params(graph_config=graph_config) + return IterationNode( + id="iteration-node", + config={ + "id": "iteration-node", + "data": { + "type": "iteration", + "title": "Iteration", + "iterator_selector": ["start", "items"], + "output_selector": ["iteration-node", "output"], + "start_node_id": start_node_id, + }, + }, + graph_init_params=init_params, + graph_runtime_state=runtime_state, + ) + + +def test_graph_runtime_state_raises_specific_error_when_child_builder_is_missing(): + runtime_state = _build_runtime_state() + graph_init_params = build_test_graph_init_params() + + with pytest.raises(ChildEngineBuilderNotConfiguredError): + runtime_state.create_child_engine( + workflow_id="workflow", + graph_init_params=graph_init_params, + graph_runtime_state=_build_runtime_state(), + graph_config={}, + root_node_id="root", + ) + + +def test_iteration_node_only_translates_child_graph_not_found_error(): + runtime_state = _build_runtime_state() + runtime_state.bind_child_engine_builder(_MissingGraphBuilder()) + node = _build_iteration_node( + graph_config={"nodes": [{"id": "present-node"}], "edges": []}, + runtime_state=runtime_state, + start_node_id="missing-node", + ) + + with pytest.raises(IterationGraphNotFoundError): + node._create_graph_engine(index=0, item="item") + + +def test_iteration_node_propagates_non_graph_not_found_errors(): + runtime_state = _build_runtime_state() + node = _build_iteration_node( + graph_config={"nodes": [{"id": "start-node"}], "edges": []}, + runtime_state=runtime_state, + start_node_id="start-node", + ) + + with pytest.raises(ChildEngineBuilderNotConfiguredError): + node._create_graph_engine(index=0, item="item") diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py index ffb9c0a43f..8116fc8b3c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py @@ -4,9 +4,8 @@ from unittest.mock import Mock import pytest -from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.entities import GraphInitParams -from dify_graph.enums import SystemVariableKey, UserFrom, WorkflowNodeExecutionStatus +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from dify_graph.enums import SystemVariableKey, WorkflowNodeExecutionStatus from dify_graph.nodes.knowledge_index.entities import KnowledgeIndexNodeData from dify_graph.nodes.knowledge_index.exc import KnowledgeIndexNodeError from dify_graph.nodes.knowledge_index.knowledge_index_node import KnowledgeIndexNode @@ -15,16 +14,17 @@ from dify_graph.repositories.summary_index_service_protocol import SummaryIndexS from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from dify_graph.variables.segments import StringSegment +from tests.workflow_test_utils import build_test_graph_init_params @pytest.fixture def mock_graph_init_params(): """Create mock GraphInitParams.""" - return GraphInitParams( - tenant_id=str(uuid.uuid4()), - app_id=str(uuid.uuid4()), + return build_test_graph_init_params( workflow_id=str(uuid.uuid4()), graph_config={}, + tenant_id=str(uuid.uuid4()), + app_id=str(uuid.uuid4()), user_id=str(uuid.uuid4()), user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py index 45a4316ead..e194d66ee3 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py @@ -4,9 +4,8 @@ from unittest.mock import Mock import pytest -from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.entities import GraphInitParams -from dify_graph.enums import UserFrom, WorkflowNodeExecutionStatus +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.nodes.knowledge_retrieval.entities import ( KnowledgeRetrievalNodeData, @@ -20,16 +19,17 @@ from dify_graph.repositories.rag_retrieval_protocol import RAGRetrievalProtocol, from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from dify_graph.variables import StringSegment +from tests.workflow_test_utils import build_test_graph_init_params @pytest.fixture def mock_graph_init_params(): """Create mock GraphInitParams.""" - return GraphInitParams( - tenant_id=str(uuid.uuid4()), - app_id=str(uuid.uuid4()), + return build_test_graph_init_params( workflow_id=str(uuid.uuid4()), graph_config={}, + tenant_id=str(uuid.uuid4()), + app_id=str(uuid.uuid4()), user_id=str(uuid.uuid4()), user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, diff --git a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py index 65228df517..25760ba352 100644 --- a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py @@ -1,14 +1,13 @@ from unittest.mock import MagicMock import pytest -from dify_graph.graph_engine.entities.graph import Graph -from dify_graph.graph_engine.entities.graph_init_params import GraphInitParams -from dify_graph.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from dify_graph.entities import GraphInitParams +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus from dify_graph.nodes.list_operator.node import ListOperatorNode +from dify_graph.runtime import GraphRuntimeState from dify_graph.variables import ArrayNumberSegment, ArrayStringSegment -from models.workflow import WorkflowType class TestListOperatorNode: @@ -22,43 +21,40 @@ class TestListOperatorNode: mock_state.variable_pool = mock_variable_pool return mock_state - @pytest.fixture - def mock_graph(self): - """Create mock Graph.""" - return MagicMock(spec=Graph) - @pytest.fixture def graph_init_params(self): """Create GraphInitParams fixture.""" return GraphInitParams( - tenant_id="test", - app_id="test", - workflow_type=WorkflowType.WORKFLOW, workflow_id="test", graph_config={}, - user_id="test", - user_from="test", - invoke_from="test", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test", + "app_id": "test", + "user_id": "test", + "user_from": "test", + "invoke_from": "test", + } + }, call_depth=0, ) @pytest.fixture - def list_operator_node_factory(self, graph_init_params, mock_graph, mock_graph_runtime_state): + def list_operator_node_factory(self, graph_init_params, mock_graph_runtime_state): """Factory fixture for creating ListOperatorNode instances.""" def _create_node(config, mock_variable): mock_graph_runtime_state.variable_pool.get.return_value = mock_variable return ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) return _create_node - def test_node_initialization(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_node_initialization(self, mock_graph_runtime_state, graph_init_params): """Test node initializes correctly.""" config = { "title": "List Operator", @@ -70,9 +66,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -101,7 +96,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == ["apple", "banana", "cherry"] - def test_run_with_empty_array(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_empty_array(self, mock_graph_runtime_state, graph_init_params): """Test with empty array.""" config = { "title": "Test", @@ -116,9 +111,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -129,7 +123,7 @@ class TestListOperatorNode: assert result.outputs["first_record"] is None assert result.outputs["last_record"] is None - def test_run_with_filter_contains(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_filter_contains(self, mock_graph_runtime_state, graph_init_params): """Test filter with contains condition.""" config = { "title": "Test", @@ -148,9 +142,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -159,7 +152,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == ["apple", "pineapple"] - def test_run_with_filter_not_contains(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_filter_not_contains(self, mock_graph_runtime_state, graph_init_params): """Test filter with not contains condition.""" config = { "title": "Test", @@ -178,9 +171,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -189,7 +181,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == ["banana", "cherry"] - def test_run_with_number_filter_greater_than(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_number_filter_greater_than(self, mock_graph_runtime_state, graph_init_params): """Test filter with greater than condition on numbers.""" config = { "title": "Test", @@ -208,9 +200,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -219,7 +210,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == [7, 9, 11] - def test_run_with_order_ascending(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_order_ascending(self, mock_graph_runtime_state, graph_init_params): """Test ordering in ascending order.""" config = { "title": "Test", @@ -237,9 +228,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -248,7 +238,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == ["apple", "banana", "cherry"] - def test_run_with_order_descending(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_order_descending(self, mock_graph_runtime_state, graph_init_params): """Test ordering in descending order.""" config = { "title": "Test", @@ -266,9 +256,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -277,7 +266,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == ["cherry", "banana", "apple"] - def test_run_with_limit(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_limit(self, mock_graph_runtime_state, graph_init_params): """Test with limit enabled.""" config = { "title": "Test", @@ -295,9 +284,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -306,7 +294,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == ["apple", "banana"] - def test_run_with_filter_order_and_limit(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_filter_order_and_limit(self, mock_graph_runtime_state, graph_init_params): """Test with filter, order, and limit combined.""" config = { "title": "Test", @@ -331,9 +319,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -342,7 +329,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == [9, 8, 7] - def test_run_with_variable_not_found(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_variable_not_found(self, mock_graph_runtime_state, graph_init_params): """Test when variable is not found.""" config = { "title": "Test", @@ -356,9 +343,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -367,7 +353,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.FAILED assert "Variable not found" in result.error - def test_run_with_first_and_last_record(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_first_and_last_record(self, mock_graph_runtime_state, graph_init_params): """Test first_record and last_record outputs.""" config = { "title": "Test", @@ -382,9 +368,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -394,7 +379,7 @@ class TestListOperatorNode: assert result.outputs["first_record"] == "first" assert result.outputs["last_record"] == "last" - def test_run_with_filter_startswith(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_filter_startswith(self, mock_graph_runtime_state, graph_init_params): """Test filter with startswith condition.""" config = { "title": "Test", @@ -413,9 +398,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -424,7 +408,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == ["apple", "application"] - def test_run_with_filter_endswith(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_filter_endswith(self, mock_graph_runtime_state, graph_init_params): """Test filter with endswith condition.""" config = { "title": "Test", @@ -443,9 +427,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -454,7 +437,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == ["apple", "pineapple", "table"] - def test_run_with_number_filter_equals(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_number_filter_equals(self, mock_graph_runtime_state, graph_init_params): """Test number filter with equals condition.""" config = { "title": "Test", @@ -473,9 +456,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -484,7 +466,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == [5, 5] - def test_run_with_number_filter_not_equals(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_number_filter_not_equals(self, mock_graph_runtime_state, graph_init_params): """Test number filter with not equals condition.""" config = { "title": "Test", @@ -503,9 +485,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -514,7 +495,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == [1, 3, 7, 9] - def test_run_with_number_order_ascending(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_number_order_ascending(self, mock_graph_runtime_state, graph_init_params): """Test number ordering in ascending order.""" config = { "title": "Test", @@ -532,9 +513,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index b822f1fbe4..90308facc3 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -5,14 +5,13 @@ from unittest import mock import pytest -from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity +from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity, UserFrom from core.app.llm.model_access import DifyCredentialsProvider, DifyModelFactory, fetch_model_config from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import CustomConfiguration, SystemConfiguration from core.model_manager import ModelInstance from core.prompt.entities.advanced_prompt_entities import MemoryConfig from dify_graph.entities import GraphInitParams -from dify_graph.enums import UserFrom from dify_graph.file import File, FileTransferMethod, FileType from dify_graph.model_runtime.entities.common_entities import I18nObject from dify_graph.model_runtime.entities.message_entities import ( @@ -41,6 +40,7 @@ from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from dify_graph.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment from models.provider import ProviderType +from tests.workflow_test_utils import build_test_graph_init_params class MockTokenBufferMemory: @@ -76,11 +76,11 @@ def llm_node_data() -> LLMNodeData: @pytest.fixture def graph_init_params() -> GraphInitParams: - return GraphInitParams( - tenant_id="1", - app_id="1", + return build_test_graph_init_params( workflow_id="1", graph_config={}, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.SERVICE_API, diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py index 48d76d9b9b..6831626f58 100644 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py @@ -2,15 +2,13 @@ from unittest.mock import MagicMock import pytest -from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.entities import GraphInitParams +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from dify_graph.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus from dify_graph.graph import Graph from dify_graph.nodes.template_transform.template_renderer import TemplateRenderError from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode from dify_graph.runtime import GraphRuntimeState -from models.enums import UserFrom -from models.workflow import WorkflowType +from tests.workflow_test_utils import build_test_graph_init_params class TestTemplateTransformNode: @@ -32,12 +30,11 @@ class TestTemplateTransformNode: @pytest.fixture def graph_init_params(self): """Create a mock GraphInitParams.""" - return GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", - workflow_type=WorkflowType.WORKFLOW, + return build_test_graph_init_params( workflow_id="test_workflow", graph_config={}, + tenant_id="test_tenant", + app_id="test_app", user_id="test_user", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, diff --git a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py index ff9247059b..44abf430c0 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py @@ -2,13 +2,14 @@ from collections.abc import Mapping import pytest -from core.app.entities.app_invoke_entities import InvokeFrom as LegacyInvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from dify_graph.entities import GraphInitParams -from dify_graph.enums import InvokeFrom, NodeType, UserFrom +from dify_graph.enums import NodeType from dify_graph.nodes.base.entities import BaseNodeData from dify_graph.nodes.base.node import Node from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params class _SampleNodeData(BaseNodeData): @@ -27,15 +28,10 @@ class _SampleNode(Node[_SampleNodeData]): def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, GraphRuntimeState]: - init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", - workflow_id="workflow", + init_params = build_test_graph_init_params( graph_config=graph_config, - user_id="user", user_from="account", invoke_from="debugger", - call_depth=0, ) runtime_state = GraphRuntimeState( variable_pool=VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}), @@ -57,21 +53,17 @@ def test_node_hydrates_data_during_initialization(): assert node.node_data.foo == "bar" assert node.title == "Sample" - assert node.user_from == UserFrom.ACCOUNT - assert node.invoke_from == InvokeFrom.DEBUGGER + dify_ctx = node.require_dify_context() + assert dify_ctx.user_from == "account" + assert dify_ctx.invoke_from == "debugger" -def test_node_normalizes_legacy_invoke_from_enum(): +def test_node_accepts_invoke_from_enum(): graph_config: dict[str, object] = {} - init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", - workflow_id="workflow", + init_params = build_test_graph_init_params( graph_config=graph_config, - user_id="user", user_from=UserFrom.ACCOUNT, - invoke_from=LegacyInvokeFrom.DEBUGGER, - call_depth=0, + invoke_from=InvokeFrom.DEBUGGER, ) runtime_state = GraphRuntimeState( variable_pool=VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}), @@ -85,8 +77,12 @@ def test_node_normalizes_legacy_invoke_from_enum(): graph_runtime_state=runtime_state, ) - assert node.user_from == UserFrom.ACCOUNT - assert node.invoke_from == InvokeFrom.DEBUGGER + dify_ctx = node.require_dify_context() + assert dify_ctx.user_from == UserFrom.ACCOUNT + assert dify_ctx.invoke_from == InvokeFrom.DEBUGGER + assert node.get_run_context_value("missing") is None + with pytest.raises(ValueError): + node.require_run_context_value("missing") def test_missing_generic_argument_raises_type_error(): diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py index dff84b580a..5e20b1e12f 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py @@ -5,9 +5,9 @@ import pandas as pd import pytest from docx.oxml.text.paragraph import CT_P -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from dify_graph.entities import GraphInitParams -from dify_graph.enums import NodeType, UserFrom, WorkflowNodeExecutionStatus +from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus from dify_graph.file import File, FileTransferMethod from dify_graph.node_events import NodeRunResult from dify_graph.nodes.document_extractor import DocumentExtractorNode, DocumentExtractorNodeData @@ -20,15 +20,16 @@ from dify_graph.nodes.document_extractor.node import ( from dify_graph.variables import ArrayFileSegment from dify_graph.variables.segments import ArrayStringSegment from dify_graph.variables.variables import StringVariable +from tests.workflow_test_utils import build_test_graph_init_params @pytest.fixture def graph_init_params() -> GraphInitParams: - return GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", + return build_test_graph_init_params( workflow_id="test_workflow", graph_config={}, + tenant_id="test_tenant", + app_id="test_app", user_id="test_user", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index 23b96e7b25..041bd66d03 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -4,10 +4,10 @@ from unittest.mock import MagicMock, Mock import pytest -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities import GraphInitParams -from dify_graph.enums import UserFrom, WorkflowNodeExecutionStatus +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY +from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.file import File, FileTransferMethod, FileType from dify_graph.graph import Graph from dify_graph.nodes.if_else.entities import IfElseNodeData @@ -17,16 +17,17 @@ from dify_graph.system_variable import SystemVariable from dify_graph.utils.condition.entities import Condition, SubCondition, SubVariableCondition from dify_graph.variables import ArrayFileSegment from extensions.ext_database import db +from tests.workflow_test_utils import build_test_graph_init_params def test_execute_if_else_result_true(): graph_config = {"edges": [], "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}]} - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, @@ -128,11 +129,11 @@ def test_execute_if_else_result_false(): # Create a simple graph for IfElse node testing graph_config = {"edges": [], "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}]} - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, @@ -229,14 +230,18 @@ def test_array_file_contains_file_name(): # Create properly configured mock for graph_init_params graph_init_params = Mock() - graph_init_params.tenant_id = "test_tenant" - graph_init_params.app_id = "test_app" graph_init_params.workflow_id = "test_workflow" graph_init_params.graph_config = {} - graph_init_params.user_id = "test_user" - graph_init_params.user_from = UserFrom.ACCOUNT - graph_init_params.invoke_from = InvokeFrom.SERVICE_API graph_init_params.call_depth = 0 + graph_init_params.run_context = { + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + } node = IfElseNode( id=str(uuid.uuid4()), @@ -298,11 +303,11 @@ def test_execute_if_else_boolean_conditions(condition: Condition): """Test IfElseNode with boolean conditions using various operators""" graph_config = {"edges": [], "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}]} - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, @@ -353,11 +358,11 @@ def test_execute_if_else_boolean_false_conditions(): """Test IfElseNode with boolean conditions that should evaluate to false""" graph_config = {"edges": [], "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}]} - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, @@ -422,11 +427,11 @@ def test_execute_if_else_boolean_cases_structure(): """Test IfElseNode with boolean conditions using the new cases structure""" graph_config = {"edges": [], "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}]} - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py index 4c43f63c74..6ca72b64b2 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -2,8 +2,9 @@ from unittest.mock import MagicMock import pytest -from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.enums import UserFrom, WorkflowNodeExecutionStatus +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY +from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.file import File, FileTransferMethod, FileType from dify_graph.nodes.list_operator.entities import ( ExtractConfig, @@ -41,14 +42,18 @@ def list_operator_node(): } # Create properly configured mock for graph_init_params graph_init_params = MagicMock() - graph_init_params.tenant_id = "test_tenant" - graph_init_params.app_id = "test_app" graph_init_params.workflow_id = "test_workflow" graph_init_params.graph_config = {} - graph_init_params.user_id = "test_user" - graph_init_params.user_from = UserFrom.ACCOUNT - graph_init_params.invoke_from = InvokeFrom.SERVICE_API graph_init_params.call_depth = 0 + graph_init_params.run_context = { + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + } node = ListOperatorNode( id="test_node_id", diff --git a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py index 5fd1e33768..b8f0e25e91 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py @@ -4,12 +4,12 @@ import time import pytest from pydantic import ValidationError as PydanticValidationError -from dify_graph.entities import GraphInitParams from dify_graph.nodes.start.entities import StartNodeData from dify_graph.nodes.start.start_node import StartNode from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from dify_graph.variables.input_entities import VariableEntity, VariableEntityType +from tests.workflow_test_utils import build_test_graph_init_params def make_start_node(user_inputs, variables): @@ -32,11 +32,11 @@ def make_start_node(user_inputs, variables): return StartNode( id="start", config=config, - graph_init_params=GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params=build_test_graph_init_params( workflow_id="wf", graph_config={}, + tenant_id="tenant", + app_id="app", user_id="u", user_from="account", invoke_from="debugger", diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py index 3d88baa272..11554169e1 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py @@ -10,13 +10,13 @@ import pytest from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.utils.message_transformer import ToolFileMessageTransformer -from dify_graph.entities import GraphInitParams from dify_graph.file import File, FileTransferMethod, FileType from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.node_events import StreamChunkEvent, StreamCompletedEvent from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from dify_graph.variables.segments import ArrayFileSegment +from tests.workflow_test_utils import build_test_graph_init_params if TYPE_CHECKING: # pragma: no cover - imported for type checking only from dify_graph.nodes.tool.tool_node import ToolNode @@ -54,11 +54,11 @@ def tool_node(monkeypatch) -> ToolNode: "edges": [], } - init_params = GraphInitParams( - tenant_id="tenant-id", - app_id="app-id", + init_params = build_test_graph_init_params( workflow_id="workflow-id", graph_config=graph_config, + tenant_id="tenant-id", + app_id="app-id", user_id="user-id", user_from="account", invoke_from="debugger", diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py index d5a593ffab..2cd3a38fa6 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py @@ -2,10 +2,10 @@ import time import uuid from uuid import uuid4 -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory from dify_graph.entities import GraphInitParams -from dify_graph.enums import UserFrom +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.graph import Graph from dify_graph.graph_events.node import NodeRunSucceededEvent from dify_graph.nodes.variable_assigner.common import helpers as common_helpers @@ -43,13 +43,17 @@ def test_overwrite_string_variable(): } init_params = GraphInitParams( - tenant_id="1", - app_id="1", workflow_id="1", graph_config=graph_config, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "1", + "app_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ) @@ -141,13 +145,17 @@ def test_append_variable_to_array(): } init_params = GraphInitParams( - tenant_id="1", - app_id="1", workflow_id="1", graph_config=graph_config, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "1", + "app_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ) @@ -236,13 +244,17 @@ def test_clear_array(): } init_params = GraphInitParams( - tenant_id="1", - app_id="1", workflow_id="1", graph_config=graph_config, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "1", + "app_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py index ef816b9ddc..5b285c2681 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py @@ -2,10 +2,10 @@ import time import uuid from uuid import uuid4 -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory from dify_graph.entities import GraphInitParams -from dify_graph.enums import UserFrom +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.graph import Graph from dify_graph.nodes.variable_assigner.v2 import VariableAssignerNode from dify_graph.nodes.variable_assigner.v2.enums import InputType, Operation @@ -85,13 +85,17 @@ def test_remove_first_from_array(): } init_params = GraphInitParams( - tenant_id="1", - app_id="1", workflow_id="1", graph_config=graph_config, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "1", + "app_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ) @@ -169,13 +173,17 @@ def test_remove_last_from_array(): } init_params = GraphInitParams( - tenant_id="1", - app_id="1", workflow_id="1", graph_config=graph_config, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "1", + "app_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ) @@ -250,13 +258,17 @@ def test_remove_first_from_empty_array(): } init_params = GraphInitParams( - tenant_id="1", - app_id="1", workflow_id="1", graph_config=graph_config, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "1", + "app_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ) @@ -331,13 +343,17 @@ def test_remove_last_from_empty_array(): } init_params = GraphInitParams( - tenant_id="1", - app_id="1", workflow_id="1", graph_config=graph_config, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "1", + "app_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ) @@ -404,13 +420,17 @@ def test_node_factory_creates_variable_assigner_node(): } init_params = GraphInitParams( - tenant_id="1", - app_id="1", workflow_id="1", graph_config=graph_config, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "1", + "app_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ) variable_pool = VariablePool( diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py index 5c89ba7d34..c750e74182 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py @@ -8,10 +8,9 @@ when passing files to downstream LLM nodes. from unittest.mock import Mock, patch -from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.entities.graph_init_params import GraphInitParams +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.enums import UserFrom from dify_graph.nodes.trigger_webhook.entities import ( ContentType, Method, @@ -22,7 +21,6 @@ from dify_graph.nodes.trigger_webhook.node import TriggerWebhookNode from dify_graph.runtime.graph_runtime_state import GraphRuntimeState from dify_graph.runtime.variable_pool import VariablePool from dify_graph.system_variable import SystemVariable -from models.workflow import WorkflowType def create_webhook_node( @@ -37,14 +35,17 @@ def create_webhook_node( } graph_init_params = GraphInitParams( - tenant_id=tenant_id, - app_id="test-app", - workflow_type=WorkflowType.WORKFLOW, workflow_id="test-workflow", graph_config={}, - user_id="test-user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": tenant_id, + "app_id": "test-app", + "user_id": "test-user", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py index 066ec5542d..df13bbb92f 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py @@ -2,10 +2,9 @@ from unittest.mock import patch import pytest -from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.entities.graph_init_params import GraphInitParams +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.enums import UserFrom from dify_graph.file import File, FileTransferMethod, FileType from dify_graph.nodes.trigger_webhook.entities import ( ContentType, @@ -19,7 +18,6 @@ from dify_graph.runtime.graph_runtime_state import GraphRuntimeState from dify_graph.runtime.variable_pool import VariablePool from dify_graph.system_variable import SystemVariable from dify_graph.variables import FileVariable, StringVariable -from models.workflow import WorkflowType def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) -> TriggerWebhookNode: @@ -30,14 +28,17 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) } graph_init_params = GraphInitParams( - tenant_id="1", - app_id="1", - workflow_type=WorkflowType.WORKFLOW, workflow_id="1", graph_config={}, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "1", + "app_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) runtime_state = GraphRuntimeState( diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py b/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py index 31644edcd8..9969c953e8 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py @@ -2,9 +2,8 @@ from unittest.mock import MagicMock, patch -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.workflow_entry import WorkflowEntry -from dify_graph.enums import UserFrom from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel from dify_graph.runtime import GraphRuntimeState, VariablePool diff --git a/api/tests/workflow_test_utils.py b/api/tests/workflow_test_utils.py new file mode 100644 index 0000000000..1f0bf8ef37 --- /dev/null +++ b/api/tests/workflow_test_utils.py @@ -0,0 +1,53 @@ +from collections.abc import Mapping +from typing import Any + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context +from dify_graph.entities.graph_init_params import GraphInitParams + + +def build_test_run_context( + *, + tenant_id: str = "tenant", + app_id: str = "app", + user_id: str = "user", + user_from: UserFrom | str = UserFrom.ACCOUNT, + invoke_from: InvokeFrom | str = InvokeFrom.DEBUGGER, + extra_context: Mapping[str, Any] | None = None, +) -> dict[str, Any]: + normalized_user_from = user_from if isinstance(user_from, UserFrom) else UserFrom(user_from) + normalized_invoke_from = invoke_from if isinstance(invoke_from, InvokeFrom) else InvokeFrom(invoke_from) + return build_dify_run_context( + tenant_id=tenant_id, + app_id=app_id, + user_id=user_id, + user_from=normalized_user_from, + invoke_from=normalized_invoke_from, + extra_context=extra_context, + ) + + +def build_test_graph_init_params( + *, + workflow_id: str = "workflow", + graph_config: Mapping[str, Any] | None = None, + call_depth: int = 0, + tenant_id: str = "tenant", + app_id: str = "app", + user_id: str = "user", + user_from: UserFrom | str = UserFrom.ACCOUNT, + invoke_from: InvokeFrom | str = InvokeFrom.DEBUGGER, + extra_context: Mapping[str, Any] | None = None, +) -> GraphInitParams: + return GraphInitParams( + workflow_id=workflow_id, + graph_config=graph_config or {}, + run_context=build_test_run_context( + tenant_id=tenant_id, + app_id=app_id, + user_id=user_id, + user_from=user_from, + invoke_from=invoke_from, + extra_context=extra_context, + ), + call_depth=call_depth, + ) From 1819b87a562e98632a6a3263dd4d170ba99ef9ef Mon Sep 17 00:00:00 2001 From: Coding On Star <447357187@qq.com> Date: Thu, 5 Mar 2026 14:34:07 +0800 Subject: [PATCH 2/5] test(workflow): add validation tests for workflow and node component rendering part 3 (#33012) Co-authored-by: CodingOnStar --- .../components/__tests__/index.spec.tsx | 8 +- .../__tests__/workflow-test-env.spec.tsx | 136 ++++++++++++ .../workflow/__tests__/workflow-test-env.tsx | 199 ++++++++++++++---- 3 files changed, 301 insertions(+), 42 deletions(-) create mode 100644 web/app/components/workflow/__tests__/workflow-test-env.spec.tsx diff --git a/web/app/components/rag-pipeline/components/__tests__/index.spec.tsx b/web/app/components/rag-pipeline/components/__tests__/index.spec.tsx index 99318b07b3..0d3b638bab 100644 --- a/web/app/components/rag-pipeline/components/__tests__/index.spec.tsx +++ b/web/app/components/rag-pipeline/components/__tests__/index.spec.tsx @@ -1,5 +1,5 @@ import type { EnvironmentVariable } from '@/app/components/workflow/types' -import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' import { createMockProviderContextValue } from '@/__mocks__/provider-context' import Conversion from '../conversion' @@ -9,6 +9,12 @@ import PublishToast from '../publish-toast' import RagPipelineChildren from '../rag-pipeline-children' import PipelineScreenShot from '../screenshot' +afterEach(async () => { + await act(async () => { + await new Promise(resolve => setTimeout(resolve, 0)) + }) +}) + const mockPush = vi.fn() vi.mock('next/navigation', () => ({ useParams: () => ({ datasetId: 'test-dataset-id' }), diff --git a/web/app/components/workflow/__tests__/workflow-test-env.spec.tsx b/web/app/components/workflow/__tests__/workflow-test-env.spec.tsx new file mode 100644 index 0000000000..d9a4efa12e --- /dev/null +++ b/web/app/components/workflow/__tests__/workflow-test-env.spec.tsx @@ -0,0 +1,136 @@ +/** + * Validation tests for renderWorkflowComponent and renderNodeComponent. + */ +import type { Shape } from '../store/workflow' +import { act, screen } from '@testing-library/react' +import * as React from 'react' +import { FlowType } from '@/types/common' +import { useHooksStore } from '../hooks-store/store' +import { useStore, useWorkflowStore } from '../store/workflow' +import { renderNodeComponent, renderWorkflowComponent } from './workflow-test-env' + +// --------------------------------------------------------------------------- +// Test components that read from workflow contexts +// --------------------------------------------------------------------------- + +function StoreReader() { + const showConfirm = useStore(s => s.showConfirm) + return React.createElement('div', { 'data-testid': 'store-reader' }, showConfirm ? 'has-confirm' : 'no-confirm') +} + +function StoreWriter() { + const store = useWorkflowStore() + return React.createElement( + 'button', + { + 'data-testid': 'store-writer', + 'onClick': () => store.setState({ showConfirm: { title: 'Test', onConfirm: () => {} } } as Partial), + }, + 'Write', + ) +} + +function HooksStoreReader() { + const flowId = useHooksStore(s => s.configsMap?.flowId ?? 'none') + return React.createElement('div', { 'data-testid': 'hooks-reader' }, flowId) +} + +function NodeRenderer(props: { id: string, data: { title: string }, selected?: boolean }) { + return React.createElement( + 'div', + { 'data-testid': 'node-render' }, + `${props.id}:${props.data.title}:${props.selected ? 'sel' : 'nosel'}`, + ) +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +describe('renderWorkflowComponent', () => { + it('should provide WorkflowContext with default store', () => { + renderWorkflowComponent(React.createElement(StoreReader)) + expect(screen.getByTestId('store-reader')).toHaveTextContent('no-confirm') + }) + + it('should apply initialStoreState', () => { + renderWorkflowComponent(React.createElement(StoreReader), { + initialStoreState: { showConfirm: { title: 'Hey', onConfirm: () => {} } }, + }) + expect(screen.getByTestId('store-reader')).toHaveTextContent('has-confirm') + }) + + it('should return a live store that components can mutate', () => { + const { store } = renderWorkflowComponent( + React.createElement(React.Fragment, null, React.createElement(StoreReader), React.createElement(StoreWriter)), + ) + + expect(store.getState().showConfirm).toBeUndefined() + + act(() => { + screen.getByTestId('store-writer').click() + }) + + expect(store.getState().showConfirm).toBeDefined() + expect(screen.getByTestId('store-reader')).toHaveTextContent('has-confirm') + }) + + it('should provide HooksStoreContext when hooksStoreProps given', () => { + renderWorkflowComponent(React.createElement(HooksStoreReader), { + hooksStoreProps: { configsMap: { flowId: 'test-123', flowType: FlowType.appFlow, fileSettings: {} } }, + }) + expect(screen.getByTestId('hooks-reader')).toHaveTextContent('test-123') + }) + + it('should throw when HooksStoreContext is not provided', () => { + const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {}) + try { + expect(() => { + renderWorkflowComponent(React.createElement(HooksStoreReader)) + }).toThrow('Missing HooksStoreContext.Provider') + } + finally { + consoleSpy.mockRestore() + } + }) + + it('should forward extra render options (container)', () => { + const container = document.createElement('section') + document.body.appendChild(container) + + try { + renderWorkflowComponent(React.createElement(StoreReader), { container }) + expect(container.querySelector('[data-testid="store-reader"]')).toBeTruthy() + } + finally { + document.body.removeChild(container) + } + }) +}) + +describe('renderNodeComponent', () => { + it('should render node with default id and selected=false', () => { + renderNodeComponent(NodeRenderer, { title: 'Hello' }) + expect(screen.getByTestId('node-render')).toHaveTextContent('test-node-1:Hello:nosel') + }) + + it('should accept custom nodeId and selected', () => { + renderNodeComponent(NodeRenderer, { title: 'World' }, { + nodeId: 'custom-42', + selected: true, + }) + expect(screen.getByTestId('node-render')).toHaveTextContent('custom-42:World:sel') + }) + + it('should provide WorkflowContext to node components', () => { + function NodeWithStore(props: { id: string, data: Record }) { + const controlMode = useStore(s => s.controlMode) + return React.createElement('div', { 'data-testid': 'node-store' }, `${props.id}:${controlMode}`) + } + + renderNodeComponent(NodeWithStore, {}, { + initialStoreState: { controlMode: 'hand' as Shape['controlMode'] }, + }) + expect(screen.getByTestId('node-store')).toHaveTextContent('test-node-1:hand') + }) +}) diff --git a/web/app/components/workflow/__tests__/workflow-test-env.tsx b/web/app/components/workflow/__tests__/workflow-test-env.tsx index 6109d8a7f4..00d6829964 100644 --- a/web/app/components/workflow/__tests__/workflow-test-env.tsx +++ b/web/app/components/workflow/__tests__/workflow-test-env.tsx @@ -1,7 +1,7 @@ /** * Workflow test environment — composable providers + render helpers. * - * ## Quick start + * ## Quick start (hook) * * ```ts * import { resetReactFlowMockState, rfState } from '../../__tests__/reactflow-mock-state' @@ -29,13 +29,43 @@ * expect(rfState.setNodes).toHaveBeenCalled() * }) * ``` + * + * ## Quick start (component) + * + * ```ts + * import { renderWorkflowComponent } from '../../__tests__/workflow-test-env' + * + * it('renders correctly', () => { + * const { getByText, store } = renderWorkflowComponent( + * , + * { initialStoreState: { showConfirm: undefined } }, + * ) + * expect(getByText('value')).toBeInTheDocument() + * expect(store.getState().showConfirm).toBeUndefined() + * }) + * ``` + * + * ## Quick start (node component) + * + * ```ts + * import { renderNodeComponent } from '../../__tests__/workflow-test-env' + * + * it('renders node', () => { + * const { getByText, store } = renderNodeComponent( + * MyNodeComponent, + * { type: BlockEnum.Code, title: 'My Node', desc: '' }, + * { nodeId: 'n-1', initialStoreState: { ... } }, + * ) + * expect(getByText('My Node')).toBeInTheDocument() + * }) + * ``` */ -import type { RenderHookOptions, RenderHookResult } from '@testing-library/react' +import type { RenderHookOptions, RenderHookResult, RenderOptions, RenderResult } from '@testing-library/react' import type { Shape as HooksStoreShape } from '../hooks-store/store' import type { Shape } from '../store/workflow' import type { Edge, Node, WorkflowRunningData } from '../types' import type { WorkflowHistoryStoreApi } from '../workflow-history-store' -import { renderHook } from '@testing-library/react' +import { render, renderHook } from '@testing-library/react' import isDeepEqual from 'fast-deep-equal' import * as React from 'react' import { temporal } from 'zundo' @@ -83,11 +113,14 @@ export function createTestWorkflowStore(initialState?: Partial): Workflow } export function createTestHooksStore(props?: Partial): HooksStore { - return createHooksStore(props ?? {}) + const store = createHooksStore(props ?? {}) + if (props) + store.setState(props) + return store } // --------------------------------------------------------------------------- -// renderWorkflowHook — composable hook renderer +// Shared provider options & wrapper factory // --------------------------------------------------------------------------- type HistoryStoreConfig = { @@ -95,17 +128,68 @@ type HistoryStoreConfig = { edges?: Edge[] } -type WorkflowTestOptions

= Omit, 'wrapper'> & { +type WorkflowProviderOptions = { initialStoreState?: Partial hooksStoreProps?: Partial historyStore?: HistoryStoreConfig } -type WorkflowTestResult = RenderHookResult & { +type StoreInstances = { store: WorkflowStore hooksStore?: HooksStore } +function createStoresFromOptions(options: WorkflowProviderOptions): StoreInstances { + const store = createTestWorkflowStore(options.initialStoreState) + const hooksStore = options.hooksStoreProps !== undefined + ? createTestHooksStore(options.hooksStoreProps) + : undefined + return { store, hooksStore } +} + +function createWorkflowWrapper( + stores: StoreInstances, + historyConfig?: HistoryStoreConfig, +) { + const historyCtxValue = historyConfig + ? createTestHistoryStoreContext(historyConfig) + : undefined + + return ({ children }: { children: React.ReactNode }) => { + let inner: React.ReactNode = children + + if (historyCtxValue) { + inner = React.createElement( + WorkflowHistoryStoreContext.Provider, + { value: historyCtxValue }, + inner, + ) + } + + if (stores.hooksStore) { + inner = React.createElement( + HooksStoreContext.Provider, + { value: stores.hooksStore }, + inner, + ) + } + + return React.createElement( + WorkflowContext.Provider, + { value: stores.store }, + inner, + ) + } +} + +// --------------------------------------------------------------------------- +// renderWorkflowHook — composable hook renderer +// --------------------------------------------------------------------------- + +type WorkflowHookTestOptions

= Omit, 'wrapper'> & WorkflowProviderOptions + +type WorkflowHookTestResult = RenderHookResult & StoreInstances + /** * Renders a hook inside composable workflow providers. * @@ -116,44 +200,77 @@ type WorkflowTestResult = RenderHookResult & { */ export function renderWorkflowHook( hook: (props: P) => R, - options?: WorkflowTestOptions

, -): WorkflowTestResult { + options?: WorkflowHookTestOptions

, +): WorkflowHookTestResult { const { initialStoreState, hooksStoreProps, historyStore: historyConfig, ...rest } = options ?? {} - const store = createTestWorkflowStore(initialStoreState) - const hooksStore = hooksStoreProps !== undefined - ? createTestHooksStore(hooksStoreProps) - : undefined - - const wrapper = ({ children }: { children: React.ReactNode }) => { - let inner: React.ReactNode = children - - if (historyConfig) { - const historyCtxValue = createTestHistoryStoreContext(historyConfig) - inner = React.createElement( - WorkflowHistoryStoreContext.Provider, - { value: historyCtxValue }, - inner, - ) - } - - if (hooksStore) { - inner = React.createElement( - HooksStoreContext.Provider, - { value: hooksStore }, - inner, - ) - } - - return React.createElement( - WorkflowContext.Provider, - { value: store }, - inner, - ) - } + const stores = createStoresFromOptions({ initialStoreState, hooksStoreProps }) + const wrapper = createWorkflowWrapper(stores, historyConfig) const renderResult = renderHook(hook, { wrapper, ...rest }) - return { ...renderResult, store, hooksStore } + return { ...renderResult, ...stores } +} + +// --------------------------------------------------------------------------- +// renderWorkflowComponent — composable component renderer +// --------------------------------------------------------------------------- + +type WorkflowComponentTestOptions = Omit & WorkflowProviderOptions + +type WorkflowComponentTestResult = RenderResult & StoreInstances + +/** + * Renders a React element inside composable workflow providers. + * + * Provides the same context layers as `renderWorkflowHook`: + * - **Always**: `WorkflowContext` (real zustand store) + * - **hooksStoreProps**: `HooksStoreContext` (real zustand store) + * - **historyStore**: `WorkflowHistoryStoreContext` (real zundo temporal store) + */ +export function renderWorkflowComponent( + ui: React.ReactElement, + options?: WorkflowComponentTestOptions, +): WorkflowComponentTestResult { + const { initialStoreState, hooksStoreProps, historyStore: historyConfig, ...renderOptions } = options ?? {} + + const stores = createStoresFromOptions({ initialStoreState, hooksStoreProps }) + const wrapper = createWorkflowWrapper(stores, historyConfig) + + const renderResult = render(ui, { wrapper, ...renderOptions }) + return { ...renderResult, ...stores } +} + +// --------------------------------------------------------------------------- +// renderNodeComponent — convenience wrapper for node components +// --------------------------------------------------------------------------- + +type NodeComponentProps> = { + id: string + data: T + selected?: boolean +} + +type NodeTestOptions = WorkflowComponentTestOptions & { + nodeId?: string + selected?: boolean +} + +/** + * Renders a workflow node component inside composable workflow providers. + * + * Automatically provides `id`, `data`, and `selected` props that + * ReactFlow would normally inject into custom node components. + */ +export function renderNodeComponent>( + Component: React.ComponentType>, + data: T, + options?: NodeTestOptions, +): WorkflowComponentTestResult { + const { nodeId = 'test-node-1', selected = false, ...rest } = options ?? {} + return renderWorkflowComponent( + React.createElement(Component, { id: nodeId, data, selected }), + rest, + ) } // --------------------------------------------------------------------------- From f3c840a60ea139203b7e50b189a69f23f904b38b Mon Sep 17 00:00:00 2001 From: zxhlyh Date: Thu, 5 Mar 2026 15:08:37 +0800 Subject: [PATCH 3/5] fix: workflow canvas sync (#33030) --- .../__tests__/use-nodes-sync-draft.spec.ts | 111 ++++++++++++++++++ .../hooks/__tests__/use-workflow-init.spec.ts | 107 +++++++++++++++++ .../use-workflow-refresh-draft.spec.ts | 80 +++++++++++++ .../hooks/use-nodes-sync-draft.ts | 2 +- .../workflow-app/hooks/use-workflow-init.ts | 1 + .../hooks/use-workflow-refresh-draft.ts | 14 ++- 6 files changed, 308 insertions(+), 7 deletions(-) create mode 100644 web/app/components/workflow-app/hooks/__tests__/use-nodes-sync-draft.spec.ts create mode 100644 web/app/components/workflow-app/hooks/__tests__/use-workflow-init.spec.ts create mode 100644 web/app/components/workflow-app/hooks/__tests__/use-workflow-refresh-draft.spec.ts diff --git a/web/app/components/workflow-app/hooks/__tests__/use-nodes-sync-draft.spec.ts b/web/app/components/workflow-app/hooks/__tests__/use-nodes-sync-draft.spec.ts new file mode 100644 index 0000000000..d35e6e3612 --- /dev/null +++ b/web/app/components/workflow-app/hooks/__tests__/use-nodes-sync-draft.spec.ts @@ -0,0 +1,111 @@ +import { act, renderHook } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +import { useNodesSyncDraft } from '../use-nodes-sync-draft' + +const mockGetNodes = vi.fn() +vi.mock('reactflow', () => ({ + useStoreApi: () => ({ getState: () => ({ getNodes: mockGetNodes, edges: [], transform: [0, 0, 1] }) }), +})) + +vi.mock('@/app/components/workflow/store', () => ({ + useWorkflowStore: () => ({ + getState: () => ({ + appId: 'app-1', + isWorkflowDataLoaded: true, + syncWorkflowDraftHash: 'hash-123', + environmentVariables: [], + conversationVariables: [], + setSyncWorkflowDraftHash: vi.fn(), + setDraftUpdatedAt: vi.fn(), + }), + }), +})) + +vi.mock('@/app/components/base/features/hooks', () => ({ + useFeaturesStore: () => ({ + getState: () => ({ + features: { + opening: { enabled: false, opening_statement: '', suggested_questions: [] }, + suggested: {}, + text2speech: {}, + speech2text: {}, + citation: {}, + moderation: {}, + file: {}, + }, + }), + }), +})) + +vi.mock('@/app/components/workflow/hooks/use-workflow', () => ({ + useNodesReadOnly: () => ({ getNodesReadOnly: () => false }), +})) + +vi.mock('@/app/components/workflow/hooks/use-serial-async-callback', () => ({ + useSerialAsyncCallback: (fn: (...args: unknown[]) => Promise, checkFn: () => boolean) => + (...args: unknown[]) => { + if (!checkFn()) + return fn(...args) + }, +})) + +const mockSyncWorkflowDraft = vi.fn() +vi.mock('@/service/workflow', () => ({ + syncWorkflowDraft: (p: unknown) => mockSyncWorkflowDraft(p), +})) + +vi.mock('@/service/fetch', () => ({ postWithKeepalive: vi.fn() })) +vi.mock('@/config', () => ({ API_PREFIX: '/api' })) + +const mockHandleRefreshWorkflowDraft = vi.fn() +vi.mock('@/app/components/workflow-app/hooks', () => ({ + useWorkflowRefreshDraft: () => ({ handleRefreshWorkflowDraft: mockHandleRefreshWorkflowDraft }), +})) + +describe('useNodesSyncDraft — handleRefreshWorkflowDraft(true) on 409', () => { + beforeEach(() => { + vi.clearAllMocks() + mockGetNodes.mockReturnValue([{ id: 'n1', position: { x: 0, y: 0 }, data: { type: 'start' } }]) + mockSyncWorkflowDraft.mockResolvedValue({ hash: 'new', updated_at: 1 }) + }) + + it('should call handleRefreshWorkflowDraft(true) — not updating canvas — on draft_workflow_not_sync', async () => { + const error = { json: vi.fn().mockResolvedValue({ code: 'draft_workflow_not_sync' }), bodyUsed: false } + mockSyncWorkflowDraft.mockRejectedValue(error) + + const { result } = renderHook(() => useNodesSyncDraft()) + await act(async () => { + await result.current.doSyncWorkflowDraft(false) + }) + await new Promise(r => setTimeout(r, 0)) + + expect(mockHandleRefreshWorkflowDraft).toHaveBeenCalledWith(true) + }) + + it('should NOT refresh when notRefreshWhenSyncError=true', async () => { + const error = { json: vi.fn().mockResolvedValue({ code: 'draft_workflow_not_sync' }), bodyUsed: false } + mockSyncWorkflowDraft.mockRejectedValue(error) + + const { result } = renderHook(() => useNodesSyncDraft()) + await act(async () => { + await result.current.doSyncWorkflowDraft(true) + }) + await new Promise(r => setTimeout(r, 0)) + + expect(mockHandleRefreshWorkflowDraft).not.toHaveBeenCalled() + }) + + it('should NOT refresh for a different error code', async () => { + const error = { json: vi.fn().mockResolvedValue({ code: 'other_error' }), bodyUsed: false } + mockSyncWorkflowDraft.mockRejectedValue(error) + + const { result } = renderHook(() => useNodesSyncDraft()) + await act(async () => { + await result.current.doSyncWorkflowDraft(false) + }) + await new Promise(r => setTimeout(r, 0)) + + expect(mockHandleRefreshWorkflowDraft).not.toHaveBeenCalled() + }) +}) diff --git a/web/app/components/workflow-app/hooks/__tests__/use-workflow-init.spec.ts b/web/app/components/workflow-app/hooks/__tests__/use-workflow-init.spec.ts new file mode 100644 index 0000000000..42e4b593ed --- /dev/null +++ b/web/app/components/workflow-app/hooks/__tests__/use-workflow-init.spec.ts @@ -0,0 +1,107 @@ +import { renderHook, waitFor } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +import { useWorkflowInit } from '../use-workflow-init' + +const mockSetSyncWorkflowDraftHash = vi.fn() +const mockSetDraftUpdatedAt = vi.fn() +const mockSetToolPublished = vi.fn() +const mockSetPublishedAt = vi.fn() +const mockSetLastPublishedHasUserInput = vi.fn() +const mockSetFileUploadConfig = vi.fn() +const mockWorkflowStoreSetState = vi.fn() +const mockWorkflowStoreGetState = vi.fn() + +vi.mock('@/app/components/workflow/store', () => ({ + useStore: (selector: (state: { setSyncWorkflowDraftHash: ReturnType }) => T): T => + selector({ setSyncWorkflowDraftHash: mockSetSyncWorkflowDraftHash }), + useWorkflowStore: () => ({ + setState: mockWorkflowStoreSetState, + getState: mockWorkflowStoreGetState, + }), +})) + +vi.mock('@/app/components/app/store', () => ({ + useStore: (selector: (state: { appDetail: { id: string, name: string, mode: string } }) => T): T => + selector({ appDetail: { id: 'app-1', name: 'Test', mode: 'workflow' } }), +})) + +vi.mock('../use-workflow-template', () => ({ + useWorkflowTemplate: () => ({ nodes: [], edges: [] }), +})) + +vi.mock('@/service/use-workflow', () => ({ + useWorkflowConfig: () => ({ data: null, isLoading: false }), +})) + +const mockFetchWorkflowDraft = vi.fn() +const mockSyncWorkflowDraft = vi.fn() + +vi.mock('@/service/workflow', () => ({ + fetchWorkflowDraft: (...args: unknown[]) => mockFetchWorkflowDraft(...args), + syncWorkflowDraft: (...args: unknown[]) => mockSyncWorkflowDraft(...args), + fetchNodesDefaultConfigs: () => Promise.resolve([]), + fetchPublishedWorkflow: () => Promise.resolve({ created_at: 0, graph: { nodes: [], edges: [] } }), +})) + +const notExistError = () => ({ + json: vi.fn().mockResolvedValue({ code: 'draft_workflow_not_exist' }), + bodyUsed: false, +}) + +const draftResponse = { + id: 'draft-id', + graph: { nodes: [], edges: [] }, + hash: 'server-hash', + created_at: 0, + created_by: { id: '', name: '', email: '' }, + updated_at: 1, + updated_by: { id: '', name: '', email: '' }, + tool_published: false, + environment_variables: [], + conversation_variables: [], + version: '1', + marked_name: '', + marked_comment: '', +} + +describe('useWorkflowInit — hash fix (draft_workflow_not_exist)', () => { + beforeEach(() => { + vi.clearAllMocks() + mockWorkflowStoreGetState.mockReturnValue({ + setDraftUpdatedAt: mockSetDraftUpdatedAt, + setToolPublished: mockSetToolPublished, + setPublishedAt: mockSetPublishedAt, + setLastPublishedHasUserInput: mockSetLastPublishedHasUserInput, + setFileUploadConfig: mockSetFileUploadConfig, + }) + mockFetchWorkflowDraft + .mockRejectedValueOnce(notExistError()) + .mockResolvedValueOnce(draftResponse) + mockSyncWorkflowDraft.mockResolvedValue({ hash: 'new-hash', updated_at: 1 }) + }) + + it('should call setSyncWorkflowDraftHash with hash returned by syncWorkflowDraft', async () => { + renderHook(() => useWorkflowInit()) + await waitFor(() => expect(mockSetSyncWorkflowDraftHash).toHaveBeenCalledWith('new-hash')) + }) + + it('should store hash BEFORE making the recursive fetchWorkflowDraft call', async () => { + const order: string[] = [] + mockSetSyncWorkflowDraftHash.mockImplementation((h: string) => order.push(`hash:${h}`)) + mockFetchWorkflowDraft + .mockReset() + .mockRejectedValueOnce(notExistError()) + .mockImplementationOnce(async () => { + order.push('fetch:2') + return draftResponse + }) + mockSyncWorkflowDraft.mockResolvedValue({ hash: 'new-hash', updated_at: 1 }) + + renderHook(() => useWorkflowInit()) + + await waitFor(() => expect(order).toContain('fetch:2')) + expect(order).toContain('hash:new-hash') + expect(order.indexOf('hash:new-hash')).toBeLessThan(order.indexOf('fetch:2')) + }) +}) diff --git a/web/app/components/workflow-app/hooks/__tests__/use-workflow-refresh-draft.spec.ts b/web/app/components/workflow-app/hooks/__tests__/use-workflow-refresh-draft.spec.ts new file mode 100644 index 0000000000..2fd06e587b --- /dev/null +++ b/web/app/components/workflow-app/hooks/__tests__/use-workflow-refresh-draft.spec.ts @@ -0,0 +1,80 @@ +import { act, renderHook } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +import { useWorkflowRefreshDraft } from '../use-workflow-refresh-draft' + +const mockHandleUpdateWorkflowCanvas = vi.fn() +const mockSetSyncWorkflowDraftHash = vi.fn() + +vi.mock('@/app/components/workflow/store', () => ({ + useWorkflowStore: () => ({ + getState: () => ({ + appId: 'app-1', + isWorkflowDataLoaded: true, + debouncedSyncWorkflowDraft: undefined, + setSyncWorkflowDraftHash: mockSetSyncWorkflowDraftHash, + setIsSyncingWorkflowDraft: vi.fn(), + setEnvironmentVariables: vi.fn(), + setEnvSecrets: vi.fn(), + setConversationVariables: vi.fn(), + setIsWorkflowDataLoaded: vi.fn(), + }), + }), +})) + +vi.mock('@/app/components/workflow/hooks', () => ({ + useWorkflowUpdate: () => ({ handleUpdateWorkflowCanvas: mockHandleUpdateWorkflowCanvas }), +})) + +const mockFetchWorkflowDraft = vi.fn() +vi.mock('@/service/workflow', () => ({ + fetchWorkflowDraft: (...args: unknown[]) => mockFetchWorkflowDraft(...args), +})) + +const draftResponse = { + hash: 'server-hash', + graph: { nodes: [{ id: 'n1' }], edges: [], viewport: { x: 1, y: 2, zoom: 1 } }, + environment_variables: [], + conversation_variables: [], +} + +describe('useWorkflowRefreshDraft — notUpdateCanvas parameter', () => { + beforeEach(() => { + vi.clearAllMocks() + mockFetchWorkflowDraft.mockResolvedValue(draftResponse) + }) + + it('should update canvas by default (notUpdateCanvas omitted)', async () => { + const { result } = renderHook(() => useWorkflowRefreshDraft()) + await act(async () => { + result.current.handleRefreshWorkflowDraft() + }) + expect(mockHandleUpdateWorkflowCanvas).toHaveBeenCalledTimes(1) + }) + + it('should update canvas when notUpdateCanvas=false', async () => { + const { result } = renderHook(() => useWorkflowRefreshDraft()) + await act(async () => { + result.current.handleRefreshWorkflowDraft(false) + }) + expect(mockHandleUpdateWorkflowCanvas).toHaveBeenCalledTimes(1) + }) + + it('should NOT update canvas when notUpdateCanvas=true', async () => { + // This is the key change: when called from a 409 error during editing, + // canvas must not be overwritten with server state. + const { result } = renderHook(() => useWorkflowRefreshDraft()) + await act(async () => { + result.current.handleRefreshWorkflowDraft(true) + }) + expect(mockHandleUpdateWorkflowCanvas).not.toHaveBeenCalled() + }) + + it('should still update hash even when notUpdateCanvas=true', async () => { + const { result } = renderHook(() => useWorkflowRefreshDraft()) + await act(async () => { + result.current.handleRefreshWorkflowDraft(true) + }) + expect(mockSetSyncWorkflowDraftHash).toHaveBeenCalledWith('server-hash') + }) +}) diff --git a/web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts b/web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts index 5dc0741324..4f9e529d92 100644 --- a/web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts +++ b/web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts @@ -132,7 +132,7 @@ export const useNodesSyncDraft = () => { if (error && error.json && !error.bodyUsed) { error.json().then((err: any) => { if (err.code === 'draft_workflow_not_sync' && !notRefreshWhenSyncError) - handleRefreshWorkflowDraft() + handleRefreshWorkflowDraft(true) }) } callback?.onError?.() diff --git a/web/app/components/workflow-app/hooks/use-workflow-init.ts b/web/app/components/workflow-app/hooks/use-workflow-init.ts index 8e976937b5..00bff2919f 100644 --- a/web/app/components/workflow-app/hooks/use-workflow-init.ts +++ b/web/app/components/workflow-app/hooks/use-workflow-init.ts @@ -100,6 +100,7 @@ export const useWorkflowInit = () => { }, }).then((res) => { workflowStore.getState().setDraftUpdatedAt(res.updated_at) + setSyncWorkflowDraftHash(res.hash) handleGetInitialWorkflowData() }) } diff --git a/web/app/components/workflow-app/hooks/use-workflow-refresh-draft.ts b/web/app/components/workflow-app/hooks/use-workflow-refresh-draft.ts index fa4a44d894..a7283c0078 100644 --- a/web/app/components/workflow-app/hooks/use-workflow-refresh-draft.ts +++ b/web/app/components/workflow-app/hooks/use-workflow-refresh-draft.ts @@ -8,7 +8,7 @@ export const useWorkflowRefreshDraft = () => { const workflowStore = useWorkflowStore() const { handleUpdateWorkflowCanvas } = useWorkflowUpdate() - const handleRefreshWorkflowDraft = useCallback(() => { + const handleRefreshWorkflowDraft = useCallback((notUpdateCanvas?: boolean) => { const { appId, setSyncWorkflowDraftHash, @@ -31,12 +31,14 @@ export const useWorkflowRefreshDraft = () => { fetchWorkflowDraft(`/apps/${appId}/workflows/draft`) .then((response) => { // Ensure we have a valid workflow structure with viewport - const workflowData: WorkflowDataUpdater = { - nodes: response.graph?.nodes || [], - edges: response.graph?.edges || [], - viewport: response.graph?.viewport || { x: 0, y: 0, zoom: 1 }, + if (!notUpdateCanvas) { + const workflowData: WorkflowDataUpdater = { + nodes: response.graph?.nodes || [], + edges: response.graph?.edges || [], + viewport: response.graph?.viewport || { x: 0, y: 0, zoom: 1 }, + } + handleUpdateWorkflowCanvas(workflowData) } - handleUpdateWorkflowCanvas(workflowData) setSyncWorkflowDraftHash(response.hash) setEnvSecrets((response.environment_variables || []).filter(env => env.value_type === 'secret').reduce((acc, env) => { acc[env.id] = env.value From f487b680f57e59da1a9aa5dafb285568ca62302f Mon Sep 17 00:00:00 2001 From: Stephen Zhou Date: Thu, 5 Mar 2026 15:54:56 +0800 Subject: [PATCH 4/5] refactor: spilt context for better hmr (#33033) --- .../dsl-export-import-flow.test.ts | 2 +- .../[appId]/overview/card-view.tsx | 2 +- web/app/(commonLayout)/layout.tsx | 6 +- .../account-page/AvatarWithEdit.tsx | 2 +- .../account-page/email-change-modal.tsx | 2 +- .../(commonLayout)/account-page/index.tsx | 2 +- web/app/account/(commonLayout)/layout.tsx | 6 +- .../__tests__/use-app-info-actions.spec.ts | 2 +- .../app-info/use-app-info-actions.ts | 2 +- .../csv-uploader.spec.tsx | 2 +- .../csv-uploader.tsx | 2 +- .../config-prompt/advanced-prompt-input.tsx | 2 +- .../config-prompt/simple-prompt-input.tsx | 2 +- .../config/agent/prompt-editor.tsx | 2 +- .../settings-modal/index.spec.tsx | 2 +- .../dataset-config/settings-modal/index.tsx | 2 +- .../context-provider.tsx | 28 ++++++ .../context.spec.tsx | 6 +- .../{context.tsx => context.ts} | 26 +---- .../debug/debug-with-multiple-model/index.tsx | 6 +- .../debug-with-single-model/index.spec.tsx | 2 +- .../app/configuration/debug/index.tsx | 2 +- .../components/app/configuration/index.tsx | 5 +- .../tools/external-data-tool-modal.tsx | 2 +- .../app/configuration/tools/index.tsx | 2 +- .../app/create-app-modal/index.spec.tsx | 2 +- .../components/app/create-app-modal/index.tsx | 2 +- .../app/create-from-dsl-modal/index.tsx | 2 +- .../app/create-from-dsl-modal/uploader.tsx | 2 +- web/app/components/app/log/list.tsx | 2 +- .../app/overview/settings/index.spec.tsx | 16 ++-- .../app/overview/settings/index.tsx | 2 +- .../app/switch-app-modal/index.spec.tsx | 2 +- .../components/app/switch-app-modal/index.tsx | 2 +- web/app/components/apps/app-card.tsx | 3 +- .../agent-log-modal/__tests__/detail.spec.tsx | 2 +- .../agent-log-modal/__tests__/index.spec.tsx | 2 +- .../base/agent-log-modal/detail.tsx | 2 +- .../{context.tsx => context.ts} | 0 .../base/chat/chat-with-history/hooks.tsx | 2 +- .../base/chat/chat/__tests__/context.spec.tsx | 3 +- .../chat/chat/__tests__/question.spec.tsx | 2 +- .../chat-input-area/__tests__/index.spec.tsx | 12 +-- .../base/chat/chat/chat-input-area/index.tsx | 2 +- .../base/chat/chat/check-input-forms-hooks.ts | 2 +- .../{context.tsx => context-provider.tsx} | 30 +----- web/app/components/base/chat/chat/context.ts | 30 ++++++ web/app/components/base/chat/chat/hooks.ts | 2 +- web/app/components/base/chat/chat/index.tsx | 2 +- .../{context.tsx => context.ts} | 0 .../base/chat/embedded-chatbot/hooks.tsx | 2 +- .../inputs-form/__tests__/content.spec.tsx | 2 +- .../moderation-setting-modal.spec.tsx | 2 +- .../moderation/moderation-setting-modal.tsx | 2 +- .../file-uploader/__tests__/hooks.spec.ts | 2 +- .../components/base/file-uploader/hooks.ts | 2 +- .../__tests__/use-check-validated.spec.ts | 2 +- .../base/form/hooks/use-check-validated.ts | 2 +- .../image-uploader/__tests__/hooks.spec.ts | 2 +- .../components/base/image-uploader/hooks.ts | 2 +- .../markdown-blocks/__tests__/button.spec.tsx | 2 +- .../__tests__/think-block.spec.tsx | 2 +- .../markdown-blocks/think-block.stories.tsx | 2 +- .../__tests__/index.spec.tsx | 3 +- .../radio/context/{index.tsx => index.ts} | 0 .../base/tag-input/__tests__/index.spec.tsx | 2 +- web/app/components/base/tag-input/index.tsx | 2 +- .../tag-management/__tests__/panel.spec.tsx | 2 +- .../__tests__/selector.spec.tsx | 2 +- .../components/base/tag-management/index.tsx | 2 +- .../components/base/tag-management/panel.tsx | 2 +- .../base/tag-management/tag-item-editor.tsx | 2 +- .../text-generation/__tests__/hooks.spec.ts | 2 +- .../components/base/text-generation/hooks.ts | 2 +- .../base/toast/__tests__/index.spec.tsx | 3 +- web/app/components/base/toast/context.ts | 23 +++++ .../components/base/toast/index.stories.tsx | 3 +- web/app/components/base/toast/index.tsx | 27 ++---- .../__tests__/index.spec.tsx | 4 +- .../custom/custom-web-app-brand/index.tsx | 2 +- .../__tests__/uploader.spec.tsx | 2 +- .../hooks/use-dsl-import.ts | 2 +- .../create-from-dsl-modal/uploader.tsx | 2 +- .../empty-dataset-creation-modal/index.tsx | 2 +- .../hooks/__tests__/use-file-upload.spec.tsx | 2 +- .../file-uploader/hooks/use-file-upload.ts | 2 +- .../components/__tests__/operations.spec.tsx | 2 +- .../documents/components/operations.tsx | 2 +- .../__tests__/use-local-file-upload.spec.tsx | 4 +- .../__tests__/csv-uploader.spec.tsx | 2 +- .../detail/batch-modal/csv-uploader.tsx | 2 +- .../detail/completed/__tests__/index.spec.tsx | 2 +- .../__tests__/regeneration-modal.spec.tsx | 3 +- .../__tests__/use-child-segment-data.spec.ts | 2 +- .../__tests__/use-segment-list-data.spec.ts | 2 +- .../completed/hooks/use-child-segment-data.ts | 2 +- .../completed/hooks/use-segment-list-data.ts | 2 +- .../detail/completed/new-child-segment.tsx | 2 +- .../documents/detail/embedding/index.tsx | 2 +- .../__tests__/use-metadata-state.spec.ts | 2 +- .../metadata/hooks/use-metadata-state.ts | 2 +- .../datasets/documents/detail/new-segment.tsx | 2 +- .../datasets/documents/status-item/index.tsx | 2 +- .../__tests__/index.spec.tsx | 2 +- .../external-api/external-api-modal/index.tsx | 2 +- .../connector/__tests__/index.spec.tsx | 2 +- .../connector/index.tsx | 2 +- .../workplace-selector/index.spec.tsx | 2 +- .../workplace-selector/index.tsx | 2 +- .../api-based-extension-page/modal.spec.tsx | 4 +- .../api-based-extension-page/modal.tsx | 2 +- .../account-setting/language-page/index.tsx | 2 +- .../edit-workspace-modal/index.spec.tsx | 2 +- .../edit-workspace-modal/index.tsx | 2 +- .../members-page/invite-modal/index.spec.tsx | 2 +- .../members-page/invite-modal/index.tsx | 2 +- .../members-page/operation/index.spec.tsx | 2 +- .../members-page/operation/index.tsx | 2 +- .../transfer-ownership-modal/index.spec.tsx | 2 +- .../transfer-ownership-modal/index.tsx | 2 +- .../model-auth/hooks/use-auth.spec.tsx | 2 +- .../model-auth/hooks/use-auth.ts | 2 +- .../credential-panel.spec.tsx | 2 +- .../provider-added-card/credential-panel.tsx | 2 +- .../model-load-balancing-modal.spec.tsx | 2 +- .../model-load-balancing-modal.tsx | 2 +- .../system-model-selector/index.spec.tsx | 2 +- .../system-model-selector/index.tsx | 2 +- .../plugin-page/SerpapiPlugin.spec.tsx | 4 +- .../plugin-page/SerpapiPlugin.tsx | 2 +- .../plugin-page/index.spec.tsx | 2 +- web/app/components/header/index.spec.tsx | 2 +- web/app/components/header/index.tsx | 2 +- .../__tests__/authorized-in-node.spec.tsx | 2 +- .../__tests__/plugin-auth-in-agent.spec.tsx | 2 +- .../__tests__/api-key-modal.spec.tsx | 2 +- .../__tests__/authorize-components.spec.tsx | 2 +- .../__tests__/oauth-client-settings.spec.tsx | 2 +- .../plugin-auth/authorize/api-key-modal.tsx | 2 +- .../authorize/oauth-client-settings.tsx | 2 +- .../authorized/__tests__/index.spec.tsx | 2 +- .../plugins/plugin-auth/authorized/index.tsx | 2 +- .../__tests__/use-plugin-auth-action.spec.ts | 2 +- .../hooks/use-plugin-auth-action.ts | 2 +- .../edit/__tests__/apikey-edit-modal.spec.tsx | 11 ++- .../edit/__tests__/manual-edit-modal.spec.tsx | 11 ++- .../edit/__tests__/oauth-edit-modal.spec.tsx | 11 ++- .../__tests__/tool-credentials-form.spec.tsx | 3 + .../plugin-page/__tests__/context.spec.tsx | 3 +- .../{context.tsx => context-provider.tsx} | 46 +-------- .../components/plugins/plugin-page/context.ts | 46 +++++++++ .../components/plugins/plugin-page/index.tsx | 6 +- .../__tests__/update-dsl-modal.spec.tsx | 2 +- .../__tests__/index.spec.tsx | 2 +- .../publisher/__tests__/index.spec.tsx | 2 +- .../publisher/__tests__/popup.spec.tsx | 2 +- .../rag-pipeline-header/publisher/popup.tsx | 2 +- .../hooks/__tests__/index.spec.ts | 2 +- .../hooks/__tests__/use-DSL.spec.ts | 2 +- .../__tests__/use-update-dsl-modal.spec.ts | 2 +- .../components/rag-pipeline/hooks/use-DSL.ts | 2 +- .../hooks/use-update-dsl-modal.ts | 2 +- .../__tests__/features-trigger.spec.tsx | 2 +- .../workflow-header/features-trigger.tsx | 2 +- .../components/workflow-app/hooks/use-DSL.ts | 2 +- .../components/workflow/header/run-mode.tsx | 2 +- .../hooks/__tests__/use-checklist.spec.ts | 2 +- .../workflow/hooks/use-checklist.ts | 2 +- .../plugins/link-editor-plugin/hooks.ts | 2 +- .../components/object-value-item.tsx | 2 +- .../components/variable-modal.tsx | 2 +- .../workflow/panel/debug-and-preview/hooks.ts | 2 +- .../panel/env-panel/variable-modal.tsx | 2 +- web/app/components/workflow/run/index.tsx | 2 +- .../components/workflow/update-dsl-modal.tsx | 2 +- .../education-apply/education-apply-page.tsx | 2 +- ...tasets-context.tsx => datasets-context.ts} | 0 web/context/event-emitter-provider.tsx | 22 +++++ .../{event-emitter.tsx => event-emitter.ts} | 18 +--- web/context/mitt-context-provider.tsx | 19 ++++ .../{mitt-context.tsx => mitt-context.ts} | 14 +-- ...context.tsx => modal-context-provider.tsx} | 93 ++---------------- web/context/modal-context.test.tsx | 2 +- web/context/modal-context.ts | 92 ++++++++++++++++++ ...text.tsx => provider-context-provider.tsx} | 96 +------------------ web/context/provider-context.ts | 90 +++++++++++++++++ web/context/workspace-context-provider.tsx | 24 +++++ web/context/workspace-context.ts | 16 ++++ web/context/workspace-context.tsx | 36 ------- web/eslint-suppressions.json | 60 ++---------- web/hooks/use-import-dsl.ts | 2 +- 191 files changed, 645 insertions(+), 615 deletions(-) create mode 100644 web/app/components/app/configuration/debug/debug-with-multiple-model/context-provider.tsx rename web/app/components/app/configuration/debug/debug-with-multiple-model/{context.tsx => context.ts} (50%) rename web/app/components/base/chat/chat-with-history/{context.tsx => context.ts} (100%) rename web/app/components/base/chat/chat/{context.tsx => context-provider.tsx} (58%) create mode 100644 web/app/components/base/chat/chat/context.ts rename web/app/components/base/chat/embedded-chatbot/{context.tsx => context.ts} (100%) rename web/app/components/base/radio/context/{index.tsx => index.ts} (100%) create mode 100644 web/app/components/base/toast/context.ts rename web/app/components/plugins/plugin-page/{context.tsx => context-provider.tsx} (58%) create mode 100644 web/app/components/plugins/plugin-page/context.ts rename web/context/{datasets-context.tsx => datasets-context.ts} (100%) create mode 100644 web/context/event-emitter-provider.tsx rename web/context/{event-emitter.tsx => event-emitter.ts} (53%) create mode 100644 web/context/mitt-context-provider.tsx rename web/context/{mitt-context.tsx => mitt-context.ts} (66%) rename web/context/{modal-context.tsx => modal-context-provider.tsx} (81%) create mode 100644 web/context/modal-context.ts rename web/context/{provider-context.tsx => provider-context-provider.tsx} (70%) create mode 100644 web/context/provider-context.ts create mode 100644 web/context/workspace-context-provider.tsx create mode 100644 web/context/workspace-context.ts delete mode 100644 web/context/workspace-context.tsx diff --git a/web/__tests__/rag-pipeline/dsl-export-import-flow.test.ts b/web/__tests__/rag-pipeline/dsl-export-import-flow.test.ts index 578552840d..dc5ab3fc86 100644 --- a/web/__tests__/rag-pipeline/dsl-export-import-flow.test.ts +++ b/web/__tests__/rag-pipeline/dsl-export-import-flow.test.ts @@ -19,7 +19,7 @@ vi.mock('react-i18next', () => ({ }), })) -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify }), })) diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx index f07b2932c9..8c1df8d63d 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx @@ -12,7 +12,7 @@ import AppCard from '@/app/components/app/overview/app-card' import TriggerCard from '@/app/components/app/overview/trigger-card' import { useStore as useAppStore } from '@/app/components/app/store' import Loading from '@/app/components/base/loading' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import MCPServiceCard from '@/app/components/tools/mcp/mcp-service-card' import { isTriggerNode } from '@/app/components/workflow/types' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' diff --git a/web/app/(commonLayout)/layout.tsx b/web/app/(commonLayout)/layout.tsx index bd067fde6a..db2786f6cf 100644 --- a/web/app/(commonLayout)/layout.tsx +++ b/web/app/(commonLayout)/layout.tsx @@ -9,9 +9,9 @@ import Header from '@/app/components/header' import HeaderWrapper from '@/app/components/header/header-wrapper' import ReadmePanel from '@/app/components/plugins/readme-panel' import { AppContextProvider } from '@/context/app-context-provider' -import { EventEmitterContextProvider } from '@/context/event-emitter' -import { ModalContextProvider } from '@/context/modal-context' -import { ProviderContextProvider } from '@/context/provider-context' +import { EventEmitterContextProvider } from '@/context/event-emitter-provider' +import { ModalContextProvider } from '@/context/modal-context-provider' +import { ProviderContextProvider } from '@/context/provider-context-provider' import PartnerStack from '../components/billing/partner-stack' import Splash from '../components/splash' import RoleRouteGuard from './role-route-guard' diff --git a/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx b/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx index 49b59704b1..0a17822187 100644 --- a/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx +++ b/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx @@ -16,7 +16,7 @@ import Button from '@/app/components/base/button' import Divider from '@/app/components/base/divider' import { useLocalFileUploader } from '@/app/components/base/image-uploader/hooks' import Modal from '@/app/components/base/modal' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { DISABLE_UPLOAD_IMAGE_AS_ICON } from '@/config' import { updateUserProfile } from '@/service/common' diff --git a/web/app/account/(commonLayout)/account-page/email-change-modal.tsx b/web/app/account/(commonLayout)/account-page/email-change-modal.tsx index 87ca6a689c..463c27294a 100644 --- a/web/app/account/(commonLayout)/account-page/email-change-modal.tsx +++ b/web/app/account/(commonLayout)/account-page/email-change-modal.tsx @@ -9,7 +9,7 @@ import { useContext } from 'use-context-selector' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Modal from '@/app/components/base/modal' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { checkEmailExisted, resetEmail, diff --git a/web/app/account/(commonLayout)/account-page/index.tsx b/web/app/account/(commonLayout)/account-page/index.tsx index f01efc002c..908ef9c2e8 100644 --- a/web/app/account/(commonLayout)/account-page/index.tsx +++ b/web/app/account/(commonLayout)/account-page/index.tsx @@ -12,7 +12,7 @@ import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Modal from '@/app/components/base/modal' import PremiumBadge from '@/app/components/base/premium-badge' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import Collapse from '@/app/components/header/account-setting/collapse' import { IS_CE_EDITION, validPassword } from '@/config' import { useAppContext } from '@/context/app-context' diff --git a/web/app/account/(commonLayout)/layout.tsx b/web/app/account/(commonLayout)/layout.tsx index 47fb47b02b..8fdbd8a238 100644 --- a/web/app/account/(commonLayout)/layout.tsx +++ b/web/app/account/(commonLayout)/layout.tsx @@ -5,9 +5,9 @@ import AmplitudeProvider from '@/app/components/base/amplitude' import GA, { GaType } from '@/app/components/base/ga' import HeaderWrapper from '@/app/components/header/header-wrapper' import { AppContextProvider } from '@/context/app-context-provider' -import { EventEmitterContextProvider } from '@/context/event-emitter' -import { ModalContextProvider } from '@/context/modal-context' -import { ProviderContextProvider } from '@/context/provider-context' +import { EventEmitterContextProvider } from '@/context/event-emitter-provider' +import { ModalContextProvider } from '@/context/modal-context-provider' +import { ProviderContextProvider } from '@/context/provider-context-provider' import Header from './header' const Layout = ({ children }: { children: ReactNode }) => { diff --git a/web/app/components/app-sidebar/app-info/__tests__/use-app-info-actions.spec.ts b/web/app/components/app-sidebar/app-info/__tests__/use-app-info-actions.spec.ts index e5966ed972..6104e2b641 100644 --- a/web/app/components/app-sidebar/app-info/__tests__/use-app-info-actions.spec.ts +++ b/web/app/components/app-sidebar/app-info/__tests__/use-app-info-actions.spec.ts @@ -42,7 +42,7 @@ vi.mock('@/app/components/app/store', () => ({ }), })) -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ ToastContext: {}, })) diff --git a/web/app/components/app-sidebar/app-info/use-app-info-actions.ts b/web/app/components/app-sidebar/app-info/use-app-info-actions.ts index 13880acbed..800f21de44 100644 --- a/web/app/components/app-sidebar/app-info/use-app-info-actions.ts +++ b/web/app/components/app-sidebar/app-info/use-app-info-actions.ts @@ -6,7 +6,7 @@ import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import { useStore as useAppStore } from '@/app/components/app/store' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useProviderContext } from '@/context/provider-context' import { copyApp, deleteApp, exportAppConfig, updateAppInfo } from '@/service/apps' diff --git a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.spec.tsx b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.spec.tsx index 6a67ba3207..55f5ee0564 100644 --- a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.spec.tsx +++ b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.spec.tsx @@ -1,7 +1,7 @@ import type { Props } from './csv-uploader' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import * as React from 'react' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import CSVUploader from './csv-uploader' describe('CSVUploader', () => { diff --git a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx index 6d5eb1ef95..118eaea58e 100644 --- a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx +++ b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx @@ -7,7 +7,7 @@ import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import Button from '@/app/components/base/button' import { Csv as CSVIcon } from '@/app/components/base/icons/src/public/files' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { cn } from '@/utils/classnames' export type Props = { diff --git a/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx b/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx index d0e9eb586c..9625204d81 100644 --- a/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx +++ b/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx @@ -20,7 +20,7 @@ import { } from '@/app/components/base/icons/src/vender/line/files' import PromptEditor from '@/app/components/base/prompt-editor' import { INSERT_VARIABLE_VALUE_BLOCK_COMMAND } from '@/app/components/base/prompt-editor/plugins/variable-block' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import Tooltip from '@/app/components/base/tooltip' import ConfigContext from '@/context/debug-configuration' import { useEventEmitterContextContext } from '@/context/event-emitter' diff --git a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx index bc94f87838..c33d55873d 100644 --- a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx +++ b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx @@ -17,7 +17,7 @@ import { useFeaturesStore } from '@/app/components/base/features/hooks' import PromptEditor from '@/app/components/base/prompt-editor' import { PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER } from '@/app/components/base/prompt-editor/plugins/update-block' import { INSERT_VARIABLE_VALUE_BLOCK_COMMAND } from '@/app/components/base/prompt-editor/plugins/variable-block' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import Tooltip from '@/app/components/base/tooltip' import ConfigContext from '@/context/debug-configuration' import { useEventEmitterContextContext } from '@/context/event-emitter' diff --git a/web/app/components/app/configuration/config/agent/prompt-editor.tsx b/web/app/components/app/configuration/config/agent/prompt-editor.tsx index e9e3b60859..9f1f04ba3c 100644 --- a/web/app/components/app/configuration/config/agent/prompt-editor.tsx +++ b/web/app/components/app/configuration/config/agent/prompt-editor.tsx @@ -12,7 +12,7 @@ import { CopyCheck, } from '@/app/components/base/icons/src/vender/line/files' import PromptEditor from '@/app/components/base/prompt-editor' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import ConfigContext from '@/context/debug-configuration' import { useModalContext } from '@/context/modal-context' import { cn } from '@/utils/classnames' diff --git a/web/app/components/app/configuration/dataset-config/settings-modal/index.spec.tsx b/web/app/components/app/configuration/dataset-config/settings-modal/index.spec.tsx index c4fdfb7553..ea70725ea8 100644 --- a/web/app/components/app/configuration/dataset-config/settings-modal/index.spec.tsx +++ b/web/app/components/app/configuration/dataset-config/settings-modal/index.spec.tsx @@ -3,7 +3,7 @@ import type { DataSet } from '@/models/datasets' import type { RetrievalConfig } from '@/types/app' import { render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { IndexingType } from '@/app/components/datasets/create/step-two' import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' diff --git a/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx b/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx index e910702b66..c9c8d080f2 100644 --- a/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx +++ b/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx @@ -9,7 +9,7 @@ import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Textarea from '@/app/components/base/textarea' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model' import { IndexingType } from '@/app/components/datasets/create/step-two' import IndexMethod from '@/app/components/datasets/settings/index-method' diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/context-provider.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/context-provider.tsx new file mode 100644 index 0000000000..74aed2d1e2 --- /dev/null +++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/context-provider.tsx @@ -0,0 +1,28 @@ +'use client' + +import type { ReactNode } from 'react' +import type { DebugWithMultipleModelContextType } from './context' +import { DebugWithMultipleModelContext } from './context' + +type DebugWithMultipleModelContextProviderProps = { + children: ReactNode +} & DebugWithMultipleModelContextType +export const DebugWithMultipleModelContextProvider = ({ + children, + onMultipleModelConfigsChange, + multipleModelConfigs, + onDebugWithMultipleModelChange, + checkCanSend, +}: DebugWithMultipleModelContextProviderProps) => { + return ( + + {children} + + ) +} diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/context.spec.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/context.spec.tsx index e26fcec607..989285f812 100644 --- a/web/app/components/app/configuration/debug/debug-with-multiple-model/context.spec.tsx +++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/context.spec.tsx @@ -1,10 +1,8 @@ import type { ModelAndParameter } from '../types' import type { DebugWithMultipleModelContextType } from './context' import { render, screen } from '@testing-library/react' -import { - DebugWithMultipleModelContextProvider, - useDebugWithMultipleModelContext, -} from './context' +import { useDebugWithMultipleModelContext } from './context' +import { DebugWithMultipleModelContextProvider } from './context-provider' const createModelAndParameter = (overrides: Partial = {}): ModelAndParameter => ({ id: 'model-1', diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/context.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/context.ts similarity index 50% rename from web/app/components/app/configuration/debug/debug-with-multiple-model/context.tsx rename to web/app/components/app/configuration/debug/debug-with-multiple-model/context.ts index 38f803f8ab..e3ad06f1b9 100644 --- a/web/app/components/app/configuration/debug/debug-with-multiple-model/context.tsx +++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/context.ts @@ -10,7 +10,8 @@ export type DebugWithMultipleModelContextType = { onDebugWithMultipleModelChange: (singleModelConfig: ModelAndParameter) => void checkCanSend?: () => boolean } -const DebugWithMultipleModelContext = createContext({ + +export const DebugWithMultipleModelContext = createContext({ multipleModelConfigs: [], onMultipleModelConfigsChange: noop, onDebugWithMultipleModelChange: noop, @@ -18,27 +19,4 @@ const DebugWithMultipleModelContext = createContext useContext(DebugWithMultipleModelContext) -type DebugWithMultipleModelContextProviderProps = { - children: React.ReactNode -} & DebugWithMultipleModelContextType -export const DebugWithMultipleModelContextProvider = ({ - children, - onMultipleModelConfigsChange, - multipleModelConfigs, - onDebugWithMultipleModelChange, - checkCanSend, -}: DebugWithMultipleModelContextProviderProps) => { - return ( - - {children} - - ) -} - export default DebugWithMultipleModelContext diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/index.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/index.tsx index c73eb54329..f98e8c1f06 100644 --- a/web/app/components/app/configuration/debug/debug-with-multiple-model/index.tsx +++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/index.tsx @@ -14,10 +14,8 @@ import { useDebugConfigurationContext } from '@/context/debug-configuration' import { useEventEmitterContextContext } from '@/context/event-emitter' import { AppModeEnum } from '@/types/app' import { APP_CHAT_WITH_MULTIPLE_MODEL } from '../types' -import { - DebugWithMultipleModelContextProvider, - useDebugWithMultipleModelContext, -} from './context' +import { useDebugWithMultipleModelContext } from './context' +import { DebugWithMultipleModelContextProvider } from './context-provider' import DebugItem from './debug-item' const DebugWithMultipleModel = () => { diff --git a/web/app/components/app/configuration/debug/debug-with-single-model/index.spec.tsx b/web/app/components/app/configuration/debug/debug-with-single-model/index.spec.tsx index 08bdd2bfcb..48141d0045 100644 --- a/web/app/components/app/configuration/debug/debug-with-single-model/index.spec.tsx +++ b/web/app/components/app/configuration/debug/debug-with-single-model/index.spec.tsx @@ -387,7 +387,7 @@ vi.mock('@/context/event-emitter', () => ({ })) // Mock toast context -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: vi.fn(() => ({ notify: vi.fn(), })), diff --git a/web/app/components/app/configuration/debug/index.tsx b/web/app/components/app/configuration/debug/index.tsx index 14a00e85c7..1a6f9e264f 100644 --- a/web/app/components/app/configuration/debug/index.tsx +++ b/web/app/components/app/configuration/debug/index.tsx @@ -29,7 +29,7 @@ import Button from '@/app/components/base/button' import { useFeatures, useFeaturesStore } from '@/app/components/base/features/hooks' import { RefreshCcw01 } from '@/app/components/base/icons/src/vender/line/arrows' import PromptLogModal from '@/app/components/base/prompt-log-modal' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import TooltipPlus from '@/app/components/base/tooltip' import { ModelFeatureEnum, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks' diff --git a/web/app/components/app/configuration/index.tsx b/web/app/components/app/configuration/index.tsx index 16cf9454ca..a39eb3c63e 100644 --- a/web/app/components/app/configuration/index.tsx +++ b/web/app/components/app/configuration/index.tsx @@ -49,7 +49,8 @@ import { FeaturesProvider } from '@/app/components/base/features' import NewFeaturePanel from '@/app/components/base/features/new-feature-panel' import Loading from '@/app/components/base/loading' import { FILE_EXTS } from '@/app/components/base/prompt-editor/constants' -import Toast, { ToastContext } from '@/app/components/base/toast' +import Toast from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants' import { ModelFeatureEnum, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { @@ -66,7 +67,7 @@ import { SupportUploadFileTypes } from '@/app/components/workflow/types' import { ANNOTATION_DEFAULT, DATASET_DEFAULT, DEFAULT_AGENT_SETTING, DEFAULT_CHAT_PROMPT_CONFIG, DEFAULT_COMPLETION_PROMPT_CONFIG } from '@/config' import { useAppContext } from '@/context/app-context' import ConfigContext from '@/context/debug-configuration' -import { MittProvider } from '@/context/mitt-context' +import { MittProvider } from '@/context/mitt-context-provider' import { useModalContext } from '@/context/modal-context' import { useProviderContext } from '@/context/provider-context' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' diff --git a/web/app/components/app/configuration/tools/external-data-tool-modal.tsx b/web/app/components/app/configuration/tools/external-data-tool-modal.tsx index fece5598e1..35ba59a5fb 100644 --- a/web/app/components/app/configuration/tools/external-data-tool-modal.tsx +++ b/web/app/components/app/configuration/tools/external-data-tool-modal.tsx @@ -13,7 +13,7 @@ import FormGeneration from '@/app/components/base/features/new-feature-panel/mod import { BookOpen01 } from '@/app/components/base/icons/src/vender/line/education' import Modal from '@/app/components/base/modal' import { SimpleSelect } from '@/app/components/base/select' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import ApiBasedExtensionSelector from '@/app/components/header/account-setting/api-based-extension-page/selector' import { useDocLink, useLocale } from '@/context/i18n' import { LanguagesSupported } from '@/i18n-config/language' diff --git a/web/app/components/app/configuration/tools/index.tsx b/web/app/components/app/configuration/tools/index.tsx index d2873b0be3..f348a7718d 100644 --- a/web/app/components/app/configuration/tools/index.tsx +++ b/web/app/components/app/configuration/tools/index.tsx @@ -15,7 +15,7 @@ import { } from '@/app/components/base/icons/src/vender/line/general' import { Tool03 } from '@/app/components/base/icons/src/vender/solid/general' import Switch from '@/app/components/base/switch' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import Tooltip from '@/app/components/base/tooltip' import ConfigContext from '@/context/debug-configuration' import { useModalContext } from '@/context/modal-context' diff --git a/web/app/components/app/create-app-modal/index.spec.tsx b/web/app/components/app/create-app-modal/index.spec.tsx index 8c368df62c..a9adb17582 100644 --- a/web/app/components/app/create-app-modal/index.spec.tsx +++ b/web/app/components/app/create-app-modal/index.spec.tsx @@ -4,7 +4,7 @@ import { useRouter } from 'next/navigation' import { afterAll, beforeEach, describe, expect, it, vi } from 'vitest' import { trackEvent } from '@/app/components/base/amplitude' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useAppContext } from '@/context/app-context' import { useProviderContext } from '@/context/provider-context' diff --git a/web/app/components/app/create-app-modal/index.tsx b/web/app/components/app/create-app-modal/index.tsx index 66c7bce80c..16ca4bdaff 100644 --- a/web/app/components/app/create-app-modal/index.tsx +++ b/web/app/components/app/create-app-modal/index.tsx @@ -17,7 +17,7 @@ import FullScreenModal from '@/app/components/base/fullscreen-modal' import { BubbleTextMod, ChatBot, ListSparkle, Logic } from '@/app/components/base/icons/src/vender/solid/communication' import Input from '@/app/components/base/input' import Textarea from '@/app/components/base/textarea' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import AppsFull from '@/app/components/billing/apps-full-in-dialog' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useAppContext } from '@/context/app-context' diff --git a/web/app/components/app/create-from-dsl-modal/index.tsx b/web/app/components/app/create-from-dsl-modal/index.tsx index 04d8b1e754..a542ef059f 100644 --- a/web/app/components/app/create-from-dsl-modal/index.tsx +++ b/web/app/components/app/create-from-dsl-modal/index.tsx @@ -12,7 +12,7 @@ import { trackEvent } from '@/app/components/base/amplitude' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Modal from '@/app/components/base/modal' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import AppsFull from '@/app/components/billing/apps-full-in-dialog' import { usePluginDependencies } from '@/app/components/workflow/plugin-dependency/hooks' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' diff --git a/web/app/components/app/create-from-dsl-modal/uploader.tsx b/web/app/components/app/create-from-dsl-modal/uploader.tsx index 778a2c1420..6b7ca4db61 100644 --- a/web/app/components/app/create-from-dsl-modal/uploader.tsx +++ b/web/app/components/app/create-from-dsl-modal/uploader.tsx @@ -10,7 +10,7 @@ import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import ActionButton from '@/app/components/base/action-button' import { Yaml as YamlIcon } from '@/app/components/base/icons/src/public/files' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { cn } from '@/utils/classnames' import { formatFileSize } from '@/utils/format' diff --git a/web/app/components/app/log/list.tsx b/web/app/components/app/log/list.tsx index 40519dcb36..4eb733a02c 100644 --- a/web/app/components/app/log/list.tsx +++ b/web/app/components/app/log/list.tsx @@ -31,7 +31,7 @@ import Drawer from '@/app/components/base/drawer' import { getProcessedFilesFromResponse } from '@/app/components/base/file-uploader/utils' import Loading from '@/app/components/base/loading' import MessageLogModal from '@/app/components/base/message-log-modal' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import Tooltip from '@/app/components/base/tooltip' import { addFileInfos, sortAgentSorts } from '@/app/components/tools/utils' import { WorkflowContextProvider } from '@/app/components/workflow/context' diff --git a/web/app/components/app/overview/settings/index.spec.tsx b/web/app/components/app/overview/settings/index.spec.tsx index c9cbe0b724..d98e02ad57 100644 --- a/web/app/components/app/overview/settings/index.spec.tsx +++ b/web/app/components/app/overview/settings/index.spec.tsx @@ -59,16 +59,12 @@ vi.mock('@/context/modal-context', () => ({ useModalContext: () => buildModalContext(), })) -vi.mock('@/app/components/base/toast', async () => { - const actual = await vi.importActual('@/app/components/base/toast') - return { - ...actual, - useToastContext: () => ({ - notify: mockNotify, - close: vi.fn(), - }), - } -}) +vi.mock('@/app/components/base/toast/context', () => ({ + useToastContext: () => ({ + notify: mockNotify, + close: vi.fn(), + }), +})) vi.mock('@/context/i18n', async () => { const actual = await vi.importActual('@/context/i18n') diff --git a/web/app/components/app/overview/settings/index.tsx b/web/app/components/app/overview/settings/index.tsx index 040703f41c..92bfdc5d31 100644 --- a/web/app/components/app/overview/settings/index.tsx +++ b/web/app/components/app/overview/settings/index.tsx @@ -20,7 +20,7 @@ import PremiumBadge from '@/app/components/base/premium-badge' import { SimpleSelect } from '@/app/components/base/select' import Switch from '@/app/components/base/switch' import Textarea from '@/app/components/base/textarea' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import Tooltip from '@/app/components/base/tooltip' import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants' import { useModalContext } from '@/context/modal-context' diff --git a/web/app/components/app/switch-app-modal/index.spec.tsx b/web/app/components/app/switch-app-modal/index.spec.tsx index f0400f8b75..c905d79b31 100644 --- a/web/app/components/app/switch-app-modal/index.spec.tsx +++ b/web/app/components/app/switch-app-modal/index.spec.tsx @@ -3,7 +3,7 @@ import { render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' import * as React from 'react' import { useStore as useAppStore } from '@/app/components/app/store' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { Plan } from '@/app/components/billing/type' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { AppModeEnum } from '@/types/app' diff --git a/web/app/components/app/switch-app-modal/index.tsx b/web/app/components/app/switch-app-modal/index.tsx index 30d7877ed0..8caa07c187 100644 --- a/web/app/components/app/switch-app-modal/index.tsx +++ b/web/app/components/app/switch-app-modal/index.tsx @@ -15,7 +15,7 @@ import Confirm from '@/app/components/base/confirm' import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' import Input from '@/app/components/base/input' import Modal from '@/app/components/base/modal' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import AppsFull from '@/app/components/billing/apps-full-in-dialog' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useAppContext } from '@/context/app-context' diff --git a/web/app/components/apps/app-card.tsx b/web/app/components/apps/app-card.tsx index 8dc5c82e13..471b3420d1 100644 --- a/web/app/components/apps/app-card.tsx +++ b/web/app/components/apps/app-card.tsx @@ -18,7 +18,8 @@ import AppIcon from '@/app/components/base/app-icon' import Divider from '@/app/components/base/divider' import CustomPopover from '@/app/components/base/popover' import TagSelector from '@/app/components/base/tag-management/selector' -import Toast, { ToastContext } from '@/app/components/base/toast' +import Toast from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import Tooltip from '@/app/components/base/tooltip' import { AlertDialog, diff --git a/web/app/components/base/agent-log-modal/__tests__/detail.spec.tsx b/web/app/components/base/agent-log-modal/__tests__/detail.spec.tsx index c77f144da2..47d854e028 100644 --- a/web/app/components/base/agent-log-modal/__tests__/detail.spec.tsx +++ b/web/app/components/base/agent-log-modal/__tests__/detail.spec.tsx @@ -2,7 +2,7 @@ import type { ComponentProps } from 'react' import type { IChatItem } from '@/app/components/base/chat/chat/type' import type { AgentLogDetailResponse } from '@/models/log' import { fireEvent, render, screen, waitFor } from '@testing-library/react' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { fetchAgentLogDetail } from '@/service/log' import AgentLogDetail from '../detail' diff --git a/web/app/components/base/agent-log-modal/__tests__/index.spec.tsx b/web/app/components/base/agent-log-modal/__tests__/index.spec.tsx index 6b59e90c77..6437ae5b43 100644 --- a/web/app/components/base/agent-log-modal/__tests__/index.spec.tsx +++ b/web/app/components/base/agent-log-modal/__tests__/index.spec.tsx @@ -1,7 +1,7 @@ import type { IChatItem } from '@/app/components/base/chat/chat/type' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import { useClickAway } from 'ahooks' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { fetchAgentLogDetail } from '@/service/log' import AgentLogModal from '../index' diff --git a/web/app/components/base/agent-log-modal/detail.tsx b/web/app/components/base/agent-log-modal/detail.tsx index 36b502e9a5..21ed0be7e8 100644 --- a/web/app/components/base/agent-log-modal/detail.tsx +++ b/web/app/components/base/agent-log-modal/detail.tsx @@ -10,7 +10,7 @@ import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import { useStore as useAppStore } from '@/app/components/app/store' import Loading from '@/app/components/base/loading' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { fetchAgentLogDetail } from '@/service/log' import { cn } from '@/utils/classnames' import ResultPanel from './result' diff --git a/web/app/components/base/chat/chat-with-history/context.tsx b/web/app/components/base/chat/chat-with-history/context.ts similarity index 100% rename from web/app/components/base/chat/chat-with-history/context.tsx rename to web/app/components/base/chat/chat-with-history/context.ts diff --git a/web/app/components/base/chat/chat-with-history/hooks.tsx b/web/app/components/base/chat/chat-with-history/hooks.tsx index da344a9789..23936111ce 100644 --- a/web/app/components/base/chat/chat-with-history/hooks.tsx +++ b/web/app/components/base/chat/chat-with-history/hooks.tsx @@ -23,7 +23,7 @@ import { } from 'react' import { useTranslation } from 'react-i18next' import { getProcessedFilesFromResponse } from '@/app/components/base/file-uploader/utils' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { InputVarType } from '@/app/components/workflow/types' import { useWebAppStore } from '@/context/web-app-context' import { useAppFavicon } from '@/hooks/use-app-favicon' diff --git a/web/app/components/base/chat/chat/__tests__/context.spec.tsx b/web/app/components/base/chat/chat/__tests__/context.spec.tsx index fd00156e59..aeba073a7b 100644 --- a/web/app/components/base/chat/chat/__tests__/context.spec.tsx +++ b/web/app/components/base/chat/chat/__tests__/context.spec.tsx @@ -3,7 +3,8 @@ import type { ChatContextValue } from '../context' import { render, screen } from '@testing-library/react' import userEvent from '@testing-library/user-event' import { vi } from 'vitest' -import { ChatContextProvider, useChatContext } from '../context' +import { useChatContext } from '../context' +import { ChatContextProvider } from '../context-provider' const TestConsumer = () => { const context = useChatContext() diff --git a/web/app/components/base/chat/chat/__tests__/question.spec.tsx b/web/app/components/base/chat/chat/__tests__/question.spec.tsx index 1d0584805b..7c717b6e31 100644 --- a/web/app/components/base/chat/chat/__tests__/question.spec.tsx +++ b/web/app/components/base/chat/chat/__tests__/question.spec.tsx @@ -9,7 +9,7 @@ import { vi } from 'vitest' import Toast from '../../../toast' import { ThemeBuilder } from '../../embedded-chatbot/theme/theme-context' -import { ChatContextProvider } from '../context' +import { ChatContextProvider } from '../context-provider' import Question from '../question' // Global Mocks diff --git a/web/app/components/base/chat/chat/chat-input-area/__tests__/index.spec.tsx b/web/app/components/base/chat/chat/chat-input-area/__tests__/index.spec.tsx index a6d4570fcb..03f3c673ce 100644 --- a/web/app/components/base/chat/chat/chat-input-area/__tests__/index.spec.tsx +++ b/web/app/components/base/chat/chat/chat-input-area/__tests__/index.spec.tsx @@ -91,15 +91,9 @@ vi.mock('@/app/components/base/features/hooks', () => ({ // --------------------------------------------------------------------------- // Toast context // --------------------------------------------------------------------------- -vi.mock('@/app/components/base/toast', async () => { - const actual = await vi.importActual( - '@/app/components/base/toast', - ) - return { - ...actual, - useToastContext: () => ({ notify: mockNotify }), - } -}) +vi.mock('@/app/components/base/toast/context', () => ({ + useToastContext: () => ({ notify: mockNotify, close: vi.fn() }), +})) // --------------------------------------------------------------------------- // Internal layout hook – controls single/multi-line textarea mode diff --git a/web/app/components/base/chat/chat/chat-input-area/index.tsx b/web/app/components/base/chat/chat/chat-input-area/index.tsx index 9de52cb18c..170cccaaca 100644 --- a/web/app/components/base/chat/chat/chat-input-area/index.tsx +++ b/web/app/components/base/chat/chat/chat-input-area/index.tsx @@ -22,7 +22,7 @@ import { FileContextProvider, useFileStore, } from '@/app/components/base/file-uploader/store' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import VoiceInput from '@/app/components/base/voice-input' import { TransferMethod } from '@/types/app' import { cn } from '@/utils/classnames' diff --git a/web/app/components/base/chat/chat/check-input-forms-hooks.ts b/web/app/components/base/chat/chat/check-input-forms-hooks.ts index 2da57b289e..842e89070b 100644 --- a/web/app/components/base/chat/chat/check-input-forms-hooks.ts +++ b/web/app/components/base/chat/chat/check-input-forms-hooks.ts @@ -1,7 +1,7 @@ import type { InputForm } from './type' import { useCallback } from 'react' import { useTranslation } from 'react-i18next' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { InputVarType } from '@/app/components/workflow/types' import { TransferMethod } from '@/types/app' diff --git a/web/app/components/base/chat/chat/context.tsx b/web/app/components/base/chat/chat/context-provider.tsx similarity index 58% rename from web/app/components/base/chat/chat/context.tsx rename to web/app/components/base/chat/chat/context-provider.tsx index 7843665ad7..02503521e5 100644 --- a/web/app/components/base/chat/chat/context.tsx +++ b/web/app/components/base/chat/chat/context-provider.tsx @@ -1,30 +1,8 @@ 'use client' import type { ReactNode } from 'react' -import type { ChatProps } from './index' -import { createContext, useContext } from 'use-context-selector' - -export type ChatContextValue = Pick & { - readonly?: boolean - } - -const ChatContext = createContext({ - chatList: [], - readonly: false, -}) +import type { ChatContextValue } from './context' +import { ChatContext } from './context' type ChatContextProviderProps = { children: ReactNode @@ -71,7 +49,3 @@ export const ChatContextProvider = ({ ) } - -export const useChatContext = () => useContext(ChatContext) - -export default ChatContext diff --git a/web/app/components/base/chat/chat/context.ts b/web/app/components/base/chat/chat/context.ts new file mode 100644 index 0000000000..ff0bd26336 --- /dev/null +++ b/web/app/components/base/chat/chat/context.ts @@ -0,0 +1,30 @@ +'use client' + +import type { ChatProps } from './index' +import { createContext, useContext } from 'use-context-selector' + +export type ChatContextValue = Pick & { + readonly?: boolean + } + +export const ChatContext = createContext({ + chatList: [], + readonly: false, +}) + +export const useChatContext = () => useContext(ChatContext) + +export default ChatContext diff --git a/web/app/components/base/chat/chat/hooks.ts b/web/app/components/base/chat/chat/hooks.ts index 3e01fa5ade..4828ef2a47 100644 --- a/web/app/components/base/chat/chat/hooks.ts +++ b/web/app/components/base/chat/chat/hooks.ts @@ -30,7 +30,7 @@ import { getProcessedFiles, getProcessedFilesFromResponse, } from '@/app/components/base/file-uploader/utils' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { NodeRunningStatus, WorkflowRunningStatus } from '@/app/components/workflow/types' import useTimestamp from '@/hooks/use-timestamp' import { diff --git a/web/app/components/base/chat/chat/index.tsx b/web/app/components/base/chat/chat/index.tsx index a77911d895..69c064e3e2 100644 --- a/web/app/components/base/chat/chat/index.tsx +++ b/web/app/components/base/chat/chat/index.tsx @@ -30,7 +30,7 @@ import PromptLogModal from '@/app/components/base/prompt-log-modal' import { cn } from '@/utils/classnames' import Answer from './answer' import ChatInputArea from './chat-input-area' -import { ChatContextProvider } from './context' +import { ChatContextProvider } from './context-provider' import Question from './question' import TryToAsk from './try-to-ask' diff --git a/web/app/components/base/chat/embedded-chatbot/context.tsx b/web/app/components/base/chat/embedded-chatbot/context.ts similarity index 100% rename from web/app/components/base/chat/embedded-chatbot/context.tsx rename to web/app/components/base/chat/embedded-chatbot/context.ts diff --git a/web/app/components/base/chat/embedded-chatbot/hooks.tsx b/web/app/components/base/chat/embedded-chatbot/hooks.tsx index bffee78792..da142a69ec 100644 --- a/web/app/components/base/chat/embedded-chatbot/hooks.tsx +++ b/web/app/components/base/chat/embedded-chatbot/hooks.tsx @@ -21,7 +21,7 @@ import { useState, } from 'react' import { useTranslation } from 'react-i18next' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { addFileInfos, sortAgentSorts } from '@/app/components/tools/utils' import { InputVarType } from '@/app/components/workflow/types' import { useWebAppStore } from '@/context/web-app-context' diff --git a/web/app/components/base/chat/embedded-chatbot/inputs-form/__tests__/content.spec.tsx b/web/app/components/base/chat/embedded-chatbot/inputs-form/__tests__/content.spec.tsx index 08c9a035e7..aad2d3d09b 100644 --- a/web/app/components/base/chat/embedded-chatbot/inputs-form/__tests__/content.spec.tsx +++ b/web/app/components/base/chat/embedded-chatbot/inputs-form/__tests__/content.spec.tsx @@ -16,7 +16,7 @@ vi.mock('next/navigation', () => ({ useSearchParams: () => new URLSearchParams(), })) -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: vi.fn() }), })) diff --git a/web/app/components/base/features/new-feature-panel/moderation/__tests__/moderation-setting-modal.spec.tsx b/web/app/components/base/features/new-feature-panel/moderation/__tests__/moderation-setting-modal.spec.tsx index 3c690635da..88f74d2686 100644 --- a/web/app/components/base/features/new-feature-panel/moderation/__tests__/moderation-setting-modal.spec.tsx +++ b/web/app/components/base/features/new-feature-panel/moderation/__tests__/moderation-setting-modal.spec.tsx @@ -3,7 +3,7 @@ import { fireEvent, render, screen } from '@testing-library/react' import ModerationSettingModal from '../moderation-setting-modal' const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify }), })) diff --git a/web/app/components/base/features/new-feature-panel/moderation/moderation-setting-modal.tsx b/web/app/components/base/features/new-feature-panel/moderation/moderation-setting-modal.tsx index c68abfd7b1..4c0682d182 100644 --- a/web/app/components/base/features/new-feature-panel/moderation/moderation-setting-modal.tsx +++ b/web/app/components/base/features/new-feature-panel/moderation/moderation-setting-modal.tsx @@ -7,7 +7,7 @@ import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Divider from '@/app/components/base/divider' import Modal from '@/app/components/base/modal' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import ApiBasedExtensionSelector from '@/app/components/header/account-setting/api-based-extension-page/selector' import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants' import { CustomConfigurationStatusEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' diff --git a/web/app/components/base/file-uploader/__tests__/hooks.spec.ts b/web/app/components/base/file-uploader/__tests__/hooks.spec.ts index 00c64224aa..8343974967 100644 --- a/web/app/components/base/file-uploader/__tests__/hooks.spec.ts +++ b/web/app/components/base/file-uploader/__tests__/hooks.spec.ts @@ -11,7 +11,7 @@ vi.mock('next/navigation', () => ({ })) // Exception: hook requires toast context that isn't available without a provider wrapper -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify, }), diff --git a/web/app/components/base/file-uploader/hooks.ts b/web/app/components/base/file-uploader/hooks.ts index 14e46548d8..4aab60175c 100644 --- a/web/app/components/base/file-uploader/hooks.ts +++ b/web/app/components/base/file-uploader/hooks.ts @@ -18,7 +18,7 @@ import { MAX_FILE_UPLOAD_LIMIT, VIDEO_SIZE_LIMIT, } from '@/app/components/base/file-uploader/constants' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { SupportUploadFileTypes } from '@/app/components/workflow/types' import { uploadRemoteFileInfo } from '@/service/common' import { TransferMethod } from '@/types/app' diff --git a/web/app/components/base/form/hooks/__tests__/use-check-validated.spec.ts b/web/app/components/base/form/hooks/__tests__/use-check-validated.spec.ts index b9d60594f7..28eb5bd5ed 100644 --- a/web/app/components/base/form/hooks/__tests__/use-check-validated.spec.ts +++ b/web/app/components/base/form/hooks/__tests__/use-check-validated.spec.ts @@ -5,7 +5,7 @@ import { useCheckValidated } from '../use-check-validated' const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify, }), diff --git a/web/app/components/base/form/hooks/use-check-validated.ts b/web/app/components/base/form/hooks/use-check-validated.ts index 7ed6164bb2..d186996035 100644 --- a/web/app/components/base/form/hooks/use-check-validated.ts +++ b/web/app/components/base/form/hooks/use-check-validated.ts @@ -1,7 +1,7 @@ import type { AnyFormApi } from '@tanstack/react-form' import type { FormSchema } from '@/app/components/base/form/types' import { useCallback } from 'react' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' export const useCheckValidated = (form: AnyFormApi, FormSchemas: FormSchema[]) => { const { notify } = useToastContext() diff --git a/web/app/components/base/image-uploader/__tests__/hooks.spec.ts b/web/app/components/base/image-uploader/__tests__/hooks.spec.ts index 4d150830d0..f79ea98081 100644 --- a/web/app/components/base/image-uploader/__tests__/hooks.spec.ts +++ b/web/app/components/base/image-uploader/__tests__/hooks.spec.ts @@ -5,7 +5,7 @@ import { Resolution, TransferMethod } from '@/types/app' import { useClipboardUploader, useDraggableUploader, useImageFiles, useLocalFileUploader } from '../hooks' const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify }), })) diff --git a/web/app/components/base/image-uploader/hooks.ts b/web/app/components/base/image-uploader/hooks.ts index cd309a1f7b..03cf0feeca 100644 --- a/web/app/components/base/image-uploader/hooks.ts +++ b/web/app/components/base/image-uploader/hooks.ts @@ -3,7 +3,7 @@ import type { ImageFile, VisionSettings } from '@/types/app' import { useParams } from 'next/navigation' import { useCallback, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { ALLOW_FILE_EXTENSIONS, TransferMethod } from '@/types/app' import { getImageUploadErrorMessage, imageUpload } from './utils' diff --git a/web/app/components/base/markdown-blocks/__tests__/button.spec.tsx b/web/app/components/base/markdown-blocks/__tests__/button.spec.tsx index 305896f4f1..95ed788db3 100644 --- a/web/app/components/base/markdown-blocks/__tests__/button.spec.tsx +++ b/web/app/components/base/markdown-blocks/__tests__/button.spec.tsx @@ -5,7 +5,7 @@ import userEvent from '@testing-library/user-event' // markdown-button.spec.tsx import * as React from 'react' import { beforeEach, describe, expect, it, vi } from 'vitest' -import { ChatContextProvider } from '@/app/components/base/chat/chat/context' +import { ChatContextProvider } from '@/app/components/base/chat/chat/context-provider' import MarkdownButton from '../button' diff --git a/web/app/components/base/markdown-blocks/__tests__/think-block.spec.tsx b/web/app/components/base/markdown-blocks/__tests__/think-block.spec.tsx index 2cd31f9a49..e8b956cbbf 100644 --- a/web/app/components/base/markdown-blocks/__tests__/think-block.spec.tsx +++ b/web/app/components/base/markdown-blocks/__tests__/think-block.spec.tsx @@ -1,6 +1,6 @@ import { act, render, screen } from '@testing-library/react' import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' -import { ChatContextProvider } from '@/app/components/base/chat/chat/context' +import { ChatContextProvider } from '@/app/components/base/chat/chat/context-provider' import ThinkBlock from '../think-block' // Mock react-i18next diff --git a/web/app/components/base/markdown-blocks/think-block.stories.tsx b/web/app/components/base/markdown-blocks/think-block.stories.tsx index 23713fb263..7c3f809ee7 100644 --- a/web/app/components/base/markdown-blocks/think-block.stories.tsx +++ b/web/app/components/base/markdown-blocks/think-block.stories.tsx @@ -1,6 +1,6 @@ import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useState } from 'react' -import { ChatContextProvider } from '@/app/components/base/chat/chat/context' +import { ChatContextProvider } from '@/app/components/base/chat/chat/context-provider' import ThinkBlock from './think-block' const THOUGHT_TEXT = ` diff --git a/web/app/components/base/prompt-editor/plugins/component-picker-block/__tests__/index.spec.tsx b/web/app/components/base/prompt-editor/plugins/component-picker-block/__tests__/index.spec.tsx index 2deec561e9..6cc6c3a67f 100644 --- a/web/app/components/base/prompt-editor/plugins/component-picker-block/__tests__/index.spec.tsx +++ b/web/app/components/base/prompt-editor/plugins/component-picker-block/__tests__/index.spec.tsx @@ -28,7 +28,8 @@ import { import * as React from 'react' import { GeneratorType } from '@/app/components/app/configuration/config/automatic/types' import { VarType } from '@/app/components/workflow/types' -import { EventEmitterContextProvider, useEventEmitterContextContext } from '@/context/event-emitter' +import { useEventEmitterContextContext } from '@/context/event-emitter' +import { EventEmitterContextProvider } from '@/context/event-emitter-provider' import { INSERT_CONTEXT_BLOCK_COMMAND } from '../../context-block' import { INSERT_CURRENT_BLOCK_COMMAND } from '../../current-block' import { INSERT_ERROR_MESSAGE_BLOCK_COMMAND } from '../../error-message-block' diff --git a/web/app/components/base/radio/context/index.tsx b/web/app/components/base/radio/context/index.ts similarity index 100% rename from web/app/components/base/radio/context/index.tsx rename to web/app/components/base/radio/context/index.ts diff --git a/web/app/components/base/tag-input/__tests__/index.spec.tsx b/web/app/components/base/tag-input/__tests__/index.spec.tsx index b091d9cd03..f07399f8af 100644 --- a/web/app/components/base/tag-input/__tests__/index.spec.tsx +++ b/web/app/components/base/tag-input/__tests__/index.spec.tsx @@ -5,7 +5,7 @@ import TagInput from '../index' const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify, }), diff --git a/web/app/components/base/tag-input/index.tsx b/web/app/components/base/tag-input/index.tsx index 1c49b026fb..377e68abe4 100644 --- a/web/app/components/base/tag-input/index.tsx +++ b/web/app/components/base/tag-input/index.tsx @@ -2,7 +2,7 @@ import type { ChangeEvent, FC, KeyboardEvent } from 'react' import { useCallback, useState } from 'react' import AutosizeInput from 'react-18-input-autosize' import { useTranslation } from 'react-i18next' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { cn } from '@/utils/classnames' type TagInputProps = { diff --git a/web/app/components/base/tag-management/__tests__/panel.spec.tsx b/web/app/components/base/tag-management/__tests__/panel.spec.tsx index c91c72e583..cd9e37e286 100644 --- a/web/app/components/base/tag-management/__tests__/panel.spec.tsx +++ b/web/app/components/base/tag-management/__tests__/panel.spec.tsx @@ -3,7 +3,7 @@ import { render, screen, waitFor, within } from '@testing-library/react' import userEvent from '@testing-library/user-event' import * as React from 'react' import { act } from 'react' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import Panel from '../panel' import { useStore as useTagStore } from '../store' diff --git a/web/app/components/base/tag-management/__tests__/selector.spec.tsx b/web/app/components/base/tag-management/__tests__/selector.spec.tsx index dc58ca37e6..43f17a1e8c 100644 --- a/web/app/components/base/tag-management/__tests__/selector.spec.tsx +++ b/web/app/components/base/tag-management/__tests__/selector.spec.tsx @@ -3,7 +3,7 @@ import { render, screen, waitFor, within } from '@testing-library/react' import userEvent from '@testing-library/user-event' import * as React from 'react' import { act } from 'react' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import TagSelector from '../selector' import { useStore as useTagStore } from '../store' diff --git a/web/app/components/base/tag-management/index.tsx b/web/app/components/base/tag-management/index.tsx index e9ce85ecc0..b7682bcdad 100644 --- a/web/app/components/base/tag-management/index.tsx +++ b/web/app/components/base/tag-management/index.tsx @@ -4,7 +4,7 @@ import { useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import Modal from '@/app/components/base/modal' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { createTag, fetchTagList, diff --git a/web/app/components/base/tag-management/panel.tsx b/web/app/components/base/tag-management/panel.tsx index 4174ba0476..cebed74f3b 100644 --- a/web/app/components/base/tag-management/panel.tsx +++ b/web/app/components/base/tag-management/panel.tsx @@ -10,7 +10,7 @@ import { useContext } from 'use-context-selector' import Checkbox from '@/app/components/base/checkbox' import Divider from '@/app/components/base/divider' import Input from '@/app/components/base/input' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { bindTag, createTag, unBindTag } from '@/service/tag' import { useStore as useTagStore } from './store' diff --git a/web/app/components/base/tag-management/tag-item-editor.tsx b/web/app/components/base/tag-management/tag-item-editor.tsx index a37e42dcce..3cff335f58 100644 --- a/web/app/components/base/tag-management/tag-item-editor.tsx +++ b/web/app/components/base/tag-management/tag-item-editor.tsx @@ -6,7 +6,7 @@ import { useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import Confirm from '@/app/components/base/confirm' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import Tooltip from '@/app/components/base/tooltip' import { deleteTag, diff --git a/web/app/components/base/text-generation/__tests__/hooks.spec.ts b/web/app/components/base/text-generation/__tests__/hooks.spec.ts index cab06f1c8a..a5b5578158 100644 --- a/web/app/components/base/text-generation/__tests__/hooks.spec.ts +++ b/web/app/components/base/text-generation/__tests__/hooks.spec.ts @@ -5,7 +5,7 @@ import { useTextGeneration } from '../hooks' const mockNotify = vi.fn() const mockSsePost = vi.fn<(url: string, fetchOptions: { body: Record }, otherOptions: IOtherOptions) => void>() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify, }), diff --git a/web/app/components/base/text-generation/hooks.ts b/web/app/components/base/text-generation/hooks.ts index c5d008956b..4314a81925 100644 --- a/web/app/components/base/text-generation/hooks.ts +++ b/web/app/components/base/text-generation/hooks.ts @@ -1,6 +1,6 @@ import { useState } from 'react' import { useTranslation } from 'react-i18next' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { ssePost } from '@/service/base' export const useTextGeneration = () => { diff --git a/web/app/components/base/toast/__tests__/index.spec.tsx b/web/app/components/base/toast/__tests__/index.spec.tsx index f526290fa1..2f5fa49823 100644 --- a/web/app/components/base/toast/__tests__/index.spec.tsx +++ b/web/app/components/base/toast/__tests__/index.spec.tsx @@ -2,7 +2,8 @@ import type { ReactNode } from 'react' import { act, render, screen, waitFor } from '@testing-library/react' import { noop } from 'es-toolkit/function' import * as React from 'react' -import Toast, { ToastProvider, useToastContext } from '..' +import Toast, { ToastProvider } from '..' +import { useToastContext } from '../context' const TestComponent = () => { const { notify, close } = useToastContext() diff --git a/web/app/components/base/toast/context.ts b/web/app/components/base/toast/context.ts new file mode 100644 index 0000000000..ddd8f91336 --- /dev/null +++ b/web/app/components/base/toast/context.ts @@ -0,0 +1,23 @@ +'use client' + +import type { ReactNode } from 'react' +import { createContext, useContext } from 'use-context-selector' + +export type IToastProps = { + type?: 'success' | 'error' | 'warning' | 'info' + size?: 'md' | 'sm' + duration?: number + message: string + children?: ReactNode + onClose?: () => void + className?: string + customComponent?: ReactNode +} + +type IToastContext = { + notify: (props: IToastProps) => void + close: () => void +} + +export const ToastContext = createContext({} as IToastContext) +export const useToastContext = () => useContext(ToastContext) diff --git a/web/app/components/base/toast/index.stories.tsx b/web/app/components/base/toast/index.stories.tsx index 4ab9138070..40d6fecfc2 100644 --- a/web/app/components/base/toast/index.stories.tsx +++ b/web/app/components/base/toast/index.stories.tsx @@ -1,6 +1,7 @@ import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useCallback } from 'react' -import Toast, { ToastProvider, useToastContext } from '.' +import Toast, { ToastProvider } from '.' +import { useToastContext } from './context' const ToastControls = () => { const { notify } = useToastContext() diff --git a/web/app/components/base/toast/index.tsx b/web/app/components/base/toast/index.tsx index 1b9ae4eedb..a70a0db06c 100644 --- a/web/app/components/base/toast/index.tsx +++ b/web/app/components/base/toast/index.tsx @@ -1,5 +1,6 @@ 'use client' import type { ReactNode } from 'react' +import type { IToastProps } from './context' import { RiAlertFill, RiCheckboxCircleFill, @@ -11,31 +12,13 @@ import { noop } from 'es-toolkit/function' import * as React from 'react' import { useEffect, useState } from 'react' import { createRoot } from 'react-dom/client' -import { createContext, useContext } from 'use-context-selector' import ActionButton from '@/app/components/base/action-button' import { cn } from '@/utils/classnames' - -export type IToastProps = { - type?: 'success' | 'error' | 'warning' | 'info' - size?: 'md' | 'sm' - duration?: number - message: string - children?: ReactNode - onClose?: () => void - className?: string - customComponent?: ReactNode -} -type IToastContext = { - notify: (props: IToastProps) => void - close: () => void -} +import { ToastContext, useToastContext } from './context' export type ToastHandle = { clear?: VoidFunction } - -export const ToastContext = createContext({} as IToastContext) -export const useToastContext = () => useContext(ToastContext) const Toast = ({ type = 'info', size = 'md', @@ -77,11 +60,11 @@ const Toast = ({

-
{message}
+
{message}
{customComponent}
{!!children && ( -
+
{children}
)} @@ -183,3 +166,5 @@ Toast.notify = ({ } export default Toast + +export type { IToastProps } from './context' diff --git a/web/app/components/custom/custom-web-app-brand/__tests__/index.spec.tsx b/web/app/components/custom/custom-web-app-brand/__tests__/index.spec.tsx index 2ceb45235c..1d17a2ae0f 100644 --- a/web/app/components/custom/custom-web-app-brand/__tests__/index.spec.tsx +++ b/web/app/components/custom/custom-web-app-brand/__tests__/index.spec.tsx @@ -1,7 +1,7 @@ import { fireEvent, render, screen, waitFor } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' import { getImageUploadErrorMessage, imageUpload } from '@/app/components/base/image-uploader/utils' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { Plan } from '@/app/components/billing/type' import { useAppContext } from '@/context/app-context' import { useGlobalPublicStore } from '@/context/global-public-context' @@ -9,7 +9,7 @@ import { useProviderContext } from '@/context/provider-context' import { updateCurrentWorkspace } from '@/service/common' import CustomWebAppBrand from '../index' -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: vi.fn(), })) vi.mock('@/service/common', () => ({ diff --git a/web/app/components/custom/custom-web-app-brand/index.tsx b/web/app/components/custom/custom-web-app-brand/index.tsx index 438e69894d..e6f9a3837b 100644 --- a/web/app/components/custom/custom-web-app-brand/index.tsx +++ b/web/app/components/custom/custom-web-app-brand/index.tsx @@ -16,7 +16,7 @@ import { BubbleTextMod } from '@/app/components/base/icons/src/vender/solid/comm import { getImageUploadErrorMessage, imageUpload } from '@/app/components/base/image-uploader/utils' import DifyLogo from '@/app/components/base/logo/dify-logo' import Switch from '@/app/components/base/switch' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { Plan } from '@/app/components/billing/type' import { useAppContext } from '@/context/app-context' import { useGlobalPublicStore } from '@/context/global-public-context' diff --git a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/__tests__/uploader.spec.tsx b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/__tests__/uploader.spec.tsx index db67c91bc0..e47a876fe8 100644 --- a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/__tests__/uploader.spec.tsx +++ b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/__tests__/uploader.spec.tsx @@ -4,7 +4,7 @@ import { beforeEach, describe, expect, it, vi } from 'vitest' import Uploader from '../uploader' const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ ToastContext: { Provider: ({ children }: { children: React.ReactNode }) => children, Consumer: ({ children }: { children: (value: { notify: typeof mockNotify }) => React.ReactNode }) => children({ notify: mockNotify }), diff --git a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/use-dsl-import.ts b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/use-dsl-import.ts index 87e55ea740..c839fad3a2 100644 --- a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/use-dsl-import.ts +++ b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/use-dsl-import.ts @@ -4,7 +4,7 @@ import { useRouter } from 'next/navigation' import { useCallback, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { usePluginDependencies } from '@/app/components/workflow/plugin-dependency/hooks' import { DSLImportMode, diff --git a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/uploader.tsx b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/uploader.tsx index 3fa940c60d..faf168c73a 100644 --- a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/uploader.tsx +++ b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/uploader.tsx @@ -10,7 +10,7 @@ import { useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import ActionButton from '@/app/components/base/action-button' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { cn } from '@/utils/classnames' import { formatFileSize } from '@/utils/format' diff --git a/web/app/components/datasets/create/empty-dataset-creation-modal/index.tsx b/web/app/components/datasets/create/empty-dataset-creation-modal/index.tsx index a9f55c8b84..0a4064de2a 100644 --- a/web/app/components/datasets/create/empty-dataset-creation-modal/index.tsx +++ b/web/app/components/datasets/create/empty-dataset-creation-modal/index.tsx @@ -8,7 +8,7 @@ import { trackEvent } from '@/app/components/base/amplitude' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Modal from '@/app/components/base/modal' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { createEmptyDataset } from '@/service/datasets' import { useInvalidDatasetList } from '@/service/knowledge/use-dataset' diff --git a/web/app/components/datasets/create/file-uploader/hooks/__tests__/use-file-upload.spec.tsx b/web/app/components/datasets/create/file-uploader/hooks/__tests__/use-file-upload.spec.tsx index b5d1a96554..80331afe2a 100644 --- a/web/app/components/datasets/create/file-uploader/hooks/__tests__/use-file-upload.spec.tsx +++ b/web/app/components/datasets/create/file-uploader/hooks/__tests__/use-file-upload.spec.tsx @@ -2,7 +2,7 @@ import type { ReactNode } from 'react' import type { CustomFile, FileItem } from '@/models/datasets' import { act, render, renderHook, waitFor } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { PROGRESS_COMPLETE, PROGRESS_ERROR, PROGRESS_NOT_STARTED } from '../../constants' // Import after mocks diff --git a/web/app/components/datasets/create/file-uploader/hooks/use-file-upload.ts b/web/app/components/datasets/create/file-uploader/hooks/use-file-upload.ts index e097bab755..ada60fac1a 100644 --- a/web/app/components/datasets/create/file-uploader/hooks/use-file-upload.ts +++ b/web/app/components/datasets/create/file-uploader/hooks/use-file-upload.ts @@ -5,7 +5,7 @@ import { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import { getFileUploadErrorMessage } from '@/app/components/base/file-uploader/utils' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { IS_CE_EDITION } from '@/config' import { useLocale } from '@/context/i18n' import { LanguagesSupported } from '@/i18n-config/language' diff --git a/web/app/components/datasets/documents/components/__tests__/operations.spec.tsx b/web/app/components/datasets/documents/components/__tests__/operations.spec.tsx index 5aae8dda73..b988b1aeab 100644 --- a/web/app/components/datasets/documents/components/__tests__/operations.spec.tsx +++ b/web/app/components/datasets/documents/components/__tests__/operations.spec.tsx @@ -10,7 +10,7 @@ vi.mock('next/navigation', () => ({ })) const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ ToastContext: { Provider: ({ children }: { children: React.ReactNode }) => children, }, diff --git a/web/app/components/datasets/documents/components/operations.tsx b/web/app/components/datasets/documents/components/operations.tsx index 15c89a9b26..84e16c7c48 100644 --- a/web/app/components/datasets/documents/components/operations.tsx +++ b/web/app/components/datasets/documents/components/operations.tsx @@ -24,7 +24,7 @@ import Divider from '@/app/components/base/divider' import { SearchLinesSparkle } from '@/app/components/base/icons/src/vender/knowledge' import CustomPopover from '@/app/components/base/popover' import Switch from '@/app/components/base/switch' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import Tooltip from '@/app/components/base/tooltip' import { IS_CE_EDITION } from '@/config' import { DataSourceType, DocumentActionType } from '@/models/datasets' diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/hooks/__tests__/use-local-file-upload.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/hooks/__tests__/use-local-file-upload.spec.tsx index bc9ce04beb..efd1f2a483 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/hooks/__tests__/use-local-file-upload.spec.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/hooks/__tests__/use-local-file-upload.spec.tsx @@ -9,7 +9,7 @@ const mockNotify = vi.fn() const mockClose = vi.fn() // Mock ToastContext with factory function -vi.mock('@/app/components/base/toast', async () => { +vi.mock('@/app/components/base/toast/context', async () => { const { createContext, useContext } = await import('use-context-selector') const context = createContext({ notify: mockNotify, close: mockClose }) return { @@ -87,7 +87,7 @@ vi.mock('@/service/base', () => ({ // Import after all mocks are set up const { useLocalFileUpload } = await import('../use-local-file-upload') -const { ToastContext } = await import('@/app/components/base/toast') +const { ToastContext } = await import('@/app/components/base/toast/context') const createWrapper = () => { return ({ children }: { children: ReactNode }) => ( diff --git a/web/app/components/datasets/documents/detail/batch-modal/__tests__/csv-uploader.spec.tsx b/web/app/components/datasets/documents/detail/batch-modal/__tests__/csv-uploader.spec.tsx index 7fb1de7cf9..6876753714 100644 --- a/web/app/components/datasets/documents/detail/batch-modal/__tests__/csv-uploader.spec.tsx +++ b/web/app/components/datasets/documents/detail/batch-modal/__tests__/csv-uploader.spec.tsx @@ -25,7 +25,7 @@ vi.mock('@/hooks/use-theme', () => ({ })) const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ ToastContext: { Provider: ({ children }: { children: ReactNode }) => children, Consumer: ({ children }: { children: (ctx: { notify: typeof mockNotify }) => ReactNode }) => children({ notify: mockNotify }), diff --git a/web/app/components/datasets/documents/detail/batch-modal/csv-uploader.tsx b/web/app/components/datasets/documents/detail/batch-modal/csv-uploader.tsx index f3a86e910d..63efc766b7 100644 --- a/web/app/components/datasets/documents/detail/batch-modal/csv-uploader.tsx +++ b/web/app/components/datasets/documents/detail/batch-modal/csv-uploader.tsx @@ -12,7 +12,7 @@ import Button from '@/app/components/base/button' import { getFileUploadErrorMessage } from '@/app/components/base/file-uploader/utils' import { Csv as CSVIcon } from '@/app/components/base/icons/src/public/files' import SimplePieChart from '@/app/components/base/simple-pie-chart' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import useTheme from '@/hooks/use-theme' import { upload } from '@/service/base' import { useFileUploadConfig } from '@/service/use-common' diff --git a/web/app/components/datasets/documents/detail/completed/__tests__/index.spec.tsx b/web/app/components/datasets/documents/detail/completed/__tests__/index.spec.tsx index 5802fb8b82..59ecbf5f25 100644 --- a/web/app/components/datasets/documents/detail/completed/__tests__/index.spec.tsx +++ b/web/app/components/datasets/documents/detail/completed/__tests__/index.spec.tsx @@ -65,7 +65,7 @@ vi.mock('../../context', () => ({ }, })) -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ ToastContext: { Provider: ({ children }: { children: React.ReactNode }) => children, Consumer: () => null }, useToastContext: () => ({ notify: mockNotify }), })) diff --git a/web/app/components/datasets/documents/detail/completed/common/__tests__/regeneration-modal.spec.tsx b/web/app/components/datasets/documents/detail/completed/common/__tests__/regeneration-modal.spec.tsx index 719e2867b7..cc7f1aafa4 100644 --- a/web/app/components/datasets/documents/detail/completed/common/__tests__/regeneration-modal.spec.tsx +++ b/web/app/components/datasets/documents/detail/completed/common/__tests__/regeneration-modal.spec.tsx @@ -1,7 +1,8 @@ import type { ReactNode } from 'react' import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' -import { EventEmitterContextProvider, useEventEmitterContextContext } from '@/context/event-emitter' +import { useEventEmitterContextContext } from '@/context/event-emitter' +import { EventEmitterContextProvider } from '@/context/event-emitter-provider' import RegenerationModal from '../regeneration-modal' // Store emit function for triggering events in tests diff --git a/web/app/components/datasets/documents/detail/completed/hooks/__tests__/use-child-segment-data.spec.ts b/web/app/components/datasets/documents/detail/completed/hooks/__tests__/use-child-segment-data.spec.ts index 83918a3f30..4cfb4d5927 100644 --- a/web/app/components/datasets/documents/detail/completed/hooks/__tests__/use-child-segment-data.spec.ts +++ b/web/app/components/datasets/documents/detail/completed/hooks/__tests__/use-child-segment-data.spec.ts @@ -59,7 +59,7 @@ vi.mock('../../../context', () => ({ }, })) -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify }), })) diff --git a/web/app/components/datasets/documents/detail/completed/hooks/__tests__/use-segment-list-data.spec.ts b/web/app/components/datasets/documents/detail/completed/hooks/__tests__/use-segment-list-data.spec.ts index aef2053298..f54c00e3e7 100644 --- a/web/app/components/datasets/documents/detail/completed/hooks/__tests__/use-segment-list-data.spec.ts +++ b/web/app/components/datasets/documents/detail/completed/hooks/__tests__/use-segment-list-data.spec.ts @@ -92,7 +92,7 @@ vi.mock('../../../context', () => ({ }, })) -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify }), })) diff --git a/web/app/components/datasets/documents/detail/completed/hooks/use-child-segment-data.ts b/web/app/components/datasets/documents/detail/completed/hooks/use-child-segment-data.ts index 4f4c6a532d..cdc8a0b22d 100644 --- a/web/app/components/datasets/documents/detail/completed/hooks/use-child-segment-data.ts +++ b/web/app/components/datasets/documents/detail/completed/hooks/use-child-segment-data.ts @@ -2,7 +2,7 @@ import type { ChildChunkDetail, ChildSegmentsResponse, SegmentDetailModel, Segme import { useQueryClient } from '@tanstack/react-query' import { useCallback, useEffect, useMemo, useRef } from 'react' import { useTranslation } from 'react-i18next' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { useEventEmitterContextContext } from '@/context/event-emitter' import { useChildSegmentList, diff --git a/web/app/components/datasets/documents/detail/completed/hooks/use-segment-list-data.ts b/web/app/components/datasets/documents/detail/completed/hooks/use-segment-list-data.ts index fd391d2864..aa91e9f464 100644 --- a/web/app/components/datasets/documents/detail/completed/hooks/use-segment-list-data.ts +++ b/web/app/components/datasets/documents/detail/completed/hooks/use-segment-list-data.ts @@ -4,7 +4,7 @@ import { useQueryClient } from '@tanstack/react-query' import { usePathname } from 'next/navigation' import { useCallback, useEffect, useMemo, useRef } from 'react' import { useTranslation } from 'react-i18next' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { useEventEmitterContextContext } from '@/context/event-emitter' import { ChunkingMode } from '@/models/datasets' import { diff --git a/web/app/components/datasets/documents/detail/completed/new-child-segment.tsx b/web/app/components/datasets/documents/detail/completed/new-child-segment.tsx index 89143662c6..7a6ae4b306 100644 --- a/web/app/components/datasets/documents/detail/completed/new-child-segment.tsx +++ b/web/app/components/datasets/documents/detail/completed/new-child-segment.tsx @@ -8,7 +8,7 @@ import { useContext } from 'use-context-selector' import { useShallow } from 'zustand/react/shallow' import { useStore as useAppStore } from '@/app/components/app/store' import Divider from '@/app/components/base/divider' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { ChunkingMode } from '@/models/datasets' import { useAddChildSegment } from '@/service/knowledge/use-segment' import { cn } from '@/utils/classnames' diff --git a/web/app/components/datasets/documents/detail/embedding/index.tsx b/web/app/components/datasets/documents/detail/embedding/index.tsx index e89a85c6de..bd344800db 100644 --- a/web/app/components/datasets/documents/detail/embedding/index.tsx +++ b/web/app/components/datasets/documents/detail/embedding/index.tsx @@ -5,7 +5,7 @@ import * as React from 'react' import { useCallback } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { useProcessRule } from '@/service/knowledge/use-dataset' import { useDocumentContext } from '../context' import { ProgressBar, RuleDetail, SegmentProgress, StatusHeader } from './components' diff --git a/web/app/components/datasets/documents/detail/metadata/hooks/__tests__/use-metadata-state.spec.ts b/web/app/components/datasets/documents/detail/metadata/hooks/__tests__/use-metadata-state.spec.ts index ab1d45338f..3d7b28c78c 100644 --- a/web/app/components/datasets/documents/detail/metadata/hooks/__tests__/use-metadata-state.spec.ts +++ b/web/app/components/datasets/documents/detail/metadata/hooks/__tests__/use-metadata-state.spec.ts @@ -3,7 +3,7 @@ import type { FullDocumentDetail } from '@/models/datasets' import { act, renderHook } from '@testing-library/react' import * as React from 'react' import { describe, expect, it, vi } from 'vitest' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { useMetadataState } from '../use-metadata-state' diff --git a/web/app/components/datasets/documents/detail/metadata/hooks/use-metadata-state.ts b/web/app/components/datasets/documents/detail/metadata/hooks/use-metadata-state.ts index 08651b699e..f786609981 100644 --- a/web/app/components/datasets/documents/detail/metadata/hooks/use-metadata-state.ts +++ b/web/app/components/datasets/documents/detail/metadata/hooks/use-metadata-state.ts @@ -4,7 +4,7 @@ import type { DocType, FullDocumentDetail } from '@/models/datasets' import { useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { modifyDocMetadata } from '@/service/datasets' import { asyncRunSafe } from '@/utils' import { useDocumentContext } from '../../context' diff --git a/web/app/components/datasets/documents/detail/new-segment.tsx b/web/app/components/datasets/documents/detail/new-segment.tsx index 3a58d6ac06..f32c94bf70 100644 --- a/web/app/components/datasets/documents/detail/new-segment.tsx +++ b/web/app/components/datasets/documents/detail/new-segment.tsx @@ -9,7 +9,7 @@ import { useContext } from 'use-context-selector' import { useShallow } from 'zustand/react/shallow' import { useStore as useAppStore } from '@/app/components/app/store' import Divider from '@/app/components/base/divider' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import ImageUploaderInChunk from '@/app/components/datasets/common/image-uploader/image-uploader-in-chunk' import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' import { ChunkingMode } from '@/models/datasets' diff --git a/web/app/components/datasets/documents/status-item/index.tsx b/web/app/components/datasets/documents/status-item/index.tsx index 60d837fd81..8d3abed7cf 100644 --- a/web/app/components/datasets/documents/status-item/index.tsx +++ b/web/app/components/datasets/documents/status-item/index.tsx @@ -8,7 +8,7 @@ import { useMemo } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import Switch from '@/app/components/base/switch' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import Tooltip from '@/app/components/base/tooltip' import Indicator from '@/app/components/header/indicator' import { useDocumentDelete, useDocumentDisable, useDocumentEnable } from '@/service/knowledge/use-document' diff --git a/web/app/components/datasets/external-api/external-api-modal/__tests__/index.spec.tsx b/web/app/components/datasets/external-api/external-api-modal/__tests__/index.spec.tsx index a631de3ea0..66d9a163be 100644 --- a/web/app/components/datasets/external-api/external-api-modal/__tests__/index.spec.tsx +++ b/web/app/components/datasets/external-api/external-api-modal/__tests__/index.spec.tsx @@ -12,7 +12,7 @@ vi.mock('@/service/datasets', () => ({ })) const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify, }), diff --git a/web/app/components/datasets/external-api/external-api-modal/index.tsx b/web/app/components/datasets/external-api/external-api-modal/index.tsx index a9a87d11bd..b6e870cdc1 100644 --- a/web/app/components/datasets/external-api/external-api-modal/index.tsx +++ b/web/app/components/datasets/external-api/external-api-modal/index.tsx @@ -19,7 +19,7 @@ import { PortalToFollowElem, PortalToFollowElemContent, } from '@/app/components/base/portal-to-follow-elem' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import Tooltip from '@/app/components/base/tooltip' import { createExternalAPI } from '@/service/datasets' import Form from './Form' diff --git a/web/app/components/datasets/external-knowledge-base/connector/__tests__/index.spec.tsx b/web/app/components/datasets/external-knowledge-base/connector/__tests__/index.spec.tsx index ccd637887b..a6a60aa856 100644 --- a/web/app/components/datasets/external-knowledge-base/connector/__tests__/index.spec.tsx +++ b/web/app/components/datasets/external-knowledge-base/connector/__tests__/index.spec.tsx @@ -22,7 +22,7 @@ vi.mock('@/context/i18n', () => ({ })) const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify, }), diff --git a/web/app/components/datasets/external-knowledge-base/connector/index.tsx b/web/app/components/datasets/external-knowledge-base/connector/index.tsx index 1545c0d232..cf36eed382 100644 --- a/web/app/components/datasets/external-knowledge-base/connector/index.tsx +++ b/web/app/components/datasets/external-knowledge-base/connector/index.tsx @@ -5,7 +5,7 @@ import { useRouter } from 'next/navigation' import * as React from 'react' import { useState } from 'react' import { trackEvent } from '@/app/components/base/amplitude' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import ExternalKnowledgeBaseCreate from '@/app/components/datasets/external-knowledge-base/create' import { createExternalKnowledgeBase } from '@/service/datasets' diff --git a/web/app/components/header/account-dropdown/workplace-selector/index.spec.tsx b/web/app/components/header/account-dropdown/workplace-selector/index.spec.tsx index fc32b5f8df..20104b572c 100644 --- a/web/app/components/header/account-dropdown/workplace-selector/index.spec.tsx +++ b/web/app/components/header/account-dropdown/workplace-selector/index.spec.tsx @@ -1,7 +1,7 @@ import type { ProviderContextState } from '@/context/provider-context' import type { IWorkspace } from '@/models/common' import { fireEvent, render, screen, waitFor } from '@testing-library/react' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { baseProviderContextValue, useProviderContext } from '@/context/provider-context' import { useWorkspacesContext } from '@/context/workspace-context' import { switchWorkspace } from '@/service/common' diff --git a/web/app/components/header/account-dropdown/workplace-selector/index.tsx b/web/app/components/header/account-dropdown/workplace-selector/index.tsx index 058935aa27..528686a26a 100644 --- a/web/app/components/header/account-dropdown/workplace-selector/index.tsx +++ b/web/app/components/header/account-dropdown/workplace-selector/index.tsx @@ -4,7 +4,7 @@ import { RiArrowDownSLine } from '@remixicon/react' import { Fragment } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import PlanBadge from '@/app/components/header/plan-badge' import { useWorkspacesContext } from '@/context/workspace-context' import { switchWorkspace } from '@/service/common' diff --git a/web/app/components/header/account-setting/api-based-extension-page/modal.spec.tsx b/web/app/components/header/account-setting/api-based-extension-page/modal.spec.tsx index 3903fbfcf3..884ee8df33 100644 --- a/web/app/components/header/account-setting/api-based-extension-page/modal.spec.tsx +++ b/web/app/components/header/account-setting/api-based-extension-page/modal.spec.tsx @@ -1,8 +1,8 @@ import type { TFunction } from 'i18next' -import type { IToastProps } from '@/app/components/base/toast' +import type { IToastProps } from '@/app/components/base/toast/context' import { fireEvent, render as RTLRender, screen, waitFor } from '@testing-library/react' import * as reactI18next from 'react-i18next' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { useDocLink } from '@/context/i18n' import { addApiBasedExtension, updateApiBasedExtension } from '@/service/common' import ApiBasedExtensionModal from './modal' diff --git a/web/app/components/header/account-setting/api-based-extension-page/modal.tsx b/web/app/components/header/account-setting/api-based-extension-page/modal.tsx index b04981bf3c..efe6c46dcc 100644 --- a/web/app/components/header/account-setting/api-based-extension-page/modal.tsx +++ b/web/app/components/header/account-setting/api-based-extension-page/modal.tsx @@ -6,7 +6,7 @@ import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import { BookOpen01 } from '@/app/components/base/icons/src/vender/line/education' import Modal from '@/app/components/base/modal' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { useDocLink } from '@/context/i18n' import { addApiBasedExtension, diff --git a/web/app/components/header/account-setting/language-page/index.tsx b/web/app/components/header/account-setting/language-page/index.tsx index 2a0604421f..5751e88285 100644 --- a/web/app/components/header/account-setting/language-page/index.tsx +++ b/web/app/components/header/account-setting/language-page/index.tsx @@ -7,7 +7,7 @@ import { useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import { SimpleSelect } from '@/app/components/base/select' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { useAppContext } from '@/context/app-context' import { useLocale } from '@/context/i18n' import { setLocaleOnClient } from '@/i18n-config' diff --git a/web/app/components/header/account-setting/members-page/edit-workspace-modal/index.spec.tsx b/web/app/components/header/account-setting/members-page/edit-workspace-modal/index.spec.tsx index 46ce3f1992..ae0dd8cd4d 100644 --- a/web/app/components/header/account-setting/members-page/edit-workspace-modal/index.spec.tsx +++ b/web/app/components/header/account-setting/members-page/edit-workspace-modal/index.spec.tsx @@ -3,7 +3,7 @@ import type { ICurrentWorkspace } from '@/models/common' import { render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' import { vi } from 'vitest' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { useAppContext } from '@/context/app-context' import { updateWorkspaceInfo } from '@/service/common' import EditWorkspaceModal from './index' diff --git a/web/app/components/header/account-setting/members-page/edit-workspace-modal/index.tsx b/web/app/components/header/account-setting/members-page/edit-workspace-modal/index.tsx index a702a83da9..1c3984b0b5 100644 --- a/web/app/components/header/account-setting/members-page/edit-workspace-modal/index.tsx +++ b/web/app/components/header/account-setting/members-page/edit-workspace-modal/index.tsx @@ -6,7 +6,7 @@ import { useContext } from 'use-context-selector' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Modal from '@/app/components/base/modal' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { useAppContext } from '@/context/app-context' import { updateWorkspaceInfo } from '@/service/common' import { cn } from '@/utils/classnames' diff --git a/web/app/components/header/account-setting/members-page/invite-modal/index.spec.tsx b/web/app/components/header/account-setting/members-page/invite-modal/index.spec.tsx index ef55425ee0..82882c8be5 100644 --- a/web/app/components/header/account-setting/members-page/invite-modal/index.spec.tsx +++ b/web/app/components/header/account-setting/members-page/invite-modal/index.spec.tsx @@ -2,7 +2,7 @@ import type { InvitationResponse } from '@/models/common' import { render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' import { vi } from 'vitest' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { useProviderContextSelector } from '@/context/provider-context' import { inviteMember } from '@/service/common' import InviteModal from './index' diff --git a/web/app/components/header/account-setting/members-page/invite-modal/index.tsx b/web/app/components/header/account-setting/members-page/invite-modal/index.tsx index a8c0da40bf..8e4e47e0b8 100644 --- a/web/app/components/header/account-setting/members-page/invite-modal/index.tsx +++ b/web/app/components/header/account-setting/members-page/invite-modal/index.tsx @@ -9,7 +9,7 @@ import { ReactMultiEmail } from 'react-multi-email' import { useContext } from 'use-context-selector' import Button from '@/app/components/base/button' import Modal from '@/app/components/base/modal' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { emailRegex } from '@/config' import { useLocale } from '@/context/i18n' import { useProviderContextSelector } from '@/context/provider-context' diff --git a/web/app/components/header/account-setting/members-page/operation/index.spec.tsx b/web/app/components/header/account-setting/members-page/operation/index.spec.tsx index 661b2fbc83..e5e7fac10f 100644 --- a/web/app/components/header/account-setting/members-page/operation/index.spec.tsx +++ b/web/app/components/header/account-setting/members-page/operation/index.spec.tsx @@ -2,7 +2,7 @@ import type { Member } from '@/models/common' import { render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' import { vi } from 'vitest' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import Operation from './index' const mockUpdateMemberRole = vi.fn() diff --git a/web/app/components/header/account-setting/members-page/operation/index.tsx b/web/app/components/header/account-setting/members-page/operation/index.tsx index 88c8e250ea..306a67a093 100644 --- a/web/app/components/header/account-setting/members-page/operation/index.tsx +++ b/web/app/components/header/account-setting/members-page/operation/index.tsx @@ -9,7 +9,7 @@ import { PortalToFollowElemContent, PortalToFollowElemTrigger, } from '@/app/components/base/portal-to-follow-elem' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { useProviderContext } from '@/context/provider-context' import { deleteMemberOrCancelInvitation, updateMemberRole } from '@/service/common' import { cn } from '@/utils/classnames' diff --git a/web/app/components/header/account-setting/members-page/transfer-ownership-modal/index.spec.tsx b/web/app/components/header/account-setting/members-page/transfer-ownership-modal/index.spec.tsx index 11a0a2db4a..4baa90a7fa 100644 --- a/web/app/components/header/account-setting/members-page/transfer-ownership-modal/index.spec.tsx +++ b/web/app/components/header/account-setting/members-page/transfer-ownership-modal/index.spec.tsx @@ -3,7 +3,7 @@ import type { ICurrentWorkspace } from '@/models/common' import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' import { vi } from 'vitest' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { useAppContext } from '@/context/app-context' import { ownershipTransfer, sendOwnerEmail, verifyOwnerEmail } from '@/service/common' import { useMembers } from '@/service/use-common' diff --git a/web/app/components/header/account-setting/members-page/transfer-ownership-modal/index.tsx b/web/app/components/header/account-setting/members-page/transfer-ownership-modal/index.tsx index 21ea8aa1e9..c4f614737a 100644 --- a/web/app/components/header/account-setting/members-page/transfer-ownership-modal/index.tsx +++ b/web/app/components/header/account-setting/members-page/transfer-ownership-modal/index.tsx @@ -6,7 +6,7 @@ import { useContext } from 'use-context-selector' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Modal from '@/app/components/base/modal' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { useAppContext } from '@/context/app-context' import { ownershipTransfer, diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-auth.spec.tsx b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-auth.spec.tsx index c2259f543c..b637fed894 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-auth.spec.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-auth.spec.tsx @@ -20,7 +20,7 @@ const mockAddModelCredential = vi.fn() const mockEditProviderCredential = vi.fn() const mockEditModelCredential = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify }), })) diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-auth.ts b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-auth.ts index 3576c749b2..4e01677b29 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-auth.ts +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-auth.ts @@ -12,7 +12,7 @@ import { useState, } from 'react' import { useTranslation } from 'react-i18next' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { useModelModalHandler, useRefreshModel, diff --git a/web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.spec.tsx b/web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.spec.tsx index 554efc93d2..9f493d25e5 100644 --- a/web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.spec.tsx +++ b/web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.spec.tsx @@ -24,7 +24,7 @@ vi.mock('@/config', async (importOriginal) => { } }) -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify, }), diff --git a/web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.tsx b/web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.tsx index c46f9d56bd..efa768e7f5 100644 --- a/web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.tsx +++ b/web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.tsx @@ -3,7 +3,7 @@ import type { } from '../declarations' import { useMemo } from 'react' import { useTranslation } from 'react-i18next' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { ConfigProvider } from '@/app/components/header/account-setting/model-provider-page/model-auth' import { useCredentialStatus } from '@/app/components/header/account-setting/model-provider-page/model-auth/hooks' import Indicator from '@/app/components/header/indicator' diff --git a/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal.spec.tsx b/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal.spec.tsx index ea78234612..b945b50e9b 100644 --- a/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal.spec.tsx +++ b/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal.spec.tsx @@ -43,7 +43,7 @@ let mockCredentialData: CredentialData | undefined = { current_credential_name: 'Default', } -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify, }), diff --git a/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal.tsx b/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal.tsx index 93229f1257..0009237edc 100644 --- a/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal.tsx +++ b/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal.tsx @@ -12,7 +12,7 @@ import Button from '@/app/components/base/button' import Confirm from '@/app/components/base/confirm' import Loading from '@/app/components/base/loading' import Modal from '@/app/components/base/modal' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { SwitchCredentialInLoadBalancing } from '@/app/components/header/account-setting/model-provider-page/model-auth' import { useGetModelCredential, diff --git a/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.spec.tsx b/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.spec.tsx index 819bb71164..22186b34e1 100644 --- a/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.spec.tsx +++ b/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.spec.tsx @@ -42,7 +42,7 @@ vi.mock('@/context/provider-context', () => ({ }), })) -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify, }), diff --git a/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.tsx b/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.tsx index 29c71e04fc..5df062789b 100644 --- a/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.tsx @@ -12,7 +12,7 @@ import { PortalToFollowElemContent, PortalToFollowElemTrigger, } from '@/app/components/base/portal-to-follow-elem' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import Tooltip from '@/app/components/base/tooltip' import { useAppContext } from '@/context/app-context' import { useProviderContext } from '@/context/provider-context' diff --git a/web/app/components/header/account-setting/plugin-page/SerpapiPlugin.spec.tsx b/web/app/components/header/account-setting/plugin-page/SerpapiPlugin.spec.tsx index de480d47a1..97a79815ff 100644 --- a/web/app/components/header/account-setting/plugin-page/SerpapiPlugin.spec.tsx +++ b/web/app/components/header/account-setting/plugin-page/SerpapiPlugin.spec.tsx @@ -1,6 +1,6 @@ import type { PluginProvider } from '@/models/common' import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { useAppContext } from '@/context/app-context' import SerpapiPlugin from './SerpapiPlugin' import { updatePluginKey, validatePluginKey } from './utils' @@ -20,7 +20,7 @@ const mockEventEmitter = vi.hoisted(() => { } }) -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: vi.fn(), })) diff --git a/web/app/components/header/account-setting/plugin-page/SerpapiPlugin.tsx b/web/app/components/header/account-setting/plugin-page/SerpapiPlugin.tsx index f6909fad28..fe8832e84b 100644 --- a/web/app/components/header/account-setting/plugin-page/SerpapiPlugin.tsx +++ b/web/app/components/header/account-setting/plugin-page/SerpapiPlugin.tsx @@ -2,7 +2,7 @@ import type { Form, ValidateValue } from '../key-validator/declarations' import type { PluginProvider } from '@/models/common' import Image from 'next/image' import { useTranslation } from 'react-i18next' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { useAppContext } from '@/context/app-context' import SerpapiLogo from '../../assets/serpapi.png' import KeyValidator from '../key-validator' diff --git a/web/app/components/header/account-setting/plugin-page/index.spec.tsx b/web/app/components/header/account-setting/plugin-page/index.spec.tsx index 654292443f..0654bb68aa 100644 --- a/web/app/components/header/account-setting/plugin-page/index.spec.tsx +++ b/web/app/components/header/account-setting/plugin-page/index.spec.tsx @@ -14,7 +14,7 @@ vi.mock('@/context/app-context', () => ({ useAppContext: vi.fn(), })) -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: vi.fn(), }), diff --git a/web/app/components/header/index.spec.tsx b/web/app/components/header/index.spec.tsx index 2d5ce50fb8..36c85e6f08 100644 --- a/web/app/components/header/index.spec.tsx +++ b/web/app/components/header/index.spec.tsx @@ -52,7 +52,7 @@ vi.mock('@/app/components/header/plan-badge', () => ({ ), })) -vi.mock('@/context/workspace-context', () => ({ +vi.mock('@/context/workspace-context-provider', () => ({ WorkspaceProvider: ({ children }: { children?: React.ReactNode }) => children, })) diff --git a/web/app/components/header/index.tsx b/web/app/components/header/index.tsx index 210c62b660..64498a82f8 100644 --- a/web/app/components/header/index.tsx +++ b/web/app/components/header/index.tsx @@ -8,7 +8,7 @@ import { useAppContext } from '@/context/app-context' import { useGlobalPublicStore } from '@/context/global-public-context' import { useModalContext } from '@/context/modal-context' import { useProviderContext } from '@/context/provider-context' -import { WorkspaceProvider } from '@/context/workspace-context' +import { WorkspaceProvider } from '@/context/workspace-context-provider' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import { Plan } from '../billing/type' import AccountDropdown from './account-dropdown' diff --git a/web/app/components/plugins/plugin-auth/__tests__/authorized-in-node.spec.tsx b/web/app/components/plugins/plugin-auth/__tests__/authorized-in-node.spec.tsx index 7e8208b995..91ffbaa24a 100644 --- a/web/app/components/plugins/plugin-auth/__tests__/authorized-in-node.spec.tsx +++ b/web/app/components/plugins/plugin-auth/__tests__/authorized-in-node.spec.tsx @@ -42,7 +42,7 @@ vi.mock('@/context/app-context', () => ({ }), })) -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: vi.fn() }), })) diff --git a/web/app/components/plugins/plugin-auth/__tests__/plugin-auth-in-agent.spec.tsx b/web/app/components/plugins/plugin-auth/__tests__/plugin-auth-in-agent.spec.tsx index 6b66aca9dd..901c3ab49a 100644 --- a/web/app/components/plugins/plugin-auth/__tests__/plugin-auth-in-agent.spec.tsx +++ b/web/app/components/plugins/plugin-auth/__tests__/plugin-auth-in-agent.spec.tsx @@ -42,7 +42,7 @@ vi.mock('@/context/app-context', () => ({ }), })) -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: vi.fn() }), })) diff --git a/web/app/components/plugins/plugin-auth/authorize/__tests__/api-key-modal.spec.tsx b/web/app/components/plugins/plugin-auth/authorize/__tests__/api-key-modal.spec.tsx index a99b3363d6..430149e50b 100644 --- a/web/app/components/plugins/plugin-auth/authorize/__tests__/api-key-modal.spec.tsx +++ b/web/app/components/plugins/plugin-auth/authorize/__tests__/api-key-modal.spec.tsx @@ -9,7 +9,7 @@ const mockAddPluginCredential = vi.fn().mockResolvedValue({}) const mockUpdatePluginCredential = vi.fn().mockResolvedValue({}) const mockFormValues = { isCheckValidated: true, values: { __name__: 'My Key', api_key: 'sk-123' } } -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify, }), diff --git a/web/app/components/plugins/plugin-auth/authorize/__tests__/authorize-components.spec.tsx b/web/app/components/plugins/plugin-auth/authorize/__tests__/authorize-components.spec.tsx index 51aa287fea..d120902e6d 100644 --- a/web/app/components/plugins/plugin-auth/authorize/__tests__/authorize-components.spec.tsx +++ b/web/app/components/plugins/plugin-auth/authorize/__tests__/authorize-components.spec.tsx @@ -95,7 +95,7 @@ vi.mock('@/app/components/base/form/form-scenarios/auth', () => ({ // Mock useToastContext const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify }), })) diff --git a/web/app/components/plugins/plugin-auth/authorize/__tests__/oauth-client-settings.spec.tsx b/web/app/components/plugins/plugin-auth/authorize/__tests__/oauth-client-settings.spec.tsx index 61920e2869..f1b86f80ea 100644 --- a/web/app/components/plugins/plugin-auth/authorize/__tests__/oauth-client-settings.spec.tsx +++ b/web/app/components/plugins/plugin-auth/authorize/__tests__/oauth-client-settings.spec.tsx @@ -9,7 +9,7 @@ const mockDeletePluginOAuthCustomClient = vi.fn().mockResolvedValue({}) const mockInvalidPluginOAuthClientSchema = vi.fn() const mockFormValues = { isCheckValidated: true, values: { __oauth_client__: 'custom', client_id: 'test-id' } } -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify, }), diff --git a/web/app/components/plugins/plugin-auth/authorize/api-key-modal.tsx b/web/app/components/plugins/plugin-auth/authorize/api-key-modal.tsx index cc98ca3731..d8ab1cafdd 100644 --- a/web/app/components/plugins/plugin-auth/authorize/api-key-modal.tsx +++ b/web/app/components/plugins/plugin-auth/authorize/api-key-modal.tsx @@ -16,7 +16,7 @@ import AuthForm from '@/app/components/base/form/form-scenarios/auth' import { FormTypeEnum } from '@/app/components/base/form/types' import Loading from '@/app/components/base/loading' import Modal from '@/app/components/base/modal/modal' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { ReadmeEntrance } from '../../readme-panel/entrance' import { ReadmeShowType } from '../../readme-panel/store' import { diff --git a/web/app/components/plugins/plugin-auth/authorize/oauth-client-settings.tsx b/web/app/components/plugins/plugin-auth/authorize/oauth-client-settings.tsx index 7eb22ee4ac..28989da77c 100644 --- a/web/app/components/plugins/plugin-auth/authorize/oauth-client-settings.tsx +++ b/web/app/components/plugins/plugin-auth/authorize/oauth-client-settings.tsx @@ -17,7 +17,7 @@ import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import AuthForm from '@/app/components/base/form/form-scenarios/auth' import Modal from '@/app/components/base/modal/modal' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { ReadmeEntrance } from '../../readme-panel/entrance' import { ReadmeShowType } from '../../readme-panel/store' import { diff --git a/web/app/components/plugins/plugin-auth/authorized/__tests__/index.spec.tsx b/web/app/components/plugins/plugin-auth/authorized/__tests__/index.spec.tsx index f56c814222..a617c2543e 100644 --- a/web/app/components/plugins/plugin-auth/authorized/__tests__/index.spec.tsx +++ b/web/app/components/plugins/plugin-auth/authorized/__tests__/index.spec.tsx @@ -52,7 +52,7 @@ vi.mock('../../hooks/use-credential', () => ({ // Mock toast context const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify, }), diff --git a/web/app/components/plugins/plugin-auth/authorized/index.tsx b/web/app/components/plugins/plugin-auth/authorized/index.tsx index ea58cd16c9..daee3131ad 100644 --- a/web/app/components/plugins/plugin-auth/authorized/index.tsx +++ b/web/app/components/plugins/plugin-auth/authorized/index.tsx @@ -19,7 +19,7 @@ import { PortalToFollowElemContent, PortalToFollowElemTrigger, } from '@/app/components/base/portal-to-follow-elem' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import Indicator from '@/app/components/header/indicator' import { cn } from '@/utils/classnames' import Authorize from '../authorize' diff --git a/web/app/components/plugins/plugin-auth/hooks/__tests__/use-plugin-auth-action.spec.ts b/web/app/components/plugins/plugin-auth/hooks/__tests__/use-plugin-auth-action.spec.ts index d31b29ab85..f779623697 100644 --- a/web/app/components/plugins/plugin-auth/hooks/__tests__/use-plugin-auth-action.spec.ts +++ b/web/app/components/plugins/plugin-auth/hooks/__tests__/use-plugin-auth-action.spec.ts @@ -11,7 +11,7 @@ const mockSetPluginDefaultCredential = vi.fn().mockResolvedValue({}) const mockUpdatePluginCredential = vi.fn().mockResolvedValue({}) const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify, }), diff --git a/web/app/components/plugins/plugin-auth/hooks/use-plugin-auth-action.ts b/web/app/components/plugins/plugin-auth/hooks/use-plugin-auth-action.ts index e9218e2d3d..5628c76cc3 100644 --- a/web/app/components/plugins/plugin-auth/hooks/use-plugin-auth-action.ts +++ b/web/app/components/plugins/plugin-auth/hooks/use-plugin-auth-action.ts @@ -5,7 +5,7 @@ import { useState, } from 'react' import { useTranslation } from 'react-i18next' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { useDeletePluginCredentialHook, useSetPluginDefaultCredentialHook, diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/apikey-edit-modal.spec.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/apikey-edit-modal.spec.tsx index af145df2da..e0fb7455ce 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/apikey-edit-modal.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/apikey-edit-modal.spec.tsx @@ -54,13 +54,16 @@ vi.mock('@/app/components/base/toast', async (importOriginal) => { default: { notify: (args: { type: string, message: string }) => mockToast(args), }, - useToastContext: () => ({ - notify: (args: { type: string, message: string }) => mockToast(args), - close: vi.fn(), - }), } }) +vi.mock('@/app/components/base/toast/context', () => ({ + useToastContext: () => ({ + notify: (args: { type: string, message: string }) => mockToast(args), + close: vi.fn(), + }), +})) + const createSubscription = (overrides: Partial = {}): TriggerSubscription => ({ id: 'sub-1', name: 'Subscription One', diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/manual-edit-modal.spec.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/manual-edit-modal.spec.tsx index c6144542ab..60a8428287 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/manual-edit-modal.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/manual-edit-modal.spec.tsx @@ -37,13 +37,16 @@ vi.mock('@/app/components/base/toast', async (importOriginal) => { default: { notify: (args: { type: string, message: string }) => mockToast(args), }, - useToastContext: () => ({ - notify: (args: { type: string, message: string }) => mockToast(args), - close: vi.fn(), - }), } }) +vi.mock('@/app/components/base/toast/context', () => ({ + useToastContext: () => ({ + notify: (args: { type: string, message: string }) => mockToast(args), + close: vi.fn(), + }), +})) + const createSubscription = (overrides: Partial = {}): TriggerSubscription => ({ id: 'sub-1', name: 'Subscription One', diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/oauth-edit-modal.spec.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/oauth-edit-modal.spec.tsx index 7bdcdbc936..8835b46695 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/oauth-edit-modal.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/oauth-edit-modal.spec.tsx @@ -37,13 +37,16 @@ vi.mock('@/app/components/base/toast', async (importOriginal) => { default: { notify: (args: { type: string, message: string }) => mockToast(args), }, - useToastContext: () => ({ - notify: (args: { type: string, message: string }) => mockToast(args), - close: vi.fn(), - }), } }) +vi.mock('@/app/components/base/toast/context', () => ({ + useToastContext: () => ({ + notify: (args: { type: string, message: string }) => mockToast(args), + close: vi.fn(), + }), +})) + const createSubscription = (overrides: Partial = {}): TriggerSubscription => ({ id: 'sub-1', name: 'Subscription One', diff --git a/web/app/components/plugins/plugin-detail-panel/tool-selector/components/__tests__/tool-credentials-form.spec.tsx b/web/app/components/plugins/plugin-detail-panel/tool-selector/components/__tests__/tool-credentials-form.spec.tsx index 20655d0139..cb5b929d29 100644 --- a/web/app/components/plugins/plugin-detail-panel/tool-selector/components/__tests__/tool-credentials-form.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/tool-selector/components/__tests__/tool-credentials-form.spec.tsx @@ -12,6 +12,9 @@ vi.mock('@/utils/classnames', () => ({ vi.mock('@/app/components/base/toast', () => ({ default: { notify: vi.fn() }, +})) + +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: vi.fn() }), })) diff --git a/web/app/components/plugins/plugin-page/__tests__/context.spec.tsx b/web/app/components/plugins/plugin-page/__tests__/context.spec.tsx index ab0e35e042..389c161e8a 100644 --- a/web/app/components/plugins/plugin-page/__tests__/context.spec.tsx +++ b/web/app/components/plugins/plugin-page/__tests__/context.spec.tsx @@ -3,7 +3,8 @@ import { beforeEach, describe, expect, it, vi } from 'vitest' // Import mocks import { useGlobalPublicStore } from '@/context/global-public-context' -import { PluginPageContext, PluginPageContextProvider, usePluginPageContext } from '../context' +import { PluginPageContext, usePluginPageContext } from '../context' +import { PluginPageContextProvider } from '../context-provider' // Mock dependencies vi.mock('nuqs', () => ({ diff --git a/web/app/components/plugins/plugin-page/context.tsx b/web/app/components/plugins/plugin-page/context-provider.tsx similarity index 58% rename from web/app/components/plugins/plugin-page/context.tsx rename to web/app/components/plugins/plugin-page/context-provider.tsx index 01ec518347..83776b48f9 100644 --- a/web/app/components/plugins/plugin-page/context.tsx +++ b/web/app/components/plugins/plugin-page/context-provider.tsx @@ -1,24 +1,20 @@ 'use client' -import type { ReactNode, RefObject } from 'react' +import type { ReactNode } from 'react' +import type { PluginPageTab } from './context' import type { FilterState } from './filter-management' -import { noop } from 'es-toolkit/function' import { parseAsStringEnum, useQueryState } from 'nuqs' import { useMemo, useRef, useState, } from 'react' -import { - createContext, - useContextSelector, -} from 'use-context-selector' import { useGlobalPublicStore } from '@/context/global-public-context' import { PLUGIN_PAGE_TABS_MAP, usePluginPageTabs } from '../hooks' import { PLUGIN_TYPE_SEARCH_MAP } from '../marketplace/constants' - -export type PluginPageTab = typeof PLUGIN_PAGE_TABS_MAP[keyof typeof PLUGIN_PAGE_TABS_MAP] - | (typeof PLUGIN_TYPE_SEARCH_MAP)[keyof typeof PLUGIN_TYPE_SEARCH_MAP] +import { + PluginPageContext, +} from './context' const PLUGIN_PAGE_TAB_VALUES: PluginPageTab[] = [ PLUGIN_PAGE_TABS_MAP.plugins, @@ -29,42 +25,10 @@ const PLUGIN_PAGE_TAB_VALUES: PluginPageTab[] = [ const parseAsPluginPageTab = parseAsStringEnum(PLUGIN_PAGE_TAB_VALUES) .withDefault(PLUGIN_PAGE_TABS_MAP.plugins) -export type PluginPageContextValue = { - containerRef: RefObject - currentPluginID: string | undefined - setCurrentPluginID: (pluginID?: string) => void - filters: FilterState - setFilters: (filter: FilterState) => void - activeTab: PluginPageTab - setActiveTab: (tab: PluginPageTab) => void - options: Array<{ value: string, text: string }> -} - -const emptyContainerRef: RefObject = { current: null } - -export const PluginPageContext = createContext({ - containerRef: emptyContainerRef, - currentPluginID: undefined, - setCurrentPluginID: noop, - filters: { - categories: [], - tags: [], - searchQuery: '', - }, - setFilters: noop, - activeTab: PLUGIN_PAGE_TABS_MAP.plugins, - setActiveTab: noop, - options: [], -}) - type PluginPageContextProviderProps = { children: ReactNode } -export function usePluginPageContext(selector: (value: PluginPageContextValue) => any) { - return useContextSelector(PluginPageContext, selector) -} - export const PluginPageContextProvider = ({ children, }: PluginPageContextProviderProps) => { diff --git a/web/app/components/plugins/plugin-page/context.ts b/web/app/components/plugins/plugin-page/context.ts new file mode 100644 index 0000000000..04ab1eef19 --- /dev/null +++ b/web/app/components/plugins/plugin-page/context.ts @@ -0,0 +1,46 @@ +'use client' + +import type { RefObject } from 'react' +import type { PLUGIN_TYPE_SEARCH_MAP } from '../marketplace/constants' +import type { FilterState } from './filter-management' +import { noop } from 'es-toolkit/function' +import { + createContext, + useContextSelector, +} from 'use-context-selector' +import { PLUGIN_PAGE_TABS_MAP } from '../hooks' + +export type PluginPageTab = typeof PLUGIN_PAGE_TABS_MAP[keyof typeof PLUGIN_PAGE_TABS_MAP] + | (typeof PLUGIN_TYPE_SEARCH_MAP)[keyof typeof PLUGIN_TYPE_SEARCH_MAP] + +export type PluginPageContextValue = { + containerRef: RefObject + currentPluginID: string | undefined + setCurrentPluginID: (pluginID?: string) => void + filters: FilterState + setFilters: (filter: FilterState) => void + activeTab: PluginPageTab + setActiveTab: (tab: PluginPageTab) => void + options: Array<{ value: string, text: string }> +} + +const emptyContainerRef: RefObject = { current: null } + +export const PluginPageContext = createContext({ + containerRef: emptyContainerRef, + currentPluginID: undefined, + setCurrentPluginID: noop, + filters: { + categories: [], + tags: [], + searchQuery: '', + }, + setFilters: noop, + activeTab: PLUGIN_PAGE_TABS_MAP.plugins, + setActiveTab: noop, + options: [], +}) + +export function usePluginPageContext(selector: (value: PluginPageContextValue) => any) { + return useContextSelector(PluginPageContext, selector) +} diff --git a/web/app/components/plugins/plugin-page/index.tsx b/web/app/components/plugins/plugin-page/index.tsx index bec1eb60ab..6768361acf 100644 --- a/web/app/components/plugins/plugin-page/index.tsx +++ b/web/app/components/plugins/plugin-page/index.tsx @@ -28,10 +28,8 @@ import { PLUGIN_PAGE_TABS_MAP } from '../hooks' import InstallFromLocalPackage from '../install-plugin/install-from-local-package' import InstallFromMarketplace from '../install-plugin/install-from-marketplace' import { PLUGIN_TYPE_SEARCH_MAP } from '../marketplace/constants' -import { - PluginPageContextProvider, - usePluginPageContext, -} from './context' +import { usePluginPageContext } from './context' +import { PluginPageContextProvider } from './context-provider' import DebugInfo from './debug-info' import InstallPluginDropdown from './install-plugin-dropdown' import PluginTasks from './plugin-tasks' diff --git a/web/app/components/rag-pipeline/components/__tests__/update-dsl-modal.spec.tsx b/web/app/components/rag-pipeline/components/__tests__/update-dsl-modal.spec.tsx index 2f9b2172bd..98ad5f78f4 100644 --- a/web/app/components/rag-pipeline/components/__tests__/update-dsl-modal.spec.tsx +++ b/web/app/components/rag-pipeline/components/__tests__/update-dsl-modal.spec.tsx @@ -20,7 +20,7 @@ vi.mock('use-context-selector', () => ({ useContext: () => ({ notify: mockNotify }), })) -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ ToastContext: { Provider: ({ children }: PropsWithChildren) => children }, })) diff --git a/web/app/components/rag-pipeline/components/rag-pipeline-header/__tests__/index.spec.tsx b/web/app/components/rag-pipeline/components/rag-pipeline-header/__tests__/index.spec.tsx index 0b858eaaa7..00c989acb0 100644 --- a/web/app/components/rag-pipeline/components/rag-pipeline-header/__tests__/index.spec.tsx +++ b/web/app/components/rag-pipeline/components/rag-pipeline-header/__tests__/index.spec.tsx @@ -155,7 +155,7 @@ vi.mock('@/app/components/base/amplitude', () => ({ })) const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify, }), diff --git a/web/app/components/rag-pipeline/components/rag-pipeline-header/publisher/__tests__/index.spec.tsx b/web/app/components/rag-pipeline/components/rag-pipeline-header/publisher/__tests__/index.spec.tsx index 6129d3fe73..9cd1af2736 100644 --- a/web/app/components/rag-pipeline/components/rag-pipeline-header/publisher/__tests__/index.spec.tsx +++ b/web/app/components/rag-pipeline/components/rag-pipeline-header/publisher/__tests__/index.spec.tsx @@ -3,7 +3,7 @@ import { QueryClient, QueryClientProvider } from '@tanstack/react-query' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import * as React from 'react' import { beforeEach, describe, expect, it, vi } from 'vitest' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import Publisher from '../index' import Popup from '../popup' diff --git a/web/app/components/rag-pipeline/components/rag-pipeline-header/publisher/__tests__/popup.spec.tsx b/web/app/components/rag-pipeline/components/rag-pipeline-header/publisher/__tests__/popup.spec.tsx index 71707721a4..48282820d8 100644 --- a/web/app/components/rag-pipeline/components/rag-pipeline-header/publisher/__tests__/popup.spec.tsx +++ b/web/app/components/rag-pipeline/components/rag-pipeline-header/publisher/__tests__/popup.spec.tsx @@ -57,7 +57,7 @@ vi.mock('@/app/components/workflow/store', () => ({ }), })) -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify }), })) diff --git a/web/app/components/rag-pipeline/components/rag-pipeline-header/publisher/popup.tsx b/web/app/components/rag-pipeline/components/rag-pipeline-header/publisher/popup.tsx index 2dd56b4277..371e4fd721 100644 --- a/web/app/components/rag-pipeline/components/rag-pipeline-header/publisher/popup.tsx +++ b/web/app/components/rag-pipeline/components/rag-pipeline-header/publisher/popup.tsx @@ -24,7 +24,7 @@ import Confirm from '@/app/components/base/confirm' import Divider from '@/app/components/base/divider' import { SparklesSoft } from '@/app/components/base/icons/src/public/common' import PremiumBadge from '@/app/components/base/premium-badge' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { useChecklistBeforePublish, } from '@/app/components/workflow/hooks' diff --git a/web/app/components/rag-pipeline/hooks/__tests__/index.spec.ts b/web/app/components/rag-pipeline/hooks/__tests__/index.spec.ts index 4c60e5133c..95fe763d61 100644 --- a/web/app/components/rag-pipeline/hooks/__tests__/index.spec.ts +++ b/web/app/components/rag-pipeline/hooks/__tests__/index.spec.ts @@ -31,7 +31,7 @@ vi.mock('@/app/components/workflow/store', () => ({ })) const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify, }), diff --git a/web/app/components/rag-pipeline/hooks/__tests__/use-DSL.spec.ts b/web/app/components/rag-pipeline/hooks/__tests__/use-DSL.spec.ts index c0b983052d..c5603f2fc7 100644 --- a/web/app/components/rag-pipeline/hooks/__tests__/use-DSL.spec.ts +++ b/web/app/components/rag-pipeline/hooks/__tests__/use-DSL.spec.ts @@ -3,7 +3,7 @@ import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' import { useDSL } from '../use-DSL' const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify }), })) diff --git a/web/app/components/rag-pipeline/hooks/__tests__/use-update-dsl-modal.spec.ts b/web/app/components/rag-pipeline/hooks/__tests__/use-update-dsl-modal.spec.ts index 942e337ad8..a50965fb3b 100644 --- a/web/app/components/rag-pipeline/hooks/__tests__/use-update-dsl-modal.spec.ts +++ b/web/app/components/rag-pipeline/hooks/__tests__/use-update-dsl-modal.spec.ts @@ -23,7 +23,7 @@ vi.mock('use-context-selector', () => ({ useContext: () => ({ notify: mockNotify }), })) -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ ToastContext: {}, })) diff --git a/web/app/components/rag-pipeline/hooks/use-DSL.ts b/web/app/components/rag-pipeline/hooks/use-DSL.ts index 5c0f9def1c..f45cf35bdf 100644 --- a/web/app/components/rag-pipeline/hooks/use-DSL.ts +++ b/web/app/components/rag-pipeline/hooks/use-DSL.ts @@ -3,7 +3,7 @@ import { useState, } from 'react' import { useTranslation } from 'react-i18next' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { DSL_EXPORT_CHECK, } from '@/app/components/workflow/constants' diff --git a/web/app/components/rag-pipeline/hooks/use-update-dsl-modal.ts b/web/app/components/rag-pipeline/hooks/use-update-dsl-modal.ts index 3b86937417..48d2bc78b4 100644 --- a/web/app/components/rag-pipeline/hooks/use-update-dsl-modal.ts +++ b/web/app/components/rag-pipeline/hooks/use-update-dsl-modal.ts @@ -6,7 +6,7 @@ import { } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { WORKFLOW_DATA_UPDATE } from '@/app/components/workflow/constants' import { usePluginDependencies } from '@/app/components/workflow/plugin-dependency/hooks' import { useWorkflowStore } from '@/app/components/workflow/store' diff --git a/web/app/components/workflow-app/components/workflow-header/__tests__/features-trigger.spec.tsx b/web/app/components/workflow-app/components/workflow-header/__tests__/features-trigger.spec.tsx index 4a7fd1275f..2318a1c7bc 100644 --- a/web/app/components/workflow-app/components/workflow-header/__tests__/features-trigger.spec.tsx +++ b/web/app/components/workflow-app/components/workflow-header/__tests__/features-trigger.spec.tsx @@ -4,7 +4,7 @@ import type { App } from '@/types/app' import { render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' import { useStore as useAppStore } from '@/app/components/app/store' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { Plan } from '@/app/components/billing/type' import { BlockEnum, InputVarType } from '@/app/components/workflow/types' import FeaturesTrigger from '../features-trigger' diff --git a/web/app/components/workflow-app/components/workflow-header/features-trigger.tsx b/web/app/components/workflow-app/components/workflow-header/features-trigger.tsx index d58eb6c669..84603e9a13 100644 --- a/web/app/components/workflow-app/components/workflow-header/features-trigger.tsx +++ b/web/app/components/workflow-app/components/workflow-header/features-trigger.tsx @@ -17,7 +17,7 @@ import AppPublisher from '@/app/components/app/app-publisher' import { useStore as useAppStore } from '@/app/components/app/store' import Button from '@/app/components/base/button' import { useFeatures } from '@/app/components/base/features/hooks' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { Plan } from '@/app/components/billing/type' import { useChecklist, diff --git a/web/app/components/workflow-app/hooks/use-DSL.ts b/web/app/components/workflow-app/hooks/use-DSL.ts index 939e43b554..918a60f185 100644 --- a/web/app/components/workflow-app/hooks/use-DSL.ts +++ b/web/app/components/workflow-app/hooks/use-DSL.ts @@ -4,7 +4,7 @@ import { } from 'react' import { useTranslation } from 'react-i18next' import { useStore as useAppStore } from '@/app/components/app/store' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { DSL_EXPORT_CHECK, } from '@/app/components/workflow/constants' diff --git a/web/app/components/workflow/header/run-mode.tsx b/web/app/components/workflow/header/run-mode.tsx index 63c48ff0dc..943af13a92 100644 --- a/web/app/components/workflow/header/run-mode.tsx +++ b/web/app/components/workflow/header/run-mode.tsx @@ -5,7 +5,7 @@ import { useCallback, useEffect, useRef } from 'react' import { useTranslation } from 'react-i18next' import { trackEvent } from '@/app/components/base/amplitude' import { StopCircle } from '@/app/components/base/icons/src/vender/line/mediaAndDevices' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { useWorkflowRun, useWorkflowRunValidation, useWorkflowStartRun } from '@/app/components/workflow/hooks' import ShortcutsName from '@/app/components/workflow/shortcuts-name' import { useStore } from '@/app/components/workflow/store' diff --git a/web/app/components/workflow/hooks/__tests__/use-checklist.spec.ts b/web/app/components/workflow/hooks/__tests__/use-checklist.spec.ts index d72d001e0b..1b37055134 100644 --- a/web/app/components/workflow/hooks/__tests__/use-checklist.spec.ts +++ b/web/app/components/workflow/hooks/__tests__/use-checklist.spec.ts @@ -82,7 +82,7 @@ vi.mock('../index', () => ({ useNodesMetaData: () => ({ nodes: [], nodesMap: mockNodesMap }), })) -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: vi.fn() }), })) diff --git a/web/app/components/workflow/hooks/use-checklist.ts b/web/app/components/workflow/hooks/use-checklist.ts index 642179aed7..dd1d012aed 100644 --- a/web/app/components/workflow/hooks/use-checklist.ts +++ b/web/app/components/workflow/hooks/use-checklist.ts @@ -22,7 +22,7 @@ import { import { useTranslation } from 'react-i18next' import { useEdges, useStoreApi } from 'reactflow' import { useStore as useAppStore } from '@/app/components/app/store' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { useModelList } from '@/app/components/header/account-setting/model-provider-page/hooks' import useNodes from '@/app/components/workflow/store/workflow/use-nodes' diff --git a/web/app/components/workflow/note-node/note-editor/plugins/link-editor-plugin/hooks.ts b/web/app/components/workflow/note-node/note-editor/plugins/link-editor-plugin/hooks.ts index 3a084ef0af..511940d142 100644 --- a/web/app/components/workflow/note-node/note-editor/plugins/link-editor-plugin/hooks.ts +++ b/web/app/components/workflow/note-node/note-editor/plugins/link-editor-plugin/hooks.ts @@ -15,7 +15,7 @@ import { useEffect, } from 'react' import { useTranslation } from 'react-i18next' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { useNoteEditorStore } from '../../store' import { urlRegExp } from '../../utils' diff --git a/web/app/components/workflow/panel/chat-variable-panel/components/object-value-item.tsx b/web/app/components/workflow/panel/chat-variable-panel/components/object-value-item.tsx index ef4b4d71a9..74765c2b9e 100644 --- a/web/app/components/workflow/panel/chat-variable-panel/components/object-value-item.tsx +++ b/web/app/components/workflow/panel/chat-variable-panel/components/object-value-item.tsx @@ -5,7 +5,7 @@ import * as React from 'react' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import RemoveButton from '@/app/components/workflow/nodes/_base/components/remove-button' import VariableTypeSelector from '@/app/components/workflow/panel/chat-variable-panel/components/variable-type-select' import { ChatVarType } from '@/app/components/workflow/panel/chat-variable-panel/type' diff --git a/web/app/components/workflow/panel/chat-variable-panel/components/variable-modal.tsx b/web/app/components/workflow/panel/chat-variable-panel/components/variable-modal.tsx index 5c07cca3df..dd14bcf75a 100644 --- a/web/app/components/workflow/panel/chat-variable-panel/components/variable-modal.tsx +++ b/web/app/components/workflow/panel/chat-variable-panel/components/variable-modal.tsx @@ -7,7 +7,7 @@ import { useContext } from 'use-context-selector' import { v4 as uuid4 } from 'uuid' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import CodeEditor from '@/app/components/workflow/nodes/_base/components/editor/code-editor' import { CodeLanguage } from '@/app/components/workflow/nodes/code/types' import ArrayValueList from '@/app/components/workflow/panel/chat-variable-panel/components/array-value-list' diff --git a/web/app/components/workflow/panel/debug-and-preview/hooks.ts b/web/app/components/workflow/panel/debug-and-preview/hooks.ts index b31673ee26..3481733cd2 100644 --- a/web/app/components/workflow/panel/debug-and-preview/hooks.ts +++ b/web/app/components/workflow/panel/debug-and-preview/hooks.ts @@ -26,7 +26,7 @@ import { getProcessedFiles, getProcessedFilesFromResponse, } from '@/app/components/base/file-uploader/utils' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { CUSTOM_NODE, } from '@/app/components/workflow/constants' diff --git a/web/app/components/workflow/panel/env-panel/variable-modal.tsx b/web/app/components/workflow/panel/env-panel/variable-modal.tsx index 0e120ac77c..493a73405a 100644 --- a/web/app/components/workflow/panel/env-panel/variable-modal.tsx +++ b/web/app/components/workflow/panel/env-panel/variable-modal.tsx @@ -7,7 +7,7 @@ import { useContext } from 'use-context-selector' import { v4 as uuid4 } from 'uuid' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import Tooltip from '@/app/components/base/tooltip' import { useWorkflowStore } from '@/app/components/workflow/store' import { cn } from '@/utils/classnames' diff --git a/web/app/components/workflow/run/index.tsx b/web/app/components/workflow/run/index.tsx index 441002c86c..96acb1a026 100644 --- a/web/app/components/workflow/run/index.tsx +++ b/web/app/components/workflow/run/index.tsx @@ -6,7 +6,7 @@ import { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import Loading from '@/app/components/base/loading' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { WorkflowRunningStatus } from '@/app/components/workflow/types' import { fetchRunDetail, fetchTracingList } from '@/service/log' import { cn } from '@/utils/classnames' diff --git a/web/app/components/workflow/update-dsl-modal.tsx b/web/app/components/workflow/update-dsl-modal.tsx index d33679ff1b..08107a0c24 100644 --- a/web/app/components/workflow/update-dsl-modal.tsx +++ b/web/app/components/workflow/update-dsl-modal.tsx @@ -24,7 +24,7 @@ import { useStore as useAppStore } from '@/app/components/app/store' import Button from '@/app/components/base/button' import Modal from '@/app/components/base/modal' import { FILE_EXTS } from '@/app/components/base/prompt-editor/constants' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { usePluginDependencies } from '@/app/components/workflow/plugin-dependency/hooks' import { useEventEmitterContextContext } from '@/context/event-emitter' import { diff --git a/web/app/education-apply/education-apply-page.tsx b/web/app/education-apply/education-apply-page.tsx index 3f8d80a67e..636a471d3d 100644 --- a/web/app/education-apply/education-apply-page.tsx +++ b/web/app/education-apply/education-apply-page.tsx @@ -12,7 +12,7 @@ import { import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Checkbox from '@/app/components/base/checkbox' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { EDUCATION_VERIFYING_LOCALSTORAGE_ITEM } from '@/app/education-apply/constants' import { useDocLink } from '@/context/i18n' import { useProviderContext } from '@/context/provider-context' diff --git a/web/context/datasets-context.tsx b/web/context/datasets-context.ts similarity index 100% rename from web/context/datasets-context.tsx rename to web/context/datasets-context.ts diff --git a/web/context/event-emitter-provider.tsx b/web/context/event-emitter-provider.tsx new file mode 100644 index 0000000000..da8d2d78c2 --- /dev/null +++ b/web/context/event-emitter-provider.tsx @@ -0,0 +1,22 @@ +'use client' + +import type { ReactNode } from 'react' +import type { EventEmitterValue } from './event-emitter' +import { useEventEmitter } from 'ahooks' +import { EventEmitterContext } from './event-emitter' + +type EventEmitterContextProviderProps = { + children: ReactNode +} + +export const EventEmitterContextProvider = ({ + children, +}: EventEmitterContextProviderProps) => { + const eventEmitter = useEventEmitter() + + return ( + + {children} + + ) +} diff --git a/web/context/event-emitter.tsx b/web/context/event-emitter.ts similarity index 53% rename from web/context/event-emitter.tsx rename to web/context/event-emitter.ts index 14b81eacb6..eb7794dfe1 100644 --- a/web/context/event-emitter.tsx +++ b/web/context/event-emitter.ts @@ -1,7 +1,6 @@ 'use client' import type { EventEmitter } from 'ahooks/lib/useEventEmitter' -import { useEventEmitter } from 'ahooks' import { createContext, useContext } from 'use-context-selector' /** @@ -16,25 +15,10 @@ export type EventEmitterMessage = { export type EventEmitterValue = string | EventEmitterMessage -const EventEmitterContext = createContext<{ eventEmitter: EventEmitter | null }>({ +export const EventEmitterContext = createContext<{ eventEmitter: EventEmitter | null }>({ eventEmitter: null, }) export const useEventEmitterContextContext = () => useContext(EventEmitterContext) -type EventEmitterContextProviderProps = { - children: React.ReactNode -} -export const EventEmitterContextProvider = ({ - children, -}: EventEmitterContextProviderProps) => { - const eventEmitter = useEventEmitter() - - return ( - - {children} - - ) -} - export default EventEmitterContext diff --git a/web/context/mitt-context-provider.tsx b/web/context/mitt-context-provider.tsx new file mode 100644 index 0000000000..b177694d8d --- /dev/null +++ b/web/context/mitt-context-provider.tsx @@ -0,0 +1,19 @@ +'use client' + +import type { ReactNode } from 'react' +import { useMitt } from '@/hooks/use-mitt' +import { MittContext } from './mitt-context' + +type MittProviderProps = { + children: ReactNode +} + +export const MittProvider = ({ children }: MittProviderProps) => { + const mitt = useMitt() + + return ( + + {children} + + ) +} diff --git a/web/context/mitt-context.tsx b/web/context/mitt-context.ts similarity index 66% rename from web/context/mitt-context.tsx rename to web/context/mitt-context.ts index 4317fc5660..5c4a0771c5 100644 --- a/web/context/mitt-context.tsx +++ b/web/context/mitt-context.ts @@ -1,6 +1,8 @@ +'use client' + +import type { useMitt } from '@/hooks/use-mitt' import { noop } from 'es-toolkit/function' import { createContext, useContext, useContextSelector } from 'use-context-selector' -import { useMitt } from '@/hooks/use-mitt' type ContextValueType = ReturnType export const MittContext = createContext({ @@ -8,16 +10,6 @@ export const MittContext = createContext({ useSubscribe: noop, }) -export const MittProvider = ({ children }: { children: React.ReactNode }) => { - const mitt = useMitt() - - return ( - - {children} - - ) -} - export const useMittContext = () => { return useContext(MittContext) } diff --git a/web/context/modal-context.tsx b/web/context/modal-context-provider.tsx similarity index 81% rename from web/context/modal-context.tsx rename to web/context/modal-context-provider.tsx index ae1184e5bf..8c64642f43 100644 --- a/web/context/modal-context.tsx +++ b/web/context/modal-context-provider.tsx @@ -1,32 +1,20 @@ 'use client' -import type { Dispatch, SetStateAction } from 'react' -import type { TriggerEventsLimitModalPayload } from './hooks/use-trigger-events-limit-modal' +import type { ReactNode, SetStateAction } from 'react' +import type { ModalState, ModelModalType } from './modal-context' import type { OpeningStatement } from '@/app/components/base/features/types' import type { CreateExternalAPIReq } from '@/app/components/datasets/external-api/declarations' import type { AccountSettingTab } from '@/app/components/header/account-setting/constants' -import type { - ConfigurationMethodEnum, - Credential, - CustomConfigurationModelFixedFields, - CustomModel, - ModelModalModeEnum, - ModelProvider, -} from '@/app/components/header/account-setting/model-provider-page/declarations' import type { ModelLoadBalancingModalProps } from '@/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal' import type { UpdatePluginPayload } from '@/app/components/plugins/types' import type { InputVar } from '@/app/components/workflow/types' import type { ExpireNoticeModalPayloadProps } from '@/app/education-apply/expire-notice-modal' -import type { - ApiBasedExtension, - ExternalDataTool, -} from '@/models/common' +import type { ApiBasedExtension, ExternalDataTool } from '@/models/common' import type { ModerationConfig, PromptVariable } from '@/models/debug' -import { noop } from 'es-toolkit/function' import dynamic from 'next/dynamic' import { useCallback, useEffect, useRef, useState } from 'react' -import { createContext, useContext, useContextSelector } from 'use-context-selector' import { + DEFAULT_ACCOUNT_SETTING_TAB, isValidAccountSettingTab, } from '@/app/components/header/account-setting/constants' @@ -39,11 +27,10 @@ import { useAccountSettingModal, usePricingModal, } from '@/hooks/use-query-params' - +import { useTriggerEventsLimitModal } from './hooks/use-trigger-events-limit-modal' import { - - useTriggerEventsLimitModal, -} from './hooks/use-trigger-events-limit-modal' + ModalContext, +} from './modal-context' const AccountSetting = dynamic(() => import('@/app/components/header/account-setting'), { ssr: false, @@ -86,72 +73,8 @@ const TriggerEventsLimitModal = dynamic(() => import('@/app/components/billing/t ssr: false, }) -export type ModalState = { - payload: T - onCancelCallback?: () => void - onSaveCallback?: (newPayload?: T, formValues?: Record) => void - onRemoveCallback?: (newPayload?: T, formValues?: Record) => void - onEditCallback?: (newPayload: T) => void - onValidateBeforeSaveCallback?: (newPayload: T) => boolean - isEditMode?: boolean - datasetBindings?: { id: string, name: string }[] -} - -export type ModelModalType = { - currentProvider: ModelProvider - currentConfigurationMethod: ConfigurationMethodEnum - currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields - isModelCredential?: boolean - credential?: Credential - model?: CustomModel - mode?: ModelModalModeEnum -} - -export type ModalContextState = { - setShowAccountSettingModal: Dispatch | null>> - setShowApiBasedExtensionModal: Dispatch | null>> - setShowModerationSettingModal: Dispatch | null>> - setShowExternalDataToolModal: Dispatch | null>> - setShowPricingModal: () => void - setShowAnnotationFullModal: () => void - setShowModelModal: Dispatch | null>> - setShowExternalKnowledgeAPIModal: Dispatch | null>> - setShowModelLoadBalancingModal: Dispatch> - setShowOpeningModal: Dispatch void - }> | null>> - setShowUpdatePluginModal: Dispatch | null>> - setShowEducationExpireNoticeModal: Dispatch | null>> - setShowTriggerEventsLimitModal: Dispatch | null>> -} - -const ModalContext = createContext({ - setShowAccountSettingModal: noop, - setShowApiBasedExtensionModal: noop, - setShowModerationSettingModal: noop, - setShowExternalDataToolModal: noop, - setShowPricingModal: noop, - setShowAnnotationFullModal: noop, - setShowModelModal: noop, - setShowExternalKnowledgeAPIModal: noop, - setShowModelLoadBalancingModal: noop, - setShowOpeningModal: noop, - setShowUpdatePluginModal: noop, - setShowEducationExpireNoticeModal: noop, - setShowTriggerEventsLimitModal: noop, -}) - -export const useModalContext = () => useContext(ModalContext) - -// Adding a dangling comma to avoid the generic parsing issue in tsx, see: -// https://github.com/microsoft/TypeScript/issues/15713 -export const useModalContextSelector = (selector: (state: ModalContextState) => T): T => - useContextSelector(ModalContext, selector) - type ModalContextProviderProps = { - children: React.ReactNode + children: ReactNode } export const ModalContextProvider = ({ children, diff --git a/web/context/modal-context.test.tsx b/web/context/modal-context.test.tsx index a0f6ff35ec..98f67a5473 100644 --- a/web/context/modal-context.test.tsx +++ b/web/context/modal-context.test.tsx @@ -2,7 +2,7 @@ import { act, screen, waitFor } from '@testing-library/react' import * as React from 'react' import { defaultPlan } from '@/app/components/billing/config' import { Plan } from '@/app/components/billing/type' -import { ModalContextProvider } from '@/context/modal-context' +import { ModalContextProvider } from '@/context/modal-context-provider' import { renderWithNuqs } from '@/test/nuqs-testing' vi.mock('@/config', async (importOriginal) => { diff --git a/web/context/modal-context.ts b/web/context/modal-context.ts new file mode 100644 index 0000000000..cc0ca28a42 --- /dev/null +++ b/web/context/modal-context.ts @@ -0,0 +1,92 @@ +'use client' + +import type { Dispatch, SetStateAction } from 'react' +import type { TriggerEventsLimitModalPayload } from './hooks/use-trigger-events-limit-modal' +import type { OpeningStatement } from '@/app/components/base/features/types' +import type { CreateExternalAPIReq } from '@/app/components/datasets/external-api/declarations' +import type { AccountSettingTab } from '@/app/components/header/account-setting/constants' +import type { + ConfigurationMethodEnum, + Credential, + CustomConfigurationModelFixedFields, + CustomModel, + ModelModalModeEnum, + ModelProvider, +} from '@/app/components/header/account-setting/model-provider-page/declarations' +import type { ModelLoadBalancingModalProps } from '@/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal' +import type { UpdatePluginPayload } from '@/app/components/plugins/types' +import type { InputVar } from '@/app/components/workflow/types' +import type { ExpireNoticeModalPayloadProps } from '@/app/education-apply/expire-notice-modal' +import type { + ApiBasedExtension, + ExternalDataTool, +} from '@/models/common' +import type { ModerationConfig, PromptVariable } from '@/models/debug' +import { noop } from 'es-toolkit/function' +import { createContext, useContext, useContextSelector } from 'use-context-selector' + +export type ModalState = { + payload: T + onCancelCallback?: () => void + onSaveCallback?: (newPayload?: T, formValues?: Record) => void + onRemoveCallback?: (newPayload?: T, formValues?: Record) => void + onEditCallback?: (newPayload: T) => void + onValidateBeforeSaveCallback?: (newPayload: T) => boolean + isEditMode?: boolean + datasetBindings?: { id: string, name: string }[] +} + +export type ModelModalType = { + currentProvider: ModelProvider + currentConfigurationMethod: ConfigurationMethodEnum + currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields + isModelCredential?: boolean + credential?: Credential + model?: CustomModel + mode?: ModelModalModeEnum +} + +export type ModalContextState = { + setShowAccountSettingModal: Dispatch | null>> + setShowApiBasedExtensionModal: Dispatch | null>> + setShowModerationSettingModal: Dispatch | null>> + setShowExternalDataToolModal: Dispatch | null>> + setShowPricingModal: () => void + setShowAnnotationFullModal: () => void + setShowModelModal: Dispatch | null>> + setShowExternalKnowledgeAPIModal: Dispatch | null>> + setShowModelLoadBalancingModal: Dispatch> + setShowOpeningModal: Dispatch void + }> | null>> + setShowUpdatePluginModal: Dispatch | null>> + setShowEducationExpireNoticeModal: Dispatch | null>> + setShowTriggerEventsLimitModal: Dispatch | null>> +} + +export const ModalContext = createContext({ + setShowAccountSettingModal: noop, + setShowApiBasedExtensionModal: noop, + setShowModerationSettingModal: noop, + setShowExternalDataToolModal: noop, + setShowPricingModal: noop, + setShowAnnotationFullModal: noop, + setShowModelModal: noop, + setShowExternalKnowledgeAPIModal: noop, + setShowModelLoadBalancingModal: noop, + setShowOpeningModal: noop, + setShowUpdatePluginModal: noop, + setShowEducationExpireNoticeModal: noop, + setShowTriggerEventsLimitModal: noop, +}) + +export const useModalContext = () => useContext(ModalContext) + +// Adding a dangling comma to avoid the generic parsing issue in tsx, see: +// https://github.com/microsoft/TypeScript/issues/15713 +export const useModalContextSelector = (selector: (state: ModalContextState) => T): T => + useContextSelector(ModalContext, selector) + +export default ModalContext diff --git a/web/context/provider-context.tsx b/web/context/provider-context-provider.tsx similarity index 70% rename from web/context/provider-context.tsx rename to web/context/provider-context-provider.tsx index 2a71d9cf93..ce7f2ba40c 100644 --- a/web/context/provider-context.tsx +++ b/web/context/provider-context-provider.tsx @@ -1,14 +1,10 @@ 'use client' -import type { Plan, UsagePlanInfo, UsageResetInfo } from '@/app/components/billing/type' -import type { Model, ModelProvider } from '@/app/components/header/account-setting/model-provider-page/declarations' -import type { RETRIEVE_METHOD } from '@/types/app' +import type { ReactNode } from 'react' import { useQueryClient } from '@tanstack/react-query' import dayjs from 'dayjs' -import { noop } from 'es-toolkit/function' import { useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' -import { createContext, useContext, useContextSelector } from 'use-context-selector' import Toast from '@/app/components/base/toast' import { setZendeskConversationFields } from '@/app/components/base/zendesk/utils' import { defaultPlan } from '@/app/components/billing/config' @@ -25,93 +21,13 @@ import { useModelProviders, useSupportRetrievalMethods, } from '@/service/use-common' -import { - useEducationStatus, -} from '@/service/use-education' - -export type ProviderContextState = { - modelProviders: ModelProvider[] - refreshModelProviders: () => void - textGenerationModelList: Model[] - supportRetrievalMethods: RETRIEVE_METHOD[] - isAPIKeySet: boolean - plan: { - type: Plan - usage: UsagePlanInfo - total: UsagePlanInfo - reset: UsageResetInfo - } - isFetchedPlan: boolean - enableBilling: boolean - onPlanInfoChanged: () => void - enableReplaceWebAppLogo: boolean - modelLoadBalancingEnabled: boolean - datasetOperatorEnabled: boolean - enableEducationPlan: boolean - isEducationWorkspace: boolean - isEducationAccount: boolean - allowRefreshEducationVerify: boolean - educationAccountExpireAt: number | null - isLoadingEducationAccountInfo: boolean - isFetchingEducationAccountInfo: boolean - webappCopyrightEnabled: boolean - licenseLimit: { - workspace_members: { - size: number - limit: number - } - } - refreshLicenseLimit: () => void - isAllowTransferWorkspace: boolean - isAllowPublishAsCustomKnowledgePipelineTemplate: boolean - humanInputEmailDeliveryEnabled: boolean -} - -export const baseProviderContextValue: ProviderContextState = { - modelProviders: [], - refreshModelProviders: noop, - textGenerationModelList: [], - supportRetrievalMethods: [], - isAPIKeySet: true, - plan: defaultPlan, - isFetchedPlan: false, - enableBilling: false, - onPlanInfoChanged: noop, - enableReplaceWebAppLogo: false, - modelLoadBalancingEnabled: false, - datasetOperatorEnabled: false, - enableEducationPlan: false, - isEducationWorkspace: false, - isEducationAccount: false, - allowRefreshEducationVerify: false, - educationAccountExpireAt: null, - isLoadingEducationAccountInfo: false, - isFetchingEducationAccountInfo: false, - webappCopyrightEnabled: false, - licenseLimit: { - workspace_members: { - size: 0, - limit: 0, - }, - }, - refreshLicenseLimit: noop, - isAllowTransferWorkspace: false, - isAllowPublishAsCustomKnowledgePipelineTemplate: false, - humanInputEmailDeliveryEnabled: false, -} - -const ProviderContext = createContext(baseProviderContextValue) - -export const useProviderContext = () => useContext(ProviderContext) - -// Adding a dangling comma to avoid the generic parsing issue in tsx, see: -// https://github.com/microsoft/TypeScript/issues/15713 -export const useProviderContextSelector = (selector: (state: ProviderContextState) => T): T => - useContextSelector(ProviderContext, selector) +import { useEducationStatus } from '@/service/use-education' +import { ProviderContext } from './provider-context' type ProviderContextProviderProps = { - children: React.ReactNode + children: ReactNode } + export const ProviderContextProvider = ({ children, }: ProviderContextProviderProps) => { @@ -262,5 +178,3 @@ export const ProviderContextProvider = ({ ) } - -export default ProviderContext diff --git a/web/context/provider-context.ts b/web/context/provider-context.ts new file mode 100644 index 0000000000..27c43e7c91 --- /dev/null +++ b/web/context/provider-context.ts @@ -0,0 +1,90 @@ +'use client' + +import type { Plan, UsagePlanInfo, UsageResetInfo } from '@/app/components/billing/type' +import type { Model, ModelProvider } from '@/app/components/header/account-setting/model-provider-page/declarations' +import type { RETRIEVE_METHOD } from '@/types/app' +import { noop } from 'es-toolkit/function' +import { createContext, useContext, useContextSelector } from 'use-context-selector' +import { defaultPlan } from '@/app/components/billing/config' + +export type ProviderContextState = { + modelProviders: ModelProvider[] + refreshModelProviders: () => void + textGenerationModelList: Model[] + supportRetrievalMethods: RETRIEVE_METHOD[] + isAPIKeySet: boolean + plan: { + type: Plan + usage: UsagePlanInfo + total: UsagePlanInfo + reset: UsageResetInfo + } + isFetchedPlan: boolean + enableBilling: boolean + onPlanInfoChanged: () => void + enableReplaceWebAppLogo: boolean + modelLoadBalancingEnabled: boolean + datasetOperatorEnabled: boolean + enableEducationPlan: boolean + isEducationWorkspace: boolean + isEducationAccount: boolean + allowRefreshEducationVerify: boolean + educationAccountExpireAt: number | null + isLoadingEducationAccountInfo: boolean + isFetchingEducationAccountInfo: boolean + webappCopyrightEnabled: boolean + licenseLimit: { + workspace_members: { + size: number + limit: number + } + } + refreshLicenseLimit: () => void + isAllowTransferWorkspace: boolean + isAllowPublishAsCustomKnowledgePipelineTemplate: boolean + humanInputEmailDeliveryEnabled: boolean +} + +export const baseProviderContextValue: ProviderContextState = { + modelProviders: [], + refreshModelProviders: noop, + textGenerationModelList: [], + supportRetrievalMethods: [], + isAPIKeySet: true, + plan: defaultPlan, + isFetchedPlan: false, + enableBilling: false, + onPlanInfoChanged: noop, + enableReplaceWebAppLogo: false, + modelLoadBalancingEnabled: false, + datasetOperatorEnabled: false, + enableEducationPlan: false, + isEducationWorkspace: false, + isEducationAccount: false, + allowRefreshEducationVerify: false, + educationAccountExpireAt: null, + isLoadingEducationAccountInfo: false, + isFetchingEducationAccountInfo: false, + webappCopyrightEnabled: false, + licenseLimit: { + workspace_members: { + size: 0, + limit: 0, + }, + }, + refreshLicenseLimit: noop, + isAllowTransferWorkspace: false, + isAllowPublishAsCustomKnowledgePipelineTemplate: false, + humanInputEmailDeliveryEnabled: false, +} + +export const ProviderContext = createContext(baseProviderContextValue) + +export const useProviderContext = () => useContext(ProviderContext) + +// Adding a dangling comma to avoid the generic parsing issue in tsx, see: +// https://github.com/microsoft/TypeScript/issues/15713 +export const useProviderContextSelector = (selector: (state: ProviderContextState) => T): T => + useContextSelector(ProviderContext, selector) + +export default ProviderContext diff --git a/web/context/workspace-context-provider.tsx b/web/context/workspace-context-provider.tsx new file mode 100644 index 0000000000..afec62f710 --- /dev/null +++ b/web/context/workspace-context-provider.tsx @@ -0,0 +1,24 @@ +'use client' + +import type { ReactNode } from 'react' +import { useWorkspaces } from '@/service/use-common' +import { WorkspacesContext } from './workspace-context' + +type WorkspaceProviderProps = { + children: ReactNode +} + +export const WorkspaceProvider = ({ + children, +}: WorkspaceProviderProps) => { + const { data } = useWorkspaces() + + return ( + + {children} + + ) +} diff --git a/web/context/workspace-context.ts b/web/context/workspace-context.ts new file mode 100644 index 0000000000..e088d12f4e --- /dev/null +++ b/web/context/workspace-context.ts @@ -0,0 +1,16 @@ +'use client' + +import type { IWorkspace } from '@/models/common' +import { createContext, useContext } from 'use-context-selector' + +export type WorkspacesContextValue = { + workspaces: IWorkspace[] +} + +export const WorkspacesContext = createContext({ + workspaces: [], +}) + +export const useWorkspacesContext = () => useContext(WorkspacesContext) + +export default WorkspacesContext diff --git a/web/context/workspace-context.tsx b/web/context/workspace-context.tsx deleted file mode 100644 index 3834641bc1..0000000000 --- a/web/context/workspace-context.tsx +++ /dev/null @@ -1,36 +0,0 @@ -'use client' - -import type { IWorkspace } from '@/models/common' -import { createContext, useContext } from 'use-context-selector' -import { useWorkspaces } from '@/service/use-common' - -export type WorkspacesContextValue = { - workspaces: IWorkspace[] -} - -const WorkspacesContext = createContext({ - workspaces: [], -}) - -type IWorkspaceProviderProps = { - children: React.ReactNode -} - -export const WorkspaceProvider = ({ - children, -}: IWorkspaceProviderProps) => { - const { data } = useWorkspaces() - - return ( - - {children} - - ) -} - -export const useWorkspacesContext = () => useContext(WorkspacesContext) - -export default WorkspacesContext diff --git a/web/eslint-suppressions.json b/web/eslint-suppressions.json index 0880084320..22f225d1d0 100644 --- a/web/eslint-suppressions.json +++ b/web/eslint-suppressions.json @@ -968,11 +968,6 @@ "count": 6 } }, - "app/components/app/configuration/debug/debug-with-multiple-model/context.tsx": { - "react-refresh/only-export-components": { - "count": 1 - } - }, "app/components/app/configuration/debug/debug-with-multiple-model/debug-item.tsx": { "no-restricted-imports": { "count": 2 @@ -1561,7 +1556,7 @@ "count": 7 } }, - "app/components/base/chat/chat-with-history/context.tsx": { + "app/components/base/chat/chat-with-history/context.ts": { "ts/no-explicit-any": { "count": 7 } @@ -1715,11 +1710,6 @@ "count": 1 } }, - "app/components/base/chat/chat/context.tsx": { - "react-refresh/only-export-components": { - "count": 1 - } - }, "app/components/base/chat/chat/hooks.ts": { "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 2 @@ -1759,7 +1749,7 @@ "count": 7 } }, - "app/components/base/chat/embedded-chatbot/context.tsx": { + "app/components/base/chat/embedded-chatbot/context.ts": { "ts/no-explicit-any": { "count": 7 } @@ -2763,7 +2753,7 @@ "count": 2 } }, - "app/components/base/radio/context/index.tsx": { + "app/components/base/radio/context/index.ts": { "ts/no-explicit-any": { "count": 1 } @@ -2915,14 +2905,6 @@ "count": 1 } }, - "app/components/base/toast/index.tsx": { - "react-refresh/only-export-components": { - "count": 2 - }, - "tailwindcss/enforce-consistent-class-order": { - "count": 2 - } - }, "app/components/base/video-gallery/VideoPlayer.tsx": { "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 1 @@ -5683,10 +5665,7 @@ "count": 1 } }, - "app/components/plugins/plugin-page/context.tsx": { - "react-refresh/only-export-components": { - "count": 2 - }, + "app/components/plugins/plugin-page/context.ts": { "ts/no-explicit-any": { "count": 1 } @@ -9571,16 +9550,6 @@ "count": 5 } }, - "context/datasets-context.tsx": { - "react-refresh/only-export-components": { - "count": 1 - } - }, - "context/event-emitter.tsx": { - "react-refresh/only-export-components": { - "count": 1 - } - }, "context/external-api-panel-context.tsx": { "react-refresh/only-export-components": { "count": 1 @@ -9601,8 +9570,8 @@ "count": 3 } }, - "context/mitt-context.tsx": { - "react-refresh/only-export-components": { + "context/modal-context-provider.tsx": { + "ts/no-explicit-any": { "count": 3 } }, @@ -9611,18 +9580,12 @@ "count": 3 } }, - "context/modal-context.tsx": { - "react-refresh/only-export-components": { - "count": 2 - }, + "context/modal-context.ts": { "ts/no-explicit-any": { - "count": 5 + "count": 2 } }, - "context/provider-context.tsx": { - "react-refresh/only-export-components": { - "count": 3 - }, + "context/provider-context-provider.tsx": { "ts/no-explicit-any": { "count": 1 } @@ -9632,11 +9595,6 @@ "count": 1 } }, - "context/workspace-context.tsx": { - "react-refresh/only-export-components": { - "count": 1 - } - }, "hooks/use-async-window-open.spec.ts": { "ts/no-explicit-any": { "count": 6 diff --git a/web/hooks/use-import-dsl.ts b/web/hooks/use-import-dsl.ts index ba33db1e84..454f580b42 100644 --- a/web/hooks/use-import-dsl.ts +++ b/web/hooks/use-import-dsl.ts @@ -10,7 +10,7 @@ import { useState, } from 'react' import { useTranslation } from 'react-i18next' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { usePluginDependencies } from '@/app/components/workflow/plugin-dependency/hooks' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useSelector } from '@/context/app-context' From ebda5efe275154ca6015e665778a9f5c0678a6a8 Mon Sep 17 00:00:00 2001 From: yyh <92089059+lyzno1@users.noreply.github.com> Date: Thu, 5 Mar 2026 16:13:02 +0800 Subject: [PATCH 5/5] chore: prevent Storybook crash caused by vite-plugin-inspect (#33039) --- web/vite.config.ts | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/web/vite.config.ts b/web/vite.config.ts index b07e7ea7be..83a35f8558 100644 --- a/web/vite.config.ts +++ b/web/vite.config.ts @@ -160,6 +160,8 @@ if (import.meta.hot) { export default defineConfig(({ mode }) => { const isTest = mode === 'test' + const isStorybook = process.env.STORYBOOK === 'true' + || process.argv.some(arg => arg.toLowerCase().includes('storybook')) return { plugins: isTest @@ -176,14 +178,19 @@ export default defineConfig(({ mode }) => { }, } as Plugin, ] - : [ - Inspect(), - createCodeInspectorPlugin(), - createForceInspectorClientInjectionPlugin(), - react(), - vinext(), - customI18nHmrPlugin(), - ], + : isStorybook + ? [ + tsconfigPaths(), + react(), + ] + : [ + Inspect(), + createCodeInspectorPlugin(), + createForceInspectorClientInjectionPlugin(), + react(), + vinext(), + customI18nHmrPlugin(), + ], resolve: { alias: { '~@': __dirname, @@ -191,7 +198,7 @@ export default defineConfig(({ mode }) => { }, // vinext related config - ...(!isTest + ...(!isTest && !isStorybook ? { optimizeDeps: { exclude: ['nuqs'],