From d162f7e5ef0db74d3396239c82e6283732f043ae Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Mon, 1 Dec 2025 14:14:19 +0800 Subject: [PATCH 1/8] feat(api): automatically `NODE_TYPE_CLASSES_MAPPING` generation from node class definitions (#28525) --- api/app_factory.py | 2 + api/core/app/entities/app_invoke_entities.py | 33 ++-- api/core/workflow/nodes/base/node.py | 58 +++++++ api/core/workflow/nodes/node_mapping.py | 160 +----------------- api/core/workflow/nodes/tool/tool_node.py | 16 +- api/extensions/ext_forward_refs.py | 49 ++++++ .../workflow/graph/test_graph_validation.py | 2 +- .../workflow/graph_engine/test_mock_nodes.py | 24 +-- .../workflow/nodes/base/test_base_node.py | 4 + .../test_get_node_type_classes_mapping.py | 84 +++++++++ .../core/workflow/nodes/test_base_node.py | 2 +- 11 files changed, 245 insertions(+), 189 deletions(-) create mode 100644 api/extensions/ext_forward_refs.py create mode 100644 api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py diff --git a/api/app_factory.py b/api/app_factory.py index 933cf294d1..ad2065682c 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -51,6 +51,7 @@ def initialize_extensions(app: DifyApp): ext_commands, ext_compress, ext_database, + ext_forward_refs, ext_hosting_provider, ext_import_modules, ext_logging, @@ -75,6 +76,7 @@ def initialize_extensions(app: DifyApp): ext_warnings, ext_import_modules, ext_orjson, + ext_forward_refs, ext_set_secretkey, ext_compress, ext_code_based_extension, diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 5143dbf1e8..81c355eb10 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -130,7 +130,7 @@ class AppGenerateEntity(BaseModel): # extra parameters, like: auto_generate_conversation_name extras: dict[str, Any] = Field(default_factory=dict) - # tracing instance + # tracing instance; use forward ref to avoid circular import at import time trace_manager: Optional["TraceQueueManager"] = None @@ -275,16 +275,23 @@ class RagPipelineGenerateEntity(WorkflowAppGenerateEntity): start_node_id: str | None = None -# Import TraceQueueManager at runtime to resolve forward references -from core.ops.ops_trace_manager import TraceQueueManager +# NOTE: Avoid importing heavy tracing modules at import time to prevent circular imports. +# Forward reference to TraceQueueManager is kept as a string; we rebuild with a stub now to +# avoid Pydantic forward-ref errors in test contexts, and with the real class at app startup. -# Rebuild models that use forward references -AppGenerateEntity.model_rebuild() -EasyUIBasedAppGenerateEntity.model_rebuild() -ConversationAppGenerateEntity.model_rebuild() -ChatAppGenerateEntity.model_rebuild() -CompletionAppGenerateEntity.model_rebuild() -AgentChatAppGenerateEntity.model_rebuild() -AdvancedChatAppGenerateEntity.model_rebuild() -WorkflowAppGenerateEntity.model_rebuild() -RagPipelineGenerateEntity.model_rebuild() + +# Minimal stub to satisfy Pydantic model_rebuild in environments where the real type is not importable yet. +class _TraceQueueManagerStub: + pass + + +_ns = {"TraceQueueManager": _TraceQueueManagerStub} +AppGenerateEntity.model_rebuild(_types_namespace=_ns) +EasyUIBasedAppGenerateEntity.model_rebuild(_types_namespace=_ns) +ConversationAppGenerateEntity.model_rebuild(_types_namespace=_ns) +ChatAppGenerateEntity.model_rebuild(_types_namespace=_ns) +CompletionAppGenerateEntity.model_rebuild(_types_namespace=_ns) +AgentChatAppGenerateEntity.model_rebuild(_types_namespace=_ns) +AdvancedChatAppGenerateEntity.model_rebuild(_types_namespace=_ns) +WorkflowAppGenerateEntity.model_rebuild(_types_namespace=_ns) +RagPipelineGenerateEntity.model_rebuild(_types_namespace=_ns) diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index 592bea0e16..c2e1105971 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -1,7 +1,11 @@ +import importlib import logging +import operator +import pkgutil from abc import abstractmethod from collections.abc import Generator, Mapping, Sequence from functools import singledispatchmethod +from types import MappingProxyType from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin from uuid import uuid4 @@ -134,6 +138,34 @@ class Node(Generic[NodeDataT]): cls._node_data_type = node_data_type + # Skip base class itself + if cls is Node: + return + # Only register production node implementations defined under core.workflow.nodes.* + # This prevents test helper subclasses from polluting the global registry and + # accidentally overriding real node types (e.g., a test Answer node). + module_name = getattr(cls, "__module__", "") + # Only register concrete subclasses that define node_type and version() + node_type = cls.node_type + version = cls.version() + bucket = Node._registry.setdefault(node_type, {}) + if module_name.startswith("core.workflow.nodes."): + # Production node definitions take precedence and may override + bucket[version] = cls # type: ignore[index] + else: + # External/test subclasses may register but must not override production + bucket.setdefault(version, cls) # type: ignore[index] + # Maintain a "latest" pointer preferring numeric versions; fallback to lexicographic + version_keys = [v for v in bucket if v != "latest"] + numeric_pairs: list[tuple[str, int]] = [] + for v in version_keys: + numeric_pairs.append((v, int(v))) + if numeric_pairs: + latest_key = max(numeric_pairs, key=operator.itemgetter(1))[0] + else: + latest_key = max(version_keys) if version_keys else version + bucket["latest"] = bucket[latest_key] + @classmethod def _extract_node_data_type_from_generic(cls) -> type[BaseNodeData] | None: """ @@ -165,6 +197,9 @@ class Node(Generic[NodeDataT]): return None + # Global registry populated via __init_subclass__ + _registry: ClassVar[dict["NodeType", dict[str, type["Node"]]]] = {} + def __init__( self, id: str, @@ -395,6 +430,29 @@ class Node(Generic[NodeDataT]): # in `api/core/workflow/nodes/__init__.py`. raise NotImplementedError("subclasses of BaseNode must implement `version` method.") + @classmethod + def get_node_type_classes_mapping(cls) -> Mapping["NodeType", Mapping[str, type["Node"]]]: + """Return mapping of NodeType -> {version -> Node subclass} using __init_subclass__ registry. + + Import all modules under core.workflow.nodes so subclasses register themselves on import. + Then we return a readonly view of the registry to avoid accidental mutation. + """ + # Import all node modules to ensure they are loaded (thus registered) + import core.workflow.nodes as _nodes_pkg + + for _, _modname, _ in pkgutil.walk_packages(_nodes_pkg.__path__, _nodes_pkg.__name__ + "."): + # Avoid importing modules that depend on the registry to prevent circular imports + # e.g. node_factory imports node_mapping which builds the mapping here. + if _modname in { + "core.workflow.nodes.node_factory", + "core.workflow.nodes.node_mapping", + }: + continue + importlib.import_module(_modname) + + # Return a readonly view so callers can't mutate the registry by accident + return {nt: MappingProxyType(ver_map) for nt, ver_map in cls._registry.items()} + @property def retry(self) -> bool: return False diff --git a/api/core/workflow/nodes/node_mapping.py b/api/core/workflow/nodes/node_mapping.py index b926645f18..85df543a2a 100644 --- a/api/core/workflow/nodes/node_mapping.py +++ b/api/core/workflow/nodes/node_mapping.py @@ -1,165 +1,9 @@ from collections.abc import Mapping from core.workflow.enums import NodeType -from core.workflow.nodes.agent.agent_node import AgentNode -from core.workflow.nodes.answer.answer_node import AnswerNode from core.workflow.nodes.base.node import Node -from core.workflow.nodes.code import CodeNode -from core.workflow.nodes.datasource.datasource_node import DatasourceNode -from core.workflow.nodes.document_extractor import DocumentExtractorNode -from core.workflow.nodes.end.end_node import EndNode -from core.workflow.nodes.http_request import HttpRequestNode -from core.workflow.nodes.human_input import HumanInputNode -from core.workflow.nodes.if_else import IfElseNode -from core.workflow.nodes.iteration import IterationNode, IterationStartNode -from core.workflow.nodes.knowledge_index import KnowledgeIndexNode -from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode -from core.workflow.nodes.list_operator import ListOperatorNode -from core.workflow.nodes.llm import LLMNode -from core.workflow.nodes.loop import LoopEndNode, LoopNode, LoopStartNode -from core.workflow.nodes.parameter_extractor import ParameterExtractorNode -from core.workflow.nodes.question_classifier import QuestionClassifierNode -from core.workflow.nodes.start import StartNode -from core.workflow.nodes.template_transform import TemplateTransformNode -from core.workflow.nodes.tool import ToolNode -from core.workflow.nodes.trigger_plugin import TriggerEventNode -from core.workflow.nodes.trigger_schedule import TriggerScheduleNode -from core.workflow.nodes.trigger_webhook import TriggerWebhookNode -from core.workflow.nodes.variable_aggregator import VariableAggregatorNode -from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode as VariableAssignerNodeV1 -from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode as VariableAssignerNodeV2 LATEST_VERSION = "latest" -# NOTE(QuantumGhost): This should be in sync with subclasses of BaseNode. -# Specifically, if you have introduced new node types, you should add them here. -# -# TODO(QuantumGhost): This could be automated with either metaclass or `__init_subclass__` -# hook. Try to avoid duplication of node information. -NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = { - NodeType.START: { - LATEST_VERSION: StartNode, - "1": StartNode, - }, - NodeType.END: { - LATEST_VERSION: EndNode, - "1": EndNode, - }, - NodeType.ANSWER: { - LATEST_VERSION: AnswerNode, - "1": AnswerNode, - }, - NodeType.LLM: { - LATEST_VERSION: LLMNode, - "1": LLMNode, - }, - NodeType.KNOWLEDGE_RETRIEVAL: { - LATEST_VERSION: KnowledgeRetrievalNode, - "1": KnowledgeRetrievalNode, - }, - NodeType.IF_ELSE: { - LATEST_VERSION: IfElseNode, - "1": IfElseNode, - }, - NodeType.CODE: { - LATEST_VERSION: CodeNode, - "1": CodeNode, - }, - NodeType.TEMPLATE_TRANSFORM: { - LATEST_VERSION: TemplateTransformNode, - "1": TemplateTransformNode, - }, - NodeType.QUESTION_CLASSIFIER: { - LATEST_VERSION: QuestionClassifierNode, - "1": QuestionClassifierNode, - }, - NodeType.HTTP_REQUEST: { - LATEST_VERSION: HttpRequestNode, - "1": HttpRequestNode, - }, - NodeType.TOOL: { - LATEST_VERSION: ToolNode, - # This is an issue that caused problems before. - # Logically, we shouldn't use two different versions to point to the same class here, - # but in order to maintain compatibility with historical data, this approach has been retained. - "2": ToolNode, - "1": ToolNode, - }, - NodeType.VARIABLE_AGGREGATOR: { - LATEST_VERSION: VariableAggregatorNode, - "1": VariableAggregatorNode, - }, - NodeType.LEGACY_VARIABLE_AGGREGATOR: { - LATEST_VERSION: VariableAggregatorNode, - "1": VariableAggregatorNode, - }, # original name of VARIABLE_AGGREGATOR - NodeType.ITERATION: { - LATEST_VERSION: IterationNode, - "1": IterationNode, - }, - NodeType.ITERATION_START: { - LATEST_VERSION: IterationStartNode, - "1": IterationStartNode, - }, - NodeType.LOOP: { - LATEST_VERSION: LoopNode, - "1": LoopNode, - }, - NodeType.LOOP_START: { - LATEST_VERSION: LoopStartNode, - "1": LoopStartNode, - }, - NodeType.LOOP_END: { - LATEST_VERSION: LoopEndNode, - "1": LoopEndNode, - }, - NodeType.PARAMETER_EXTRACTOR: { - LATEST_VERSION: ParameterExtractorNode, - "1": ParameterExtractorNode, - }, - NodeType.VARIABLE_ASSIGNER: { - LATEST_VERSION: VariableAssignerNodeV2, - "1": VariableAssignerNodeV1, - "2": VariableAssignerNodeV2, - }, - NodeType.DOCUMENT_EXTRACTOR: { - LATEST_VERSION: DocumentExtractorNode, - "1": DocumentExtractorNode, - }, - NodeType.LIST_OPERATOR: { - LATEST_VERSION: ListOperatorNode, - "1": ListOperatorNode, - }, - NodeType.AGENT: { - LATEST_VERSION: AgentNode, - # This is an issue that caused problems before. - # Logically, we shouldn't use two different versions to point to the same class here, - # but in order to maintain compatibility with historical data, this approach has been retained. - "2": AgentNode, - "1": AgentNode, - }, - NodeType.HUMAN_INPUT: { - LATEST_VERSION: HumanInputNode, - "1": HumanInputNode, - }, - NodeType.DATASOURCE: { - LATEST_VERSION: DatasourceNode, - "1": DatasourceNode, - }, - NodeType.KNOWLEDGE_INDEX: { - LATEST_VERSION: KnowledgeIndexNode, - "1": KnowledgeIndexNode, - }, - NodeType.TRIGGER_WEBHOOK: { - LATEST_VERSION: TriggerWebhookNode, - "1": TriggerWebhookNode, - }, - NodeType.TRIGGER_PLUGIN: { - LATEST_VERSION: TriggerEventNode, - "1": TriggerEventNode, - }, - NodeType.TRIGGER_SCHEDULE: { - LATEST_VERSION: TriggerScheduleNode, - "1": TriggerScheduleNode, - }, -} +# Mapping is built by Node.get_node_type_classes_mapping(), which imports and walks core.workflow.nodes +NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping() diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index d8536474b1..2e7ec757b4 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -12,7 +12,6 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.errors import ToolInvokeError from core.tools.tool_engine import ToolEngine from core.tools.utils.message_transformer import ToolFileMessageTransformer -from core.tools.workflow_as_tool.tool import WorkflowTool from core.variables.segments import ArrayAnySegment, ArrayFileSegment from core.variables.variables import ArrayAnyVariable from core.workflow.enums import ( @@ -430,7 +429,7 @@ class ToolNode(Node[ToolNodeData]): metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = { WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, } - if usage.total_tokens > 0: + if isinstance(usage.total_tokens, int) and usage.total_tokens > 0: metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency @@ -449,8 +448,17 @@ class ToolNode(Node[ToolNodeData]): @staticmethod def _extract_tool_usage(tool_runtime: Tool) -> LLMUsage: - if isinstance(tool_runtime, WorkflowTool): - return tool_runtime.latest_usage + # Avoid importing WorkflowTool at module import time; rely on duck typing + # Some runtimes expose `latest_usage`; mocks may synthesize arbitrary attributes. + latest = getattr(tool_runtime, "latest_usage", None) + # Normalize into a concrete LLMUsage. MagicMock returns truthy attribute objects + # for any name, so we must type-check here. + if isinstance(latest, LLMUsage): + return latest + if isinstance(latest, dict): + # Allow dict payloads from external runtimes + return LLMUsage.model_validate(latest) + # Fallback to empty usage when attribute is missing or not a valid payload return LLMUsage.empty_usage() @classmethod diff --git a/api/extensions/ext_forward_refs.py b/api/extensions/ext_forward_refs.py new file mode 100644 index 0000000000..c40b505b16 --- /dev/null +++ b/api/extensions/ext_forward_refs.py @@ -0,0 +1,49 @@ +import logging + +from dify_app import DifyApp + + +def is_enabled() -> bool: + return True + + +def init_app(app: DifyApp): + """Resolve Pydantic forward refs that would otherwise cause circular imports. + + Rebuilds models in core.app.entities.app_invoke_entities with the real TraceQueueManager type. + Safe to run multiple times. + """ + logger = logging.getLogger(__name__) + try: + from core.app.entities.app_invoke_entities import ( + AdvancedChatAppGenerateEntity, + AgentChatAppGenerateEntity, + AppGenerateEntity, + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + ConversationAppGenerateEntity, + EasyUIBasedAppGenerateEntity, + RagPipelineGenerateEntity, + WorkflowAppGenerateEntity, + ) + from core.ops.ops_trace_manager import TraceQueueManager # heavy import, do it at startup only + + ns = {"TraceQueueManager": TraceQueueManager} + for Model in ( + AppGenerateEntity, + EasyUIBasedAppGenerateEntity, + ConversationAppGenerateEntity, + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + AgentChatAppGenerateEntity, + AdvancedChatAppGenerateEntity, + WorkflowAppGenerateEntity, + RagPipelineGenerateEntity, + ): + try: + Model.model_rebuild(_types_namespace=ns) + except Exception as e: + logger.debug("model_rebuild skipped for %s: %s", Model.__name__, e) + except Exception as e: + # Don't block app startup; just log at debug level. + logger.debug("ext_forward_refs init skipped: %s", e) diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py index 2597a3d65a..5716aae4c7 100644 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py +++ b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py @@ -29,7 +29,7 @@ class _TestNode(Node[_TestNodeData]): @classmethod def version(cls) -> str: - return "test" + return "1" def __init__( self, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index 68f57ee9fb..fd94a5e833 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -92,7 +92,7 @@ class MockLLMNode(MockNodeMixin, LLMNode): @classmethod def version(cls) -> str: """Return the version of this mock node.""" - return "mock-1" + return "1" def _run(self) -> Generator: """Execute mock LLM node.""" @@ -189,7 +189,7 @@ class MockAgentNode(MockNodeMixin, AgentNode): @classmethod def version(cls) -> str: """Return the version of this mock node.""" - return "mock-1" + return "1" def _run(self) -> Generator: """Execute mock agent node.""" @@ -241,7 +241,7 @@ class MockToolNode(MockNodeMixin, ToolNode): @classmethod def version(cls) -> str: """Return the version of this mock node.""" - return "mock-1" + return "1" def _run(self) -> Generator: """Execute mock tool node.""" @@ -294,7 +294,7 @@ class MockKnowledgeRetrievalNode(MockNodeMixin, KnowledgeRetrievalNode): @classmethod def version(cls) -> str: """Return the version of this mock node.""" - return "mock-1" + return "1" def _run(self) -> Generator: """Execute mock knowledge retrieval node.""" @@ -351,7 +351,7 @@ class MockHttpRequestNode(MockNodeMixin, HttpRequestNode): @classmethod def version(cls) -> str: """Return the version of this mock node.""" - return "mock-1" + return "1" def _run(self) -> Generator: """Execute mock HTTP request node.""" @@ -404,7 +404,7 @@ class MockQuestionClassifierNode(MockNodeMixin, QuestionClassifierNode): @classmethod def version(cls) -> str: """Return the version of this mock node.""" - return "mock-1" + return "1" def _run(self) -> Generator: """Execute mock question classifier node.""" @@ -452,7 +452,7 @@ class MockParameterExtractorNode(MockNodeMixin, ParameterExtractorNode): @classmethod def version(cls) -> str: """Return the version of this mock node.""" - return "mock-1" + return "1" def _run(self) -> Generator: """Execute mock parameter extractor node.""" @@ -502,7 +502,7 @@ class MockDocumentExtractorNode(MockNodeMixin, DocumentExtractorNode): @classmethod def version(cls) -> str: """Return the version of this mock node.""" - return "mock-1" + return "1" def _run(self) -> Generator: """Execute mock document extractor node.""" @@ -557,7 +557,7 @@ class MockIterationNode(MockNodeMixin, IterationNode): @classmethod def version(cls) -> str: """Return the version of this mock node.""" - return "mock-1" + return "1" def _create_graph_engine(self, index: int, item: Any): """Create a graph engine with MockNodeFactory instead of DifyNodeFactory.""" @@ -632,7 +632,7 @@ class MockLoopNode(MockNodeMixin, LoopNode): @classmethod def version(cls) -> str: """Return the version of this mock node.""" - return "mock-1" + return "1" def _create_graph_engine(self, start_at, root_node_id: str): """Create a graph engine with MockNodeFactory instead of DifyNodeFactory.""" @@ -694,7 +694,7 @@ class MockTemplateTransformNode(MockNodeMixin, TemplateTransformNode): @classmethod def version(cls) -> str: """Return the version of this mock node.""" - return "mock-1" + return "1" def _run(self) -> NodeRunResult: """Execute mock template transform node.""" @@ -780,7 +780,7 @@ class MockCodeNode(MockNodeMixin, CodeNode): @classmethod def version(cls) -> str: """Return the version of this mock node.""" - return "mock-1" + return "1" def _run(self) -> NodeRunResult: """Execute mock code node.""" diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py index 6eead80ac9..488b47761b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py @@ -33,6 +33,10 @@ def test_ensure_subclasses_of_base_node_has_node_type_and_version_method_defined type_version_set: set[tuple[NodeType, str]] = set() for cls in classes: + # Only validate production node classes; skip test-defined subclasses and external helpers + module_name = getattr(cls, "__module__", "") + if not module_name.startswith("core."): + continue # Validate that 'version' is directly defined in the class (not inherited) by checking the class's __dict__ assert "version" in cls.__dict__, f"class {cls} should have version method defined (NOT INHERITED.)" node_type = cls.node_type diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py b/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py new file mode 100644 index 0000000000..45d222b98c --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py @@ -0,0 +1,84 @@ +import types +from collections.abc import Mapping + +from core.workflow.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData +from core.workflow.nodes.base.node import Node + +# Import concrete nodes we will assert on (numeric version path) +from core.workflow.nodes.variable_assigner.v1.node import ( + VariableAssignerNode as VariableAssignerV1, +) +from core.workflow.nodes.variable_assigner.v2.node import ( + VariableAssignerNode as VariableAssignerV2, +) + + +def test_variable_assigner_latest_prefers_highest_numeric_version(): + # Act + mapping: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping() + + # Assert basic presence + assert NodeType.VARIABLE_ASSIGNER in mapping + va_versions = mapping[NodeType.VARIABLE_ASSIGNER] + + # Both concrete versions must be present + assert va_versions.get("1") is VariableAssignerV1 + assert va_versions.get("2") is VariableAssignerV2 + + # And latest should point to numerically-highest version ("2") + assert va_versions.get("latest") is VariableAssignerV2 + + +def test_latest_prefers_highest_numeric_version(): + # Arrange: define two ephemeral subclasses with numeric versions under a NodeType + # that has no concrete implementations in production to avoid interference. + class _Version1(Node[BaseNodeData]): # type: ignore[misc] + node_type = NodeType.LEGACY_VARIABLE_AGGREGATOR + + def init_node_data(self, data): + pass + + def _run(self): + raise NotImplementedError + + @classmethod + def version(cls) -> str: + return "1" + + def _get_error_strategy(self): + return None + + def _get_retry_config(self): + return types.SimpleNamespace() # not used + + def _get_title(self) -> str: + return "version1" + + def _get_description(self): + return None + + def _get_default_value_dict(self): + return {} + + def get_base_node_data(self): + return types.SimpleNamespace(title="version1") + + class _Version2(_Version1): # type: ignore[misc] + @classmethod + def version(cls) -> str: + return "2" + + def _get_title(self) -> str: + return "version2" + + # Act: build a fresh mapping (it should now see our ephemeral subclasses) + mapping: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping() + + # Assert: both numeric versions exist for this NodeType; 'latest' points to the higher numeric version + assert NodeType.LEGACY_VARIABLE_AGGREGATOR in mapping + legacy_versions = mapping[NodeType.LEGACY_VARIABLE_AGGREGATOR] + + assert legacy_versions.get("1") is _Version1 + assert legacy_versions.get("2") is _Version2 + assert legacy_versions.get("latest") is _Version2 diff --git a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py index 4a57ab2b89..1854cca236 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py @@ -19,7 +19,7 @@ class _SampleNode(Node[_SampleNodeData]): @classmethod def version(cls) -> str: - return "sample-test" + return "1" def _run(self): raise NotImplementedError From f94972f6627463ecb1733ffcf9b7f8e5051d3f61 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 1 Dec 2025 15:44:52 +0800 Subject: [PATCH 2/8] chore(deps): bump @lexical/list from 0.36.2 to 0.38.2 in /web (#28961) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- web/package.json | 2 +- web/pnpm-lock.yaml | 70 ++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 69 insertions(+), 3 deletions(-) diff --git a/web/package.json b/web/package.json index a646b26bab..11a8763566 100644 --- a/web/package.json +++ b/web/package.json @@ -53,7 +53,7 @@ "@hookform/resolvers": "^3.10.0", "@lexical/code": "^0.36.2", "@lexical/link": "^0.36.2", - "@lexical/list": "^0.36.2", + "@lexical/list": "^0.38.2", "@lexical/react": "^0.36.2", "@lexical/selection": "^0.37.0", "@lexical/text": "^0.38.2", diff --git a/web/pnpm-lock.yaml b/web/pnpm-lock.yaml index d65fb5e4f3..6038ec0153 100644 --- a/web/pnpm-lock.yaml +++ b/web/pnpm-lock.yaml @@ -85,8 +85,8 @@ importers: specifier: ^0.36.2 version: 0.36.2 '@lexical/list': - specifier: ^0.36.2 - version: 0.36.2 + specifier: ^0.38.2 + version: 0.38.2 '@lexical/react': specifier: ^0.36.2 version: 0.36.2(react-dom@19.1.1(react@19.1.1))(react@19.1.1)(yjs@13.6.27) @@ -2009,6 +2009,9 @@ packages: '@lexical/clipboard@0.37.0': resolution: {integrity: sha512-hRwASFX/ilaI5r8YOcZuQgONFshRgCPfdxfofNL7uruSFYAO6LkUhsjzZwUgf0DbmCJmbBADFw15FSthgCUhGA==} + '@lexical/clipboard@0.38.2': + resolution: {integrity: sha512-dDShUplCu8/o6BB9ousr3uFZ9bltR+HtleF/Tl8FXFNPpZ4AXhbLKUoJuucRuIr+zqT7RxEv/3M6pk/HEoE6NQ==} + '@lexical/code@0.36.2': resolution: {integrity: sha512-dfS62rNo3uKwNAJQ39zC+8gYX0k8UAoW7u+JPIqx+K2VPukZlvpsPLNGft15pdWBkHc7Pv+o9gJlB6gGv+EBfA==} @@ -2027,6 +2030,9 @@ packages: '@lexical/extension@0.37.0': resolution: {integrity: sha512-Z58f2tIdz9bn8gltUu5cVg37qROGha38dUZv20gI2GeNugXAkoPzJYEcxlI1D/26tkevJ/7VaFUr9PTk+iKmaA==} + '@lexical/extension@0.38.2': + resolution: {integrity: sha512-qbUNxEVjAC0kxp7hEMTzktj0/51SyJoIJWK6Gm790b4yNBq82fEPkksfuLkRg9VQUteD0RT1Nkjy8pho8nNamw==} + '@lexical/hashtag@0.36.2': resolution: {integrity: sha512-WdmKtzXFcahQT3ShFDeHF6LCR5C8yvFCj3ImI09rZwICrYeonbMrzsBUxS1joBz0HQ+ufF9Tx+RxLvGWx6WxzQ==} @@ -2039,6 +2045,9 @@ packages: '@lexical/html@0.37.0': resolution: {integrity: sha512-oTsBc45eL8/lmF7fqGR+UCjrJYP04gumzf5nk4TczrxWL2pM4GIMLLKG1mpQI2H1MDiRLzq3T/xdI7Gh74z7Zw==} + '@lexical/html@0.38.2': + resolution: {integrity: sha512-pC5AV+07bmHistRwgG3NJzBMlIzSdxYO6rJU4eBNzyR4becdiLsI4iuv+aY7PhfSv+SCs7QJ9oc4i5caq48Pkg==} + '@lexical/link@0.36.2': resolution: {integrity: sha512-Zb+DeHA1po8VMiOAAXsBmAHhfWmQttsUkI5oiZUmOXJruRuQ2rVr01NoxHpoEpLwHOABVNzD3PMbwov+g3c7lg==} @@ -2048,6 +2057,9 @@ packages: '@lexical/list@0.37.0': resolution: {integrity: sha512-AOC6yAA3mfNvJKbwo+kvAbPJI+13yF2ISA65vbA578CugvJ08zIVgM+pSzxquGhD0ioJY3cXVW7+gdkCP1qu5g==} + '@lexical/list@0.38.2': + resolution: {integrity: sha512-OQm9TzatlMrDZGxMxbozZEHzMJhKxAbH1TOnOGyFfzpfjbnFK2y8oLeVsfQZfZRmiqQS4Qc/rpFnRP2Ax5dsbA==} + '@lexical/mark@0.36.2': resolution: {integrity: sha512-n0MNXtGH+1i43hglgHjpQV0093HmIiFR7Budg2BJb8ZNzO1KZRqeXAHlA5ZzJ698FkAnS4R5bqG9tZ0JJHgAuA==} @@ -2078,12 +2090,18 @@ packages: '@lexical/selection@0.37.0': resolution: {integrity: sha512-Lix1s2r71jHfsTEs4q/YqK2s3uXKOnyA3fd1VDMWysO+bZzRwEO5+qyDvENZ0WrXSDCnlibNFV1HttWX9/zqyw==} + '@lexical/selection@0.38.2': + resolution: {integrity: sha512-eMFiWlBH6bEX9U9sMJ6PXPxVXTrihQfFeiIlWLuTpEIDF2HRz7Uo1KFRC/yN6q0DQaj7d9NZYA6Mei5DoQuz5w==} + '@lexical/table@0.36.2': resolution: {integrity: sha512-96rNNPiVbC65i+Jn1QzIsehCS7UVUc69ovrh9Bt4+pXDebZSdZai153Q7RUq8q3AQ5ocK4/SA2kLQfMu0grj3Q==} '@lexical/table@0.37.0': resolution: {integrity: sha512-g7S8ml8kIujEDLWlzYKETgPCQ2U9oeWqdytRuHjHGi/rjAAGHSej5IRqTPIMxNP3VVQHnBoQ+Y9hBtjiuddhgQ==} + '@lexical/table@0.38.2': + resolution: {integrity: sha512-uu0i7yz0nbClmHOO5ZFsinRJE6vQnFz2YPblYHAlNigiBedhqMwSv5bedrzDq8nTTHwych3mC63tcyKIrM+I1g==} + '@lexical/text@0.36.2': resolution: {integrity: sha512-IbbqgRdMAD6Uk9b2+qSVoy+8RVcczrz6OgXvg39+EYD+XEC7Rbw7kDTWzuNSJJpP7vxSO8YDZSaIlP5gNH3qKA==} @@ -2096,6 +2114,9 @@ packages: '@lexical/utils@0.37.0': resolution: {integrity: sha512-CFp4diY/kR5RqhzQSl/7SwsMod1sgLpI1FBifcOuJ6L/S6YywGpEB4B7aV5zqW21A/jU2T+2NZtxSUn6S+9gMg==} + '@lexical/utils@0.38.2': + resolution: {integrity: sha512-y+3rw15r4oAWIEXicUdNjfk8018dbKl7dWHqGHVEtqzAYefnEYdfD2FJ5KOTXfeoYfxi8yOW7FvzS4NZDi8Bfw==} + '@lexical/yjs@0.36.2': resolution: {integrity: sha512-gZ66Mw+uKXTO8KeX/hNKAinXbFg3gnNYraG76lBXCwb/Ka3q34upIY9FUeGOwGVaau3iIDQhE49I+6MugAX2FQ==} peerDependencies: @@ -10221,6 +10242,14 @@ snapshots: '@lexical/utils': 0.37.0 lexical: 0.37.0 + '@lexical/clipboard@0.38.2': + dependencies: + '@lexical/html': 0.38.2 + '@lexical/list': 0.38.2 + '@lexical/selection': 0.38.2 + '@lexical/utils': 0.38.2 + lexical: 0.37.0 + '@lexical/code@0.36.2': dependencies: '@lexical/utils': 0.36.2 @@ -10255,6 +10284,12 @@ snapshots: '@preact/signals-core': 1.12.1 lexical: 0.37.0 + '@lexical/extension@0.38.2': + dependencies: + '@lexical/utils': 0.38.2 + '@preact/signals-core': 1.12.1 + lexical: 0.37.0 + '@lexical/hashtag@0.36.2': dependencies: '@lexical/text': 0.36.2 @@ -10279,6 +10314,12 @@ snapshots: '@lexical/utils': 0.37.0 lexical: 0.37.0 + '@lexical/html@0.38.2': + dependencies: + '@lexical/selection': 0.38.2 + '@lexical/utils': 0.38.2 + lexical: 0.37.0 + '@lexical/link@0.36.2': dependencies: '@lexical/extension': 0.36.2 @@ -10299,6 +10340,13 @@ snapshots: '@lexical/utils': 0.37.0 lexical: 0.37.0 + '@lexical/list@0.38.2': + dependencies: + '@lexical/extension': 0.38.2 + '@lexical/selection': 0.38.2 + '@lexical/utils': 0.38.2 + lexical: 0.37.0 + '@lexical/mark@0.36.2': dependencies: '@lexical/utils': 0.36.2 @@ -10372,6 +10420,10 @@ snapshots: dependencies: lexical: 0.37.0 + '@lexical/selection@0.38.2': + dependencies: + lexical: 0.37.0 + '@lexical/table@0.36.2': dependencies: '@lexical/clipboard': 0.36.2 @@ -10386,6 +10438,13 @@ snapshots: '@lexical/utils': 0.37.0 lexical: 0.37.0 + '@lexical/table@0.38.2': + dependencies: + '@lexical/clipboard': 0.38.2 + '@lexical/extension': 0.38.2 + '@lexical/utils': 0.38.2 + lexical: 0.37.0 + '@lexical/text@0.36.2': dependencies: lexical: 0.37.0 @@ -10408,6 +10467,13 @@ snapshots: '@lexical/table': 0.37.0 lexical: 0.37.0 + '@lexical/utils@0.38.2': + dependencies: + '@lexical/list': 0.38.2 + '@lexical/selection': 0.38.2 + '@lexical/table': 0.38.2 + lexical: 0.37.0 + '@lexical/yjs@0.36.2(yjs@13.6.27)': dependencies: '@lexical/offset': 0.36.2 From 70dabe318ca4aeb3e1a8f90a525865b4d421e7d0 Mon Sep 17 00:00:00 2001 From: Gritty_dev <101377478+codomposer@users.noreply.github.com> Date: Mon, 1 Dec 2025 02:45:22 -0500 Subject: [PATCH 3/8] feat: complete test script of mail send task (#28963) --- .../unit_tests/tasks/test_mail_send_task.py | 1504 +++++++++++++++++ 1 file changed, 1504 insertions(+) create mode 100644 api/tests/unit_tests/tasks/test_mail_send_task.py diff --git a/api/tests/unit_tests/tasks/test_mail_send_task.py b/api/tests/unit_tests/tasks/test_mail_send_task.py new file mode 100644 index 0000000000..736871d784 --- /dev/null +++ b/api/tests/unit_tests/tasks/test_mail_send_task.py @@ -0,0 +1,1504 @@ +""" +Unit tests for mail send tasks. + +This module tests the mail sending functionality including: +- Email template rendering with internationalization +- SMTP integration with various configurations +- Retry logic for failed email sends +- Error handling and logging +""" + +import smtplib +from unittest.mock import MagicMock, patch + +import pytest + +from configs import dify_config +from configs.feature import TemplateMode +from libs.email_i18n import EmailType +from tasks.mail_inner_task import _render_template_with_strategy, send_inner_email_task +from tasks.mail_register_task import ( + send_email_register_mail_task, + send_email_register_mail_task_when_account_exist, +) +from tasks.mail_reset_password_task import ( + send_reset_password_mail_task, + send_reset_password_mail_task_when_account_not_exist, +) + + +class TestEmailTemplateRendering: + """Test email template rendering with various scenarios.""" + + def test_render_template_unsafe_mode(self): + """Test template rendering in unsafe mode with Jinja2 syntax.""" + # Arrange + body = "Hello {{ name }}, your code is {{ code }}" + substitutions = {"name": "John", "code": "123456"} + + # Act + with patch.object(dify_config, "MAIL_TEMPLATING_MODE", TemplateMode.UNSAFE): + result = _render_template_with_strategy(body, substitutions) + + # Assert + assert result == "Hello John, your code is 123456" + + def test_render_template_sandbox_mode(self): + """Test template rendering in sandbox mode for security.""" + # Arrange + body = "Hello {{ name }}, your code is {{ code }}" + substitutions = {"name": "Alice", "code": "654321"} + + # Act + with patch.object(dify_config, "MAIL_TEMPLATING_MODE", TemplateMode.SANDBOX): + with patch.object(dify_config, "MAIL_TEMPLATING_TIMEOUT", 3): + result = _render_template_with_strategy(body, substitutions) + + # Assert + assert result == "Hello Alice, your code is 654321" + + def test_render_template_disabled_mode(self): + """Test template rendering when templating is disabled.""" + # Arrange + body = "Hello {{ name }}, your code is {{ code }}" + substitutions = {"name": "Bob", "code": "999999"} + + # Act + with patch.object(dify_config, "MAIL_TEMPLATING_MODE", TemplateMode.DISABLED): + result = _render_template_with_strategy(body, substitutions) + + # Assert - should return body unchanged + assert result == "Hello {{ name }}, your code is {{ code }}" + + def test_render_template_sandbox_timeout(self): + """Test that sandbox mode respects timeout settings and range limits.""" + # Arrange - template with very large range (exceeds sandbox MAX_RANGE) + body = "{% for i in range(1000000) %}{{ i }}{% endfor %}" + substitutions: dict[str, str] = {} + + # Act & Assert - sandbox blocks ranges larger than MAX_RANGE (100000) + with patch.object(dify_config, "MAIL_TEMPLATING_MODE", TemplateMode.SANDBOX): + with patch.object(dify_config, "MAIL_TEMPLATING_TIMEOUT", 1): + # Should raise OverflowError for range too big + with pytest.raises((TimeoutError, RuntimeError, OverflowError)): + _render_template_with_strategy(body, substitutions) + + def test_render_template_invalid_mode(self): + """Test that invalid template mode raises ValueError.""" + # Arrange + body = "Test" + substitutions: dict[str, str] = {} + + # Act & Assert + with patch.object(dify_config, "MAIL_TEMPLATING_MODE", "invalid_mode"): + with pytest.raises(ValueError, match="Unsupported mail templating mode"): + _render_template_with_strategy(body, substitutions) + + def test_render_template_with_special_characters(self): + """Test template rendering with special characters and HTML.""" + # Arrange + body = "

Hello {{ name }}

Code: {{ code }}

" + substitutions = {"name": "Test", "code": "ABC&123"} + + # Act + with patch.object(dify_config, "MAIL_TEMPLATING_MODE", TemplateMode.SANDBOX): + result = _render_template_with_strategy(body, substitutions) + + # Assert + assert "Test" in result + assert "ABC&123" in result + + def test_render_template_missing_variable_sandbox(self): + """Test sandbox mode handles missing variables gracefully.""" + # Arrange + body = "Hello {{ name }}, your code is {{ missing_var }}" + substitutions = {"name": "John"} + + # Act - sandbox mode renders undefined variables as empty strings by default + with patch.object(dify_config, "MAIL_TEMPLATING_MODE", TemplateMode.SANDBOX): + result = _render_template_with_strategy(body, substitutions) + + # Assert - undefined variable is rendered as empty string + assert "Hello John" in result + assert "missing_var" not in result # Variable name should not appear in output + + +class TestSMTPIntegration: + """Test SMTP client integration with various configurations.""" + + @patch("libs.smtp.smtplib.SMTP_SSL") + def test_smtp_send_with_tls_ssl(self, mock_smtp_ssl): + """Test SMTP send with TLS using SMTP_SSL.""" + # Arrange + from libs.smtp import SMTPClient + + mock_server = MagicMock() + mock_smtp_ssl.return_value = mock_server + + client = SMTPClient( + server="smtp.example.com", + port=465, + username="user@example.com", + password="password123", + _from="noreply@example.com", + use_tls=True, + opportunistic_tls=False, + ) + + mail_data = {"to": "recipient@example.com", "subject": "Test Subject", "html": "

Test Content

"} + + # Act + client.send(mail_data) + + # Assert + mock_smtp_ssl.assert_called_once_with("smtp.example.com", 465, timeout=10) + mock_server.login.assert_called_once_with("user@example.com", "password123") + mock_server.sendmail.assert_called_once() + mock_server.quit.assert_called_once() + + @patch("libs.smtp.smtplib.SMTP") + def test_smtp_send_with_opportunistic_tls(self, mock_smtp): + """Test SMTP send with opportunistic TLS (STARTTLS).""" + # Arrange + from libs.smtp import SMTPClient + + mock_server = MagicMock() + mock_smtp.return_value = mock_server + + client = SMTPClient( + server="smtp.example.com", + port=587, + username="user@example.com", + password="password123", + _from="noreply@example.com", + use_tls=True, + opportunistic_tls=True, + ) + + mail_data = {"to": "recipient@example.com", "subject": "Test", "html": "

Content

"} + + # Act + client.send(mail_data) + + # Assert + mock_smtp.assert_called_once_with("smtp.example.com", 587, timeout=10) + mock_server.ehlo.assert_called() + mock_server.starttls.assert_called_once() + assert mock_server.ehlo.call_count == 2 # Before and after STARTTLS + mock_server.sendmail.assert_called_once() + mock_server.quit.assert_called_once() + + @patch("libs.smtp.smtplib.SMTP") + def test_smtp_send_without_tls(self, mock_smtp): + """Test SMTP send without TLS encryption.""" + # Arrange + from libs.smtp import SMTPClient + + mock_server = MagicMock() + mock_smtp.return_value = mock_server + + client = SMTPClient( + server="smtp.example.com", + port=25, + username="user@example.com", + password="password123", + _from="noreply@example.com", + use_tls=False, + opportunistic_tls=False, + ) + + mail_data = {"to": "recipient@example.com", "subject": "Test", "html": "

Content

"} + + # Act + client.send(mail_data) + + # Assert + mock_smtp.assert_called_once_with("smtp.example.com", 25, timeout=10) + mock_server.login.assert_called_once() + mock_server.sendmail.assert_called_once() + mock_server.quit.assert_called_once() + + @patch("libs.smtp.smtplib.SMTP") + def test_smtp_send_without_authentication(self, mock_smtp): + """Test SMTP send without authentication (empty credentials).""" + # Arrange + from libs.smtp import SMTPClient + + mock_server = MagicMock() + mock_smtp.return_value = mock_server + + client = SMTPClient( + server="smtp.example.com", + port=25, + username="", + password="", + _from="noreply@example.com", + use_tls=False, + opportunistic_tls=False, + ) + + mail_data = {"to": "recipient@example.com", "subject": "Test", "html": "

Content

"} + + # Act + client.send(mail_data) + + # Assert + mock_server.login.assert_not_called() # Should skip login with empty credentials + mock_server.sendmail.assert_called_once() + mock_server.quit.assert_called_once() + + @patch("libs.smtp.smtplib.SMTP_SSL") + def test_smtp_send_authentication_failure(self, mock_smtp_ssl): + """Test SMTP send handles authentication failure.""" + # Arrange + from libs.smtp import SMTPClient + + mock_server = MagicMock() + mock_smtp_ssl.return_value = mock_server + mock_server.login.side_effect = smtplib.SMTPAuthenticationError(535, b"Authentication failed") + + client = SMTPClient( + server="smtp.example.com", + port=465, + username="user@example.com", + password="wrong_password", + _from="noreply@example.com", + use_tls=True, + opportunistic_tls=False, + ) + + mail_data = {"to": "recipient@example.com", "subject": "Test", "html": "

Content

"} + + # Act & Assert + with pytest.raises(smtplib.SMTPAuthenticationError): + client.send(mail_data) + + mock_server.quit.assert_called_once() # Should still cleanup + + @patch("libs.smtp.smtplib.SMTP_SSL") + def test_smtp_send_timeout_error(self, mock_smtp_ssl): + """Test SMTP send handles timeout errors.""" + # Arrange + from libs.smtp import SMTPClient + + mock_smtp_ssl.side_effect = TimeoutError("Connection timeout") + + client = SMTPClient( + server="smtp.example.com", + port=465, + username="user@example.com", + password="password123", + _from="noreply@example.com", + use_tls=True, + opportunistic_tls=False, + ) + + mail_data = {"to": "recipient@example.com", "subject": "Test", "html": "

Content

"} + + # Act & Assert + with pytest.raises(TimeoutError): + client.send(mail_data) + + @patch("libs.smtp.smtplib.SMTP_SSL") + def test_smtp_send_connection_refused(self, mock_smtp_ssl): + """Test SMTP send handles connection refused errors.""" + # Arrange + from libs.smtp import SMTPClient + + mock_smtp_ssl.side_effect = ConnectionRefusedError("Connection refused") + + client = SMTPClient( + server="smtp.example.com", + port=465, + username="user@example.com", + password="password123", + _from="noreply@example.com", + use_tls=True, + opportunistic_tls=False, + ) + + mail_data = {"to": "recipient@example.com", "subject": "Test", "html": "

Content

"} + + # Act & Assert + with pytest.raises((ConnectionRefusedError, OSError)): + client.send(mail_data) + + @patch("libs.smtp.smtplib.SMTP_SSL") + def test_smtp_send_ensures_cleanup_on_error(self, mock_smtp_ssl): + """Test SMTP send ensures cleanup even when errors occur.""" + # Arrange + from libs.smtp import SMTPClient + + mock_server = MagicMock() + mock_smtp_ssl.return_value = mock_server + mock_server.sendmail.side_effect = smtplib.SMTPException("Send failed") + + client = SMTPClient( + server="smtp.example.com", + port=465, + username="user@example.com", + password="password123", + _from="noreply@example.com", + use_tls=True, + opportunistic_tls=False, + ) + + mail_data = {"to": "recipient@example.com", "subject": "Test", "html": "

Content

"} + + # Act & Assert + with pytest.raises(smtplib.SMTPException): + client.send(mail_data) + + # Verify cleanup was called + mock_server.quit.assert_called_once() + + +class TestMailTaskRetryLogic: + """Test retry logic for mail sending tasks.""" + + @patch("tasks.mail_register_task.mail") + def test_mail_task_skips_when_not_initialized(self, mock_mail): + """Test that mail tasks skip execution when mail is not initialized.""" + # Arrange + mock_mail.is_inited.return_value = False + + # Act + result = send_email_register_mail_task(language="en-US", to="test@example.com", code="123456") + + # Assert + assert result is None + mock_mail.is_inited.assert_called_once() + + @patch("tasks.mail_register_task.get_email_i18n_service") + @patch("tasks.mail_register_task.mail") + @patch("tasks.mail_register_task.logger") + def test_mail_task_logs_success(self, mock_logger, mock_mail, mock_email_service): + """Test that successful mail sends are logged properly.""" + # Arrange + mock_mail.is_inited.return_value = True + mock_service = MagicMock() + mock_email_service.return_value = mock_service + + # Act + send_email_register_mail_task(language="en-US", to="test@example.com", code="123456") + + # Assert + mock_service.send_email.assert_called_once_with( + email_type=EmailType.EMAIL_REGISTER, + language_code="en-US", + to="test@example.com", + template_context={"to": "test@example.com", "code": "123456"}, + ) + # Verify logging calls + assert mock_logger.info.call_count == 2 # Start and success logs + + @patch("tasks.mail_register_task.get_email_i18n_service") + @patch("tasks.mail_register_task.mail") + @patch("tasks.mail_register_task.logger") + def test_mail_task_logs_failure(self, mock_logger, mock_mail, mock_email_service): + """Test that failed mail sends are logged with exception details.""" + # Arrange + mock_mail.is_inited.return_value = True + mock_service = MagicMock() + mock_service.send_email.side_effect = Exception("SMTP connection failed") + mock_email_service.return_value = mock_service + + # Act + send_email_register_mail_task(language="en-US", to="test@example.com", code="123456") + + # Assert + mock_logger.exception.assert_called_once_with("Send email register mail to %s failed", "test@example.com") + + @patch("tasks.mail_reset_password_task.get_email_i18n_service") + @patch("tasks.mail_reset_password_task.mail") + def test_reset_password_task_success(self, mock_mail, mock_email_service): + """Test reset password task sends email successfully.""" + # Arrange + mock_mail.is_inited.return_value = True + mock_service = MagicMock() + mock_email_service.return_value = mock_service + + # Act + send_reset_password_mail_task(language="zh-Hans", to="user@example.com", code="RESET123") + + # Assert + mock_service.send_email.assert_called_once_with( + email_type=EmailType.RESET_PASSWORD, + language_code="zh-Hans", + to="user@example.com", + template_context={"to": "user@example.com", "code": "RESET123"}, + ) + + @patch("tasks.mail_reset_password_task.get_email_i18n_service") + @patch("tasks.mail_reset_password_task.mail") + @patch("tasks.mail_reset_password_task.dify_config") + def test_reset_password_when_account_not_exist_with_register(self, mock_config, mock_mail, mock_email_service): + """Test reset password task when account doesn't exist and registration is allowed.""" + # Arrange + mock_mail.is_inited.return_value = True + mock_config.CONSOLE_WEB_URL = "https://console.example.com" + mock_service = MagicMock() + mock_email_service.return_value = mock_service + + # Act + send_reset_password_mail_task_when_account_not_exist( + language="en-US", to="newuser@example.com", is_allow_register=True + ) + + # Assert + mock_service.send_email.assert_called_once() + call_args = mock_service.send_email.call_args + assert call_args[1]["email_type"] == EmailType.RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST + assert call_args[1]["to"] == "newuser@example.com" + assert "sign_up_url" in call_args[1]["template_context"] + + @patch("tasks.mail_reset_password_task.get_email_i18n_service") + @patch("tasks.mail_reset_password_task.mail") + def test_reset_password_when_account_not_exist_without_register(self, mock_mail, mock_email_service): + """Test reset password task when account doesn't exist and registration is not allowed.""" + # Arrange + mock_mail.is_inited.return_value = True + mock_service = MagicMock() + mock_email_service.return_value = mock_service + + # Act + send_reset_password_mail_task_when_account_not_exist( + language="en-US", to="newuser@example.com", is_allow_register=False + ) + + # Assert + mock_service.send_email.assert_called_once() + call_args = mock_service.send_email.call_args + assert call_args[1]["email_type"] == EmailType.RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST_NO_REGISTER + + +class TestMailTaskInternationalization: + """Test internationalization support in mail tasks.""" + + @patch("tasks.mail_register_task.get_email_i18n_service") + @patch("tasks.mail_register_task.mail") + def test_mail_task_with_english_language(self, mock_mail, mock_email_service): + """Test mail task with English language code.""" + # Arrange + mock_mail.is_inited.return_value = True + mock_service = MagicMock() + mock_email_service.return_value = mock_service + + # Act + send_email_register_mail_task(language="en-US", to="test@example.com", code="123456") + + # Assert + call_args = mock_service.send_email.call_args + assert call_args[1]["language_code"] == "en-US" + + @patch("tasks.mail_register_task.get_email_i18n_service") + @patch("tasks.mail_register_task.mail") + def test_mail_task_with_chinese_language(self, mock_mail, mock_email_service): + """Test mail task with Chinese language code.""" + # Arrange + mock_mail.is_inited.return_value = True + mock_service = MagicMock() + mock_email_service.return_value = mock_service + + # Act + send_email_register_mail_task(language="zh-Hans", to="test@example.com", code="123456") + + # Assert + call_args = mock_service.send_email.call_args + assert call_args[1]["language_code"] == "zh-Hans" + + @patch("tasks.mail_register_task.get_email_i18n_service") + @patch("tasks.mail_register_task.mail") + @patch("tasks.mail_register_task.dify_config") + def test_account_exist_task_includes_urls(self, mock_config, mock_mail, mock_email_service): + """Test account exist task includes proper URLs in template context.""" + # Arrange + mock_mail.is_inited.return_value = True + mock_config.CONSOLE_WEB_URL = "https://console.example.com" + mock_service = MagicMock() + mock_email_service.return_value = mock_service + + # Act + send_email_register_mail_task_when_account_exist( + language="en-US", to="existing@example.com", account_name="John Doe" + ) + + # Assert + call_args = mock_service.send_email.call_args + context = call_args[1]["template_context"] + assert context["login_url"] == "https://console.example.com/signin" + assert context["reset_password_url"] == "https://console.example.com/reset-password" + assert context["account_name"] == "John Doe" + + +class TestInnerEmailTask: + """Test inner email task with template rendering.""" + + @patch("tasks.mail_inner_task.get_email_i18n_service") + @patch("tasks.mail_inner_task.mail") + @patch("tasks.mail_inner_task._render_template_with_strategy") + def test_inner_email_task_renders_and_sends(self, mock_render, mock_mail, mock_email_service): + """Test inner email task renders template and sends email.""" + # Arrange + mock_mail.is_inited.return_value = True + mock_render.return_value = "

Hello John, your code is 123456

" + mock_service = MagicMock() + mock_email_service.return_value = mock_service + + to_list = ["user1@example.com", "user2@example.com"] + subject = "Test Subject" + body = "

Hello {{ name }}, your code is {{ code }}

" + substitutions = {"name": "John", "code": "123456"} + + # Act + send_inner_email_task(to=to_list, subject=subject, body=body, substitutions=substitutions) + + # Assert + mock_render.assert_called_once_with(body, substitutions) + mock_service.send_raw_email.assert_called_once_with( + to=to_list, subject=subject, html_content="

Hello John, your code is 123456

" + ) + + @patch("tasks.mail_inner_task.mail") + def test_inner_email_task_skips_when_not_initialized(self, mock_mail): + """Test inner email task skips when mail is not initialized.""" + # Arrange + mock_mail.is_inited.return_value = False + + # Act + result = send_inner_email_task(to=["test@example.com"], subject="Test", body="Body", substitutions={}) + + # Assert + assert result is None + + @patch("tasks.mail_inner_task.get_email_i18n_service") + @patch("tasks.mail_inner_task.mail") + @patch("tasks.mail_inner_task._render_template_with_strategy") + @patch("tasks.mail_inner_task.logger") + def test_inner_email_task_logs_failure(self, mock_logger, mock_render, mock_mail, mock_email_service): + """Test inner email task logs failures properly.""" + # Arrange + mock_mail.is_inited.return_value = True + mock_render.return_value = "

Content

" + mock_service = MagicMock() + mock_service.send_raw_email.side_effect = Exception("Send failed") + mock_email_service.return_value = mock_service + + to_list = ["user@example.com"] + + # Act + send_inner_email_task(to=to_list, subject="Test", body="Body", substitutions={}) + + # Assert + mock_logger.exception.assert_called_once() + + +class TestSendGridIntegration: + """Test SendGrid client integration.""" + + @patch("libs.sendgrid.sendgrid.SendGridAPIClient") + def test_sendgrid_send_success(self, mock_sg_client): + """Test SendGrid client sends email successfully.""" + # Arrange + from libs.sendgrid import SendGridClient + + mock_client_instance = MagicMock() + mock_sg_client.return_value = mock_client_instance + mock_response = MagicMock() + mock_response.status_code = 202 + mock_client_instance.client.mail.send.post.return_value = mock_response + + client = SendGridClient(sendgrid_api_key="test_api_key", _from="noreply@example.com") + + mail_data = {"to": "recipient@example.com", "subject": "Test Subject", "html": "

Test Content

"} + + # Act + client.send(mail_data) + + # Assert + mock_sg_client.assert_called_once_with(api_key="test_api_key") + mock_client_instance.client.mail.send.post.assert_called_once() + + @patch("libs.sendgrid.sendgrid.SendGridAPIClient") + def test_sendgrid_send_missing_recipient(self, mock_sg_client): + """Test SendGrid client raises error when recipient is missing.""" + # Arrange + from libs.sendgrid import SendGridClient + + client = SendGridClient(sendgrid_api_key="test_api_key", _from="noreply@example.com") + + mail_data = {"to": "", "subject": "Test Subject", "html": "

Test Content

"} + + # Act & Assert + with pytest.raises(ValueError, match="recipient address is missing"): + client.send(mail_data) + + @patch("libs.sendgrid.sendgrid.SendGridAPIClient") + def test_sendgrid_send_unauthorized_error(self, mock_sg_client): + """Test SendGrid client handles unauthorized errors.""" + # Arrange + from python_http_client.exceptions import UnauthorizedError + + from libs.sendgrid import SendGridClient + + mock_client_instance = MagicMock() + mock_sg_client.return_value = mock_client_instance + mock_client_instance.client.mail.send.post.side_effect = UnauthorizedError( + MagicMock(status_code=401), "Unauthorized" + ) + + client = SendGridClient(sendgrid_api_key="invalid_key", _from="noreply@example.com") + + mail_data = {"to": "recipient@example.com", "subject": "Test", "html": "

Content

"} + + # Act & Assert + with pytest.raises(UnauthorizedError): + client.send(mail_data) + + @patch("libs.sendgrid.sendgrid.SendGridAPIClient") + def test_sendgrid_send_forbidden_error(self, mock_sg_client): + """Test SendGrid client handles forbidden errors.""" + # Arrange + from python_http_client.exceptions import ForbiddenError + + from libs.sendgrid import SendGridClient + + mock_client_instance = MagicMock() + mock_sg_client.return_value = mock_client_instance + mock_client_instance.client.mail.send.post.side_effect = ForbiddenError(MagicMock(status_code=403), "Forbidden") + + client = SendGridClient(sendgrid_api_key="test_api_key", _from="invalid@example.com") + + mail_data = {"to": "recipient@example.com", "subject": "Test", "html": "

Content

"} + + # Act & Assert + with pytest.raises(ForbiddenError): + client.send(mail_data) + + @patch("libs.sendgrid.sendgrid.SendGridAPIClient") + def test_sendgrid_send_timeout_error(self, mock_sg_client): + """Test SendGrid client handles timeout errors.""" + # Arrange + from libs.sendgrid import SendGridClient + + mock_client_instance = MagicMock() + mock_sg_client.return_value = mock_client_instance + mock_client_instance.client.mail.send.post.side_effect = TimeoutError("Request timeout") + + client = SendGridClient(sendgrid_api_key="test_api_key", _from="noreply@example.com") + + mail_data = {"to": "recipient@example.com", "subject": "Test", "html": "

Content

"} + + # Act & Assert + with pytest.raises(TimeoutError): + client.send(mail_data) + + +class TestMailExtension: + """Test mail extension initialization and configuration.""" + + @patch("extensions.ext_mail.dify_config") + def test_mail_init_smtp_configuration(self, mock_config): + """Test mail extension initializes SMTP client correctly.""" + # Arrange + from extensions.ext_mail import Mail + + mock_config.MAIL_TYPE = "smtp" + mock_config.SMTP_SERVER = "smtp.example.com" + mock_config.SMTP_PORT = 465 + mock_config.SMTP_USERNAME = "user@example.com" + mock_config.SMTP_PASSWORD = "password123" + mock_config.SMTP_USE_TLS = True + mock_config.SMTP_OPPORTUNISTIC_TLS = False + mock_config.MAIL_DEFAULT_SEND_FROM = "noreply@example.com" + + mail = Mail() + mock_app = MagicMock() + + # Act + mail.init_app(mock_app) + + # Assert + assert mail.is_inited() is True + assert mail._client is not None + + @patch("extensions.ext_mail.dify_config") + def test_mail_init_without_mail_type(self, mock_config): + """Test mail extension skips initialization when MAIL_TYPE is not set.""" + # Arrange + from extensions.ext_mail import Mail + + mock_config.MAIL_TYPE = None + + mail = Mail() + mock_app = MagicMock() + + # Act + mail.init_app(mock_app) + + # Assert + assert mail.is_inited() is False + + @patch("extensions.ext_mail.dify_config") + def test_mail_send_validates_parameters(self, mock_config): + """Test mail send validates required parameters.""" + # Arrange + from extensions.ext_mail import Mail + + mail = Mail() + mail._client = MagicMock() + mail._default_send_from = "noreply@example.com" + + # Act & Assert - missing to + with pytest.raises(ValueError, match="mail to is not set"): + mail.send(to="", subject="Test", html="

Content

") + + # Act & Assert - missing subject + with pytest.raises(ValueError, match="mail subject is not set"): + mail.send(to="test@example.com", subject="", html="

Content

") + + # Act & Assert - missing html + with pytest.raises(ValueError, match="mail html is not set"): + mail.send(to="test@example.com", subject="Test", html="") + + @patch("extensions.ext_mail.dify_config") + def test_mail_send_uses_default_from(self, mock_config): + """Test mail send uses default from address when not provided.""" + # Arrange + from extensions.ext_mail import Mail + + mail = Mail() + mock_client = MagicMock() + mail._client = mock_client + mail._default_send_from = "default@example.com" + + # Act + mail.send(to="test@example.com", subject="Test", html="

Content

") + + # Assert + mock_client.send.assert_called_once() + call_args = mock_client.send.call_args[0][0] + assert call_args["from"] == "default@example.com" + + +class TestEmailI18nService: + """Test email internationalization service.""" + + @patch("libs.email_i18n.FlaskMailSender") + @patch("libs.email_i18n.FeatureBrandingService") + @patch("libs.email_i18n.FlaskEmailRenderer") + def test_email_service_sends_with_branding(self, mock_renderer_class, mock_branding_class, mock_sender_class): + """Test email service sends email with branding support.""" + # Arrange + from libs.email_i18n import EmailI18nConfig, EmailI18nService, EmailLanguage, EmailTemplate, EmailType + from services.feature_service import BrandingModel + + mock_renderer = MagicMock() + mock_renderer.render_template.return_value = "Rendered content" + mock_renderer_class.return_value = mock_renderer + + mock_branding = MagicMock() + mock_branding.get_branding_config.return_value = BrandingModel( + enabled=True, application_title="Custom App", logo="logo.png" + ) + mock_branding_class.return_value = mock_branding + + mock_sender = MagicMock() + mock_sender_class.return_value = mock_sender + + template = EmailTemplate( + subject="Test {application_title}", + template_path="templates/test.html", + branded_template_path="templates/branded/test.html", + ) + + config = EmailI18nConfig(templates={EmailType.EMAIL_REGISTER: {EmailLanguage.EN_US: template}}) + + service = EmailI18nService( + config=config, renderer=mock_renderer, branding_service=mock_branding, sender=mock_sender + ) + + # Act + service.send_email( + email_type=EmailType.EMAIL_REGISTER, + language_code="en-US", + to="test@example.com", + template_context={"code": "123456"}, + ) + + # Assert + mock_renderer.render_template.assert_called_once() + # Should use branded template + assert mock_renderer.render_template.call_args[0][0] == "templates/branded/test.html" + mock_sender.send_email.assert_called_once_with( + to="test@example.com", subject="Test Custom App", html_content="Rendered content" + ) + + @patch("libs.email_i18n.FlaskMailSender") + def test_email_service_send_raw_email_single_recipient(self, mock_sender_class): + """Test email service sends raw email to single recipient.""" + # Arrange + from libs.email_i18n import EmailI18nConfig, EmailI18nService + + mock_sender = MagicMock() + mock_sender_class.return_value = mock_sender + + service = EmailI18nService( + config=EmailI18nConfig(), + renderer=MagicMock(), + branding_service=MagicMock(), + sender=mock_sender, + ) + + # Act + service.send_raw_email(to="test@example.com", subject="Test", html_content="

Content

") + + # Assert + mock_sender.send_email.assert_called_once_with( + to="test@example.com", subject="Test", html_content="

Content

" + ) + + @patch("libs.email_i18n.FlaskMailSender") + def test_email_service_send_raw_email_multiple_recipients(self, mock_sender_class): + """Test email service sends raw email to multiple recipients.""" + # Arrange + from libs.email_i18n import EmailI18nConfig, EmailI18nService + + mock_sender = MagicMock() + mock_sender_class.return_value = mock_sender + + service = EmailI18nService( + config=EmailI18nConfig(), + renderer=MagicMock(), + branding_service=MagicMock(), + sender=mock_sender, + ) + + # Act + service.send_raw_email( + to=["user1@example.com", "user2@example.com"], subject="Test", html_content="

Content

" + ) + + # Assert + assert mock_sender.send_email.call_count == 2 + mock_sender.send_email.assert_any_call(to="user1@example.com", subject="Test", html_content="

Content

") + mock_sender.send_email.assert_any_call(to="user2@example.com", subject="Test", html_content="

Content

") + + +class TestPerformanceAndTiming: + """Test performance tracking and timing in mail tasks.""" + + @patch("tasks.mail_register_task.get_email_i18n_service") + @patch("tasks.mail_register_task.mail") + @patch("tasks.mail_register_task.logger") + @patch("tasks.mail_register_task.time") + def test_mail_task_tracks_execution_time(self, mock_time, mock_logger, mock_mail, mock_email_service): + """Test that mail tasks track and log execution time.""" + # Arrange + mock_mail.is_inited.return_value = True + mock_service = MagicMock() + mock_email_service.return_value = mock_service + + # Simulate time progression + mock_time.perf_counter.side_effect = [100.0, 100.5] # 0.5 second execution + + # Act + send_email_register_mail_task(language="en-US", to="test@example.com", code="123456") + + # Assert + assert mock_time.perf_counter.call_count == 2 + # Verify latency is logged + success_log_call = mock_logger.info.call_args_list[1] + assert "latency" in str(success_log_call) + + +class TestEdgeCasesAndErrorHandling: + """ + Test edge cases and error handling scenarios. + + This test class covers unusual inputs, boundary conditions, + and various error scenarios to ensure robust error handling. + """ + + @patch("extensions.ext_mail.dify_config") + def test_mail_init_invalid_smtp_config_missing_server(self, mock_config): + """ + Test mail initialization fails when SMTP server is missing. + + Validates that proper error is raised when required SMTP + configuration parameters are not provided. + """ + # Arrange + from extensions.ext_mail import Mail + + mock_config.MAIL_TYPE = "smtp" + mock_config.SMTP_SERVER = None # Missing required parameter + mock_config.SMTP_PORT = 465 + + mail = Mail() + mock_app = MagicMock() + + # Act & Assert + with pytest.raises(ValueError, match="SMTP_SERVER and SMTP_PORT are required"): + mail.init_app(mock_app) + + @patch("extensions.ext_mail.dify_config") + def test_mail_init_invalid_smtp_opportunistic_tls_without_tls(self, mock_config): + """ + Test mail initialization fails with opportunistic TLS but TLS disabled. + + Opportunistic TLS (STARTTLS) requires TLS to be enabled. + This test ensures the configuration is validated properly. + """ + # Arrange + from extensions.ext_mail import Mail + + mock_config.MAIL_TYPE = "smtp" + mock_config.SMTP_SERVER = "smtp.example.com" + mock_config.SMTP_PORT = 587 + mock_config.SMTP_USE_TLS = False # TLS disabled + mock_config.SMTP_OPPORTUNISTIC_TLS = True # But opportunistic TLS enabled + + mail = Mail() + mock_app = MagicMock() + + # Act & Assert + with pytest.raises(ValueError, match="SMTP_OPPORTUNISTIC_TLS is not supported without enabling SMTP_USE_TLS"): + mail.init_app(mock_app) + + @patch("extensions.ext_mail.dify_config") + def test_mail_init_unsupported_mail_type(self, mock_config): + """ + Test mail initialization fails with unsupported mail type. + + Ensures that only supported mail providers (smtp, sendgrid, resend) + are accepted and invalid types are rejected. + """ + # Arrange + from extensions.ext_mail import Mail + + mock_config.MAIL_TYPE = "unsupported_provider" + + mail = Mail() + mock_app = MagicMock() + + # Act & Assert + with pytest.raises(ValueError, match="Unsupported mail type"): + mail.init_app(mock_app) + + @patch("libs.smtp.smtplib.SMTP_SSL") + def test_smtp_send_with_empty_subject(self, mock_smtp_ssl): + """ + Test SMTP client handles empty subject gracefully. + + While not ideal, the SMTP client should be able to send + emails with empty subjects without crashing. + """ + # Arrange + from libs.smtp import SMTPClient + + mock_server = MagicMock() + mock_smtp_ssl.return_value = mock_server + + client = SMTPClient( + server="smtp.example.com", + port=465, + username="user@example.com", + password="password123", + _from="noreply@example.com", + use_tls=True, + opportunistic_tls=False, + ) + + # Email with empty subject + mail_data = {"to": "recipient@example.com", "subject": "", "html": "

Content

"} + + # Act + client.send(mail_data) + + # Assert - should still send successfully + mock_server.sendmail.assert_called_once() + + @patch("libs.smtp.smtplib.SMTP_SSL") + def test_smtp_send_with_unicode_characters(self, mock_smtp_ssl): + """ + Test SMTP client handles Unicode characters in email content. + + Ensures proper handling of international characters in + subject lines and email bodies. + """ + # Arrange + from libs.smtp import SMTPClient + + mock_server = MagicMock() + mock_smtp_ssl.return_value = mock_server + + client = SMTPClient( + server="smtp.example.com", + port=465, + username="user@example.com", + password="password123", + _from="noreply@example.com", + use_tls=True, + opportunistic_tls=False, + ) + + # Email with Unicode characters (Chinese, emoji, etc.) + mail_data = { + "to": "recipient@example.com", + "subject": "测试邮件 🎉 Test Email", + "html": "

你好世界 Hello World 🌍

", + } + + # Act + client.send(mail_data) + + # Assert + mock_server.sendmail.assert_called_once() + mock_server.quit.assert_called_once() + + @patch("tasks.mail_inner_task.get_email_i18n_service") + @patch("tasks.mail_inner_task.mail") + @patch("tasks.mail_inner_task._render_template_with_strategy") + def test_inner_email_task_with_empty_recipient_list(self, mock_render, mock_mail, mock_email_service): + """ + Test inner email task handles empty recipient list. + + When no recipients are provided, the task should handle + this gracefully without attempting to send emails. + """ + # Arrange + mock_mail.is_inited.return_value = True + mock_render.return_value = "

Content

" + mock_service = MagicMock() + mock_email_service.return_value = mock_service + + # Act + send_inner_email_task(to=[], subject="Test", body="Body", substitutions={}) + + # Assert + mock_service.send_raw_email.assert_called_once_with(to=[], subject="Test", html_content="

Content

") + + +class TestConcurrencyAndThreadSafety: + """ + Test concurrent execution and thread safety scenarios. + + These tests ensure that mail tasks can handle concurrent + execution without race conditions or resource conflicts. + """ + + @patch("tasks.mail_register_task.get_email_i18n_service") + @patch("tasks.mail_register_task.mail") + def test_multiple_mail_tasks_concurrent_execution(self, mock_mail, mock_email_service): + """ + Test multiple mail tasks can execute concurrently. + + Simulates concurrent execution of multiple mail tasks + to ensure thread safety and proper resource handling. + """ + # Arrange + mock_mail.is_inited.return_value = True + mock_service = MagicMock() + mock_email_service.return_value = mock_service + + # Act - simulate concurrent task execution + recipients = [f"user{i}@example.com" for i in range(5)] + for recipient in recipients: + send_email_register_mail_task(language="en-US", to=recipient, code="123456") + + # Assert - all tasks should complete successfully + assert mock_service.send_email.call_count == 5 + + +class TestResendIntegration: + """ + Test Resend email service integration. + + Resend is an alternative email provider that can be used + instead of SMTP or SendGrid. + """ + + @patch("builtins.__import__", side_effect=__import__) + @patch("extensions.ext_mail.dify_config") + def test_mail_init_resend_configuration(self, mock_config, mock_import): + """ + Test mail extension initializes Resend client correctly. + + Validates that Resend API key is properly configured + and the client is initialized. + """ + # Arrange + from extensions.ext_mail import Mail + + mock_config.MAIL_TYPE = "resend" + mock_config.RESEND_API_KEY = "re_test_api_key" + mock_config.RESEND_API_URL = None + mock_config.MAIL_DEFAULT_SEND_FROM = "noreply@example.com" + + # Create mock resend module + mock_resend = MagicMock() + mock_emails = MagicMock() + mock_resend.Emails = mock_emails + + # Override import for resend module + original_import = __import__ + + def custom_import(name, *args, **kwargs): + if name == "resend": + return mock_resend + return original_import(name, *args, **kwargs) + + mock_import.side_effect = custom_import + + mail = Mail() + mock_app = MagicMock() + + # Act + mail.init_app(mock_app) + + # Assert + assert mail.is_inited() is True + assert mock_resend.api_key == "re_test_api_key" + + @patch("builtins.__import__", side_effect=__import__) + @patch("extensions.ext_mail.dify_config") + def test_mail_init_resend_with_custom_url(self, mock_config, mock_import): + """ + Test mail extension initializes Resend with custom API URL. + + Some deployments may use a custom Resend API endpoint. + This test ensures custom URLs are properly configured. + """ + # Arrange + from extensions.ext_mail import Mail + + mock_config.MAIL_TYPE = "resend" + mock_config.RESEND_API_KEY = "re_test_api_key" + mock_config.RESEND_API_URL = "https://custom-resend.example.com" + mock_config.MAIL_DEFAULT_SEND_FROM = "noreply@example.com" + + # Create mock resend module + mock_resend = MagicMock() + mock_emails = MagicMock() + mock_resend.Emails = mock_emails + + # Override import for resend module + original_import = __import__ + + def custom_import(name, *args, **kwargs): + if name == "resend": + return mock_resend + return original_import(name, *args, **kwargs) + + mock_import.side_effect = custom_import + + mail = Mail() + mock_app = MagicMock() + + # Act + mail.init_app(mock_app) + + # Assert + assert mail.is_inited() is True + assert mock_resend.api_url == "https://custom-resend.example.com" + + @patch("extensions.ext_mail.dify_config") + def test_mail_init_resend_missing_api_key(self, mock_config): + """ + Test mail initialization fails when Resend API key is missing. + + Resend requires an API key to function. This test ensures + proper validation of required configuration. + """ + # Arrange + from extensions.ext_mail import Mail + + mock_config.MAIL_TYPE = "resend" + mock_config.RESEND_API_KEY = None # Missing API key + + mail = Mail() + mock_app = MagicMock() + + # Act & Assert + with pytest.raises(ValueError, match="RESEND_API_KEY is not set"): + mail.init_app(mock_app) + + +class TestTemplateContextValidation: + """ + Test template context validation and rendering. + + These tests ensure that template contexts are properly + validated and rendered with correct variable substitution. + """ + + @patch("tasks.mail_register_task.get_email_i18n_service") + @patch("tasks.mail_register_task.mail") + def test_mail_task_template_context_includes_all_required_fields(self, mock_mail, mock_email_service): + """ + Test that mail tasks include all required fields in template context. + + Template rendering requires specific context variables. + This test ensures all required fields are present. + """ + # Arrange + mock_mail.is_inited.return_value = True + mock_service = MagicMock() + mock_email_service.return_value = mock_service + + # Act + send_email_register_mail_task(language="en-US", to="test@example.com", code="ABC123") + + # Assert + call_args = mock_service.send_email.call_args + context = call_args[1]["template_context"] + + # Verify all required fields are present + assert "to" in context + assert "code" in context + assert context["to"] == "test@example.com" + assert context["code"] == "ABC123" + + def test_render_template_with_complex_nested_data(self): + """ + Test template rendering with complex nested data structures. + + Templates may need to access nested dictionaries or lists. + This test ensures complex data structures are handled correctly. + """ + # Arrange + body = ( + "User: {{ user.name }}, Items: " + "{% for item in items %}{{ item }}{% if not loop.last %}, {% endif %}{% endfor %}" + ) + substitutions = {"user": {"name": "John Doe"}, "items": ["apple", "banana", "cherry"]} + + # Act + with patch.object(dify_config, "MAIL_TEMPLATING_MODE", TemplateMode.SANDBOX): + result = _render_template_with_strategy(body, substitutions) + + # Assert + assert "John Doe" in result + assert "apple" in result + assert "banana" in result + assert "cherry" in result + + def test_render_template_with_conditional_logic(self): + """ + Test template rendering with conditional logic. + + Templates often use conditional statements to customize + content based on context variables. + """ + # Arrange + body = "{% if is_premium %}Premium User{% else %}Free User{% endif %}" + + # Act - Test with premium user + with patch.object(dify_config, "MAIL_TEMPLATING_MODE", TemplateMode.SANDBOX): + result_premium = _render_template_with_strategy(body, {"is_premium": True}) + result_free = _render_template_with_strategy(body, {"is_premium": False}) + + # Assert + assert "Premium User" in result_premium + assert "Free User" in result_free + + +class TestEmailValidation: + """ + Test email address validation and sanitization. + + These tests ensure that email addresses are properly + validated before sending to prevent errors. + """ + + @patch("extensions.ext_mail.dify_config") + def test_mail_send_with_invalid_email_format(self, mock_config): + """ + Test mail send with malformed email address. + + While the Mail class doesn't validate email format, + this test documents the current behavior. + """ + # Arrange + from extensions.ext_mail import Mail + + mail = Mail() + mock_client = MagicMock() + mail._client = mock_client + mail._default_send_from = "noreply@example.com" + + # Act - send to malformed email (no validation in Mail class) + mail.send(to="not-an-email", subject="Test", html="

Content

") + + # Assert - Mail class passes through to client + mock_client.send.assert_called_once() + + +class TestSMTPEdgeCases: + """ + Test SMTP-specific edge cases and error conditions. + + These tests cover various SMTP-specific scenarios that + may occur in production environments. + """ + + @patch("libs.smtp.smtplib.SMTP_SSL") + def test_smtp_send_with_very_large_email_body(self, mock_smtp_ssl): + """ + Test SMTP client handles large email bodies. + + Some emails may contain large HTML content with images + or extensive formatting. This test ensures they're handled. + """ + # Arrange + from libs.smtp import SMTPClient + + mock_server = MagicMock() + mock_smtp_ssl.return_value = mock_server + + client = SMTPClient( + server="smtp.example.com", + port=465, + username="user@example.com", + password="password123", + _from="noreply@example.com", + use_tls=True, + opportunistic_tls=False, + ) + + # Create a large HTML body (simulating a newsletter) + large_html = "" + "

Content paragraph

" * 1000 + "" + mail_data = {"to": "recipient@example.com", "subject": "Large Email", "html": large_html} + + # Act + client.send(mail_data) + + # Assert + mock_server.sendmail.assert_called_once() + # Verify the large content was included + sent_message = mock_server.sendmail.call_args[0][2] + assert len(sent_message) > 10000 # Should be a large message + + @patch("libs.smtp.smtplib.SMTP_SSL") + def test_smtp_send_with_multiple_recipients_in_to_field(self, mock_smtp_ssl): + """ + Test SMTP client with single recipient (current implementation). + + The current SMTPClient implementation sends to a single + recipient per call. This test documents that behavior. + """ + # Arrange + from libs.smtp import SMTPClient + + mock_server = MagicMock() + mock_smtp_ssl.return_value = mock_server + + client = SMTPClient( + server="smtp.example.com", + port=465, + username="user@example.com", + password="password123", + _from="noreply@example.com", + use_tls=True, + opportunistic_tls=False, + ) + + mail_data = {"to": "recipient@example.com", "subject": "Test", "html": "

Content

"} + + # Act + client.send(mail_data) + + # Assert - sends to single recipient + call_args = mock_server.sendmail.call_args + assert call_args[0][1] == "recipient@example.com" + + @patch("libs.smtp.smtplib.SMTP") + def test_smtp_send_with_whitespace_in_credentials(self, mock_smtp): + """ + Test SMTP client strips whitespace from credentials. + + The SMTPClient checks for non-empty credentials after stripping + whitespace to avoid authentication with blank credentials. + """ + # Arrange + from libs.smtp import SMTPClient + + mock_server = MagicMock() + mock_smtp.return_value = mock_server + + # Credentials with only whitespace + client = SMTPClient( + server="smtp.example.com", + port=25, + username=" ", # Only whitespace + password=" ", # Only whitespace + _from="noreply@example.com", + use_tls=False, + opportunistic_tls=False, + ) + + mail_data = {"to": "recipient@example.com", "subject": "Test", "html": "

Content

"} + + # Act + client.send(mail_data) + + # Assert - should NOT attempt login with whitespace-only credentials + mock_server.login.assert_not_called() + + +class TestLoggingAndMonitoring: + """ + Test logging and monitoring functionality. + + These tests ensure that mail tasks properly log their + execution for debugging and monitoring purposes. + """ + + @patch("tasks.mail_register_task.get_email_i18n_service") + @patch("tasks.mail_register_task.mail") + @patch("tasks.mail_register_task.logger") + def test_mail_task_logs_recipient_information(self, mock_logger, mock_mail, mock_email_service): + """ + Test that mail tasks log recipient information for audit trails. + + Logging recipient information helps with debugging and + tracking email delivery in production. + """ + # Arrange + mock_mail.is_inited.return_value = True + mock_service = MagicMock() + mock_email_service.return_value = mock_service + + # Act + send_email_register_mail_task(language="en-US", to="audit@example.com", code="123456") + + # Assert + # Check that recipient is logged in start message + start_log_call = mock_logger.info.call_args_list[0] + assert "audit@example.com" in str(start_log_call) + + @patch("tasks.mail_inner_task.get_email_i18n_service") + @patch("tasks.mail_inner_task.mail") + @patch("tasks.mail_inner_task.logger") + def test_inner_email_task_logs_subject_for_tracking(self, mock_logger, mock_mail, mock_email_service): + """ + Test that inner email task logs subject for tracking purposes. + + Logging email subjects helps identify which emails are being + sent and aids in debugging delivery issues. + """ + # Arrange + mock_mail.is_inited.return_value = True + mock_service = MagicMock() + mock_email_service.return_value = mock_service + + # Act + send_inner_email_task( + to=["user@example.com"], subject="Important Notification", body="

Body

", substitutions={} + ) + + # Assert + # Check that subject is logged + start_log_call = mock_logger.info.call_args_list[0] + assert "Important Notification" in str(start_log_call) From f4db5f99734c889a254c6a8fc3c47fad2d6640ca Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 1 Dec 2025 15:45:39 +0800 Subject: [PATCH 4/8] chore(deps): bump faker from 32.1.0 to 38.2.0 in /api (#28964) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- api/pyproject.toml | 2 +- api/uv.lock | 11 +++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/api/pyproject.toml b/api/pyproject.toml index a31fd758cc..d28ba91413 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -111,7 +111,7 @@ package = false dev = [ "coverage~=7.2.4", "dotenv-linter~=0.5.0", - "faker~=32.1.0", + "faker~=38.2.0", "lxml-stubs~=0.5.1", "ty~=0.0.1a19", "basedpyright~=1.31.0", diff --git a/api/uv.lock b/api/uv.lock index 963591ac27..f691e90837 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1628,7 +1628,7 @@ dev = [ { name = "celery-types", specifier = ">=0.23.0" }, { name = "coverage", specifier = "~=7.2.4" }, { name = "dotenv-linter", specifier = "~=0.5.0" }, - { name = "faker", specifier = "~=32.1.0" }, + { name = "faker", specifier = "~=38.2.0" }, { name = "hypothesis", specifier = ">=6.131.15" }, { name = "import-linter", specifier = ">=2.3" }, { name = "lxml-stubs", specifier = "~=0.5.1" }, @@ -1859,15 +1859,14 @@ wheels = [ [[package]] name = "faker" -version = "32.1.0" +version = "38.2.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "python-dateutil" }, - { name = "typing-extensions" }, + { name = "tzdata" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/1c/2a/dd2c8f55d69013d0eee30ec4c998250fb7da957f5fe860ed077b3df1725b/faker-32.1.0.tar.gz", hash = "sha256:aac536ba04e6b7beb2332c67df78485fc29c1880ff723beac6d1efd45e2f10f5", size = 1850193, upload-time = "2024-11-12T22:04:34.812Z" } +sdist = { url = "https://files.pythonhosted.org/packages/64/27/022d4dbd4c20567b4c294f79a133cc2f05240ea61e0d515ead18c995c249/faker-38.2.0.tar.gz", hash = "sha256:20672803db9c7cb97f9b56c18c54b915b6f1d8991f63d1d673642dc43f5ce7ab", size = 1941469, upload-time = "2025-11-19T16:37:31.892Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/7e/fa/4a82dea32d6262a96e6841cdd4a45c11ac09eecdff018e745565410ac70e/Faker-32.1.0-py3-none-any.whl", hash = "sha256:c77522577863c264bdc9dad3a2a750ad3f7ee43ff8185072e482992288898814", size = 1889123, upload-time = "2024-11-12T22:04:32.298Z" }, + { url = "https://files.pythonhosted.org/packages/17/93/00c94d45f55c336434a15f98d906387e87ce28f9918e4444829a8fda432d/faker-38.2.0-py3-none-any.whl", hash = "sha256:35fe4a0a79dee0dc4103a6083ee9224941e7d3594811a50e3969e547b0d2ee65", size = 1980505, upload-time = "2025-11-19T16:37:30.208Z" }, ] [[package]] From 626d4f3e356fefede5937bd23551b9a2d0e5e5c0 Mon Sep 17 00:00:00 2001 From: yyh <92089059+lyzno1@users.noreply.github.com> Date: Mon, 1 Dec 2025 15:45:50 +0800 Subject: [PATCH 5/8] fix(web): use atomic selectors to fix Zustand v5 infinite loop (#28977) --- .../workflow/panel/debug-and-preview/chat-wrapper.tsx | 6 ++---- web/app/components/workflow/panel/inputs-panel.tsx | 5 +---- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/web/app/components/workflow/panel/debug-and-preview/chat-wrapper.tsx b/web/app/components/workflow/panel/debug-and-preview/chat-wrapper.tsx index 6fba10bf81..682e91ea81 100644 --- a/web/app/components/workflow/panel/debug-and-preview/chat-wrapper.tsx +++ b/web/app/components/workflow/panel/debug-and-preview/chat-wrapper.tsx @@ -47,10 +47,8 @@ const ChatWrapper = ( const startVariables = startNode?.data.variables const appDetail = useAppStore(s => s.appDetail) const workflowStore = useWorkflowStore() - const { inputs, setInputs } = useStore(s => ({ - inputs: s.inputs, - setInputs: s.setInputs, - })) + const inputs = useStore(s => s.inputs) + const setInputs = useStore(s => s.setInputs) const initialInputs = useMemo(() => { const initInputs: Record = {} diff --git a/web/app/components/workflow/panel/inputs-panel.tsx b/web/app/components/workflow/panel/inputs-panel.tsx index 11492539df..4c9de03b8a 100644 --- a/web/app/components/workflow/panel/inputs-panel.tsx +++ b/web/app/components/workflow/panel/inputs-panel.tsx @@ -32,10 +32,7 @@ type Props = { const InputsPanel = ({ onRun }: Props) => { const { t } = useTranslation() const workflowStore = useWorkflowStore() - const { inputs } = useStore(s => ({ - inputs: s.inputs, - setInputs: s.setInputs, - })) + const inputs = useStore(s => s.inputs) const fileSettings = useHooksStore(s => s.configsMap?.fileSettings) const nodes = useNodes() const files = useStore(s => s.files) From 0a22bc5d05160afa0334e620a333699af1e2e2c0 Mon Sep 17 00:00:00 2001 From: yyh <92089059+lyzno1@users.noreply.github.com> Date: Mon, 1 Dec 2025 19:23:42 +0800 Subject: [PATCH 6/8] fix(web): use atomic selectors in AccessControlItem (#28983) --- .../app/app-access-control/access-control-item.tsx | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/web/app/components/app/app-access-control/access-control-item.tsx b/web/app/components/app/app-access-control/access-control-item.tsx index 0840902371..ce3bf5d275 100644 --- a/web/app/components/app/app-access-control/access-control-item.tsx +++ b/web/app/components/app/app-access-control/access-control-item.tsx @@ -1,6 +1,6 @@ 'use client' import type { FC, PropsWithChildren } from 'react' -import useAccessControlStore from '../../../../context/access-control-store' +import useAccessControlStore from '@/context/access-control-store' import type { AccessMode } from '@/models/access-control' type AccessControlItemProps = PropsWithChildren<{ @@ -8,7 +8,8 @@ type AccessControlItemProps = PropsWithChildren<{ }> const AccessControlItem: FC = ({ type, children }) => { - const { currentMenu, setCurrentMenu } = useAccessControlStore(s => ({ currentMenu: s.currentMenu, setCurrentMenu: s.setCurrentMenu })) + const currentMenu = useAccessControlStore(s => s.currentMenu) + const setCurrentMenu = useAccessControlStore(s => s.setCurrentMenu) if (currentMenu !== type) { return
Date: Mon, 1 Dec 2025 22:25:08 -0500 Subject: [PATCH 8/8] feat: complete test script of plugin manager (#28967) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- .../core/plugin/test_plugin_manager.py | 1422 +++++++++++++++++ 1 file changed, 1422 insertions(+) create mode 100644 api/tests/unit_tests/core/plugin/test_plugin_manager.py diff --git a/api/tests/unit_tests/core/plugin/test_plugin_manager.py b/api/tests/unit_tests/core/plugin/test_plugin_manager.py new file mode 100644 index 0000000000..510aedd551 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/test_plugin_manager.py @@ -0,0 +1,1422 @@ +""" +Unit tests for Plugin Manager (PluginInstaller). + +This module tests the plugin management functionality including: +- Plugin discovery and listing +- Plugin loading and installation +- Plugin validation and manifest parsing +- Version compatibility checks +- Dependency resolution +""" + +import datetime +from unittest.mock import patch + +import httpx +import pytest +from packaging.version import Version +from requests import HTTPError + +from core.plugin.entities.bundle import PluginBundleDependency +from core.plugin.entities.plugin import ( + MissingPluginDependency, + PluginCategory, + PluginDeclaration, + PluginEntity, + PluginInstallation, + PluginInstallationSource, + PluginResourceRequirements, +) +from core.plugin.entities.plugin_daemon import ( + PluginDecodeResponse, + PluginInstallTask, + PluginInstallTaskStartResponse, + PluginInstallTaskStatus, + PluginListResponse, + PluginReadmeResponse, + PluginVerification, +) +from core.plugin.impl.exc import ( + PluginDaemonBadRequestError, + PluginDaemonInternalServerError, + PluginDaemonNotFoundError, +) +from core.plugin.impl.plugin import PluginInstaller +from core.tools.entities.common_entities import I18nObject +from models.provider_ids import GenericProviderID + + +class TestPluginDiscovery: + """Test plugin discovery functionality.""" + + @pytest.fixture + def plugin_installer(self): + """Create a PluginInstaller instance for testing.""" + return PluginInstaller() + + @pytest.fixture + def mock_plugin_entity(self): + """Create a mock PluginEntity for testing.""" + return PluginEntity( + id="entity-123", + created_at=datetime.datetime(2023, 1, 1, 0, 0, 0), + updated_at=datetime.datetime(2023, 1, 1, 0, 0, 0), + tenant_id="test-tenant", + endpoints_setups=0, + endpoints_active=0, + runtime_type="remote", + source=PluginInstallationSource.Marketplace, + meta={}, + plugin_id="plugin-123", + plugin_unique_identifier="test-org/test-plugin/1.0.0", + version="1.0.0", + checksum="abc123", + name="Test Plugin", + installation_id="install-123", + declaration=PluginDeclaration( + version="1.0.0", + author="test-author", + name="test-plugin", + description=I18nObject(en_US="Test plugin description", zh_Hans="测试插件描述"), + icon="icon.png", + label=I18nObject(en_US="Test Plugin", zh_Hans="测试插件"), + category=PluginCategory.Tool, + created_at=datetime.datetime.now(), + resource=PluginResourceRequirements(memory=512, permission=None), + plugins=PluginDeclaration.Plugins(), + meta=PluginDeclaration.Meta(version="1.0.0"), + ), + ) + + def test_list_plugins_success(self, plugin_installer, mock_plugin_entity): + """Test successful plugin listing.""" + # Arrange: Mock the HTTP response for listing plugins + mock_response = PluginListResponse(list=[mock_plugin_entity], total=1) + + with patch.object( + plugin_installer, "_request_with_plugin_daemon_response", return_value=mock_response + ) as mock_request: + # Act: List plugins for a tenant + result = plugin_installer.list_plugins("test-tenant") + + # Assert: Verify the request was made correctly + mock_request.assert_called_once() + assert len(result) == 1 + assert result[0].plugin_id == "plugin-123" + assert result[0].name == "Test Plugin" + + def test_list_plugins_with_pagination(self, plugin_installer, mock_plugin_entity): + """Test plugin listing with pagination support.""" + # Arrange: Mock paginated response + mock_response = PluginListResponse(list=[mock_plugin_entity], total=10) + + with patch.object( + plugin_installer, "_request_with_plugin_daemon_response", return_value=mock_response + ) as mock_request: + # Act: List plugins with pagination + result = plugin_installer.list_plugins_with_total("test-tenant", page=1, page_size=5) + + # Assert: Verify pagination parameters + mock_request.assert_called_once() + call_args = mock_request.call_args + assert call_args[1]["params"]["page"] == 1 + assert call_args[1]["params"]["page_size"] == 5 + assert result.total == 10 + + def test_list_plugins_empty_result(self, plugin_installer): + """Test plugin listing when no plugins are installed.""" + # Arrange: Mock empty response + mock_response = PluginListResponse(list=[], total=0) + + with patch.object(plugin_installer, "_request_with_plugin_daemon_response", return_value=mock_response): + # Act: List plugins + result = plugin_installer.list_plugins("test-tenant") + + # Assert: Verify empty list is returned + assert len(result) == 0 + + def test_fetch_plugin_by_identifier_found(self, plugin_installer): + """Test fetching a plugin by its unique identifier when it exists.""" + # Arrange: Mock successful fetch + with patch.object(plugin_installer, "_request_with_plugin_daemon_response", return_value=True) as mock_request: + # Act: Fetch plugin by identifier + result = plugin_installer.fetch_plugin_by_identifier("test-tenant", "test-org/test-plugin/1.0.0") + + # Assert: Verify the plugin was found + assert result is True + mock_request.assert_called_once() + + def test_fetch_plugin_by_identifier_not_found(self, plugin_installer): + """Test fetching a plugin by identifier when it doesn't exist.""" + # Arrange: Mock not found response + with patch.object(plugin_installer, "_request_with_plugin_daemon_response", return_value=False): + # Act: Fetch non-existent plugin + result = plugin_installer.fetch_plugin_by_identifier("test-tenant", "non-existent/plugin/1.0.0") + + # Assert: Verify the plugin was not found + assert result is False + + +class TestPluginLoading: + """Test plugin loading and installation functionality.""" + + @pytest.fixture + def plugin_installer(self): + """Create a PluginInstaller instance for testing.""" + return PluginInstaller() + + @pytest.fixture + def mock_plugin_declaration(self): + """Create a mock PluginDeclaration for testing.""" + return PluginDeclaration( + version="1.0.0", + author="test-author", + name="test-plugin", + description=I18nObject(en_US="Test plugin", zh_Hans="测试插件"), + icon="icon.png", + label=I18nObject(en_US="Test Plugin", zh_Hans="测试插件"), + category=PluginCategory.Tool, + created_at=datetime.datetime.now(), + resource=PluginResourceRequirements(memory=512, permission=None), + plugins=PluginDeclaration.Plugins(), + meta=PluginDeclaration.Meta(version="1.0.0"), + ) + + def test_upload_pkg_success(self, plugin_installer, mock_plugin_declaration): + """Test successful plugin package upload.""" + # Arrange: Create mock package data and expected response + pkg_data = b"mock-plugin-package-data" + mock_response = PluginDecodeResponse( + unique_identifier="test-org/test-plugin/1.0.0", + manifest=mock_plugin_declaration, + verification=PluginVerification(authorized_category=PluginVerification.AuthorizedCategory.Community), + ) + + with patch.object( + plugin_installer, "_request_with_plugin_daemon_response", return_value=mock_response + ) as mock_request: + # Act: Upload plugin package + result = plugin_installer.upload_pkg("test-tenant", pkg_data, verify_signature=False) + + # Assert: Verify upload was successful + assert result.unique_identifier == "test-org/test-plugin/1.0.0" + assert result.manifest.name == "test-plugin" + mock_request.assert_called_once() + + def test_upload_pkg_with_signature_verification(self, plugin_installer, mock_plugin_declaration): + """Test plugin package upload with signature verification enabled.""" + # Arrange: Create mock package data + pkg_data = b"signed-plugin-package" + mock_response = PluginDecodeResponse( + unique_identifier="verified-org/verified-plugin/1.0.0", + manifest=mock_plugin_declaration, + verification=PluginVerification(authorized_category=PluginVerification.AuthorizedCategory.Partner), + ) + + with patch.object( + plugin_installer, "_request_with_plugin_daemon_response", return_value=mock_response + ) as mock_request: + # Act: Upload with signature verification + result = plugin_installer.upload_pkg("test-tenant", pkg_data, verify_signature=True) + + # Assert: Verify signature verification was requested + call_args = mock_request.call_args + assert call_args[1]["data"]["verify_signature"] == "true" + assert result.verification.authorized_category == PluginVerification.AuthorizedCategory.Partner + + def test_install_from_identifiers_success(self, plugin_installer): + """Test successful plugin installation from identifiers.""" + # Arrange: Mock installation response + mock_response = PluginInstallTaskStartResponse(all_installed=False, task_id="task-123") + + with patch.object( + plugin_installer, "_request_with_plugin_daemon_response", return_value=mock_response + ) as mock_request: + # Act: Install plugins from identifiers + result = plugin_installer.install_from_identifiers( + tenant_id="test-tenant", + identifiers=["plugin1/1.0.0", "plugin2/2.0.0"], + source=PluginInstallationSource.Marketplace, + metas=[{"key": "value1"}, {"key": "value2"}], + ) + + # Assert: Verify installation task was created + assert result.task_id == "task-123" + assert result.all_installed is False + mock_request.assert_called_once() + + def test_install_from_identifiers_all_installed(self, plugin_installer): + """Test installation when all plugins are already installed.""" + # Arrange: Mock response indicating all plugins are installed + mock_response = PluginInstallTaskStartResponse(all_installed=True, task_id="") + + with patch.object(plugin_installer, "_request_with_plugin_daemon_response", return_value=mock_response): + # Act: Attempt to install already-installed plugins + result = plugin_installer.install_from_identifiers( + tenant_id="test-tenant", + identifiers=["existing-plugin/1.0.0"], + source=PluginInstallationSource.Package, + metas=[{}], + ) + + # Assert: Verify all_installed flag is True + assert result.all_installed is True + + def test_fetch_plugin_installation_task(self, plugin_installer): + """Test fetching a specific plugin installation task.""" + # Arrange: Mock installation task + mock_task = PluginInstallTask( + id="task-123", + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + status=PluginInstallTaskStatus.Running, + total_plugins=3, + completed_plugins=1, + plugins=[], + ) + + with patch.object( + plugin_installer, "_request_with_plugin_daemon_response", return_value=mock_task + ) as mock_request: + # Act: Fetch installation task + result = plugin_installer.fetch_plugin_installation_task("test-tenant", "task-123") + + # Assert: Verify task details + assert result.status == PluginInstallTaskStatus.Running + assert result.total_plugins == 3 + assert result.completed_plugins == 1 + mock_request.assert_called_once() + + def test_uninstall_plugin_success(self, plugin_installer): + """Test successful plugin uninstallation.""" + # Arrange: Mock successful uninstall + with patch.object(plugin_installer, "_request_with_plugin_daemon_response", return_value=True) as mock_request: + # Act: Uninstall plugin + result = plugin_installer.uninstall("test-tenant", "install-123") + + # Assert: Verify uninstallation succeeded + assert result is True + mock_request.assert_called_once() + + def test_upgrade_plugin_success(self, plugin_installer): + """Test successful plugin upgrade.""" + # Arrange: Mock upgrade response + mock_response = PluginInstallTaskStartResponse(all_installed=False, task_id="upgrade-task-123") + + with patch.object( + plugin_installer, "_request_with_plugin_daemon_response", return_value=mock_response + ) as mock_request: + # Act: Upgrade plugin + result = plugin_installer.upgrade_plugin( + tenant_id="test-tenant", + original_plugin_unique_identifier="plugin/1.0.0", + new_plugin_unique_identifier="plugin/2.0.0", + source=PluginInstallationSource.Marketplace, + meta={"upgrade": "true"}, + ) + + # Assert: Verify upgrade task was created + assert result.task_id == "upgrade-task-123" + mock_request.assert_called_once() + + +class TestPluginValidation: + """Test plugin validation and manifest parsing.""" + + @pytest.fixture + def plugin_installer(self): + """Create a PluginInstaller instance for testing.""" + return PluginInstaller() + + def test_fetch_plugin_manifest_success(self, plugin_installer): + """Test successful plugin manifest fetching.""" + # Arrange: Create a valid plugin declaration + mock_manifest = PluginDeclaration( + version="1.0.0", + author="test-author", + name="test-plugin", + description=I18nObject(en_US="Test plugin", zh_Hans="测试插件"), + icon="icon.png", + label=I18nObject(en_US="Test Plugin", zh_Hans="测试插件"), + category=PluginCategory.Tool, + created_at=datetime.datetime.now(), + resource=PluginResourceRequirements(memory=512, permission=None), + plugins=PluginDeclaration.Plugins(), + meta=PluginDeclaration.Meta(version="1.0.0", minimum_dify_version="0.6.0"), + ) + + with patch.object( + plugin_installer, "_request_with_plugin_daemon_response", return_value=mock_manifest + ) as mock_request: + # Act: Fetch plugin manifest + result = plugin_installer.fetch_plugin_manifest("test-tenant", "test-org/test-plugin/1.0.0") + + # Assert: Verify manifest was fetched correctly + assert result.name == "test-plugin" + assert result.version == "1.0.0" + assert result.author == "test-author" + assert result.meta.minimum_dify_version == "0.6.0" + mock_request.assert_called_once() + + def test_decode_plugin_from_identifier(self, plugin_installer): + """Test decoding plugin information from identifier.""" + # Arrange: Create mock decode response + mock_declaration = PluginDeclaration( + version="2.0.0", + author="decode-author", + name="decode-plugin", + description=I18nObject(en_US="Decoded plugin", zh_Hans="解码插件"), + icon="icon.png", + label=I18nObject(en_US="Decode Plugin", zh_Hans="解码插件"), + category=PluginCategory.Model, + created_at=datetime.datetime.now(), + resource=PluginResourceRequirements(memory=1024, permission=None), + plugins=PluginDeclaration.Plugins(), + meta=PluginDeclaration.Meta(version="2.0.0"), + ) + + mock_response = PluginDecodeResponse( + unique_identifier="org/decode-plugin/2.0.0", + manifest=mock_declaration, + verification=None, + ) + + with patch.object(plugin_installer, "_request_with_plugin_daemon_response", return_value=mock_response): + # Act: Decode plugin from identifier + result = plugin_installer.decode_plugin_from_identifier("test-tenant", "org/decode-plugin/2.0.0") + + # Assert: Verify decoded information + assert result.unique_identifier == "org/decode-plugin/2.0.0" + assert result.manifest.name == "decode-plugin" + # Category will be Extension unless a model provider entity is provided + assert result.manifest.category == PluginCategory.Extension + + def test_plugin_manifest_invalid_version_format(self): + """Test that invalid version format raises validation error.""" + # Arrange & Act & Assert: Creating a declaration with invalid version should fail + with pytest.raises(ValueError, match="Invalid version format"): + PluginDeclaration( + version="invalid-version", # Invalid version format + author="test-author", + name="test-plugin", + description=I18nObject(en_US="Test", zh_Hans="测试"), + icon="icon.png", + label=I18nObject(en_US="Test", zh_Hans="测试"), + category=PluginCategory.Tool, + created_at=datetime.datetime.now(), + resource=PluginResourceRequirements(memory=512, permission=None), + plugins=PluginDeclaration.Plugins(), + meta=PluginDeclaration.Meta(version="1.0.0"), + ) + + def test_plugin_manifest_invalid_author_format(self): + """Test that invalid author format raises validation error.""" + # Arrange & Act & Assert: Creating a declaration with invalid author should fail + with pytest.raises(ValueError): + PluginDeclaration( + version="1.0.0", + author="invalid author with spaces!@#", # Invalid author format + name="test-plugin", + description=I18nObject(en_US="Test", zh_Hans="测试"), + icon="icon.png", + label=I18nObject(en_US="Test", zh_Hans="测试"), + category=PluginCategory.Tool, + created_at=datetime.datetime.now(), + resource=PluginResourceRequirements(memory=512, permission=None), + plugins=PluginDeclaration.Plugins(), + meta=PluginDeclaration.Meta(version="1.0.0"), + ) + + def test_plugin_manifest_invalid_name_format(self): + """Test that invalid plugin name format raises validation error.""" + # Arrange & Act & Assert: Creating a declaration with invalid name should fail + with pytest.raises(ValueError): + PluginDeclaration( + version="1.0.0", + author="test-author", + name="Invalid_Plugin_Name_With_Uppercase", # Invalid name format + description=I18nObject(en_US="Test", zh_Hans="测试"), + icon="icon.png", + label=I18nObject(en_US="Test", zh_Hans="测试"), + category=PluginCategory.Tool, + created_at=datetime.datetime.now(), + resource=PluginResourceRequirements(memory=512, permission=None), + plugins=PluginDeclaration.Plugins(), + meta=PluginDeclaration.Meta(version="1.0.0"), + ) + + def test_fetch_plugin_readme_success(self, plugin_installer): + """Test successful plugin readme fetching.""" + # Arrange: Mock readme response + mock_response = PluginReadmeResponse(content="# Test Plugin\n\nThis is a test plugin.", language="en_US") + + with patch.object(plugin_installer, "_request_with_plugin_daemon_response", return_value=mock_response): + # Act: Fetch plugin readme + result = plugin_installer.fetch_plugin_readme("test-tenant", "test-org/test-plugin/1.0.0", "en_US") + + # Assert: Verify readme content + assert result == "# Test Plugin\n\nThis is a test plugin." + + def test_fetch_plugin_readme_not_found(self, plugin_installer): + """Test fetching readme when it doesn't exist (404 error).""" + # Arrange: Mock HTTP 404 error - the actual implementation catches HTTPError from requests library + mock_error = HTTPError("404 Not Found") + + with patch.object(plugin_installer, "_request_with_plugin_daemon_response", side_effect=mock_error): + # Act: Fetch non-existent readme + result = plugin_installer.fetch_plugin_readme("test-tenant", "test-org/test-plugin/1.0.0", "en_US") + + # Assert: Verify empty string is returned for 404 + assert result == "" + + +class TestVersionCompatibility: + """Test version compatibility checks.""" + + def test_valid_version_format(self): + """Test that valid semantic versions are accepted.""" + # Arrange & Act: Create declarations with various valid version formats + valid_versions = ["1.0.0", "2.1.3", "0.0.1", "10.20.30"] + + for version in valid_versions: + # Assert: All valid versions should be accepted + declaration = PluginDeclaration( + version=version, + author="test-author", + name="test-plugin", + description=I18nObject(en_US="Test", zh_Hans="测试"), + icon="icon.png", + label=I18nObject(en_US="Test", zh_Hans="测试"), + category=PluginCategory.Tool, + created_at=datetime.datetime.now(), + resource=PluginResourceRequirements(memory=512, permission=None), + plugins=PluginDeclaration.Plugins(), + meta=PluginDeclaration.Meta(version=version), + ) + assert declaration.version == version + + def test_minimum_dify_version_validation(self): + """Test minimum Dify version validation.""" + # Arrange & Act: Create declaration with minimum Dify version + declaration = PluginDeclaration( + version="1.0.0", + author="test-author", + name="test-plugin", + description=I18nObject(en_US="Test", zh_Hans="测试"), + icon="icon.png", + label=I18nObject(en_US="Test", zh_Hans="测试"), + category=PluginCategory.Tool, + created_at=datetime.datetime.now(), + resource=PluginResourceRequirements(memory=512, permission=None), + plugins=PluginDeclaration.Plugins(), + meta=PluginDeclaration.Meta(version="1.0.0", minimum_dify_version="0.6.0"), + ) + + # Assert: Verify minimum version is set correctly + assert declaration.meta.minimum_dify_version == "0.6.0" + + def test_invalid_minimum_dify_version(self): + """Test that invalid minimum Dify version format raises error.""" + # Arrange & Act & Assert: Invalid minimum version should raise ValueError + with pytest.raises(ValueError, match="Invalid version format"): + PluginDeclaration.Meta(version="1.0.0", minimum_dify_version="invalid.version") + + def test_version_comparison_logic(self): + """Test version comparison using packaging.version.Version.""" + # Arrange: Create version objects for comparison + v1 = Version("1.0.0") + v2 = Version("2.0.0") + v3 = Version("1.5.0") + + # Act & Assert: Verify version comparison works correctly + assert v1 < v2 + assert v2 > v1 + assert v1 < v3 < v2 + assert v1 == Version("1.0.0") + + def test_plugin_upgrade_version_check(self): + """Test that plugin upgrade requires newer version.""" + # Arrange: Define old and new versions + old_version = Version("1.0.0") + new_version = Version("2.0.0") + same_version = Version("1.0.0") + + # Act & Assert: Verify version upgrade logic + assert new_version > old_version # Valid upgrade + assert not (same_version > old_version) # Invalid upgrade (same version) + + +class TestDependencyResolution: + """Test plugin dependency resolution.""" + + @pytest.fixture + def plugin_installer(self): + """Create a PluginInstaller instance for testing.""" + return PluginInstaller() + + def test_upload_bundle_with_dependencies(self, plugin_installer): + """Test uploading a plugin bundle and extracting dependencies.""" + # Arrange: Create mock bundle data and dependencies + bundle_data = b"mock-bundle-data" + mock_dependencies = [ + PluginBundleDependency( + type=PluginBundleDependency.Type.Marketplace, + value=PluginBundleDependency.Marketplace(organization="org1", plugin="plugin1", version="1.0.0"), + ), + PluginBundleDependency( + type=PluginBundleDependency.Type.Github, + value=PluginBundleDependency.Github( + repo_address="https://github.com/org/repo", + repo="org/repo", + release="v1.0.0", + packages="plugin.zip", + ), + ), + ] + + with patch.object( + plugin_installer, "_request_with_plugin_daemon_response", return_value=mock_dependencies + ) as mock_request: + # Act: Upload bundle + result = plugin_installer.upload_bundle("test-tenant", bundle_data, verify_signature=False) + + # Assert: Verify dependencies were extracted + assert len(result) == 2 + assert result[0].type == PluginBundleDependency.Type.Marketplace + assert result[1].type == PluginBundleDependency.Type.Github + mock_request.assert_called_once() + + def test_fetch_missing_dependencies(self, plugin_installer): + """Test fetching missing dependencies for plugins.""" + # Arrange: Mock missing dependencies response + mock_missing = [ + MissingPluginDependency(plugin_unique_identifier="dep1/1.0.0", current_identifier=None), + MissingPluginDependency(plugin_unique_identifier="dep2/2.0.0", current_identifier="dep2/1.0.0"), + ] + + with patch.object( + plugin_installer, "_request_with_plugin_daemon_response", return_value=mock_missing + ) as mock_request: + # Act: Fetch missing dependencies + result = plugin_installer.fetch_missing_dependencies("test-tenant", ["plugin1/1.0.0", "plugin2/2.0.0"]) + + # Assert: Verify missing dependencies were identified + assert len(result) == 2 + assert result[0].plugin_unique_identifier == "dep1/1.0.0" + assert result[1].current_identifier == "dep2/1.0.0" + mock_request.assert_called_once() + + def test_fetch_missing_dependencies_none_missing(self, plugin_installer): + """Test fetching missing dependencies when all are satisfied.""" + # Arrange: Mock empty missing dependencies + with patch.object(plugin_installer, "_request_with_plugin_daemon_response", return_value=[]): + # Act: Fetch missing dependencies + result = plugin_installer.fetch_missing_dependencies("test-tenant", ["plugin1/1.0.0"]) + + # Assert: Verify no missing dependencies + assert len(result) == 0 + + def test_fetch_plugin_installation_by_ids(self, plugin_installer): + """Test fetching plugin installations by their IDs.""" + # Arrange: Create mock plugin installations + mock_installations = [ + PluginInstallation( + id="install-1", + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + tenant_id="test-tenant", + endpoints_setups=0, + endpoints_active=0, + runtime_type="remote", + source=PluginInstallationSource.Marketplace, + meta={}, + plugin_id="plugin-1", + plugin_unique_identifier="org/plugin1/1.0.0", + version="1.0.0", + checksum="abc123", + declaration=PluginDeclaration( + version="1.0.0", + author="author1", + name="plugin1", + description=I18nObject(en_US="Plugin 1", zh_Hans="插件1"), + icon="icon.png", + label=I18nObject(en_US="Plugin 1", zh_Hans="插件1"), + category=PluginCategory.Tool, + created_at=datetime.datetime.now(), + resource=PluginResourceRequirements(memory=512, permission=None), + plugins=PluginDeclaration.Plugins(), + meta=PluginDeclaration.Meta(version="1.0.0"), + ), + ) + ] + + with patch.object( + plugin_installer, "_request_with_plugin_daemon_response", return_value=mock_installations + ) as mock_request: + # Act: Fetch installations by IDs + result = plugin_installer.fetch_plugin_installation_by_ids("test-tenant", ["plugin-1", "plugin-2"]) + + # Assert: Verify installations were fetched + assert len(result) == 1 + assert result[0].plugin_id == "plugin-1" + mock_request.assert_called_once() + + def test_dependency_chain_resolution(self, plugin_installer): + """Test resolving a chain of dependencies.""" + # Arrange: Create a dependency chain scenario + # Plugin A depends on Plugin B, Plugin B depends on Plugin C + mock_missing = [ + MissingPluginDependency(plugin_unique_identifier="plugin-b/1.0.0", current_identifier=None), + MissingPluginDependency(plugin_unique_identifier="plugin-c/1.0.0", current_identifier=None), + ] + + with patch.object(plugin_installer, "_request_with_plugin_daemon_response", return_value=mock_missing): + # Act: Fetch missing dependencies for Plugin A + result = plugin_installer.fetch_missing_dependencies("test-tenant", ["plugin-a/1.0.0"]) + + # Assert: Verify all dependencies in the chain are identified + assert len(result) == 2 + identifiers = [dep.plugin_unique_identifier for dep in result] + assert "plugin-b/1.0.0" in identifiers + assert "plugin-c/1.0.0" in identifiers + + def test_check_tools_existence(self, plugin_installer): + """Test checking if plugin tools exist.""" + # Arrange: Create provider IDs to check using the correct format + provider_ids = [ + GenericProviderID("org1/plugin1/provider1"), + GenericProviderID("org2/plugin2/provider2"), + ] + + # Mock response indicating first exists, second doesn't + mock_response = [True, False] + + with patch.object( + plugin_installer, "_request_with_plugin_daemon_response", return_value=mock_response + ) as mock_request: + # Act: Check tools existence + result = plugin_installer.check_tools_existence("test-tenant", provider_ids) + + # Assert: Verify existence check results + assert len(result) == 2 + assert result[0] is True + assert result[1] is False + mock_request.assert_called_once() + + +class TestPluginTaskManagement: + """Test plugin installation task management.""" + + @pytest.fixture + def plugin_installer(self): + """Create a PluginInstaller instance for testing.""" + return PluginInstaller() + + def test_fetch_plugin_installation_tasks(self, plugin_installer): + """Test fetching multiple plugin installation tasks.""" + # Arrange: Create mock installation tasks + mock_tasks = [ + PluginInstallTask( + id="task-1", + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + status=PluginInstallTaskStatus.Running, + total_plugins=2, + completed_plugins=1, + plugins=[], + ), + PluginInstallTask( + id="task-2", + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + status=PluginInstallTaskStatus.Success, + total_plugins=1, + completed_plugins=1, + plugins=[], + ), + ] + + with patch.object( + plugin_installer, "_request_with_plugin_daemon_response", return_value=mock_tasks + ) as mock_request: + # Act: Fetch installation tasks + result = plugin_installer.fetch_plugin_installation_tasks("test-tenant", page=1, page_size=10) + + # Assert: Verify tasks were fetched + assert len(result) == 2 + assert result[0].status == PluginInstallTaskStatus.Running + assert result[1].status == PluginInstallTaskStatus.Success + mock_request.assert_called_once() + + def test_delete_plugin_installation_task(self, plugin_installer): + """Test deleting a specific plugin installation task.""" + # Arrange: Mock successful deletion + with patch.object(plugin_installer, "_request_with_plugin_daemon_response", return_value=True) as mock_request: + # Act: Delete installation task + result = plugin_installer.delete_plugin_installation_task("test-tenant", "task-123") + + # Assert: Verify deletion succeeded + assert result is True + mock_request.assert_called_once() + + def test_delete_all_plugin_installation_task_items(self, plugin_installer): + """Test deleting all plugin installation task items.""" + # Arrange: Mock successful deletion of all items + with patch.object(plugin_installer, "_request_with_plugin_daemon_response", return_value=True) as mock_request: + # Act: Delete all task items + result = plugin_installer.delete_all_plugin_installation_task_items("test-tenant") + + # Assert: Verify all items were deleted + assert result is True + mock_request.assert_called_once() + + def test_delete_plugin_installation_task_item(self, plugin_installer): + """Test deleting a specific item from an installation task.""" + # Arrange: Mock successful item deletion + with patch.object(plugin_installer, "_request_with_plugin_daemon_response", return_value=True) as mock_request: + # Act: Delete specific task item + result = plugin_installer.delete_plugin_installation_task_item( + "test-tenant", "task-123", "plugin-identifier" + ) + + # Assert: Verify item was deleted + assert result is True + mock_request.assert_called_once() + + +class TestErrorHandling: + """Test error handling in plugin manager.""" + + @pytest.fixture + def plugin_installer(self): + """Create a PluginInstaller instance for testing.""" + return PluginInstaller() + + def test_plugin_not_found_error(self, plugin_installer): + """Test handling of plugin not found error.""" + # Arrange: Mock plugin daemon not found error + with patch.object( + plugin_installer, + "_request_with_plugin_daemon_response", + side_effect=PluginDaemonNotFoundError("Plugin not found"), + ): + # Act & Assert: Verify error is raised + with pytest.raises(PluginDaemonNotFoundError): + plugin_installer.fetch_plugin_manifest("test-tenant", "non-existent/plugin/1.0.0") + + def test_plugin_bad_request_error(self, plugin_installer): + """Test handling of bad request error.""" + # Arrange: Mock bad request error + with patch.object( + plugin_installer, + "_request_with_plugin_daemon_response", + side_effect=PluginDaemonBadRequestError("Invalid request"), + ): + # Act & Assert: Verify error is raised + with pytest.raises(PluginDaemonBadRequestError): + plugin_installer.install_from_identifiers("test-tenant", [], PluginInstallationSource.Marketplace, []) + + def test_plugin_internal_server_error(self, plugin_installer): + """Test handling of internal server error.""" + # Arrange: Mock internal server error + with patch.object( + plugin_installer, + "_request_with_plugin_daemon_response", + side_effect=PluginDaemonInternalServerError("Internal error"), + ): + # Act & Assert: Verify error is raised + with pytest.raises(PluginDaemonInternalServerError): + plugin_installer.list_plugins("test-tenant") + + def test_http_error_handling(self, plugin_installer): + """Test handling of HTTP errors during requests.""" + # Arrange: Mock HTTP error + with patch.object(plugin_installer, "_request", side_effect=httpx.RequestError("Connection failed")): + # Act & Assert: Verify appropriate error handling + with pytest.raises(httpx.RequestError): + plugin_installer._request("GET", "test/path") + + +class TestPluginCategoryDetection: + """Test automatic plugin category detection.""" + + def test_category_defaults_to_extension_without_tool_provider(self): + """Test that plugins without tool providers default to Extension category.""" + # Arrange: Create declaration - category is auto-detected based on provider presence + # The model_validator in PluginDeclaration automatically sets category based on which provider is present + # Since we're not providing a tool provider entity, it defaults to Extension + # This test verifies that explicitly set categories are preserved + declaration = PluginDeclaration( + version="1.0.0", + author="test-author", + name="tool-plugin", + description=I18nObject(en_US="Tool plugin", zh_Hans="工具插件"), + icon="icon.png", + label=I18nObject(en_US="Tool Plugin", zh_Hans="工具插件"), + category=PluginCategory.Extension, # Will be Extension without a tool provider + created_at=datetime.datetime.now(), + resource=PluginResourceRequirements(memory=512, permission=None), + plugins=PluginDeclaration.Plugins(), + meta=PluginDeclaration.Meta(version="1.0.0"), + ) + + # Assert: Verify category defaults to Extension when no provider is specified + assert declaration.category == PluginCategory.Extension + + def test_category_defaults_to_extension_without_model_provider(self): + """Test that plugins without model providers default to Extension category.""" + # Arrange: Create declaration - without a model provider entity, defaults to Extension + # The category is auto-detected in the model_validator based on provider presence + declaration = PluginDeclaration( + version="1.0.0", + author="test-author", + name="model-plugin", + description=I18nObject(en_US="Model plugin", zh_Hans="模型插件"), + icon="icon.png", + label=I18nObject(en_US="Model Plugin", zh_Hans="模型插件"), + category=PluginCategory.Extension, # Will be Extension without a model provider + created_at=datetime.datetime.now(), + resource=PluginResourceRequirements(memory=1024, permission=None), + plugins=PluginDeclaration.Plugins(), + meta=PluginDeclaration.Meta(version="1.0.0"), + ) + + # Assert: Verify category defaults to Extension when no provider is specified + assert declaration.category == PluginCategory.Extension + + def test_extension_category_default(self): + """Test that plugins without specific providers default to Extension.""" + # Arrange: Create declaration without specific provider type + declaration = PluginDeclaration( + version="1.0.0", + author="test-author", + name="extension-plugin", + description=I18nObject(en_US="Extension plugin", zh_Hans="扩展插件"), + icon="icon.png", + label=I18nObject(en_US="Extension Plugin", zh_Hans="扩展插件"), + category=PluginCategory.Extension, + created_at=datetime.datetime.now(), + resource=PluginResourceRequirements(memory=512, permission=None), + plugins=PluginDeclaration.Plugins(), + meta=PluginDeclaration.Meta(version="1.0.0"), + ) + + # Assert: Verify category is Extension + assert declaration.category == PluginCategory.Extension + + +class TestPluginResourceRequirements: + """Test plugin resource requirements and permissions.""" + + def test_default_resource_requirements(self): + """ + Test that plugin resource requirements can be created with default values. + + Resource requirements define the memory and permissions needed for a plugin to run. + This test verifies that a basic resource requirement with only memory can be created. + """ + # Arrange & Act: Create resource requirements with only memory specified + resources = PluginResourceRequirements(memory=512, permission=None) + + # Assert: Verify memory is set correctly and permissions are None + assert resources.memory == 512 + assert resources.permission is None + + def test_resource_requirements_with_tool_permission(self): + """ + Test plugin resource requirements with tool permissions enabled. + + Tool permissions allow a plugin to provide tool functionality. + This test verifies that tool permissions can be properly configured. + """ + # Arrange & Act: Create resource requirements with tool permissions + resources = PluginResourceRequirements( + memory=1024, + permission=PluginResourceRequirements.Permission( + tool=PluginResourceRequirements.Permission.Tool(enabled=True) + ), + ) + + # Assert: Verify tool permissions are enabled + assert resources.memory == 1024 + assert resources.permission is not None + assert resources.permission.tool is not None + assert resources.permission.tool.enabled is True + + def test_resource_requirements_with_model_permissions(self): + """ + Test plugin resource requirements with model permissions. + + Model permissions allow a plugin to provide various AI model capabilities + including LLM, text embedding, rerank, TTS, speech-to-text, and moderation. + """ + # Arrange & Act: Create resource requirements with comprehensive model permissions + resources = PluginResourceRequirements( + memory=2048, + permission=PluginResourceRequirements.Permission( + model=PluginResourceRequirements.Permission.Model( + enabled=True, + llm=True, + text_embedding=True, + rerank=True, + tts=False, + speech2text=False, + moderation=True, + ) + ), + ) + + # Assert: Verify all model permissions are set correctly + assert resources.memory == 2048 + assert resources.permission.model.enabled is True + assert resources.permission.model.llm is True + assert resources.permission.model.text_embedding is True + assert resources.permission.model.rerank is True + assert resources.permission.model.tts is False + assert resources.permission.model.speech2text is False + assert resources.permission.model.moderation is True + + def test_resource_requirements_with_storage_permission(self): + """ + Test plugin resource requirements with storage permissions. + + Storage permissions allow a plugin to persist data with size limits. + The size must be between 1KB (1024 bytes) and 1GB (1073741824 bytes). + """ + # Arrange & Act: Create resource requirements with storage permissions + resources = PluginResourceRequirements( + memory=512, + permission=PluginResourceRequirements.Permission( + storage=PluginResourceRequirements.Permission.Storage(enabled=True, size=10485760) # 10MB + ), + ) + + # Assert: Verify storage permissions and size limits + assert resources.permission.storage.enabled is True + assert resources.permission.storage.size == 10485760 + + def test_resource_requirements_with_endpoint_permission(self): + """ + Test plugin resource requirements with endpoint permissions. + + Endpoint permissions allow a plugin to expose HTTP endpoints. + """ + # Arrange & Act: Create resource requirements with endpoint permissions + resources = PluginResourceRequirements( + memory=1024, + permission=PluginResourceRequirements.Permission( + endpoint=PluginResourceRequirements.Permission.Endpoint(enabled=True) + ), + ) + + # Assert: Verify endpoint permissions are enabled + assert resources.permission.endpoint.enabled is True + + def test_resource_requirements_with_node_permission(self): + """ + Test plugin resource requirements with node permissions. + + Node permissions allow a plugin to provide custom workflow nodes. + """ + # Arrange & Act: Create resource requirements with node permissions + resources = PluginResourceRequirements( + memory=768, + permission=PluginResourceRequirements.Permission( + node=PluginResourceRequirements.Permission.Node(enabled=True) + ), + ) + + # Assert: Verify node permissions are enabled + assert resources.permission.node.enabled is True + + +class TestPluginInstallationSources: + """Test different plugin installation sources.""" + + def test_marketplace_installation_source(self): + """ + Test plugin installation from marketplace source. + + Marketplace is the official plugin distribution channel where + verified and community plugins are available for installation. + """ + # Arrange & Act: Use marketplace as installation source + source = PluginInstallationSource.Marketplace + + # Assert: Verify source type + assert source == PluginInstallationSource.Marketplace + assert source.value == "marketplace" + + def test_github_installation_source(self): + """ + Test plugin installation from GitHub source. + + GitHub source allows installing plugins directly from GitHub repositories, + useful for development and testing unreleased versions. + """ + # Arrange & Act: Use GitHub as installation source + source = PluginInstallationSource.Github + + # Assert: Verify source type + assert source == PluginInstallationSource.Github + assert source.value == "github" + + def test_package_installation_source(self): + """ + Test plugin installation from package source. + + Package source allows installing plugins from local .difypkg files, + useful for private or custom plugins. + """ + # Arrange & Act: Use package as installation source + source = PluginInstallationSource.Package + + # Assert: Verify source type + assert source == PluginInstallationSource.Package + assert source.value == "package" + + def test_remote_installation_source(self): + """ + Test plugin installation from remote source. + + Remote source allows installing plugins from custom remote URLs. + """ + # Arrange & Act: Use remote as installation source + source = PluginInstallationSource.Remote + + # Assert: Verify source type + assert source == PluginInstallationSource.Remote + assert source.value == "remote" + + +class TestPluginBundleOperations: + """Test plugin bundle operations and dependency extraction.""" + + @pytest.fixture + def plugin_installer(self): + """Create a PluginInstaller instance for testing.""" + return PluginInstaller() + + def test_upload_bundle_with_marketplace_dependencies(self, plugin_installer): + """ + Test uploading a bundle with marketplace dependencies. + + Marketplace dependencies reference plugins available in the official marketplace + by organization, plugin name, and version. + """ + # Arrange: Create mock bundle with marketplace dependencies + bundle_data = b"mock-marketplace-bundle" + mock_dependencies = [ + PluginBundleDependency( + type=PluginBundleDependency.Type.Marketplace, + value=PluginBundleDependency.Marketplace( + organization="langgenius", plugin="search-tool", version="1.2.0" + ), + ) + ] + + with patch.object(plugin_installer, "_request_with_plugin_daemon_response", return_value=mock_dependencies): + # Act: Upload bundle + result = plugin_installer.upload_bundle("test-tenant", bundle_data) + + # Assert: Verify marketplace dependency was extracted + assert len(result) == 1 + assert result[0].type == PluginBundleDependency.Type.Marketplace + assert isinstance(result[0].value, PluginBundleDependency.Marketplace) + assert result[0].value.organization == "langgenius" + assert result[0].value.plugin == "search-tool" + + def test_upload_bundle_with_github_dependencies(self, plugin_installer): + """ + Test uploading a bundle with GitHub dependencies. + + GitHub dependencies reference plugins hosted on GitHub repositories + with specific releases and package files. + """ + # Arrange: Create mock bundle with GitHub dependencies + bundle_data = b"mock-github-bundle" + mock_dependencies = [ + PluginBundleDependency( + type=PluginBundleDependency.Type.Github, + value=PluginBundleDependency.Github( + repo_address="https://github.com/example/plugin", + repo="example/plugin", + release="v2.0.0", + packages="plugin-v2.0.0.zip", + ), + ) + ] + + with patch.object(plugin_installer, "_request_with_plugin_daemon_response", return_value=mock_dependencies): + # Act: Upload bundle + result = plugin_installer.upload_bundle("test-tenant", bundle_data) + + # Assert: Verify GitHub dependency was extracted + assert len(result) == 1 + assert result[0].type == PluginBundleDependency.Type.Github + assert isinstance(result[0].value, PluginBundleDependency.Github) + assert result[0].value.repo == "example/plugin" + assert result[0].value.release == "v2.0.0" + + def test_upload_bundle_with_package_dependencies(self, plugin_installer): + """ + Test uploading a bundle with package dependencies. + + Package dependencies include the full plugin manifest and unique identifier, + allowing for self-contained plugin bundles. + """ + # Arrange: Create mock bundle with package dependencies + bundle_data = b"mock-package-bundle" + mock_manifest = PluginDeclaration( + version="1.5.0", + author="bundle-author", + name="bundled-plugin", + description=I18nObject(en_US="Bundled plugin", zh_Hans="捆绑插件"), + icon="icon.png", + label=I18nObject(en_US="Bundled Plugin", zh_Hans="捆绑插件"), + category=PluginCategory.Extension, + created_at=datetime.datetime.now(), + resource=PluginResourceRequirements(memory=512, permission=None), + plugins=PluginDeclaration.Plugins(), + meta=PluginDeclaration.Meta(version="1.5.0"), + ) + + mock_dependencies = [ + PluginBundleDependency( + type=PluginBundleDependency.Type.Package, + value=PluginBundleDependency.Package( + unique_identifier="org/bundled-plugin/1.5.0", manifest=mock_manifest + ), + ) + ] + + with patch.object(plugin_installer, "_request_with_plugin_daemon_response", return_value=mock_dependencies): + # Act: Upload bundle + result = plugin_installer.upload_bundle("test-tenant", bundle_data) + + # Assert: Verify package dependency was extracted with manifest + assert len(result) == 1 + assert result[0].type == PluginBundleDependency.Type.Package + assert isinstance(result[0].value, PluginBundleDependency.Package) + assert result[0].value.unique_identifier == "org/bundled-plugin/1.5.0" + assert result[0].value.manifest.name == "bundled-plugin" + + def test_upload_bundle_with_mixed_dependencies(self, plugin_installer): + """ + Test uploading a bundle with multiple dependency types. + + Real-world plugin bundles often have dependencies from various sources: + marketplace plugins, GitHub repositories, and packaged plugins. + """ + # Arrange: Create mock bundle with mixed dependencies + bundle_data = b"mock-mixed-bundle" + mock_dependencies = [ + PluginBundleDependency( + type=PluginBundleDependency.Type.Marketplace, + value=PluginBundleDependency.Marketplace(organization="org1", plugin="plugin1", version="1.0.0"), + ), + PluginBundleDependency( + type=PluginBundleDependency.Type.Github, + value=PluginBundleDependency.Github( + repo_address="https://github.com/org2/plugin2", + repo="org2/plugin2", + release="v1.0.0", + packages="plugin2.zip", + ), + ), + ] + + with patch.object(plugin_installer, "_request_with_plugin_daemon_response", return_value=mock_dependencies): + # Act: Upload bundle + result = plugin_installer.upload_bundle("test-tenant", bundle_data, verify_signature=True) + + # Assert: Verify all dependency types were extracted + assert len(result) == 2 + assert result[0].type == PluginBundleDependency.Type.Marketplace + assert result[1].type == PluginBundleDependency.Type.Github + + +class TestPluginTaskStatusTransitions: + """Test plugin installation task status transitions and lifecycle.""" + + @pytest.fixture + def plugin_installer(self): + """Create a PluginInstaller instance for testing.""" + return PluginInstaller() + + def test_task_status_pending(self, plugin_installer): + """ + Test plugin installation task in pending status. + + Pending status indicates the task has been created but not yet started. + No plugins have been processed yet. + """ + # Arrange: Create mock task in pending status + mock_task = PluginInstallTask( + id="pending-task", + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + status=PluginInstallTaskStatus.Pending, + total_plugins=3, + completed_plugins=0, # No plugins completed yet + plugins=[], + ) + + with patch.object(plugin_installer, "_request_with_plugin_daemon_response", return_value=mock_task): + # Act: Fetch task + result = plugin_installer.fetch_plugin_installation_task("test-tenant", "pending-task") + + # Assert: Verify pending status + assert result.status == PluginInstallTaskStatus.Pending + assert result.completed_plugins == 0 + assert result.total_plugins == 3 + + def test_task_status_running(self, plugin_installer): + """ + Test plugin installation task in running status. + + Running status indicates the task is actively installing plugins. + Some plugins may be completed while others are still in progress. + """ + # Arrange: Create mock task in running status + mock_task = PluginInstallTask( + id="running-task", + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + status=PluginInstallTaskStatus.Running, + total_plugins=5, + completed_plugins=2, # 2 out of 5 completed + plugins=[], + ) + + with patch.object(plugin_installer, "_request_with_plugin_daemon_response", return_value=mock_task): + # Act: Fetch task + result = plugin_installer.fetch_plugin_installation_task("test-tenant", "running-task") + + # Assert: Verify running status and progress + assert result.status == PluginInstallTaskStatus.Running + assert result.completed_plugins == 2 + assert result.total_plugins == 5 + assert result.completed_plugins < result.total_plugins + + def test_task_status_success(self, plugin_installer): + """ + Test plugin installation task in success status. + + Success status indicates all plugins in the task have been + successfully installed without errors. + """ + # Arrange: Create mock task in success status + mock_task = PluginInstallTask( + id="success-task", + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + status=PluginInstallTaskStatus.Success, + total_plugins=4, + completed_plugins=4, # All plugins completed + plugins=[], + ) + + with patch.object(plugin_installer, "_request_with_plugin_daemon_response", return_value=mock_task): + # Act: Fetch task + result = plugin_installer.fetch_plugin_installation_task("test-tenant", "success-task") + + # Assert: Verify success status and completion + assert result.status == PluginInstallTaskStatus.Success + assert result.completed_plugins == result.total_plugins + assert result.completed_plugins == 4 + + def test_task_status_failed(self, plugin_installer): + """ + Test plugin installation task in failed status. + + Failed status indicates the task encountered errors during installation. + Some plugins may have been installed before the failure occurred. + """ + # Arrange: Create mock task in failed status + mock_task = PluginInstallTask( + id="failed-task", + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + status=PluginInstallTaskStatus.Failed, + total_plugins=3, + completed_plugins=1, # Only 1 completed before failure + plugins=[], + ) + + with patch.object(plugin_installer, "_request_with_plugin_daemon_response", return_value=mock_task): + # Act: Fetch task + result = plugin_installer.fetch_plugin_installation_task("test-tenant", "failed-task") + + # Assert: Verify failed status + assert result.status == PluginInstallTaskStatus.Failed + assert result.completed_plugins < result.total_plugins + + +class TestPluginI18nSupport: + """Test plugin internationalization (i18n) support.""" + + def test_plugin_with_multiple_languages(self): + """ + Test plugin declaration with multiple language support. + + Plugins should support multiple languages for descriptions and labels + to provide localized experiences for users worldwide. + """ + # Arrange & Act: Create plugin with English and Chinese support + declaration = PluginDeclaration( + version="1.0.0", + author="i18n-author", + name="multilang-plugin", + description=I18nObject( + en_US="A plugin with multilingual support", + zh_Hans="支持多语言的插件", + ja_JP="多言語対応のプラグイン", + ), + icon="icon.png", + label=I18nObject(en_US="Multilingual Plugin", zh_Hans="多语言插件", ja_JP="多言語プラグイン"), + category=PluginCategory.Extension, + created_at=datetime.datetime.now(), + resource=PluginResourceRequirements(memory=512, permission=None), + plugins=PluginDeclaration.Plugins(), + meta=PluginDeclaration.Meta(version="1.0.0"), + ) + + # Assert: Verify all language variants are preserved + assert declaration.description.en_US == "A plugin with multilingual support" + assert declaration.description.zh_Hans == "支持多语言的插件" + assert declaration.label.en_US == "Multilingual Plugin" + assert declaration.label.zh_Hans == "多语言插件" + + def test_plugin_readme_language_variants(self): + """ + Test fetching plugin README in different languages. + + Plugins can provide README files in multiple languages to help + users understand the plugin in their preferred language. + """ + # Arrange: Create plugin installer instance + plugin_installer = PluginInstaller() + + # Mock README responses for different languages + english_readme = PluginReadmeResponse( + content="# English README\n\nThis is the English version.", language="en_US" + ) + + chinese_readme = PluginReadmeResponse(content="# 中文说明\n\n这是中文版本。", language="zh_Hans") + + # Test English README + with patch.object(plugin_installer, "_request_with_plugin_daemon_response", return_value=english_readme): + # Act: Fetch English README + result_en = plugin_installer.fetch_plugin_readme("test-tenant", "plugin/1.0.0", "en_US") + + # Assert: Verify English content + assert "English README" in result_en + + # Test Chinese README + with patch.object(plugin_installer, "_request_with_plugin_daemon_response", return_value=chinese_readme): + # Act: Fetch Chinese README + result_zh = plugin_installer.fetch_plugin_readme("test-tenant", "plugin/1.0.0", "zh_Hans") + + # Assert: Verify Chinese content + assert "中文说明" in result_zh