diff --git a/api/configs/middleware/storage/volcengine_tos_storage_config.py b/api/configs/middleware/storage/volcengine_tos_storage_config.py index be01f2dc36..2a35300401 100644 --- a/api/configs/middleware/storage/volcengine_tos_storage_config.py +++ b/api/configs/middleware/storage/volcengine_tos_storage_config.py @@ -4,7 +4,7 @@ from pydantic_settings import BaseSettings class VolcengineTOSStorageConfig(BaseSettings): """ - Configuration settings for Volcengine Tinder Object Storage (TOS) + Configuration settings for Volcengine Torch Object Storage (TOS) """ VOLCENGINE_TOS_BUCKET_NAME: str | None = Field( diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index ac78d3854b..707d90f044 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -7,7 +7,7 @@ from typing import Literal, cast import sqlalchemy as sa from flask import request from flask_restx import Resource, fields, marshal, marshal_with -from pydantic import BaseModel +from pydantic import BaseModel, Field from sqlalchemy import asc, desc, select from werkzeug.exceptions import Forbidden, NotFound @@ -104,6 +104,15 @@ class DocumentRenamePayload(BaseModel): name: str +class DocumentDatasetListParam(BaseModel): + page: int = Field(1, title="Page", description="Page number.") + limit: int = Field(20, title="Limit", description="Page size.") + search: str | None = Field(None, alias="keyword", title="Search", description="Search keyword.") + sort_by: str = Field("-created_at", alias="sort", title="SortBy", description="Sort by field.") + status: str | None = Field(None, title="Status", description="Document status.") + fetch_val: str = Field("false", alias="fetch") + + register_schema_models( console_ns, KnowledgeConfig, @@ -225,14 +234,16 @@ class DatasetDocumentListApi(Resource): def get(self, dataset_id): current_user, current_tenant_id = current_account_with_tenant() dataset_id = str(dataset_id) - page = request.args.get("page", default=1, type=int) - limit = request.args.get("limit", default=20, type=int) - search = request.args.get("keyword", default=None, type=str) - sort = request.args.get("sort", default="-created_at", type=str) - status = request.args.get("status", default=None, type=str) + raw_args = request.args.to_dict() + param = DocumentDatasetListParam.model_validate(raw_args) + page = param.page + limit = param.limit + search = param.search + sort = param.sort_by + status = param.status # "yes", "true", "t", "y", "1" convert to True, while others convert to False. try: - fetch_val = request.args.get("fetch", default="false") + fetch_val = param.fetch_val if isinstance(fetch_val, bool): fetch = fetch_val else: diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index d636548f2b..a258144d35 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -24,7 +24,7 @@ from core.app.layers.conversation_variable_persist_layer import ConversationVari from core.db.session_factory import session_factory from core.moderation.base import ModerationError from core.moderation.input_moderation import InputModeration -from core.variables.variables import VariableUnion +from core.variables.variables import Variable from core.workflow.enums import WorkflowType from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel from core.workflow.graph_engine.layers.base import GraphEngineLayer @@ -149,8 +149,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): system_variables=system_inputs, user_inputs=inputs, environment_variables=self._workflow.environment_variables, - # Based on the definition of `VariableUnion`, - # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. + # Based on the definition of `Variable`, + # `VariableBase` instances can be safely used as `Variable` since they are compatible. conversation_variables=conversation_variables, ) @@ -318,7 +318,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): trace_manager=app_generate_entity.trace_manager, ) - def _initialize_conversation_variables(self) -> list[VariableUnion]: + def _initialize_conversation_variables(self) -> list[Variable]: """ Initialize conversation variables for the current conversation. @@ -343,7 +343,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): conversation_variables = [var.to_variable() for var in existing_variables] session.commit() - return cast(list[VariableUnion], conversation_variables) + return cast(list[Variable], conversation_variables) def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]: """ diff --git a/api/core/app/layers/conversation_variable_persist_layer.py b/api/core/app/layers/conversation_variable_persist_layer.py index 77cc00bdc9..c070845b73 100644 --- a/api/core/app/layers/conversation_variable_persist_layer.py +++ b/api/core/app/layers/conversation_variable_persist_layer.py @@ -1,6 +1,6 @@ import logging -from core.variables import Variable +from core.variables import VariableBase from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID from core.workflow.conversation_variable_updater import ConversationVariableUpdater from core.workflow.enums import NodeType @@ -44,7 +44,7 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer): if selector[0] != CONVERSATION_VARIABLE_NODE_ID: continue variable = self.graph_runtime_state.variable_pool.get(selector) - if not isinstance(variable, Variable): + if not isinstance(variable, VariableBase): logger.warning( "Conversation variable not found in variable pool. selector=%s", selector, diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/core/ops/aliyun_trace/aliyun_trace.py index cf6659150f..22ad756c91 100644 --- a/api/core/ops/aliyun_trace/aliyun_trace.py +++ b/api/core/ops/aliyun_trace/aliyun_trace.py @@ -55,7 +55,7 @@ from core.ops.entities.trace_entity import ( ToolTraceInfo, WorkflowTraceInfo, ) -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories import DifyCoreRepositoryFactory from core.workflow.entities import WorkflowNodeExecution from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey from extensions.ext_database import db @@ -275,7 +275,7 @@ class AliyunDataTrace(BaseTraceInstance): service_account = self.get_service_account_with_tenant(app_id) session_factory = sessionmaker(bind=db.engine) - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=service_account, app_id=app_id, diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 81a1d54199..389db8a972 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -7,8 +7,8 @@ from typing import Any, cast from flask import has_request_context from sqlalchemy import select -from sqlalchemy.orm import Session +from core.db.session_factory import session_factory from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from core.tools.__base.tool import Tool @@ -20,7 +20,6 @@ from core.tools.entities.tool_entities import ( ToolProviderType, ) from core.tools.errors import ToolInvokeError -from extensions.ext_database import db from factories.file_factory import build_from_mapping from libs.login import current_user from models import Account, Tenant @@ -230,30 +229,32 @@ class WorkflowTool(Tool): """ Resolve user from database (worker/Celery context). """ + with session_factory.create_session() as session: + tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id) + tenant = session.scalar(tenant_stmt) + if not tenant: + return None + + user_stmt = select(Account).where(Account.id == user_id) + user = session.scalar(user_stmt) + if user: + user.current_tenant = tenant + session.expunge(user) + return user + + end_user_stmt = select(EndUser).where(EndUser.id == user_id, EndUser.tenant_id == tenant.id) + end_user = session.scalar(end_user_stmt) + if end_user: + session.expunge(end_user) + return end_user - tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id) - tenant = db.session.scalar(tenant_stmt) - if not tenant: return None - user_stmt = select(Account).where(Account.id == user_id) - user = db.session.scalar(user_stmt) - if user: - user.current_tenant = tenant - return user - - end_user_stmt = select(EndUser).where(EndUser.id == user_id, EndUser.tenant_id == tenant.id) - end_user = db.session.scalar(end_user_stmt) - if end_user: - return end_user - - return None - def _get_workflow(self, app_id: str, version: str) -> Workflow: """ get the workflow by app id and version """ - with Session(db.engine, expire_on_commit=False) as session, session.begin(): + with session_factory.create_session() as session, session.begin(): if not version: stmt = ( select(Workflow) @@ -265,22 +266,24 @@ class WorkflowTool(Tool): stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version) workflow = session.scalar(stmt) - if not workflow: - raise ValueError("workflow not found or not published") + if not workflow: + raise ValueError("workflow not found or not published") - return workflow + session.expunge(workflow) + return workflow def _get_app(self, app_id: str) -> App: """ get the app by app id """ stmt = select(App).where(App.id == app_id) - with Session(db.engine, expire_on_commit=False) as session, session.begin(): + with session_factory.create_session() as session, session.begin(): app = session.scalar(stmt) - if not app: - raise ValueError("app not found") + if not app: + raise ValueError("app not found") - return app + session.expunge(app) + return app def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]: """ diff --git a/api/core/variables/__init__.py b/api/core/variables/__init__.py index 7a1cbf9940..7498224923 100644 --- a/api/core/variables/__init__.py +++ b/api/core/variables/__init__.py @@ -30,6 +30,7 @@ from .variables import ( SecretVariable, StringVariable, Variable, + VariableBase, ) __all__ = [ @@ -62,4 +63,5 @@ __all__ = [ "StringSegment", "StringVariable", "Variable", + "VariableBase", ] diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py index 406b4e6f93..8330f1fe19 100644 --- a/api/core/variables/segments.py +++ b/api/core/variables/segments.py @@ -232,7 +232,7 @@ def get_segment_discriminator(v: Any) -> SegmentType | None: # - All variants in `SegmentUnion` must inherit from the `Segment` class. # - The union must include all non-abstract subclasses of `Segment`, except: # - `SegmentGroup`, which is not added to the variable pool. -# - `Variable` and its subclasses, which are handled by `VariableUnion`. +# - `VariableBase` and its subclasses, which are handled by `Variable`. SegmentUnion: TypeAlias = Annotated[ ( Annotated[NoneSegment, Tag(SegmentType.NONE)] diff --git a/api/core/variables/variables.py b/api/core/variables/variables.py index 9fd0bbc5b2..a19c53918d 100644 --- a/api/core/variables/variables.py +++ b/api/core/variables/variables.py @@ -27,7 +27,7 @@ from .segments import ( from .types import SegmentType -class Variable(Segment): +class VariableBase(Segment): """ A variable is a segment that has a name. @@ -45,23 +45,23 @@ class Variable(Segment): selector: Sequence[str] = Field(default_factory=list) -class StringVariable(StringSegment, Variable): +class StringVariable(StringSegment, VariableBase): pass -class FloatVariable(FloatSegment, Variable): +class FloatVariable(FloatSegment, VariableBase): pass -class IntegerVariable(IntegerSegment, Variable): +class IntegerVariable(IntegerSegment, VariableBase): pass -class ObjectVariable(ObjectSegment, Variable): +class ObjectVariable(ObjectSegment, VariableBase): pass -class ArrayVariable(ArraySegment, Variable): +class ArrayVariable(ArraySegment, VariableBase): pass @@ -89,16 +89,16 @@ class SecretVariable(StringVariable): return encrypter.obfuscated_token(self.value) -class NoneVariable(NoneSegment, Variable): +class NoneVariable(NoneSegment, VariableBase): value_type: SegmentType = SegmentType.NONE value: None = None -class FileVariable(FileSegment, Variable): +class FileVariable(FileSegment, VariableBase): pass -class BooleanVariable(BooleanSegment, Variable): +class BooleanVariable(BooleanSegment, VariableBase): pass @@ -139,13 +139,13 @@ class RAGPipelineVariableInput(BaseModel): value: Any -# The `VariableUnion`` type is used to enable serialization and deserialization with Pydantic. -# Use `Variable` for type hinting when serialization is not required. +# The `Variable` type is used to enable serialization and deserialization with Pydantic. +# Use `VariableBase` for type hinting when serialization is not required. # # Note: -# - All variants in `VariableUnion` must inherit from the `Variable` class. -# - The union must include all non-abstract subclasses of `Segment`, except: -VariableUnion: TypeAlias = Annotated[ +# - All variants in `Variable` must inherit from the `VariableBase` class. +# - The union must include all non-abstract subclasses of `VariableBase`. +Variable: TypeAlias = Annotated[ ( Annotated[NoneVariable, Tag(SegmentType.NONE)] | Annotated[StringVariable, Tag(SegmentType.STRING)] diff --git a/api/core/workflow/conversation_variable_updater.py b/api/core/workflow/conversation_variable_updater.py index fd78248c17..75f47691da 100644 --- a/api/core/workflow/conversation_variable_updater.py +++ b/api/core/workflow/conversation_variable_updater.py @@ -1,7 +1,7 @@ import abc from typing import Protocol -from core.variables import Variable +from core.variables import VariableBase class ConversationVariableUpdater(Protocol): @@ -20,12 +20,12 @@ class ConversationVariableUpdater(Protocol): """ @abc.abstractmethod - def update(self, conversation_id: str, variable: "Variable"): + def update(self, conversation_id: str, variable: "VariableBase"): """ Updates the value of the specified conversation variable in the underlying storage. :param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`. - :param variable: The `Variable` instance containing the updated value. + :param variable: The `VariableBase` instance containing the updated value. """ pass diff --git a/api/core/workflow/graph_engine/entities/commands.py b/api/core/workflow/graph_engine/entities/commands.py index 6dce03c94d..41276eb444 100644 --- a/api/core/workflow/graph_engine/entities/commands.py +++ b/api/core/workflow/graph_engine/entities/commands.py @@ -11,7 +11,7 @@ from typing import Any from pydantic import BaseModel, Field -from core.variables.variables import VariableUnion +from core.variables.variables import Variable class CommandType(StrEnum): @@ -46,7 +46,7 @@ class PauseCommand(GraphEngineCommand): class VariableUpdate(BaseModel): """Represents a single variable update instruction.""" - value: VariableUnion = Field(description="New variable value") + value: Variable = Field(description="New variable value") class UpdateVariablesCommand(GraphEngineCommand): diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index e5d86414c1..91df2e4e0b 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -11,7 +11,7 @@ from typing_extensions import TypeIs from core.model_runtime.entities.llm_entities import LLMUsage from core.variables import IntegerVariable, NoneSegment from core.variables.segments import ArrayAnySegment, ArraySegment -from core.variables.variables import VariableUnion +from core.variables.variables import Variable from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID from core.workflow.enums import ( NodeExecutionType, @@ -240,7 +240,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): datetime, list[GraphNodeEventBase], object | None, - dict[str, VariableUnion], + dict[str, Variable], LLMUsage, ] ], @@ -308,7 +308,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): item: object, flask_app: Flask, context_vars: contextvars.Context, - ) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, VariableUnion], LLMUsage]: + ) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]: """Execute a single iteration in parallel mode and return results.""" with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars): iter_start_at = datetime.now(UTC).replace(tzinfo=None) @@ -515,11 +515,11 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): return variable_mapping - def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, VariableUnion]: + def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, Variable]: conversation_variables = variable_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {}) return {name: variable.model_copy(deep=True) for name, variable in conversation_variables.items()} - def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, VariableUnion]) -> None: + def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, Variable]) -> None: parent_pool = self.graph_runtime_state.variable_pool parent_conversations = parent_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {}) diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py index ac2870aa65..9f5818f4bb 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node.py +++ b/api/core/workflow/nodes/variable_assigner/v1/node.py @@ -1,7 +1,7 @@ from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any -from core.variables import SegmentType, Variable +from core.variables import SegmentType, VariableBase from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID from core.workflow.entities import GraphInitParams from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus @@ -73,7 +73,7 @@ class VariableAssignerNode(Node[VariableAssignerData]): assigned_variable_selector = self.node_data.assigned_variable_selector # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector) - if not isinstance(original_variable, Variable): + if not isinstance(original_variable, VariableBase): raise VariableOperatorNodeError("assigned variable not found") match self.node_data.write_mode: diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py index 486e6bb6a7..5857702e72 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/node.py +++ b/api/core/workflow/nodes/variable_assigner/v2/node.py @@ -2,7 +2,7 @@ import json from collections.abc import Mapping, MutableMapping, Sequence from typing import TYPE_CHECKING, Any -from core.variables import SegmentType, Variable +from core.variables import SegmentType, VariableBase from core.variables.consts import SELECTORS_LENGTH from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus @@ -118,7 +118,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]): # ==================== Validation Part # Check if variable exists - if not isinstance(variable, Variable): + if not isinstance(variable, VariableBase): raise VariableNotFoundError(variable_selector=item.variable_selector) # Check if operation is supported @@ -192,7 +192,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]): for selector in updated_variable_selectors: variable = self.graph_runtime_state.variable_pool.get(selector) - if not isinstance(variable, Variable): + if not isinstance(variable, VariableBase): raise VariableNotFoundError(variable_selector=selector) process_data[variable.name] = variable.value @@ -213,7 +213,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]): def _handle_item( self, *, - variable: Variable, + variable: VariableBase, operation: Operation, value: Any, ): diff --git a/api/core/workflow/runtime/variable_pool.py b/api/core/workflow/runtime/variable_pool.py index 85ceb9d59e..d205c6ac8f 100644 --- a/api/core/workflow/runtime/variable_pool.py +++ b/api/core/workflow/runtime/variable_pool.py @@ -9,10 +9,10 @@ from typing import Annotated, Any, Union, cast from pydantic import BaseModel, Field from core.file import File, FileAttribute, file_manager -from core.variables import Segment, SegmentGroup, Variable +from core.variables import Segment, SegmentGroup, VariableBase from core.variables.consts import SELECTORS_LENGTH from core.variables.segments import FileSegment, ObjectSegment -from core.variables.variables import RAGPipelineVariableInput, VariableUnion +from core.variables.variables import RAGPipelineVariableInput, Variable from core.workflow.constants import ( CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, @@ -32,7 +32,7 @@ class VariablePool(BaseModel): # The first element of the selector is the node id, it's the first-level key in the dictionary. # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the # elements of the selector except the first one. - variable_dictionary: defaultdict[str, Annotated[dict[str, VariableUnion], Field(default_factory=dict)]] = Field( + variable_dictionary: defaultdict[str, Annotated[dict[str, Variable], Field(default_factory=dict)]] = Field( description="Variables mapping", default=defaultdict(dict), ) @@ -46,13 +46,13 @@ class VariablePool(BaseModel): description="System variables", default_factory=SystemVariable.empty, ) - environment_variables: Sequence[VariableUnion] = Field( + environment_variables: Sequence[Variable] = Field( description="Environment variables.", - default_factory=list[VariableUnion], + default_factory=list[Variable], ) - conversation_variables: Sequence[VariableUnion] = Field( + conversation_variables: Sequence[Variable] = Field( description="Conversation variables.", - default_factory=list[VariableUnion], + default_factory=list[Variable], ) rag_pipeline_variables: list[RAGPipelineVariableInput] = Field( description="RAG pipeline variables.", @@ -105,7 +105,7 @@ class VariablePool(BaseModel): f"got {len(selector)} elements" ) - if isinstance(value, Variable): + if isinstance(value, VariableBase): variable = value elif isinstance(value, Segment): variable = variable_factory.segment_to_variable(segment=value, selector=selector) @@ -114,9 +114,9 @@ class VariablePool(BaseModel): variable = variable_factory.segment_to_variable(segment=segment, selector=selector) node_id, name = self._selector_to_keys(selector) - # Based on the definition of `VariableUnion`, - # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. - self.variable_dictionary[node_id][name] = cast(VariableUnion, variable) + # Based on the definition of `Variable`, + # `VariableBase` instances can be safely used as `Variable` since they are compatible. + self.variable_dictionary[node_id][name] = cast(Variable, variable) @classmethod def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, str]: diff --git a/api/core/workflow/variable_loader.py b/api/core/workflow/variable_loader.py index ea0bdc3537..7992785fe1 100644 --- a/api/core/workflow/variable_loader.py +++ b/api/core/workflow/variable_loader.py @@ -2,7 +2,7 @@ import abc from collections.abc import Mapping, Sequence from typing import Any, Protocol -from core.variables import Variable +from core.variables import VariableBase from core.variables.consts import SELECTORS_LENGTH from core.workflow.runtime import VariablePool @@ -26,7 +26,7 @@ class VariableLoader(Protocol): """ @abc.abstractmethod - def load_variables(self, selectors: list[list[str]]) -> list[Variable]: + def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]: """Load variables based on the provided selectors. If the selectors are empty, this method should return an empty list. @@ -36,7 +36,7 @@ class VariableLoader(Protocol): :param: selectors: a list of string list, each inner list should have at least two elements: - the first element is the node ID, - the second element is the variable name. - :return: a list of Variable objects that match the provided selectors. + :return: a list of VariableBase objects that match the provided selectors. """ pass @@ -46,7 +46,7 @@ class _DummyVariableLoader(VariableLoader): Serves as a placeholder when no variable loading is needed. """ - def load_variables(self, selectors: list[list[str]]) -> list[Variable]: + def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]: return [] diff --git a/api/extensions/ext_logstore.py b/api/extensions/ext_logstore.py index 502f0bb46b..cda2d1ad1e 100644 --- a/api/extensions/ext_logstore.py +++ b/api/extensions/ext_logstore.py @@ -10,6 +10,7 @@ import os from dotenv import load_dotenv +from configs import dify_config from dify_app import DifyApp logger = logging.getLogger(__name__) @@ -19,12 +20,17 @@ def is_enabled() -> bool: """ Check if logstore extension is enabled. + Logstore is considered enabled when: + 1. All required Aliyun SLS environment variables are set + 2. At least one repository configuration points to a logstore implementation + Returns: - True if all required Aliyun SLS environment variables are set, False otherwise + True if logstore should be initialized, False otherwise """ # Load environment variables from .env file load_dotenv() + # Check if Aliyun SLS connection parameters are configured required_vars = [ "ALIYUN_SLS_ACCESS_KEY_ID", "ALIYUN_SLS_ACCESS_KEY_SECRET", @@ -33,24 +39,32 @@ def is_enabled() -> bool: "ALIYUN_SLS_PROJECT_NAME", ] - all_set = all(os.environ.get(var) for var in required_vars) + sls_vars_set = all(os.environ.get(var) for var in required_vars) - if not all_set: - logger.info("Logstore extension disabled: required Aliyun SLS environment variables not set") + if not sls_vars_set: + return False - return all_set + # Check if any repository configuration points to logstore implementation + repository_configs = [ + dify_config.CORE_WORKFLOW_EXECUTION_REPOSITORY, + dify_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY, + dify_config.API_WORKFLOW_NODE_EXECUTION_REPOSITORY, + dify_config.API_WORKFLOW_RUN_REPOSITORY, + ] + + uses_logstore = any("logstore" in config.lower() for config in repository_configs) + + if not uses_logstore: + return False + + logger.info("Logstore extension enabled: SLS variables set and repository configured to use logstore") + return True def init_app(app: DifyApp): """ Initialize logstore on application startup. - - This function: - 1. Creates Aliyun SLS project if it doesn't exist - 2. Creates logstores (workflow_execution, workflow_node_execution) if they don't exist - 3. Creates indexes with field configurations based on PostgreSQL table structures - - This operation is idempotent and only executes once during application startup. + If initialization fails, the application continues running without logstore features. Args: app: The Dify application instance @@ -58,17 +72,23 @@ def init_app(app: DifyApp): try: from extensions.logstore.aliyun_logstore import AliyunLogStore - logger.info("Initializing logstore...") + logger.info("Initializing Aliyun SLS Logstore...") - # Create logstore client and initialize project/logstores/indexes + # Create logstore client and initialize resources logstore_client = AliyunLogStore() logstore_client.init_project_logstore() - # Attach to app for potential later use app.extensions["logstore"] = logstore_client logger.info("Logstore initialized successfully") + except Exception: - logger.exception("Failed to initialize logstore") - # Don't raise - allow application to continue even if logstore init fails - # This ensures that the application can still run if logstore is misconfigured + logger.exception( + "Logstore initialization failed. Configuration: endpoint=%s, region=%s, project=%s, timeout=%ss. " + "Application will continue but logstore features will NOT work.", + os.environ.get("ALIYUN_SLS_ENDPOINT"), + os.environ.get("ALIYUN_SLS_REGION"), + os.environ.get("ALIYUN_SLS_PROJECT_NAME"), + os.environ.get("ALIYUN_SLS_CHECK_CONNECTIVITY_TIMEOUT", "30"), + ) + # Don't raise - allow application to continue even if logstore setup fails diff --git a/api/extensions/logstore/aliyun_logstore.py b/api/extensions/logstore/aliyun_logstore.py index 8c64a25be4..f6a4765f14 100644 --- a/api/extensions/logstore/aliyun_logstore.py +++ b/api/extensions/logstore/aliyun_logstore.py @@ -2,6 +2,7 @@ from __future__ import annotations import logging import os +import socket import threading import time from collections.abc import Sequence @@ -179,9 +180,18 @@ class AliyunLogStore: self.region: str = os.environ.get("ALIYUN_SLS_REGION", "") self.project_name: str = os.environ.get("ALIYUN_SLS_PROJECT_NAME", "") self.logstore_ttl: int = int(os.environ.get("ALIYUN_SLS_LOGSTORE_TTL", 365)) - self.log_enabled: bool = os.environ.get("SQLALCHEMY_ECHO", "false").lower() == "true" + self.log_enabled: bool = ( + os.environ.get("SQLALCHEMY_ECHO", "false").lower() == "true" + or os.environ.get("LOGSTORE_SQL_ECHO", "false").lower() == "true" + ) self.pg_mode_enabled: bool = os.environ.get("LOGSTORE_PG_MODE_ENABLED", "true").lower() == "true" + # Get timeout configuration + check_timeout = int(os.environ.get("ALIYUN_SLS_CHECK_CONNECTIVITY_TIMEOUT", 30)) + + # Pre-check endpoint connectivity to prevent indefinite hangs + self._check_endpoint_connectivity(self.endpoint, check_timeout) + # Initialize SDK client self.client = LogClient( self.endpoint, self.access_key_id, self.access_key_secret, auth_version=AUTH_VERSION_4, region=self.region @@ -199,6 +209,49 @@ class AliyunLogStore: self.__class__._initialized = True + @staticmethod + def _check_endpoint_connectivity(endpoint: str, timeout: int) -> None: + """ + Check if the SLS endpoint is reachable before creating LogClient. + Prevents indefinite hangs when the endpoint is unreachable. + + Args: + endpoint: SLS endpoint URL + timeout: Connection timeout in seconds + + Raises: + ConnectionError: If endpoint is not reachable + """ + # Parse endpoint URL to extract hostname and port + from urllib.parse import urlparse + + parsed_url = urlparse(endpoint if "://" in endpoint else f"http://{endpoint}") + hostname = parsed_url.hostname + port = parsed_url.port or (443 if parsed_url.scheme == "https" else 80) + + if not hostname: + raise ConnectionError(f"Invalid endpoint URL: {endpoint}") + + sock = None + try: + # Create socket and set timeout + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(timeout) + sock.connect((hostname, port)) + except Exception as e: + # Catch all exceptions and provide clear error message + error_type = type(e).__name__ + raise ConnectionError( + f"Cannot connect to {hostname}:{port} (timeout={timeout}s): [{error_type}] {e}" + ) from e + finally: + # Ensure socket is properly closed + if sock: + try: + sock.close() + except Exception: # noqa: S110 + pass # Ignore errors during cleanup + @property def supports_pg_protocol(self) -> bool: """Check if PG protocol is supported and enabled.""" @@ -220,19 +273,16 @@ class AliyunLogStore: try: self._use_pg_protocol = self._pg_client.init_connection() if self._use_pg_protocol: - logger.info("Successfully connected to project %s using PG protocol", self.project_name) + logger.info("Using PG protocol for project %s", self.project_name) # Check if scan_index is enabled for all logstores self._check_and_disable_pg_if_scan_index_disabled() return True else: - logger.info("PG connection failed for project %s. Will use SDK mode.", self.project_name) + logger.info("Using SDK mode for project %s", self.project_name) return False except Exception as e: - logger.warning( - "Failed to establish PG connection for project %s: %s. Will use SDK mode.", - self.project_name, - str(e), - ) + logger.info("Using SDK mode for project %s", self.project_name) + logger.debug("PG connection details: %s", str(e)) self._use_pg_protocol = False return False @@ -246,10 +296,6 @@ class AliyunLogStore: if self._use_pg_protocol: return - logger.info( - "Attempting delayed PG connection for newly created project %s ...", - self.project_name, - ) self._attempt_pg_connection_init() self.__class__._pg_connection_timer = None @@ -284,11 +330,7 @@ class AliyunLogStore: if project_is_new: # For newly created projects, schedule delayed PG connection self._use_pg_protocol = False - logger.info( - "Project %s is newly created. Will use SDK mode and schedule PG connection attempt in %d seconds.", - self.project_name, - self.__class__._pg_connection_delay, - ) + logger.info("Using SDK mode for project %s (newly created)", self.project_name) if self.__class__._pg_connection_timer is not None: self.__class__._pg_connection_timer.cancel() self.__class__._pg_connection_timer = threading.Timer( @@ -299,7 +341,6 @@ class AliyunLogStore: self.__class__._pg_connection_timer.start() else: # For existing projects, attempt PG connection immediately - logger.info("Project %s already exists. Attempting PG connection...", self.project_name) self._attempt_pg_connection_init() def _check_and_disable_pg_if_scan_index_disabled(self) -> None: @@ -318,9 +359,9 @@ class AliyunLogStore: existing_config = self.get_existing_index_config(logstore_name) if existing_config and not existing_config.scan_index: logger.info( - "Logstore %s has scan_index=false, USE SDK mode for read/write operations. " - "PG protocol requires scan_index to be enabled.", + "Logstore %s requires scan_index enabled, using SDK mode for project %s", logstore_name, + self.project_name, ) self._use_pg_protocol = False # Close PG connection if it was initialized @@ -748,7 +789,6 @@ class AliyunLogStore: reverse=reverse, ) - # Log query info if SQLALCHEMY_ECHO is enabled if self.log_enabled: logger.info( "[LogStore] GET_LOGS | logstore=%s | project=%s | query=%s | " @@ -770,7 +810,6 @@ class AliyunLogStore: for log in logs: result.append(log.get_contents()) - # Log result count if SQLALCHEMY_ECHO is enabled if self.log_enabled: logger.info( "[LogStore] GET_LOGS RESULT | logstore=%s | returned_count=%d", @@ -845,7 +884,6 @@ class AliyunLogStore: query=full_query, ) - # Log query info if SQLALCHEMY_ECHO is enabled if self.log_enabled: logger.info( "[LogStore-SDK] EXECUTE_SQL | logstore=%s | project=%s | from_time=%d | to_time=%d | full_query=%s", @@ -853,8 +891,7 @@ class AliyunLogStore: self.project_name, from_time, to_time, - query, - sql, + full_query, ) try: @@ -865,7 +902,6 @@ class AliyunLogStore: for log in logs: result.append(log.get_contents()) - # Log result count if SQLALCHEMY_ECHO is enabled if self.log_enabled: logger.info( "[LogStore-SDK] EXECUTE_SQL RESULT | logstore=%s | returned_count=%d", diff --git a/api/extensions/logstore/aliyun_logstore_pg.py b/api/extensions/logstore/aliyun_logstore_pg.py index 35aa51ce53..874c20d144 100644 --- a/api/extensions/logstore/aliyun_logstore_pg.py +++ b/api/extensions/logstore/aliyun_logstore_pg.py @@ -7,8 +7,7 @@ from contextlib import contextmanager from typing import Any import psycopg2 -import psycopg2.pool -from psycopg2 import InterfaceError, OperationalError +from sqlalchemy import create_engine from configs import dify_config @@ -16,11 +15,7 @@ logger = logging.getLogger(__name__) class AliyunLogStorePG: - """ - PostgreSQL protocol support for Aliyun SLS LogStore. - - Handles PG connection pooling and operations for regions that support PG protocol. - """ + """PostgreSQL protocol support for Aliyun SLS LogStore using SQLAlchemy connection pool.""" def __init__(self, access_key_id: str, access_key_secret: str, endpoint: str, project_name: str): """ @@ -36,24 +31,11 @@ class AliyunLogStorePG: self._access_key_secret = access_key_secret self._endpoint = endpoint self.project_name = project_name - self._pg_pool: psycopg2.pool.SimpleConnectionPool | None = None + self._engine: Any = None # SQLAlchemy Engine self._use_pg_protocol = False def _check_port_connectivity(self, host: str, port: int, timeout: float = 2.0) -> bool: - """ - Check if a TCP port is reachable using socket connection. - - This provides a fast check before attempting full database connection, - preventing long waits when connecting to unsupported regions. - - Args: - host: Hostname or IP address - port: Port number - timeout: Connection timeout in seconds (default: 2.0) - - Returns: - True if port is reachable, False otherwise - """ + """Fast TCP port check to avoid long waits on unsupported regions.""" try: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.settimeout(timeout) @@ -65,166 +47,101 @@ class AliyunLogStorePG: return False def init_connection(self) -> bool: - """ - Initialize PostgreSQL connection pool for SLS PG protocol support. - - Attempts to connect to SLS using PostgreSQL protocol. If successful, sets - _use_pg_protocol to True and creates a connection pool. If connection fails - (region doesn't support PG protocol or other errors), returns False. - - Returns: - True if PG protocol is supported and initialized, False otherwise - """ + """Initialize SQLAlchemy connection pool with pool_recycle and TCP keepalive support.""" try: - # Extract hostname from endpoint (remove protocol if present) pg_host = self._endpoint.replace("http://", "").replace("https://", "") - # Get pool configuration - pg_max_connections = int(os.environ.get("ALIYUN_SLS_PG_MAX_CONNECTIONS", 10)) + # Pool configuration + pool_size = int(os.environ.get("ALIYUN_SLS_PG_POOL_SIZE", 5)) + max_overflow = int(os.environ.get("ALIYUN_SLS_PG_MAX_OVERFLOW", 5)) + pool_recycle = int(os.environ.get("ALIYUN_SLS_PG_POOL_RECYCLE", 3600)) + pool_pre_ping = os.environ.get("ALIYUN_SLS_PG_POOL_PRE_PING", "false").lower() == "true" - logger.debug( - "Check PG protocol connection to SLS: host=%s, project=%s", - pg_host, - self.project_name, - ) + logger.debug("Check PG protocol connection to SLS: host=%s, project=%s", pg_host, self.project_name) - # Fast port connectivity check before attempting full connection - # This prevents long waits when connecting to unsupported regions + # Fast port check to avoid long waits if not self._check_port_connectivity(pg_host, 5432, timeout=1.0): - logger.info( - "USE SDK mode for read/write operations, host=%s", - pg_host, - ) + logger.debug("Using SDK mode for host=%s", pg_host) return False - # Create connection pool - self._pg_pool = psycopg2.pool.SimpleConnectionPool( - minconn=1, - maxconn=pg_max_connections, - host=pg_host, - port=5432, - database=self.project_name, - user=self._access_key_id, - password=self._access_key_secret, - sslmode="require", - connect_timeout=5, - application_name=f"Dify-{dify_config.project.version}", + # Build connection URL + from urllib.parse import quote_plus + + username = quote_plus(self._access_key_id) + password = quote_plus(self._access_key_secret) + database_url = ( + f"postgresql+psycopg2://{username}:{password}@{pg_host}:5432/{self.project_name}?sslmode=require" ) - # Note: Skip test query because SLS PG protocol only supports SELECT/INSERT on actual tables - # Connection pool creation success already indicates connectivity + # Create SQLAlchemy engine with connection pool + self._engine = create_engine( + database_url, + pool_size=pool_size, + max_overflow=max_overflow, + pool_recycle=pool_recycle, + pool_pre_ping=pool_pre_ping, + pool_timeout=30, + connect_args={ + "connect_timeout": 5, + "application_name": f"Dify-{dify_config.project.version}-fixautocommit", + "keepalives": 1, + "keepalives_idle": 60, + "keepalives_interval": 10, + "keepalives_count": 5, + }, + ) self._use_pg_protocol = True logger.info( - "PG protocol initialized successfully for SLS project=%s. Will use PG for read/write operations.", + "PG protocol initialized for SLS project=%s (pool_size=%d, pool_recycle=%ds)", self.project_name, + pool_size, + pool_recycle, ) return True except Exception as e: - # PG connection failed - fallback to SDK mode self._use_pg_protocol = False - if self._pg_pool: + if self._engine: try: - self._pg_pool.closeall() + self._engine.dispose() except Exception: - logger.debug("Failed to close PG connection pool during cleanup, ignoring") - self._pg_pool = None + logger.debug("Failed to dispose engine during cleanup, ignoring") + self._engine = None - logger.info( - "PG protocol connection failed (region may not support PG protocol): %s. " - "Falling back to SDK mode for read/write operations.", - str(e), - ) - return False - - def _is_connection_valid(self, conn: Any) -> bool: - """ - Check if a connection is still valid. - - Args: - conn: psycopg2 connection object - - Returns: - True if connection is valid, False otherwise - """ - try: - # Check if connection is closed - if conn.closed: - return False - - # Quick ping test - execute a lightweight query - # For SLS PG protocol, we can't use SELECT 1 without FROM, - # so we just check the connection status - with conn.cursor() as cursor: - cursor.execute("SELECT 1") - cursor.fetchone() - return True - except Exception: + logger.debug("Using SDK mode for region: %s", str(e)) return False @contextmanager def _get_connection(self): - """ - Context manager to get a PostgreSQL connection from the pool. + """Get connection from SQLAlchemy pool. Pool handles recycle, invalidation, and keepalive automatically.""" + if not self._engine: + raise RuntimeError("SQLAlchemy engine is not initialized") - Automatically validates and refreshes stale connections. - - Note: Aliyun SLS PG protocol does not support transactions, so we always - use autocommit mode. - - Yields: - psycopg2 connection object - - Raises: - RuntimeError: If PG pool is not initialized - """ - if not self._pg_pool: - raise RuntimeError("PG connection pool is not initialized") - - conn = self._pg_pool.getconn() + connection = self._engine.raw_connection() try: - # Validate connection and get a fresh one if needed - if not self._is_connection_valid(conn): - logger.debug("Connection is stale, marking as bad and getting a new one") - # Mark connection as bad and get a new one - self._pg_pool.putconn(conn, close=True) - conn = self._pg_pool.getconn() - - # Aliyun SLS PG protocol does not support transactions, always use autocommit - conn.autocommit = True - yield conn + connection.autocommit = True # SLS PG protocol does not support transactions + yield connection + except Exception: + raise finally: - # Return connection to pool (or close if it's bad) - if self._is_connection_valid(conn): - self._pg_pool.putconn(conn) - else: - self._pg_pool.putconn(conn, close=True) + connection.close() def close(self) -> None: - """Close the PostgreSQL connection pool.""" - if self._pg_pool: + """Dispose SQLAlchemy engine and close all connections.""" + if self._engine: try: - self._pg_pool.closeall() - logger.info("PG connection pool closed") + self._engine.dispose() + logger.info("SQLAlchemy engine disposed") except Exception: - logger.exception("Failed to close PG connection pool") + logger.exception("Failed to dispose engine") def _is_retriable_error(self, error: Exception) -> bool: - """ - Check if an error is retriable (connection-related issues). - - Args: - error: Exception to check - - Returns: - True if the error is retriable, False otherwise - """ - # Retry on connection-related errors - if isinstance(error, (OperationalError, InterfaceError)): + """Check if error is retriable (connection-related issues).""" + # Check for psycopg2 connection errors directly + if isinstance(error, (psycopg2.OperationalError, psycopg2.InterfaceError)): return True - # Check error message for specific connection issues error_msg = str(error).lower() retriable_patterns = [ "connection", @@ -234,34 +151,18 @@ class AliyunLogStorePG: "reset by peer", "no route to host", "network", + "operational error", + "interface error", ] return any(pattern in error_msg for pattern in retriable_patterns) def put_log(self, logstore: str, contents: Sequence[tuple[str, str]], log_enabled: bool = False) -> None: - """ - Write log to SLS using PostgreSQL protocol with automatic retry. - - Note: SLS PG protocol only supports INSERT (not UPDATE). This uses append-only - writes with log_version field for versioning, same as SDK implementation. - - Args: - logstore: Name of the logstore table - contents: List of (field_name, value) tuples - log_enabled: Whether to enable logging - - Raises: - psycopg2.Error: If database operation fails after all retries - """ + """Write log to SLS using INSERT with automatic retry (3 attempts with exponential backoff).""" if not contents: return - # Extract field names and values from contents fields = [field_name for field_name, _ in contents] values = [value for _, value in contents] - - # Build INSERT statement with literal values - # Note: Aliyun SLS PG protocol doesn't support parameterized queries, - # so we need to use mogrify to safely create literal values field_list = ", ".join([f'"{field}"' for field in fields]) if log_enabled: @@ -272,67 +173,40 @@ class AliyunLogStorePG: len(contents), ) - # Retry configuration max_retries = 3 - retry_delay = 0.1 # Start with 100ms + retry_delay = 0.1 for attempt in range(max_retries): try: with self._get_connection() as conn: with conn.cursor() as cursor: - # Use mogrify to safely convert values to SQL literals placeholders = ", ".join(["%s"] * len(fields)) values_literal = cursor.mogrify(f"({placeholders})", values).decode("utf-8") insert_sql = f'INSERT INTO "{logstore}" ({field_list}) VALUES {values_literal}' cursor.execute(insert_sql) - # Success - exit retry loop return except psycopg2.Error as e: - # Check if error is retriable if not self._is_retriable_error(e): - # Not a retriable error (e.g., data validation error), fail immediately - logger.exception( - "Failed to put logs to logstore %s via PG protocol (non-retriable error)", - logstore, - ) + logger.exception("Failed to put logs to logstore %s (non-retriable error)", logstore) raise - # Retriable error - log and retry if we have attempts left if attempt < max_retries - 1: logger.warning( - "Failed to put logs to logstore %s via PG protocol (attempt %d/%d): %s. Retrying...", + "Failed to put logs to logstore %s (attempt %d/%d): %s. Retrying...", logstore, attempt + 1, max_retries, str(e), ) time.sleep(retry_delay) - retry_delay *= 2 # Exponential backoff + retry_delay *= 2 else: - # Last attempt failed - logger.exception( - "Failed to put logs to logstore %s via PG protocol after %d attempts", - logstore, - max_retries, - ) + logger.exception("Failed to put logs to logstore %s after %d attempts", logstore, max_retries) raise def execute_sql(self, sql: str, logstore: str, log_enabled: bool = False) -> list[dict[str, Any]]: - """ - Execute SQL query using PostgreSQL protocol with automatic retry. - - Args: - sql: SQL query string - logstore: Name of the logstore (for logging purposes) - log_enabled: Whether to enable logging - - Returns: - List of result rows as dictionaries - - Raises: - psycopg2.Error: If database operation fails after all retries - """ + """Execute SQL query with automatic retry (3 attempts with exponential backoff).""" if log_enabled: logger.info( "[LogStore-PG] EXECUTE_SQL | logstore=%s | project=%s | sql=%s", @@ -341,20 +215,16 @@ class AliyunLogStorePG: sql, ) - # Retry configuration max_retries = 3 - retry_delay = 0.1 # Start with 100ms + retry_delay = 0.1 for attempt in range(max_retries): try: with self._get_connection() as conn: with conn.cursor() as cursor: cursor.execute(sql) - - # Get column names from cursor description columns = [desc[0] for desc in cursor.description] - # Fetch all results and convert to list of dicts result = [] for row in cursor.fetchall(): row_dict = {} @@ -372,36 +242,31 @@ class AliyunLogStorePG: return result except psycopg2.Error as e: - # Check if error is retriable if not self._is_retriable_error(e): - # Not a retriable error (e.g., SQL syntax error), fail immediately logger.exception( - "Failed to execute SQL query on logstore %s via PG protocol (non-retriable error): sql=%s", + "Failed to execute SQL on logstore %s (non-retriable error): sql=%s", logstore, sql, ) raise - # Retriable error - log and retry if we have attempts left if attempt < max_retries - 1: logger.warning( - "Failed to execute SQL query on logstore %s via PG protocol (attempt %d/%d): %s. Retrying...", + "Failed to execute SQL on logstore %s (attempt %d/%d): %s. Retrying...", logstore, attempt + 1, max_retries, str(e), ) time.sleep(retry_delay) - retry_delay *= 2 # Exponential backoff + retry_delay *= 2 else: - # Last attempt failed logger.exception( - "Failed to execute SQL query on logstore %s via PG protocol after %d attempts: sql=%s", + "Failed to execute SQL on logstore %s after %d attempts: sql=%s", logstore, max_retries, sql, ) raise - # This line should never be reached due to raise above, but makes type checker happy return [] diff --git a/api/extensions/logstore/repositories/__init__.py b/api/extensions/logstore/repositories/__init__.py index e69de29bb2..b5a4fcf844 100644 --- a/api/extensions/logstore/repositories/__init__.py +++ b/api/extensions/logstore/repositories/__init__.py @@ -0,0 +1,29 @@ +""" +LogStore repository utilities. +""" + +from typing import Any + + +def safe_float(value: Any, default: float = 0.0) -> float: + """ + Safely convert a value to float, handling 'null' strings and None. + """ + if value is None or value in {"null", ""}: + return default + try: + return float(value) + except (ValueError, TypeError): + return default + + +def safe_int(value: Any, default: int = 0) -> int: + """ + Safely convert a value to int, handling 'null' strings and None. + """ + if value is None or value in {"null", ""}: + return default + try: + return int(float(value)) + except (ValueError, TypeError): + return default diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py index 8c804d6bb5..f67723630b 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py @@ -14,6 +14,8 @@ from typing import Any from sqlalchemy.orm import sessionmaker from extensions.logstore.aliyun_logstore import AliyunLogStore +from extensions.logstore.repositories import safe_float, safe_int +from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value from models.workflow import WorkflowNodeExecutionModel from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository @@ -52,9 +54,8 @@ def _dict_to_workflow_node_execution_model(data: dict[str, Any]) -> WorkflowNode model.created_by_role = data.get("created_by_role") or "" model.created_by = data.get("created_by") or "" - # Numeric fields with defaults - model.index = int(data.get("index", 0)) - model.elapsed_time = float(data.get("elapsed_time", 0)) + model.index = safe_int(data.get("index", 0)) + model.elapsed_time = safe_float(data.get("elapsed_time", 0)) # Optional fields model.workflow_run_id = data.get("workflow_run_id") @@ -130,6 +131,12 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep node_id, ) try: + # Escape parameters to prevent SQL injection + escaped_tenant_id = escape_identifier(tenant_id) + escaped_app_id = escape_identifier(app_id) + escaped_workflow_id = escape_identifier(workflow_id) + escaped_node_id = escape_identifier(node_id) + # Check if PG protocol is supported if self.logstore_client.supports_pg_protocol: # Use PG protocol with SQL query (get latest version of each record) @@ -138,10 +145,10 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn FROM "{AliyunLogStore.workflow_node_execution_logstore}" - WHERE tenant_id = '{tenant_id}' - AND app_id = '{app_id}' - AND workflow_id = '{workflow_id}' - AND node_id = '{node_id}' + WHERE tenant_id = '{escaped_tenant_id}' + AND app_id = '{escaped_app_id}' + AND workflow_id = '{escaped_workflow_id}' + AND node_id = '{escaped_node_id}' AND __time__ > 0 ) AS subquery WHERE rn = 1 LIMIT 100 @@ -153,7 +160,8 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep else: # Use SDK with LogStore query syntax query = ( - f"tenant_id: {tenant_id} and app_id: {app_id} and workflow_id: {workflow_id} and node_id: {node_id}" + f"tenant_id: {escaped_tenant_id} and app_id: {escaped_app_id} " + f"and workflow_id: {escaped_workflow_id} and node_id: {escaped_node_id}" ) from_time = 0 to_time = int(time.time()) # now @@ -227,6 +235,11 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep workflow_run_id, ) try: + # Escape parameters to prevent SQL injection + escaped_tenant_id = escape_identifier(tenant_id) + escaped_app_id = escape_identifier(app_id) + escaped_workflow_run_id = escape_identifier(workflow_run_id) + # Check if PG protocol is supported if self.logstore_client.supports_pg_protocol: # Use PG protocol with SQL query (get latest version of each record) @@ -235,9 +248,9 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn FROM "{AliyunLogStore.workflow_node_execution_logstore}" - WHERE tenant_id = '{tenant_id}' - AND app_id = '{app_id}' - AND workflow_run_id = '{workflow_run_id}' + WHERE tenant_id = '{escaped_tenant_id}' + AND app_id = '{escaped_app_id}' + AND workflow_run_id = '{escaped_workflow_run_id}' AND __time__ > 0 ) AS subquery WHERE rn = 1 LIMIT 1000 @@ -248,7 +261,10 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep ) else: # Use SDK with LogStore query syntax - query = f"tenant_id: {tenant_id} and app_id: {app_id} and workflow_run_id: {workflow_run_id}" + query = ( + f"tenant_id: {escaped_tenant_id} and app_id: {escaped_app_id} " + f"and workflow_run_id: {escaped_workflow_run_id}" + ) from_time = 0 to_time = int(time.time()) # now @@ -313,16 +329,24 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep """ logger.debug("get_execution_by_id: execution_id=%s, tenant_id=%s", execution_id, tenant_id) try: + # Escape parameters to prevent SQL injection + escaped_execution_id = escape_identifier(execution_id) + # Check if PG protocol is supported if self.logstore_client.supports_pg_protocol: # Use PG protocol with SQL query (get latest version of record) - tenant_filter = f"AND tenant_id = '{tenant_id}'" if tenant_id else "" + if tenant_id: + escaped_tenant_id = escape_identifier(tenant_id) + tenant_filter = f"AND tenant_id = '{escaped_tenant_id}'" + else: + tenant_filter = "" + sql_query = f""" SELECT * FROM ( SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn FROM "{AliyunLogStore.workflow_node_execution_logstore}" - WHERE id = '{execution_id}' {tenant_filter} AND __time__ > 0 + WHERE id = '{escaped_execution_id}' {tenant_filter} AND __time__ > 0 ) AS subquery WHERE rn = 1 LIMIT 1 """ @@ -332,10 +356,14 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep ) else: # Use SDK with LogStore query syntax + # Note: Values must be quoted in LogStore query syntax to prevent injection if tenant_id: - query = f"id: {execution_id} and tenant_id: {tenant_id}" + query = ( + f"id:{escape_logstore_query_value(execution_id)} " + f"and tenant_id:{escape_logstore_query_value(tenant_id)}" + ) else: - query = f"id: {execution_id}" + query = f"id:{escape_logstore_query_value(execution_id)}" from_time = 0 to_time = int(time.time()) # now diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py index 252cdcc4df..14382ed876 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py @@ -10,6 +10,7 @@ Key Features: - Optimized deduplication using finished_at IS NOT NULL filter - Window functions only when necessary (running status queries) - Multi-tenant data isolation and security +- SQL injection prevention via parameter escaping """ import logging @@ -22,6 +23,8 @@ from typing import Any, cast from sqlalchemy.orm import sessionmaker from extensions.logstore.aliyun_logstore import AliyunLogStore +from extensions.logstore.repositories import safe_float, safe_int +from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import WorkflowRunTriggeredFrom from models.workflow import WorkflowRun @@ -63,10 +66,9 @@ def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun: model.created_by_role = data.get("created_by_role") or "" model.created_by = data.get("created_by") or "" - # Numeric fields with defaults - model.total_tokens = int(data.get("total_tokens", 0)) - model.total_steps = int(data.get("total_steps", 0)) - model.exceptions_count = int(data.get("exceptions_count", 0)) + model.total_tokens = safe_int(data.get("total_tokens", 0)) + model.total_steps = safe_int(data.get("total_steps", 0)) + model.exceptions_count = safe_int(data.get("exceptions_count", 0)) # Optional fields model.graph = data.get("graph") @@ -101,7 +103,8 @@ def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun: if model.finished_at and model.created_at: model.elapsed_time = (model.finished_at - model.created_at).total_seconds() else: - model.elapsed_time = float(data.get("elapsed_time", 0)) + # Use safe conversion to handle 'null' strings and None values + model.elapsed_time = safe_float(data.get("elapsed_time", 0)) return model @@ -165,16 +168,26 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): status, ) # Convert triggered_from to list if needed - if isinstance(triggered_from, WorkflowRunTriggeredFrom): + if isinstance(triggered_from, (WorkflowRunTriggeredFrom, str)): triggered_from_list = [triggered_from] else: triggered_from_list = list(triggered_from) - # Build triggered_from filter - triggered_from_filter = " OR ".join([f"triggered_from='{tf.value}'" for tf in triggered_from_list]) + # Escape parameters to prevent SQL injection + escaped_tenant_id = escape_identifier(tenant_id) + escaped_app_id = escape_identifier(app_id) - # Build status filter - status_filter = f"AND status='{status}'" if status else "" + # Build triggered_from filter with escaped values + # Support both enum and string values for triggered_from + triggered_from_filter = " OR ".join( + [ + f"triggered_from='{escape_sql_string(tf.value if isinstance(tf, WorkflowRunTriggeredFrom) else tf)}'" + for tf in triggered_from_list + ] + ) + + # Build status filter with escaped value + status_filter = f"AND status='{escape_sql_string(status)}'" if status else "" # Build last_id filter for pagination # Note: This is simplified. In production, you'd need to track created_at from last record @@ -188,8 +201,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): SELECT * FROM ( SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn FROM {AliyunLogStore.workflow_execution_logstore} - WHERE tenant_id='{tenant_id}' - AND app_id='{app_id}' + WHERE tenant_id='{escaped_tenant_id}' + AND app_id='{escaped_app_id}' AND ({triggered_from_filter}) {status_filter} {last_id_filter} @@ -232,6 +245,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): logger.debug("get_workflow_run_by_id: tenant_id=%s, app_id=%s, run_id=%s", tenant_id, app_id, run_id) try: + # Escape parameters to prevent SQL injection + escaped_run_id = escape_identifier(run_id) + escaped_tenant_id = escape_identifier(tenant_id) + escaped_app_id = escape_identifier(app_id) + # Check if PG protocol is supported if self.logstore_client.supports_pg_protocol: # Use PG protocol with SQL query (get latest version of record) @@ -240,7 +258,10 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn FROM "{AliyunLogStore.workflow_execution_logstore}" - WHERE id = '{run_id}' AND tenant_id = '{tenant_id}' AND app_id = '{app_id}' AND __time__ > 0 + WHERE id = '{escaped_run_id}' + AND tenant_id = '{escaped_tenant_id}' + AND app_id = '{escaped_app_id}' + AND __time__ > 0 ) AS subquery WHERE rn = 1 LIMIT 100 """ @@ -250,7 +271,12 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): ) else: # Use SDK with LogStore query syntax - query = f"id: {run_id} and tenant_id: {tenant_id} and app_id: {app_id}" + # Note: Values must be quoted in LogStore query syntax to prevent injection + query = ( + f"id:{escape_logstore_query_value(run_id)} " + f"and tenant_id:{escape_logstore_query_value(tenant_id)} " + f"and app_id:{escape_logstore_query_value(app_id)}" + ) from_time = 0 to_time = int(time.time()) # now @@ -323,6 +349,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): logger.debug("get_workflow_run_by_id_without_tenant: run_id=%s", run_id) try: + # Escape parameter to prevent SQL injection + escaped_run_id = escape_identifier(run_id) + # Check if PG protocol is supported if self.logstore_client.supports_pg_protocol: # Use PG protocol with SQL query (get latest version of record) @@ -331,7 +360,7 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn FROM "{AliyunLogStore.workflow_execution_logstore}" - WHERE id = '{run_id}' AND __time__ > 0 + WHERE id = '{escaped_run_id}' AND __time__ > 0 ) AS subquery WHERE rn = 1 LIMIT 100 """ @@ -341,7 +370,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): ) else: # Use SDK with LogStore query syntax - query = f"id: {run_id}" + # Note: Values must be quoted in LogStore query syntax + query = f"id:{escape_logstore_query_value(run_id)}" from_time = 0 to_time = int(time.time()) # now @@ -410,6 +440,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): triggered_from, status, ) + # Escape parameters to prevent SQL injection + escaped_tenant_id = escape_identifier(tenant_id) + escaped_app_id = escape_identifier(app_id) + escaped_triggered_from = escape_sql_string(triggered_from) + # Build time range filter time_filter = "" if time_range: @@ -418,6 +453,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): # If status is provided, simple count if status: + escaped_status = escape_sql_string(status) + if status == "running": # Running status requires window function sql = f""" @@ -425,9 +462,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): FROM ( SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn FROM {AliyunLogStore.workflow_execution_logstore} - WHERE tenant_id='{tenant_id}' - AND app_id='{app_id}' - AND triggered_from='{triggered_from}' + WHERE tenant_id='{escaped_tenant_id}' + AND app_id='{escaped_app_id}' + AND triggered_from='{escaped_triggered_from}' AND status='running' {time_filter} ) t @@ -438,10 +475,10 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): sql = f""" SELECT COUNT(DISTINCT id) as count FROM {AliyunLogStore.workflow_execution_logstore} - WHERE tenant_id='{tenant_id}' - AND app_id='{app_id}' - AND triggered_from='{triggered_from}' - AND status='{status}' + WHERE tenant_id='{escaped_tenant_id}' + AND app_id='{escaped_app_id}' + AND triggered_from='{escaped_triggered_from}' + AND status='{escaped_status}' AND finished_at IS NOT NULL {time_filter} """ @@ -467,13 +504,14 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): # No status filter - get counts grouped by status # Use optimized query for finished runs, separate query for running try: + # Escape parameters (already escaped above, reuse variables) # Count finished runs grouped by status finished_sql = f""" SELECT status, COUNT(DISTINCT id) as count FROM {AliyunLogStore.workflow_execution_logstore} - WHERE tenant_id='{tenant_id}' - AND app_id='{app_id}' - AND triggered_from='{triggered_from}' + WHERE tenant_id='{escaped_tenant_id}' + AND app_id='{escaped_app_id}' + AND triggered_from='{escaped_triggered_from}' AND finished_at IS NOT NULL {time_filter} GROUP BY status @@ -485,9 +523,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): FROM ( SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn FROM {AliyunLogStore.workflow_execution_logstore} - WHERE tenant_id='{tenant_id}' - AND app_id='{app_id}' - AND triggered_from='{triggered_from}' + WHERE tenant_id='{escaped_tenant_id}' + AND app_id='{escaped_app_id}' + AND triggered_from='{escaped_triggered_from}' AND status='running' {time_filter} ) t @@ -546,7 +584,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): logger.debug( "get_daily_runs_statistics: tenant_id=%s, app_id=%s, triggered_from=%s", tenant_id, app_id, triggered_from ) - # Build time range filter + + # Escape parameters to prevent SQL injection + escaped_tenant_id = escape_identifier(tenant_id) + escaped_app_id = escape_identifier(app_id) + escaped_triggered_from = escape_sql_string(triggered_from) + + # Build time range filter (datetime.isoformat() is safe) time_filter = "" if start_date: time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))" @@ -557,9 +601,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): sql = f""" SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT id) as runs FROM {AliyunLogStore.workflow_execution_logstore} - WHERE tenant_id='{tenant_id}' - AND app_id='{app_id}' - AND triggered_from='{triggered_from}' + WHERE tenant_id='{escaped_tenant_id}' + AND app_id='{escaped_app_id}' + AND triggered_from='{escaped_triggered_from}' AND finished_at IS NOT NULL {time_filter} GROUP BY date @@ -601,7 +645,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): app_id, triggered_from, ) - # Build time range filter + + # Escape parameters to prevent SQL injection + escaped_tenant_id = escape_identifier(tenant_id) + escaped_app_id = escape_identifier(app_id) + escaped_triggered_from = escape_sql_string(triggered_from) + + # Build time range filter (datetime.isoformat() is safe) time_filter = "" if start_date: time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))" @@ -611,9 +661,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): sql = f""" SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT created_by) as terminal_count FROM {AliyunLogStore.workflow_execution_logstore} - WHERE tenant_id='{tenant_id}' - AND app_id='{app_id}' - AND triggered_from='{triggered_from}' + WHERE tenant_id='{escaped_tenant_id}' + AND app_id='{escaped_app_id}' + AND triggered_from='{escaped_triggered_from}' AND finished_at IS NOT NULL {time_filter} GROUP BY date @@ -655,7 +705,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): app_id, triggered_from, ) - # Build time range filter + + # Escape parameters to prevent SQL injection + escaped_tenant_id = escape_identifier(tenant_id) + escaped_app_id = escape_identifier(app_id) + escaped_triggered_from = escape_sql_string(triggered_from) + + # Build time range filter (datetime.isoformat() is safe) time_filter = "" if start_date: time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))" @@ -665,9 +721,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): sql = f""" SELECT DATE(from_unixtime(__time__)) as date, SUM(total_tokens) as token_count FROM {AliyunLogStore.workflow_execution_logstore} - WHERE tenant_id='{tenant_id}' - AND app_id='{app_id}' - AND triggered_from='{triggered_from}' + WHERE tenant_id='{escaped_tenant_id}' + AND app_id='{escaped_app_id}' + AND triggered_from='{escaped_triggered_from}' AND finished_at IS NOT NULL {time_filter} GROUP BY date @@ -709,7 +765,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): app_id, triggered_from, ) - # Build time range filter + + # Escape parameters to prevent SQL injection + escaped_tenant_id = escape_identifier(tenant_id) + escaped_app_id = escape_identifier(app_id) + escaped_triggered_from = escape_sql_string(triggered_from) + + # Build time range filter (datetime.isoformat() is safe) time_filter = "" if start_date: time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))" @@ -726,9 +788,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): created_by, COUNT(DISTINCT id) AS interactions FROM {AliyunLogStore.workflow_execution_logstore} - WHERE tenant_id='{tenant_id}' - AND app_id='{app_id}' - AND triggered_from='{triggered_from}' + WHERE tenant_id='{escaped_tenant_id}' + AND app_id='{escaped_app_id}' + AND triggered_from='{escaped_triggered_from}' AND finished_at IS NOT NULL {time_filter} GROUP BY date, created_by diff --git a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py index 1119534d52..9928879a7b 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py @@ -10,6 +10,7 @@ from sqlalchemy.orm import sessionmaker from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from core.workflow.entities import WorkflowExecution from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository +from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter from extensions.logstore.aliyun_logstore import AliyunLogStore from libs.helper import extract_tenant_id from models import ( @@ -22,18 +23,6 @@ from models.enums import WorkflowRunTriggeredFrom logger = logging.getLogger(__name__) -def to_serializable(obj): - """ - Convert non-JSON-serializable objects into JSON-compatible formats. - - - Uses `to_dict()` if it's a callable method. - - Falls back to string representation. - """ - if hasattr(obj, "to_dict") and callable(obj.to_dict): - return obj.to_dict() - return str(obj) - - class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository): def __init__( self, @@ -79,7 +68,7 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository): # Control flag for dual-write (write to both LogStore and SQL database) # Set to True to enable dual-write for safe migration, False to use LogStore only - self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true" + self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "false").lower() == "true" # Control flag for whether to write the `graph` field to LogStore. # If LOGSTORE_ENABLE_PUT_GRAPH_FIELD is "true", write the full `graph` field; @@ -113,6 +102,9 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository): # Generate log_version as nanosecond timestamp for record versioning log_version = str(time.time_ns()) + # Use WorkflowRuntimeTypeConverter to handle complex types (Segment, File, etc.) + json_converter = WorkflowRuntimeTypeConverter() + logstore_model = [ ("id", domain_model.id_), ("log_version", log_version), # Add log_version field for append-only writes @@ -127,19 +119,19 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository): ("version", domain_model.workflow_version), ( "graph", - json.dumps(domain_model.graph, ensure_ascii=False, default=to_serializable) + json.dumps(json_converter.to_json_encodable(domain_model.graph), ensure_ascii=False) if domain_model.graph and self._enable_put_graph_field else "{}", ), ( "inputs", - json.dumps(domain_model.inputs, ensure_ascii=False, default=to_serializable) + json.dumps(json_converter.to_json_encodable(domain_model.inputs), ensure_ascii=False) if domain_model.inputs else "{}", ), ( "outputs", - json.dumps(domain_model.outputs, ensure_ascii=False, default=to_serializable) + json.dumps(json_converter.to_json_encodable(domain_model.outputs), ensure_ascii=False) if domain_model.outputs else "{}", ), diff --git a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py index 400a089516..4897171b12 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py @@ -24,6 +24,8 @@ from core.workflow.enums import NodeType from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter from extensions.logstore.aliyun_logstore import AliyunLogStore +from extensions.logstore.repositories import safe_float, safe_int +from extensions.logstore.sql_escape import escape_identifier from libs.helper import extract_tenant_id from models import ( Account, @@ -73,7 +75,7 @@ def _dict_to_workflow_node_execution(data: dict[str, Any]) -> WorkflowNodeExecut node_execution_id=data.get("node_execution_id"), workflow_id=data.get("workflow_id", ""), workflow_execution_id=data.get("workflow_run_id"), - index=int(data.get("index", 0)), + index=safe_int(data.get("index", 0)), predecessor_node_id=data.get("predecessor_node_id"), node_id=data.get("node_id", ""), node_type=NodeType(data.get("node_type", "start")), @@ -83,7 +85,7 @@ def _dict_to_workflow_node_execution(data: dict[str, Any]) -> WorkflowNodeExecut outputs=outputs, status=status, error=data.get("error"), - elapsed_time=float(data.get("elapsed_time", 0.0)), + elapsed_time=safe_float(data.get("elapsed_time", 0.0)), metadata=domain_metadata, created_at=created_at, finished_at=finished_at, @@ -147,7 +149,7 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): # Control flag for dual-write (write to both LogStore and SQL database) # Set to True to enable dual-write for safe migration, False to use LogStore only - self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true" + self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "false").lower() == "true" def _to_logstore_model(self, domain_model: WorkflowNodeExecution) -> Sequence[tuple[str, str]]: logger.debug( @@ -274,16 +276,34 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): Save or update the inputs, process_data, or outputs associated with a specific node_execution record. - For LogStore implementation, this is similar to save() since we always write - complete records. We append a new record with updated data fields. + For LogStore implementation, this is a no-op for the LogStore write because save() + already writes all fields including inputs, process_data, and outputs. The caller + typically calls save() first to persist status/metadata, then calls save_execution_data() + to persist data fields. Since LogStore writes complete records atomically, we don't + need a separate write here to avoid duplicate records. + + However, if dual-write is enabled, we still need to call the SQL repository's + save_execution_data() method to properly update the SQL database. Args: execution: The NodeExecution instance with data to save """ - logger.debug("save_execution_data: id=%s, node_execution_id=%s", execution.id, execution.node_execution_id) - # In LogStore, we simply write a new complete record with the data - # The log_version timestamp will ensure this is treated as the latest version - self.save(execution) + logger.debug( + "save_execution_data: no-op for LogStore (data already saved by save()): id=%s, node_execution_id=%s", + execution.id, + execution.node_execution_id, + ) + # No-op for LogStore: save() already writes all fields including inputs, process_data, and outputs + # Calling save() again would create a duplicate record in the append-only LogStore + + # Dual-write to SQL database if enabled (for safe migration) + if self._enable_dual_write: + try: + self.sql_repository.save_execution_data(execution) + logger.debug("Dual-write: saved node execution data to SQL database: id=%s", execution.id) + except Exception: + logger.exception("Failed to dual-write node execution data to SQL database: id=%s", execution.id) + # Don't raise - LogStore write succeeded, SQL is just a backup def get_by_workflow_run( self, @@ -292,8 +312,8 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): ) -> Sequence[WorkflowNodeExecution]: """ Retrieve all NodeExecution instances for a specific workflow run. - Uses LogStore SQL query with finished_at IS NOT NULL filter for deduplication. - This ensures we only get the final version of each node execution. + Uses LogStore SQL query with window function to get the latest version of each node execution. + This ensures we only get the most recent version of each node execution record. Args: workflow_run_id: The workflow run ID order_config: Optional configuration for ordering results @@ -304,16 +324,19 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): A list of NodeExecution instances Note: - This method filters by finished_at IS NOT NULL to avoid duplicates from - version updates. For complete history including intermediate states, - a different query strategy would be needed. + This method uses ROW_NUMBER() window function partitioned by node_execution_id + to get the latest version (highest log_version) of each node execution. """ logger.debug("get_by_workflow_run: workflow_run_id=%s, order_config=%s", workflow_run_id, order_config) - # Build SQL query with deduplication using finished_at IS NOT NULL - # This optimization avoids window functions for common case where we only - # want the final state of each node execution + # Build SQL query with deduplication using window function + # ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC) + # ensures we get the latest version of each node execution - # Build ORDER BY clause + # Escape parameters to prevent SQL injection + escaped_workflow_run_id = escape_identifier(workflow_run_id) + escaped_tenant_id = escape_identifier(self._tenant_id) + + # Build ORDER BY clause for outer query order_clause = "" if order_config and order_config.order_by: order_fields = [] @@ -327,16 +350,23 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): if order_fields: order_clause = "ORDER BY " + ", ".join(order_fields) - sql = f""" - SELECT * - FROM {AliyunLogStore.workflow_node_execution_logstore} - WHERE workflow_run_id='{workflow_run_id}' - AND tenant_id='{self._tenant_id}' - AND finished_at IS NOT NULL - """ - + # Build app_id filter for subquery + app_id_filter = "" if self._app_id: - sql += f" AND app_id='{self._app_id}'" + escaped_app_id = escape_identifier(self._app_id) + app_id_filter = f" AND app_id='{escaped_app_id}'" + + # Use window function to get latest version of each node execution + sql = f""" + SELECT * FROM ( + SELECT *, ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC) AS rn + FROM {AliyunLogStore.workflow_node_execution_logstore} + WHERE workflow_run_id='{escaped_workflow_run_id}' + AND tenant_id='{escaped_tenant_id}' + {app_id_filter} + ) t + WHERE rn = 1 + """ if order_clause: sql += f" {order_clause}" diff --git a/api/extensions/logstore/sql_escape.py b/api/extensions/logstore/sql_escape.py new file mode 100644 index 0000000000..d88d6bd959 --- /dev/null +++ b/api/extensions/logstore/sql_escape.py @@ -0,0 +1,134 @@ +""" +SQL Escape Utility for LogStore Queries + +This module provides escaping utilities to prevent injection attacks in LogStore queries. + +LogStore supports two query modes: +1. PG Protocol Mode: Uses SQL syntax with single quotes for strings +2. SDK Mode: Uses LogStore query syntax (key: value) with double quotes + +Key Security Concerns: +- Prevent tenant A from accessing tenant B's data via injection +- SLS queries are read-only, so we focus on data access control +- Different escaping strategies for SQL vs LogStore query syntax +""" + + +def escape_sql_string(value: str) -> str: + """ + Escape a string value for safe use in SQL queries. + + This function escapes single quotes by doubling them, which is the standard + SQL escaping method. This prevents SQL injection by ensuring that user input + cannot break out of string literals. + + Args: + value: The string value to escape + + Returns: + Escaped string safe for use in SQL queries + + Examples: + >>> escape_sql_string("normal_value") + "normal_value" + >>> escape_sql_string("value' OR '1'='1") + "value'' OR ''1''=''1" + >>> escape_sql_string("tenant's_id") + "tenant''s_id" + + Security: + - Prevents breaking out of string literals + - Stops injection attacks like: ' OR '1'='1 + - Protects against cross-tenant data access + """ + if not value: + return value + + # Escape single quotes by doubling them (standard SQL escaping) + # This prevents breaking out of string literals in SQL queries + return value.replace("'", "''") + + +def escape_identifier(value: str) -> str: + """ + Escape an identifier (tenant_id, app_id, run_id, etc.) for safe SQL use. + + This function is for PG protocol mode (SQL syntax). + For SDK mode, use escape_logstore_query_value() instead. + + Args: + value: The identifier value to escape + + Returns: + Escaped identifier safe for use in SQL queries + + Examples: + >>> escape_identifier("550e8400-e29b-41d4-a716-446655440000") + "550e8400-e29b-41d4-a716-446655440000" + >>> escape_identifier("tenant_id' OR '1'='1") + "tenant_id'' OR ''1''=''1" + + Security: + - Prevents SQL injection via identifiers + - Stops cross-tenant access attempts + - Works for UUIDs, alphanumeric IDs, and similar identifiers + """ + # For identifiers, use the same escaping as strings + # This is simple and effective for preventing injection + return escape_sql_string(value) + + +def escape_logstore_query_value(value: str) -> str: + """ + Escape value for LogStore query syntax (SDK mode). + + LogStore query syntax rules: + 1. Keywords (and/or/not) are case-insensitive + 2. Single quotes are ordinary characters (no special meaning) + 3. Double quotes wrap values: key:"value" + 4. Backslash is the escape character: + - \" for double quote inside value + - \\ for backslash itself + 5. Parentheses can change query structure + + To prevent injection: + - Wrap value in double quotes to treat special chars as literals + - Escape backslashes and double quotes using backslash + + Args: + value: The value to escape for LogStore query syntax + + Returns: + Quoted and escaped value safe for LogStore query syntax (includes the quotes) + + Examples: + >>> escape_logstore_query_value("normal_value") + '"normal_value"' + >>> escape_logstore_query_value("value or field:evil") + '"value or field:evil"' # 'or' and ':' are now literals + >>> escape_logstore_query_value('value"test') + '"value\\"test"' # Internal double quote escaped + >>> escape_logstore_query_value('value\\test') + '"value\\\\test"' # Backslash escaped + + Security: + - Prevents injection via and/or/not keywords + - Prevents injection via colons (:) + - Prevents injection via parentheses + - Protects against cross-tenant data access + + Note: + Escape order is critical: backslash first, then double quotes. + Otherwise, we'd double-escape the escape character itself. + """ + if not value: + return '""' + + # IMPORTANT: Escape backslashes FIRST, then double quotes + # This prevents double-escaping (e.g., " -> \" -> \\" incorrectly) + escaped = value.replace("\\", "\\\\") # \ -> \\ + escaped = escaped.replace('"', '\\"') # " -> \" + + # Wrap in double quotes to treat as literal string + # This prevents and/or/not/:/() from being interpreted as operators + return f'"{escaped}"' diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 494194369a..3f030ae127 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -38,7 +38,7 @@ from core.variables.variables import ( ObjectVariable, SecretVariable, StringVariable, - Variable, + VariableBase, ) from core.workflow.constants import ( CONVERSATION_VARIABLE_NODE_ID, @@ -72,25 +72,25 @@ SEGMENT_TO_VARIABLE_MAP = { } -def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: +def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase: if not mapping.get("name"): raise VariableError("missing name") return _build_variable_from_mapping(mapping=mapping, selector=[CONVERSATION_VARIABLE_NODE_ID, mapping["name"]]) -def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: +def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase: if not mapping.get("name"): raise VariableError("missing name") return _build_variable_from_mapping(mapping=mapping, selector=[ENVIRONMENT_VARIABLE_NODE_ID, mapping["name"]]) -def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: +def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase: if not mapping.get("variable"): raise VariableError("missing variable") return mapping["variable"] -def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable: +def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> VariableBase: """ This factory function is used to create the environment variable or the conversation variable, not support the File type. @@ -100,7 +100,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen if (value := mapping.get("value")) is None: raise VariableError("missing value") - result: Variable + result: VariableBase match value_type: case SegmentType.STRING: result = StringVariable.model_validate(mapping) @@ -134,7 +134,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}") if not result.selector: result = result.model_copy(update={"selector": selector}) - return cast(Variable, result) + return cast(VariableBase, result) def build_segment(value: Any, /) -> Segment: @@ -285,8 +285,8 @@ def segment_to_variable( id: str | None = None, name: str | None = None, description: str = "", -) -> Variable: - if isinstance(segment, Variable): +) -> VariableBase: + if isinstance(segment, VariableBase): return segment name = name or selector[-1] id = id or str(uuid4()) @@ -297,7 +297,7 @@ def segment_to_variable( variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type] return cast( - Variable, + VariableBase, variable_class( id=id, name=name, diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index d037b0c442..2755f77f61 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -1,7 +1,7 @@ from flask_restx import fields from core.helper import encrypter -from core.variables import SecretVariable, SegmentType, Variable +from core.variables import SecretVariable, SegmentType, VariableBase from fields.member_fields import simple_account_fields from libs.helper import TimestampField @@ -21,7 +21,7 @@ class EnvironmentVariableField(fields.Raw): "value_type": value.value_type.value, "description": value.description, } - if isinstance(value, Variable): + if isinstance(value, VariableBase): return { "id": value.id, "name": value.name, diff --git a/api/models/workflow.py b/api/models/workflow.py index 072c6100b5..5d92da3fa1 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1,11 +1,9 @@ -from __future__ import annotations - import json import logging from collections.abc import Generator, Mapping, Sequence from datetime import datetime from enum import StrEnum -from typing import TYPE_CHECKING, Any, Union, cast +from typing import TYPE_CHECKING, Any, Optional, Union, cast from uuid import uuid4 import sqlalchemy as sa @@ -46,7 +44,7 @@ if TYPE_CHECKING: from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE from core.helper import encrypter -from core.variables import SecretVariable, Segment, SegmentType, Variable +from core.variables import SecretVariable, Segment, SegmentType, VariableBase from factories import variable_factory from libs import helper @@ -69,7 +67,7 @@ class WorkflowType(StrEnum): RAG_PIPELINE = "rag-pipeline" @classmethod - def value_of(cls, value: str) -> WorkflowType: + def value_of(cls, value: str) -> "WorkflowType": """ Get value of given mode. @@ -82,7 +80,7 @@ class WorkflowType(StrEnum): raise ValueError(f"invalid workflow type value {value}") @classmethod - def from_app_mode(cls, app_mode: Union[str, AppMode]) -> WorkflowType: + def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType": """ Get workflow type from app mode. @@ -178,12 +176,12 @@ class Workflow(Base): # bug graph: str, features: str, created_by: str, - environment_variables: Sequence[Variable], - conversation_variables: Sequence[Variable], + environment_variables: Sequence[VariableBase], + conversation_variables: Sequence[VariableBase], rag_pipeline_variables: list[dict], marked_name: str = "", marked_comment: str = "", - ) -> Workflow: + ) -> "Workflow": workflow = Workflow() workflow.id = str(uuid4()) workflow.tenant_id = tenant_id @@ -447,7 +445,7 @@ class Workflow(Base): # bug # decrypt secret variables value def decrypt_func( - var: Variable, + var: VariableBase, ) -> StringVariable | IntegerVariable | FloatVariable | SecretVariable: if isinstance(var, SecretVariable): return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)}) @@ -463,7 +461,7 @@ class Workflow(Base): # bug return decrypted_results @environment_variables.setter - def environment_variables(self, value: Sequence[Variable]): + def environment_variables(self, value: Sequence[VariableBase]): if not value: self._environment_variables = "{}" return @@ -487,7 +485,7 @@ class Workflow(Base): # bug value[i] = origin_variables_dictionary[variable.id].model_copy(update={"name": variable.name}) # encrypt secret variables value - def encrypt_func(var: Variable) -> Variable: + def encrypt_func(var: VariableBase) -> VariableBase: if isinstance(var, SecretVariable): return var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)}) else: @@ -517,7 +515,7 @@ class Workflow(Base): # bug return result @property - def conversation_variables(self) -> Sequence[Variable]: + def conversation_variables(self) -> Sequence[VariableBase]: # TODO: find some way to init `self._conversation_variables` when instance created. if self._conversation_variables is None: self._conversation_variables = "{}" @@ -527,7 +525,7 @@ class Workflow(Base): # bug return results @conversation_variables.setter - def conversation_variables(self, value: Sequence[Variable]): + def conversation_variables(self, value: Sequence[VariableBase]): self._conversation_variables = json.dumps( {var.name: var.model_dump() for var in value}, ensure_ascii=False, @@ -622,7 +620,7 @@ class WorkflowRun(Base): finished_at: Mapped[datetime | None] = mapped_column(DateTime) exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True) - pause: Mapped[WorkflowPause | None] = orm.relationship( + pause: Mapped[Optional["WorkflowPause"]] = orm.relationship( "WorkflowPause", primaryjoin="WorkflowRun.id == foreign(WorkflowPause.workflow_run_id)", uselist=False, @@ -692,7 +690,7 @@ class WorkflowRun(Base): } @classmethod - def from_dict(cls, data: dict[str, Any]) -> WorkflowRun: + def from_dict(cls, data: dict[str, Any]) -> "WorkflowRun": return cls( id=data.get("id"), tenant_id=data.get("tenant_id"), @@ -844,7 +842,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo created_by: Mapped[str] = mapped_column(StringUUID) finished_at: Mapped[datetime | None] = mapped_column(DateTime) - offload_data: Mapped[list[WorkflowNodeExecutionOffload]] = orm.relationship( + offload_data: Mapped[list["WorkflowNodeExecutionOffload"]] = orm.relationship( "WorkflowNodeExecutionOffload", primaryjoin="WorkflowNodeExecutionModel.id == foreign(WorkflowNodeExecutionOffload.node_execution_id)", uselist=True, @@ -854,13 +852,13 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo @staticmethod def preload_offload_data( - query: Select[tuple[WorkflowNodeExecutionModel]] | orm.Query[WorkflowNodeExecutionModel], + query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"], ): return query.options(orm.selectinload(WorkflowNodeExecutionModel.offload_data)) @staticmethod def preload_offload_data_and_files( - query: Select[tuple[WorkflowNodeExecutionModel]] | orm.Query[WorkflowNodeExecutionModel], + query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"], ): return query.options( orm.selectinload(WorkflowNodeExecutionModel.offload_data).options( @@ -935,7 +933,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo ) return extras - def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> WorkflowNodeExecutionOffload | None: + def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]: return next(iter([i for i in self.offload_data if i.type_ == type_]), None) @property @@ -1049,7 +1047,7 @@ class WorkflowNodeExecutionOffload(Base): back_populates="offload_data", ) - file: Mapped[UploadFile | None] = orm.relationship( + file: Mapped[Optional["UploadFile"]] = orm.relationship( foreign_keys=[file_id], lazy="raise", uselist=False, @@ -1067,7 +1065,7 @@ class WorkflowAppLogCreatedFrom(StrEnum): INSTALLED_APP = "installed-app" @classmethod - def value_of(cls, value: str) -> WorkflowAppLogCreatedFrom: + def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom": """ Get value of given mode. @@ -1184,7 +1182,7 @@ class ConversationVariable(TypeBase): ) @classmethod - def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> ConversationVariable: + def from_variable(cls, *, app_id: str, conversation_id: str, variable: VariableBase) -> "ConversationVariable": obj = cls( id=variable.id, app_id=app_id, @@ -1193,7 +1191,7 @@ class ConversationVariable(TypeBase): ) return obj - def to_variable(self) -> Variable: + def to_variable(self) -> VariableBase: mapping = json.loads(self.data) return variable_factory.build_conversation_variable_from_mapping(mapping) @@ -1337,7 +1335,7 @@ class WorkflowDraftVariable(Base): ) # Relationship to WorkflowDraftVariableFile - variable_file: Mapped[WorkflowDraftVariableFile | None] = orm.relationship( + variable_file: Mapped[Optional["WorkflowDraftVariableFile"]] = orm.relationship( foreign_keys=[file_id], lazy="raise", uselist=False, @@ -1507,7 +1505,7 @@ class WorkflowDraftVariable(Base): node_execution_id: str | None, description: str = "", file_id: str | None = None, - ) -> WorkflowDraftVariable: + ) -> "WorkflowDraftVariable": variable = WorkflowDraftVariable() variable.id = str(uuid4()) variable.created_at = naive_utc_now() @@ -1530,7 +1528,7 @@ class WorkflowDraftVariable(Base): name: str, value: Segment, description: str = "", - ) -> WorkflowDraftVariable: + ) -> "WorkflowDraftVariable": variable = cls._new( app_id=app_id, node_id=CONVERSATION_VARIABLE_NODE_ID, @@ -1551,7 +1549,7 @@ class WorkflowDraftVariable(Base): value: Segment, node_execution_id: str, editable: bool = False, - ) -> WorkflowDraftVariable: + ) -> "WorkflowDraftVariable": variable = cls._new( app_id=app_id, node_id=SYSTEM_VARIABLE_NODE_ID, @@ -1574,7 +1572,7 @@ class WorkflowDraftVariable(Base): visible: bool = True, editable: bool = True, file_id: str | None = None, - ) -> WorkflowDraftVariable: + ) -> "WorkflowDraftVariable": variable = cls._new( app_id=app_id, node_id=node_id, @@ -1670,7 +1668,7 @@ class WorkflowDraftVariableFile(Base): ) # Relationship to UploadFile - upload_file: Mapped[UploadFile] = orm.relationship( + upload_file: Mapped["UploadFile"] = orm.relationship( foreign_keys=[upload_file_id], lazy="raise", uselist=False, @@ -1737,7 +1735,7 @@ class WorkflowPause(DefaultFieldsMixin, Base): state_object_key: Mapped[str] = mapped_column(String(length=255), nullable=False) # Relationship to WorkflowRun - workflow_run: Mapped[WorkflowRun] = orm.relationship( + workflow_run: Mapped["WorkflowRun"] = orm.relationship( foreign_keys=[workflow_run_id], # require explicit preloading. lazy="raise", @@ -1793,7 +1791,7 @@ class WorkflowPauseReason(DefaultFieldsMixin, Base): ) @classmethod - def from_entity(cls, pause_reason: PauseReason) -> WorkflowPauseReason: + def from_entity(cls, pause_reason: PauseReason) -> "WorkflowPauseReason": if isinstance(pause_reason, HumanInputRequired): return cls( type_=PauseReasonType.HUMAN_INPUT_REQUIRED, form_id=pause_reason.form_id, node_id=pause_reason.node_id diff --git a/api/pyproject.toml b/api/pyproject.toml index 7d2d68bc8d..28bd591d17 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dify-api" -version = "1.11.2" +version = "1.11.3" requires-python = ">=3.11,<3.13" dependencies = [ diff --git a/api/services/conversation_variable_updater.py b/api/services/conversation_variable_updater.py index acc0ec2b22..92008d5ff1 100644 --- a/api/services/conversation_variable_updater.py +++ b/api/services/conversation_variable_updater.py @@ -1,7 +1,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker -from core.variables.variables import Variable +from core.variables.variables import VariableBase from models import ConversationVariable @@ -13,7 +13,7 @@ class ConversationVariableUpdater: def __init__(self, session_maker: sessionmaker[Session]) -> None: self._session_maker: sessionmaker[Session] = session_maker - def update(self, conversation_id: str, variable: Variable) -> None: + def update(self, conversation_id: str, variable: VariableBase) -> None: stmt = select(ConversationVariable).where( ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id ) diff --git a/api/services/enterprise/base.py b/api/services/enterprise/base.py index bdc960aa2d..e3832475aa 100644 --- a/api/services/enterprise/base.py +++ b/api/services/enterprise/base.py @@ -1,9 +1,14 @@ +import logging import os from collections.abc import Mapping from typing import Any import httpx +from core.helper.trace_id_helper import generate_traceparent_header + +logger = logging.getLogger(__name__) + class BaseRequest: proxies: Mapping[str, str] | None = { @@ -38,6 +43,15 @@ class BaseRequest: headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key} url = f"{cls.base_url}{endpoint}" mounts = cls._build_mounts() + + try: + # ensure traceparent even when OTEL is disabled + traceparent = generate_traceparent_header() + if traceparent: + headers["traceparent"] = traceparent + except Exception: + logger.debug("Failed to generate traceparent header", exc_info=True) + with httpx.Client(mounts=mounts) as client: response = client.request(method, url, json=json, params=params, headers=headers) return response.json() diff --git a/api/services/plugin/plugin_service.py b/api/services/plugin/plugin_service.py index b8303eb724..411c335c17 100644 --- a/api/services/plugin/plugin_service.py +++ b/api/services/plugin/plugin_service.py @@ -3,6 +3,7 @@ from collections.abc import Mapping, Sequence from mimetypes import guess_type from pydantic import BaseModel +from sqlalchemy import select from yarl import URL from configs import dify_config @@ -25,7 +26,9 @@ from core.plugin.entities.plugin_daemon import ( from core.plugin.impl.asset import PluginAssetManager from core.plugin.impl.debugging import PluginDebuggingClient from core.plugin.impl.plugin import PluginInstaller +from extensions.ext_database import db from extensions.ext_redis import redis_client +from models.provider import ProviderCredential from models.provider_ids import GenericProviderID from services.errors.plugin import PluginInstallationForbiddenError from services.feature_service import FeatureService, PluginInstallationScope @@ -506,6 +509,33 @@ class PluginService: @staticmethod def uninstall(tenant_id: str, plugin_installation_id: str) -> bool: manager = PluginInstaller() + + # Get plugin info before uninstalling to delete associated credentials + try: + plugins = manager.list_plugins(tenant_id) + plugin = next((p for p in plugins if p.installation_id == plugin_installation_id), None) + + if plugin: + plugin_id = plugin.plugin_id + logger.info("Deleting credentials for plugin: %s", plugin_id) + + # Delete provider credentials that match this plugin + credentials = db.session.scalars( + select(ProviderCredential).where( + ProviderCredential.tenant_id == tenant_id, + ProviderCredential.provider_name.like(f"{plugin_id}/%"), + ) + ).all() + + for cred in credentials: + db.session.delete(cred) + + db.session.commit() + logger.info("Deleted %d credentials for plugin: %s", len(credentials), plugin_id) + except Exception as e: + logger.warning("Failed to delete credentials: %s", e) + # Continue with uninstall even if credential deletion fails + return manager.uninstall(tenant_id, plugin_installation_id) @staticmethod diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 1ba64813ba..2d8418900c 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -36,7 +36,7 @@ from core.rag.entities.event import ( ) from core.repositories.factory import DifyCoreRepositoryFactory from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository -from core.variables.variables import Variable +from core.variables.variables import VariableBase from core.workflow.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, @@ -270,8 +270,8 @@ class RagPipelineService: graph: dict, unique_hash: str | None, account: Account, - environment_variables: Sequence[Variable], - conversation_variables: Sequence[Variable], + environment_variables: Sequence[VariableBase], + conversation_variables: Sequence[VariableBase], rag_pipeline_variables: list, ) -> Workflow: """ diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 9407a2b3f0..70b0190231 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -15,7 +15,7 @@ from sqlalchemy.sql.expression import and_, or_ from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.file.models import File -from core.variables import Segment, StringSegment, Variable +from core.variables import Segment, StringSegment, VariableBase from core.variables.consts import SELECTORS_LENGTH from core.variables.segments import ( ArrayFileSegment, @@ -77,14 +77,14 @@ class DraftVarLoader(VariableLoader): # Application ID for which variables are being loaded. _app_id: str _tenant_id: str - _fallback_variables: Sequence[Variable] + _fallback_variables: Sequence[VariableBase] def __init__( self, engine: Engine, app_id: str, tenant_id: str, - fallback_variables: Sequence[Variable] | None = None, + fallback_variables: Sequence[VariableBase] | None = None, ): self._engine = engine self._app_id = app_id @@ -94,12 +94,12 @@ class DraftVarLoader(VariableLoader): def _selector_to_tuple(self, selector: Sequence[str]) -> tuple[str, str]: return (selector[0], selector[1]) - def load_variables(self, selectors: list[list[str]]) -> list[Variable]: + def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]: if not selectors: return [] - # Map each selector (as a tuple via `_selector_to_tuple`) to its corresponding Variable instance. - variable_by_selector: dict[tuple[str, str], Variable] = {} + # Map each selector (as a tuple via `_selector_to_tuple`) to its corresponding variable instance. + variable_by_selector: dict[tuple[str, str], VariableBase] = {} with Session(bind=self._engine, expire_on_commit=False) as session: srv = WorkflowDraftVariableService(session) @@ -145,7 +145,7 @@ class DraftVarLoader(VariableLoader): return list(variable_by_selector.values()) - def _load_offloaded_variable(self, draft_var: WorkflowDraftVariable) -> tuple[tuple[str, str], Variable]: + def _load_offloaded_variable(self, draft_var: WorkflowDraftVariable) -> tuple[tuple[str, str], VariableBase]: # This logic is closely tied to `WorkflowDraftVaribleService._try_offload_large_variable` # and must remain synchronized with it. # Ideally, these should be co-located for better maintainability. diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index b45a167b73..d8c3159178 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -13,8 +13,8 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.file import File from core.repositories import DifyCoreRepositoryFactory -from core.variables import Variable -from core.variables.variables import VariableUnion +from core.variables import VariableBase +from core.variables.variables import Variable from core.workflow.entities import WorkflowNodeExecution from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.errors import WorkflowNodeRunFailedError @@ -198,8 +198,8 @@ class WorkflowService: features: dict, unique_hash: str | None, account: Account, - environment_variables: Sequence[Variable], - conversation_variables: Sequence[Variable], + environment_variables: Sequence[VariableBase], + conversation_variables: Sequence[VariableBase], ) -> Workflow: """ Sync draft workflow @@ -1044,7 +1044,7 @@ def _setup_variable_pool( workflow: Workflow, node_type: NodeType, conversation_id: str, - conversation_variables: list[Variable], + conversation_variables: list[VariableBase], ): # Only inject system variables for START node type. if node_type == NodeType.START or node_type.is_trigger_node: @@ -1070,9 +1070,9 @@ def _setup_variable_pool( system_variables=system_variable, user_inputs=user_inputs, environment_variables=workflow.environment_variables, - # Based on the definition of `VariableUnion`, - # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. - conversation_variables=cast(list[VariableUnion], conversation_variables), # + # Based on the definition of `Variable`, + # `VariableBase` instances can be safely used as `Variable` since they are compatible. + conversation_variables=cast(list[Variable], conversation_variables), # ) return variable_pool diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py index 5d180c7cbc..cd45292488 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py @@ -228,11 +228,28 @@ def test_resolve_user_from_database_falls_back_to_end_user(monkeypatch: pytest.M def scalar(self, _stmt): return self.results.pop(0) + # SQLAlchemy Session APIs used by code under test + def expunge(self, *_args, **_kwargs): + pass + + def close(self): + pass + + # support `with session_factory.create_session() as session:` + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + self.close() + tenant = SimpleNamespace(id="tenant_id") end_user = SimpleNamespace(id="end_user_id", tenant_id="tenant_id") - db_stub = SimpleNamespace(session=StubSession([tenant, None, end_user])) - monkeypatch.setattr("core.tools.workflow_as_tool.tool.db", db_stub) + # Monkeypatch session factory to return our stub session + monkeypatch.setattr( + "core.tools.workflow_as_tool.tool.session_factory.create_session", + lambda: StubSession([tenant, None, end_user]), + ) entity = ToolEntity( identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), @@ -266,8 +283,23 @@ def test_resolve_user_from_database_returns_none_when_no_tenant(monkeypatch: pyt def scalar(self, _stmt): return self.results.pop(0) - db_stub = SimpleNamespace(session=StubSession([None])) - monkeypatch.setattr("core.tools.workflow_as_tool.tool.db", db_stub) + def expunge(self, *_args, **_kwargs): + pass + + def close(self): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + self.close() + + # Monkeypatch session factory to return our stub session with no tenant + monkeypatch.setattr( + "core.tools.workflow_as_tool.tool.session_factory.create_session", + lambda: StubSession([None]), + ) entity = ToolEntity( identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), diff --git a/api/tests/unit_tests/core/variables/test_segment.py b/api/tests/unit_tests/core/variables/test_segment.py index af4f96ba23..aa16c8af1c 100644 --- a/api/tests/unit_tests/core/variables/test_segment.py +++ b/api/tests/unit_tests/core/variables/test_segment.py @@ -35,7 +35,6 @@ from core.variables.variables import ( SecretVariable, StringVariable, Variable, - VariableUnion, ) from core.workflow.runtime import VariablePool from core.workflow.system_variable import SystemVariable @@ -96,7 +95,7 @@ class _Segments(BaseModel): class _Variables(BaseModel): - variables: list[VariableUnion] + variables: list[Variable] def create_test_file( @@ -194,7 +193,7 @@ class TestSegmentDumpAndLoad: # Create one instance of each variable type test_file = create_test_file() - all_variables: list[VariableUnion] = [ + all_variables: list[Variable] = [ NoneVariable(name="none_var"), StringVariable(value="test string", name="string_var"), IntegerVariable(value=42, name="int_var"), diff --git a/api/tests/unit_tests/core/variables/test_variables.py b/api/tests/unit_tests/core/variables/test_variables.py index 925142892c..fb4b18b57a 100644 --- a/api/tests/unit_tests/core/variables/test_variables.py +++ b/api/tests/unit_tests/core/variables/test_variables.py @@ -11,7 +11,7 @@ from core.variables import ( SegmentType, StringVariable, ) -from core.variables.variables import Variable +from core.variables.variables import VariableBase def test_frozen_variables(): @@ -76,7 +76,7 @@ def test_object_variable_to_object(): def test_variable_to_object(): - var: Variable = StringVariable(name="text", value="text") + var: VariableBase = StringVariable(name="text", value="text") assert var.to_object() == "text" var = IntegerVariable(name="integer", value=42) assert var.to_object() == 42 diff --git a/api/tests/unit_tests/core/workflow/test_variable_pool.py b/api/tests/unit_tests/core/workflow/test_variable_pool.py index 9733bf60eb..b8869dbf1d 100644 --- a/api/tests/unit_tests/core/workflow/test_variable_pool.py +++ b/api/tests/unit_tests/core/workflow/test_variable_pool.py @@ -24,7 +24,7 @@ from core.variables.variables import ( IntegerVariable, ObjectVariable, StringVariable, - VariableUnion, + Variable, ) from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from core.workflow.runtime import VariablePool @@ -160,7 +160,7 @@ class TestVariablePoolSerialization: ) # Create environment variables with all types including ArrayFileVariable - env_vars: list[VariableUnion] = [ + env_vars: list[Variable] = [ StringVariable( id="env_string_id", name="env_string", @@ -182,7 +182,7 @@ class TestVariablePoolSerialization: ] # Create conversation variables with complex data - conv_vars: list[VariableUnion] = [ + conv_vars: list[Variable] = [ StringVariable( id="conv_string_id", name="conv_string", diff --git a/api/tests/unit_tests/extensions/logstore/__init__.py b/api/tests/unit_tests/extensions/logstore/__init__.py new file mode 100644 index 0000000000..fe9ada9128 --- /dev/null +++ b/api/tests/unit_tests/extensions/logstore/__init__.py @@ -0,0 +1 @@ +"""LogStore extension unit tests.""" diff --git a/api/tests/unit_tests/extensions/logstore/test_sql_escape.py b/api/tests/unit_tests/extensions/logstore/test_sql_escape.py new file mode 100644 index 0000000000..63172b3f9b --- /dev/null +++ b/api/tests/unit_tests/extensions/logstore/test_sql_escape.py @@ -0,0 +1,469 @@ +""" +Unit tests for SQL escape utility functions. + +These tests ensure that SQL injection attacks are properly prevented +in LogStore queries, particularly for cross-tenant access scenarios. +""" + +import pytest + +from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string + + +class TestEscapeSQLString: + """Test escape_sql_string function.""" + + def test_escape_empty_string(self): + """Test escaping empty string.""" + assert escape_sql_string("") == "" + + def test_escape_normal_string(self): + """Test escaping string without special characters.""" + assert escape_sql_string("tenant_abc123") == "tenant_abc123" + assert escape_sql_string("app-uuid-1234") == "app-uuid-1234" + + def test_escape_single_quote(self): + """Test escaping single quote.""" + # Single quote should be doubled + assert escape_sql_string("tenant'id") == "tenant''id" + assert escape_sql_string("O'Reilly") == "O''Reilly" + + def test_escape_multiple_quotes(self): + """Test escaping multiple single quotes.""" + assert escape_sql_string("a'b'c") == "a''b''c" + assert escape_sql_string("'''") == "''''''" + + # === SQL Injection Attack Scenarios === + + def test_prevent_boolean_injection(self): + """Test prevention of boolean injection attacks.""" + # Classic OR 1=1 attack + malicious_input = "tenant' OR '1'='1" + escaped = escape_sql_string(malicious_input) + assert escaped == "tenant'' OR ''1''=''1" + + # When used in SQL, this becomes a safe string literal + sql = f"WHERE tenant_id='{escaped}'" + assert sql == "WHERE tenant_id='tenant'' OR ''1''=''1'" + # The entire input is now a string literal that won't match any tenant + + def test_prevent_or_injection(self): + """Test prevention of OR-based injection.""" + malicious_input = "tenant_a' OR tenant_id='tenant_b" + escaped = escape_sql_string(malicious_input) + assert escaped == "tenant_a'' OR tenant_id=''tenant_b" + + sql = f"WHERE tenant_id='{escaped}'" + # The OR is now part of the string literal, not SQL logic + assert "OR tenant_id=" in sql + # The SQL has: opening ', doubled internal quotes '', and closing ' + assert sql == "WHERE tenant_id='tenant_a'' OR tenant_id=''tenant_b'" + + def test_prevent_union_injection(self): + """Test prevention of UNION-based injection.""" + malicious_input = "xxx' UNION SELECT password FROM users WHERE '1'='1" + escaped = escape_sql_string(malicious_input) + assert escaped == "xxx'' UNION SELECT password FROM users WHERE ''1''=''1" + + # UNION becomes part of the string literal + assert "UNION" in escaped + assert escaped.count("''") == 4 # All internal quotes are doubled + + def test_prevent_comment_injection(self): + """Test prevention of comment-based injection.""" + # SQL comment to bypass remaining conditions + malicious_input = "tenant' --" + escaped = escape_sql_string(malicious_input) + assert escaped == "tenant'' --" + + sql = f"WHERE tenant_id='{escaped}' AND deleted=false" + # The -- is now inside the string, not a SQL comment + assert "--" in sql + assert "AND deleted=false" in sql # This part is NOT commented out + + def test_prevent_semicolon_injection(self): + """Test prevention of semicolon-based multi-statement injection.""" + malicious_input = "tenant'; DROP TABLE users; --" + escaped = escape_sql_string(malicious_input) + assert escaped == "tenant''; DROP TABLE users; --" + + # Semicolons and DROP are now part of the string + assert "DROP TABLE" in escaped + + def test_prevent_time_based_blind_injection(self): + """Test prevention of time-based blind SQL injection.""" + malicious_input = "tenant' AND SLEEP(5) --" + escaped = escape_sql_string(malicious_input) + assert escaped == "tenant'' AND SLEEP(5) --" + + # SLEEP becomes part of the string + assert "SLEEP" in escaped + + def test_prevent_wildcard_injection(self): + """Test prevention of wildcard-based injection.""" + malicious_input = "tenant' OR tenant_id LIKE '%" + escaped = escape_sql_string(malicious_input) + assert escaped == "tenant'' OR tenant_id LIKE ''%" + + # The LIKE and wildcard are now part of the string + assert "LIKE" in escaped + + def test_prevent_null_byte_injection(self): + """Test handling of null bytes.""" + # Null bytes can sometimes bypass filters + malicious_input = "tenant\x00' OR '1'='1" + escaped = escape_sql_string(malicious_input) + # Null byte is preserved, but quote is escaped + assert "''1''=''1" in escaped + + # === Real-world SAAS Scenarios === + + def test_cross_tenant_access_attempt(self): + """Test prevention of cross-tenant data access.""" + # Attacker tries to access another tenant's data + attacker_input = "tenant_b' OR tenant_id='tenant_a" + escaped = escape_sql_string(attacker_input) + + sql = f"SELECT * FROM workflow_runs WHERE tenant_id='{escaped}'" + # The query will look for a tenant literally named "tenant_b' OR tenant_id='tenant_a" + # which doesn't exist - preventing access to either tenant's data + assert "tenant_b'' OR tenant_id=''tenant_a" in sql + + def test_cross_app_access_attempt(self): + """Test prevention of cross-application data access.""" + attacker_input = "app1' OR app_id='app2" + escaped = escape_sql_string(attacker_input) + + sql = f"WHERE app_id='{escaped}'" + # Cannot access app2's data + assert "app1'' OR app_id=''app2" in sql + + def test_bypass_status_filter(self): + """Test prevention of bypassing status filters.""" + # Try to see all statuses instead of just 'running' + attacker_input = "running' OR status LIKE '%" + escaped = escape_sql_string(attacker_input) + + sql = f"WHERE status='{escaped}'" + # Status condition is not bypassed + assert "running'' OR status LIKE ''%" in sql + + # === Edge Cases === + + def test_escape_only_quotes(self): + """Test string with only quotes.""" + assert escape_sql_string("'") == "''" + assert escape_sql_string("''") == "''''" + + def test_escape_mixed_content(self): + """Test string with mixed quotes and other chars.""" + input_str = "It's a 'test' of O'Reilly's code" + escaped = escape_sql_string(input_str) + assert escaped == "It''s a ''test'' of O''Reilly''s code" + + def test_escape_unicode_with_quotes(self): + """Test Unicode strings with quotes.""" + input_str = "租户' OR '1'='1" + escaped = escape_sql_string(input_str) + assert escaped == "租户'' OR ''1''=''1" + + +class TestEscapeIdentifier: + """Test escape_identifier function.""" + + def test_escape_uuid(self): + """Test escaping UUID identifiers.""" + uuid = "550e8400-e29b-41d4-a716-446655440000" + assert escape_identifier(uuid) == uuid + + def test_escape_alphanumeric_id(self): + """Test escaping alphanumeric identifiers.""" + assert escape_identifier("tenant_123") == "tenant_123" + assert escape_identifier("app-abc-123") == "app-abc-123" + + def test_escape_identifier_with_quote(self): + """Test escaping identifier with single quote.""" + malicious = "tenant' OR '1'='1" + escaped = escape_identifier(malicious) + assert escaped == "tenant'' OR ''1''=''1" + + def test_identifier_injection_attempt(self): + """Test prevention of injection through identifiers.""" + # Common identifier injection patterns + test_cases = [ + ("id' OR '1'='1", "id'' OR ''1''=''1"), + ("id'; DROP TABLE", "id''; DROP TABLE"), + ("id' UNION SELECT", "id'' UNION SELECT"), + ] + + for malicious, expected in test_cases: + assert escape_identifier(malicious) == expected + + +class TestSQLInjectionIntegration: + """Integration tests simulating real SQL construction scenarios.""" + + def test_complete_where_clause_safety(self): + """Test that a complete WHERE clause is safe from injection.""" + # Simulating typical query construction + tenant_id = "tenant' OR '1'='1" + app_id = "app' UNION SELECT" + run_id = "run' --" + + escaped_tenant = escape_identifier(tenant_id) + escaped_app = escape_identifier(app_id) + escaped_run = escape_identifier(run_id) + + sql = f""" + SELECT * FROM workflow_runs + WHERE tenant_id='{escaped_tenant}' + AND app_id='{escaped_app}' + AND id='{escaped_run}' + """ + + # Verify all special characters are escaped + assert "tenant'' OR ''1''=''1" in sql + assert "app'' UNION SELECT" in sql + assert "run'' --" in sql + + # Verify SQL structure is preserved (3 conditions with AND) + assert sql.count("AND") == 2 + + def test_multiple_conditions_with_injection_attempts(self): + """Test multiple conditions all attempting injection.""" + conditions = { + "tenant_id": "t1' OR tenant_id='t2", + "app_id": "a1' OR app_id='a2", + "status": "running' OR '1'='1", + } + + where_parts = [] + for field, value in conditions.items(): + escaped = escape_sql_string(value) + where_parts.append(f"{field}='{escaped}'") + + where_clause = " AND ".join(where_parts) + + # All injection attempts are neutralized + assert "t1'' OR tenant_id=''t2" in where_clause + assert "a1'' OR app_id=''a2" in where_clause + assert "running'' OR ''1''=''1" in where_clause + + # AND structure is preserved + assert where_clause.count(" AND ") == 2 + + @pytest.mark.parametrize( + ("attack_vector", "description"), + [ + ("' OR '1'='1", "Boolean injection"), + ("' OR '1'='1' --", "Boolean with comment"), + ("' UNION SELECT * FROM users --", "Union injection"), + ("'; DROP TABLE workflow_runs; --", "Destructive command"), + ("' AND SLEEP(10) --", "Time-based blind"), + ("' OR tenant_id LIKE '%", "Wildcard injection"), + ("admin' --", "Comment bypass"), + ("' OR 1=1 LIMIT 1 --", "Limit bypass"), + ], + ) + def test_common_injection_vectors(self, attack_vector, description): + """Test protection against common injection attack vectors.""" + escaped = escape_sql_string(attack_vector) + + # Build SQL + sql = f"WHERE tenant_id='{escaped}'" + + # Verify the attack string is now a safe literal + # The key indicator: all internal single quotes are doubled + internal_quotes = escaped.count("''") + original_quotes = attack_vector.count("'") + + # Each original quote should be doubled + assert internal_quotes == original_quotes + + # Verify SQL has exactly 2 quotes (opening and closing) + assert sql.count("'") >= 2 # At least opening and closing + + def test_logstore_specific_scenario(self): + """Test SQL injection prevention in LogStore-specific scenarios.""" + # Simulate LogStore query with window function + tenant_id = "tenant' OR '1'='1" + app_id = "app' UNION SELECT" + + escaped_tenant = escape_identifier(tenant_id) + escaped_app = escape_identifier(app_id) + + sql = f""" + SELECT * FROM ( + SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn + FROM workflow_execution_logstore + WHERE tenant_id='{escaped_tenant}' + AND app_id='{escaped_app}' + AND __time__ > 0 + ) AS subquery WHERE rn = 1 + """ + + # Complex query structure is maintained + assert "ROW_NUMBER()" in sql + assert "PARTITION BY id" in sql + + # Injection attempts are escaped + assert "tenant'' OR ''1''=''1" in sql + assert "app'' UNION SELECT" in sql + + +# ==================================================================================== +# Tests for LogStore Query Syntax (SDK Mode) +# ==================================================================================== + + +class TestLogStoreQueryEscape: + """Test escape_logstore_query_value for SDK mode query syntax.""" + + def test_normal_value(self): + """Test escaping normal alphanumeric value.""" + value = "550e8400-e29b-41d4-a716-446655440000" + escaped = escape_logstore_query_value(value) + + # Should be wrapped in double quotes + assert escaped == '"550e8400-e29b-41d4-a716-446655440000"' + + def test_empty_value(self): + """Test escaping empty string.""" + assert escape_logstore_query_value("") == '""' + + def test_value_with_and_keyword(self): + """Test that 'and' keyword is neutralized when quoted.""" + malicious = "value and field:evil" + escaped = escape_logstore_query_value(malicious) + + # Should be wrapped in quotes, making 'and' a literal + assert escaped == '"value and field:evil"' + + # Simulate using in query + query = f"tenant_id:{escaped}" + assert query == 'tenant_id:"value and field:evil"' + + def test_value_with_or_keyword(self): + """Test that 'or' keyword is neutralized when quoted.""" + malicious = "tenant_a or tenant_id:tenant_b" + escaped = escape_logstore_query_value(malicious) + + assert escaped == '"tenant_a or tenant_id:tenant_b"' + + query = f"tenant_id:{escaped}" + assert "or" in query # Present but as literal string + + def test_value_with_not_keyword(self): + """Test that 'not' keyword is neutralized when quoted.""" + malicious = "not field:value" + escaped = escape_logstore_query_value(malicious) + + assert escaped == '"not field:value"' + + def test_value_with_parentheses(self): + """Test that parentheses are neutralized when quoted.""" + malicious = "(tenant_a or tenant_b)" + escaped = escape_logstore_query_value(malicious) + + assert escaped == '"(tenant_a or tenant_b)"' + assert "(" in escaped # Present as literal + assert ")" in escaped # Present as literal + + def test_value_with_colon(self): + """Test that colons are neutralized when quoted.""" + malicious = "field:value" + escaped = escape_logstore_query_value(malicious) + + assert escaped == '"field:value"' + assert ":" in escaped # Present as literal + + def test_value_with_double_quotes(self): + """Test that internal double quotes are escaped.""" + value_with_quotes = 'tenant"test"value' + escaped = escape_logstore_query_value(value_with_quotes) + + # Double quotes should be escaped with backslash + assert escaped == '"tenant\\"test\\"value"' + # Should have outer quotes plus escaped inner quotes + assert '\\"' in escaped + + def test_value_with_backslash(self): + """Test that backslashes are escaped.""" + value_with_backslash = "tenant\\test" + escaped = escape_logstore_query_value(value_with_backslash) + + # Backslash should be escaped + assert escaped == '"tenant\\\\test"' + assert "\\\\" in escaped + + def test_value_with_backslash_and_quote(self): + """Test escaping both backslash and double quote.""" + value = 'path\\to\\"file"' + escaped = escape_logstore_query_value(value) + + # Both should be escaped + assert escaped == '"path\\\\to\\\\\\"file\\""' + # Verify escape order is correct + assert "\\\\" in escaped # Escaped backslash + assert '\\"' in escaped # Escaped double quote + + def test_complex_injection_attempt(self): + """Test complex injection combining multiple operators.""" + malicious = 'tenant_a" or (tenant_id:"tenant_b" and app_id:"evil")' + escaped = escape_logstore_query_value(malicious) + + # All special chars should be literals or escaped + assert escaped.startswith('"') + assert escaped.endswith('"') + # Inner double quotes escaped, operators become literals + assert "or" in escaped + assert "and" in escaped + assert '\\"' in escaped # Escaped quotes + + def test_only_backslash(self): + """Test escaping a single backslash.""" + assert escape_logstore_query_value("\\") == '"\\\\"' + + def test_only_double_quote(self): + """Test escaping a single double quote.""" + assert escape_logstore_query_value('"') == '"\\""' + + def test_multiple_backslashes(self): + """Test escaping multiple consecutive backslashes.""" + assert escape_logstore_query_value("\\\\\\") == '"\\\\\\\\\\\\"' # 3 backslashes -> 6 + + def test_escape_sequence_like_input(self): + """Test that existing escape sequences are properly escaped.""" + # Input looks like already escaped, but we still escape it + value = 'value\\"test' + escaped = escape_logstore_query_value(value) + # \\ -> \\\\, " -> \" + assert escaped == '"value\\\\\\"test"' + + +@pytest.mark.parametrize( + ("attack_scenario", "field", "malicious_value"), + [ + ("Cross-tenant via OR", "tenant_id", "tenant_a or tenant_id:tenant_b"), + ("Cross-app via AND", "app_id", "app_a and (app_id:app_b or app_id:app_c)"), + ("Boolean logic", "status", "succeeded or status:failed"), + ("Negation", "tenant_id", "not tenant_a"), + ("Field injection", "run_id", "run123 and tenant_id:evil_tenant"), + ("Parentheses grouping", "app_id", "app1 or (app_id:app2 and tenant_id:tenant2)"), + ("Quote breaking attempt", "tenant_id", 'tenant" or "1"="1'), + ("Backslash escape bypass", "app_id", "app\\ and app_id:evil"), + ], +) +def test_logstore_query_injection_scenarios(attack_scenario: str, field: str, malicious_value: str): + """Test that various LogStore query injection attempts are neutralized.""" + escaped = escape_logstore_query_value(malicious_value) + + # Build query + query = f"{field}:{escaped}" + + # All operators should be within quoted string (literals) + assert escaped.startswith('"') + assert escaped.endswith('"') + + # Verify the full query structure is safe + assert query.count(":") >= 1 # At least the main field:value separator diff --git a/api/tests/unit_tests/services/enterprise/__init__.py b/api/tests/unit_tests/services/enterprise/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/services/enterprise/test_traceparent_propagation.py b/api/tests/unit_tests/services/enterprise/test_traceparent_propagation.py new file mode 100644 index 0000000000..87c03f13a3 --- /dev/null +++ b/api/tests/unit_tests/services/enterprise/test_traceparent_propagation.py @@ -0,0 +1,59 @@ +"""Unit tests for traceparent header propagation in EnterpriseRequest. + +This test module verifies that the W3C traceparent header is properly +generated and included in HTTP requests made by EnterpriseRequest. +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from services.enterprise.base import EnterpriseRequest + + +class TestTraceparentPropagation: + """Unit tests for traceparent header propagation.""" + + @pytest.fixture + def mock_enterprise_config(self): + """Mock EnterpriseRequest configuration.""" + with ( + patch.object(EnterpriseRequest, "base_url", "https://enterprise-api.example.com"), + patch.object(EnterpriseRequest, "secret_key", "test-secret-key"), + patch.object(EnterpriseRequest, "secret_key_header", "Enterprise-Api-Secret-Key"), + ): + yield + + @pytest.fixture + def mock_httpx_client(self): + """Mock httpx.Client for testing.""" + with patch("services.enterprise.base.httpx.Client") as mock_client_class: + mock_client = MagicMock() + mock_client_class.return_value.__enter__.return_value = mock_client + mock_client_class.return_value.__exit__.return_value = None + + # Setup default response + mock_response = MagicMock() + mock_response.json.return_value = {"result": "success"} + mock_client.request.return_value = mock_response + + yield mock_client + + def test_traceparent_header_included_when_generated(self, mock_enterprise_config, mock_httpx_client): + """Test that traceparent header is included when successfully generated.""" + # Arrange + expected_traceparent = "00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01" + + with patch("services.enterprise.base.generate_traceparent_header", return_value=expected_traceparent): + # Act + EnterpriseRequest.send_request("GET", "/test") + + # Assert + mock_httpx_client.request.assert_called_once() + call_args = mock_httpx_client.request.call_args + headers = call_args[1]["headers"] + + assert "traceparent" in headers + assert headers["traceparent"] == expected_traceparent + assert headers["Content-Type"] == "application/json" + assert headers["Enterprise-Api-Secret-Key"] == "test-secret-key" diff --git a/api/uv.lock b/api/uv.lock index a999c4ee18..aacf408902 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -453,15 +453,15 @@ wheels = [ [[package]] name = "azure-core" -version = "1.36.0" +version = "1.38.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "requests" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/0a/c4/d4ff3bc3ddf155156460bff340bbe9533f99fac54ddea165f35a8619f162/azure_core-1.36.0.tar.gz", hash = "sha256:22e5605e6d0bf1d229726af56d9e92bc37b6e726b141a18be0b4d424131741b7", size = 351139, upload-time = "2025-10-15T00:33:49.083Z" } +sdist = { url = "https://files.pythonhosted.org/packages/dc/1b/e503e08e755ea94e7d3419c9242315f888fc664211c90d032e40479022bf/azure_core-1.38.0.tar.gz", hash = "sha256:8194d2682245a3e4e3151a667c686464c3786fed7918b394d035bdcd61bb5993", size = 363033, upload-time = "2026-01-12T17:03:05.535Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b1/3c/b90d5afc2e47c4a45f4bba00f9c3193b0417fad5ad3bb07869f9d12832aa/azure_core-1.36.0-py3-none-any.whl", hash = "sha256:fee9923a3a753e94a259563429f3644aaf05c486d45b1215d098115102d91d3b", size = 213302, upload-time = "2025-10-15T00:33:51.058Z" }, + { url = "https://files.pythonhosted.org/packages/fc/d8/b8fcba9464f02b121f39de2db2bf57f0b216fe11d014513d666e8634380d/azure_core-1.38.0-py3-none-any.whl", hash = "sha256:ab0c9b2cd71fecb1842d52c965c95285d3cfb38902f6766e4a471f1cd8905335", size = 217825, upload-time = "2026-01-12T17:03:07.291Z" }, ] [[package]] @@ -1368,7 +1368,7 @@ wheels = [ [[package]] name = "dify-api" -version = "1.11.2" +version = "1.11.3" source = { virtual = "." } dependencies = [ { name = "aliyun-log-python-sdk" }, @@ -1965,11 +1965,11 @@ wheels = [ [[package]] name = "filelock" -version = "3.20.0" +version = "3.20.3" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/58/46/0028a82567109b5ef6e4d2a1f04a583fb513e6cf9527fcdd09afd817deeb/filelock-3.20.0.tar.gz", hash = "sha256:711e943b4ec6be42e1d4e6690b48dc175c822967466bb31c0c293f34334c13f4", size = 18922, upload-time = "2025-10-08T18:03:50.056Z" } +sdist = { url = "https://files.pythonhosted.org/packages/1d/65/ce7f1b70157833bf3cb851b556a37d4547ceafc158aa9b34b36782f23696/filelock-3.20.3.tar.gz", hash = "sha256:18c57ee915c7ec61cff0ecf7f0f869936c7c30191bb0cf406f1341778d0834e1", size = 19485, upload-time = "2026-01-09T17:55:05.421Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/76/91/7216b27286936c16f5b4d0c530087e4a54eead683e6b0b73dd0c64844af6/filelock-3.20.0-py3-none-any.whl", hash = "sha256:339b4732ffda5cd79b13f4e2711a31b0365ce445d95d243bb996273d072546a2", size = 16054, upload-time = "2025-10-08T18:03:48.35Z" }, + { url = "https://files.pythonhosted.org/packages/b5/36/7fb70f04bf00bc646cd5bb45aa9eddb15e19437a28b8fb2b4a5249fac770/filelock-3.20.3-py3-none-any.whl", hash = "sha256:4b0dda527ee31078689fc205ec4f1c1bf7d56cf88b6dc9426c4f230e46c2dce1", size = 16701, upload-time = "2026-01-09T17:55:04.334Z" }, ] [[package]] diff --git a/docker/.env.example b/docker/.env.example index e7cb8711ce..9a3a7239c6 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -1037,18 +1037,26 @@ WORKFLOW_NODE_EXECUTION_STORAGE=rdbms # Options: # - core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository (default) # - core.repositories.celery_workflow_execution_repository.CeleryWorkflowExecutionRepository +# - extensions.logstore.repositories.logstore_workflow_execution_repository.LogstoreWorkflowExecutionRepository CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository # Core workflow node execution repository implementation # Options: # - core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository (default) # - core.repositories.celery_workflow_node_execution_repository.CeleryWorkflowNodeExecutionRepository +# - extensions.logstore.repositories.logstore_workflow_node_execution_repository.LogstoreWorkflowNodeExecutionRepository CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository # API workflow run repository implementation +# Options: +# - repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository (default) +# - extensions.logstore.repositories.logstore_api_workflow_run_repository.LogstoreAPIWorkflowRunRepository API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository # API workflow node execution repository implementation +# Options: +# - repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository (default) +# - extensions.logstore.repositories.logstore_api_workflow_node_execution_repository.LogstoreAPIWorkflowNodeExecutionRepository API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository # Workflow log cleanup configuration diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index 709aff23df..aada39569e 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -21,7 +21,7 @@ services: # API service api: - image: langgenius/dify-api:1.11.2 + image: langgenius/dify-api:1.11.3 restart: always environment: # Use the shared environment variables. @@ -63,7 +63,7 @@ services: # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: - image: langgenius/dify-api:1.11.2 + image: langgenius/dify-api:1.11.3 restart: always environment: # Use the shared environment variables. @@ -102,7 +102,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.11.2 + image: langgenius/dify-api:1.11.3 restart: always environment: # Use the shared environment variables. @@ -132,7 +132,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.11.2 + image: langgenius/dify-web:1.11.3 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 041f60aaa2..fcb07dda36 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -704,7 +704,7 @@ services: # API service api: - image: langgenius/dify-api:1.11.2 + image: langgenius/dify-api:1.11.3 restart: always environment: # Use the shared environment variables. @@ -746,7 +746,7 @@ services: # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: - image: langgenius/dify-api:1.11.2 + image: langgenius/dify-api:1.11.3 restart: always environment: # Use the shared environment variables. @@ -785,7 +785,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.11.2 + image: langgenius/dify-api:1.11.3 restart: always environment: # Use the shared environment variables. @@ -815,7 +815,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.11.2 + image: langgenius/dify-web:1.11.3 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} diff --git a/web/__tests__/embedded-user-id-store.test.tsx b/web/__tests__/embedded-user-id-store.test.tsx index 276b22bcd7..901218e76b 100644 --- a/web/__tests__/embedded-user-id-store.test.tsx +++ b/web/__tests__/embedded-user-id-store.test.tsx @@ -53,6 +53,7 @@ vi.mock('@/context/global-public-context', () => { ) return { useGlobalPublicStore, + useIsSystemFeaturesPending: () => false, } }) diff --git a/web/app/components/app-initializer.tsx b/web/app/components/app-initializer.tsx index e30646eb3f..3410ecbe9a 100644 --- a/web/app/components/app-initializer.tsx +++ b/web/app/components/app-initializer.tsx @@ -9,8 +9,8 @@ import { EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION, EDUCATION_VERIFYING_LOCALSTORAGE_ITEM, } from '@/app/education-apply/constants' -import { fetchSetupStatus } from '@/service/common' import { sendGAEvent } from '@/utils/gtag' +import { fetchSetupStatusWithCache } from '@/utils/setup-status' import { resolvePostLoginRedirect } from '../signin/utils/post-login-redirect' import { trackEvent } from './base/amplitude' @@ -33,15 +33,8 @@ export const AppInitializer = ({ const isSetupFinished = useCallback(async () => { try { - if (localStorage.getItem('setup_status') === 'finished') - return true - const setUpStatus = await fetchSetupStatus() - if (setUpStatus.step !== 'finished') { - localStorage.removeItem('setup_status') - return false - } - localStorage.setItem('setup_status', 'finished') - return true + const setUpStatus = await fetchSetupStatusWithCache() + return setUpStatus.step === 'finished' } catch (error) { console.error(error) diff --git a/web/app/components/app/app-access-control/access-control.spec.tsx b/web/app/components/app/app-access-control/access-control.spec.tsx index dd9acd3479..b73ed5c266 100644 --- a/web/app/components/app/app-access-control/access-control.spec.tsx +++ b/web/app/components/app/app-access-control/access-control.spec.tsx @@ -34,13 +34,6 @@ vi.mock('@/context/app-context', () => ({ }), })) -vi.mock('@/service/common', () => ({ - fetchCurrentWorkspace: vi.fn(), - fetchLangGeniusVersion: vi.fn(), - fetchUserProfile: vi.fn(), - getSystemFeatures: vi.fn(), -})) - vi.mock('@/service/access-control', () => ({ useAppWhiteListSubjects: (...args: unknown[]) => mockUseAppWhiteListSubjects(...args), useSearchForWhiteListCandidates: (...args: unknown[]) => mockUseSearchForWhiteListCandidates(...args), @@ -125,7 +118,6 @@ const resetAccessControlStore = () => { const resetGlobalStore = () => { useGlobalPublicStore.setState({ systemFeatures: defaultSystemFeatures, - isGlobalPending: false, }) } diff --git a/web/app/components/base/chat/chat-with-history/hooks.spec.tsx b/web/app/components/base/chat/chat-with-history/hooks.spec.tsx index a6d51d8643..f6a8f25cbb 100644 --- a/web/app/components/base/chat/chat-with-history/hooks.spec.tsx +++ b/web/app/components/base/chat/chat-with-history/hooks.spec.tsx @@ -170,8 +170,12 @@ describe('useChatWithHistory', () => { await waitFor(() => { expect(mockFetchChatList).toHaveBeenCalledWith('conversation-1', false, 'app-1') }) - expect(result.current.pinnedConversationList).toEqual(pinnedData.data) - expect(result.current.conversationList).toEqual(listData.data) + await waitFor(() => { + expect(result.current.pinnedConversationList).toEqual(pinnedData.data) + }) + await waitFor(() => { + expect(result.current.conversationList).toEqual(listData.data) + }) }) }) diff --git a/web/app/components/billing/pricing/plans/cloud-plan-item/index.spec.tsx b/web/app/components/billing/pricing/plans/cloud-plan-item/index.spec.tsx index 4473ef98fa..680243a474 100644 --- a/web/app/components/billing/pricing/plans/cloud-plan-item/index.spec.tsx +++ b/web/app/components/billing/pricing/plans/cloud-plan-item/index.spec.tsx @@ -3,7 +3,8 @@ import { fireEvent, render, screen, waitFor } from '@testing-library/react' import * as React from 'react' import { useAppContext } from '@/context/app-context' import { useAsyncWindowOpen } from '@/hooks/use-async-window-open' -import { fetchBillingUrl, fetchSubscriptionUrls } from '@/service/billing' +import { fetchSubscriptionUrls } from '@/service/billing' +import { consoleClient } from '@/service/client' import Toast from '../../../../base/toast' import { ALL_PLANS } from '../../../config' import { Plan } from '../../../type' @@ -21,10 +22,15 @@ vi.mock('@/context/app-context', () => ({ })) vi.mock('@/service/billing', () => ({ - fetchBillingUrl: vi.fn(), fetchSubscriptionUrls: vi.fn(), })) +vi.mock('@/service/client', () => ({ + consoleClient: { + billingUrl: vi.fn(), + }, +})) + vi.mock('@/hooks/use-async-window-open', () => ({ useAsyncWindowOpen: vi.fn(), })) @@ -37,7 +43,7 @@ vi.mock('../../assets', () => ({ const mockUseAppContext = useAppContext as Mock const mockUseAsyncWindowOpen = useAsyncWindowOpen as Mock -const mockFetchBillingUrl = fetchBillingUrl as Mock +const mockBillingUrl = consoleClient.billingUrl as Mock const mockFetchSubscriptionUrls = fetchSubscriptionUrls as Mock const mockToastNotify = Toast.notify as Mock @@ -69,7 +75,7 @@ beforeEach(() => { vi.clearAllMocks() mockUseAppContext.mockReturnValue({ isCurrentWorkspaceManager: true }) mockUseAsyncWindowOpen.mockReturnValue(vi.fn(async open => await open())) - mockFetchBillingUrl.mockResolvedValue({ url: 'https://billing.example' }) + mockBillingUrl.mockResolvedValue({ url: 'https://billing.example' }) mockFetchSubscriptionUrls.mockResolvedValue({ url: 'https://subscription.example' }) assignedHref = '' }) @@ -143,7 +149,7 @@ describe('CloudPlanItem', () => { type: 'error', message: 'billing.buyPermissionDeniedTip', })) - expect(mockFetchBillingUrl).not.toHaveBeenCalled() + expect(mockBillingUrl).not.toHaveBeenCalled() }) it('should open billing portal when upgrading current paid plan', async () => { @@ -162,7 +168,7 @@ describe('CloudPlanItem', () => { fireEvent.click(screen.getByRole('button', { name: 'billing.plansCommon.currentPlan' })) await waitFor(() => { - expect(mockFetchBillingUrl).toHaveBeenCalledTimes(1) + expect(mockBillingUrl).toHaveBeenCalledTimes(1) }) expect(openWindow).toHaveBeenCalledTimes(1) }) diff --git a/web/app/components/billing/pricing/plans/cloud-plan-item/index.tsx b/web/app/components/billing/pricing/plans/cloud-plan-item/index.tsx index b694dc57e2..d9c4d3f75b 100644 --- a/web/app/components/billing/pricing/plans/cloud-plan-item/index.tsx +++ b/web/app/components/billing/pricing/plans/cloud-plan-item/index.tsx @@ -6,7 +6,8 @@ import { useMemo } from 'react' import { useTranslation } from 'react-i18next' import { useAppContext } from '@/context/app-context' import { useAsyncWindowOpen } from '@/hooks/use-async-window-open' -import { fetchBillingUrl, fetchSubscriptionUrls } from '@/service/billing' +import { fetchSubscriptionUrls } from '@/service/billing' +import { consoleClient } from '@/service/client' import Toast from '../../../../base/toast' import { ALL_PLANS } from '../../../config' import { Plan } from '../../../type' @@ -76,7 +77,7 @@ const CloudPlanItem: FC = ({ try { if (isCurrentPaidPlan) { await openAsyncWindow(async () => { - const res = await fetchBillingUrl() + const res = await consoleClient.billingUrl() if (res.url) return res.url throw new Error('Failed to open billing page') diff --git a/web/app/components/header/account-setting/data-source-page-new/hooks/use-marketplace-all-plugins.ts b/web/app/components/header/account-setting/data-source-page-new/hooks/use-marketplace-all-plugins.ts index 0c2154210c..90ef6e78a4 100644 --- a/web/app/components/header/account-setting/data-source-page-new/hooks/use-marketplace-all-plugins.ts +++ b/web/app/components/header/account-setting/data-source-page-new/hooks/use-marketplace-all-plugins.ts @@ -30,8 +30,8 @@ export const useMarketplaceAllPlugins = (providers: any[], searchText: string) = category: PluginCategoryEnum.datasource, exclude, type: 'plugin', - sortBy: 'install_count', - sortOrder: 'DESC', + sort_by: 'install_count', + sort_order: 'DESC', }) } else { @@ -39,10 +39,10 @@ export const useMarketplaceAllPlugins = (providers: any[], searchText: string) = query: '', category: PluginCategoryEnum.datasource, type: 'plugin', - pageSize: 1000, + page_size: 1000, exclude, - sortBy: 'install_count', - sortOrder: 'DESC', + sort_by: 'install_count', + sort_order: 'DESC', }) } }, [queryPlugins, queryPluginsWithDebounced, searchText, exclude]) diff --git a/web/app/components/header/account-setting/model-provider-page/hooks.ts b/web/app/components/header/account-setting/model-provider-page/hooks.ts index 0e35f0fb31..6aba41d4e4 100644 --- a/web/app/components/header/account-setting/model-provider-page/hooks.ts +++ b/web/app/components/header/account-setting/model-provider-page/hooks.ts @@ -275,8 +275,8 @@ export const useMarketplaceAllPlugins = (providers: ModelProvider[], searchText: category: PluginCategoryEnum.model, exclude, type: 'plugin', - sortBy: 'install_count', - sortOrder: 'DESC', + sort_by: 'install_count', + sort_order: 'DESC', }) } else { @@ -284,10 +284,10 @@ export const useMarketplaceAllPlugins = (providers: ModelProvider[], searchText: query: '', category: PluginCategoryEnum.model, type: 'plugin', - pageSize: 1000, + page_size: 1000, exclude, - sortBy: 'install_count', - sortOrder: 'DESC', + sort_by: 'install_count', + sort_order: 'DESC', }) } }, [queryPlugins, queryPluginsWithDebounced, searchText, exclude]) diff --git a/web/app/components/plugins/marketplace/hooks.ts b/web/app/components/plugins/marketplace/hooks.ts index b1e4f50767..60ba0e0bee 100644 --- a/web/app/components/plugins/marketplace/hooks.ts +++ b/web/app/components/plugins/marketplace/hooks.ts @@ -100,11 +100,11 @@ export const useMarketplacePlugins = () => { const [queryParams, setQueryParams] = useState() const normalizeParams = useCallback((pluginsSearchParams: PluginsSearchParams) => { - const pageSize = pluginsSearchParams.pageSize || 40 + const page_size = pluginsSearchParams.page_size || 40 return { ...pluginsSearchParams, - pageSize, + page_size, } }, []) @@ -116,20 +116,20 @@ export const useMarketplacePlugins = () => { plugins: [] as Plugin[], total: 0, page: 1, - pageSize: 40, + page_size: 40, } } const params = normalizeParams(queryParams) const { query, - sortBy, - sortOrder, + sort_by, + sort_order, category, tags, exclude, type, - pageSize, + page_size, } = params const pluginOrBundle = type === 'bundle' ? 'bundles' : 'plugins' @@ -137,10 +137,10 @@ export const useMarketplacePlugins = () => { const res = await postMarketplace<{ data: PluginsFromMarketplaceResponse }>(`/${pluginOrBundle}/search/advanced`, { body: { page: pageParam, - page_size: pageSize, + page_size, query, - sort_by: sortBy, - sort_order: sortOrder, + sort_by, + sort_order, category: category !== 'all' ? category : '', tags, exclude, @@ -154,7 +154,7 @@ export const useMarketplacePlugins = () => { plugins: resPlugins.map(plugin => getFormattedPlugin(plugin)), total: res.data.total, page: pageParam, - pageSize, + page_size, } } catch { @@ -162,13 +162,13 @@ export const useMarketplacePlugins = () => { plugins: [], total: 0, page: pageParam, - pageSize, + page_size, } } }, getNextPageParam: (lastPage) => { const nextPage = lastPage.page + 1 - const loaded = lastPage.page * lastPage.pageSize + const loaded = lastPage.page * lastPage.page_size return loaded < (lastPage.total || 0) ? nextPage : undefined }, initialPageParam: 1, diff --git a/web/app/components/plugins/marketplace/hydration-server.tsx b/web/app/components/plugins/marketplace/hydration-server.tsx index 0aa544cff1..b01f4dd463 100644 --- a/web/app/components/plugins/marketplace/hydration-server.tsx +++ b/web/app/components/plugins/marketplace/hydration-server.tsx @@ -2,8 +2,8 @@ import type { SearchParams } from 'nuqs' import { dehydrate, HydrationBoundary } from '@tanstack/react-query' import { createLoader } from 'nuqs/server' import { getQueryClientServer } from '@/context/query-client-server' +import { marketplaceQuery } from '@/service/client' import { PLUGIN_CATEGORY_WITH_COLLECTIONS } from './constants' -import { marketplaceKeys } from './query' import { marketplaceSearchParamsParsers } from './search-params' import { getCollectionsParams, getMarketplaceCollectionsAndPlugins } from './utils' @@ -23,7 +23,7 @@ async function getDehydratedState(searchParams?: Promise) { const queryClient = getQueryClientServer() await queryClient.prefetchQuery({ - queryKey: marketplaceKeys.collections(getCollectionsParams(params.category)), + queryKey: marketplaceQuery.collections.queryKey({ input: { query: getCollectionsParams(params.category) } }), queryFn: () => getMarketplaceCollectionsAndPlugins(getCollectionsParams(params.category)), }) return dehydrate(queryClient) diff --git a/web/app/components/plugins/marketplace/index.spec.tsx b/web/app/components/plugins/marketplace/index.spec.tsx index 1a3cd15b6b..dc2513ac05 100644 --- a/web/app/components/plugins/marketplace/index.spec.tsx +++ b/web/app/components/plugins/marketplace/index.spec.tsx @@ -60,10 +60,10 @@ vi.mock('@/service/use-plugins', () => ({ // Mock tanstack query const mockFetchNextPage = vi.fn() const mockHasNextPage = false -let mockInfiniteQueryData: { pages: Array<{ plugins: unknown[], total: number, page: number, pageSize: number }> } | undefined +let mockInfiniteQueryData: { pages: Array<{ plugins: unknown[], total: number, page: number, page_size: number }> } | undefined let capturedInfiniteQueryFn: ((ctx: { pageParam: number, signal: AbortSignal }) => Promise) | null = null let capturedQueryFn: ((ctx: { signal: AbortSignal }) => Promise) | null = null -let capturedGetNextPageParam: ((lastPage: { page: number, pageSize: number, total: number }) => number | undefined) | null = null +let capturedGetNextPageParam: ((lastPage: { page: number, page_size: number, total: number }) => number | undefined) | null = null vi.mock('@tanstack/react-query', () => ({ useQuery: vi.fn(({ queryFn, enabled }: { queryFn: (ctx: { signal: AbortSignal }) => Promise, enabled: boolean }) => { @@ -83,7 +83,7 @@ vi.mock('@tanstack/react-query', () => ({ }), useInfiniteQuery: vi.fn(({ queryFn, getNextPageParam, enabled: _enabled }: { queryFn: (ctx: { pageParam: number, signal: AbortSignal }) => Promise - getNextPageParam: (lastPage: { page: number, pageSize: number, total: number }) => number | undefined + getNextPageParam: (lastPage: { page: number, page_size: number, total: number }) => number | undefined enabled: boolean }) => { // Capture queryFn and getNextPageParam for later testing @@ -97,9 +97,9 @@ vi.mock('@tanstack/react-query', () => ({ // Call getNextPageParam to increase coverage if (getNextPageParam) { // Test with more data available - getNextPageParam({ page: 1, pageSize: 40, total: 100 }) + getNextPageParam({ page: 1, page_size: 40, total: 100 }) // Test with no more data - getNextPageParam({ page: 3, pageSize: 40, total: 100 }) + getNextPageParam({ page: 3, page_size: 40, total: 100 }) } return { data: mockInfiniteQueryData, @@ -151,6 +151,7 @@ vi.mock('@/service/base', () => ({ // Mock config vi.mock('@/config', () => ({ + API_PREFIX: '/api', APP_VERSION: '1.0.0', IS_MARKETPLACE: false, MARKETPLACE_API_PREFIX: 'https://marketplace.dify.ai/api/v1', @@ -731,10 +732,10 @@ describe('useMarketplacePlugins', () => { expect(() => { result.current.queryPlugins({ query: 'test', - sortBy: 'install_count', - sortOrder: 'DESC', + sort_by: 'install_count', + sort_order: 'DESC', category: 'tool', - pageSize: 20, + page_size: 20, }) }).not.toThrow() }) @@ -747,7 +748,7 @@ describe('useMarketplacePlugins', () => { result.current.queryPlugins({ query: 'test', type: 'bundle', - pageSize: 40, + page_size: 40, }) }).not.toThrow() }) @@ -798,8 +799,8 @@ describe('useMarketplacePlugins', () => { result.current.queryPlugins({ query: 'test', category: 'all', - sortBy: 'install_count', - sortOrder: 'DESC', + sort_by: 'install_count', + sort_order: 'DESC', }) }).not.toThrow() }) @@ -824,7 +825,7 @@ describe('useMarketplacePlugins', () => { expect(() => { result.current.queryPlugins({ query: 'test', - pageSize: 100, + page_size: 100, }) }).not.toThrow() }) @@ -843,7 +844,7 @@ describe('Hooks queryFn Coverage', () => { // Set mock data to have pages mockInfiniteQueryData = { pages: [ - { plugins: [{ name: 'plugin1' }], total: 10, page: 1, pageSize: 40 }, + { plugins: [{ name: 'plugin1' }], total: 10, page: 1, page_size: 40 }, ], } @@ -863,8 +864,8 @@ describe('Hooks queryFn Coverage', () => { it('should expose page and total from infinite query data', async () => { mockInfiniteQueryData = { pages: [ - { plugins: [{ name: 'plugin1' }, { name: 'plugin2' }], total: 20, page: 1, pageSize: 40 }, - { plugins: [{ name: 'plugin3' }], total: 20, page: 2, pageSize: 40 }, + { plugins: [{ name: 'plugin1' }, { name: 'plugin2' }], total: 20, page: 1, page_size: 40 }, + { plugins: [{ name: 'plugin3' }], total: 20, page: 2, page_size: 40 }, ], } @@ -893,7 +894,7 @@ describe('Hooks queryFn Coverage', () => { it('should return total from first page when query is set and data exists', async () => { mockInfiniteQueryData = { pages: [ - { plugins: [], total: 50, page: 1, pageSize: 40 }, + { plugins: [], total: 50, page: 1, page_size: 40 }, ], } @@ -917,8 +918,8 @@ describe('Hooks queryFn Coverage', () => { type: 'plugin', query: 'search test', category: 'model', - sortBy: 'version_updated_at', - sortOrder: 'ASC', + sort_by: 'version_updated_at', + sort_order: 'ASC', }) expect(result.current).toBeDefined() @@ -1027,13 +1028,13 @@ describe('Advanced Hook Integration', () => { // Test with all possible parameters result.current.queryPlugins({ query: 'comprehensive test', - sortBy: 'install_count', - sortOrder: 'DESC', + sort_by: 'install_count', + sort_order: 'DESC', category: 'tool', tags: ['tag1', 'tag2'], exclude: ['excluded-plugin'], type: 'plugin', - pageSize: 50, + page_size: 50, }) expect(result.current).toBeDefined() @@ -1081,9 +1082,9 @@ describe('Direct queryFn Coverage', () => { result.current.queryPlugins({ query: 'direct test', category: 'tool', - sortBy: 'install_count', - sortOrder: 'DESC', - pageSize: 40, + sort_by: 'install_count', + sort_order: 'DESC', + page_size: 40, }) // Now queryFn should be captured and enabled @@ -1255,7 +1256,7 @@ describe('Direct queryFn Coverage', () => { result.current.queryPlugins({ query: 'structure test', - pageSize: 20, + page_size: 20, }) if (capturedInfiniteQueryFn) { @@ -1264,14 +1265,14 @@ describe('Direct queryFn Coverage', () => { plugins: unknown[] total: number page: number - pageSize: number + page_size: number } // Verify the returned structure expect(response).toHaveProperty('plugins') expect(response).toHaveProperty('total') expect(response).toHaveProperty('page') - expect(response).toHaveProperty('pageSize') + expect(response).toHaveProperty('page_size') } }) }) @@ -1296,7 +1297,7 @@ describe('flatMap Coverage', () => { ], total: 5, page: 1, - pageSize: 40, + page_size: 40, }, { plugins: [ @@ -1304,7 +1305,7 @@ describe('flatMap Coverage', () => { ], total: 5, page: 2, - pageSize: 40, + page_size: 40, }, ], } @@ -1336,8 +1337,8 @@ describe('flatMap Coverage', () => { it('should test hook with pages data for flatMap path', async () => { mockInfiniteQueryData = { pages: [ - { plugins: [], total: 100, page: 1, pageSize: 40 }, - { plugins: [], total: 100, page: 2, pageSize: 40 }, + { plugins: [], total: 100, page: 1, page_size: 40 }, + { plugins: [], total: 100, page: 2, page_size: 40 }, ], } @@ -1371,7 +1372,7 @@ describe('flatMap Coverage', () => { plugins: unknown[] total: number page: number - pageSize: number + page_size: number } // When error is caught, should return fallback data expect(response.plugins).toEqual([]) @@ -1392,15 +1393,15 @@ describe('flatMap Coverage', () => { // Test getNextPageParam function directly if (capturedGetNextPageParam) { // When there are more pages - const nextPage = capturedGetNextPageParam({ page: 1, pageSize: 40, total: 100 }) + const nextPage = capturedGetNextPageParam({ page: 1, page_size: 40, total: 100 }) expect(nextPage).toBe(2) // When all data is loaded - const noMorePages = capturedGetNextPageParam({ page: 3, pageSize: 40, total: 100 }) + const noMorePages = capturedGetNextPageParam({ page: 3, page_size: 40, total: 100 }) expect(noMorePages).toBeUndefined() // Edge case: exactly at boundary - const atBoundary = capturedGetNextPageParam({ page: 2, pageSize: 50, total: 100 }) + const atBoundary = capturedGetNextPageParam({ page: 2, page_size: 50, total: 100 }) expect(atBoundary).toBeUndefined() } }) @@ -1427,7 +1428,7 @@ describe('flatMap Coverage', () => { plugins: unknown[] total: number page: number - pageSize: number + page_size: number } // Catch block should return fallback values expect(response.plugins).toEqual([]) @@ -1446,7 +1447,7 @@ describe('flatMap Coverage', () => { plugins: [{ name: 'test-plugin-1' }, { name: 'test-plugin-2' }], total: 10, page: 1, - pageSize: 40, + page_size: 40, }, ], } @@ -1489,9 +1490,12 @@ describe('Async Utils', () => { { type: 'plugin', org: 'test', name: 'plugin2' }, ] - globalThis.fetch = vi.fn().mockResolvedValue({ - json: () => Promise.resolve({ data: { plugins: mockPlugins } }), - }) + globalThis.fetch = vi.fn().mockResolvedValue( + new Response(JSON.stringify({ data: { plugins: mockPlugins } }), { + status: 200, + headers: { 'Content-Type': 'application/json' }, + }), + ) const { getMarketplacePluginsByCollectionId } = await import('./utils') const result = await getMarketplacePluginsByCollectionId('test-collection', { @@ -1514,19 +1518,26 @@ describe('Async Utils', () => { }) it('should pass abort signal when provided', async () => { - const mockPlugins = [{ type: 'plugin', org: 'test', name: 'plugin1' }] - globalThis.fetch = vi.fn().mockResolvedValue({ - json: () => Promise.resolve({ data: { plugins: mockPlugins } }), - }) + const mockPlugins = [{ type: 'plugins', org: 'test', name: 'plugin1' }] + globalThis.fetch = vi.fn().mockResolvedValue( + new Response(JSON.stringify({ data: { plugins: mockPlugins } }), { + status: 200, + headers: { 'Content-Type': 'application/json' }, + }), + ) const controller = new AbortController() const { getMarketplacePluginsByCollectionId } = await import('./utils') await getMarketplacePluginsByCollectionId('test-collection', {}, { signal: controller.signal }) + // oRPC uses Request objects, so check that fetch was called with a Request containing the right URL expect(globalThis.fetch).toHaveBeenCalledWith( - expect.any(String), - expect.objectContaining({ signal: controller.signal }), + expect.any(Request), + expect.any(Object), ) + const call = vi.mocked(globalThis.fetch).mock.calls[0] + const request = call[0] as Request + expect(request.url).toContain('test-collection') }) }) @@ -1535,19 +1546,25 @@ describe('Async Utils', () => { const mockCollections = [ { name: 'collection1', label: {}, description: {}, rule: '', created_at: '', updated_at: '' }, ] - const mockPlugins = [{ type: 'plugin', org: 'test', name: 'plugin1' }] + const mockPlugins = [{ type: 'plugins', org: 'test', name: 'plugin1' }] let callCount = 0 globalThis.fetch = vi.fn().mockImplementation(() => { callCount++ if (callCount === 1) { - return Promise.resolve({ - json: () => Promise.resolve({ data: { collections: mockCollections } }), - }) + return Promise.resolve( + new Response(JSON.stringify({ data: { collections: mockCollections } }), { + status: 200, + headers: { 'Content-Type': 'application/json' }, + }), + ) } - return Promise.resolve({ - json: () => Promise.resolve({ data: { plugins: mockPlugins } }), - }) + return Promise.resolve( + new Response(JSON.stringify({ data: { plugins: mockPlugins } }), { + status: 200, + headers: { 'Content-Type': 'application/json' }, + }), + ) }) const { getMarketplaceCollectionsAndPlugins } = await import('./utils') @@ -1571,9 +1588,12 @@ describe('Async Utils', () => { }) it('should append condition and type to URL when provided', async () => { - globalThis.fetch = vi.fn().mockResolvedValue({ - json: () => Promise.resolve({ data: { collections: [] } }), - }) + globalThis.fetch = vi.fn().mockResolvedValue( + new Response(JSON.stringify({ data: { collections: [] } }), { + status: 200, + headers: { 'Content-Type': 'application/json' }, + }), + ) const { getMarketplaceCollectionsAndPlugins } = await import('./utils') await getMarketplaceCollectionsAndPlugins({ @@ -1581,10 +1601,11 @@ describe('Async Utils', () => { type: 'bundle', }) - expect(globalThis.fetch).toHaveBeenCalledWith( - expect.stringContaining('condition=category=tool'), - expect.any(Object), - ) + // oRPC uses Request objects, so check that fetch was called with a Request containing the right URL + expect(globalThis.fetch).toHaveBeenCalled() + const call = vi.mocked(globalThis.fetch).mock.calls[0] + const request = call[0] as Request + expect(request.url).toContain('condition=category%3Dtool') }) }) }) diff --git a/web/app/components/plugins/marketplace/query.ts b/web/app/components/plugins/marketplace/query.ts index c5a1421146..35d99a2bd5 100644 --- a/web/app/components/plugins/marketplace/query.ts +++ b/web/app/components/plugins/marketplace/query.ts @@ -1,22 +1,14 @@ -import type { CollectionsAndPluginsSearchParams, PluginsSearchParams } from './types' +import type { PluginsSearchParams } from './types' +import type { MarketPlaceInputs } from '@/contract/router' import { useInfiniteQuery, useQuery } from '@tanstack/react-query' +import { marketplaceQuery } from '@/service/client' import { getMarketplaceCollectionsAndPlugins, getMarketplacePlugins } from './utils' -// TODO: Avoid manual maintenance of query keys and better service management, -// https://github.com/langgenius/dify/issues/30342 - -export const marketplaceKeys = { - all: ['marketplace'] as const, - collections: (params?: CollectionsAndPluginsSearchParams) => [...marketplaceKeys.all, 'collections', params] as const, - collectionPlugins: (collectionId: string, params?: CollectionsAndPluginsSearchParams) => [...marketplaceKeys.all, 'collectionPlugins', collectionId, params] as const, - plugins: (params?: PluginsSearchParams) => [...marketplaceKeys.all, 'plugins', params] as const, -} - export function useMarketplaceCollectionsAndPlugins( - collectionsParams: CollectionsAndPluginsSearchParams, + collectionsParams: MarketPlaceInputs['collections']['query'], ) { return useQuery({ - queryKey: marketplaceKeys.collections(collectionsParams), + queryKey: marketplaceQuery.collections.queryKey({ input: { query: collectionsParams } }), queryFn: ({ signal }) => getMarketplaceCollectionsAndPlugins(collectionsParams, { signal }), }) } @@ -25,11 +17,16 @@ export function useMarketplacePlugins( queryParams: PluginsSearchParams | undefined, ) { return useInfiniteQuery({ - queryKey: marketplaceKeys.plugins(queryParams), + queryKey: marketplaceQuery.searchAdvanced.queryKey({ + input: { + body: queryParams!, + params: { kind: queryParams?.type === 'bundle' ? 'bundles' : 'plugins' }, + }, + }), queryFn: ({ pageParam = 1, signal }) => getMarketplacePlugins(queryParams, pageParam, signal), getNextPageParam: (lastPage) => { const nextPage = lastPage.page + 1 - const loaded = lastPage.page * lastPage.pageSize + const loaded = lastPage.page * lastPage.page_size return loaded < (lastPage.total || 0) ? nextPage : undefined }, initialPageParam: 1, diff --git a/web/app/components/plugins/marketplace/state.ts b/web/app/components/plugins/marketplace/state.ts index 1c1abfc0a1..9c76a21e92 100644 --- a/web/app/components/plugins/marketplace/state.ts +++ b/web/app/components/plugins/marketplace/state.ts @@ -26,8 +26,8 @@ export function useMarketplaceData() { query: searchPluginText, category: activePluginType === PLUGIN_TYPE_SEARCH_MAP.all ? undefined : activePluginType, tags: filterPluginTags, - sortBy: sort.sortBy, - sortOrder: sort.sortOrder, + sort_by: sort.sortBy, + sort_order: sort.sortOrder, type: getMarketplaceListFilterType(activePluginType), } }, [isSearchMode, searchPluginText, activePluginType, filterPluginTags, sort]) diff --git a/web/app/components/plugins/marketplace/types.ts b/web/app/components/plugins/marketplace/types.ts index 4145f69248..e4e2dbd935 100644 --- a/web/app/components/plugins/marketplace/types.ts +++ b/web/app/components/plugins/marketplace/types.ts @@ -30,9 +30,9 @@ export type MarketplaceCollectionPluginsResponse = { export type PluginsSearchParams = { query: string page?: number - pageSize?: number - sortBy?: string - sortOrder?: string + page_size?: number + sort_by?: string + sort_order?: string category?: string tags?: string[] exclude?: string[] diff --git a/web/app/components/plugins/marketplace/utils.ts b/web/app/components/plugins/marketplace/utils.ts index eaf299314c..01f3c59284 100644 --- a/web/app/components/plugins/marketplace/utils.ts +++ b/web/app/components/plugins/marketplace/utils.ts @@ -4,14 +4,12 @@ import type { MarketplaceCollection, PluginsSearchParams, } from '@/app/components/plugins/marketplace/types' -import type { Plugin, PluginsFromMarketplaceResponse } from '@/app/components/plugins/types' +import type { Plugin } from '@/app/components/plugins/types' import { PluginCategoryEnum } from '@/app/components/plugins/types' import { - APP_VERSION, - IS_MARKETPLACE, MARKETPLACE_API_PREFIX, } from '@/config' -import { postMarketplace } from '@/service/base' +import { marketplaceClient } from '@/service/client' import { getMarketplaceUrl } from '@/utils/var' import { PLUGIN_TYPE_SEARCH_MAP } from './constants' @@ -19,10 +17,6 @@ type MarketplaceFetchOptions = { signal?: AbortSignal } -const getMarketplaceHeaders = () => new Headers({ - 'X-Dify-Version': !IS_MARKETPLACE ? APP_VERSION : '999.0.0', -}) - export const getPluginIconInMarketplace = (plugin: Plugin) => { if (plugin.type === 'bundle') return `${MARKETPLACE_API_PREFIX}/bundles/${plugin.org}/${plugin.name}/icon` @@ -65,24 +59,15 @@ export const getMarketplacePluginsByCollectionId = async ( let plugins: Plugin[] = [] try { - const url = `${MARKETPLACE_API_PREFIX}/collections/${collectionId}/plugins` - const headers = getMarketplaceHeaders() - const marketplaceCollectionPluginsData = await globalThis.fetch( - url, - { - cache: 'no-store', - method: 'POST', - headers, - signal: options?.signal, - body: JSON.stringify({ - category: query?.category, - exclude: query?.exclude, - type: query?.type, - }), + const marketplaceCollectionPluginsDataJson = await marketplaceClient.collectionPlugins({ + params: { + collectionId, }, - ) - const marketplaceCollectionPluginsDataJson = await marketplaceCollectionPluginsData.json() - plugins = (marketplaceCollectionPluginsDataJson.data.plugins || []).map((plugin: Plugin) => getFormattedPlugin(plugin)) + body: query, + }, { + signal: options?.signal, + }) + plugins = (marketplaceCollectionPluginsDataJson.data?.plugins || []).map(plugin => getFormattedPlugin(plugin)) } // eslint-disable-next-line unused-imports/no-unused-vars catch (e) { @@ -99,22 +84,16 @@ export const getMarketplaceCollectionsAndPlugins = async ( let marketplaceCollections: MarketplaceCollection[] = [] let marketplaceCollectionPluginsMap: Record = {} try { - let marketplaceUrl = `${MARKETPLACE_API_PREFIX}/collections?page=1&page_size=100` - if (query?.condition) - marketplaceUrl += `&condition=${query.condition}` - if (query?.type) - marketplaceUrl += `&type=${query.type}` - const headers = getMarketplaceHeaders() - const marketplaceCollectionsData = await globalThis.fetch( - marketplaceUrl, - { - headers, - cache: 'no-store', - signal: options?.signal, + const marketplaceCollectionsDataJson = await marketplaceClient.collections({ + query: { + ...query, + page: 1, + page_size: 100, }, - ) - const marketplaceCollectionsDataJson = await marketplaceCollectionsData.json() - marketplaceCollections = marketplaceCollectionsDataJson.data.collections || [] + }, { + signal: options?.signal, + }) + marketplaceCollections = marketplaceCollectionsDataJson.data?.collections || [] await Promise.all(marketplaceCollections.map(async (collection: MarketplaceCollection) => { const plugins = await getMarketplacePluginsByCollectionId(collection.name, query, options) @@ -143,42 +122,42 @@ export const getMarketplacePlugins = async ( plugins: [] as Plugin[], total: 0, page: 1, - pageSize: 40, + page_size: 40, } } const { query, - sortBy, - sortOrder, + sort_by, + sort_order, category, tags, type, - pageSize = 40, + page_size = 40, } = queryParams - const pluginOrBundle = type === 'bundle' ? 'bundles' : 'plugins' try { - const res = await postMarketplace<{ data: PluginsFromMarketplaceResponse }>(`/${pluginOrBundle}/search/advanced`, { + const res = await marketplaceClient.searchAdvanced({ + params: { + kind: type === 'bundle' ? 'bundles' : 'plugins', + }, body: { page: pageParam, - page_size: pageSize, + page_size, query, - sort_by: sortBy, - sort_order: sortOrder, + sort_by, + sort_order, category: category !== 'all' ? category : '', tags, - type, }, - signal, - }) + }, { signal }) const resPlugins = res.data.bundles || res.data.plugins || [] return { plugins: resPlugins.map(plugin => getFormattedPlugin(plugin)), total: res.data.total, page: pageParam, - pageSize, + page_size, } } catch { @@ -186,7 +165,7 @@ export const getMarketplacePlugins = async ( plugins: [], total: 0, page: pageParam, - pageSize, + page_size, } } } diff --git a/web/app/components/workflow/hooks/use-nodes-interactions.ts b/web/app/components/workflow/hooks/use-nodes-interactions.ts index afb47d5994..8277e7dac8 100644 --- a/web/app/components/workflow/hooks/use-nodes-interactions.ts +++ b/web/app/components/workflow/hooks/use-nodes-interactions.ts @@ -1602,6 +1602,7 @@ export const useNodesInteractions = () => { const offsetX = currentPosition.x - x const offsetY = currentPosition.y - y let idMapping: Record = {} + const parentChildrenToAppend: { parentId: string, childId: string, childType: BlockEnum }[] = [] clipboardElements.forEach((nodeToPaste, index) => { const nodeType = nodeToPaste.data.type @@ -1615,6 +1616,7 @@ export const useNodesInteractions = () => { _isBundled: false, _connectedSourceHandleIds: [], _connectedTargetHandleIds: [], + _dimmed: false, title: genNewNodeTitleFromOld(nodeToPaste.data.title), }, position: { @@ -1682,27 +1684,24 @@ export const useNodesInteractions = () => { return // handle paste to nested block - if (selectedNode.data.type === BlockEnum.Iteration) { - newNode.data.isInIteration = true - newNode.data.iteration_id = selectedNode.data.iteration_id - newNode.parentId = selectedNode.id - newNode.positionAbsolute = { - x: newNode.position.x, - y: newNode.position.y, - } - // set position base on parent node - newNode.position = getNestedNodePosition(newNode, selectedNode) - } - else if (selectedNode.data.type === BlockEnum.Loop) { - newNode.data.isInLoop = true - newNode.data.loop_id = selectedNode.data.loop_id + if (selectedNode.data.type === BlockEnum.Iteration || selectedNode.data.type === BlockEnum.Loop) { + const isIteration = selectedNode.data.type === BlockEnum.Iteration + + newNode.data.isInIteration = isIteration + newNode.data.iteration_id = isIteration ? selectedNode.id : undefined + newNode.data.isInLoop = !isIteration + newNode.data.loop_id = !isIteration ? selectedNode.id : undefined + newNode.parentId = selectedNode.id + newNode.zIndex = isIteration ? ITERATION_CHILDREN_Z_INDEX : LOOP_CHILDREN_Z_INDEX newNode.positionAbsolute = { x: newNode.position.x, y: newNode.position.y, } // set position base on parent node newNode.position = getNestedNodePosition(newNode, selectedNode) + // update parent children array like native add + parentChildrenToAppend.push({ parentId: selectedNode.id, childId: newNode.id, childType: newNode.data.type }) } } } @@ -1733,7 +1732,17 @@ export const useNodesInteractions = () => { } }) - setNodes([...nodes, ...nodesToPaste]) + const newNodes = produce(nodes, (draft: Node[]) => { + parentChildrenToAppend.forEach(({ parentId, childId, childType }) => { + const p = draft.find(n => n.id === parentId) + if (p) { + p.data._children?.push({ nodeId: childId, nodeType: childType }) + } + }) + draft.push(...nodesToPaste) + }) + + setNodes(newNodes) setEdges([...edges, ...edgesToPaste]) saveStateToHistory(WorkflowHistoryEvent.NodePaste, { nodeId: nodesToPaste?.[0]?.id, diff --git a/web/app/components/workflow/nodes/loop/use-interactions.ts b/web/app/components/workflow/nodes/loop/use-interactions.ts index 006d8f963b..5e8f6ae36c 100644 --- a/web/app/components/workflow/nodes/loop/use-interactions.ts +++ b/web/app/components/workflow/nodes/loop/use-interactions.ts @@ -7,6 +7,7 @@ import { useCallback } from 'react' import { useStoreApi } from 'reactflow' import { useNodesMetaData } from '@/app/components/workflow/hooks' import { + LOOP_CHILDREN_Z_INDEX, LOOP_PADDING, } from '../../constants' import { @@ -114,9 +115,7 @@ export const useNodeLoopInteractions = () => { return childrenNodes.map((child, index) => { const childNodeType = child.data.type as BlockEnum - const { - defaultValue, - } = nodesMetaDataMap![childNodeType] + const { defaultValue } = nodesMetaDataMap![childNodeType] const nodesWithSameType = nodes.filter(node => node.data.type === childNodeType) const { newNode } = generateNewNode({ type: getNodeCustomTypeByNodeDataType(childNodeType), @@ -127,15 +126,17 @@ export const useNodeLoopInteractions = () => { _isBundled: false, _connectedSourceHandleIds: [], _connectedTargetHandleIds: [], + _dimmed: false, title: nodesWithSameType.length > 0 ? `${defaultValue.title} ${nodesWithSameType.length + 1}` : defaultValue.title, + isInLoop: true, loop_id: newNodeId, - + type: childNodeType, }, position: child.position, positionAbsolute: child.positionAbsolute, parentId: newNodeId, extent: child.extent, - zIndex: child.zIndex, + zIndex: LOOP_CHILDREN_Z_INDEX, }) newNode.id = `${newNodeId}${newNode.id + index}` return newNode diff --git a/web/app/install/installForm.spec.tsx b/web/app/install/installForm.spec.tsx index 74602f916a..17ce35d6a1 100644 --- a/web/app/install/installForm.spec.tsx +++ b/web/app/install/installForm.spec.tsx @@ -16,9 +16,16 @@ vi.mock('@/service/common', () => ({ fetchInitValidateStatus: vi.fn(), setup: vi.fn(), login: vi.fn(), - getSystemFeatures: vi.fn(), })) +vi.mock('@/context/global-public-context', async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + useIsSystemFeaturesPending: () => false, + } +}) + const mockFetchSetupStatus = vi.mocked(fetchSetupStatus) const mockFetchInitValidateStatus = vi.mocked(fetchInitValidateStatus) const mockSetup = vi.mocked(setup) diff --git a/web/context/global-public-context.tsx b/web/context/global-public-context.tsx index c2742bb7a9..3a570fc7ef 100644 --- a/web/context/global-public-context.tsx +++ b/web/context/global-public-context.tsx @@ -2,42 +2,61 @@ import type { FC, PropsWithChildren } from 'react' import type { SystemFeatures } from '@/types/feature' import { useQuery } from '@tanstack/react-query' -import { useEffect } from 'react' import { create } from 'zustand' import Loading from '@/app/components/base/loading' -import { getSystemFeatures } from '@/service/common' +import { consoleClient } from '@/service/client' import { defaultSystemFeatures } from '@/types/feature' +import { fetchSetupStatusWithCache } from '@/utils/setup-status' type GlobalPublicStore = { - isGlobalPending: boolean - setIsGlobalPending: (isPending: boolean) => void systemFeatures: SystemFeatures setSystemFeatures: (systemFeatures: SystemFeatures) => void } export const useGlobalPublicStore = create(set => ({ - isGlobalPending: true, - setIsGlobalPending: (isPending: boolean) => set(() => ({ isGlobalPending: isPending })), systemFeatures: defaultSystemFeatures, setSystemFeatures: (systemFeatures: SystemFeatures) => set(() => ({ systemFeatures })), })) +const systemFeaturesQueryKey = ['systemFeatures'] as const +const setupStatusQueryKey = ['setupStatus'] as const + +async function fetchSystemFeatures() { + const data = await consoleClient.systemFeatures() + const { setSystemFeatures } = useGlobalPublicStore.getState() + setSystemFeatures({ ...defaultSystemFeatures, ...data }) + return data +} + +export function useSystemFeaturesQuery() { + return useQuery({ + queryKey: systemFeaturesQueryKey, + queryFn: fetchSystemFeatures, + }) +} + +export function useIsSystemFeaturesPending() { + const { isPending } = useSystemFeaturesQuery() + return isPending +} + +export function useSetupStatusQuery() { + return useQuery({ + queryKey: setupStatusQueryKey, + queryFn: fetchSetupStatusWithCache, + staleTime: Infinity, + }) +} + const GlobalPublicStoreProvider: FC = ({ children, }) => { - const { isPending, data } = useQuery({ - queryKey: ['systemFeatures'], - queryFn: getSystemFeatures, - }) - const { setSystemFeatures, setIsGlobalPending: setIsPending } = useGlobalPublicStore() - useEffect(() => { - if (data) - setSystemFeatures({ ...defaultSystemFeatures, ...data }) - }, [data, setSystemFeatures]) + // Fetch systemFeatures and setupStatus in parallel to reduce waterfall. + // setupStatus is prefetched here and cached in localStorage for AppInitializer. + const { isPending } = useSystemFeaturesQuery() - useEffect(() => { - setIsPending(isPending) - }, [isPending, setIsPending]) + // Prefetch setupStatus for AppInitializer (result not needed here) + useSetupStatusQuery() if (isPending) return
diff --git a/web/context/web-app-context.tsx b/web/context/web-app-context.tsx index e6680c95a5..c5488a565c 100644 --- a/web/context/web-app-context.tsx +++ b/web/context/web-app-context.tsx @@ -10,7 +10,7 @@ import { getProcessedSystemVariablesFromUrlParams } from '@/app/components/base/ import Loading from '@/app/components/base/loading' import { AccessMode } from '@/models/access-control' import { useGetWebAppAccessModeByCode } from '@/service/use-share' -import { useGlobalPublicStore } from './global-public-context' +import { useIsSystemFeaturesPending } from './global-public-context' type WebAppStore = { shareCode: string | null @@ -65,7 +65,7 @@ const getShareCodeFromPathname = (pathname: string): string | null => { } const WebAppStoreProvider: FC = ({ children }) => { - const isGlobalPending = useGlobalPublicStore(s => s.isGlobalPending) + const isGlobalPending = useIsSystemFeaturesPending() const updateWebAppAccessMode = useWebAppStore(state => state.updateWebAppAccessMode) const updateShareCode = useWebAppStore(state => state.updateShareCode) const updateEmbeddedUserId = useWebAppStore(state => state.updateEmbeddedUserId) diff --git a/web/contract/base.ts b/web/contract/base.ts new file mode 100644 index 0000000000..764db9d554 --- /dev/null +++ b/web/contract/base.ts @@ -0,0 +1,3 @@ +import { oc } from '@orpc/contract' + +export const base = oc.$route({ inputStructure: 'detailed' }) diff --git a/web/contract/console.ts b/web/contract/console.ts new file mode 100644 index 0000000000..ec929d1357 --- /dev/null +++ b/web/contract/console.ts @@ -0,0 +1,34 @@ +import type { SystemFeatures } from '@/types/feature' +import { type } from '@orpc/contract' +import { base } from './base' + +export const systemFeaturesContract = base + .route({ + path: '/system-features', + method: 'GET', + }) + .input(type()) + .output(type()) + +export const billingUrlContract = base + .route({ + path: '/billing/invoices', + method: 'GET', + }) + .input(type()) + .output(type<{ url: string }>()) + +export const bindPartnerStackContract = base + .route({ + path: '/billing/partners/{partnerKey}/tenants', + method: 'PUT', + }) + .input(type<{ + params: { + partnerKey: string + } + body: { + click_id: string + } + }>()) + .output(type()) diff --git a/web/contract/marketplace.ts b/web/contract/marketplace.ts new file mode 100644 index 0000000000..3573ba5c24 --- /dev/null +++ b/web/contract/marketplace.ts @@ -0,0 +1,56 @@ +import type { CollectionsAndPluginsSearchParams, MarketplaceCollection, PluginsSearchParams } from '@/app/components/plugins/marketplace/types' +import type { Plugin, PluginsFromMarketplaceResponse } from '@/app/components/plugins/types' +import { type } from '@orpc/contract' +import { base } from './base' + +export const collectionsContract = base + .route({ + path: '/collections', + method: 'GET', + }) + .input( + type<{ + query?: CollectionsAndPluginsSearchParams & { page?: number, page_size?: number } + }>(), + ) + .output( + type<{ + data?: { + collections?: MarketplaceCollection[] + } + }>(), + ) + +export const collectionPluginsContract = base + .route({ + path: '/collections/{collectionId}/plugins', + method: 'POST', + }) + .input( + type<{ + params: { + collectionId: string + } + body?: CollectionsAndPluginsSearchParams + }>(), + ) + .output( + type<{ + data?: { + plugins?: Plugin[] + } + }>(), + ) + +export const searchAdvancedContract = base + .route({ + path: '/{kind}/search/advanced', + method: 'POST', + }) + .input(type<{ + params: { + kind: 'plugins' | 'bundles' + } + body: Omit + }>()) + .output(type<{ data: PluginsFromMarketplaceResponse }>()) diff --git a/web/contract/router.ts b/web/contract/router.ts new file mode 100644 index 0000000000..d83cffb7b8 --- /dev/null +++ b/web/contract/router.ts @@ -0,0 +1,19 @@ +import type { InferContractRouterInputs } from '@orpc/contract' +import { billingUrlContract, bindPartnerStackContract, systemFeaturesContract } from './console' +import { collectionPluginsContract, collectionsContract, searchAdvancedContract } from './marketplace' + +export const marketplaceRouterContract = { + collections: collectionsContract, + collectionPlugins: collectionPluginsContract, + searchAdvanced: searchAdvancedContract, +} + +export type MarketPlaceInputs = InferContractRouterInputs + +export const consoleRouterContract = { + systemFeatures: systemFeaturesContract, + billingUrl: billingUrlContract, + bindPartnerStack: bindPartnerStackContract, +} + +export type ConsoleInputs = InferContractRouterInputs diff --git a/web/hooks/use-document-title.spec.ts b/web/hooks/use-document-title.spec.ts index 3909978591..7ce1e693db 100644 --- a/web/hooks/use-document-title.spec.ts +++ b/web/hooks/use-document-title.spec.ts @@ -1,5 +1,5 @@ import { act, renderHook } from '@testing-library/react' -import { useGlobalPublicStore } from '@/context/global-public-context' +import { useGlobalPublicStore, useIsSystemFeaturesPending } from '@/context/global-public-context' /** * Test suite for useDocumentTitle hook * @@ -15,19 +15,25 @@ import { useGlobalPublicStore } from '@/context/global-public-context' import { defaultSystemFeatures } from '@/types/feature' import useDocumentTitle from './use-document-title' -vi.mock('@/service/common', () => ({ - getSystemFeatures: vi.fn(() => ({ ...defaultSystemFeatures })), -})) +vi.mock('@/context/global-public-context', async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + useIsSystemFeaturesPending: vi.fn(() => false), + } +}) /** * Test behavior when system features are still loading * Title should remain empty to prevent flicker */ describe('title should be empty if systemFeatures is pending', () => { - act(() => { - useGlobalPublicStore.setState({ - systemFeatures: { ...defaultSystemFeatures, branding: { ...defaultSystemFeatures.branding, enabled: false } }, - isGlobalPending: true, + beforeEach(() => { + vi.mocked(useIsSystemFeaturesPending).mockReturnValue(true) + act(() => { + useGlobalPublicStore.setState({ + systemFeatures: { ...defaultSystemFeatures, branding: { ...defaultSystemFeatures.branding, enabled: false } }, + }) }) }) /** @@ -52,9 +58,9 @@ describe('title should be empty if systemFeatures is pending', () => { */ describe('use default branding', () => { beforeEach(() => { + vi.mocked(useIsSystemFeaturesPending).mockReturnValue(false) act(() => { useGlobalPublicStore.setState({ - isGlobalPending: false, systemFeatures: { ...defaultSystemFeatures, branding: { ...defaultSystemFeatures.branding, enabled: false } }, }) }) @@ -84,9 +90,9 @@ describe('use default branding', () => { */ describe('use specific branding', () => { beforeEach(() => { + vi.mocked(useIsSystemFeaturesPending).mockReturnValue(false) act(() => { useGlobalPublicStore.setState({ - isGlobalPending: false, systemFeatures: { ...defaultSystemFeatures, branding: { ...defaultSystemFeatures.branding, enabled: true, application_title: 'Test' } }, }) }) diff --git a/web/hooks/use-document-title.ts b/web/hooks/use-document-title.ts index bb69aeb20f..37b31a7dea 100644 --- a/web/hooks/use-document-title.ts +++ b/web/hooks/use-document-title.ts @@ -1,11 +1,11 @@ 'use client' import { useFavicon, useTitle } from 'ahooks' import { useEffect } from 'react' -import { useGlobalPublicStore } from '@/context/global-public-context' +import { useGlobalPublicStore, useIsSystemFeaturesPending } from '@/context/global-public-context' import { basePath } from '@/utils/var' export default function useDocumentTitle(title: string) { - const isPending = useGlobalPublicStore(s => s.isGlobalPending) + const isPending = useIsSystemFeaturesPending() const systemFeatures = useGlobalPublicStore(s => s.systemFeatures) const prefix = title ? `${title} - ` : '' let titleStr = '' diff --git a/web/package.json b/web/package.json index 4019e49cd9..fab33f7608 100644 --- a/web/package.json +++ b/web/package.json @@ -1,7 +1,7 @@ { "name": "dify-web", "type": "module", - "version": "1.11.2", + "version": "1.11.3", "private": true, "packageManager": "pnpm@10.27.0+sha512.72d699da16b1179c14ba9e64dc71c9a40988cbdc65c264cb0e489db7de917f20dcf4d64d8723625f2969ba52d4b7e2a1170682d9ac2a5dcaeaab732b7e16f04a", "imports": { @@ -69,6 +69,10 @@ "@monaco-editor/react": "^4.7.0", "@octokit/core": "^6.1.6", "@octokit/request-error": "^6.1.8", + "@orpc/client": "^1.13.4", + "@orpc/contract": "^1.13.4", + "@orpc/openapi-client": "^1.13.4", + "@orpc/tanstack-query": "^1.13.4", "@remixicon/react": "^4.7.0", "@sentry/react": "^8.55.0", "@svgdotjs/svg.js": "^3.2.5", diff --git a/web/pnpm-lock.yaml b/web/pnpm-lock.yaml index 853c366025..c8797e3d65 100644 --- a/web/pnpm-lock.yaml +++ b/web/pnpm-lock.yaml @@ -108,6 +108,18 @@ importers: '@octokit/request-error': specifier: ^6.1.8 version: 6.1.8 + '@orpc/client': + specifier: ^1.13.4 + version: 1.13.4 + '@orpc/contract': + specifier: ^1.13.4 + version: 1.13.4 + '@orpc/openapi-client': + specifier: ^1.13.4 + version: 1.13.4 + '@orpc/tanstack-query': + specifier: ^1.13.4 + version: 1.13.4(@orpc/client@1.13.4)(@tanstack/query-core@5.90.12) '@remixicon/react': specifier: ^4.7.0 version: 4.7.0(react@19.2.3) @@ -2291,6 +2303,38 @@ packages: '@open-draft/until@2.1.0': resolution: {integrity: sha512-U69T3ItWHvLwGg5eJ0n3I62nWuE6ilHlmz7zM0npLBRvPRd7e6NYmg54vvRtP5mZG7kZqZCFVdsTWo7BPtBujg==} + '@orpc/client@1.13.4': + resolution: {integrity: sha512-s13GPMeoooJc5Th2EaYT5HMFtWG8S03DUVytYfJv8pIhP87RYKl94w52A36denH6r/B4LaAgBeC9nTAOslK+Og==} + + '@orpc/contract@1.13.4': + resolution: {integrity: sha512-TIxyaF67uOlihCRcasjHZxguZpbqfNK7aMrDLnhoufmQBE4OKvguNzmrOFHgsuM0OXoopX0Nuhun1ccaxKP10A==} + + '@orpc/openapi-client@1.13.4': + resolution: {integrity: sha512-tRUcY4E6sgpS5bY/9nNES/Q/PMyYyPOsI4TuhwLhfgxOb0GFPwYKJ6Kif7KFNOhx4fkN/jTOfE1nuWuIZU1gyg==} + + '@orpc/shared@1.13.4': + resolution: {integrity: sha512-TYt9rLG/BUkNQBeQ6C1tEiHS/Seb8OojHgj9GlvqyjHJhMZx5qjsIyTW6RqLPZJ4U2vgK6x4Her36+tlFCKJug==} + peerDependencies: + '@opentelemetry/api': '>=1.9.0' + peerDependenciesMeta: + '@opentelemetry/api': + optional: true + + '@orpc/standard-server-fetch@1.13.4': + resolution: {integrity: sha512-/zmKwnuxfAXbppJpgr1CMnQX3ptPlYcDzLz1TaVzz9VG/Xg58Ov3YhabS2Oi1utLVhy5t4kaCppUducAvoKN+A==} + + '@orpc/standard-server-peer@1.13.4': + resolution: {integrity: sha512-UfqnTLqevjCKUk4cmImOG8cQUwANpV1dp9e9u2O1ki6BRBsg/zlXFg6G2N6wP0zr9ayIiO1d2qJdH55yl/1BNw==} + + '@orpc/standard-server@1.13.4': + resolution: {integrity: sha512-ZOzgfVp6XUg+wVYw+gqesfRfGPtQbnBIrIiSnFMtZF+6ncmFJeF2Shc4RI2Guqc0Qz25juy8Ogo4tX3YqysOcg==} + + '@orpc/tanstack-query@1.13.4': + resolution: {integrity: sha512-gCL/kh3kf6OUGKfXxSoOZpcX1jNYzxGfo/PkLQKX7ui4xiTbfWw3sCDF30sNS4I7yAOnBwDwJ3N2xzfkTftOBg==} + peerDependencies: + '@orpc/client': 1.13.4 + '@tanstack/query-core': '>=5.80.2' + '@oxc-resolver/binding-android-arm-eabi@11.15.0': resolution: {integrity: sha512-Q+lWuFfq7whNelNJIP1dhXaVz4zO9Tu77GcQHyxDWh3MaCoO2Bisphgzmsh4ZoUe2zIchQh6OvQL99GlWHg9Tw==} cpu: [arm] @@ -6685,6 +6729,9 @@ packages: resolution: {integrity: sha512-7x81NCL719oNbsq/3mh+hVrAWmFuEYUqrq/Iw3kUzH8ReypT9QQ0BLoJS7/G9k6N81XjW4qHWtjWwe/9eLy1EQ==} engines: {node: '>=12'} + openapi-types@12.1.3: + resolution: {integrity: sha512-N4YtSYJqghVu4iek2ZUvcN/0aqH1kRDuNqzcycDxhOUpg7GdvLa2F3DgS6yBNhInhv2r/6I0Flkn7CqL8+nIcw==} + opener@1.5.2: resolution: {integrity: sha512-ur5UIdyw5Y7yEj9wLzhqXiy6GZ3Mwx0yGI+5sMn2r0N0v3cKJvUmFH5yPP+WXh9e0xfyzyJX95D8l088DNFj7A==} hasBin: true @@ -7081,6 +7128,10 @@ packages: queue-microtask@1.2.3: resolution: {integrity: sha512-NuaNSa6flKT5JaSYQzJok04JzTL1CA6aGhv5rfLW3PgqA+M2ChpZQnAC8h8i4ZFkBS8X5RqkDBHA7r4hej3K9A==} + radash@12.1.1: + resolution: {integrity: sha512-h36JMxKRqrAxVD8201FrCpyeNuUY9Y5zZwujr20fFO77tpUtGa6EZzfKw/3WaiBX95fq7+MpsuMLNdSnORAwSA==} + engines: {node: '>=14.18.0'} + randombytes@2.1.0: resolution: {integrity: sha512-vYl3iOX+4CKUWuxGi9Ukhie6fsqXqS9FE2Zaic4tNFD2N2QQaXOMFbuKK4QmDHC0JO6B1Zp41J0LpT0oR68amQ==} @@ -7826,6 +7877,10 @@ packages: tabbable@6.3.0: resolution: {integrity: sha512-EIHvdY5bPLuWForiR/AN2Bxngzpuwn1is4asboytXtpTgsArc+WmSJKVLlhdh71u7jFcryDqB2A8lQvj78MkyQ==} + tagged-tag@1.0.0: + resolution: {integrity: sha512-yEFYrVhod+hdNyx7g5Bnkkb0G6si8HJurOoOEgC8B/O0uXLHlaey/65KRv6cuWBNhBgHKAROVpc7QyYqE5gFng==} + engines: {node: '>=20'} + tailwind-merge@2.6.0: resolution: {integrity: sha512-P+Vu1qXfzediirmHOC3xKGAYeZtPcV9g76X+xg2FD4tYgR71ewMA35Y3sCz3zhiN/dwefRpJX0yBcgwi1fXNQA==} @@ -8027,13 +8082,17 @@ packages: resolution: {integrity: sha512-5zknd7Dss75pMSED270A1RQS3KloqRJA9XbXLe0eCxyw7xXFb3rd+9B0UQ/0E+LQT6lnrLviEolYORlRWamn4w==} engines: {node: '>=16'} + type-fest@5.4.0: + resolution: {integrity: sha512-wfkA6r0tBpVfGiyO+zbf9e10QkRQSlK9F2UvyfnjoCmrvH2bjHyhPzhugSBOuq1dog3P0+FKckqe+Xf6WKVjwg==} + engines: {node: '>=20'} + typescript@5.9.3: resolution: {integrity: sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==} engines: {node: '>=14.17'} hasBin: true - ufo@1.6.1: - resolution: {integrity: sha512-9a4/uxlTWJ4+a5i0ooc1rU7C7YOw3wT+UGqdeNNHWnOF9qcMBgLRS+4IYUqbczewFx4mLEig6gawh7X6mFlEkA==} + ufo@1.6.2: + resolution: {integrity: sha512-heMioaxBcG9+Znsda5Q8sQbWnLJSl98AFDXTO80wELWEzX3hordXsTdxrIfMQoO9IY1MEnoGoPjpoKpMj+Yx0Q==} uglify-js@3.19.3: resolution: {integrity: sha512-v3Xu+yuwBXisp6QYTcH4UbH+xYJXqnq2m/LtQVWKWzYc1iehYnLixoQDN9FH6/j9/oybfd6W9Ghwkl8+UMKTKQ==} @@ -10638,6 +10697,66 @@ snapshots: '@open-draft/until@2.1.0': {} + '@orpc/client@1.13.4': + dependencies: + '@orpc/shared': 1.13.4 + '@orpc/standard-server': 1.13.4 + '@orpc/standard-server-fetch': 1.13.4 + '@orpc/standard-server-peer': 1.13.4 + transitivePeerDependencies: + - '@opentelemetry/api' + + '@orpc/contract@1.13.4': + dependencies: + '@orpc/client': 1.13.4 + '@orpc/shared': 1.13.4 + '@standard-schema/spec': 1.1.0 + openapi-types: 12.1.3 + transitivePeerDependencies: + - '@opentelemetry/api' + + '@orpc/openapi-client@1.13.4': + dependencies: + '@orpc/client': 1.13.4 + '@orpc/contract': 1.13.4 + '@orpc/shared': 1.13.4 + '@orpc/standard-server': 1.13.4 + transitivePeerDependencies: + - '@opentelemetry/api' + + '@orpc/shared@1.13.4': + dependencies: + radash: 12.1.1 + type-fest: 5.4.0 + + '@orpc/standard-server-fetch@1.13.4': + dependencies: + '@orpc/shared': 1.13.4 + '@orpc/standard-server': 1.13.4 + transitivePeerDependencies: + - '@opentelemetry/api' + + '@orpc/standard-server-peer@1.13.4': + dependencies: + '@orpc/shared': 1.13.4 + '@orpc/standard-server': 1.13.4 + transitivePeerDependencies: + - '@opentelemetry/api' + + '@orpc/standard-server@1.13.4': + dependencies: + '@orpc/shared': 1.13.4 + transitivePeerDependencies: + - '@opentelemetry/api' + + '@orpc/tanstack-query@1.13.4(@orpc/client@1.13.4)(@tanstack/query-core@5.90.12)': + dependencies: + '@orpc/client': 1.13.4 + '@orpc/shared': 1.13.4 + '@tanstack/query-core': 5.90.12 + transitivePeerDependencies: + - '@opentelemetry/api' + '@oxc-resolver/binding-android-arm-eabi@11.15.0': optional: true @@ -15603,7 +15722,7 @@ snapshots: acorn: 8.15.0 pathe: 2.0.3 pkg-types: 1.3.1 - ufo: 1.6.1 + ufo: 1.6.2 monaco-editor@0.55.1: dependencies: @@ -15766,6 +15885,8 @@ snapshots: is-docker: 2.2.1 is-wsl: 2.2.0 + openapi-types@12.1.3: {} + opener@1.5.2: {} optionator@0.9.4: @@ -16181,6 +16302,8 @@ snapshots: queue-microtask@1.2.3: {} + radash@12.1.1: {} + randombytes@2.1.0: dependencies: safe-buffer: 5.2.1 @@ -17098,6 +17221,8 @@ snapshots: tabbable@6.3.0: {} + tagged-tag@1.0.0: {} + tailwind-merge@2.6.0: {} tailwindcss@3.4.18(tsx@4.21.0)(yaml@2.8.2): @@ -17305,9 +17430,13 @@ snapshots: type-fest@4.2.0: optional: true + type-fest@5.4.0: + dependencies: + tagged-tag: 1.0.0 + typescript@5.9.3: {} - ufo@1.6.1: {} + ufo@1.6.2: {} uglify-js@3.19.3: {} diff --git a/web/service/base.ts b/web/service/base.ts index 2ab115f96c..fb32ce6bcf 100644 --- a/web/service/base.ts +++ b/web/service/base.ts @@ -81,6 +81,11 @@ export type IOtherOptions = { needAllResponseContent?: boolean deleteContentType?: boolean silent?: boolean + + /** If true, behaves like standard fetch: no URL prefix, returns raw Response */ + fetchCompat?: boolean + request?: Request + onData?: IOnData // for stream onThought?: IOnThought onFile?: IOnFile diff --git a/web/service/billing.ts b/web/service/billing.ts index f06c4f06c6..075ab71ade 100644 --- a/web/service/billing.ts +++ b/web/service/billing.ts @@ -1,5 +1,5 @@ import type { CurrentPlanInfoBackend, SubscriptionUrlsBackend } from '@/app/components/billing/type' -import { get, put } from './base' +import { get } from './base' export const fetchCurrentPlanInfo = () => { return get('/features') @@ -8,17 +8,3 @@ export const fetchCurrentPlanInfo = () => { export const fetchSubscriptionUrls = (plan: string, interval: string) => { return get(`/billing/subscription?plan=${plan}&interval=${interval}`) } - -export const fetchBillingUrl = () => { - return get<{ url: string }>('/billing/invoices') -} - -export const bindPartnerStackInfo = (partnerKey: string, clickId: string) => { - return put(`/billing/partners/${partnerKey}/tenants`, { - body: { - click_id: clickId, - }, - }, { - silent: true, - }) -} diff --git a/web/service/client.ts b/web/service/client.ts new file mode 100644 index 0000000000..c9c92ddd15 --- /dev/null +++ b/web/service/client.ts @@ -0,0 +1,61 @@ +import type { ContractRouterClient } from '@orpc/contract' +import type { JsonifiedClient } from '@orpc/openapi-client' +import { createORPCClient, onError } from '@orpc/client' +import { OpenAPILink } from '@orpc/openapi-client/fetch' +import { createTanstackQueryUtils } from '@orpc/tanstack-query' +import { + API_PREFIX, + APP_VERSION, + IS_MARKETPLACE, + MARKETPLACE_API_PREFIX, +} from '@/config' +import { + consoleRouterContract, + marketplaceRouterContract, +} from '@/contract/router' +import { request } from './base' + +const getMarketplaceHeaders = () => new Headers({ + 'X-Dify-Version': !IS_MARKETPLACE ? APP_VERSION : '999.0.0', +}) + +const marketplaceLink = new OpenAPILink(marketplaceRouterContract, { + url: MARKETPLACE_API_PREFIX, + headers: () => (getMarketplaceHeaders()), + fetch: (request, init) => { + return globalThis.fetch(request, { + ...init, + cache: 'no-store', + }) + }, + interceptors: [ + onError((error) => { + console.error(error) + }), + ], +}) + +export const marketplaceClient: JsonifiedClient> = createORPCClient(marketplaceLink) +export const marketplaceQuery = createTanstackQueryUtils(marketplaceClient, { path: ['marketplace'] }) + +const consoleLink = new OpenAPILink(consoleRouterContract, { + url: API_PREFIX, + fetch: (input, init) => { + return request( + input.url, + init, + { + fetchCompat: true, + request: input, + }, + ) + }, + interceptors: [ + onError((error) => { + console.error(error) + }), + ], +}) + +export const consoleClient: JsonifiedClient> = createORPCClient(consoleLink) +export const consoleQuery = createTanstackQueryUtils(consoleClient, { path: ['console'] }) diff --git a/web/service/common.ts b/web/service/common.ts index 5fc4850d5f..70211d10d3 100644 --- a/web/service/common.ts +++ b/web/service/common.ts @@ -34,7 +34,6 @@ import type { UserProfileOriginResponse, } from '@/models/common' import type { RETRIEVE_METHOD } from '@/types/app' -import type { SystemFeatures } from '@/types/feature' import { del, get, patch, post, put } from './base' type LoginSuccess = { @@ -307,10 +306,6 @@ export const fetchSupportRetrievalMethods = (url: string): Promise(url) } -export const getSystemFeatures = (): Promise => { - return get('/system-features') -} - export const enableModel = (url: string, body: { model: string, model_type: ModelTypeEnum }): Promise => patch(url, { body }) diff --git a/web/service/fetch.ts b/web/service/fetch.ts index d0af932d73..13be7ae97b 100644 --- a/web/service/fetch.ts +++ b/web/service/fetch.ts @@ -136,6 +136,8 @@ async function base(url: string, options: FetchOptionType = {}, otherOptions: needAllResponseContent, deleteContentType, getAbortController, + fetchCompat = false, + request, } = otherOptions let base: string @@ -181,7 +183,7 @@ async function base(url: string, options: FetchOptionType = {}, otherOptions: }, }) - const res = await client(fetchPathname, { + const res = await client(request || fetchPathname, { ...init, headers, credentials: isMarketplaceAPI @@ -190,8 +192,8 @@ async function base(url: string, options: FetchOptionType = {}, otherOptions: retry: { methods: [], }, - ...(bodyStringify ? { json: body } : { body: body as BodyInit }), - searchParams: params, + ...(bodyStringify && !fetchCompat ? { json: body } : { body: body as BodyInit }), + searchParams: !fetchCompat ? params : undefined, fetch(resource: RequestInfo | URL, options?: RequestInit) { if (resource instanceof Request && options) { const mergedHeaders = new Headers(options.headers || {}) @@ -204,7 +206,7 @@ async function base(url: string, options: FetchOptionType = {}, otherOptions: }, }) - if (needAllResponseContent) + if (needAllResponseContent || fetchCompat) return res as T const contentType = res.headers.get('content-type') if ( diff --git a/web/service/use-billing.ts b/web/service/use-billing.ts index 3dc2b8a994..794b192d5c 100644 --- a/web/service/use-billing.ts +++ b/web/service/use-billing.ts @@ -1,21 +1,22 @@ import { useMutation, useQuery } from '@tanstack/react-query' -import { bindPartnerStackInfo, fetchBillingUrl } from '@/service/billing' - -const NAME_SPACE = 'billing' +import { consoleClient, consoleQuery } from '@/service/client' export const useBindPartnerStackInfo = () => { return useMutation({ - mutationKey: [NAME_SPACE, 'bind-partner-stack'], - mutationFn: (data: { partnerKey: string, clickId: string }) => bindPartnerStackInfo(data.partnerKey, data.clickId), + mutationKey: consoleQuery.bindPartnerStack.mutationKey(), + mutationFn: (data: { partnerKey: string, clickId: string }) => consoleClient.bindPartnerStack({ + params: { partnerKey: data.partnerKey }, + body: { click_id: data.clickId }, + }), }) } export const useBillingUrl = (enabled: boolean) => { return useQuery({ - queryKey: [NAME_SPACE, 'url'], + queryKey: consoleQuery.billingUrl.queryKey(), enabled, queryFn: async () => { - const res = await fetchBillingUrl() + const res = await consoleClient.billingUrl() return res.url }, }) diff --git a/web/service/use-plugins.ts b/web/service/use-plugins.ts index 4e9776df97..5267503a11 100644 --- a/web/service/use-plugins.ts +++ b/web/service/use-plugins.ts @@ -488,23 +488,23 @@ export const useMutationPluginsFromMarketplace = () => { mutationFn: (pluginsSearchParams: PluginsSearchParams) => { const { query, - sortBy, - sortOrder, + sort_by, + sort_order, category, tags, exclude, type, page = 1, - pageSize = 40, + page_size = 40, } = pluginsSearchParams const pluginOrBundle = type === 'bundle' ? 'bundles' : 'plugins' return postMarketplace<{ data: PluginsFromMarketplaceResponse }>(`/${pluginOrBundle}/search/advanced`, { body: { page, - page_size: pageSize, + page_size, query, - sort_by: sortBy, - sort_order: sortOrder, + sort_by, + sort_order, category: category !== 'all' ? category : '', tags, exclude, @@ -535,23 +535,23 @@ export const useFetchPluginListOrBundleList = (pluginsSearchParams: PluginsSearc queryFn: () => { const { query, - sortBy, - sortOrder, + sort_by, + sort_order, category, tags, exclude, type, page = 1, - pageSize = 40, + page_size = 40, } = pluginsSearchParams const pluginOrBundle = type === 'bundle' ? 'bundles' : 'plugins' return postMarketplace<{ data: PluginsFromMarketplaceResponse }>(`/${pluginOrBundle}/search/advanced`, { body: { page, - page_size: pageSize, + page_size, query, - sort_by: sortBy, - sort_order: sortOrder, + sort_by, + sort_order, category: category !== 'all' ? category : '', tags, exclude, diff --git a/web/utils/setup-status.spec.ts b/web/utils/setup-status.spec.ts new file mode 100644 index 0000000000..be96b43eba --- /dev/null +++ b/web/utils/setup-status.spec.ts @@ -0,0 +1,139 @@ +import type { SetupStatusResponse } from '@/models/common' + +import { fetchSetupStatus } from '@/service/common' + +import { fetchSetupStatusWithCache } from './setup-status' + +vi.mock('@/service/common', () => ({ + fetchSetupStatus: vi.fn(), +})) + +const mockFetchSetupStatus = vi.mocked(fetchSetupStatus) + +describe('setup-status utilities', () => { + beforeEach(() => { + vi.clearAllMocks() + localStorage.clear() + }) + + describe('fetchSetupStatusWithCache', () => { + describe('when cache exists', () => { + it('should return cached finished status without API call', async () => { + localStorage.setItem('setup_status', 'finished') + + const result = await fetchSetupStatusWithCache() + + expect(result).toEqual({ step: 'finished' }) + expect(mockFetchSetupStatus).not.toHaveBeenCalled() + }) + + it('should not modify localStorage when returning cached value', async () => { + localStorage.setItem('setup_status', 'finished') + + await fetchSetupStatusWithCache() + + expect(localStorage.getItem('setup_status')).toBe('finished') + }) + }) + + describe('when cache does not exist', () => { + it('should call API and cache finished status', async () => { + const apiResponse: SetupStatusResponse = { step: 'finished' } + mockFetchSetupStatus.mockResolvedValue(apiResponse) + + const result = await fetchSetupStatusWithCache() + + expect(mockFetchSetupStatus).toHaveBeenCalledTimes(1) + expect(result).toEqual(apiResponse) + expect(localStorage.getItem('setup_status')).toBe('finished') + }) + + it('should call API and remove cache when not finished', async () => { + const apiResponse: SetupStatusResponse = { step: 'not_started' } + mockFetchSetupStatus.mockResolvedValue(apiResponse) + + const result = await fetchSetupStatusWithCache() + + expect(mockFetchSetupStatus).toHaveBeenCalledTimes(1) + expect(result).toEqual(apiResponse) + expect(localStorage.getItem('setup_status')).toBeNull() + }) + + it('should clear stale cache when API returns not_started', async () => { + localStorage.setItem('setup_status', 'some_invalid_value') + const apiResponse: SetupStatusResponse = { step: 'not_started' } + mockFetchSetupStatus.mockResolvedValue(apiResponse) + + const result = await fetchSetupStatusWithCache() + + expect(result).toEqual(apiResponse) + expect(localStorage.getItem('setup_status')).toBeNull() + }) + }) + + describe('cache edge cases', () => { + it('should call API when cache value is empty string', async () => { + localStorage.setItem('setup_status', '') + const apiResponse: SetupStatusResponse = { step: 'finished' } + mockFetchSetupStatus.mockResolvedValue(apiResponse) + + const result = await fetchSetupStatusWithCache() + + expect(mockFetchSetupStatus).toHaveBeenCalledTimes(1) + expect(result).toEqual(apiResponse) + }) + + it('should call API when cache value is not "finished"', async () => { + localStorage.setItem('setup_status', 'not_started') + const apiResponse: SetupStatusResponse = { step: 'finished' } + mockFetchSetupStatus.mockResolvedValue(apiResponse) + + const result = await fetchSetupStatusWithCache() + + expect(mockFetchSetupStatus).toHaveBeenCalledTimes(1) + expect(result).toEqual(apiResponse) + }) + + it('should call API when localStorage key does not exist', async () => { + const apiResponse: SetupStatusResponse = { step: 'finished' } + mockFetchSetupStatus.mockResolvedValue(apiResponse) + + const result = await fetchSetupStatusWithCache() + + expect(mockFetchSetupStatus).toHaveBeenCalledTimes(1) + expect(result).toEqual(apiResponse) + }) + }) + + describe('API response handling', () => { + it('should preserve setup_at from API response', async () => { + const setupDate = new Date('2024-01-01') + const apiResponse: SetupStatusResponse = { + step: 'finished', + setup_at: setupDate, + } + mockFetchSetupStatus.mockResolvedValue(apiResponse) + + const result = await fetchSetupStatusWithCache() + + expect(result).toEqual(apiResponse) + expect(result.setup_at).toEqual(setupDate) + }) + + it('should propagate API errors', async () => { + const apiError = new Error('Network error') + mockFetchSetupStatus.mockRejectedValue(apiError) + + await expect(fetchSetupStatusWithCache()).rejects.toThrow('Network error') + }) + + it('should not update cache when API call fails', async () => { + mockFetchSetupStatus.mockRejectedValue(new Error('API error')) + + await expect(fetchSetupStatusWithCache()).rejects.toThrow() + + expect(localStorage.getItem('setup_status')).toBeNull() + }) + }) + }) +}) diff --git a/web/utils/setup-status.ts b/web/utils/setup-status.ts new file mode 100644 index 0000000000..7a2810bffd --- /dev/null +++ b/web/utils/setup-status.ts @@ -0,0 +1,21 @@ +import type { SetupStatusResponse } from '@/models/common' +import { fetchSetupStatus } from '@/service/common' + +const SETUP_STATUS_KEY = 'setup_status' + +const isSetupStatusCached = (): boolean => + localStorage.getItem(SETUP_STATUS_KEY) === 'finished' + +export const fetchSetupStatusWithCache = async (): Promise => { + if (isSetupStatusCached()) + return { step: 'finished' } + + const status = await fetchSetupStatus() + + if (status.step === 'finished') + localStorage.setItem(SETUP_STATUS_KEY, 'finished') + else + localStorage.removeItem(SETUP_STATUS_KEY) + + return status +}