Merge branch 'main' into feat/summary-index

This commit is contained in:
zxhlyh 2026-01-14 13:40:15 +08:00
commit 830a7fb034
84 changed files with 2140 additions and 827 deletions

View File

@ -4,7 +4,7 @@ from pydantic_settings import BaseSettings
class VolcengineTOSStorageConfig(BaseSettings):
"""
Configuration settings for Volcengine Tinder Object Storage (TOS)
Configuration settings for Volcengine Torch Object Storage (TOS)
"""
VOLCENGINE_TOS_BUCKET_NAME: str | None = Field(

View File

@ -7,7 +7,7 @@ from typing import Literal, cast
import sqlalchemy as sa
from flask import request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel
from pydantic import BaseModel, Field
from sqlalchemy import asc, desc, select
from werkzeug.exceptions import Forbidden, NotFound
@ -104,6 +104,15 @@ class DocumentRenamePayload(BaseModel):
name: str
class DocumentDatasetListParam(BaseModel):
page: int = Field(1, title="Page", description="Page number.")
limit: int = Field(20, title="Limit", description="Page size.")
search: str | None = Field(None, alias="keyword", title="Search", description="Search keyword.")
sort_by: str = Field("-created_at", alias="sort", title="SortBy", description="Sort by field.")
status: str | None = Field(None, title="Status", description="Document status.")
fetch_val: str = Field("false", alias="fetch")
register_schema_models(
console_ns,
KnowledgeConfig,
@ -225,14 +234,16 @@ class DatasetDocumentListApi(Resource):
def get(self, dataset_id):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id)
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
search = request.args.get("keyword", default=None, type=str)
sort = request.args.get("sort", default="-created_at", type=str)
status = request.args.get("status", default=None, type=str)
raw_args = request.args.to_dict()
param = DocumentDatasetListParam.model_validate(raw_args)
page = param.page
limit = param.limit
search = param.search
sort = param.sort_by
status = param.status
# "yes", "true", "t", "y", "1" convert to True, while others convert to False.
try:
fetch_val = request.args.get("fetch", default="false")
fetch_val = param.fetch_val
if isinstance(fetch_val, bool):
fetch = fetch_val
else:

View File

@ -24,7 +24,7 @@ from core.app.layers.conversation_variable_persist_layer import ConversationVari
from core.db.session_factory import session_factory
from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration
from core.variables.variables import VariableUnion
from core.variables.variables import Variable
from core.workflow.enums import WorkflowType
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.layers.base import GraphEngineLayer
@ -149,8 +149,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=self._workflow.environment_variables,
# Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
# Based on the definition of `Variable`,
# `VariableBase` instances can be safely used as `Variable` since they are compatible.
conversation_variables=conversation_variables,
)
@ -318,7 +318,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
trace_manager=app_generate_entity.trace_manager,
)
def _initialize_conversation_variables(self) -> list[VariableUnion]:
def _initialize_conversation_variables(self) -> list[Variable]:
"""
Initialize conversation variables for the current conversation.
@ -343,7 +343,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
conversation_variables = [var.to_variable() for var in existing_variables]
session.commit()
return cast(list[VariableUnion], conversation_variables)
return cast(list[Variable], conversation_variables)
def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]:
"""

View File

@ -1,6 +1,6 @@
import logging
from core.variables import Variable
from core.variables import VariableBase
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.enums import NodeType
@ -44,7 +44,7 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer):
if selector[0] != CONVERSATION_VARIABLE_NODE_ID:
continue
variable = self.graph_runtime_state.variable_pool.get(selector)
if not isinstance(variable, Variable):
if not isinstance(variable, VariableBase):
logger.warning(
"Conversation variable not found in variable pool. selector=%s",
selector,

View File

@ -55,7 +55,7 @@ from core.ops.entities.trace_entity import (
ToolTraceInfo,
WorkflowTraceInfo,
)
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.entities import WorkflowNodeExecution
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey
from extensions.ext_database import db
@ -275,7 +275,7 @@ class AliyunDataTrace(BaseTraceInstance):
service_account = self.get_service_account_with_tenant(app_id)
session_factory = sessionmaker(bind=db.engine)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=service_account,
app_id=app_id,

View File

@ -7,8 +7,8 @@ from typing import Any, cast
from flask import has_request_context
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.db.session_factory import session_factory
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from core.tools.__base.tool import Tool
@ -20,7 +20,6 @@ from core.tools.entities.tool_entities import (
ToolProviderType,
)
from core.tools.errors import ToolInvokeError
from extensions.ext_database import db
from factories.file_factory import build_from_mapping
from libs.login import current_user
from models import Account, Tenant
@ -230,30 +229,32 @@ class WorkflowTool(Tool):
"""
Resolve user from database (worker/Celery context).
"""
with session_factory.create_session() as session:
tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id)
tenant = session.scalar(tenant_stmt)
if not tenant:
return None
user_stmt = select(Account).where(Account.id == user_id)
user = session.scalar(user_stmt)
if user:
user.current_tenant = tenant
session.expunge(user)
return user
end_user_stmt = select(EndUser).where(EndUser.id == user_id, EndUser.tenant_id == tenant.id)
end_user = session.scalar(end_user_stmt)
if end_user:
session.expunge(end_user)
return end_user
tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id)
tenant = db.session.scalar(tenant_stmt)
if not tenant:
return None
user_stmt = select(Account).where(Account.id == user_id)
user = db.session.scalar(user_stmt)
if user:
user.current_tenant = tenant
return user
end_user_stmt = select(EndUser).where(EndUser.id == user_id, EndUser.tenant_id == tenant.id)
end_user = db.session.scalar(end_user_stmt)
if end_user:
return end_user
return None
def _get_workflow(self, app_id: str, version: str) -> Workflow:
"""
get the workflow by app id and version
"""
with Session(db.engine, expire_on_commit=False) as session, session.begin():
with session_factory.create_session() as session, session.begin():
if not version:
stmt = (
select(Workflow)
@ -265,22 +266,24 @@ class WorkflowTool(Tool):
stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version)
workflow = session.scalar(stmt)
if not workflow:
raise ValueError("workflow not found or not published")
if not workflow:
raise ValueError("workflow not found or not published")
return workflow
session.expunge(workflow)
return workflow
def _get_app(self, app_id: str) -> App:
"""
get the app by app id
"""
stmt = select(App).where(App.id == app_id)
with Session(db.engine, expire_on_commit=False) as session, session.begin():
with session_factory.create_session() as session, session.begin():
app = session.scalar(stmt)
if not app:
raise ValueError("app not found")
if not app:
raise ValueError("app not found")
return app
session.expunge(app)
return app
def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]:
"""

View File

@ -30,6 +30,7 @@ from .variables import (
SecretVariable,
StringVariable,
Variable,
VariableBase,
)
__all__ = [
@ -62,4 +63,5 @@ __all__ = [
"StringSegment",
"StringVariable",
"Variable",
"VariableBase",
]

View File

@ -232,7 +232,7 @@ def get_segment_discriminator(v: Any) -> SegmentType | None:
# - All variants in `SegmentUnion` must inherit from the `Segment` class.
# - The union must include all non-abstract subclasses of `Segment`, except:
# - `SegmentGroup`, which is not added to the variable pool.
# - `Variable` and its subclasses, which are handled by `VariableUnion`.
# - `VariableBase` and its subclasses, which are handled by `Variable`.
SegmentUnion: TypeAlias = Annotated[
(
Annotated[NoneSegment, Tag(SegmentType.NONE)]

View File

@ -27,7 +27,7 @@ from .segments import (
from .types import SegmentType
class Variable(Segment):
class VariableBase(Segment):
"""
A variable is a segment that has a name.
@ -45,23 +45,23 @@ class Variable(Segment):
selector: Sequence[str] = Field(default_factory=list)
class StringVariable(StringSegment, Variable):
class StringVariable(StringSegment, VariableBase):
pass
class FloatVariable(FloatSegment, Variable):
class FloatVariable(FloatSegment, VariableBase):
pass
class IntegerVariable(IntegerSegment, Variable):
class IntegerVariable(IntegerSegment, VariableBase):
pass
class ObjectVariable(ObjectSegment, Variable):
class ObjectVariable(ObjectSegment, VariableBase):
pass
class ArrayVariable(ArraySegment, Variable):
class ArrayVariable(ArraySegment, VariableBase):
pass
@ -89,16 +89,16 @@ class SecretVariable(StringVariable):
return encrypter.obfuscated_token(self.value)
class NoneVariable(NoneSegment, Variable):
class NoneVariable(NoneSegment, VariableBase):
value_type: SegmentType = SegmentType.NONE
value: None = None
class FileVariable(FileSegment, Variable):
class FileVariable(FileSegment, VariableBase):
pass
class BooleanVariable(BooleanSegment, Variable):
class BooleanVariable(BooleanSegment, VariableBase):
pass
@ -139,13 +139,13 @@ class RAGPipelineVariableInput(BaseModel):
value: Any
# The `VariableUnion`` type is used to enable serialization and deserialization with Pydantic.
# Use `Variable` for type hinting when serialization is not required.
# The `Variable` type is used to enable serialization and deserialization with Pydantic.
# Use `VariableBase` for type hinting when serialization is not required.
#
# Note:
# - All variants in `VariableUnion` must inherit from the `Variable` class.
# - The union must include all non-abstract subclasses of `Segment`, except:
VariableUnion: TypeAlias = Annotated[
# - All variants in `Variable` must inherit from the `VariableBase` class.
# - The union must include all non-abstract subclasses of `VariableBase`.
Variable: TypeAlias = Annotated[
(
Annotated[NoneVariable, Tag(SegmentType.NONE)]
| Annotated[StringVariable, Tag(SegmentType.STRING)]

View File

@ -1,7 +1,7 @@
import abc
from typing import Protocol
from core.variables import Variable
from core.variables import VariableBase
class ConversationVariableUpdater(Protocol):
@ -20,12 +20,12 @@ class ConversationVariableUpdater(Protocol):
"""
@abc.abstractmethod
def update(self, conversation_id: str, variable: "Variable"):
def update(self, conversation_id: str, variable: "VariableBase"):
"""
Updates the value of the specified conversation variable in the underlying storage.
:param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`.
:param variable: The `Variable` instance containing the updated value.
:param variable: The `VariableBase` instance containing the updated value.
"""
pass

View File

@ -11,7 +11,7 @@ from typing import Any
from pydantic import BaseModel, Field
from core.variables.variables import VariableUnion
from core.variables.variables import Variable
class CommandType(StrEnum):
@ -46,7 +46,7 @@ class PauseCommand(GraphEngineCommand):
class VariableUpdate(BaseModel):
"""Represents a single variable update instruction."""
value: VariableUnion = Field(description="New variable value")
value: Variable = Field(description="New variable value")
class UpdateVariablesCommand(GraphEngineCommand):

View File

@ -11,7 +11,7 @@ from typing_extensions import TypeIs
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables import IntegerVariable, NoneSegment
from core.variables.segments import ArrayAnySegment, ArraySegment
from core.variables.variables import VariableUnion
from core.variables.variables import Variable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.enums import (
NodeExecutionType,
@ -240,7 +240,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
datetime,
list[GraphNodeEventBase],
object | None,
dict[str, VariableUnion],
dict[str, Variable],
LLMUsage,
]
],
@ -308,7 +308,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
item: object,
flask_app: Flask,
context_vars: contextvars.Context,
) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, VariableUnion], LLMUsage]:
) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]:
"""Execute a single iteration in parallel mode and return results."""
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
@ -515,11 +515,11 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
return variable_mapping
def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, VariableUnion]:
def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, Variable]:
conversation_variables = variable_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})
return {name: variable.model_copy(deep=True) for name, variable in conversation_variables.items()}
def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, VariableUnion]) -> None:
def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, Variable]) -> None:
parent_pool = self.graph_runtime_state.variable_pool
parent_conversations = parent_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})

View File

@ -1,7 +1,7 @@
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.variables import SegmentType, Variable
from core.variables import SegmentType, VariableBase
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.entities import GraphInitParams
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
@ -73,7 +73,7 @@ class VariableAssignerNode(Node[VariableAssignerData]):
assigned_variable_selector = self.node_data.assigned_variable_selector
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector)
if not isinstance(original_variable, Variable):
if not isinstance(original_variable, VariableBase):
raise VariableOperatorNodeError("assigned variable not found")
match self.node_data.write_mode:

View File

@ -2,7 +2,7 @@ import json
from collections.abc import Mapping, MutableMapping, Sequence
from typing import TYPE_CHECKING, Any
from core.variables import SegmentType, Variable
from core.variables import SegmentType, VariableBase
from core.variables.consts import SELECTORS_LENGTH
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
@ -118,7 +118,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
# ==================== Validation Part
# Check if variable exists
if not isinstance(variable, Variable):
if not isinstance(variable, VariableBase):
raise VariableNotFoundError(variable_selector=item.variable_selector)
# Check if operation is supported
@ -192,7 +192,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
for selector in updated_variable_selectors:
variable = self.graph_runtime_state.variable_pool.get(selector)
if not isinstance(variable, Variable):
if not isinstance(variable, VariableBase):
raise VariableNotFoundError(variable_selector=selector)
process_data[variable.name] = variable.value
@ -213,7 +213,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
def _handle_item(
self,
*,
variable: Variable,
variable: VariableBase,
operation: Operation,
value: Any,
):

View File

@ -9,10 +9,10 @@ from typing import Annotated, Any, Union, cast
from pydantic import BaseModel, Field
from core.file import File, FileAttribute, file_manager
from core.variables import Segment, SegmentGroup, Variable
from core.variables import Segment, SegmentGroup, VariableBase
from core.variables.consts import SELECTORS_LENGTH
from core.variables.segments import FileSegment, ObjectSegment
from core.variables.variables import RAGPipelineVariableInput, VariableUnion
from core.variables.variables import RAGPipelineVariableInput, Variable
from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
ENVIRONMENT_VARIABLE_NODE_ID,
@ -32,7 +32,7 @@ class VariablePool(BaseModel):
# The first element of the selector is the node id, it's the first-level key in the dictionary.
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
# elements of the selector except the first one.
variable_dictionary: defaultdict[str, Annotated[dict[str, VariableUnion], Field(default_factory=dict)]] = Field(
variable_dictionary: defaultdict[str, Annotated[dict[str, Variable], Field(default_factory=dict)]] = Field(
description="Variables mapping",
default=defaultdict(dict),
)
@ -46,13 +46,13 @@ class VariablePool(BaseModel):
description="System variables",
default_factory=SystemVariable.empty,
)
environment_variables: Sequence[VariableUnion] = Field(
environment_variables: Sequence[Variable] = Field(
description="Environment variables.",
default_factory=list[VariableUnion],
default_factory=list[Variable],
)
conversation_variables: Sequence[VariableUnion] = Field(
conversation_variables: Sequence[Variable] = Field(
description="Conversation variables.",
default_factory=list[VariableUnion],
default_factory=list[Variable],
)
rag_pipeline_variables: list[RAGPipelineVariableInput] = Field(
description="RAG pipeline variables.",
@ -105,7 +105,7 @@ class VariablePool(BaseModel):
f"got {len(selector)} elements"
)
if isinstance(value, Variable):
if isinstance(value, VariableBase):
variable = value
elif isinstance(value, Segment):
variable = variable_factory.segment_to_variable(segment=value, selector=selector)
@ -114,9 +114,9 @@ class VariablePool(BaseModel):
variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
node_id, name = self._selector_to_keys(selector)
# Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
self.variable_dictionary[node_id][name] = cast(VariableUnion, variable)
# Based on the definition of `Variable`,
# `VariableBase` instances can be safely used as `Variable` since they are compatible.
self.variable_dictionary[node_id][name] = cast(Variable, variable)
@classmethod
def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, str]:

View File

@ -2,7 +2,7 @@ import abc
from collections.abc import Mapping, Sequence
from typing import Any, Protocol
from core.variables import Variable
from core.variables import VariableBase
from core.variables.consts import SELECTORS_LENGTH
from core.workflow.runtime import VariablePool
@ -26,7 +26,7 @@ class VariableLoader(Protocol):
"""
@abc.abstractmethod
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]:
"""Load variables based on the provided selectors. If the selectors are empty,
this method should return an empty list.
@ -36,7 +36,7 @@ class VariableLoader(Protocol):
:param: selectors: a list of string list, each inner list should have at least two elements:
- the first element is the node ID,
- the second element is the variable name.
:return: a list of Variable objects that match the provided selectors.
:return: a list of VariableBase objects that match the provided selectors.
"""
pass
@ -46,7 +46,7 @@ class _DummyVariableLoader(VariableLoader):
Serves as a placeholder when no variable loading is needed.
"""
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]:
return []

View File

@ -10,6 +10,7 @@ import os
from dotenv import load_dotenv
from configs import dify_config
from dify_app import DifyApp
logger = logging.getLogger(__name__)
@ -19,12 +20,17 @@ def is_enabled() -> bool:
"""
Check if logstore extension is enabled.
Logstore is considered enabled when:
1. All required Aliyun SLS environment variables are set
2. At least one repository configuration points to a logstore implementation
Returns:
True if all required Aliyun SLS environment variables are set, False otherwise
True if logstore should be initialized, False otherwise
"""
# Load environment variables from .env file
load_dotenv()
# Check if Aliyun SLS connection parameters are configured
required_vars = [
"ALIYUN_SLS_ACCESS_KEY_ID",
"ALIYUN_SLS_ACCESS_KEY_SECRET",
@ -33,24 +39,32 @@ def is_enabled() -> bool:
"ALIYUN_SLS_PROJECT_NAME",
]
all_set = all(os.environ.get(var) for var in required_vars)
sls_vars_set = all(os.environ.get(var) for var in required_vars)
if not all_set:
logger.info("Logstore extension disabled: required Aliyun SLS environment variables not set")
if not sls_vars_set:
return False
return all_set
# Check if any repository configuration points to logstore implementation
repository_configs = [
dify_config.CORE_WORKFLOW_EXECUTION_REPOSITORY,
dify_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY,
dify_config.API_WORKFLOW_NODE_EXECUTION_REPOSITORY,
dify_config.API_WORKFLOW_RUN_REPOSITORY,
]
uses_logstore = any("logstore" in config.lower() for config in repository_configs)
if not uses_logstore:
return False
logger.info("Logstore extension enabled: SLS variables set and repository configured to use logstore")
return True
def init_app(app: DifyApp):
"""
Initialize logstore on application startup.
This function:
1. Creates Aliyun SLS project if it doesn't exist
2. Creates logstores (workflow_execution, workflow_node_execution) if they don't exist
3. Creates indexes with field configurations based on PostgreSQL table structures
This operation is idempotent and only executes once during application startup.
If initialization fails, the application continues running without logstore features.
Args:
app: The Dify application instance
@ -58,17 +72,23 @@ def init_app(app: DifyApp):
try:
from extensions.logstore.aliyun_logstore import AliyunLogStore
logger.info("Initializing logstore...")
logger.info("Initializing Aliyun SLS Logstore...")
# Create logstore client and initialize project/logstores/indexes
# Create logstore client and initialize resources
logstore_client = AliyunLogStore()
logstore_client.init_project_logstore()
# Attach to app for potential later use
app.extensions["logstore"] = logstore_client
logger.info("Logstore initialized successfully")
except Exception:
logger.exception("Failed to initialize logstore")
# Don't raise - allow application to continue even if logstore init fails
# This ensures that the application can still run if logstore is misconfigured
logger.exception(
"Logstore initialization failed. Configuration: endpoint=%s, region=%s, project=%s, timeout=%ss. "
"Application will continue but logstore features will NOT work.",
os.environ.get("ALIYUN_SLS_ENDPOINT"),
os.environ.get("ALIYUN_SLS_REGION"),
os.environ.get("ALIYUN_SLS_PROJECT_NAME"),
os.environ.get("ALIYUN_SLS_CHECK_CONNECTIVITY_TIMEOUT", "30"),
)
# Don't raise - allow application to continue even if logstore setup fails

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import logging
import os
import socket
import threading
import time
from collections.abc import Sequence
@ -179,9 +180,18 @@ class AliyunLogStore:
self.region: str = os.environ.get("ALIYUN_SLS_REGION", "")
self.project_name: str = os.environ.get("ALIYUN_SLS_PROJECT_NAME", "")
self.logstore_ttl: int = int(os.environ.get("ALIYUN_SLS_LOGSTORE_TTL", 365))
self.log_enabled: bool = os.environ.get("SQLALCHEMY_ECHO", "false").lower() == "true"
self.log_enabled: bool = (
os.environ.get("SQLALCHEMY_ECHO", "false").lower() == "true"
or os.environ.get("LOGSTORE_SQL_ECHO", "false").lower() == "true"
)
self.pg_mode_enabled: bool = os.environ.get("LOGSTORE_PG_MODE_ENABLED", "true").lower() == "true"
# Get timeout configuration
check_timeout = int(os.environ.get("ALIYUN_SLS_CHECK_CONNECTIVITY_TIMEOUT", 30))
# Pre-check endpoint connectivity to prevent indefinite hangs
self._check_endpoint_connectivity(self.endpoint, check_timeout)
# Initialize SDK client
self.client = LogClient(
self.endpoint, self.access_key_id, self.access_key_secret, auth_version=AUTH_VERSION_4, region=self.region
@ -199,6 +209,49 @@ class AliyunLogStore:
self.__class__._initialized = True
@staticmethod
def _check_endpoint_connectivity(endpoint: str, timeout: int) -> None:
"""
Check if the SLS endpoint is reachable before creating LogClient.
Prevents indefinite hangs when the endpoint is unreachable.
Args:
endpoint: SLS endpoint URL
timeout: Connection timeout in seconds
Raises:
ConnectionError: If endpoint is not reachable
"""
# Parse endpoint URL to extract hostname and port
from urllib.parse import urlparse
parsed_url = urlparse(endpoint if "://" in endpoint else f"http://{endpoint}")
hostname = parsed_url.hostname
port = parsed_url.port or (443 if parsed_url.scheme == "https" else 80)
if not hostname:
raise ConnectionError(f"Invalid endpoint URL: {endpoint}")
sock = None
try:
# Create socket and set timeout
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(timeout)
sock.connect((hostname, port))
except Exception as e:
# Catch all exceptions and provide clear error message
error_type = type(e).__name__
raise ConnectionError(
f"Cannot connect to {hostname}:{port} (timeout={timeout}s): [{error_type}] {e}"
) from e
finally:
# Ensure socket is properly closed
if sock:
try:
sock.close()
except Exception: # noqa: S110
pass # Ignore errors during cleanup
@property
def supports_pg_protocol(self) -> bool:
"""Check if PG protocol is supported and enabled."""
@ -220,19 +273,16 @@ class AliyunLogStore:
try:
self._use_pg_protocol = self._pg_client.init_connection()
if self._use_pg_protocol:
logger.info("Successfully connected to project %s using PG protocol", self.project_name)
logger.info("Using PG protocol for project %s", self.project_name)
# Check if scan_index is enabled for all logstores
self._check_and_disable_pg_if_scan_index_disabled()
return True
else:
logger.info("PG connection failed for project %s. Will use SDK mode.", self.project_name)
logger.info("Using SDK mode for project %s", self.project_name)
return False
except Exception as e:
logger.warning(
"Failed to establish PG connection for project %s: %s. Will use SDK mode.",
self.project_name,
str(e),
)
logger.info("Using SDK mode for project %s", self.project_name)
logger.debug("PG connection details: %s", str(e))
self._use_pg_protocol = False
return False
@ -246,10 +296,6 @@ class AliyunLogStore:
if self._use_pg_protocol:
return
logger.info(
"Attempting delayed PG connection for newly created project %s ...",
self.project_name,
)
self._attempt_pg_connection_init()
self.__class__._pg_connection_timer = None
@ -284,11 +330,7 @@ class AliyunLogStore:
if project_is_new:
# For newly created projects, schedule delayed PG connection
self._use_pg_protocol = False
logger.info(
"Project %s is newly created. Will use SDK mode and schedule PG connection attempt in %d seconds.",
self.project_name,
self.__class__._pg_connection_delay,
)
logger.info("Using SDK mode for project %s (newly created)", self.project_name)
if self.__class__._pg_connection_timer is not None:
self.__class__._pg_connection_timer.cancel()
self.__class__._pg_connection_timer = threading.Timer(
@ -299,7 +341,6 @@ class AliyunLogStore:
self.__class__._pg_connection_timer.start()
else:
# For existing projects, attempt PG connection immediately
logger.info("Project %s already exists. Attempting PG connection...", self.project_name)
self._attempt_pg_connection_init()
def _check_and_disable_pg_if_scan_index_disabled(self) -> None:
@ -318,9 +359,9 @@ class AliyunLogStore:
existing_config = self.get_existing_index_config(logstore_name)
if existing_config and not existing_config.scan_index:
logger.info(
"Logstore %s has scan_index=false, USE SDK mode for read/write operations. "
"PG protocol requires scan_index to be enabled.",
"Logstore %s requires scan_index enabled, using SDK mode for project %s",
logstore_name,
self.project_name,
)
self._use_pg_protocol = False
# Close PG connection if it was initialized
@ -748,7 +789,6 @@ class AliyunLogStore:
reverse=reverse,
)
# Log query info if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore] GET_LOGS | logstore=%s | project=%s | query=%s | "
@ -770,7 +810,6 @@ class AliyunLogStore:
for log in logs:
result.append(log.get_contents())
# Log result count if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore] GET_LOGS RESULT | logstore=%s | returned_count=%d",
@ -845,7 +884,6 @@ class AliyunLogStore:
query=full_query,
)
# Log query info if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore-SDK] EXECUTE_SQL | logstore=%s | project=%s | from_time=%d | to_time=%d | full_query=%s",
@ -853,8 +891,7 @@ class AliyunLogStore:
self.project_name,
from_time,
to_time,
query,
sql,
full_query,
)
try:
@ -865,7 +902,6 @@ class AliyunLogStore:
for log in logs:
result.append(log.get_contents())
# Log result count if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore-SDK] EXECUTE_SQL RESULT | logstore=%s | returned_count=%d",

View File

@ -7,8 +7,7 @@ from contextlib import contextmanager
from typing import Any
import psycopg2
import psycopg2.pool
from psycopg2 import InterfaceError, OperationalError
from sqlalchemy import create_engine
from configs import dify_config
@ -16,11 +15,7 @@ logger = logging.getLogger(__name__)
class AliyunLogStorePG:
"""
PostgreSQL protocol support for Aliyun SLS LogStore.
Handles PG connection pooling and operations for regions that support PG protocol.
"""
"""PostgreSQL protocol support for Aliyun SLS LogStore using SQLAlchemy connection pool."""
def __init__(self, access_key_id: str, access_key_secret: str, endpoint: str, project_name: str):
"""
@ -36,24 +31,11 @@ class AliyunLogStorePG:
self._access_key_secret = access_key_secret
self._endpoint = endpoint
self.project_name = project_name
self._pg_pool: psycopg2.pool.SimpleConnectionPool | None = None
self._engine: Any = None # SQLAlchemy Engine
self._use_pg_protocol = False
def _check_port_connectivity(self, host: str, port: int, timeout: float = 2.0) -> bool:
"""
Check if a TCP port is reachable using socket connection.
This provides a fast check before attempting full database connection,
preventing long waits when connecting to unsupported regions.
Args:
host: Hostname or IP address
port: Port number
timeout: Connection timeout in seconds (default: 2.0)
Returns:
True if port is reachable, False otherwise
"""
"""Fast TCP port check to avoid long waits on unsupported regions."""
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(timeout)
@ -65,166 +47,101 @@ class AliyunLogStorePG:
return False
def init_connection(self) -> bool:
"""
Initialize PostgreSQL connection pool for SLS PG protocol support.
Attempts to connect to SLS using PostgreSQL protocol. If successful, sets
_use_pg_protocol to True and creates a connection pool. If connection fails
(region doesn't support PG protocol or other errors), returns False.
Returns:
True if PG protocol is supported and initialized, False otherwise
"""
"""Initialize SQLAlchemy connection pool with pool_recycle and TCP keepalive support."""
try:
# Extract hostname from endpoint (remove protocol if present)
pg_host = self._endpoint.replace("http://", "").replace("https://", "")
# Get pool configuration
pg_max_connections = int(os.environ.get("ALIYUN_SLS_PG_MAX_CONNECTIONS", 10))
# Pool configuration
pool_size = int(os.environ.get("ALIYUN_SLS_PG_POOL_SIZE", 5))
max_overflow = int(os.environ.get("ALIYUN_SLS_PG_MAX_OVERFLOW", 5))
pool_recycle = int(os.environ.get("ALIYUN_SLS_PG_POOL_RECYCLE", 3600))
pool_pre_ping = os.environ.get("ALIYUN_SLS_PG_POOL_PRE_PING", "false").lower() == "true"
logger.debug(
"Check PG protocol connection to SLS: host=%s, project=%s",
pg_host,
self.project_name,
)
logger.debug("Check PG protocol connection to SLS: host=%s, project=%s", pg_host, self.project_name)
# Fast port connectivity check before attempting full connection
# This prevents long waits when connecting to unsupported regions
# Fast port check to avoid long waits
if not self._check_port_connectivity(pg_host, 5432, timeout=1.0):
logger.info(
"USE SDK mode for read/write operations, host=%s",
pg_host,
)
logger.debug("Using SDK mode for host=%s", pg_host)
return False
# Create connection pool
self._pg_pool = psycopg2.pool.SimpleConnectionPool(
minconn=1,
maxconn=pg_max_connections,
host=pg_host,
port=5432,
database=self.project_name,
user=self._access_key_id,
password=self._access_key_secret,
sslmode="require",
connect_timeout=5,
application_name=f"Dify-{dify_config.project.version}",
# Build connection URL
from urllib.parse import quote_plus
username = quote_plus(self._access_key_id)
password = quote_plus(self._access_key_secret)
database_url = (
f"postgresql+psycopg2://{username}:{password}@{pg_host}:5432/{self.project_name}?sslmode=require"
)
# Note: Skip test query because SLS PG protocol only supports SELECT/INSERT on actual tables
# Connection pool creation success already indicates connectivity
# Create SQLAlchemy engine with connection pool
self._engine = create_engine(
database_url,
pool_size=pool_size,
max_overflow=max_overflow,
pool_recycle=pool_recycle,
pool_pre_ping=pool_pre_ping,
pool_timeout=30,
connect_args={
"connect_timeout": 5,
"application_name": f"Dify-{dify_config.project.version}-fixautocommit",
"keepalives": 1,
"keepalives_idle": 60,
"keepalives_interval": 10,
"keepalives_count": 5,
},
)
self._use_pg_protocol = True
logger.info(
"PG protocol initialized successfully for SLS project=%s. Will use PG for read/write operations.",
"PG protocol initialized for SLS project=%s (pool_size=%d, pool_recycle=%ds)",
self.project_name,
pool_size,
pool_recycle,
)
return True
except Exception as e:
# PG connection failed - fallback to SDK mode
self._use_pg_protocol = False
if self._pg_pool:
if self._engine:
try:
self._pg_pool.closeall()
self._engine.dispose()
except Exception:
logger.debug("Failed to close PG connection pool during cleanup, ignoring")
self._pg_pool = None
logger.debug("Failed to dispose engine during cleanup, ignoring")
self._engine = None
logger.info(
"PG protocol connection failed (region may not support PG protocol): %s. "
"Falling back to SDK mode for read/write operations.",
str(e),
)
return False
def _is_connection_valid(self, conn: Any) -> bool:
"""
Check if a connection is still valid.
Args:
conn: psycopg2 connection object
Returns:
True if connection is valid, False otherwise
"""
try:
# Check if connection is closed
if conn.closed:
return False
# Quick ping test - execute a lightweight query
# For SLS PG protocol, we can't use SELECT 1 without FROM,
# so we just check the connection status
with conn.cursor() as cursor:
cursor.execute("SELECT 1")
cursor.fetchone()
return True
except Exception:
logger.debug("Using SDK mode for region: %s", str(e))
return False
@contextmanager
def _get_connection(self):
"""
Context manager to get a PostgreSQL connection from the pool.
"""Get connection from SQLAlchemy pool. Pool handles recycle, invalidation, and keepalive automatically."""
if not self._engine:
raise RuntimeError("SQLAlchemy engine is not initialized")
Automatically validates and refreshes stale connections.
Note: Aliyun SLS PG protocol does not support transactions, so we always
use autocommit mode.
Yields:
psycopg2 connection object
Raises:
RuntimeError: If PG pool is not initialized
"""
if not self._pg_pool:
raise RuntimeError("PG connection pool is not initialized")
conn = self._pg_pool.getconn()
connection = self._engine.raw_connection()
try:
# Validate connection and get a fresh one if needed
if not self._is_connection_valid(conn):
logger.debug("Connection is stale, marking as bad and getting a new one")
# Mark connection as bad and get a new one
self._pg_pool.putconn(conn, close=True)
conn = self._pg_pool.getconn()
# Aliyun SLS PG protocol does not support transactions, always use autocommit
conn.autocommit = True
yield conn
connection.autocommit = True # SLS PG protocol does not support transactions
yield connection
except Exception:
raise
finally:
# Return connection to pool (or close if it's bad)
if self._is_connection_valid(conn):
self._pg_pool.putconn(conn)
else:
self._pg_pool.putconn(conn, close=True)
connection.close()
def close(self) -> None:
"""Close the PostgreSQL connection pool."""
if self._pg_pool:
"""Dispose SQLAlchemy engine and close all connections."""
if self._engine:
try:
self._pg_pool.closeall()
logger.info("PG connection pool closed")
self._engine.dispose()
logger.info("SQLAlchemy engine disposed")
except Exception:
logger.exception("Failed to close PG connection pool")
logger.exception("Failed to dispose engine")
def _is_retriable_error(self, error: Exception) -> bool:
"""
Check if an error is retriable (connection-related issues).
Args:
error: Exception to check
Returns:
True if the error is retriable, False otherwise
"""
# Retry on connection-related errors
if isinstance(error, (OperationalError, InterfaceError)):
"""Check if error is retriable (connection-related issues)."""
# Check for psycopg2 connection errors directly
if isinstance(error, (psycopg2.OperationalError, psycopg2.InterfaceError)):
return True
# Check error message for specific connection issues
error_msg = str(error).lower()
retriable_patterns = [
"connection",
@ -234,34 +151,18 @@ class AliyunLogStorePG:
"reset by peer",
"no route to host",
"network",
"operational error",
"interface error",
]
return any(pattern in error_msg for pattern in retriable_patterns)
def put_log(self, logstore: str, contents: Sequence[tuple[str, str]], log_enabled: bool = False) -> None:
"""
Write log to SLS using PostgreSQL protocol with automatic retry.
Note: SLS PG protocol only supports INSERT (not UPDATE). This uses append-only
writes with log_version field for versioning, same as SDK implementation.
Args:
logstore: Name of the logstore table
contents: List of (field_name, value) tuples
log_enabled: Whether to enable logging
Raises:
psycopg2.Error: If database operation fails after all retries
"""
"""Write log to SLS using INSERT with automatic retry (3 attempts with exponential backoff)."""
if not contents:
return
# Extract field names and values from contents
fields = [field_name for field_name, _ in contents]
values = [value for _, value in contents]
# Build INSERT statement with literal values
# Note: Aliyun SLS PG protocol doesn't support parameterized queries,
# so we need to use mogrify to safely create literal values
field_list = ", ".join([f'"{field}"' for field in fields])
if log_enabled:
@ -272,67 +173,40 @@ class AliyunLogStorePG:
len(contents),
)
# Retry configuration
max_retries = 3
retry_delay = 0.1 # Start with 100ms
retry_delay = 0.1
for attempt in range(max_retries):
try:
with self._get_connection() as conn:
with conn.cursor() as cursor:
# Use mogrify to safely convert values to SQL literals
placeholders = ", ".join(["%s"] * len(fields))
values_literal = cursor.mogrify(f"({placeholders})", values).decode("utf-8")
insert_sql = f'INSERT INTO "{logstore}" ({field_list}) VALUES {values_literal}'
cursor.execute(insert_sql)
# Success - exit retry loop
return
except psycopg2.Error as e:
# Check if error is retriable
if not self._is_retriable_error(e):
# Not a retriable error (e.g., data validation error), fail immediately
logger.exception(
"Failed to put logs to logstore %s via PG protocol (non-retriable error)",
logstore,
)
logger.exception("Failed to put logs to logstore %s (non-retriable error)", logstore)
raise
# Retriable error - log and retry if we have attempts left
if attempt < max_retries - 1:
logger.warning(
"Failed to put logs to logstore %s via PG protocol (attempt %d/%d): %s. Retrying...",
"Failed to put logs to logstore %s (attempt %d/%d): %s. Retrying...",
logstore,
attempt + 1,
max_retries,
str(e),
)
time.sleep(retry_delay)
retry_delay *= 2 # Exponential backoff
retry_delay *= 2
else:
# Last attempt failed
logger.exception(
"Failed to put logs to logstore %s via PG protocol after %d attempts",
logstore,
max_retries,
)
logger.exception("Failed to put logs to logstore %s after %d attempts", logstore, max_retries)
raise
def execute_sql(self, sql: str, logstore: str, log_enabled: bool = False) -> list[dict[str, Any]]:
"""
Execute SQL query using PostgreSQL protocol with automatic retry.
Args:
sql: SQL query string
logstore: Name of the logstore (for logging purposes)
log_enabled: Whether to enable logging
Returns:
List of result rows as dictionaries
Raises:
psycopg2.Error: If database operation fails after all retries
"""
"""Execute SQL query with automatic retry (3 attempts with exponential backoff)."""
if log_enabled:
logger.info(
"[LogStore-PG] EXECUTE_SQL | logstore=%s | project=%s | sql=%s",
@ -341,20 +215,16 @@ class AliyunLogStorePG:
sql,
)
# Retry configuration
max_retries = 3
retry_delay = 0.1 # Start with 100ms
retry_delay = 0.1
for attempt in range(max_retries):
try:
with self._get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute(sql)
# Get column names from cursor description
columns = [desc[0] for desc in cursor.description]
# Fetch all results and convert to list of dicts
result = []
for row in cursor.fetchall():
row_dict = {}
@ -372,36 +242,31 @@ class AliyunLogStorePG:
return result
except psycopg2.Error as e:
# Check if error is retriable
if not self._is_retriable_error(e):
# Not a retriable error (e.g., SQL syntax error), fail immediately
logger.exception(
"Failed to execute SQL query on logstore %s via PG protocol (non-retriable error): sql=%s",
"Failed to execute SQL on logstore %s (non-retriable error): sql=%s",
logstore,
sql,
)
raise
# Retriable error - log and retry if we have attempts left
if attempt < max_retries - 1:
logger.warning(
"Failed to execute SQL query on logstore %s via PG protocol (attempt %d/%d): %s. Retrying...",
"Failed to execute SQL on logstore %s (attempt %d/%d): %s. Retrying...",
logstore,
attempt + 1,
max_retries,
str(e),
)
time.sleep(retry_delay)
retry_delay *= 2 # Exponential backoff
retry_delay *= 2
else:
# Last attempt failed
logger.exception(
"Failed to execute SQL query on logstore %s via PG protocol after %d attempts: sql=%s",
"Failed to execute SQL on logstore %s after %d attempts: sql=%s",
logstore,
max_retries,
sql,
)
raise
# This line should never be reached due to raise above, but makes type checker happy
return []

View File

@ -0,0 +1,29 @@
"""
LogStore repository utilities.
"""
from typing import Any
def safe_float(value: Any, default: float = 0.0) -> float:
"""
Safely convert a value to float, handling 'null' strings and None.
"""
if value is None or value in {"null", ""}:
return default
try:
return float(value)
except (ValueError, TypeError):
return default
def safe_int(value: Any, default: int = 0) -> int:
"""
Safely convert a value to int, handling 'null' strings and None.
"""
if value is None or value in {"null", ""}:
return default
try:
return int(float(value))
except (ValueError, TypeError):
return default

View File

@ -14,6 +14,8 @@ from typing import Any
from sqlalchemy.orm import sessionmaker
from extensions.logstore.aliyun_logstore import AliyunLogStore
from extensions.logstore.repositories import safe_float, safe_int
from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value
from models.workflow import WorkflowNodeExecutionModel
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
@ -52,9 +54,8 @@ def _dict_to_workflow_node_execution_model(data: dict[str, Any]) -> WorkflowNode
model.created_by_role = data.get("created_by_role") or ""
model.created_by = data.get("created_by") or ""
# Numeric fields with defaults
model.index = int(data.get("index", 0))
model.elapsed_time = float(data.get("elapsed_time", 0))
model.index = safe_int(data.get("index", 0))
model.elapsed_time = safe_float(data.get("elapsed_time", 0))
# Optional fields
model.workflow_run_id = data.get("workflow_run_id")
@ -130,6 +131,12 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
node_id,
)
try:
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
escaped_workflow_id = escape_identifier(workflow_id)
escaped_node_id = escape_identifier(node_id)
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of each record)
@ -138,10 +145,10 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_node_execution_logstore}"
WHERE tenant_id = '{tenant_id}'
AND app_id = '{app_id}'
AND workflow_id = '{workflow_id}'
AND node_id = '{node_id}'
WHERE tenant_id = '{escaped_tenant_id}'
AND app_id = '{escaped_app_id}'
AND workflow_id = '{escaped_workflow_id}'
AND node_id = '{escaped_node_id}'
AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 100
@ -153,7 +160,8 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
else:
# Use SDK with LogStore query syntax
query = (
f"tenant_id: {tenant_id} and app_id: {app_id} and workflow_id: {workflow_id} and node_id: {node_id}"
f"tenant_id: {escaped_tenant_id} and app_id: {escaped_app_id} "
f"and workflow_id: {escaped_workflow_id} and node_id: {escaped_node_id}"
)
from_time = 0
to_time = int(time.time()) # now
@ -227,6 +235,11 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
workflow_run_id,
)
try:
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
escaped_workflow_run_id = escape_identifier(workflow_run_id)
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of each record)
@ -235,9 +248,9 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_node_execution_logstore}"
WHERE tenant_id = '{tenant_id}'
AND app_id = '{app_id}'
AND workflow_run_id = '{workflow_run_id}'
WHERE tenant_id = '{escaped_tenant_id}'
AND app_id = '{escaped_app_id}'
AND workflow_run_id = '{escaped_workflow_run_id}'
AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 1000
@ -248,7 +261,10 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
)
else:
# Use SDK with LogStore query syntax
query = f"tenant_id: {tenant_id} and app_id: {app_id} and workflow_run_id: {workflow_run_id}"
query = (
f"tenant_id: {escaped_tenant_id} and app_id: {escaped_app_id} "
f"and workflow_run_id: {escaped_workflow_run_id}"
)
from_time = 0
to_time = int(time.time()) # now
@ -313,16 +329,24 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
"""
logger.debug("get_execution_by_id: execution_id=%s, tenant_id=%s", execution_id, tenant_id)
try:
# Escape parameters to prevent SQL injection
escaped_execution_id = escape_identifier(execution_id)
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of record)
tenant_filter = f"AND tenant_id = '{tenant_id}'" if tenant_id else ""
if tenant_id:
escaped_tenant_id = escape_identifier(tenant_id)
tenant_filter = f"AND tenant_id = '{escaped_tenant_id}'"
else:
tenant_filter = ""
sql_query = f"""
SELECT * FROM (
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_node_execution_logstore}"
WHERE id = '{execution_id}' {tenant_filter} AND __time__ > 0
WHERE id = '{escaped_execution_id}' {tenant_filter} AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 1
"""
@ -332,10 +356,14 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
)
else:
# Use SDK with LogStore query syntax
# Note: Values must be quoted in LogStore query syntax to prevent injection
if tenant_id:
query = f"id: {execution_id} and tenant_id: {tenant_id}"
query = (
f"id:{escape_logstore_query_value(execution_id)} "
f"and tenant_id:{escape_logstore_query_value(tenant_id)}"
)
else:
query = f"id: {execution_id}"
query = f"id:{escape_logstore_query_value(execution_id)}"
from_time = 0
to_time = int(time.time()) # now

View File

@ -10,6 +10,7 @@ Key Features:
- Optimized deduplication using finished_at IS NOT NULL filter
- Window functions only when necessary (running status queries)
- Multi-tenant data isolation and security
- SQL injection prevention via parameter escaping
"""
import logging
@ -22,6 +23,8 @@ from typing import Any, cast
from sqlalchemy.orm import sessionmaker
from extensions.logstore.aliyun_logstore import AliyunLogStore
from extensions.logstore.repositories import safe_float, safe_int
from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.enums import WorkflowRunTriggeredFrom
from models.workflow import WorkflowRun
@ -63,10 +66,9 @@ def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun:
model.created_by_role = data.get("created_by_role") or ""
model.created_by = data.get("created_by") or ""
# Numeric fields with defaults
model.total_tokens = int(data.get("total_tokens", 0))
model.total_steps = int(data.get("total_steps", 0))
model.exceptions_count = int(data.get("exceptions_count", 0))
model.total_tokens = safe_int(data.get("total_tokens", 0))
model.total_steps = safe_int(data.get("total_steps", 0))
model.exceptions_count = safe_int(data.get("exceptions_count", 0))
# Optional fields
model.graph = data.get("graph")
@ -101,7 +103,8 @@ def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun:
if model.finished_at and model.created_at:
model.elapsed_time = (model.finished_at - model.created_at).total_seconds()
else:
model.elapsed_time = float(data.get("elapsed_time", 0))
# Use safe conversion to handle 'null' strings and None values
model.elapsed_time = safe_float(data.get("elapsed_time", 0))
return model
@ -165,16 +168,26 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
status,
)
# Convert triggered_from to list if needed
if isinstance(triggered_from, WorkflowRunTriggeredFrom):
if isinstance(triggered_from, (WorkflowRunTriggeredFrom, str)):
triggered_from_list = [triggered_from]
else:
triggered_from_list = list(triggered_from)
# Build triggered_from filter
triggered_from_filter = " OR ".join([f"triggered_from='{tf.value}'" for tf in triggered_from_list])
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
# Build status filter
status_filter = f"AND status='{status}'" if status else ""
# Build triggered_from filter with escaped values
# Support both enum and string values for triggered_from
triggered_from_filter = " OR ".join(
[
f"triggered_from='{escape_sql_string(tf.value if isinstance(tf, WorkflowRunTriggeredFrom) else tf)}'"
for tf in triggered_from_list
]
)
# Build status filter with escaped value
status_filter = f"AND status='{escape_sql_string(status)}'" if status else ""
# Build last_id filter for pagination
# Note: This is simplified. In production, you'd need to track created_at from last record
@ -188,8 +201,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
SELECT * FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND ({triggered_from_filter})
{status_filter}
{last_id_filter}
@ -232,6 +245,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
logger.debug("get_workflow_run_by_id: tenant_id=%s, app_id=%s, run_id=%s", tenant_id, app_id, run_id)
try:
# Escape parameters to prevent SQL injection
escaped_run_id = escape_identifier(run_id)
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of record)
@ -240,7 +258,10 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_execution_logstore}"
WHERE id = '{run_id}' AND tenant_id = '{tenant_id}' AND app_id = '{app_id}' AND __time__ > 0
WHERE id = '{escaped_run_id}'
AND tenant_id = '{escaped_tenant_id}'
AND app_id = '{escaped_app_id}'
AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 100
"""
@ -250,7 +271,12 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
)
else:
# Use SDK with LogStore query syntax
query = f"id: {run_id} and tenant_id: {tenant_id} and app_id: {app_id}"
# Note: Values must be quoted in LogStore query syntax to prevent injection
query = (
f"id:{escape_logstore_query_value(run_id)} "
f"and tenant_id:{escape_logstore_query_value(tenant_id)} "
f"and app_id:{escape_logstore_query_value(app_id)}"
)
from_time = 0
to_time = int(time.time()) # now
@ -323,6 +349,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
logger.debug("get_workflow_run_by_id_without_tenant: run_id=%s", run_id)
try:
# Escape parameter to prevent SQL injection
escaped_run_id = escape_identifier(run_id)
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of record)
@ -331,7 +360,7 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_execution_logstore}"
WHERE id = '{run_id}' AND __time__ > 0
WHERE id = '{escaped_run_id}' AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 100
"""
@ -341,7 +370,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
)
else:
# Use SDK with LogStore query syntax
query = f"id: {run_id}"
# Note: Values must be quoted in LogStore query syntax
query = f"id:{escape_logstore_query_value(run_id)}"
from_time = 0
to_time = int(time.time()) # now
@ -410,6 +440,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
triggered_from,
status,
)
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
escaped_triggered_from = escape_sql_string(triggered_from)
# Build time range filter
time_filter = ""
if time_range:
@ -418,6 +453,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
# If status is provided, simple count
if status:
escaped_status = escape_sql_string(status)
if status == "running":
# Running status requires window function
sql = f"""
@ -425,9 +462,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND triggered_from='{escaped_triggered_from}'
AND status='running'
{time_filter}
) t
@ -438,10 +475,10 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
sql = f"""
SELECT COUNT(DISTINCT id) as count
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
AND status='{status}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND triggered_from='{escaped_triggered_from}'
AND status='{escaped_status}'
AND finished_at IS NOT NULL
{time_filter}
"""
@ -467,13 +504,14 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
# No status filter - get counts grouped by status
# Use optimized query for finished runs, separate query for running
try:
# Escape parameters (already escaped above, reuse variables)
# Count finished runs grouped by status
finished_sql = f"""
SELECT status, COUNT(DISTINCT id) as count
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY status
@ -485,9 +523,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND triggered_from='{escaped_triggered_from}'
AND status='running'
{time_filter}
) t
@ -546,7 +584,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
logger.debug(
"get_daily_runs_statistics: tenant_id=%s, app_id=%s, triggered_from=%s", tenant_id, app_id, triggered_from
)
# Build time range filter
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
escaped_triggered_from = escape_sql_string(triggered_from)
# Build time range filter (datetime.isoformat() is safe)
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
@ -557,9 +601,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
sql = f"""
SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT id) as runs
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date
@ -601,7 +645,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
app_id,
triggered_from,
)
# Build time range filter
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
escaped_triggered_from = escape_sql_string(triggered_from)
# Build time range filter (datetime.isoformat() is safe)
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
@ -611,9 +661,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
sql = f"""
SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT created_by) as terminal_count
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date
@ -655,7 +705,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
app_id,
triggered_from,
)
# Build time range filter
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
escaped_triggered_from = escape_sql_string(triggered_from)
# Build time range filter (datetime.isoformat() is safe)
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
@ -665,9 +721,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
sql = f"""
SELECT DATE(from_unixtime(__time__)) as date, SUM(total_tokens) as token_count
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date
@ -709,7 +765,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
app_id,
triggered_from,
)
# Build time range filter
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
escaped_triggered_from = escape_sql_string(triggered_from)
# Build time range filter (datetime.isoformat() is safe)
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
@ -726,9 +788,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
created_by,
COUNT(DISTINCT id) AS interactions
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date, created_by

View File

@ -10,6 +10,7 @@ from sqlalchemy.orm import sessionmaker
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.workflow.entities import WorkflowExecution
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.logstore.aliyun_logstore import AliyunLogStore
from libs.helper import extract_tenant_id
from models import (
@ -22,18 +23,6 @@ from models.enums import WorkflowRunTriggeredFrom
logger = logging.getLogger(__name__)
def to_serializable(obj):
"""
Convert non-JSON-serializable objects into JSON-compatible formats.
- Uses `to_dict()` if it's a callable method.
- Falls back to string representation.
"""
if hasattr(obj, "to_dict") and callable(obj.to_dict):
return obj.to_dict()
return str(obj)
class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
def __init__(
self,
@ -79,7 +68,7 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
# Control flag for dual-write (write to both LogStore and SQL database)
# Set to True to enable dual-write for safe migration, False to use LogStore only
self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true"
self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "false").lower() == "true"
# Control flag for whether to write the `graph` field to LogStore.
# If LOGSTORE_ENABLE_PUT_GRAPH_FIELD is "true", write the full `graph` field;
@ -113,6 +102,9 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
# Generate log_version as nanosecond timestamp for record versioning
log_version = str(time.time_ns())
# Use WorkflowRuntimeTypeConverter to handle complex types (Segment, File, etc.)
json_converter = WorkflowRuntimeTypeConverter()
logstore_model = [
("id", domain_model.id_),
("log_version", log_version), # Add log_version field for append-only writes
@ -127,19 +119,19 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
("version", domain_model.workflow_version),
(
"graph",
json.dumps(domain_model.graph, ensure_ascii=False, default=to_serializable)
json.dumps(json_converter.to_json_encodable(domain_model.graph), ensure_ascii=False)
if domain_model.graph and self._enable_put_graph_field
else "{}",
),
(
"inputs",
json.dumps(domain_model.inputs, ensure_ascii=False, default=to_serializable)
json.dumps(json_converter.to_json_encodable(domain_model.inputs), ensure_ascii=False)
if domain_model.inputs
else "{}",
),
(
"outputs",
json.dumps(domain_model.outputs, ensure_ascii=False, default=to_serializable)
json.dumps(json_converter.to_json_encodable(domain_model.outputs), ensure_ascii=False)
if domain_model.outputs
else "{}",
),

View File

@ -24,6 +24,8 @@ from core.workflow.enums import NodeType
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.logstore.aliyun_logstore import AliyunLogStore
from extensions.logstore.repositories import safe_float, safe_int
from extensions.logstore.sql_escape import escape_identifier
from libs.helper import extract_tenant_id
from models import (
Account,
@ -73,7 +75,7 @@ def _dict_to_workflow_node_execution(data: dict[str, Any]) -> WorkflowNodeExecut
node_execution_id=data.get("node_execution_id"),
workflow_id=data.get("workflow_id", ""),
workflow_execution_id=data.get("workflow_run_id"),
index=int(data.get("index", 0)),
index=safe_int(data.get("index", 0)),
predecessor_node_id=data.get("predecessor_node_id"),
node_id=data.get("node_id", ""),
node_type=NodeType(data.get("node_type", "start")),
@ -83,7 +85,7 @@ def _dict_to_workflow_node_execution(data: dict[str, Any]) -> WorkflowNodeExecut
outputs=outputs,
status=status,
error=data.get("error"),
elapsed_time=float(data.get("elapsed_time", 0.0)),
elapsed_time=safe_float(data.get("elapsed_time", 0.0)),
metadata=domain_metadata,
created_at=created_at,
finished_at=finished_at,
@ -147,7 +149,7 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
# Control flag for dual-write (write to both LogStore and SQL database)
# Set to True to enable dual-write for safe migration, False to use LogStore only
self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true"
self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "false").lower() == "true"
def _to_logstore_model(self, domain_model: WorkflowNodeExecution) -> Sequence[tuple[str, str]]:
logger.debug(
@ -274,16 +276,34 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
Save or update the inputs, process_data, or outputs associated with a specific
node_execution record.
For LogStore implementation, this is similar to save() since we always write
complete records. We append a new record with updated data fields.
For LogStore implementation, this is a no-op for the LogStore write because save()
already writes all fields including inputs, process_data, and outputs. The caller
typically calls save() first to persist status/metadata, then calls save_execution_data()
to persist data fields. Since LogStore writes complete records atomically, we don't
need a separate write here to avoid duplicate records.
However, if dual-write is enabled, we still need to call the SQL repository's
save_execution_data() method to properly update the SQL database.
Args:
execution: The NodeExecution instance with data to save
"""
logger.debug("save_execution_data: id=%s, node_execution_id=%s", execution.id, execution.node_execution_id)
# In LogStore, we simply write a new complete record with the data
# The log_version timestamp will ensure this is treated as the latest version
self.save(execution)
logger.debug(
"save_execution_data: no-op for LogStore (data already saved by save()): id=%s, node_execution_id=%s",
execution.id,
execution.node_execution_id,
)
# No-op for LogStore: save() already writes all fields including inputs, process_data, and outputs
# Calling save() again would create a duplicate record in the append-only LogStore
# Dual-write to SQL database if enabled (for safe migration)
if self._enable_dual_write:
try:
self.sql_repository.save_execution_data(execution)
logger.debug("Dual-write: saved node execution data to SQL database: id=%s", execution.id)
except Exception:
logger.exception("Failed to dual-write node execution data to SQL database: id=%s", execution.id)
# Don't raise - LogStore write succeeded, SQL is just a backup
def get_by_workflow_run(
self,
@ -292,8 +312,8 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all NodeExecution instances for a specific workflow run.
Uses LogStore SQL query with finished_at IS NOT NULL filter for deduplication.
This ensures we only get the final version of each node execution.
Uses LogStore SQL query with window function to get the latest version of each node execution.
This ensures we only get the most recent version of each node execution record.
Args:
workflow_run_id: The workflow run ID
order_config: Optional configuration for ordering results
@ -304,16 +324,19 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
A list of NodeExecution instances
Note:
This method filters by finished_at IS NOT NULL to avoid duplicates from
version updates. For complete history including intermediate states,
a different query strategy would be needed.
This method uses ROW_NUMBER() window function partitioned by node_execution_id
to get the latest version (highest log_version) of each node execution.
"""
logger.debug("get_by_workflow_run: workflow_run_id=%s, order_config=%s", workflow_run_id, order_config)
# Build SQL query with deduplication using finished_at IS NOT NULL
# This optimization avoids window functions for common case where we only
# want the final state of each node execution
# Build SQL query with deduplication using window function
# ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC)
# ensures we get the latest version of each node execution
# Build ORDER BY clause
# Escape parameters to prevent SQL injection
escaped_workflow_run_id = escape_identifier(workflow_run_id)
escaped_tenant_id = escape_identifier(self._tenant_id)
# Build ORDER BY clause for outer query
order_clause = ""
if order_config and order_config.order_by:
order_fields = []
@ -327,16 +350,23 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
if order_fields:
order_clause = "ORDER BY " + ", ".join(order_fields)
sql = f"""
SELECT *
FROM {AliyunLogStore.workflow_node_execution_logstore}
WHERE workflow_run_id='{workflow_run_id}'
AND tenant_id='{self._tenant_id}'
AND finished_at IS NOT NULL
"""
# Build app_id filter for subquery
app_id_filter = ""
if self._app_id:
sql += f" AND app_id='{self._app_id}'"
escaped_app_id = escape_identifier(self._app_id)
app_id_filter = f" AND app_id='{escaped_app_id}'"
# Use window function to get latest version of each node execution
sql = f"""
SELECT * FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC) AS rn
FROM {AliyunLogStore.workflow_node_execution_logstore}
WHERE workflow_run_id='{escaped_workflow_run_id}'
AND tenant_id='{escaped_tenant_id}'
{app_id_filter}
) t
WHERE rn = 1
"""
if order_clause:
sql += f" {order_clause}"

View File

@ -0,0 +1,134 @@
"""
SQL Escape Utility for LogStore Queries
This module provides escaping utilities to prevent injection attacks in LogStore queries.
LogStore supports two query modes:
1. PG Protocol Mode: Uses SQL syntax with single quotes for strings
2. SDK Mode: Uses LogStore query syntax (key: value) with double quotes
Key Security Concerns:
- Prevent tenant A from accessing tenant B's data via injection
- SLS queries are read-only, so we focus on data access control
- Different escaping strategies for SQL vs LogStore query syntax
"""
def escape_sql_string(value: str) -> str:
"""
Escape a string value for safe use in SQL queries.
This function escapes single quotes by doubling them, which is the standard
SQL escaping method. This prevents SQL injection by ensuring that user input
cannot break out of string literals.
Args:
value: The string value to escape
Returns:
Escaped string safe for use in SQL queries
Examples:
>>> escape_sql_string("normal_value")
"normal_value"
>>> escape_sql_string("value' OR '1'='1")
"value'' OR ''1''=''1"
>>> escape_sql_string("tenant's_id")
"tenant''s_id"
Security:
- Prevents breaking out of string literals
- Stops injection attacks like: ' OR '1'='1
- Protects against cross-tenant data access
"""
if not value:
return value
# Escape single quotes by doubling them (standard SQL escaping)
# This prevents breaking out of string literals in SQL queries
return value.replace("'", "''")
def escape_identifier(value: str) -> str:
"""
Escape an identifier (tenant_id, app_id, run_id, etc.) for safe SQL use.
This function is for PG protocol mode (SQL syntax).
For SDK mode, use escape_logstore_query_value() instead.
Args:
value: The identifier value to escape
Returns:
Escaped identifier safe for use in SQL queries
Examples:
>>> escape_identifier("550e8400-e29b-41d4-a716-446655440000")
"550e8400-e29b-41d4-a716-446655440000"
>>> escape_identifier("tenant_id' OR '1'='1")
"tenant_id'' OR ''1''=''1"
Security:
- Prevents SQL injection via identifiers
- Stops cross-tenant access attempts
- Works for UUIDs, alphanumeric IDs, and similar identifiers
"""
# For identifiers, use the same escaping as strings
# This is simple and effective for preventing injection
return escape_sql_string(value)
def escape_logstore_query_value(value: str) -> str:
"""
Escape value for LogStore query syntax (SDK mode).
LogStore query syntax rules:
1. Keywords (and/or/not) are case-insensitive
2. Single quotes are ordinary characters (no special meaning)
3. Double quotes wrap values: key:"value"
4. Backslash is the escape character:
- \" for double quote inside value
- \\ for backslash itself
5. Parentheses can change query structure
To prevent injection:
- Wrap value in double quotes to treat special chars as literals
- Escape backslashes and double quotes using backslash
Args:
value: The value to escape for LogStore query syntax
Returns:
Quoted and escaped value safe for LogStore query syntax (includes the quotes)
Examples:
>>> escape_logstore_query_value("normal_value")
'"normal_value"'
>>> escape_logstore_query_value("value or field:evil")
'"value or field:evil"' # 'or' and ':' are now literals
>>> escape_logstore_query_value('value"test')
'"value\\"test"' # Internal double quote escaped
>>> escape_logstore_query_value('value\\test')
'"value\\\\test"' # Backslash escaped
Security:
- Prevents injection via and/or/not keywords
- Prevents injection via colons (:)
- Prevents injection via parentheses
- Protects against cross-tenant data access
Note:
Escape order is critical: backslash first, then double quotes.
Otherwise, we'd double-escape the escape character itself.
"""
if not value:
return '""'
# IMPORTANT: Escape backslashes FIRST, then double quotes
# This prevents double-escaping (e.g., " -> \" -> \\" incorrectly)
escaped = value.replace("\\", "\\\\") # \ -> \\
escaped = escaped.replace('"', '\\"') # " -> \"
# Wrap in double quotes to treat as literal string
# This prevents and/or/not/:/() from being interpreted as operators
return f'"{escaped}"'

View File

@ -38,7 +38,7 @@ from core.variables.variables import (
ObjectVariable,
SecretVariable,
StringVariable,
Variable,
VariableBase,
)
from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
@ -72,25 +72,25 @@ SEGMENT_TO_VARIABLE_MAP = {
}
def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase:
if not mapping.get("name"):
raise VariableError("missing name")
return _build_variable_from_mapping(mapping=mapping, selector=[CONVERSATION_VARIABLE_NODE_ID, mapping["name"]])
def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase:
if not mapping.get("name"):
raise VariableError("missing name")
return _build_variable_from_mapping(mapping=mapping, selector=[ENVIRONMENT_VARIABLE_NODE_ID, mapping["name"]])
def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase:
if not mapping.get("variable"):
raise VariableError("missing variable")
return mapping["variable"]
def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable:
def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> VariableBase:
"""
This factory function is used to create the environment variable or the conversation variable,
not support the File type.
@ -100,7 +100,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
if (value := mapping.get("value")) is None:
raise VariableError("missing value")
result: Variable
result: VariableBase
match value_type:
case SegmentType.STRING:
result = StringVariable.model_validate(mapping)
@ -134,7 +134,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}")
if not result.selector:
result = result.model_copy(update={"selector": selector})
return cast(Variable, result)
return cast(VariableBase, result)
def build_segment(value: Any, /) -> Segment:
@ -285,8 +285,8 @@ def segment_to_variable(
id: str | None = None,
name: str | None = None,
description: str = "",
) -> Variable:
if isinstance(segment, Variable):
) -> VariableBase:
if isinstance(segment, VariableBase):
return segment
name = name or selector[-1]
id = id or str(uuid4())
@ -297,7 +297,7 @@ def segment_to_variable(
variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type]
return cast(
Variable,
VariableBase,
variable_class(
id=id,
name=name,

View File

@ -1,7 +1,7 @@
from flask_restx import fields
from core.helper import encrypter
from core.variables import SecretVariable, SegmentType, Variable
from core.variables import SecretVariable, SegmentType, VariableBase
from fields.member_fields import simple_account_fields
from libs.helper import TimestampField
@ -21,7 +21,7 @@ class EnvironmentVariableField(fields.Raw):
"value_type": value.value_type.value,
"description": value.description,
}
if isinstance(value, Variable):
if isinstance(value, VariableBase):
return {
"id": value.id,
"name": value.name,

View File

@ -1,11 +1,9 @@
from __future__ import annotations
import json
import logging
from collections.abc import Generator, Mapping, Sequence
from datetime import datetime
from enum import StrEnum
from typing import TYPE_CHECKING, Any, Union, cast
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from uuid import uuid4
import sqlalchemy as sa
@ -46,7 +44,7 @@ if TYPE_CHECKING:
from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE
from core.helper import encrypter
from core.variables import SecretVariable, Segment, SegmentType, Variable
from core.variables import SecretVariable, Segment, SegmentType, VariableBase
from factories import variable_factory
from libs import helper
@ -69,7 +67,7 @@ class WorkflowType(StrEnum):
RAG_PIPELINE = "rag-pipeline"
@classmethod
def value_of(cls, value: str) -> WorkflowType:
def value_of(cls, value: str) -> "WorkflowType":
"""
Get value of given mode.
@ -82,7 +80,7 @@ class WorkflowType(StrEnum):
raise ValueError(f"invalid workflow type value {value}")
@classmethod
def from_app_mode(cls, app_mode: Union[str, AppMode]) -> WorkflowType:
def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType":
"""
Get workflow type from app mode.
@ -178,12 +176,12 @@ class Workflow(Base): # bug
graph: str,
features: str,
created_by: str,
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable],
environment_variables: Sequence[VariableBase],
conversation_variables: Sequence[VariableBase],
rag_pipeline_variables: list[dict],
marked_name: str = "",
marked_comment: str = "",
) -> Workflow:
) -> "Workflow":
workflow = Workflow()
workflow.id = str(uuid4())
workflow.tenant_id = tenant_id
@ -447,7 +445,7 @@ class Workflow(Base): # bug
# decrypt secret variables value
def decrypt_func(
var: Variable,
var: VariableBase,
) -> StringVariable | IntegerVariable | FloatVariable | SecretVariable:
if isinstance(var, SecretVariable):
return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)})
@ -463,7 +461,7 @@ class Workflow(Base): # bug
return decrypted_results
@environment_variables.setter
def environment_variables(self, value: Sequence[Variable]):
def environment_variables(self, value: Sequence[VariableBase]):
if not value:
self._environment_variables = "{}"
return
@ -487,7 +485,7 @@ class Workflow(Base): # bug
value[i] = origin_variables_dictionary[variable.id].model_copy(update={"name": variable.name})
# encrypt secret variables value
def encrypt_func(var: Variable) -> Variable:
def encrypt_func(var: VariableBase) -> VariableBase:
if isinstance(var, SecretVariable):
return var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)})
else:
@ -517,7 +515,7 @@ class Workflow(Base): # bug
return result
@property
def conversation_variables(self) -> Sequence[Variable]:
def conversation_variables(self) -> Sequence[VariableBase]:
# TODO: find some way to init `self._conversation_variables` when instance created.
if self._conversation_variables is None:
self._conversation_variables = "{}"
@ -527,7 +525,7 @@ class Workflow(Base): # bug
return results
@conversation_variables.setter
def conversation_variables(self, value: Sequence[Variable]):
def conversation_variables(self, value: Sequence[VariableBase]):
self._conversation_variables = json.dumps(
{var.name: var.model_dump() for var in value},
ensure_ascii=False,
@ -622,7 +620,7 @@ class WorkflowRun(Base):
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
pause: Mapped[WorkflowPause | None] = orm.relationship(
pause: Mapped[Optional["WorkflowPause"]] = orm.relationship(
"WorkflowPause",
primaryjoin="WorkflowRun.id == foreign(WorkflowPause.workflow_run_id)",
uselist=False,
@ -692,7 +690,7 @@ class WorkflowRun(Base):
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> WorkflowRun:
def from_dict(cls, data: dict[str, Any]) -> "WorkflowRun":
return cls(
id=data.get("id"),
tenant_id=data.get("tenant_id"),
@ -844,7 +842,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
created_by: Mapped[str] = mapped_column(StringUUID)
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
offload_data: Mapped[list[WorkflowNodeExecutionOffload]] = orm.relationship(
offload_data: Mapped[list["WorkflowNodeExecutionOffload"]] = orm.relationship(
"WorkflowNodeExecutionOffload",
primaryjoin="WorkflowNodeExecutionModel.id == foreign(WorkflowNodeExecutionOffload.node_execution_id)",
uselist=True,
@ -854,13 +852,13 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
@staticmethod
def preload_offload_data(
query: Select[tuple[WorkflowNodeExecutionModel]] | orm.Query[WorkflowNodeExecutionModel],
query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"],
):
return query.options(orm.selectinload(WorkflowNodeExecutionModel.offload_data))
@staticmethod
def preload_offload_data_and_files(
query: Select[tuple[WorkflowNodeExecutionModel]] | orm.Query[WorkflowNodeExecutionModel],
query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"],
):
return query.options(
orm.selectinload(WorkflowNodeExecutionModel.offload_data).options(
@ -935,7 +933,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
)
return extras
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> WorkflowNodeExecutionOffload | None:
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]:
return next(iter([i for i in self.offload_data if i.type_ == type_]), None)
@property
@ -1049,7 +1047,7 @@ class WorkflowNodeExecutionOffload(Base):
back_populates="offload_data",
)
file: Mapped[UploadFile | None] = orm.relationship(
file: Mapped[Optional["UploadFile"]] = orm.relationship(
foreign_keys=[file_id],
lazy="raise",
uselist=False,
@ -1067,7 +1065,7 @@ class WorkflowAppLogCreatedFrom(StrEnum):
INSTALLED_APP = "installed-app"
@classmethod
def value_of(cls, value: str) -> WorkflowAppLogCreatedFrom:
def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom":
"""
Get value of given mode.
@ -1184,7 +1182,7 @@ class ConversationVariable(TypeBase):
)
@classmethod
def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> ConversationVariable:
def from_variable(cls, *, app_id: str, conversation_id: str, variable: VariableBase) -> "ConversationVariable":
obj = cls(
id=variable.id,
app_id=app_id,
@ -1193,7 +1191,7 @@ class ConversationVariable(TypeBase):
)
return obj
def to_variable(self) -> Variable:
def to_variable(self) -> VariableBase:
mapping = json.loads(self.data)
return variable_factory.build_conversation_variable_from_mapping(mapping)
@ -1337,7 +1335,7 @@ class WorkflowDraftVariable(Base):
)
# Relationship to WorkflowDraftVariableFile
variable_file: Mapped[WorkflowDraftVariableFile | None] = orm.relationship(
variable_file: Mapped[Optional["WorkflowDraftVariableFile"]] = orm.relationship(
foreign_keys=[file_id],
lazy="raise",
uselist=False,
@ -1507,7 +1505,7 @@ class WorkflowDraftVariable(Base):
node_execution_id: str | None,
description: str = "",
file_id: str | None = None,
) -> WorkflowDraftVariable:
) -> "WorkflowDraftVariable":
variable = WorkflowDraftVariable()
variable.id = str(uuid4())
variable.created_at = naive_utc_now()
@ -1530,7 +1528,7 @@ class WorkflowDraftVariable(Base):
name: str,
value: Segment,
description: str = "",
) -> WorkflowDraftVariable:
) -> "WorkflowDraftVariable":
variable = cls._new(
app_id=app_id,
node_id=CONVERSATION_VARIABLE_NODE_ID,
@ -1551,7 +1549,7 @@ class WorkflowDraftVariable(Base):
value: Segment,
node_execution_id: str,
editable: bool = False,
) -> WorkflowDraftVariable:
) -> "WorkflowDraftVariable":
variable = cls._new(
app_id=app_id,
node_id=SYSTEM_VARIABLE_NODE_ID,
@ -1574,7 +1572,7 @@ class WorkflowDraftVariable(Base):
visible: bool = True,
editable: bool = True,
file_id: str | None = None,
) -> WorkflowDraftVariable:
) -> "WorkflowDraftVariable":
variable = cls._new(
app_id=app_id,
node_id=node_id,
@ -1670,7 +1668,7 @@ class WorkflowDraftVariableFile(Base):
)
# Relationship to UploadFile
upload_file: Mapped[UploadFile] = orm.relationship(
upload_file: Mapped["UploadFile"] = orm.relationship(
foreign_keys=[upload_file_id],
lazy="raise",
uselist=False,
@ -1737,7 +1735,7 @@ class WorkflowPause(DefaultFieldsMixin, Base):
state_object_key: Mapped[str] = mapped_column(String(length=255), nullable=False)
# Relationship to WorkflowRun
workflow_run: Mapped[WorkflowRun] = orm.relationship(
workflow_run: Mapped["WorkflowRun"] = orm.relationship(
foreign_keys=[workflow_run_id],
# require explicit preloading.
lazy="raise",
@ -1793,7 +1791,7 @@ class WorkflowPauseReason(DefaultFieldsMixin, Base):
)
@classmethod
def from_entity(cls, pause_reason: PauseReason) -> WorkflowPauseReason:
def from_entity(cls, pause_reason: PauseReason) -> "WorkflowPauseReason":
if isinstance(pause_reason, HumanInputRequired):
return cls(
type_=PauseReasonType.HUMAN_INPUT_REQUIRED, form_id=pause_reason.form_id, node_id=pause_reason.node_id

View File

@ -1,6 +1,6 @@
[project]
name = "dify-api"
version = "1.11.2"
version = "1.11.3"
requires-python = ">=3.11,<3.13"
dependencies = [

View File

@ -1,7 +1,7 @@
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from core.variables.variables import Variable
from core.variables.variables import VariableBase
from models import ConversationVariable
@ -13,7 +13,7 @@ class ConversationVariableUpdater:
def __init__(self, session_maker: sessionmaker[Session]) -> None:
self._session_maker: sessionmaker[Session] = session_maker
def update(self, conversation_id: str, variable: Variable) -> None:
def update(self, conversation_id: str, variable: VariableBase) -> None:
stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
)

View File

@ -1,9 +1,14 @@
import logging
import os
from collections.abc import Mapping
from typing import Any
import httpx
from core.helper.trace_id_helper import generate_traceparent_header
logger = logging.getLogger(__name__)
class BaseRequest:
proxies: Mapping[str, str] | None = {
@ -38,6 +43,15 @@ class BaseRequest:
headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key}
url = f"{cls.base_url}{endpoint}"
mounts = cls._build_mounts()
try:
# ensure traceparent even when OTEL is disabled
traceparent = generate_traceparent_header()
if traceparent:
headers["traceparent"] = traceparent
except Exception:
logger.debug("Failed to generate traceparent header", exc_info=True)
with httpx.Client(mounts=mounts) as client:
response = client.request(method, url, json=json, params=params, headers=headers)
return response.json()

View File

@ -3,6 +3,7 @@ from collections.abc import Mapping, Sequence
from mimetypes import guess_type
from pydantic import BaseModel
from sqlalchemy import select
from yarl import URL
from configs import dify_config
@ -25,7 +26,9 @@ from core.plugin.entities.plugin_daemon import (
from core.plugin.impl.asset import PluginAssetManager
from core.plugin.impl.debugging import PluginDebuggingClient
from core.plugin.impl.plugin import PluginInstaller
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.provider import ProviderCredential
from models.provider_ids import GenericProviderID
from services.errors.plugin import PluginInstallationForbiddenError
from services.feature_service import FeatureService, PluginInstallationScope
@ -506,6 +509,33 @@ class PluginService:
@staticmethod
def uninstall(tenant_id: str, plugin_installation_id: str) -> bool:
manager = PluginInstaller()
# Get plugin info before uninstalling to delete associated credentials
try:
plugins = manager.list_plugins(tenant_id)
plugin = next((p for p in plugins if p.installation_id == plugin_installation_id), None)
if plugin:
plugin_id = plugin.plugin_id
logger.info("Deleting credentials for plugin: %s", plugin_id)
# Delete provider credentials that match this plugin
credentials = db.session.scalars(
select(ProviderCredential).where(
ProviderCredential.tenant_id == tenant_id,
ProviderCredential.provider_name.like(f"{plugin_id}/%"),
)
).all()
for cred in credentials:
db.session.delete(cred)
db.session.commit()
logger.info("Deleted %d credentials for plugin: %s", len(credentials), plugin_id)
except Exception as e:
logger.warning("Failed to delete credentials: %s", e)
# Continue with uninstall even if credential deletion fails
return manager.uninstall(tenant_id, plugin_installation_id)
@staticmethod

View File

@ -36,7 +36,7 @@ from core.rag.entities.event import (
)
from core.repositories.factory import DifyCoreRepositoryFactory
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
from core.variables.variables import Variable
from core.variables.variables import VariableBase
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
@ -270,8 +270,8 @@ class RagPipelineService:
graph: dict,
unique_hash: str | None,
account: Account,
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable],
environment_variables: Sequence[VariableBase],
conversation_variables: Sequence[VariableBase],
rag_pipeline_variables: list,
) -> Workflow:
"""

View File

@ -15,7 +15,7 @@ from sqlalchemy.sql.expression import and_, or_
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File
from core.variables import Segment, StringSegment, Variable
from core.variables import Segment, StringSegment, VariableBase
from core.variables.consts import SELECTORS_LENGTH
from core.variables.segments import (
ArrayFileSegment,
@ -77,14 +77,14 @@ class DraftVarLoader(VariableLoader):
# Application ID for which variables are being loaded.
_app_id: str
_tenant_id: str
_fallback_variables: Sequence[Variable]
_fallback_variables: Sequence[VariableBase]
def __init__(
self,
engine: Engine,
app_id: str,
tenant_id: str,
fallback_variables: Sequence[Variable] | None = None,
fallback_variables: Sequence[VariableBase] | None = None,
):
self._engine = engine
self._app_id = app_id
@ -94,12 +94,12 @@ class DraftVarLoader(VariableLoader):
def _selector_to_tuple(self, selector: Sequence[str]) -> tuple[str, str]:
return (selector[0], selector[1])
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]:
if not selectors:
return []
# Map each selector (as a tuple via `_selector_to_tuple`) to its corresponding Variable instance.
variable_by_selector: dict[tuple[str, str], Variable] = {}
# Map each selector (as a tuple via `_selector_to_tuple`) to its corresponding variable instance.
variable_by_selector: dict[tuple[str, str], VariableBase] = {}
with Session(bind=self._engine, expire_on_commit=False) as session:
srv = WorkflowDraftVariableService(session)
@ -145,7 +145,7 @@ class DraftVarLoader(VariableLoader):
return list(variable_by_selector.values())
def _load_offloaded_variable(self, draft_var: WorkflowDraftVariable) -> tuple[tuple[str, str], Variable]:
def _load_offloaded_variable(self, draft_var: WorkflowDraftVariable) -> tuple[tuple[str, str], VariableBase]:
# This logic is closely tied to `WorkflowDraftVaribleService._try_offload_large_variable`
# and must remain synchronized with it.
# Ideally, these should be co-located for better maintainability.

View File

@ -13,8 +13,8 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.file import File
from core.repositories import DifyCoreRepositoryFactory
from core.variables import Variable
from core.variables.variables import VariableUnion
from core.variables import VariableBase
from core.variables.variables import Variable
from core.workflow.entities import WorkflowNodeExecution
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.errors import WorkflowNodeRunFailedError
@ -198,8 +198,8 @@ class WorkflowService:
features: dict,
unique_hash: str | None,
account: Account,
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable],
environment_variables: Sequence[VariableBase],
conversation_variables: Sequence[VariableBase],
) -> Workflow:
"""
Sync draft workflow
@ -1044,7 +1044,7 @@ def _setup_variable_pool(
workflow: Workflow,
node_type: NodeType,
conversation_id: str,
conversation_variables: list[Variable],
conversation_variables: list[VariableBase],
):
# Only inject system variables for START node type.
if node_type == NodeType.START or node_type.is_trigger_node:
@ -1070,9 +1070,9 @@ def _setup_variable_pool(
system_variables=system_variable,
user_inputs=user_inputs,
environment_variables=workflow.environment_variables,
# Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
conversation_variables=cast(list[VariableUnion], conversation_variables), #
# Based on the definition of `Variable`,
# `VariableBase` instances can be safely used as `Variable` since they are compatible.
conversation_variables=cast(list[Variable], conversation_variables), #
)
return variable_pool

View File

@ -228,11 +228,28 @@ def test_resolve_user_from_database_falls_back_to_end_user(monkeypatch: pytest.M
def scalar(self, _stmt):
return self.results.pop(0)
# SQLAlchemy Session APIs used by code under test
def expunge(self, *_args, **_kwargs):
pass
def close(self):
pass
# support `with session_factory.create_session() as session:`
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
self.close()
tenant = SimpleNamespace(id="tenant_id")
end_user = SimpleNamespace(id="end_user_id", tenant_id="tenant_id")
db_stub = SimpleNamespace(session=StubSession([tenant, None, end_user]))
monkeypatch.setattr("core.tools.workflow_as_tool.tool.db", db_stub)
# Monkeypatch session factory to return our stub session
monkeypatch.setattr(
"core.tools.workflow_as_tool.tool.session_factory.create_session",
lambda: StubSession([tenant, None, end_user]),
)
entity = ToolEntity(
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
@ -266,8 +283,23 @@ def test_resolve_user_from_database_returns_none_when_no_tenant(monkeypatch: pyt
def scalar(self, _stmt):
return self.results.pop(0)
db_stub = SimpleNamespace(session=StubSession([None]))
monkeypatch.setattr("core.tools.workflow_as_tool.tool.db", db_stub)
def expunge(self, *_args, **_kwargs):
pass
def close(self):
pass
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
self.close()
# Monkeypatch session factory to return our stub session with no tenant
monkeypatch.setattr(
"core.tools.workflow_as_tool.tool.session_factory.create_session",
lambda: StubSession([None]),
)
entity = ToolEntity(
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),

View File

@ -35,7 +35,6 @@ from core.variables.variables import (
SecretVariable,
StringVariable,
Variable,
VariableUnion,
)
from core.workflow.runtime import VariablePool
from core.workflow.system_variable import SystemVariable
@ -96,7 +95,7 @@ class _Segments(BaseModel):
class _Variables(BaseModel):
variables: list[VariableUnion]
variables: list[Variable]
def create_test_file(
@ -194,7 +193,7 @@ class TestSegmentDumpAndLoad:
# Create one instance of each variable type
test_file = create_test_file()
all_variables: list[VariableUnion] = [
all_variables: list[Variable] = [
NoneVariable(name="none_var"),
StringVariable(value="test string", name="string_var"),
IntegerVariable(value=42, name="int_var"),

View File

@ -11,7 +11,7 @@ from core.variables import (
SegmentType,
StringVariable,
)
from core.variables.variables import Variable
from core.variables.variables import VariableBase
def test_frozen_variables():
@ -76,7 +76,7 @@ def test_object_variable_to_object():
def test_variable_to_object():
var: Variable = StringVariable(name="text", value="text")
var: VariableBase = StringVariable(name="text", value="text")
assert var.to_object() == "text"
var = IntegerVariable(name="integer", value=42)
assert var.to_object() == 42

View File

@ -24,7 +24,7 @@ from core.variables.variables import (
IntegerVariable,
ObjectVariable,
StringVariable,
VariableUnion,
Variable,
)
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.runtime import VariablePool
@ -160,7 +160,7 @@ class TestVariablePoolSerialization:
)
# Create environment variables with all types including ArrayFileVariable
env_vars: list[VariableUnion] = [
env_vars: list[Variable] = [
StringVariable(
id="env_string_id",
name="env_string",
@ -182,7 +182,7 @@ class TestVariablePoolSerialization:
]
# Create conversation variables with complex data
conv_vars: list[VariableUnion] = [
conv_vars: list[Variable] = [
StringVariable(
id="conv_string_id",
name="conv_string",

View File

@ -0,0 +1 @@
"""LogStore extension unit tests."""

View File

@ -0,0 +1,469 @@
"""
Unit tests for SQL escape utility functions.
These tests ensure that SQL injection attacks are properly prevented
in LogStore queries, particularly for cross-tenant access scenarios.
"""
import pytest
from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string
class TestEscapeSQLString:
"""Test escape_sql_string function."""
def test_escape_empty_string(self):
"""Test escaping empty string."""
assert escape_sql_string("") == ""
def test_escape_normal_string(self):
"""Test escaping string without special characters."""
assert escape_sql_string("tenant_abc123") == "tenant_abc123"
assert escape_sql_string("app-uuid-1234") == "app-uuid-1234"
def test_escape_single_quote(self):
"""Test escaping single quote."""
# Single quote should be doubled
assert escape_sql_string("tenant'id") == "tenant''id"
assert escape_sql_string("O'Reilly") == "O''Reilly"
def test_escape_multiple_quotes(self):
"""Test escaping multiple single quotes."""
assert escape_sql_string("a'b'c") == "a''b''c"
assert escape_sql_string("'''") == "''''''"
# === SQL Injection Attack Scenarios ===
def test_prevent_boolean_injection(self):
"""Test prevention of boolean injection attacks."""
# Classic OR 1=1 attack
malicious_input = "tenant' OR '1'='1"
escaped = escape_sql_string(malicious_input)
assert escaped == "tenant'' OR ''1''=''1"
# When used in SQL, this becomes a safe string literal
sql = f"WHERE tenant_id='{escaped}'"
assert sql == "WHERE tenant_id='tenant'' OR ''1''=''1'"
# The entire input is now a string literal that won't match any tenant
def test_prevent_or_injection(self):
"""Test prevention of OR-based injection."""
malicious_input = "tenant_a' OR tenant_id='tenant_b"
escaped = escape_sql_string(malicious_input)
assert escaped == "tenant_a'' OR tenant_id=''tenant_b"
sql = f"WHERE tenant_id='{escaped}'"
# The OR is now part of the string literal, not SQL logic
assert "OR tenant_id=" in sql
# The SQL has: opening ', doubled internal quotes '', and closing '
assert sql == "WHERE tenant_id='tenant_a'' OR tenant_id=''tenant_b'"
def test_prevent_union_injection(self):
"""Test prevention of UNION-based injection."""
malicious_input = "xxx' UNION SELECT password FROM users WHERE '1'='1"
escaped = escape_sql_string(malicious_input)
assert escaped == "xxx'' UNION SELECT password FROM users WHERE ''1''=''1"
# UNION becomes part of the string literal
assert "UNION" in escaped
assert escaped.count("''") == 4 # All internal quotes are doubled
def test_prevent_comment_injection(self):
"""Test prevention of comment-based injection."""
# SQL comment to bypass remaining conditions
malicious_input = "tenant' --"
escaped = escape_sql_string(malicious_input)
assert escaped == "tenant'' --"
sql = f"WHERE tenant_id='{escaped}' AND deleted=false"
# The -- is now inside the string, not a SQL comment
assert "--" in sql
assert "AND deleted=false" in sql # This part is NOT commented out
def test_prevent_semicolon_injection(self):
"""Test prevention of semicolon-based multi-statement injection."""
malicious_input = "tenant'; DROP TABLE users; --"
escaped = escape_sql_string(malicious_input)
assert escaped == "tenant''; DROP TABLE users; --"
# Semicolons and DROP are now part of the string
assert "DROP TABLE" in escaped
def test_prevent_time_based_blind_injection(self):
"""Test prevention of time-based blind SQL injection."""
malicious_input = "tenant' AND SLEEP(5) --"
escaped = escape_sql_string(malicious_input)
assert escaped == "tenant'' AND SLEEP(5) --"
# SLEEP becomes part of the string
assert "SLEEP" in escaped
def test_prevent_wildcard_injection(self):
"""Test prevention of wildcard-based injection."""
malicious_input = "tenant' OR tenant_id LIKE '%"
escaped = escape_sql_string(malicious_input)
assert escaped == "tenant'' OR tenant_id LIKE ''%"
# The LIKE and wildcard are now part of the string
assert "LIKE" in escaped
def test_prevent_null_byte_injection(self):
"""Test handling of null bytes."""
# Null bytes can sometimes bypass filters
malicious_input = "tenant\x00' OR '1'='1"
escaped = escape_sql_string(malicious_input)
# Null byte is preserved, but quote is escaped
assert "''1''=''1" in escaped
# === Real-world SAAS Scenarios ===
def test_cross_tenant_access_attempt(self):
"""Test prevention of cross-tenant data access."""
# Attacker tries to access another tenant's data
attacker_input = "tenant_b' OR tenant_id='tenant_a"
escaped = escape_sql_string(attacker_input)
sql = f"SELECT * FROM workflow_runs WHERE tenant_id='{escaped}'"
# The query will look for a tenant literally named "tenant_b' OR tenant_id='tenant_a"
# which doesn't exist - preventing access to either tenant's data
assert "tenant_b'' OR tenant_id=''tenant_a" in sql
def test_cross_app_access_attempt(self):
"""Test prevention of cross-application data access."""
attacker_input = "app1' OR app_id='app2"
escaped = escape_sql_string(attacker_input)
sql = f"WHERE app_id='{escaped}'"
# Cannot access app2's data
assert "app1'' OR app_id=''app2" in sql
def test_bypass_status_filter(self):
"""Test prevention of bypassing status filters."""
# Try to see all statuses instead of just 'running'
attacker_input = "running' OR status LIKE '%"
escaped = escape_sql_string(attacker_input)
sql = f"WHERE status='{escaped}'"
# Status condition is not bypassed
assert "running'' OR status LIKE ''%" in sql
# === Edge Cases ===
def test_escape_only_quotes(self):
"""Test string with only quotes."""
assert escape_sql_string("'") == "''"
assert escape_sql_string("''") == "''''"
def test_escape_mixed_content(self):
"""Test string with mixed quotes and other chars."""
input_str = "It's a 'test' of O'Reilly's code"
escaped = escape_sql_string(input_str)
assert escaped == "It''s a ''test'' of O''Reilly''s code"
def test_escape_unicode_with_quotes(self):
"""Test Unicode strings with quotes."""
input_str = "租户' OR '1'='1"
escaped = escape_sql_string(input_str)
assert escaped == "租户'' OR ''1''=''1"
class TestEscapeIdentifier:
"""Test escape_identifier function."""
def test_escape_uuid(self):
"""Test escaping UUID identifiers."""
uuid = "550e8400-e29b-41d4-a716-446655440000"
assert escape_identifier(uuid) == uuid
def test_escape_alphanumeric_id(self):
"""Test escaping alphanumeric identifiers."""
assert escape_identifier("tenant_123") == "tenant_123"
assert escape_identifier("app-abc-123") == "app-abc-123"
def test_escape_identifier_with_quote(self):
"""Test escaping identifier with single quote."""
malicious = "tenant' OR '1'='1"
escaped = escape_identifier(malicious)
assert escaped == "tenant'' OR ''1''=''1"
def test_identifier_injection_attempt(self):
"""Test prevention of injection through identifiers."""
# Common identifier injection patterns
test_cases = [
("id' OR '1'='1", "id'' OR ''1''=''1"),
("id'; DROP TABLE", "id''; DROP TABLE"),
("id' UNION SELECT", "id'' UNION SELECT"),
]
for malicious, expected in test_cases:
assert escape_identifier(malicious) == expected
class TestSQLInjectionIntegration:
"""Integration tests simulating real SQL construction scenarios."""
def test_complete_where_clause_safety(self):
"""Test that a complete WHERE clause is safe from injection."""
# Simulating typical query construction
tenant_id = "tenant' OR '1'='1"
app_id = "app' UNION SELECT"
run_id = "run' --"
escaped_tenant = escape_identifier(tenant_id)
escaped_app = escape_identifier(app_id)
escaped_run = escape_identifier(run_id)
sql = f"""
SELECT * FROM workflow_runs
WHERE tenant_id='{escaped_tenant}'
AND app_id='{escaped_app}'
AND id='{escaped_run}'
"""
# Verify all special characters are escaped
assert "tenant'' OR ''1''=''1" in sql
assert "app'' UNION SELECT" in sql
assert "run'' --" in sql
# Verify SQL structure is preserved (3 conditions with AND)
assert sql.count("AND") == 2
def test_multiple_conditions_with_injection_attempts(self):
"""Test multiple conditions all attempting injection."""
conditions = {
"tenant_id": "t1' OR tenant_id='t2",
"app_id": "a1' OR app_id='a2",
"status": "running' OR '1'='1",
}
where_parts = []
for field, value in conditions.items():
escaped = escape_sql_string(value)
where_parts.append(f"{field}='{escaped}'")
where_clause = " AND ".join(where_parts)
# All injection attempts are neutralized
assert "t1'' OR tenant_id=''t2" in where_clause
assert "a1'' OR app_id=''a2" in where_clause
assert "running'' OR ''1''=''1" in where_clause
# AND structure is preserved
assert where_clause.count(" AND ") == 2
@pytest.mark.parametrize(
("attack_vector", "description"),
[
("' OR '1'='1", "Boolean injection"),
("' OR '1'='1' --", "Boolean with comment"),
("' UNION SELECT * FROM users --", "Union injection"),
("'; DROP TABLE workflow_runs; --", "Destructive command"),
("' AND SLEEP(10) --", "Time-based blind"),
("' OR tenant_id LIKE '%", "Wildcard injection"),
("admin' --", "Comment bypass"),
("' OR 1=1 LIMIT 1 --", "Limit bypass"),
],
)
def test_common_injection_vectors(self, attack_vector, description):
"""Test protection against common injection attack vectors."""
escaped = escape_sql_string(attack_vector)
# Build SQL
sql = f"WHERE tenant_id='{escaped}'"
# Verify the attack string is now a safe literal
# The key indicator: all internal single quotes are doubled
internal_quotes = escaped.count("''")
original_quotes = attack_vector.count("'")
# Each original quote should be doubled
assert internal_quotes == original_quotes
# Verify SQL has exactly 2 quotes (opening and closing)
assert sql.count("'") >= 2 # At least opening and closing
def test_logstore_specific_scenario(self):
"""Test SQL injection prevention in LogStore-specific scenarios."""
# Simulate LogStore query with window function
tenant_id = "tenant' OR '1'='1"
app_id = "app' UNION SELECT"
escaped_tenant = escape_identifier(tenant_id)
escaped_app = escape_identifier(app_id)
sql = f"""
SELECT * FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM workflow_execution_logstore
WHERE tenant_id='{escaped_tenant}'
AND app_id='{escaped_app}'
AND __time__ > 0
) AS subquery WHERE rn = 1
"""
# Complex query structure is maintained
assert "ROW_NUMBER()" in sql
assert "PARTITION BY id" in sql
# Injection attempts are escaped
assert "tenant'' OR ''1''=''1" in sql
assert "app'' UNION SELECT" in sql
# ====================================================================================
# Tests for LogStore Query Syntax (SDK Mode)
# ====================================================================================
class TestLogStoreQueryEscape:
"""Test escape_logstore_query_value for SDK mode query syntax."""
def test_normal_value(self):
"""Test escaping normal alphanumeric value."""
value = "550e8400-e29b-41d4-a716-446655440000"
escaped = escape_logstore_query_value(value)
# Should be wrapped in double quotes
assert escaped == '"550e8400-e29b-41d4-a716-446655440000"'
def test_empty_value(self):
"""Test escaping empty string."""
assert escape_logstore_query_value("") == '""'
def test_value_with_and_keyword(self):
"""Test that 'and' keyword is neutralized when quoted."""
malicious = "value and field:evil"
escaped = escape_logstore_query_value(malicious)
# Should be wrapped in quotes, making 'and' a literal
assert escaped == '"value and field:evil"'
# Simulate using in query
query = f"tenant_id:{escaped}"
assert query == 'tenant_id:"value and field:evil"'
def test_value_with_or_keyword(self):
"""Test that 'or' keyword is neutralized when quoted."""
malicious = "tenant_a or tenant_id:tenant_b"
escaped = escape_logstore_query_value(malicious)
assert escaped == '"tenant_a or tenant_id:tenant_b"'
query = f"tenant_id:{escaped}"
assert "or" in query # Present but as literal string
def test_value_with_not_keyword(self):
"""Test that 'not' keyword is neutralized when quoted."""
malicious = "not field:value"
escaped = escape_logstore_query_value(malicious)
assert escaped == '"not field:value"'
def test_value_with_parentheses(self):
"""Test that parentheses are neutralized when quoted."""
malicious = "(tenant_a or tenant_b)"
escaped = escape_logstore_query_value(malicious)
assert escaped == '"(tenant_a or tenant_b)"'
assert "(" in escaped # Present as literal
assert ")" in escaped # Present as literal
def test_value_with_colon(self):
"""Test that colons are neutralized when quoted."""
malicious = "field:value"
escaped = escape_logstore_query_value(malicious)
assert escaped == '"field:value"'
assert ":" in escaped # Present as literal
def test_value_with_double_quotes(self):
"""Test that internal double quotes are escaped."""
value_with_quotes = 'tenant"test"value'
escaped = escape_logstore_query_value(value_with_quotes)
# Double quotes should be escaped with backslash
assert escaped == '"tenant\\"test\\"value"'
# Should have outer quotes plus escaped inner quotes
assert '\\"' in escaped
def test_value_with_backslash(self):
"""Test that backslashes are escaped."""
value_with_backslash = "tenant\\test"
escaped = escape_logstore_query_value(value_with_backslash)
# Backslash should be escaped
assert escaped == '"tenant\\\\test"'
assert "\\\\" in escaped
def test_value_with_backslash_and_quote(self):
"""Test escaping both backslash and double quote."""
value = 'path\\to\\"file"'
escaped = escape_logstore_query_value(value)
# Both should be escaped
assert escaped == '"path\\\\to\\\\\\"file\\""'
# Verify escape order is correct
assert "\\\\" in escaped # Escaped backslash
assert '\\"' in escaped # Escaped double quote
def test_complex_injection_attempt(self):
"""Test complex injection combining multiple operators."""
malicious = 'tenant_a" or (tenant_id:"tenant_b" and app_id:"evil")'
escaped = escape_logstore_query_value(malicious)
# All special chars should be literals or escaped
assert escaped.startswith('"')
assert escaped.endswith('"')
# Inner double quotes escaped, operators become literals
assert "or" in escaped
assert "and" in escaped
assert '\\"' in escaped # Escaped quotes
def test_only_backslash(self):
"""Test escaping a single backslash."""
assert escape_logstore_query_value("\\") == '"\\\\"'
def test_only_double_quote(self):
"""Test escaping a single double quote."""
assert escape_logstore_query_value('"') == '"\\""'
def test_multiple_backslashes(self):
"""Test escaping multiple consecutive backslashes."""
assert escape_logstore_query_value("\\\\\\") == '"\\\\\\\\\\\\"' # 3 backslashes -> 6
def test_escape_sequence_like_input(self):
"""Test that existing escape sequences are properly escaped."""
# Input looks like already escaped, but we still escape it
value = 'value\\"test'
escaped = escape_logstore_query_value(value)
# \\ -> \\\\, " -> \"
assert escaped == '"value\\\\\\"test"'
@pytest.mark.parametrize(
("attack_scenario", "field", "malicious_value"),
[
("Cross-tenant via OR", "tenant_id", "tenant_a or tenant_id:tenant_b"),
("Cross-app via AND", "app_id", "app_a and (app_id:app_b or app_id:app_c)"),
("Boolean logic", "status", "succeeded or status:failed"),
("Negation", "tenant_id", "not tenant_a"),
("Field injection", "run_id", "run123 and tenant_id:evil_tenant"),
("Parentheses grouping", "app_id", "app1 or (app_id:app2 and tenant_id:tenant2)"),
("Quote breaking attempt", "tenant_id", 'tenant" or "1"="1'),
("Backslash escape bypass", "app_id", "app\\ and app_id:evil"),
],
)
def test_logstore_query_injection_scenarios(attack_scenario: str, field: str, malicious_value: str):
"""Test that various LogStore query injection attempts are neutralized."""
escaped = escape_logstore_query_value(malicious_value)
# Build query
query = f"{field}:{escaped}"
# All operators should be within quoted string (literals)
assert escaped.startswith('"')
assert escaped.endswith('"')
# Verify the full query structure is safe
assert query.count(":") >= 1 # At least the main field:value separator

View File

@ -0,0 +1,59 @@
"""Unit tests for traceparent header propagation in EnterpriseRequest.
This test module verifies that the W3C traceparent header is properly
generated and included in HTTP requests made by EnterpriseRequest.
"""
from unittest.mock import MagicMock, patch
import pytest
from services.enterprise.base import EnterpriseRequest
class TestTraceparentPropagation:
"""Unit tests for traceparent header propagation."""
@pytest.fixture
def mock_enterprise_config(self):
"""Mock EnterpriseRequest configuration."""
with (
patch.object(EnterpriseRequest, "base_url", "https://enterprise-api.example.com"),
patch.object(EnterpriseRequest, "secret_key", "test-secret-key"),
patch.object(EnterpriseRequest, "secret_key_header", "Enterprise-Api-Secret-Key"),
):
yield
@pytest.fixture
def mock_httpx_client(self):
"""Mock httpx.Client for testing."""
with patch("services.enterprise.base.httpx.Client") as mock_client_class:
mock_client = MagicMock()
mock_client_class.return_value.__enter__.return_value = mock_client
mock_client_class.return_value.__exit__.return_value = None
# Setup default response
mock_response = MagicMock()
mock_response.json.return_value = {"result": "success"}
mock_client.request.return_value = mock_response
yield mock_client
def test_traceparent_header_included_when_generated(self, mock_enterprise_config, mock_httpx_client):
"""Test that traceparent header is included when successfully generated."""
# Arrange
expected_traceparent = "00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01"
with patch("services.enterprise.base.generate_traceparent_header", return_value=expected_traceparent):
# Act
EnterpriseRequest.send_request("GET", "/test")
# Assert
mock_httpx_client.request.assert_called_once()
call_args = mock_httpx_client.request.call_args
headers = call_args[1]["headers"]
assert "traceparent" in headers
assert headers["traceparent"] == expected_traceparent
assert headers["Content-Type"] == "application/json"
assert headers["Enterprise-Api-Secret-Key"] == "test-secret-key"

14
api/uv.lock generated
View File

@ -453,15 +453,15 @@ wheels = [
[[package]]
name = "azure-core"
version = "1.36.0"
version = "1.38.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "requests" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/0a/c4/d4ff3bc3ddf155156460bff340bbe9533f99fac54ddea165f35a8619f162/azure_core-1.36.0.tar.gz", hash = "sha256:22e5605e6d0bf1d229726af56d9e92bc37b6e726b141a18be0b4d424131741b7", size = 351139, upload-time = "2025-10-15T00:33:49.083Z" }
sdist = { url = "https://files.pythonhosted.org/packages/dc/1b/e503e08e755ea94e7d3419c9242315f888fc664211c90d032e40479022bf/azure_core-1.38.0.tar.gz", hash = "sha256:8194d2682245a3e4e3151a667c686464c3786fed7918b394d035bdcd61bb5993", size = 363033, upload-time = "2026-01-12T17:03:05.535Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b1/3c/b90d5afc2e47c4a45f4bba00f9c3193b0417fad5ad3bb07869f9d12832aa/azure_core-1.36.0-py3-none-any.whl", hash = "sha256:fee9923a3a753e94a259563429f3644aaf05c486d45b1215d098115102d91d3b", size = 213302, upload-time = "2025-10-15T00:33:51.058Z" },
{ url = "https://files.pythonhosted.org/packages/fc/d8/b8fcba9464f02b121f39de2db2bf57f0b216fe11d014513d666e8634380d/azure_core-1.38.0-py3-none-any.whl", hash = "sha256:ab0c9b2cd71fecb1842d52c965c95285d3cfb38902f6766e4a471f1cd8905335", size = 217825, upload-time = "2026-01-12T17:03:07.291Z" },
]
[[package]]
@ -1368,7 +1368,7 @@ wheels = [
[[package]]
name = "dify-api"
version = "1.11.2"
version = "1.11.3"
source = { virtual = "." }
dependencies = [
{ name = "aliyun-log-python-sdk" },
@ -1965,11 +1965,11 @@ wheels = [
[[package]]
name = "filelock"
version = "3.20.0"
version = "3.20.3"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/58/46/0028a82567109b5ef6e4d2a1f04a583fb513e6cf9527fcdd09afd817deeb/filelock-3.20.0.tar.gz", hash = "sha256:711e943b4ec6be42e1d4e6690b48dc175c822967466bb31c0c293f34334c13f4", size = 18922, upload-time = "2025-10-08T18:03:50.056Z" }
sdist = { url = "https://files.pythonhosted.org/packages/1d/65/ce7f1b70157833bf3cb851b556a37d4547ceafc158aa9b34b36782f23696/filelock-3.20.3.tar.gz", hash = "sha256:18c57ee915c7ec61cff0ecf7f0f869936c7c30191bb0cf406f1341778d0834e1", size = 19485, upload-time = "2026-01-09T17:55:05.421Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/76/91/7216b27286936c16f5b4d0c530087e4a54eead683e6b0b73dd0c64844af6/filelock-3.20.0-py3-none-any.whl", hash = "sha256:339b4732ffda5cd79b13f4e2711a31b0365ce445d95d243bb996273d072546a2", size = 16054, upload-time = "2025-10-08T18:03:48.35Z" },
{ url = "https://files.pythonhosted.org/packages/b5/36/7fb70f04bf00bc646cd5bb45aa9eddb15e19437a28b8fb2b4a5249fac770/filelock-3.20.3-py3-none-any.whl", hash = "sha256:4b0dda527ee31078689fc205ec4f1c1bf7d56cf88b6dc9426c4f230e46c2dce1", size = 16701, upload-time = "2026-01-09T17:55:04.334Z" },
]
[[package]]

View File

@ -1037,18 +1037,26 @@ WORKFLOW_NODE_EXECUTION_STORAGE=rdbms
# Options:
# - core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository (default)
# - core.repositories.celery_workflow_execution_repository.CeleryWorkflowExecutionRepository
# - extensions.logstore.repositories.logstore_workflow_execution_repository.LogstoreWorkflowExecutionRepository
CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository
# Core workflow node execution repository implementation
# Options:
# - core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository (default)
# - core.repositories.celery_workflow_node_execution_repository.CeleryWorkflowNodeExecutionRepository
# - extensions.logstore.repositories.logstore_workflow_node_execution_repository.LogstoreWorkflowNodeExecutionRepository
CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository
# API workflow run repository implementation
# Options:
# - repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository (default)
# - extensions.logstore.repositories.logstore_api_workflow_run_repository.LogstoreAPIWorkflowRunRepository
API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository
# API workflow node execution repository implementation
# Options:
# - repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository (default)
# - extensions.logstore.repositories.logstore_api_workflow_node_execution_repository.LogstoreAPIWorkflowNodeExecutionRepository
API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository
# Workflow log cleanup configuration

View File

@ -21,7 +21,7 @@ services:
# API service
api:
image: langgenius/dify-api:1.11.2
image: langgenius/dify-api:1.11.3
restart: always
environment:
# Use the shared environment variables.
@ -63,7 +63,7 @@ services:
# worker service
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
worker:
image: langgenius/dify-api:1.11.2
image: langgenius/dify-api:1.11.3
restart: always
environment:
# Use the shared environment variables.
@ -102,7 +102,7 @@ services:
# worker_beat service
# Celery beat for scheduling periodic tasks.
worker_beat:
image: langgenius/dify-api:1.11.2
image: langgenius/dify-api:1.11.3
restart: always
environment:
# Use the shared environment variables.
@ -132,7 +132,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:1.11.2
image: langgenius/dify-web:1.11.3
restart: always
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}

View File

@ -704,7 +704,7 @@ services:
# API service
api:
image: langgenius/dify-api:1.11.2
image: langgenius/dify-api:1.11.3
restart: always
environment:
# Use the shared environment variables.
@ -746,7 +746,7 @@ services:
# worker service
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
worker:
image: langgenius/dify-api:1.11.2
image: langgenius/dify-api:1.11.3
restart: always
environment:
# Use the shared environment variables.
@ -785,7 +785,7 @@ services:
# worker_beat service
# Celery beat for scheduling periodic tasks.
worker_beat:
image: langgenius/dify-api:1.11.2
image: langgenius/dify-api:1.11.3
restart: always
environment:
# Use the shared environment variables.
@ -815,7 +815,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:1.11.2
image: langgenius/dify-web:1.11.3
restart: always
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}

View File

@ -53,6 +53,7 @@ vi.mock('@/context/global-public-context', () => {
)
return {
useGlobalPublicStore,
useIsSystemFeaturesPending: () => false,
}
})

View File

@ -9,8 +9,8 @@ import {
EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION,
EDUCATION_VERIFYING_LOCALSTORAGE_ITEM,
} from '@/app/education-apply/constants'
import { fetchSetupStatus } from '@/service/common'
import { sendGAEvent } from '@/utils/gtag'
import { fetchSetupStatusWithCache } from '@/utils/setup-status'
import { resolvePostLoginRedirect } from '../signin/utils/post-login-redirect'
import { trackEvent } from './base/amplitude'
@ -33,15 +33,8 @@ export const AppInitializer = ({
const isSetupFinished = useCallback(async () => {
try {
if (localStorage.getItem('setup_status') === 'finished')
return true
const setUpStatus = await fetchSetupStatus()
if (setUpStatus.step !== 'finished') {
localStorage.removeItem('setup_status')
return false
}
localStorage.setItem('setup_status', 'finished')
return true
const setUpStatus = await fetchSetupStatusWithCache()
return setUpStatus.step === 'finished'
}
catch (error) {
console.error(error)

View File

@ -34,13 +34,6 @@ vi.mock('@/context/app-context', () => ({
}),
}))
vi.mock('@/service/common', () => ({
fetchCurrentWorkspace: vi.fn(),
fetchLangGeniusVersion: vi.fn(),
fetchUserProfile: vi.fn(),
getSystemFeatures: vi.fn(),
}))
vi.mock('@/service/access-control', () => ({
useAppWhiteListSubjects: (...args: unknown[]) => mockUseAppWhiteListSubjects(...args),
useSearchForWhiteListCandidates: (...args: unknown[]) => mockUseSearchForWhiteListCandidates(...args),
@ -125,7 +118,6 @@ const resetAccessControlStore = () => {
const resetGlobalStore = () => {
useGlobalPublicStore.setState({
systemFeatures: defaultSystemFeatures,
isGlobalPending: false,
})
}

View File

@ -170,8 +170,12 @@ describe('useChatWithHistory', () => {
await waitFor(() => {
expect(mockFetchChatList).toHaveBeenCalledWith('conversation-1', false, 'app-1')
})
expect(result.current.pinnedConversationList).toEqual(pinnedData.data)
expect(result.current.conversationList).toEqual(listData.data)
await waitFor(() => {
expect(result.current.pinnedConversationList).toEqual(pinnedData.data)
})
await waitFor(() => {
expect(result.current.conversationList).toEqual(listData.data)
})
})
})

View File

@ -3,7 +3,8 @@ import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import * as React from 'react'
import { useAppContext } from '@/context/app-context'
import { useAsyncWindowOpen } from '@/hooks/use-async-window-open'
import { fetchBillingUrl, fetchSubscriptionUrls } from '@/service/billing'
import { fetchSubscriptionUrls } from '@/service/billing'
import { consoleClient } from '@/service/client'
import Toast from '../../../../base/toast'
import { ALL_PLANS } from '../../../config'
import { Plan } from '../../../type'
@ -21,10 +22,15 @@ vi.mock('@/context/app-context', () => ({
}))
vi.mock('@/service/billing', () => ({
fetchBillingUrl: vi.fn(),
fetchSubscriptionUrls: vi.fn(),
}))
vi.mock('@/service/client', () => ({
consoleClient: {
billingUrl: vi.fn(),
},
}))
vi.mock('@/hooks/use-async-window-open', () => ({
useAsyncWindowOpen: vi.fn(),
}))
@ -37,7 +43,7 @@ vi.mock('../../assets', () => ({
const mockUseAppContext = useAppContext as Mock
const mockUseAsyncWindowOpen = useAsyncWindowOpen as Mock
const mockFetchBillingUrl = fetchBillingUrl as Mock
const mockBillingUrl = consoleClient.billingUrl as Mock
const mockFetchSubscriptionUrls = fetchSubscriptionUrls as Mock
const mockToastNotify = Toast.notify as Mock
@ -69,7 +75,7 @@ beforeEach(() => {
vi.clearAllMocks()
mockUseAppContext.mockReturnValue({ isCurrentWorkspaceManager: true })
mockUseAsyncWindowOpen.mockReturnValue(vi.fn(async open => await open()))
mockFetchBillingUrl.mockResolvedValue({ url: 'https://billing.example' })
mockBillingUrl.mockResolvedValue({ url: 'https://billing.example' })
mockFetchSubscriptionUrls.mockResolvedValue({ url: 'https://subscription.example' })
assignedHref = ''
})
@ -143,7 +149,7 @@ describe('CloudPlanItem', () => {
type: 'error',
message: 'billing.buyPermissionDeniedTip',
}))
expect(mockFetchBillingUrl).not.toHaveBeenCalled()
expect(mockBillingUrl).not.toHaveBeenCalled()
})
it('should open billing portal when upgrading current paid plan', async () => {
@ -162,7 +168,7 @@ describe('CloudPlanItem', () => {
fireEvent.click(screen.getByRole('button', { name: 'billing.plansCommon.currentPlan' }))
await waitFor(() => {
expect(mockFetchBillingUrl).toHaveBeenCalledTimes(1)
expect(mockBillingUrl).toHaveBeenCalledTimes(1)
})
expect(openWindow).toHaveBeenCalledTimes(1)
})

View File

@ -6,7 +6,8 @@ import { useMemo } from 'react'
import { useTranslation } from 'react-i18next'
import { useAppContext } from '@/context/app-context'
import { useAsyncWindowOpen } from '@/hooks/use-async-window-open'
import { fetchBillingUrl, fetchSubscriptionUrls } from '@/service/billing'
import { fetchSubscriptionUrls } from '@/service/billing'
import { consoleClient } from '@/service/client'
import Toast from '../../../../base/toast'
import { ALL_PLANS } from '../../../config'
import { Plan } from '../../../type'
@ -76,7 +77,7 @@ const CloudPlanItem: FC<CloudPlanItemProps> = ({
try {
if (isCurrentPaidPlan) {
await openAsyncWindow(async () => {
const res = await fetchBillingUrl()
const res = await consoleClient.billingUrl()
if (res.url)
return res.url
throw new Error('Failed to open billing page')

View File

@ -30,8 +30,8 @@ export const useMarketplaceAllPlugins = (providers: any[], searchText: string) =
category: PluginCategoryEnum.datasource,
exclude,
type: 'plugin',
sortBy: 'install_count',
sortOrder: 'DESC',
sort_by: 'install_count',
sort_order: 'DESC',
})
}
else {
@ -39,10 +39,10 @@ export const useMarketplaceAllPlugins = (providers: any[], searchText: string) =
query: '',
category: PluginCategoryEnum.datasource,
type: 'plugin',
pageSize: 1000,
page_size: 1000,
exclude,
sortBy: 'install_count',
sortOrder: 'DESC',
sort_by: 'install_count',
sort_order: 'DESC',
})
}
}, [queryPlugins, queryPluginsWithDebounced, searchText, exclude])

View File

@ -275,8 +275,8 @@ export const useMarketplaceAllPlugins = (providers: ModelProvider[], searchText:
category: PluginCategoryEnum.model,
exclude,
type: 'plugin',
sortBy: 'install_count',
sortOrder: 'DESC',
sort_by: 'install_count',
sort_order: 'DESC',
})
}
else {
@ -284,10 +284,10 @@ export const useMarketplaceAllPlugins = (providers: ModelProvider[], searchText:
query: '',
category: PluginCategoryEnum.model,
type: 'plugin',
pageSize: 1000,
page_size: 1000,
exclude,
sortBy: 'install_count',
sortOrder: 'DESC',
sort_by: 'install_count',
sort_order: 'DESC',
})
}
}, [queryPlugins, queryPluginsWithDebounced, searchText, exclude])

View File

@ -100,11 +100,11 @@ export const useMarketplacePlugins = () => {
const [queryParams, setQueryParams] = useState<PluginsSearchParams>()
const normalizeParams = useCallback((pluginsSearchParams: PluginsSearchParams) => {
const pageSize = pluginsSearchParams.pageSize || 40
const page_size = pluginsSearchParams.page_size || 40
return {
...pluginsSearchParams,
pageSize,
page_size,
}
}, [])
@ -116,20 +116,20 @@ export const useMarketplacePlugins = () => {
plugins: [] as Plugin[],
total: 0,
page: 1,
pageSize: 40,
page_size: 40,
}
}
const params = normalizeParams(queryParams)
const {
query,
sortBy,
sortOrder,
sort_by,
sort_order,
category,
tags,
exclude,
type,
pageSize,
page_size,
} = params
const pluginOrBundle = type === 'bundle' ? 'bundles' : 'plugins'
@ -137,10 +137,10 @@ export const useMarketplacePlugins = () => {
const res = await postMarketplace<{ data: PluginsFromMarketplaceResponse }>(`/${pluginOrBundle}/search/advanced`, {
body: {
page: pageParam,
page_size: pageSize,
page_size,
query,
sort_by: sortBy,
sort_order: sortOrder,
sort_by,
sort_order,
category: category !== 'all' ? category : '',
tags,
exclude,
@ -154,7 +154,7 @@ export const useMarketplacePlugins = () => {
plugins: resPlugins.map(plugin => getFormattedPlugin(plugin)),
total: res.data.total,
page: pageParam,
pageSize,
page_size,
}
}
catch {
@ -162,13 +162,13 @@ export const useMarketplacePlugins = () => {
plugins: [],
total: 0,
page: pageParam,
pageSize,
page_size,
}
}
},
getNextPageParam: (lastPage) => {
const nextPage = lastPage.page + 1
const loaded = lastPage.page * lastPage.pageSize
const loaded = lastPage.page * lastPage.page_size
return loaded < (lastPage.total || 0) ? nextPage : undefined
},
initialPageParam: 1,

View File

@ -2,8 +2,8 @@ import type { SearchParams } from 'nuqs'
import { dehydrate, HydrationBoundary } from '@tanstack/react-query'
import { createLoader } from 'nuqs/server'
import { getQueryClientServer } from '@/context/query-client-server'
import { marketplaceQuery } from '@/service/client'
import { PLUGIN_CATEGORY_WITH_COLLECTIONS } from './constants'
import { marketplaceKeys } from './query'
import { marketplaceSearchParamsParsers } from './search-params'
import { getCollectionsParams, getMarketplaceCollectionsAndPlugins } from './utils'
@ -23,7 +23,7 @@ async function getDehydratedState(searchParams?: Promise<SearchParams>) {
const queryClient = getQueryClientServer()
await queryClient.prefetchQuery({
queryKey: marketplaceKeys.collections(getCollectionsParams(params.category)),
queryKey: marketplaceQuery.collections.queryKey({ input: { query: getCollectionsParams(params.category) } }),
queryFn: () => getMarketplaceCollectionsAndPlugins(getCollectionsParams(params.category)),
})
return dehydrate(queryClient)

View File

@ -60,10 +60,10 @@ vi.mock('@/service/use-plugins', () => ({
// Mock tanstack query
const mockFetchNextPage = vi.fn()
const mockHasNextPage = false
let mockInfiniteQueryData: { pages: Array<{ plugins: unknown[], total: number, page: number, pageSize: number }> } | undefined
let mockInfiniteQueryData: { pages: Array<{ plugins: unknown[], total: number, page: number, page_size: number }> } | undefined
let capturedInfiniteQueryFn: ((ctx: { pageParam: number, signal: AbortSignal }) => Promise<unknown>) | null = null
let capturedQueryFn: ((ctx: { signal: AbortSignal }) => Promise<unknown>) | null = null
let capturedGetNextPageParam: ((lastPage: { page: number, pageSize: number, total: number }) => number | undefined) | null = null
let capturedGetNextPageParam: ((lastPage: { page: number, page_size: number, total: number }) => number | undefined) | null = null
vi.mock('@tanstack/react-query', () => ({
useQuery: vi.fn(({ queryFn, enabled }: { queryFn: (ctx: { signal: AbortSignal }) => Promise<unknown>, enabled: boolean }) => {
@ -83,7 +83,7 @@ vi.mock('@tanstack/react-query', () => ({
}),
useInfiniteQuery: vi.fn(({ queryFn, getNextPageParam, enabled: _enabled }: {
queryFn: (ctx: { pageParam: number, signal: AbortSignal }) => Promise<unknown>
getNextPageParam: (lastPage: { page: number, pageSize: number, total: number }) => number | undefined
getNextPageParam: (lastPage: { page: number, page_size: number, total: number }) => number | undefined
enabled: boolean
}) => {
// Capture queryFn and getNextPageParam for later testing
@ -97,9 +97,9 @@ vi.mock('@tanstack/react-query', () => ({
// Call getNextPageParam to increase coverage
if (getNextPageParam) {
// Test with more data available
getNextPageParam({ page: 1, pageSize: 40, total: 100 })
getNextPageParam({ page: 1, page_size: 40, total: 100 })
// Test with no more data
getNextPageParam({ page: 3, pageSize: 40, total: 100 })
getNextPageParam({ page: 3, page_size: 40, total: 100 })
}
return {
data: mockInfiniteQueryData,
@ -151,6 +151,7 @@ vi.mock('@/service/base', () => ({
// Mock config
vi.mock('@/config', () => ({
API_PREFIX: '/api',
APP_VERSION: '1.0.0',
IS_MARKETPLACE: false,
MARKETPLACE_API_PREFIX: 'https://marketplace.dify.ai/api/v1',
@ -731,10 +732,10 @@ describe('useMarketplacePlugins', () => {
expect(() => {
result.current.queryPlugins({
query: 'test',
sortBy: 'install_count',
sortOrder: 'DESC',
sort_by: 'install_count',
sort_order: 'DESC',
category: 'tool',
pageSize: 20,
page_size: 20,
})
}).not.toThrow()
})
@ -747,7 +748,7 @@ describe('useMarketplacePlugins', () => {
result.current.queryPlugins({
query: 'test',
type: 'bundle',
pageSize: 40,
page_size: 40,
})
}).not.toThrow()
})
@ -798,8 +799,8 @@ describe('useMarketplacePlugins', () => {
result.current.queryPlugins({
query: 'test',
category: 'all',
sortBy: 'install_count',
sortOrder: 'DESC',
sort_by: 'install_count',
sort_order: 'DESC',
})
}).not.toThrow()
})
@ -824,7 +825,7 @@ describe('useMarketplacePlugins', () => {
expect(() => {
result.current.queryPlugins({
query: 'test',
pageSize: 100,
page_size: 100,
})
}).not.toThrow()
})
@ -843,7 +844,7 @@ describe('Hooks queryFn Coverage', () => {
// Set mock data to have pages
mockInfiniteQueryData = {
pages: [
{ plugins: [{ name: 'plugin1' }], total: 10, page: 1, pageSize: 40 },
{ plugins: [{ name: 'plugin1' }], total: 10, page: 1, page_size: 40 },
],
}
@ -863,8 +864,8 @@ describe('Hooks queryFn Coverage', () => {
it('should expose page and total from infinite query data', async () => {
mockInfiniteQueryData = {
pages: [
{ plugins: [{ name: 'plugin1' }, { name: 'plugin2' }], total: 20, page: 1, pageSize: 40 },
{ plugins: [{ name: 'plugin3' }], total: 20, page: 2, pageSize: 40 },
{ plugins: [{ name: 'plugin1' }, { name: 'plugin2' }], total: 20, page: 1, page_size: 40 },
{ plugins: [{ name: 'plugin3' }], total: 20, page: 2, page_size: 40 },
],
}
@ -893,7 +894,7 @@ describe('Hooks queryFn Coverage', () => {
it('should return total from first page when query is set and data exists', async () => {
mockInfiniteQueryData = {
pages: [
{ plugins: [], total: 50, page: 1, pageSize: 40 },
{ plugins: [], total: 50, page: 1, page_size: 40 },
],
}
@ -917,8 +918,8 @@ describe('Hooks queryFn Coverage', () => {
type: 'plugin',
query: 'search test',
category: 'model',
sortBy: 'version_updated_at',
sortOrder: 'ASC',
sort_by: 'version_updated_at',
sort_order: 'ASC',
})
expect(result.current).toBeDefined()
@ -1027,13 +1028,13 @@ describe('Advanced Hook Integration', () => {
// Test with all possible parameters
result.current.queryPlugins({
query: 'comprehensive test',
sortBy: 'install_count',
sortOrder: 'DESC',
sort_by: 'install_count',
sort_order: 'DESC',
category: 'tool',
tags: ['tag1', 'tag2'],
exclude: ['excluded-plugin'],
type: 'plugin',
pageSize: 50,
page_size: 50,
})
expect(result.current).toBeDefined()
@ -1081,9 +1082,9 @@ describe('Direct queryFn Coverage', () => {
result.current.queryPlugins({
query: 'direct test',
category: 'tool',
sortBy: 'install_count',
sortOrder: 'DESC',
pageSize: 40,
sort_by: 'install_count',
sort_order: 'DESC',
page_size: 40,
})
// Now queryFn should be captured and enabled
@ -1255,7 +1256,7 @@ describe('Direct queryFn Coverage', () => {
result.current.queryPlugins({
query: 'structure test',
pageSize: 20,
page_size: 20,
})
if (capturedInfiniteQueryFn) {
@ -1264,14 +1265,14 @@ describe('Direct queryFn Coverage', () => {
plugins: unknown[]
total: number
page: number
pageSize: number
page_size: number
}
// Verify the returned structure
expect(response).toHaveProperty('plugins')
expect(response).toHaveProperty('total')
expect(response).toHaveProperty('page')
expect(response).toHaveProperty('pageSize')
expect(response).toHaveProperty('page_size')
}
})
})
@ -1296,7 +1297,7 @@ describe('flatMap Coverage', () => {
],
total: 5,
page: 1,
pageSize: 40,
page_size: 40,
},
{
plugins: [
@ -1304,7 +1305,7 @@ describe('flatMap Coverage', () => {
],
total: 5,
page: 2,
pageSize: 40,
page_size: 40,
},
],
}
@ -1336,8 +1337,8 @@ describe('flatMap Coverage', () => {
it('should test hook with pages data for flatMap path', async () => {
mockInfiniteQueryData = {
pages: [
{ plugins: [], total: 100, page: 1, pageSize: 40 },
{ plugins: [], total: 100, page: 2, pageSize: 40 },
{ plugins: [], total: 100, page: 1, page_size: 40 },
{ plugins: [], total: 100, page: 2, page_size: 40 },
],
}
@ -1371,7 +1372,7 @@ describe('flatMap Coverage', () => {
plugins: unknown[]
total: number
page: number
pageSize: number
page_size: number
}
// When error is caught, should return fallback data
expect(response.plugins).toEqual([])
@ -1392,15 +1393,15 @@ describe('flatMap Coverage', () => {
// Test getNextPageParam function directly
if (capturedGetNextPageParam) {
// When there are more pages
const nextPage = capturedGetNextPageParam({ page: 1, pageSize: 40, total: 100 })
const nextPage = capturedGetNextPageParam({ page: 1, page_size: 40, total: 100 })
expect(nextPage).toBe(2)
// When all data is loaded
const noMorePages = capturedGetNextPageParam({ page: 3, pageSize: 40, total: 100 })
const noMorePages = capturedGetNextPageParam({ page: 3, page_size: 40, total: 100 })
expect(noMorePages).toBeUndefined()
// Edge case: exactly at boundary
const atBoundary = capturedGetNextPageParam({ page: 2, pageSize: 50, total: 100 })
const atBoundary = capturedGetNextPageParam({ page: 2, page_size: 50, total: 100 })
expect(atBoundary).toBeUndefined()
}
})
@ -1427,7 +1428,7 @@ describe('flatMap Coverage', () => {
plugins: unknown[]
total: number
page: number
pageSize: number
page_size: number
}
// Catch block should return fallback values
expect(response.plugins).toEqual([])
@ -1446,7 +1447,7 @@ describe('flatMap Coverage', () => {
plugins: [{ name: 'test-plugin-1' }, { name: 'test-plugin-2' }],
total: 10,
page: 1,
pageSize: 40,
page_size: 40,
},
],
}
@ -1489,9 +1490,12 @@ describe('Async Utils', () => {
{ type: 'plugin', org: 'test', name: 'plugin2' },
]
globalThis.fetch = vi.fn().mockResolvedValue({
json: () => Promise.resolve({ data: { plugins: mockPlugins } }),
})
globalThis.fetch = vi.fn().mockResolvedValue(
new Response(JSON.stringify({ data: { plugins: mockPlugins } }), {
status: 200,
headers: { 'Content-Type': 'application/json' },
}),
)
const { getMarketplacePluginsByCollectionId } = await import('./utils')
const result = await getMarketplacePluginsByCollectionId('test-collection', {
@ -1514,19 +1518,26 @@ describe('Async Utils', () => {
})
it('should pass abort signal when provided', async () => {
const mockPlugins = [{ type: 'plugin', org: 'test', name: 'plugin1' }]
globalThis.fetch = vi.fn().mockResolvedValue({
json: () => Promise.resolve({ data: { plugins: mockPlugins } }),
})
const mockPlugins = [{ type: 'plugins', org: 'test', name: 'plugin1' }]
globalThis.fetch = vi.fn().mockResolvedValue(
new Response(JSON.stringify({ data: { plugins: mockPlugins } }), {
status: 200,
headers: { 'Content-Type': 'application/json' },
}),
)
const controller = new AbortController()
const { getMarketplacePluginsByCollectionId } = await import('./utils')
await getMarketplacePluginsByCollectionId('test-collection', {}, { signal: controller.signal })
// oRPC uses Request objects, so check that fetch was called with a Request containing the right URL
expect(globalThis.fetch).toHaveBeenCalledWith(
expect.any(String),
expect.objectContaining({ signal: controller.signal }),
expect.any(Request),
expect.any(Object),
)
const call = vi.mocked(globalThis.fetch).mock.calls[0]
const request = call[0] as Request
expect(request.url).toContain('test-collection')
})
})
@ -1535,19 +1546,25 @@ describe('Async Utils', () => {
const mockCollections = [
{ name: 'collection1', label: {}, description: {}, rule: '', created_at: '', updated_at: '' },
]
const mockPlugins = [{ type: 'plugin', org: 'test', name: 'plugin1' }]
const mockPlugins = [{ type: 'plugins', org: 'test', name: 'plugin1' }]
let callCount = 0
globalThis.fetch = vi.fn().mockImplementation(() => {
callCount++
if (callCount === 1) {
return Promise.resolve({
json: () => Promise.resolve({ data: { collections: mockCollections } }),
})
return Promise.resolve(
new Response(JSON.stringify({ data: { collections: mockCollections } }), {
status: 200,
headers: { 'Content-Type': 'application/json' },
}),
)
}
return Promise.resolve({
json: () => Promise.resolve({ data: { plugins: mockPlugins } }),
})
return Promise.resolve(
new Response(JSON.stringify({ data: { plugins: mockPlugins } }), {
status: 200,
headers: { 'Content-Type': 'application/json' },
}),
)
})
const { getMarketplaceCollectionsAndPlugins } = await import('./utils')
@ -1571,9 +1588,12 @@ describe('Async Utils', () => {
})
it('should append condition and type to URL when provided', async () => {
globalThis.fetch = vi.fn().mockResolvedValue({
json: () => Promise.resolve({ data: { collections: [] } }),
})
globalThis.fetch = vi.fn().mockResolvedValue(
new Response(JSON.stringify({ data: { collections: [] } }), {
status: 200,
headers: { 'Content-Type': 'application/json' },
}),
)
const { getMarketplaceCollectionsAndPlugins } = await import('./utils')
await getMarketplaceCollectionsAndPlugins({
@ -1581,10 +1601,11 @@ describe('Async Utils', () => {
type: 'bundle',
})
expect(globalThis.fetch).toHaveBeenCalledWith(
expect.stringContaining('condition=category=tool'),
expect.any(Object),
)
// oRPC uses Request objects, so check that fetch was called with a Request containing the right URL
expect(globalThis.fetch).toHaveBeenCalled()
const call = vi.mocked(globalThis.fetch).mock.calls[0]
const request = call[0] as Request
expect(request.url).toContain('condition=category%3Dtool')
})
})
})

View File

@ -1,22 +1,14 @@
import type { CollectionsAndPluginsSearchParams, PluginsSearchParams } from './types'
import type { PluginsSearchParams } from './types'
import type { MarketPlaceInputs } from '@/contract/router'
import { useInfiniteQuery, useQuery } from '@tanstack/react-query'
import { marketplaceQuery } from '@/service/client'
import { getMarketplaceCollectionsAndPlugins, getMarketplacePlugins } from './utils'
// TODO: Avoid manual maintenance of query keys and better service management,
// https://github.com/langgenius/dify/issues/30342
export const marketplaceKeys = {
all: ['marketplace'] as const,
collections: (params?: CollectionsAndPluginsSearchParams) => [...marketplaceKeys.all, 'collections', params] as const,
collectionPlugins: (collectionId: string, params?: CollectionsAndPluginsSearchParams) => [...marketplaceKeys.all, 'collectionPlugins', collectionId, params] as const,
plugins: (params?: PluginsSearchParams) => [...marketplaceKeys.all, 'plugins', params] as const,
}
export function useMarketplaceCollectionsAndPlugins(
collectionsParams: CollectionsAndPluginsSearchParams,
collectionsParams: MarketPlaceInputs['collections']['query'],
) {
return useQuery({
queryKey: marketplaceKeys.collections(collectionsParams),
queryKey: marketplaceQuery.collections.queryKey({ input: { query: collectionsParams } }),
queryFn: ({ signal }) => getMarketplaceCollectionsAndPlugins(collectionsParams, { signal }),
})
}
@ -25,11 +17,16 @@ export function useMarketplacePlugins(
queryParams: PluginsSearchParams | undefined,
) {
return useInfiniteQuery({
queryKey: marketplaceKeys.plugins(queryParams),
queryKey: marketplaceQuery.searchAdvanced.queryKey({
input: {
body: queryParams!,
params: { kind: queryParams?.type === 'bundle' ? 'bundles' : 'plugins' },
},
}),
queryFn: ({ pageParam = 1, signal }) => getMarketplacePlugins(queryParams, pageParam, signal),
getNextPageParam: (lastPage) => {
const nextPage = lastPage.page + 1
const loaded = lastPage.page * lastPage.pageSize
const loaded = lastPage.page * lastPage.page_size
return loaded < (lastPage.total || 0) ? nextPage : undefined
},
initialPageParam: 1,

View File

@ -26,8 +26,8 @@ export function useMarketplaceData() {
query: searchPluginText,
category: activePluginType === PLUGIN_TYPE_SEARCH_MAP.all ? undefined : activePluginType,
tags: filterPluginTags,
sortBy: sort.sortBy,
sortOrder: sort.sortOrder,
sort_by: sort.sortBy,
sort_order: sort.sortOrder,
type: getMarketplaceListFilterType(activePluginType),
}
}, [isSearchMode, searchPluginText, activePluginType, filterPluginTags, sort])

View File

@ -30,9 +30,9 @@ export type MarketplaceCollectionPluginsResponse = {
export type PluginsSearchParams = {
query: string
page?: number
pageSize?: number
sortBy?: string
sortOrder?: string
page_size?: number
sort_by?: string
sort_order?: string
category?: string
tags?: string[]
exclude?: string[]

View File

@ -4,14 +4,12 @@ import type {
MarketplaceCollection,
PluginsSearchParams,
} from '@/app/components/plugins/marketplace/types'
import type { Plugin, PluginsFromMarketplaceResponse } from '@/app/components/plugins/types'
import type { Plugin } from '@/app/components/plugins/types'
import { PluginCategoryEnum } from '@/app/components/plugins/types'
import {
APP_VERSION,
IS_MARKETPLACE,
MARKETPLACE_API_PREFIX,
} from '@/config'
import { postMarketplace } from '@/service/base'
import { marketplaceClient } from '@/service/client'
import { getMarketplaceUrl } from '@/utils/var'
import { PLUGIN_TYPE_SEARCH_MAP } from './constants'
@ -19,10 +17,6 @@ type MarketplaceFetchOptions = {
signal?: AbortSignal
}
const getMarketplaceHeaders = () => new Headers({
'X-Dify-Version': !IS_MARKETPLACE ? APP_VERSION : '999.0.0',
})
export const getPluginIconInMarketplace = (plugin: Plugin) => {
if (plugin.type === 'bundle')
return `${MARKETPLACE_API_PREFIX}/bundles/${plugin.org}/${plugin.name}/icon`
@ -65,24 +59,15 @@ export const getMarketplacePluginsByCollectionId = async (
let plugins: Plugin[] = []
try {
const url = `${MARKETPLACE_API_PREFIX}/collections/${collectionId}/plugins`
const headers = getMarketplaceHeaders()
const marketplaceCollectionPluginsData = await globalThis.fetch(
url,
{
cache: 'no-store',
method: 'POST',
headers,
signal: options?.signal,
body: JSON.stringify({
category: query?.category,
exclude: query?.exclude,
type: query?.type,
}),
const marketplaceCollectionPluginsDataJson = await marketplaceClient.collectionPlugins({
params: {
collectionId,
},
)
const marketplaceCollectionPluginsDataJson = await marketplaceCollectionPluginsData.json()
plugins = (marketplaceCollectionPluginsDataJson.data.plugins || []).map((plugin: Plugin) => getFormattedPlugin(plugin))
body: query,
}, {
signal: options?.signal,
})
plugins = (marketplaceCollectionPluginsDataJson.data?.plugins || []).map(plugin => getFormattedPlugin(plugin))
}
// eslint-disable-next-line unused-imports/no-unused-vars
catch (e) {
@ -99,22 +84,16 @@ export const getMarketplaceCollectionsAndPlugins = async (
let marketplaceCollections: MarketplaceCollection[] = []
let marketplaceCollectionPluginsMap: Record<string, Plugin[]> = {}
try {
let marketplaceUrl = `${MARKETPLACE_API_PREFIX}/collections?page=1&page_size=100`
if (query?.condition)
marketplaceUrl += `&condition=${query.condition}`
if (query?.type)
marketplaceUrl += `&type=${query.type}`
const headers = getMarketplaceHeaders()
const marketplaceCollectionsData = await globalThis.fetch(
marketplaceUrl,
{
headers,
cache: 'no-store',
signal: options?.signal,
const marketplaceCollectionsDataJson = await marketplaceClient.collections({
query: {
...query,
page: 1,
page_size: 100,
},
)
const marketplaceCollectionsDataJson = await marketplaceCollectionsData.json()
marketplaceCollections = marketplaceCollectionsDataJson.data.collections || []
}, {
signal: options?.signal,
})
marketplaceCollections = marketplaceCollectionsDataJson.data?.collections || []
await Promise.all(marketplaceCollections.map(async (collection: MarketplaceCollection) => {
const plugins = await getMarketplacePluginsByCollectionId(collection.name, query, options)
@ -143,42 +122,42 @@ export const getMarketplacePlugins = async (
plugins: [] as Plugin[],
total: 0,
page: 1,
pageSize: 40,
page_size: 40,
}
}
const {
query,
sortBy,
sortOrder,
sort_by,
sort_order,
category,
tags,
type,
pageSize = 40,
page_size = 40,
} = queryParams
const pluginOrBundle = type === 'bundle' ? 'bundles' : 'plugins'
try {
const res = await postMarketplace<{ data: PluginsFromMarketplaceResponse }>(`/${pluginOrBundle}/search/advanced`, {
const res = await marketplaceClient.searchAdvanced({
params: {
kind: type === 'bundle' ? 'bundles' : 'plugins',
},
body: {
page: pageParam,
page_size: pageSize,
page_size,
query,
sort_by: sortBy,
sort_order: sortOrder,
sort_by,
sort_order,
category: category !== 'all' ? category : '',
tags,
type,
},
signal,
})
}, { signal })
const resPlugins = res.data.bundles || res.data.plugins || []
return {
plugins: resPlugins.map(plugin => getFormattedPlugin(plugin)),
total: res.data.total,
page: pageParam,
pageSize,
page_size,
}
}
catch {
@ -186,7 +165,7 @@ export const getMarketplacePlugins = async (
plugins: [],
total: 0,
page: pageParam,
pageSize,
page_size,
}
}
}

View File

@ -1602,6 +1602,7 @@ export const useNodesInteractions = () => {
const offsetX = currentPosition.x - x
const offsetY = currentPosition.y - y
let idMapping: Record<string, string> = {}
const parentChildrenToAppend: { parentId: string, childId: string, childType: BlockEnum }[] = []
clipboardElements.forEach((nodeToPaste, index) => {
const nodeType = nodeToPaste.data.type
@ -1615,6 +1616,7 @@ export const useNodesInteractions = () => {
_isBundled: false,
_connectedSourceHandleIds: [],
_connectedTargetHandleIds: [],
_dimmed: false,
title: genNewNodeTitleFromOld(nodeToPaste.data.title),
},
position: {
@ -1682,27 +1684,24 @@ export const useNodesInteractions = () => {
return
// handle paste to nested block
if (selectedNode.data.type === BlockEnum.Iteration) {
newNode.data.isInIteration = true
newNode.data.iteration_id = selectedNode.data.iteration_id
newNode.parentId = selectedNode.id
newNode.positionAbsolute = {
x: newNode.position.x,
y: newNode.position.y,
}
// set position base on parent node
newNode.position = getNestedNodePosition(newNode, selectedNode)
}
else if (selectedNode.data.type === BlockEnum.Loop) {
newNode.data.isInLoop = true
newNode.data.loop_id = selectedNode.data.loop_id
if (selectedNode.data.type === BlockEnum.Iteration || selectedNode.data.type === BlockEnum.Loop) {
const isIteration = selectedNode.data.type === BlockEnum.Iteration
newNode.data.isInIteration = isIteration
newNode.data.iteration_id = isIteration ? selectedNode.id : undefined
newNode.data.isInLoop = !isIteration
newNode.data.loop_id = !isIteration ? selectedNode.id : undefined
newNode.parentId = selectedNode.id
newNode.zIndex = isIteration ? ITERATION_CHILDREN_Z_INDEX : LOOP_CHILDREN_Z_INDEX
newNode.positionAbsolute = {
x: newNode.position.x,
y: newNode.position.y,
}
// set position base on parent node
newNode.position = getNestedNodePosition(newNode, selectedNode)
// update parent children array like native add
parentChildrenToAppend.push({ parentId: selectedNode.id, childId: newNode.id, childType: newNode.data.type })
}
}
}
@ -1733,7 +1732,17 @@ export const useNodesInteractions = () => {
}
})
setNodes([...nodes, ...nodesToPaste])
const newNodes = produce(nodes, (draft: Node[]) => {
parentChildrenToAppend.forEach(({ parentId, childId, childType }) => {
const p = draft.find(n => n.id === parentId)
if (p) {
p.data._children?.push({ nodeId: childId, nodeType: childType })
}
})
draft.push(...nodesToPaste)
})
setNodes(newNodes)
setEdges([...edges, ...edgesToPaste])
saveStateToHistory(WorkflowHistoryEvent.NodePaste, {
nodeId: nodesToPaste?.[0]?.id,

View File

@ -7,6 +7,7 @@ import { useCallback } from 'react'
import { useStoreApi } from 'reactflow'
import { useNodesMetaData } from '@/app/components/workflow/hooks'
import {
LOOP_CHILDREN_Z_INDEX,
LOOP_PADDING,
} from '../../constants'
import {
@ -114,9 +115,7 @@ export const useNodeLoopInteractions = () => {
return childrenNodes.map((child, index) => {
const childNodeType = child.data.type as BlockEnum
const {
defaultValue,
} = nodesMetaDataMap![childNodeType]
const { defaultValue } = nodesMetaDataMap![childNodeType]
const nodesWithSameType = nodes.filter(node => node.data.type === childNodeType)
const { newNode } = generateNewNode({
type: getNodeCustomTypeByNodeDataType(childNodeType),
@ -127,15 +126,17 @@ export const useNodeLoopInteractions = () => {
_isBundled: false,
_connectedSourceHandleIds: [],
_connectedTargetHandleIds: [],
_dimmed: false,
title: nodesWithSameType.length > 0 ? `${defaultValue.title} ${nodesWithSameType.length + 1}` : defaultValue.title,
isInLoop: true,
loop_id: newNodeId,
type: childNodeType,
},
position: child.position,
positionAbsolute: child.positionAbsolute,
parentId: newNodeId,
extent: child.extent,
zIndex: child.zIndex,
zIndex: LOOP_CHILDREN_Z_INDEX,
})
newNode.id = `${newNodeId}${newNode.id + index}`
return newNode

View File

@ -16,9 +16,16 @@ vi.mock('@/service/common', () => ({
fetchInitValidateStatus: vi.fn(),
setup: vi.fn(),
login: vi.fn(),
getSystemFeatures: vi.fn(),
}))
vi.mock('@/context/global-public-context', async (importOriginal) => {
const actual = await importOriginal<typeof import('@/context/global-public-context')>()
return {
...actual,
useIsSystemFeaturesPending: () => false,
}
})
const mockFetchSetupStatus = vi.mocked(fetchSetupStatus)
const mockFetchInitValidateStatus = vi.mocked(fetchInitValidateStatus)
const mockSetup = vi.mocked(setup)

View File

@ -2,42 +2,61 @@
import type { FC, PropsWithChildren } from 'react'
import type { SystemFeatures } from '@/types/feature'
import { useQuery } from '@tanstack/react-query'
import { useEffect } from 'react'
import { create } from 'zustand'
import Loading from '@/app/components/base/loading'
import { getSystemFeatures } from '@/service/common'
import { consoleClient } from '@/service/client'
import { defaultSystemFeatures } from '@/types/feature'
import { fetchSetupStatusWithCache } from '@/utils/setup-status'
type GlobalPublicStore = {
isGlobalPending: boolean
setIsGlobalPending: (isPending: boolean) => void
systemFeatures: SystemFeatures
setSystemFeatures: (systemFeatures: SystemFeatures) => void
}
export const useGlobalPublicStore = create<GlobalPublicStore>(set => ({
isGlobalPending: true,
setIsGlobalPending: (isPending: boolean) => set(() => ({ isGlobalPending: isPending })),
systemFeatures: defaultSystemFeatures,
setSystemFeatures: (systemFeatures: SystemFeatures) => set(() => ({ systemFeatures })),
}))
const systemFeaturesQueryKey = ['systemFeatures'] as const
const setupStatusQueryKey = ['setupStatus'] as const
async function fetchSystemFeatures() {
const data = await consoleClient.systemFeatures()
const { setSystemFeatures } = useGlobalPublicStore.getState()
setSystemFeatures({ ...defaultSystemFeatures, ...data })
return data
}
export function useSystemFeaturesQuery() {
return useQuery({
queryKey: systemFeaturesQueryKey,
queryFn: fetchSystemFeatures,
})
}
export function useIsSystemFeaturesPending() {
const { isPending } = useSystemFeaturesQuery()
return isPending
}
export function useSetupStatusQuery() {
return useQuery({
queryKey: setupStatusQueryKey,
queryFn: fetchSetupStatusWithCache,
staleTime: Infinity,
})
}
const GlobalPublicStoreProvider: FC<PropsWithChildren> = ({
children,
}) => {
const { isPending, data } = useQuery({
queryKey: ['systemFeatures'],
queryFn: getSystemFeatures,
})
const { setSystemFeatures, setIsGlobalPending: setIsPending } = useGlobalPublicStore()
useEffect(() => {
if (data)
setSystemFeatures({ ...defaultSystemFeatures, ...data })
}, [data, setSystemFeatures])
// Fetch systemFeatures and setupStatus in parallel to reduce waterfall.
// setupStatus is prefetched here and cached in localStorage for AppInitializer.
const { isPending } = useSystemFeaturesQuery()
useEffect(() => {
setIsPending(isPending)
}, [isPending, setIsPending])
// Prefetch setupStatus for AppInitializer (result not needed here)
useSetupStatusQuery()
if (isPending)
return <div className="flex h-screen w-screen items-center justify-center"><Loading /></div>

View File

@ -10,7 +10,7 @@ import { getProcessedSystemVariablesFromUrlParams } from '@/app/components/base/
import Loading from '@/app/components/base/loading'
import { AccessMode } from '@/models/access-control'
import { useGetWebAppAccessModeByCode } from '@/service/use-share'
import { useGlobalPublicStore } from './global-public-context'
import { useIsSystemFeaturesPending } from './global-public-context'
type WebAppStore = {
shareCode: string | null
@ -65,7 +65,7 @@ const getShareCodeFromPathname = (pathname: string): string | null => {
}
const WebAppStoreProvider: FC<PropsWithChildren> = ({ children }) => {
const isGlobalPending = useGlobalPublicStore(s => s.isGlobalPending)
const isGlobalPending = useIsSystemFeaturesPending()
const updateWebAppAccessMode = useWebAppStore(state => state.updateWebAppAccessMode)
const updateShareCode = useWebAppStore(state => state.updateShareCode)
const updateEmbeddedUserId = useWebAppStore(state => state.updateEmbeddedUserId)

3
web/contract/base.ts Normal file
View File

@ -0,0 +1,3 @@
import { oc } from '@orpc/contract'
export const base = oc.$route({ inputStructure: 'detailed' })

34
web/contract/console.ts Normal file
View File

@ -0,0 +1,34 @@
import type { SystemFeatures } from '@/types/feature'
import { type } from '@orpc/contract'
import { base } from './base'
export const systemFeaturesContract = base
.route({
path: '/system-features',
method: 'GET',
})
.input(type<unknown>())
.output(type<SystemFeatures>())
export const billingUrlContract = base
.route({
path: '/billing/invoices',
method: 'GET',
})
.input(type<unknown>())
.output(type<{ url: string }>())
export const bindPartnerStackContract = base
.route({
path: '/billing/partners/{partnerKey}/tenants',
method: 'PUT',
})
.input(type<{
params: {
partnerKey: string
}
body: {
click_id: string
}
}>())
.output(type<unknown>())

View File

@ -0,0 +1,56 @@
import type { CollectionsAndPluginsSearchParams, MarketplaceCollection, PluginsSearchParams } from '@/app/components/plugins/marketplace/types'
import type { Plugin, PluginsFromMarketplaceResponse } from '@/app/components/plugins/types'
import { type } from '@orpc/contract'
import { base } from './base'
export const collectionsContract = base
.route({
path: '/collections',
method: 'GET',
})
.input(
type<{
query?: CollectionsAndPluginsSearchParams & { page?: number, page_size?: number }
}>(),
)
.output(
type<{
data?: {
collections?: MarketplaceCollection[]
}
}>(),
)
export const collectionPluginsContract = base
.route({
path: '/collections/{collectionId}/plugins',
method: 'POST',
})
.input(
type<{
params: {
collectionId: string
}
body?: CollectionsAndPluginsSearchParams
}>(),
)
.output(
type<{
data?: {
plugins?: Plugin[]
}
}>(),
)
export const searchAdvancedContract = base
.route({
path: '/{kind}/search/advanced',
method: 'POST',
})
.input(type<{
params: {
kind: 'plugins' | 'bundles'
}
body: Omit<PluginsSearchParams, 'type'>
}>())
.output(type<{ data: PluginsFromMarketplaceResponse }>())

19
web/contract/router.ts Normal file
View File

@ -0,0 +1,19 @@
import type { InferContractRouterInputs } from '@orpc/contract'
import { billingUrlContract, bindPartnerStackContract, systemFeaturesContract } from './console'
import { collectionPluginsContract, collectionsContract, searchAdvancedContract } from './marketplace'
export const marketplaceRouterContract = {
collections: collectionsContract,
collectionPlugins: collectionPluginsContract,
searchAdvanced: searchAdvancedContract,
}
export type MarketPlaceInputs = InferContractRouterInputs<typeof marketplaceRouterContract>
export const consoleRouterContract = {
systemFeatures: systemFeaturesContract,
billingUrl: billingUrlContract,
bindPartnerStack: bindPartnerStackContract,
}
export type ConsoleInputs = InferContractRouterInputs<typeof consoleRouterContract>

View File

@ -1,5 +1,5 @@
import { act, renderHook } from '@testing-library/react'
import { useGlobalPublicStore } from '@/context/global-public-context'
import { useGlobalPublicStore, useIsSystemFeaturesPending } from '@/context/global-public-context'
/**
* Test suite for useDocumentTitle hook
*
@ -15,19 +15,25 @@ import { useGlobalPublicStore } from '@/context/global-public-context'
import { defaultSystemFeatures } from '@/types/feature'
import useDocumentTitle from './use-document-title'
vi.mock('@/service/common', () => ({
getSystemFeatures: vi.fn(() => ({ ...defaultSystemFeatures })),
}))
vi.mock('@/context/global-public-context', async (importOriginal) => {
const actual = await importOriginal<typeof import('@/context/global-public-context')>()
return {
...actual,
useIsSystemFeaturesPending: vi.fn(() => false),
}
})
/**
* Test behavior when system features are still loading
* Title should remain empty to prevent flicker
*/
describe('title should be empty if systemFeatures is pending', () => {
act(() => {
useGlobalPublicStore.setState({
systemFeatures: { ...defaultSystemFeatures, branding: { ...defaultSystemFeatures.branding, enabled: false } },
isGlobalPending: true,
beforeEach(() => {
vi.mocked(useIsSystemFeaturesPending).mockReturnValue(true)
act(() => {
useGlobalPublicStore.setState({
systemFeatures: { ...defaultSystemFeatures, branding: { ...defaultSystemFeatures.branding, enabled: false } },
})
})
})
/**
@ -52,9 +58,9 @@ describe('title should be empty if systemFeatures is pending', () => {
*/
describe('use default branding', () => {
beforeEach(() => {
vi.mocked(useIsSystemFeaturesPending).mockReturnValue(false)
act(() => {
useGlobalPublicStore.setState({
isGlobalPending: false,
systemFeatures: { ...defaultSystemFeatures, branding: { ...defaultSystemFeatures.branding, enabled: false } },
})
})
@ -84,9 +90,9 @@ describe('use default branding', () => {
*/
describe('use specific branding', () => {
beforeEach(() => {
vi.mocked(useIsSystemFeaturesPending).mockReturnValue(false)
act(() => {
useGlobalPublicStore.setState({
isGlobalPending: false,
systemFeatures: { ...defaultSystemFeatures, branding: { ...defaultSystemFeatures.branding, enabled: true, application_title: 'Test' } },
})
})

View File

@ -1,11 +1,11 @@
'use client'
import { useFavicon, useTitle } from 'ahooks'
import { useEffect } from 'react'
import { useGlobalPublicStore } from '@/context/global-public-context'
import { useGlobalPublicStore, useIsSystemFeaturesPending } from '@/context/global-public-context'
import { basePath } from '@/utils/var'
export default function useDocumentTitle(title: string) {
const isPending = useGlobalPublicStore(s => s.isGlobalPending)
const isPending = useIsSystemFeaturesPending()
const systemFeatures = useGlobalPublicStore(s => s.systemFeatures)
const prefix = title ? `${title} - ` : ''
let titleStr = ''

View File

@ -1,7 +1,7 @@
{
"name": "dify-web",
"type": "module",
"version": "1.11.2",
"version": "1.11.3",
"private": true,
"packageManager": "pnpm@10.27.0+sha512.72d699da16b1179c14ba9e64dc71c9a40988cbdc65c264cb0e489db7de917f20dcf4d64d8723625f2969ba52d4b7e2a1170682d9ac2a5dcaeaab732b7e16f04a",
"imports": {
@ -69,6 +69,10 @@
"@monaco-editor/react": "^4.7.0",
"@octokit/core": "^6.1.6",
"@octokit/request-error": "^6.1.8",
"@orpc/client": "^1.13.4",
"@orpc/contract": "^1.13.4",
"@orpc/openapi-client": "^1.13.4",
"@orpc/tanstack-query": "^1.13.4",
"@remixicon/react": "^4.7.0",
"@sentry/react": "^8.55.0",
"@svgdotjs/svg.js": "^3.2.5",

137
web/pnpm-lock.yaml generated
View File

@ -108,6 +108,18 @@ importers:
'@octokit/request-error':
specifier: ^6.1.8
version: 6.1.8
'@orpc/client':
specifier: ^1.13.4
version: 1.13.4
'@orpc/contract':
specifier: ^1.13.4
version: 1.13.4
'@orpc/openapi-client':
specifier: ^1.13.4
version: 1.13.4
'@orpc/tanstack-query':
specifier: ^1.13.4
version: 1.13.4(@orpc/client@1.13.4)(@tanstack/query-core@5.90.12)
'@remixicon/react':
specifier: ^4.7.0
version: 4.7.0(react@19.2.3)
@ -2291,6 +2303,38 @@ packages:
'@open-draft/until@2.1.0':
resolution: {integrity: sha512-U69T3ItWHvLwGg5eJ0n3I62nWuE6ilHlmz7zM0npLBRvPRd7e6NYmg54vvRtP5mZG7kZqZCFVdsTWo7BPtBujg==}
'@orpc/client@1.13.4':
resolution: {integrity: sha512-s13GPMeoooJc5Th2EaYT5HMFtWG8S03DUVytYfJv8pIhP87RYKl94w52A36denH6r/B4LaAgBeC9nTAOslK+Og==}
'@orpc/contract@1.13.4':
resolution: {integrity: sha512-TIxyaF67uOlihCRcasjHZxguZpbqfNK7aMrDLnhoufmQBE4OKvguNzmrOFHgsuM0OXoopX0Nuhun1ccaxKP10A==}
'@orpc/openapi-client@1.13.4':
resolution: {integrity: sha512-tRUcY4E6sgpS5bY/9nNES/Q/PMyYyPOsI4TuhwLhfgxOb0GFPwYKJ6Kif7KFNOhx4fkN/jTOfE1nuWuIZU1gyg==}
'@orpc/shared@1.13.4':
resolution: {integrity: sha512-TYt9rLG/BUkNQBeQ6C1tEiHS/Seb8OojHgj9GlvqyjHJhMZx5qjsIyTW6RqLPZJ4U2vgK6x4Her36+tlFCKJug==}
peerDependencies:
'@opentelemetry/api': '>=1.9.0'
peerDependenciesMeta:
'@opentelemetry/api':
optional: true
'@orpc/standard-server-fetch@1.13.4':
resolution: {integrity: sha512-/zmKwnuxfAXbppJpgr1CMnQX3ptPlYcDzLz1TaVzz9VG/Xg58Ov3YhabS2Oi1utLVhy5t4kaCppUducAvoKN+A==}
'@orpc/standard-server-peer@1.13.4':
resolution: {integrity: sha512-UfqnTLqevjCKUk4cmImOG8cQUwANpV1dp9e9u2O1ki6BRBsg/zlXFg6G2N6wP0zr9ayIiO1d2qJdH55yl/1BNw==}
'@orpc/standard-server@1.13.4':
resolution: {integrity: sha512-ZOzgfVp6XUg+wVYw+gqesfRfGPtQbnBIrIiSnFMtZF+6ncmFJeF2Shc4RI2Guqc0Qz25juy8Ogo4tX3YqysOcg==}
'@orpc/tanstack-query@1.13.4':
resolution: {integrity: sha512-gCL/kh3kf6OUGKfXxSoOZpcX1jNYzxGfo/PkLQKX7ui4xiTbfWw3sCDF30sNS4I7yAOnBwDwJ3N2xzfkTftOBg==}
peerDependencies:
'@orpc/client': 1.13.4
'@tanstack/query-core': '>=5.80.2'
'@oxc-resolver/binding-android-arm-eabi@11.15.0':
resolution: {integrity: sha512-Q+lWuFfq7whNelNJIP1dhXaVz4zO9Tu77GcQHyxDWh3MaCoO2Bisphgzmsh4ZoUe2zIchQh6OvQL99GlWHg9Tw==}
cpu: [arm]
@ -6685,6 +6729,9 @@ packages:
resolution: {integrity: sha512-7x81NCL719oNbsq/3mh+hVrAWmFuEYUqrq/Iw3kUzH8ReypT9QQ0BLoJS7/G9k6N81XjW4qHWtjWwe/9eLy1EQ==}
engines: {node: '>=12'}
openapi-types@12.1.3:
resolution: {integrity: sha512-N4YtSYJqghVu4iek2ZUvcN/0aqH1kRDuNqzcycDxhOUpg7GdvLa2F3DgS6yBNhInhv2r/6I0Flkn7CqL8+nIcw==}
opener@1.5.2:
resolution: {integrity: sha512-ur5UIdyw5Y7yEj9wLzhqXiy6GZ3Mwx0yGI+5sMn2r0N0v3cKJvUmFH5yPP+WXh9e0xfyzyJX95D8l088DNFj7A==}
hasBin: true
@ -7081,6 +7128,10 @@ packages:
queue-microtask@1.2.3:
resolution: {integrity: sha512-NuaNSa6flKT5JaSYQzJok04JzTL1CA6aGhv5rfLW3PgqA+M2ChpZQnAC8h8i4ZFkBS8X5RqkDBHA7r4hej3K9A==}
radash@12.1.1:
resolution: {integrity: sha512-h36JMxKRqrAxVD8201FrCpyeNuUY9Y5zZwujr20fFO77tpUtGa6EZzfKw/3WaiBX95fq7+MpsuMLNdSnORAwSA==}
engines: {node: '>=14.18.0'}
randombytes@2.1.0:
resolution: {integrity: sha512-vYl3iOX+4CKUWuxGi9Ukhie6fsqXqS9FE2Zaic4tNFD2N2QQaXOMFbuKK4QmDHC0JO6B1Zp41J0LpT0oR68amQ==}
@ -7826,6 +7877,10 @@ packages:
tabbable@6.3.0:
resolution: {integrity: sha512-EIHvdY5bPLuWForiR/AN2Bxngzpuwn1is4asboytXtpTgsArc+WmSJKVLlhdh71u7jFcryDqB2A8lQvj78MkyQ==}
tagged-tag@1.0.0:
resolution: {integrity: sha512-yEFYrVhod+hdNyx7g5Bnkkb0G6si8HJurOoOEgC8B/O0uXLHlaey/65KRv6cuWBNhBgHKAROVpc7QyYqE5gFng==}
engines: {node: '>=20'}
tailwind-merge@2.6.0:
resolution: {integrity: sha512-P+Vu1qXfzediirmHOC3xKGAYeZtPcV9g76X+xg2FD4tYgR71ewMA35Y3sCz3zhiN/dwefRpJX0yBcgwi1fXNQA==}
@ -8027,13 +8082,17 @@ packages:
resolution: {integrity: sha512-5zknd7Dss75pMSED270A1RQS3KloqRJA9XbXLe0eCxyw7xXFb3rd+9B0UQ/0E+LQT6lnrLviEolYORlRWamn4w==}
engines: {node: '>=16'}
type-fest@5.4.0:
resolution: {integrity: sha512-wfkA6r0tBpVfGiyO+zbf9e10QkRQSlK9F2UvyfnjoCmrvH2bjHyhPzhugSBOuq1dog3P0+FKckqe+Xf6WKVjwg==}
engines: {node: '>=20'}
typescript@5.9.3:
resolution: {integrity: sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==}
engines: {node: '>=14.17'}
hasBin: true
ufo@1.6.1:
resolution: {integrity: sha512-9a4/uxlTWJ4+a5i0ooc1rU7C7YOw3wT+UGqdeNNHWnOF9qcMBgLRS+4IYUqbczewFx4mLEig6gawh7X6mFlEkA==}
ufo@1.6.2:
resolution: {integrity: sha512-heMioaxBcG9+Znsda5Q8sQbWnLJSl98AFDXTO80wELWEzX3hordXsTdxrIfMQoO9IY1MEnoGoPjpoKpMj+Yx0Q==}
uglify-js@3.19.3:
resolution: {integrity: sha512-v3Xu+yuwBXisp6QYTcH4UbH+xYJXqnq2m/LtQVWKWzYc1iehYnLixoQDN9FH6/j9/oybfd6W9Ghwkl8+UMKTKQ==}
@ -10638,6 +10697,66 @@ snapshots:
'@open-draft/until@2.1.0': {}
'@orpc/client@1.13.4':
dependencies:
'@orpc/shared': 1.13.4
'@orpc/standard-server': 1.13.4
'@orpc/standard-server-fetch': 1.13.4
'@orpc/standard-server-peer': 1.13.4
transitivePeerDependencies:
- '@opentelemetry/api'
'@orpc/contract@1.13.4':
dependencies:
'@orpc/client': 1.13.4
'@orpc/shared': 1.13.4
'@standard-schema/spec': 1.1.0
openapi-types: 12.1.3
transitivePeerDependencies:
- '@opentelemetry/api'
'@orpc/openapi-client@1.13.4':
dependencies:
'@orpc/client': 1.13.4
'@orpc/contract': 1.13.4
'@orpc/shared': 1.13.4
'@orpc/standard-server': 1.13.4
transitivePeerDependencies:
- '@opentelemetry/api'
'@orpc/shared@1.13.4':
dependencies:
radash: 12.1.1
type-fest: 5.4.0
'@orpc/standard-server-fetch@1.13.4':
dependencies:
'@orpc/shared': 1.13.4
'@orpc/standard-server': 1.13.4
transitivePeerDependencies:
- '@opentelemetry/api'
'@orpc/standard-server-peer@1.13.4':
dependencies:
'@orpc/shared': 1.13.4
'@orpc/standard-server': 1.13.4
transitivePeerDependencies:
- '@opentelemetry/api'
'@orpc/standard-server@1.13.4':
dependencies:
'@orpc/shared': 1.13.4
transitivePeerDependencies:
- '@opentelemetry/api'
'@orpc/tanstack-query@1.13.4(@orpc/client@1.13.4)(@tanstack/query-core@5.90.12)':
dependencies:
'@orpc/client': 1.13.4
'@orpc/shared': 1.13.4
'@tanstack/query-core': 5.90.12
transitivePeerDependencies:
- '@opentelemetry/api'
'@oxc-resolver/binding-android-arm-eabi@11.15.0':
optional: true
@ -15603,7 +15722,7 @@ snapshots:
acorn: 8.15.0
pathe: 2.0.3
pkg-types: 1.3.1
ufo: 1.6.1
ufo: 1.6.2
monaco-editor@0.55.1:
dependencies:
@ -15766,6 +15885,8 @@ snapshots:
is-docker: 2.2.1
is-wsl: 2.2.0
openapi-types@12.1.3: {}
opener@1.5.2: {}
optionator@0.9.4:
@ -16181,6 +16302,8 @@ snapshots:
queue-microtask@1.2.3: {}
radash@12.1.1: {}
randombytes@2.1.0:
dependencies:
safe-buffer: 5.2.1
@ -17098,6 +17221,8 @@ snapshots:
tabbable@6.3.0: {}
tagged-tag@1.0.0: {}
tailwind-merge@2.6.0: {}
tailwindcss@3.4.18(tsx@4.21.0)(yaml@2.8.2):
@ -17305,9 +17430,13 @@ snapshots:
type-fest@4.2.0:
optional: true
type-fest@5.4.0:
dependencies:
tagged-tag: 1.0.0
typescript@5.9.3: {}
ufo@1.6.1: {}
ufo@1.6.2: {}
uglify-js@3.19.3: {}

View File

@ -81,6 +81,11 @@ export type IOtherOptions = {
needAllResponseContent?: boolean
deleteContentType?: boolean
silent?: boolean
/** If true, behaves like standard fetch: no URL prefix, returns raw Response */
fetchCompat?: boolean
request?: Request
onData?: IOnData // for stream
onThought?: IOnThought
onFile?: IOnFile

View File

@ -1,5 +1,5 @@
import type { CurrentPlanInfoBackend, SubscriptionUrlsBackend } from '@/app/components/billing/type'
import { get, put } from './base'
import { get } from './base'
export const fetchCurrentPlanInfo = () => {
return get<CurrentPlanInfoBackend>('/features')
@ -8,17 +8,3 @@ export const fetchCurrentPlanInfo = () => {
export const fetchSubscriptionUrls = (plan: string, interval: string) => {
return get<SubscriptionUrlsBackend>(`/billing/subscription?plan=${plan}&interval=${interval}`)
}
export const fetchBillingUrl = () => {
return get<{ url: string }>('/billing/invoices')
}
export const bindPartnerStackInfo = (partnerKey: string, clickId: string) => {
return put(`/billing/partners/${partnerKey}/tenants`, {
body: {
click_id: clickId,
},
}, {
silent: true,
})
}

61
web/service/client.ts Normal file
View File

@ -0,0 +1,61 @@
import type { ContractRouterClient } from '@orpc/contract'
import type { JsonifiedClient } from '@orpc/openapi-client'
import { createORPCClient, onError } from '@orpc/client'
import { OpenAPILink } from '@orpc/openapi-client/fetch'
import { createTanstackQueryUtils } from '@orpc/tanstack-query'
import {
API_PREFIX,
APP_VERSION,
IS_MARKETPLACE,
MARKETPLACE_API_PREFIX,
} from '@/config'
import {
consoleRouterContract,
marketplaceRouterContract,
} from '@/contract/router'
import { request } from './base'
const getMarketplaceHeaders = () => new Headers({
'X-Dify-Version': !IS_MARKETPLACE ? APP_VERSION : '999.0.0',
})
const marketplaceLink = new OpenAPILink(marketplaceRouterContract, {
url: MARKETPLACE_API_PREFIX,
headers: () => (getMarketplaceHeaders()),
fetch: (request, init) => {
return globalThis.fetch(request, {
...init,
cache: 'no-store',
})
},
interceptors: [
onError((error) => {
console.error(error)
}),
],
})
export const marketplaceClient: JsonifiedClient<ContractRouterClient<typeof marketplaceRouterContract>> = createORPCClient(marketplaceLink)
export const marketplaceQuery = createTanstackQueryUtils(marketplaceClient, { path: ['marketplace'] })
const consoleLink = new OpenAPILink(consoleRouterContract, {
url: API_PREFIX,
fetch: (input, init) => {
return request(
input.url,
init,
{
fetchCompat: true,
request: input,
},
)
},
interceptors: [
onError((error) => {
console.error(error)
}),
],
})
export const consoleClient: JsonifiedClient<ContractRouterClient<typeof consoleRouterContract>> = createORPCClient(consoleLink)
export const consoleQuery = createTanstackQueryUtils(consoleClient, { path: ['console'] })

View File

@ -34,7 +34,6 @@ import type {
UserProfileOriginResponse,
} from '@/models/common'
import type { RETRIEVE_METHOD } from '@/types/app'
import type { SystemFeatures } from '@/types/feature'
import { del, get, patch, post, put } from './base'
type LoginSuccess = {
@ -307,10 +306,6 @@ export const fetchSupportRetrievalMethods = (url: string): Promise<RetrievalMeth
return get<RetrievalMethodsRes>(url)
}
export const getSystemFeatures = (): Promise<SystemFeatures> => {
return get<SystemFeatures>('/system-features')
}
export const enableModel = (url: string, body: { model: string, model_type: ModelTypeEnum }): Promise<CommonResponse> =>
patch<CommonResponse>(url, { body })

View File

@ -136,6 +136,8 @@ async function base<T>(url: string, options: FetchOptionType = {}, otherOptions:
needAllResponseContent,
deleteContentType,
getAbortController,
fetchCompat = false,
request,
} = otherOptions
let base: string
@ -181,7 +183,7 @@ async function base<T>(url: string, options: FetchOptionType = {}, otherOptions:
},
})
const res = await client(fetchPathname, {
const res = await client(request || fetchPathname, {
...init,
headers,
credentials: isMarketplaceAPI
@ -190,8 +192,8 @@ async function base<T>(url: string, options: FetchOptionType = {}, otherOptions:
retry: {
methods: [],
},
...(bodyStringify ? { json: body } : { body: body as BodyInit }),
searchParams: params,
...(bodyStringify && !fetchCompat ? { json: body } : { body: body as BodyInit }),
searchParams: !fetchCompat ? params : undefined,
fetch(resource: RequestInfo | URL, options?: RequestInit) {
if (resource instanceof Request && options) {
const mergedHeaders = new Headers(options.headers || {})
@ -204,7 +206,7 @@ async function base<T>(url: string, options: FetchOptionType = {}, otherOptions:
},
})
if (needAllResponseContent)
if (needAllResponseContent || fetchCompat)
return res as T
const contentType = res.headers.get('content-type')
if (

View File

@ -1,21 +1,22 @@
import { useMutation, useQuery } from '@tanstack/react-query'
import { bindPartnerStackInfo, fetchBillingUrl } from '@/service/billing'
const NAME_SPACE = 'billing'
import { consoleClient, consoleQuery } from '@/service/client'
export const useBindPartnerStackInfo = () => {
return useMutation({
mutationKey: [NAME_SPACE, 'bind-partner-stack'],
mutationFn: (data: { partnerKey: string, clickId: string }) => bindPartnerStackInfo(data.partnerKey, data.clickId),
mutationKey: consoleQuery.bindPartnerStack.mutationKey(),
mutationFn: (data: { partnerKey: string, clickId: string }) => consoleClient.bindPartnerStack({
params: { partnerKey: data.partnerKey },
body: { click_id: data.clickId },
}),
})
}
export const useBillingUrl = (enabled: boolean) => {
return useQuery({
queryKey: [NAME_SPACE, 'url'],
queryKey: consoleQuery.billingUrl.queryKey(),
enabled,
queryFn: async () => {
const res = await fetchBillingUrl()
const res = await consoleClient.billingUrl()
return res.url
},
})

View File

@ -488,23 +488,23 @@ export const useMutationPluginsFromMarketplace = () => {
mutationFn: (pluginsSearchParams: PluginsSearchParams) => {
const {
query,
sortBy,
sortOrder,
sort_by,
sort_order,
category,
tags,
exclude,
type,
page = 1,
pageSize = 40,
page_size = 40,
} = pluginsSearchParams
const pluginOrBundle = type === 'bundle' ? 'bundles' : 'plugins'
return postMarketplace<{ data: PluginsFromMarketplaceResponse }>(`/${pluginOrBundle}/search/advanced`, {
body: {
page,
page_size: pageSize,
page_size,
query,
sort_by: sortBy,
sort_order: sortOrder,
sort_by,
sort_order,
category: category !== 'all' ? category : '',
tags,
exclude,
@ -535,23 +535,23 @@ export const useFetchPluginListOrBundleList = (pluginsSearchParams: PluginsSearc
queryFn: () => {
const {
query,
sortBy,
sortOrder,
sort_by,
sort_order,
category,
tags,
exclude,
type,
page = 1,
pageSize = 40,
page_size = 40,
} = pluginsSearchParams
const pluginOrBundle = type === 'bundle' ? 'bundles' : 'plugins'
return postMarketplace<{ data: PluginsFromMarketplaceResponse }>(`/${pluginOrBundle}/search/advanced`, {
body: {
page,
page_size: pageSize,
page_size,
query,
sort_by: sortBy,
sort_order: sortOrder,
sort_by,
sort_order,
category: category !== 'all' ? category : '',
tags,
exclude,

View File

@ -0,0 +1,139 @@
import type { SetupStatusResponse } from '@/models/common'
import { fetchSetupStatus } from '@/service/common'
import { fetchSetupStatusWithCache } from './setup-status'
vi.mock('@/service/common', () => ({
fetchSetupStatus: vi.fn(),
}))
const mockFetchSetupStatus = vi.mocked(fetchSetupStatus)
describe('setup-status utilities', () => {
beforeEach(() => {
vi.clearAllMocks()
localStorage.clear()
})
describe('fetchSetupStatusWithCache', () => {
describe('when cache exists', () => {
it('should return cached finished status without API call', async () => {
localStorage.setItem('setup_status', 'finished')
const result = await fetchSetupStatusWithCache()
expect(result).toEqual({ step: 'finished' })
expect(mockFetchSetupStatus).not.toHaveBeenCalled()
})
it('should not modify localStorage when returning cached value', async () => {
localStorage.setItem('setup_status', 'finished')
await fetchSetupStatusWithCache()
expect(localStorage.getItem('setup_status')).toBe('finished')
})
})
describe('when cache does not exist', () => {
it('should call API and cache finished status', async () => {
const apiResponse: SetupStatusResponse = { step: 'finished' }
mockFetchSetupStatus.mockResolvedValue(apiResponse)
const result = await fetchSetupStatusWithCache()
expect(mockFetchSetupStatus).toHaveBeenCalledTimes(1)
expect(result).toEqual(apiResponse)
expect(localStorage.getItem('setup_status')).toBe('finished')
})
it('should call API and remove cache when not finished', async () => {
const apiResponse: SetupStatusResponse = { step: 'not_started' }
mockFetchSetupStatus.mockResolvedValue(apiResponse)
const result = await fetchSetupStatusWithCache()
expect(mockFetchSetupStatus).toHaveBeenCalledTimes(1)
expect(result).toEqual(apiResponse)
expect(localStorage.getItem('setup_status')).toBeNull()
})
it('should clear stale cache when API returns not_started', async () => {
localStorage.setItem('setup_status', 'some_invalid_value')
const apiResponse: SetupStatusResponse = { step: 'not_started' }
mockFetchSetupStatus.mockResolvedValue(apiResponse)
const result = await fetchSetupStatusWithCache()
expect(result).toEqual(apiResponse)
expect(localStorage.getItem('setup_status')).toBeNull()
})
})
describe('cache edge cases', () => {
it('should call API when cache value is empty string', async () => {
localStorage.setItem('setup_status', '')
const apiResponse: SetupStatusResponse = { step: 'finished' }
mockFetchSetupStatus.mockResolvedValue(apiResponse)
const result = await fetchSetupStatusWithCache()
expect(mockFetchSetupStatus).toHaveBeenCalledTimes(1)
expect(result).toEqual(apiResponse)
})
it('should call API when cache value is not "finished"', async () => {
localStorage.setItem('setup_status', 'not_started')
const apiResponse: SetupStatusResponse = { step: 'finished' }
mockFetchSetupStatus.mockResolvedValue(apiResponse)
const result = await fetchSetupStatusWithCache()
expect(mockFetchSetupStatus).toHaveBeenCalledTimes(1)
expect(result).toEqual(apiResponse)
})
it('should call API when localStorage key does not exist', async () => {
const apiResponse: SetupStatusResponse = { step: 'finished' }
mockFetchSetupStatus.mockResolvedValue(apiResponse)
const result = await fetchSetupStatusWithCache()
expect(mockFetchSetupStatus).toHaveBeenCalledTimes(1)
expect(result).toEqual(apiResponse)
})
})
describe('API response handling', () => {
it('should preserve setup_at from API response', async () => {
const setupDate = new Date('2024-01-01')
const apiResponse: SetupStatusResponse = {
step: 'finished',
setup_at: setupDate,
}
mockFetchSetupStatus.mockResolvedValue(apiResponse)
const result = await fetchSetupStatusWithCache()
expect(result).toEqual(apiResponse)
expect(result.setup_at).toEqual(setupDate)
})
it('should propagate API errors', async () => {
const apiError = new Error('Network error')
mockFetchSetupStatus.mockRejectedValue(apiError)
await expect(fetchSetupStatusWithCache()).rejects.toThrow('Network error')
})
it('should not update cache when API call fails', async () => {
mockFetchSetupStatus.mockRejectedValue(new Error('API error'))
await expect(fetchSetupStatusWithCache()).rejects.toThrow()
expect(localStorage.getItem('setup_status')).toBeNull()
})
})
})
})

21
web/utils/setup-status.ts Normal file
View File

@ -0,0 +1,21 @@
import type { SetupStatusResponse } from '@/models/common'
import { fetchSetupStatus } from '@/service/common'
const SETUP_STATUS_KEY = 'setup_status'
const isSetupStatusCached = (): boolean =>
localStorage.getItem(SETUP_STATUS_KEY) === 'finished'
export const fetchSetupStatusWithCache = async (): Promise<SetupStatusResponse> => {
if (isSetupStatusCached())
return { step: 'finished' }
const status = await fetchSetupStatus()
if (status.step === 'finished')
localStorage.setItem(SETUP_STATUS_KEY, 'finished')
else
localStorage.removeItem(SETUP_STATUS_KEY)
return status
}