Merge branch 'main' into feat/model-auth

This commit is contained in:
zxhlyh 2025-08-13 10:12:44 +08:00
commit e69797d738
172 changed files with 8190 additions and 705 deletions

View File

@ -5,7 +5,7 @@ cd web && pnpm install
pipx install uv
echo 'alias start-api="cd /workspaces/dify/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc
echo 'alias start-worker="cd /workspaces/dify/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc
echo 'alias start-worker="cd /workspaces/dify/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage"' >> ~/.bashrc
echo 'alias start-web="cd /workspaces/dify/web && pnpm dev"' >> ~/.bashrc
echo 'alias start-web-prod="cd /workspaces/dify/web && pnpm build && pnpm start"' >> ~/.bashrc
echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d"' >> ~/.bashrc

View File

@ -8,7 +8,7 @@ inputs:
uv-version:
description: UV version to set up
required: true
default: '~=0.7.11'
default: '0.8.9'
uv-lockfile:
description: Path to the UV lockfile to restore cache from
required: true
@ -26,7 +26,7 @@ runs:
python-version: ${{ inputs.python-version }}
- name: Install uv
uses: astral-sh/setup-uv@v5
uses: astral-sh/setup-uv@v6
with:
version: ${{ inputs.uv-version }}
python-version: ${{ inputs.python-version }}

View File

@ -4,7 +4,7 @@ FROM python:3.12-slim-bookworm AS base
WORKDIR /app/api
# Install uv
ENV UV_VERSION=0.7.11
ENV UV_VERSION=0.8.9
RUN pip install --no-cache-dir uv==${UV_VERSION}

View File

@ -74,7 +74,7 @@
10. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
```bash
uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin
uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage
```
Addition, if you want to debug the celery scheduled tasks, you can use the following command in another terminal:

View File

@ -552,12 +552,18 @@ class RepositoryConfig(BaseSettings):
"""
CORE_WORKFLOW_EXECUTION_REPOSITORY: str = Field(
description="Repository implementation for WorkflowExecution. Specify as a module path",
description="Repository implementation for WorkflowExecution. Options: "
"'core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository' (default), "
"'core.repositories.celery_workflow_execution_repository.CeleryWorkflowExecutionRepository'",
default="core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository",
)
CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field(
description="Repository implementation for WorkflowNodeExecution. Specify as a module path",
description="Repository implementation for WorkflowNodeExecution. Options: "
"'core.repositories.sqlalchemy_workflow_node_execution_repository."
"SQLAlchemyWorkflowNodeExecutionRepository' (default), "
"'core.repositories.celery_workflow_node_execution_repository."
"CeleryWorkflowNodeExecutionRepository'",
default="core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository",
)

View File

@ -144,7 +144,8 @@ class DatabaseConfig(BaseSettings):
default="postgresql",
)
@computed_field
@computed_field # type: ignore[misc]
@property
def SQLALCHEMY_DATABASE_URI(self) -> str:
db_extras = (
f"{self.DB_EXTRAS}&client_encoding={self.DB_CHARSET}" if self.DB_CHARSET else self.DB_EXTRAS

View File

@ -862,6 +862,10 @@ class ToolProviderMCPApi(Resource):
parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
parser.add_argument("timeout", type=float, required=False, nullable=False, location="json", default=30)
parser.add_argument(
"sse_read_timeout", type=float, required=False, nullable=False, location="json", default=300
)
args = parser.parse_args()
user = current_user
if not is_valid_url(args["server_url"]):
@ -876,6 +880,8 @@ class ToolProviderMCPApi(Resource):
icon_background=args["icon_background"],
user_id=user.id,
server_identifier=args["server_identifier"],
timeout=args["timeout"],
sse_read_timeout=args["sse_read_timeout"],
)
)
@ -891,6 +897,8 @@ class ToolProviderMCPApi(Resource):
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
parser.add_argument("timeout", type=float, required=False, nullable=True, location="json")
parser.add_argument("sse_read_timeout", type=float, required=False, nullable=True, location="json")
args = parser.parse_args()
if not is_valid_url(args["server_url"]):
if "[__HIDDEN__]" in args["server_url"]:
@ -906,6 +914,8 @@ class ToolProviderMCPApi(Resource):
icon_type=args["icon_type"],
icon_background=args["icon_background"],
server_identifier=args["server_identifier"],
timeout=args.get("timeout"),
sse_read_timeout=args.get("sse_read_timeout"),
)
return {"result": "success"}

View File

@ -568,7 +568,7 @@ class AdvancedChatAppGenerateTaskPipeline:
)
yield workflow_finish_resp
self._base_task_pipeline._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
def _handle_workflow_partial_success_event(
self,
@ -600,7 +600,7 @@ class AdvancedChatAppGenerateTaskPipeline:
)
yield workflow_finish_resp
self._base_task_pipeline._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
def _handle_workflow_failed_event(
self,
@ -845,7 +845,7 @@ class AdvancedChatAppGenerateTaskPipeline:
# Initialize graph runtime state
graph_runtime_state: Optional[GraphRuntimeState] = None
for queue_message in self._base_task_pipeline._queue_manager.listen():
for queue_message in self._base_task_pipeline.queue_manager.listen():
event = queue_message.event
match event:
@ -959,11 +959,11 @@ class AdvancedChatAppGenerateTaskPipeline:
if self._base_task_pipeline._output_moderation_handler:
if self._base_task_pipeline._output_moderation_handler.should_direct_output():
self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output()
self._base_task_pipeline._queue_manager.publish(
self._base_task_pipeline.queue_manager.publish(
QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
)
self._base_task_pipeline._queue_manager.publish(
self._base_task_pipeline.queue_manager.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
)
return True

View File

@ -711,7 +711,7 @@ class WorkflowAppGenerateTaskPipeline:
# Initialize graph runtime state
graph_runtime_state = None
for queue_message in self._base_task_pipeline._queue_manager.listen():
for queue_message in self._base_task_pipeline.queue_manager.listen():
event = queue_message.event
match event:

View File

@ -9,7 +9,6 @@ from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAp
from core.entities.provider_configuration import ProviderModelBundle
from core.file import File, FileUploadConfig
from core.model_runtime.entities.model_entities import AIModelEntity
from core.ops.ops_trace_manager import TraceQueueManager
class InvokeFrom(Enum):
@ -114,7 +113,8 @@ class AppGenerateEntity(BaseModel):
extras: dict[str, Any] = Field(default_factory=dict)
# tracing instance
trace_manager: Optional[TraceQueueManager] = None
# Using Any to avoid circular import with TraceQueueManager
trace_manager: Optional[Any] = None
class EasyUIBasedAppGenerateEntity(AppGenerateEntity):

View File

@ -37,7 +37,7 @@ class BasedGenerateTaskPipeline:
stream: bool,
) -> None:
self._application_generate_entity = application_generate_entity
self._queue_manager = queue_manager
self.queue_manager = queue_manager
self._start_at = time.perf_counter()
self._output_moderation_handler = self._init_output_moderation()
self._stream = stream
@ -113,7 +113,7 @@ class BasedGenerateTaskPipeline:
tenant_id=app_config.tenant_id,
app_id=app_config.app_id,
rule=ModerationRule(type=sensitive_word_avoidance.type, config=sensitive_word_avoidance.config),
queue_manager=self._queue_manager,
queue_manager=self.queue_manager,
)
return None

View File

@ -257,7 +257,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
Process stream response.
:return:
"""
for message in self._queue_manager.listen():
for message in self.queue_manager.listen():
if publisher:
publisher.publish(message)
event = message.event
@ -499,7 +499,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
if self._output_moderation_handler.should_direct_output():
# stop subscribe new token when output moderation should direct output
self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output()
self._queue_manager.publish(
self.queue_manager.publish(
QueueLLMChunkEvent(
chunk=LLMResultChunk(
model=self._task_state.llm_result.model,
@ -513,7 +513,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
PublishFrom.TASK_PIPELINE,
)
self._queue_manager.publish(
self.queue_manager.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
)
return True

View File

@ -327,7 +327,7 @@ def send_message(http_client: httpx.Client, endpoint_url: str, session_message:
)
response.raise_for_status()
logger.debug("Client message sent successfully: %s", response.status_code)
except Exception as exc:
except Exception:
logger.exception("Error sending message")
raise

View File

@ -55,14 +55,10 @@ DEFAULT_QUEUE_READ_TIMEOUT = 3
class StreamableHTTPError(Exception):
"""Base exception for StreamableHTTP transport errors."""
pass
class ResumptionError(StreamableHTTPError):
"""Raised when resumption request is invalid."""
pass
@dataclass
class RequestContext:
@ -74,7 +70,7 @@ class RequestContext:
session_message: SessionMessage
metadata: ClientMessageMetadata | None
server_to_client_queue: ServerToClientQueue # Renamed for clarity
sse_read_timeout: timedelta
sse_read_timeout: float
class StreamableHTTPTransport:
@ -84,8 +80,8 @@ class StreamableHTTPTransport:
self,
url: str,
headers: dict[str, Any] | None = None,
timeout: timedelta = timedelta(seconds=30),
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
timeout: float | timedelta = 30,
sse_read_timeout: float | timedelta = 60 * 5,
) -> None:
"""Initialize the StreamableHTTP transport.
@ -97,8 +93,10 @@ class StreamableHTTPTransport:
"""
self.url = url
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout
self.sse_read_timeout = (
sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout
)
self.session_id: str | None = None
self.request_headers = {
ACCEPT: f"{JSON}, {SSE}",
@ -186,7 +184,7 @@ class StreamableHTTPTransport:
with ssrf_proxy_sse_connect(
self.url,
headers=headers,
timeout=httpx.Timeout(self.timeout.seconds, read=self.sse_read_timeout.seconds),
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
client=client,
method="GET",
) as event_source:
@ -215,7 +213,7 @@ class StreamableHTTPTransport:
with ssrf_proxy_sse_connect(
self.url,
headers=headers,
timeout=httpx.Timeout(self.timeout.seconds, read=ctx.sse_read_timeout.seconds),
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
client=ctx.client,
method="GET",
) as event_source:
@ -402,8 +400,8 @@ class StreamableHTTPTransport:
def streamablehttp_client(
url: str,
headers: dict[str, Any] | None = None,
timeout: timedelta = timedelta(seconds=30),
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
timeout: float | timedelta = 30,
sse_read_timeout: float | timedelta = 60 * 5,
terminate_on_close: bool = True,
) -> Generator[
tuple[
@ -436,7 +434,7 @@ def streamablehttp_client(
try:
with create_ssrf_proxy_mcp_http_client(
headers=transport.request_headers,
timeout=httpx.Timeout(transport.timeout.seconds, read=transport.sse_read_timeout.seconds),
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
) as client:
# Define callbacks that need access to thread pool
def start_get_stream() -> None:

View File

@ -23,12 +23,18 @@ class MCPClient:
authed: bool = True,
authorization_code: Optional[str] = None,
for_list: bool = False,
headers: Optional[dict[str, str]] = None,
timeout: Optional[float] = None,
sse_read_timeout: Optional[float] = None,
):
# Initialize info
self.provider_id = provider_id
self.tenant_id = tenant_id
self.client_type = "streamable"
self.server_url = server_url
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
# Authentication info
self.authed = authed
@ -43,7 +49,7 @@ class MCPClient:
self._session: Optional[ClientSession] = None
self._streams_context: Optional[AbstractContextManager[Any]] = None
self._session_context: Optional[ClientSession] = None
self.exit_stack = ExitStack()
self._exit_stack = ExitStack()
# Whether the client has been initialized
self._initialized = False
@ -90,21 +96,26 @@ class MCPClient:
headers = (
{"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
if self.authed and self.token
else {}
else self.headers
)
self._streams_context = client_factory(
url=self.server_url,
headers=headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
)
self._streams_context = client_factory(url=self.server_url, headers=headers)
if not self._streams_context:
raise MCPConnectionError("Failed to create connection context")
# Use exit_stack to manage context managers properly
if method_name == "mcp":
read_stream, write_stream, _ = self.exit_stack.enter_context(self._streams_context)
read_stream, write_stream, _ = self._exit_stack.enter_context(self._streams_context)
streams = (read_stream, write_stream)
else: # sse_client
streams = self.exit_stack.enter_context(self._streams_context)
streams = self._exit_stack.enter_context(self._streams_context)
self._session_context = ClientSession(*streams)
self._session = self.exit_stack.enter_context(self._session_context)
self._session = self._exit_stack.enter_context(self._session_context)
session = cast(ClientSession, self._session)
session.initialize()
return
@ -120,9 +131,6 @@ class MCPClient:
if first_try:
return self.connect_server(client_factory, method_name, first_try=False)
except MCPConnectionError:
raise
def list_tools(self) -> list[Tool]:
"""Connect to an MCP server running with SSE transport"""
# List available tools to verify connection
@ -142,7 +150,7 @@ class MCPClient:
"""Clean up resources"""
try:
# ExitStack will handle proper cleanup of all managed context managers
self.exit_stack.close()
self._exit_stack.close()
except Exception as e:
logging.exception("Error during cleanup")
raise ValueError(f"Error during cleanup: {e}")

View File

@ -2,7 +2,6 @@ import logging
import queue
from collections.abc import Callable
from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError
from contextlib import ExitStack
from datetime import timedelta
from types import TracebackType
from typing import Any, Generic, Self, TypeVar
@ -170,7 +169,6 @@ class BaseSession(
self._receive_notification_type = receive_notification_type
self._session_read_timeout_seconds = read_timeout_seconds
self._in_flight = {}
self._exit_stack = ExitStack()
# Initialize executor and future to None for proper cleanup checks
self._executor: ThreadPoolExecutor | None = None
self._receiver_future: Future | None = None
@ -377,7 +375,7 @@ class BaseSession(
self._handle_incoming(RuntimeError(f"Server Error: {message}"))
except queue.Empty:
continue
except Exception as e:
except Exception:
logging.exception("Error in message processing loop")
raise
@ -389,14 +387,12 @@ class BaseSession(
If the request is responded to within this method, it will not be
forwarded on to the message stream.
"""
pass
def _received_notification(self, notification: ReceiveNotificationT) -> None:
"""
Can be overridden by subclasses to handle a notification without needing
to listen on the message stream.
"""
pass
def send_progress_notification(
self, progress_token: str | int, progress: float, total: float | None = None
@ -405,11 +401,9 @@ class BaseSession(
Sends a progress notification for a request that is currently being
processed.
"""
pass
def _handle_incoming(
self,
req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception,
) -> None:
"""A generic handler for incoming messages. Overwritten by subclasses."""
pass

View File

@ -1,3 +1,4 @@
import queue
from datetime import timedelta
from typing import Any, Protocol
@ -85,8 +86,8 @@ class ClientSession(
):
def __init__(
self,
read_stream,
write_stream,
read_stream: queue.Queue,
write_stream: queue.Queue,
read_timeout_seconds: timedelta | None = None,
sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,

View File

@ -99,13 +99,13 @@ class TokenBufferMemory:
prompt_messages.append(UserPromptMessage(content=message.query))
else:
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
for file in file_objs:
prompt_message = file_manager.to_prompt_message_content(
file,
image_detail_config=detail,
)
prompt_message_contents.append(prompt_message)
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))

View File

@ -257,11 +257,6 @@ class ModelProviderFactory:
# scan all providers
plugin_model_provider_entities = self.get_plugin_model_providers()
# convert provider_configs to dict
provider_credentials_dict = {}
for provider_config in provider_configs:
provider_credentials_dict[provider_config.provider] = provider_config.credentials
# traverse all model_provider_extensions
providers = []
for plugin_model_provider_entity in plugin_model_provider_entities:

View File

@ -68,7 +68,7 @@ class CommonValidator:
if credential_form_schema.max_length:
if len(value) > credential_form_schema.max_length:
raise ValueError(
f"Variable {credential_form_schema.variable} length should not"
f"Variable {credential_form_schema.variable} length should not be"
f" greater than {credential_form_schema.max_length}"
)

View File

@ -1,10 +0,0 @@
import pydantic
from pydantic import BaseModel
def dump_model(model: BaseModel) -> dict:
if hasattr(pydantic, "model_dump"):
# FIXME mypy error, try to fix it instead of using type: ignore
return pydantic.model_dump(model) # type: ignore
else:
return model.model_dump()

View File

@ -109,8 +109,19 @@ class OracleVector(BaseVector):
)
def _get_connection(self) -> Connection:
connection = oracledb.connect(user=self.config.user, password=self.config.password, dsn=self.config.dsn)
return connection
if self.config.is_autonomous:
connection = oracledb.connect(
user=self.config.user,
password=self.config.password,
dsn=self.config.dsn,
config_dir=self.config.config_dir,
wallet_location=self.config.wallet_location,
wallet_password=self.config.wallet_password,
)
return connection
else:
connection = oracledb.connect(user=self.config.user, password=self.config.password, dsn=self.config.dsn)
return connection
def _create_connection_pool(self, config: OracleVectorConfig):
pool_params = {

View File

@ -5,10 +5,14 @@ This package contains concrete implementations of the repository interfaces
defined in the core.workflow.repository package.
"""
from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository
from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository
from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
__all__ = [
"CeleryWorkflowExecutionRepository",
"CeleryWorkflowNodeExecutionRepository",
"DifyCoreRepositoryFactory",
"RepositoryImportError",
"SQLAlchemyWorkflowNodeExecutionRepository",

View File

@ -0,0 +1,126 @@
"""
Celery-based implementation of the WorkflowExecutionRepository.
This implementation uses Celery tasks for asynchronous storage operations,
providing improved performance by offloading database operations to background workers.
"""
import logging
from typing import Optional, Union
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from core.workflow.entities.workflow_execution import WorkflowExecution
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from libs.helper import extract_tenant_id
from models import Account, CreatorUserRole, EndUser
from models.enums import WorkflowRunTriggeredFrom
from tasks.workflow_execution_tasks import (
save_workflow_execution_task,
)
logger = logging.getLogger(__name__)
class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository):
"""
Celery-based implementation of the WorkflowExecutionRepository interface.
This implementation provides asynchronous storage capabilities by using Celery tasks
to handle database operations in background workers. This improves performance by
reducing the blocking time for workflow execution storage operations.
Key features:
- Asynchronous save operations using Celery tasks
- Support for multi-tenancy through tenant/app filtering
- Automatic retry and error handling through Celery
"""
_session_factory: sessionmaker
_tenant_id: str
_app_id: Optional[str]
_triggered_from: Optional[WorkflowRunTriggeredFrom]
_creator_user_id: str
_creator_user_role: CreatorUserRole
def __init__(
self,
session_factory: sessionmaker | Engine,
user: Union[Account, EndUser],
app_id: Optional[str],
triggered_from: Optional[WorkflowRunTriggeredFrom],
):
"""
Initialize the repository with Celery task configuration and context information.
Args:
session_factory: SQLAlchemy sessionmaker or engine for fallback operations
user: Account or EndUser object containing tenant_id, user ID, and role information
app_id: App ID for filtering by application (can be None)
triggered_from: Source of the execution trigger (DEBUGGING or APP_RUN)
"""
# Store session factory for fallback operations
if isinstance(session_factory, Engine):
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
elif isinstance(session_factory, sessionmaker):
self._session_factory = session_factory
else:
raise ValueError(
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
)
# Extract tenant_id from user
tenant_id = extract_tenant_id(user)
if not tenant_id:
raise ValueError("User must have a tenant_id or current_tenant_id")
self._tenant_id = tenant_id # type: ignore[assignment] # We've already checked tenant_id is not None
# Store app context
self._app_id = app_id
# Extract user context
self._triggered_from = triggered_from
self._creator_user_id = user.id
# Determine user role based on user type
self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER
logger.info(
"Initialized CeleryWorkflowExecutionRepository for tenant %s, app %s, triggered_from %s",
self._tenant_id,
self._app_id,
self._triggered_from,
)
def save(self, execution: WorkflowExecution) -> None:
"""
Save or update a WorkflowExecution instance asynchronously using Celery.
This method queues the save operation as a Celery task and returns immediately,
providing improved performance for high-throughput scenarios.
Args:
execution: The WorkflowExecution instance to save or update
"""
try:
# Serialize execution for Celery task
execution_data = execution.model_dump()
# Queue the save operation as a Celery task (fire and forget)
save_workflow_execution_task.delay(
execution_data=execution_data,
tenant_id=self._tenant_id,
app_id=self._app_id or "",
triggered_from=self._triggered_from.value if self._triggered_from else "",
creator_user_id=self._creator_user_id,
creator_user_role=self._creator_user_role.value,
)
logger.debug("Queued async save for workflow execution: %s", execution.id_)
except Exception as e:
logger.exception("Failed to queue save operation for execution %s", execution.id_)
# In case of Celery failure, we could implement a fallback to synchronous save
# For now, we'll re-raise the exception
raise

View File

@ -0,0 +1,190 @@
"""
Celery-based implementation of the WorkflowNodeExecutionRepository.
This implementation uses Celery tasks for asynchronous storage operations,
providing improved performance by offloading database operations to background workers.
"""
import logging
from collections.abc import Sequence
from typing import Optional, Union
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution
from core.workflow.repositories.workflow_node_execution_repository import (
OrderConfig,
WorkflowNodeExecutionRepository,
)
from libs.helper import extract_tenant_id
from models import Account, CreatorUserRole, EndUser
from models.workflow import WorkflowNodeExecutionTriggeredFrom
from tasks.workflow_node_execution_tasks import (
save_workflow_node_execution_task,
)
logger = logging.getLogger(__name__)
class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
"""
Celery-based implementation of the WorkflowNodeExecutionRepository interface.
This implementation provides asynchronous storage capabilities by using Celery tasks
to handle database operations in background workers. This improves performance by
reducing the blocking time for workflow node execution storage operations.
Key features:
- Asynchronous save operations using Celery tasks
- In-memory cache for immediate reads
- Support for multi-tenancy through tenant/app filtering
- Automatic retry and error handling through Celery
"""
_session_factory: sessionmaker
_tenant_id: str
_app_id: Optional[str]
_triggered_from: Optional[WorkflowNodeExecutionTriggeredFrom]
_creator_user_id: str
_creator_user_role: CreatorUserRole
_execution_cache: dict[str, WorkflowNodeExecution]
_workflow_execution_mapping: dict[str, list[str]]
def __init__(
self,
session_factory: sessionmaker | Engine,
user: Union[Account, EndUser],
app_id: Optional[str],
triggered_from: Optional[WorkflowNodeExecutionTriggeredFrom],
):
"""
Initialize the repository with Celery task configuration and context information.
Args:
session_factory: SQLAlchemy sessionmaker or engine for fallback operations
user: Account or EndUser object containing tenant_id, user ID, and role information
app_id: App ID for filtering by application (can be None)
triggered_from: Source of the execution trigger (SINGLE_STEP or WORKFLOW_RUN)
"""
# Store session factory for fallback operations
if isinstance(session_factory, Engine):
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
elif isinstance(session_factory, sessionmaker):
self._session_factory = session_factory
else:
raise ValueError(
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
)
# Extract tenant_id from user
tenant_id = extract_tenant_id(user)
if not tenant_id:
raise ValueError("User must have a tenant_id or current_tenant_id")
self._tenant_id = tenant_id # type: ignore[assignment] # We've already checked tenant_id is not None
# Store app context
self._app_id = app_id
# Extract user context
self._triggered_from = triggered_from
self._creator_user_id = user.id
# Determine user role based on user type
self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER
# In-memory cache for workflow node executions
self._execution_cache: dict[str, WorkflowNodeExecution] = {}
# Cache for mapping workflow_execution_ids to execution IDs for efficient retrieval
self._workflow_execution_mapping: dict[str, list[str]] = {}
logger.info(
"Initialized CeleryWorkflowNodeExecutionRepository for tenant %s, app %s, triggered_from %s",
self._tenant_id,
self._app_id,
self._triggered_from,
)
def save(self, execution: WorkflowNodeExecution) -> None:
"""
Save or update a WorkflowNodeExecution instance to cache and asynchronously to database.
This method stores the execution in cache immediately for fast reads and queues
the save operation as a Celery task without tracking the task status.
Args:
execution: The WorkflowNodeExecution instance to save or update
"""
try:
# Store in cache immediately for fast reads
self._execution_cache[execution.id] = execution
# Update workflow execution mapping for efficient retrieval
if execution.workflow_execution_id:
if execution.workflow_execution_id not in self._workflow_execution_mapping:
self._workflow_execution_mapping[execution.workflow_execution_id] = []
if execution.id not in self._workflow_execution_mapping[execution.workflow_execution_id]:
self._workflow_execution_mapping[execution.workflow_execution_id].append(execution.id)
# Serialize execution for Celery task
execution_data = execution.model_dump()
# Queue the save operation as a Celery task (fire and forget)
save_workflow_node_execution_task.delay(
execution_data=execution_data,
tenant_id=self._tenant_id,
app_id=self._app_id or "",
triggered_from=self._triggered_from.value if self._triggered_from else "",
creator_user_id=self._creator_user_id,
creator_user_role=self._creator_user_role.value,
)
logger.debug("Cached and queued async save for workflow node execution: %s", execution.id)
except Exception as e:
logger.exception("Failed to cache or queue save operation for node execution %s", execution.id)
# In case of Celery failure, we could implement a fallback to synchronous save
# For now, we'll re-raise the exception
raise
def get_by_workflow_run(
self,
workflow_run_id: str,
order_config: Optional[OrderConfig] = None,
) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all WorkflowNodeExecution instances for a specific workflow run from cache.
Args:
workflow_run_id: The workflow run ID
order_config: Optional configuration for ordering results
Returns:
A sequence of WorkflowNodeExecution instances
"""
try:
# Get execution IDs for this workflow run from cache
execution_ids = self._workflow_execution_mapping.get(workflow_run_id, [])
# Retrieve executions from cache
result = []
for execution_id in execution_ids:
if execution_id in self._execution_cache:
result.append(self._execution_cache[execution_id])
# Apply ordering if specified
if order_config and result:
# Sort based on the order configuration
reverse = order_config.order_direction == "desc"
# Sort by multiple fields if specified
for field_name in reversed(order_config.order_by):
result.sort(key=lambda x: getattr(x, field_name, 0), reverse=reverse)
logger.debug("Retrieved %d workflow node executions for run %s from cache", len(result), workflow_run_id)
return result
except Exception as e:
logger.exception("Failed to get workflow node executions for run %s from cache", workflow_run_id)
return []

View File

@ -94,11 +94,9 @@ class DifyCoreRepositoryFactory:
def _validate_constructor_signature(repository_class: type, required_params: list[str]) -> None:
"""
Validate that a repository class constructor accepts required parameters.
Args:
repository_class: The class to validate
required_params: List of required parameter names
Raises:
RepositoryImportError: If the constructor doesn't accept required parameters
"""
@ -158,10 +156,8 @@ class DifyCoreRepositoryFactory:
try:
repository_class = cls._import_class(class_path)
cls._validate_repository_interface(repository_class, WorkflowExecutionRepository)
cls._validate_constructor_signature(
repository_class, ["session_factory", "user", "app_id", "triggered_from"]
)
# All repository types now use the same constructor parameters
return repository_class( # type: ignore[no-any-return]
session_factory=session_factory,
user=user,
@ -204,10 +200,8 @@ class DifyCoreRepositoryFactory:
try:
repository_class = cls._import_class(class_path)
cls._validate_repository_interface(repository_class, WorkflowNodeExecutionRepository)
cls._validate_constructor_signature(
repository_class, ["session_factory", "user", "app_id", "triggered_from"]
)
# All repository types now use the same constructor parameters
return repository_class( # type: ignore[no-any-return]
session_factory=session_factory,
user=user,

View File

@ -12,8 +12,6 @@ from core.tools.errors import ToolProviderCredentialValidationError
class ToolProviderController(ABC):
entity: ToolProviderEntity
def __init__(self, entity: ToolProviderEntity) -> None:
self.entity = entity

View File

@ -1,5 +1,5 @@
import json
from typing import Any
from typing import Any, Optional
from core.mcp.types import Tool as RemoteMCPTool
from core.tools.__base.tool_provider import ToolProviderController
@ -19,15 +19,24 @@ from services.tools.tools_transform_service import ToolTransformService
class MCPToolProviderController(ToolProviderController):
provider_id: str
entity: ToolProviderEntityWithPlugin
def __init__(self, entity: ToolProviderEntityWithPlugin, provider_id: str, tenant_id: str, server_url: str) -> None:
def __init__(
self,
entity: ToolProviderEntityWithPlugin,
provider_id: str,
tenant_id: str,
server_url: str,
headers: Optional[dict[str, str]] = None,
timeout: Optional[float] = None,
sse_read_timeout: Optional[float] = None,
) -> None:
super().__init__(entity)
self.entity = entity
self.entity: ToolProviderEntityWithPlugin = entity
self.tenant_id = tenant_id
self.provider_id = provider_id
self.server_url = server_url
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
@property
def provider_type(self) -> ToolProviderType:
@ -85,6 +94,9 @@ class MCPToolProviderController(ToolProviderController):
provider_id=db_provider.server_identifier or "",
tenant_id=db_provider.tenant_id or "",
server_url=db_provider.decrypted_server_url,
headers={}, # TODO: get headers from db provider
timeout=db_provider.timeout,
sse_read_timeout=db_provider.sse_read_timeout,
)
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
@ -111,6 +123,9 @@ class MCPToolProviderController(ToolProviderController):
icon=self.entity.identity.icon,
server_url=self.server_url,
provider_id=self.provider_id,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
)
def get_tools(self) -> list[MCPTool]: # type: ignore
@ -125,6 +140,9 @@ class MCPToolProviderController(ToolProviderController):
icon=self.entity.identity.icon,
server_url=self.server_url,
provider_id=self.provider_id,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
)
for tool_entity in self.entity.tools
]

View File

@ -13,13 +13,25 @@ from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, Too
class MCPTool(Tool):
def __init__(
self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, server_url: str, provider_id: str
self,
entity: ToolEntity,
runtime: ToolRuntime,
tenant_id: str,
icon: str,
server_url: str,
provider_id: str,
headers: Optional[dict[str, str]] = None,
timeout: Optional[float] = None,
sse_read_timeout: Optional[float] = None,
) -> None:
super().__init__(entity, runtime)
self.tenant_id = tenant_id
self.icon = icon
self.server_url = server_url
self.provider_id = provider_id
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.MCP
@ -35,7 +47,15 @@ class MCPTool(Tool):
from core.tools.errors import ToolInvokeError
try:
with MCPClient(self.server_url, self.provider_id, self.tenant_id, authed=True) as mcp_client:
with MCPClient(
self.server_url,
self.provider_id,
self.tenant_id,
authed=True,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
) as mcp_client:
tool_parameters = self._handle_none_parameter(tool_parameters)
result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
except MCPAuthError as e:
@ -72,6 +92,9 @@ class MCPTool(Tool):
icon=self.icon,
server_url=self.server_url,
provider_id=self.provider_id,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
)
def _handle_none_parameter(self, parameter: dict[str, Any]) -> dict[str, Any]:

View File

@ -789,9 +789,6 @@ class ToolManager:
"""
get api provider
"""
"""
get tool provider
"""
provider_name = provider
provider_obj: ApiToolProvider | None = (
db.session.query(ApiToolProvider)

View File

@ -4,4 +4,4 @@
#
# If the selector length is more than 2, the remaining parts are the keys / indexes paths used
# to extract part of the variable value.
MIN_SELECTORS_LENGTH = 2
SELECTORS_LENGTH = 2

View File

@ -7,8 +7,8 @@ from pydantic import BaseModel, Field
from core.file import File, FileAttribute, file_manager
from core.variables import Segment, SegmentGroup, Variable
from core.variables.consts import MIN_SELECTORS_LENGTH
from core.variables.segments import FileSegment, NoneSegment
from core.variables.consts import SELECTORS_LENGTH
from core.variables.segments import FileSegment, ObjectSegment
from core.variables.variables import VariableUnion
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.system_variable import SystemVariable
@ -24,7 +24,7 @@ class VariablePool(BaseModel):
# The first element of the selector is the node id, it's the first-level key in the dictionary.
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
# elements of the selector except the first one.
variable_dictionary: defaultdict[str, Annotated[dict[int, VariableUnion], Field(default_factory=dict)]] = Field(
variable_dictionary: defaultdict[str, Annotated[dict[str, VariableUnion], Field(default_factory=dict)]] = Field(
description="Variables mapping",
default=defaultdict(dict),
)
@ -36,6 +36,7 @@ class VariablePool(BaseModel):
)
system_variables: SystemVariable = Field(
description="System variables",
default_factory=SystemVariable.empty,
)
environment_variables: Sequence[VariableUnion] = Field(
description="Environment variables.",
@ -58,23 +59,29 @@ class VariablePool(BaseModel):
def add(self, selector: Sequence[str], value: Any, /) -> None:
"""
Adds a variable to the variable pool.
Add a variable to the variable pool.
NOTE: You should not add a non-Segment value to the variable pool
even if it is allowed now.
This method accepts a selector path and a value, converting the value
to a Variable object if necessary before storing it in the pool.
Args:
selector (Sequence[str]): The selector for the variable.
value (VariableValue): The value of the variable.
selector: A two-element sequence containing [node_id, variable_name].
The selector must have exactly 2 elements to be valid.
value: The value to store. Can be a Variable, Segment, or any value
that can be converted to a Segment (str, int, float, dict, list, File).
Raises:
ValueError: If the selector is invalid.
ValueError: If selector length is not exactly 2 elements.
Returns:
None
Note:
While non-Segment values are currently accepted and automatically
converted, it's recommended to pass Segment or Variable objects directly.
"""
if len(selector) < MIN_SELECTORS_LENGTH:
raise ValueError("Invalid selector")
if len(selector) != SELECTORS_LENGTH:
raise ValueError(
f"Invalid selector: expected {SELECTORS_LENGTH} elements (node_id, variable_name), "
f"got {len(selector)} elements"
)
if isinstance(value, Variable):
variable = value
@ -84,57 +91,85 @@ class VariablePool(BaseModel):
segment = variable_factory.build_segment(value)
variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
key, hash_key = self._selector_to_keys(selector)
node_id, name = self._selector_to_keys(selector)
# Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
self.variable_dictionary[key][hash_key] = cast(VariableUnion, variable)
self.variable_dictionary[node_id][name] = cast(VariableUnion, variable)
@classmethod
def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, int]:
return selector[0], hash(tuple(selector[1:]))
def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, str]:
return selector[0], selector[1]
def _has(self, selector: Sequence[str]) -> bool:
key, hash_key = self._selector_to_keys(selector)
if key not in self.variable_dictionary:
node_id, name = self._selector_to_keys(selector)
if node_id not in self.variable_dictionary:
return False
if hash_key not in self.variable_dictionary[key]:
if name not in self.variable_dictionary[node_id]:
return False
return True
def get(self, selector: Sequence[str], /) -> Segment | None:
"""
Retrieves the value from the variable pool based on the given selector.
Retrieve a variable's value from the pool as a Segment.
This method supports both simple selectors [node_id, variable_name] and
extended selectors that include attribute access for FileSegment and
ObjectSegment types.
Args:
selector (Sequence[str]): The selector used to identify the variable.
selector: A sequence with at least 2 elements:
- [node_id, variable_name]: Returns the full segment
- [node_id, variable_name, attr, ...]: Returns a nested value
from FileSegment (e.g., 'url', 'name') or ObjectSegment
Returns:
Any: The value associated with the given selector.
The Segment associated with the selector, or None if not found.
Returns None if selector has fewer than 2 elements.
Raises:
ValueError: If the selector is invalid.
ValueError: If attempting to access an invalid FileAttribute.
"""
if len(selector) < MIN_SELECTORS_LENGTH:
if len(selector) < SELECTORS_LENGTH:
return None
key, hash_key = self._selector_to_keys(selector)
value: Segment | None = self.variable_dictionary[key].get(hash_key)
node_id, name = self._selector_to_keys(selector)
segment: Segment | None = self.variable_dictionary[node_id].get(name)
if value is None:
selector, attr = selector[:-1], selector[-1]
if segment is None:
return None
if len(selector) == 2:
return segment
if isinstance(segment, FileSegment):
attr = selector[2]
# Python support `attr in FileAttribute` after 3.12
if attr not in {item.value for item in FileAttribute}:
return None
value = self.get(selector)
if not isinstance(value, FileSegment | NoneSegment):
return None
if isinstance(value, FileSegment):
attr = FileAttribute(attr)
attr_value = file_manager.get_attr(file=value.value, attr=attr)
return variable_factory.build_segment(attr_value)
return value
attr = FileAttribute(attr)
attr_value = file_manager.get_attr(file=segment.value, attr=attr)
return variable_factory.build_segment(attr_value)
return value
# Navigate through nested attributes
result: Any = segment
for attr in selector[2:]:
result = self._extract_value(result)
result = self._get_nested_attribute(result, attr)
if result is None:
return None
# Return result as Segment
return result if isinstance(result, Segment) else variable_factory.build_segment(result)
def _extract_value(self, obj: Any) -> Any:
"""Extract the actual value from an ObjectSegment."""
return obj.value if isinstance(obj, ObjectSegment) else obj
def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str) -> Any:
"""Get a nested attribute from a dictionary-like object."""
if not isinstance(obj, dict):
return None
return obj.get(attr)
def remove(self, selector: Sequence[str], /):
"""

View File

@ -15,7 +15,7 @@ from configs import dify_config
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunResult
from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
from core.workflow.graph_engine.entities.event import (
@ -51,7 +51,6 @@ from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.utils import variable_utils
from libs.flask_utils import preserve_flask_contexts
from models.enums import UserFrom
from models.workflow import WorkflowType
@ -701,11 +700,9 @@ class GraphEngine:
route_node_state.status = RouteNodeState.Status.EXCEPTION
if run_result.outputs:
for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively
self._append_variables_recursively(
node_id=node.node_id,
variable_key_list=[variable_key],
variable_value=variable_value,
# Add variables to variable pool
self.graph_runtime_state.variable_pool.add(
[node.node_id, variable_key], variable_value
)
yield NodeRunExceptionEvent(
error=run_result.error or "System Error",
@ -758,11 +755,9 @@ class GraphEngine:
# append node output variables to variable pool
if run_result.outputs:
for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively
self._append_variables_recursively(
node_id=node.node_id,
variable_key_list=[variable_key],
variable_value=variable_value,
# Add variables to variable pool
self.graph_runtime_state.variable_pool.add(
[node.node_id, variable_key], variable_value
)
# When setting metadata, convert to dict first
@ -851,21 +846,6 @@ class GraphEngine:
logger.exception("Node %s run failed", node.title)
raise e
def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
"""
Append variables recursively
:param node_id: node id
:param variable_key_list: variable key list
:param variable_value: variable value
:return:
"""
variable_utils.append_variables_recursively(
self.graph_runtime_state.variable_pool,
node_id,
variable_key_list,
variable_value,
)
def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
"""
Check timeout

View File

@ -4,7 +4,7 @@ from typing import Any, TypeVar
from pydantic import BaseModel
from core.variables import Segment
from core.variables.consts import MIN_SELECTORS_LENGTH
from core.variables.consts import SELECTORS_LENGTH
from core.variables.types import SegmentType
# Use double underscore (`__`) prefix for internal variables
@ -23,7 +23,7 @@ _T = TypeVar("_T", bound=MutableMapping[str, Any])
def variable_to_processed_data(selector: Sequence[str], seg: Segment) -> UpdatedVariable:
if len(selector) < MIN_SELECTORS_LENGTH:
if len(selector) < SELECTORS_LENGTH:
raise Exception("selector too short")
node_id, var_name = selector[:2]
return UpdatedVariable(

View File

@ -4,7 +4,7 @@ from typing import Any, Optional, cast
from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import SegmentType, Variable
from core.variables.consts import MIN_SELECTORS_LENGTH
from core.variables.consts import SELECTORS_LENGTH
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.entities.node_entities import NodeRunResult
@ -46,7 +46,7 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_
selector = item.value
if not isinstance(selector, list):
raise InvalidDataError(f"selector is not a list, {node_id=}, {item=}")
if len(selector) < MIN_SELECTORS_LENGTH:
if len(selector) < SELECTORS_LENGTH:
raise InvalidDataError(f"selector too short, {node_id=}, {item=}")
selector_str = ".".join(selector)
key = f"{node_id}.#{selector_str}#"

View File

@ -1,29 +0,0 @@
from core.variables.segments import ObjectSegment, Segment
from core.workflow.entities.variable_pool import VariablePool, VariableValue
def append_variables_recursively(
pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue | Segment
):
"""
Append variables recursively
:param pool: variable pool to append variables to
:param node_id: node id
:param variable_key_list: variable key list
:param variable_value: variable value
:return:
"""
pool.add([node_id] + variable_key_list, variable_value)
# if variable_value is a dict, then recursively append variables
if isinstance(variable_value, ObjectSegment):
variable_dict = variable_value.value
elif isinstance(variable_value, dict):
variable_dict = variable_value
else:
return
for key, value in variable_dict.items():
# construct new key list
new_key_list = variable_key_list + [key]
append_variables_recursively(pool, node_id=node_id, variable_key_list=new_key_list, variable_value=value)

View File

@ -3,9 +3,8 @@ from collections.abc import Mapping, Sequence
from typing import Any, Protocol
from core.variables import Variable
from core.variables.consts import MIN_SELECTORS_LENGTH
from core.variables.consts import SELECTORS_LENGTH
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.utils import variable_utils
class VariableLoader(Protocol):
@ -78,7 +77,7 @@ def load_into_variable_pool(
variables_to_load.append(list(selector))
loaded = variable_loader.load_variables(variables_to_load)
for var in loaded:
assert len(var.selector) >= MIN_SELECTORS_LENGTH, f"Invalid variable {var}"
variable_utils.append_variables_recursively(
variable_pool, node_id=var.selector[0], variable_key_list=list(var.selector[1:]), variable_value=var
)
assert len(var.selector) >= SELECTORS_LENGTH, f"Invalid variable {var}"
# Add variable directly to the pool
# The variable pool expects 2-element selectors [node_id, variable_name]
variable_pool.add([var.selector[0], var.selector[1]], var)

View File

@ -1,4 +1,5 @@
from collections.abc import Mapping
from decimal import Decimal
from typing import Any
from pydantic import BaseModel
@ -17,6 +18,9 @@ class WorkflowRuntimeTypeConverter:
return value
if isinstance(value, (bool, int, str, float)):
return value
if isinstance(value, Decimal):
# Convert Decimal to float for JSON serialization
return float(value)
if isinstance(value, Segment):
return self._to_json_encodable_recursive(value.value)
if isinstance(value, File):

View File

@ -32,7 +32,7 @@ if [[ "${MODE}" == "worker" ]]; then
exec celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION \
--max-tasks-per-child ${MAX_TASK_PRE_CHILD:-50} --loglevel ${LOG_LEVEL:-INFO} \
-Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion,plugin}
-Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion,plugin,workflow_storage}
elif [[ "${MODE}" == "beat" ]]; then
exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO}

View File

@ -1,18 +1,23 @@
import functools
import logging
from collections.abc import Callable
from typing import Any, Union
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Union
import redis
from redis import RedisError
from redis.cache import CacheConfig
from redis.cluster import ClusterNode, RedisCluster
from redis.connection import Connection, SSLConnection
from redis.lock import Lock
from redis.sentinel import Sentinel
from configs import dify_config
from dify_app import DifyApp
if TYPE_CHECKING:
from redis.lock import Lock
logger = logging.getLogger(__name__)
@ -28,8 +33,8 @@ class RedisClientWrapper:
a failover in a Sentinel-managed Redis setup.
Attributes:
_client (redis.Redis): The actual Redis client instance. It remains None until
initialized with the `initialize` method.
_client: The actual Redis client instance. It remains None until
initialized with the `initialize` method.
Methods:
initialize(client): Initializes the Redis client if it hasn't been initialized already.
@ -37,20 +42,78 @@ class RedisClientWrapper:
if the client is not initialized.
"""
def __init__(self):
_client: Union[redis.Redis, RedisCluster, None]
def __init__(self) -> None:
self._client = None
def initialize(self, client):
def initialize(self, client: Union[redis.Redis, RedisCluster]) -> None:
if self._client is None:
self._client = client
def __getattr__(self, item):
if TYPE_CHECKING:
# Type hints for IDE support and static analysis
# These are not executed at runtime but provide type information
def get(self, name: str | bytes) -> Any: ...
def set(
self,
name: str | bytes,
value: Any,
ex: int | None = None,
px: int | None = None,
nx: bool = False,
xx: bool = False,
keepttl: bool = False,
get: bool = False,
exat: int | None = None,
pxat: int | None = None,
) -> Any: ...
def setex(self, name: str | bytes, time: int | timedelta, value: Any) -> Any: ...
def setnx(self, name: str | bytes, value: Any) -> Any: ...
def delete(self, *names: str | bytes) -> Any: ...
def incr(self, name: str | bytes, amount: int = 1) -> Any: ...
def expire(
self,
name: str | bytes,
time: int | timedelta,
nx: bool = False,
xx: bool = False,
gt: bool = False,
lt: bool = False,
) -> Any: ...
def lock(
self,
name: str,
timeout: float | None = None,
sleep: float = 0.1,
blocking: bool = True,
blocking_timeout: float | None = None,
thread_local: bool = True,
) -> Lock: ...
def zadd(
self,
name: str | bytes,
mapping: dict[str | bytes | int | float, float | int | str | bytes],
nx: bool = False,
xx: bool = False,
ch: bool = False,
incr: bool = False,
gt: bool = False,
lt: bool = False,
) -> Any: ...
def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any: ...
def zcard(self, name: str | bytes) -> Any: ...
def getdel(self, name: str | bytes) -> Any: ...
def __getattr__(self, item: str) -> Any:
if self._client is None:
raise RuntimeError("Redis client is not initialized. Call init_app first.")
return getattr(self._client, item)
redis_client = RedisClientWrapper()
redis_client: RedisClientWrapper = RedisClientWrapper()
def init_app(app: DifyApp):
@ -80,6 +143,9 @@ def init_app(app: DifyApp):
if dify_config.REDIS_USE_SENTINEL:
assert dify_config.REDIS_SENTINELS is not None, "REDIS_SENTINELS must be set when REDIS_USE_SENTINEL is True"
assert dify_config.REDIS_SENTINEL_SERVICE_NAME is not None, (
"REDIS_SENTINEL_SERVICE_NAME must be set when REDIS_USE_SENTINEL is True"
)
sentinel_hosts = [
(node.split(":")[0], int(node.split(":")[1])) for node in dify_config.REDIS_SENTINELS.split(",")
]

View File

@ -0,0 +1,33 @@
"""add timeout for tool_mcp_providers
Revision ID: fa8b0fa6f407
Revises: 532b3f888abf
Create Date: 2025-08-07 11:15:31.517985
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'fa8b0fa6f407'
down_revision = '532b3f888abf'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_mcp_providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('timeout', sa.Float(), server_default=sa.text('30'), nullable=False))
batch_op.add_column(sa.Column('sse_read_timeout', sa.Float(), server_default=sa.text('300'), nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_mcp_providers', schema=None) as batch_op:
batch_op.drop_column('sse_read_timeout')
batch_op.drop_column('timeout')
# ### end Alembic commands ###

View File

@ -278,6 +278,8 @@ class MCPToolProvider(Base):
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
)
timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("30"))
sse_read_timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("300"))
def load_user(self) -> Account | None:
return db.session.query(Account).where(Account.id == self.user_id).first()

View File

@ -1,6 +1,6 @@
[project]
name = "dify-api"
version = "1.7.1"
version = "1.7.2"
requires-python = ">=3.11,<3.13"
dependencies = [
@ -162,6 +162,7 @@ dev = [
"pandas-stubs~=2.2.3",
"scipy-stubs>=1.15.3.0",
"types-python-http-client>=3.3.7.20240910",
"types-redis>=4.6.0.20241004",
]
############################################################

View File

@ -48,7 +48,6 @@ class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory):
RepositoryImportError: If the configured repository cannot be imported or instantiated
"""
class_path = dify_config.API_WORKFLOW_NODE_EXECUTION_REPOSITORY
logger.debug("Creating DifyAPIWorkflowNodeExecutionRepository from: %s", class_path)
try:
repository_class = cls._import_class(class_path)
@ -86,7 +85,6 @@ class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory):
RepositoryImportError: If the configured repository cannot be imported or instantiated
"""
class_path = dify_config.API_WORKFLOW_RUN_REPOSITORY
logger.debug("Creating APIWorkflowRunRepository from: %s", class_path)
try:
repository_class = cls._import_class(class_path)

View File

@ -24,9 +24,20 @@ def queue_monitor_task():
queue_name = "dataset"
threshold = dify_config.QUEUE_MONITOR_THRESHOLD
if threshold is None:
logging.warning(click.style("QUEUE_MONITOR_THRESHOLD is not configured, skipping monitoring", fg="yellow"))
return
try:
queue_length = celery_redis.llen(f"{queue_name}")
logging.info(click.style(f"Start monitor {queue_name}", fg="green"))
if queue_length is None:
logging.error(
click.style(f"Failed to get queue length for {queue_name} - Redis may be unavailable", fg="red")
)
return
logging.info(click.style(f"Queue length: {queue_length}", fg="green"))
if queue_length >= threshold:

View File

@ -59,6 +59,8 @@ class MCPToolManageService:
icon_type: str,
icon_background: str,
server_identifier: str,
timeout: float,
sse_read_timeout: float,
) -> ToolProviderApiEntity:
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
existing_provider = (
@ -91,6 +93,8 @@ class MCPToolManageService:
tools="[]",
icon=json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon,
server_identifier=server_identifier,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
)
db.session.add(mcp_tool)
db.session.commit()
@ -166,6 +170,8 @@ class MCPToolManageService:
icon_type: str,
icon_background: str,
server_identifier: str,
timeout: float | None = None,
sse_read_timeout: float | None = None,
):
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
@ -197,6 +203,10 @@ class MCPToolManageService:
mcp_provider.tools = reconnect_result["tools"]
mcp_provider.encrypted_credentials = reconnect_result["encrypted_credentials"]
if timeout is not None:
mcp_provider.timeout = timeout
if sse_read_timeout is not None:
mcp_provider.sse_read_timeout = sse_read_timeout
db.session.commit()
except IntegrityError as e:
db.session.rollback()

View File

@ -13,7 +13,7 @@ from sqlalchemy.sql.expression import and_, or_
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File
from core.variables import Segment, StringSegment, Variable
from core.variables.consts import MIN_SELECTORS_LENGTH
from core.variables.consts import SELECTORS_LENGTH
from core.variables.segments import ArrayFileSegment, FileSegment
from core.variables.types import SegmentType
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
@ -147,7 +147,7 @@ class WorkflowDraftVariableService:
) -> list[WorkflowDraftVariable]:
ors = []
for selector in selectors:
assert len(selector) >= MIN_SELECTORS_LENGTH, f"Invalid selector to get: {selector}"
assert len(selector) >= SELECTORS_LENGTH, f"Invalid selector to get: {selector}"
node_id, name = selector[:2]
ors.append(and_(WorkflowDraftVariable.node_id == node_id, WorkflowDraftVariable.name == name))
@ -608,7 +608,7 @@ class DraftVariableSaver:
for item in updated_variables:
selector = item.selector
if len(selector) < MIN_SELECTORS_LENGTH:
if len(selector) < SELECTORS_LENGTH:
raise Exception("selector too short")
# NOTE(QuantumGhost): only the following two kinds of variable could be updated by
# VariableAssigner: ConversationVariable and iteration variable.

View File

@ -56,19 +56,29 @@ def clean_dataset_task(
documents = db.session.query(Document).where(Document.dataset_id == dataset_id).all()
segments = db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id).all()
# Fix: Always clean vector database resources regardless of document existence
# This ensures all 33 vector databases properly drop tables/collections/indices
if doc_form is None:
# Use default paragraph index type for empty datasets to enable vector database cleanup
# Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace
# This ensures all invalid doc_form values are properly handled
if doc_form is None or (isinstance(doc_form, str) and not doc_form.strip()):
# Use default paragraph index type for empty/invalid datasets to enable vector database cleanup
from core.rag.index_processor.constant.index_type import IndexType
doc_form = IndexType.PARAGRAPH_INDEX
logging.info(
click.style(f"No documents found, using default index type for cleanup: {doc_form}", fg="yellow")
click.style(f"Invalid doc_form detected, using default index type for cleanup: {doc_form}", fg="yellow")
)
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True)
# Add exception handling around IndexProcessorFactory.clean() to prevent single point of failure
# This ensures Document/Segment deletion can continue even if vector database cleanup fails
try:
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True)
logging.info(click.style(f"Successfully cleaned vector database for dataset: {dataset_id}", fg="green"))
except Exception as index_cleanup_error:
logging.exception(click.style(f"Failed to clean vector database for dataset {dataset_id}", fg="red"))
# Continue with document and segment deletion even if vector cleanup fails
logging.info(
click.style(f"Continuing with document and segment deletion for dataset: {dataset_id}", fg="yellow")
)
if documents is None or len(documents) == 0:
logging.info(click.style(f"No documents found for dataset: {dataset_id}", fg="green"))
@ -128,6 +138,14 @@ def clean_dataset_task(
click.style(f"Cleaned dataset when dataset deleted: {dataset_id} latency: {end_at - start_at}", fg="green")
)
except Exception:
# Add rollback to prevent dirty session state in case of exceptions
# This ensures the database session is properly cleaned up
try:
db.session.rollback()
logging.info(click.style(f"Rolled back database session for dataset: {dataset_id}", fg="yellow"))
except Exception as rollback_error:
logging.exception("Failed to rollback database session")
logging.exception("Cleaned dataset when dataset deleted failed")
finally:
db.session.close()

View File

@ -0,0 +1,136 @@
"""
Celery tasks for asynchronous workflow execution storage operations.
These tasks provide asynchronous storage capabilities for workflow execution data,
improving performance by offloading storage operations to background workers.
"""
import json
import logging
from celery import shared_task # type: ignore[import-untyped]
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from core.workflow.entities.workflow_execution import WorkflowExecution
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.ext_database import db
from models import CreatorUserRole, WorkflowRun
from models.enums import WorkflowRunTriggeredFrom
logger = logging.getLogger(__name__)
@shared_task(queue="workflow_storage", bind=True, max_retries=3, default_retry_delay=60)
def save_workflow_execution_task(
self,
execution_data: dict,
tenant_id: str,
app_id: str,
triggered_from: str,
creator_user_id: str,
creator_user_role: str,
) -> bool:
"""
Asynchronously save or update a workflow execution to the database.
Args:
execution_data: Serialized WorkflowExecution data
tenant_id: Tenant ID for multi-tenancy
app_id: Application ID
triggered_from: Source of the execution trigger
creator_user_id: ID of the user who created the execution
creator_user_role: Role of the user who created the execution
Returns:
True if successful, False otherwise
"""
try:
# Create a new session for this task
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
with session_factory() as session:
# Deserialize execution data
execution = WorkflowExecution.model_validate(execution_data)
# Check if workflow run already exists
existing_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == execution.id_))
if existing_run:
# Update existing workflow run
_update_workflow_run_from_execution(existing_run, execution)
logger.debug("Updated existing workflow run: %s", execution.id_)
else:
# Create new workflow run
workflow_run = _create_workflow_run_from_execution(
execution=execution,
tenant_id=tenant_id,
app_id=app_id,
triggered_from=WorkflowRunTriggeredFrom(triggered_from),
creator_user_id=creator_user_id,
creator_user_role=CreatorUserRole(creator_user_role),
)
session.add(workflow_run)
logger.debug("Created new workflow run: %s", execution.id_)
session.commit()
return True
except Exception as e:
logger.exception("Failed to save workflow execution %s", execution_data.get("id_", "unknown"))
# Retry the task with exponential backoff
raise self.retry(exc=e, countdown=60 * (2**self.request.retries))
def _create_workflow_run_from_execution(
execution: WorkflowExecution,
tenant_id: str,
app_id: str,
triggered_from: WorkflowRunTriggeredFrom,
creator_user_id: str,
creator_user_role: CreatorUserRole,
) -> WorkflowRun:
"""
Create a WorkflowRun database model from a WorkflowExecution domain entity.
"""
workflow_run = WorkflowRun()
workflow_run.id = execution.id_
workflow_run.tenant_id = tenant_id
workflow_run.app_id = app_id
workflow_run.workflow_id = execution.workflow_id
workflow_run.type = execution.workflow_type.value
workflow_run.triggered_from = triggered_from.value
workflow_run.version = execution.workflow_version
json_converter = WorkflowRuntimeTypeConverter()
workflow_run.graph = json.dumps(json_converter.to_json_encodable(execution.graph))
workflow_run.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs))
workflow_run.status = execution.status.value
workflow_run.outputs = (
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
)
workflow_run.error = execution.error_message
workflow_run.elapsed_time = execution.elapsed_time
workflow_run.total_tokens = execution.total_tokens
workflow_run.total_steps = execution.total_steps
workflow_run.created_by_role = creator_user_role.value
workflow_run.created_by = creator_user_id
workflow_run.created_at = execution.started_at
workflow_run.finished_at = execution.finished_at
return workflow_run
def _update_workflow_run_from_execution(workflow_run: WorkflowRun, execution: WorkflowExecution) -> None:
"""
Update a WorkflowRun database model from a WorkflowExecution domain entity.
"""
json_converter = WorkflowRuntimeTypeConverter()
workflow_run.status = execution.status.value
workflow_run.outputs = (
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
)
workflow_run.error = execution.error_message
workflow_run.elapsed_time = execution.elapsed_time
workflow_run.total_tokens = execution.total_tokens
workflow_run.total_steps = execution.total_steps
workflow_run.finished_at = execution.finished_at

View File

@ -0,0 +1,171 @@
"""
Celery tasks for asynchronous workflow node execution storage operations.
These tasks provide asynchronous storage capabilities for workflow node execution data,
improving performance by offloading storage operations to background workers.
"""
import json
import logging
from celery import shared_task # type: ignore[import-untyped]
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
)
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.ext_database import db
from models import CreatorUserRole, WorkflowNodeExecutionModel
from models.workflow import WorkflowNodeExecutionTriggeredFrom
logger = logging.getLogger(__name__)
@shared_task(queue="workflow_storage", bind=True, max_retries=3, default_retry_delay=60)
def save_workflow_node_execution_task(
self,
execution_data: dict,
tenant_id: str,
app_id: str,
triggered_from: str,
creator_user_id: str,
creator_user_role: str,
) -> bool:
"""
Asynchronously save or update a workflow node execution to the database.
Args:
execution_data: Serialized WorkflowNodeExecution data
tenant_id: Tenant ID for multi-tenancy
app_id: Application ID
triggered_from: Source of the execution trigger
creator_user_id: ID of the user who created the execution
creator_user_role: Role of the user who created the execution
Returns:
True if successful, False otherwise
"""
try:
# Create a new session for this task
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
with session_factory() as session:
# Deserialize execution data
execution = WorkflowNodeExecution.model_validate(execution_data)
# Check if node execution already exists
existing_execution = session.scalar(
select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == execution.id)
)
if existing_execution:
# Update existing node execution
_update_node_execution_from_domain(existing_execution, execution)
logger.debug("Updated existing workflow node execution: %s", execution.id)
else:
# Create new node execution
node_execution = _create_node_execution_from_domain(
execution=execution,
tenant_id=tenant_id,
app_id=app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom(triggered_from),
creator_user_id=creator_user_id,
creator_user_role=CreatorUserRole(creator_user_role),
)
session.add(node_execution)
logger.debug("Created new workflow node execution: %s", execution.id)
session.commit()
return True
except Exception as e:
logger.exception("Failed to save workflow node execution %s", execution_data.get("id", "unknown"))
# Retry the task with exponential backoff
raise self.retry(exc=e, countdown=60 * (2**self.request.retries))
def _create_node_execution_from_domain(
execution: WorkflowNodeExecution,
tenant_id: str,
app_id: str,
triggered_from: WorkflowNodeExecutionTriggeredFrom,
creator_user_id: str,
creator_user_role: CreatorUserRole,
) -> WorkflowNodeExecutionModel:
"""
Create a WorkflowNodeExecutionModel database model from a WorkflowNodeExecution domain entity.
"""
node_execution = WorkflowNodeExecutionModel()
node_execution.id = execution.id
node_execution.tenant_id = tenant_id
node_execution.app_id = app_id
node_execution.workflow_id = execution.workflow_id
node_execution.triggered_from = triggered_from.value
node_execution.workflow_run_id = execution.workflow_execution_id
node_execution.index = execution.index
node_execution.predecessor_node_id = execution.predecessor_node_id
node_execution.node_id = execution.node_id
node_execution.node_type = execution.node_type.value
node_execution.title = execution.title
node_execution.node_execution_id = execution.node_execution_id
# Serialize complex data as JSON
json_converter = WorkflowRuntimeTypeConverter()
node_execution.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs)) if execution.inputs else "{}"
node_execution.process_data = (
json.dumps(json_converter.to_json_encodable(execution.process_data)) if execution.process_data else "{}"
)
node_execution.outputs = (
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
)
# Convert metadata enum keys to strings for JSON serialization
if execution.metadata:
metadata_for_json = {
key.value if hasattr(key, "value") else str(key): value for key, value in execution.metadata.items()
}
node_execution.execution_metadata = json.dumps(json_converter.to_json_encodable(metadata_for_json))
else:
node_execution.execution_metadata = "{}"
node_execution.status = execution.status.value
node_execution.error = execution.error
node_execution.elapsed_time = execution.elapsed_time
node_execution.created_by_role = creator_user_role.value
node_execution.created_by = creator_user_id
node_execution.created_at = execution.created_at
node_execution.finished_at = execution.finished_at
return node_execution
def _update_node_execution_from_domain(
node_execution: WorkflowNodeExecutionModel, execution: WorkflowNodeExecution
) -> None:
"""
Update a WorkflowNodeExecutionModel database model from a WorkflowNodeExecution domain entity.
"""
# Update serialized data
json_converter = WorkflowRuntimeTypeConverter()
node_execution.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs)) if execution.inputs else "{}"
node_execution.process_data = (
json.dumps(json_converter.to_json_encodable(execution.process_data)) if execution.process_data else "{}"
)
node_execution.outputs = (
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
)
# Convert metadata enum keys to strings for JSON serialization
if execution.metadata:
metadata_for_json = {
key.value if hasattr(key, "value") else str(key): value for key, value in execution.metadata.items()
}
node_execution.execution_metadata = json.dumps(json_converter.to_json_encodable(metadata_for_json))
else:
node_execution.execution_metadata = "{}"
# Update other fields
node_execution.status = execution.status.value
node_execution.error = execution.error
node_execution.elapsed_time = execution.elapsed_time
node_execution.finished_at = execution.finished_at

View File

@ -55,8 +55,8 @@ def init_code_node(code_config: dict):
environment_variables=[],
conversation_variables=[],
)
variable_pool.add(["code", "123", "args1"], 1)
variable_pool.add(["code", "123", "args2"], 2)
variable_pool.add(["code", "args1"], 1)
variable_pool.add(["code", "args2"], 2)
node = CodeNode(
id=str(uuid.uuid4()),
@ -96,9 +96,9 @@ def test_execute_code(setup_code_executor_mock):
"variables": [
{
"variable": "args1",
"value_selector": ["1", "123", "args1"],
"value_selector": ["1", "args1"],
},
{"variable": "args2", "value_selector": ["1", "123", "args2"]},
{"variable": "args2", "value_selector": ["1", "args2"]},
],
"answer": "123",
"code_language": "python3",
@ -107,8 +107,8 @@ def test_execute_code(setup_code_executor_mock):
}
node = init_code_node(code_config)
node.graph_runtime_state.variable_pool.add(["1", "123", "args1"], 1)
node.graph_runtime_state.variable_pool.add(["1", "123", "args2"], 2)
node.graph_runtime_state.variable_pool.add(["1", "args1"], 1)
node.graph_runtime_state.variable_pool.add(["1", "args2"], 2)
# execute node
result = node._run()
@ -142,9 +142,9 @@ def test_execute_code_output_validator(setup_code_executor_mock):
"variables": [
{
"variable": "args1",
"value_selector": ["1", "123", "args1"],
"value_selector": ["1", "args1"],
},
{"variable": "args2", "value_selector": ["1", "123", "args2"]},
{"variable": "args2", "value_selector": ["1", "args2"]},
],
"answer": "123",
"code_language": "python3",
@ -153,8 +153,8 @@ def test_execute_code_output_validator(setup_code_executor_mock):
}
node = init_code_node(code_config)
node.graph_runtime_state.variable_pool.add(["1", "123", "args1"], 1)
node.graph_runtime_state.variable_pool.add(["1", "123", "args2"], 2)
node.graph_runtime_state.variable_pool.add(["1", "args1"], 1)
node.graph_runtime_state.variable_pool.add(["1", "args2"], 2)
# execute node
result = node._run()
@ -217,9 +217,9 @@ def test_execute_code_output_validator_depth():
"variables": [
{
"variable": "args1",
"value_selector": ["1", "123", "args1"],
"value_selector": ["1", "args1"],
},
{"variable": "args2", "value_selector": ["1", "123", "args2"]},
{"variable": "args2", "value_selector": ["1", "args2"]},
],
"answer": "123",
"code_language": "python3",
@ -307,9 +307,9 @@ def test_execute_code_output_object_list():
"variables": [
{
"variable": "args1",
"value_selector": ["1", "123", "args1"],
"value_selector": ["1", "args1"],
},
{"variable": "args2", "value_selector": ["1", "123", "args2"]},
{"variable": "args2", "value_selector": ["1", "args2"]},
],
"answer": "123",
"code_language": "python3",

View File

@ -49,8 +49,8 @@ def init_http_node(config: dict):
environment_variables=[],
conversation_variables=[],
)
variable_pool.add(["a", "b123", "args1"], 1)
variable_pool.add(["a", "b123", "args2"], 2)
variable_pool.add(["a", "args1"], 1)
variable_pool.add(["a", "args2"], 2)
node = HttpRequestNode(
id=str(uuid.uuid4()),
@ -171,7 +171,7 @@ def test_template(setup_http_mock):
"title": "http",
"desc": "",
"method": "get",
"url": "http://example.com/{{#a.b123.args2#}}",
"url": "http://example.com/{{#a.args2#}}",
"authorization": {
"type": "api-key",
"config": {
@ -180,8 +180,8 @@ def test_template(setup_http_mock):
"header": "api-key",
},
},
"headers": "X-Header:123\nX-Header2:{{#a.b123.args2#}}",
"params": "A:b\nTemplate:{{#a.b123.args2#}}",
"headers": "X-Header:123\nX-Header2:{{#a.args2#}}",
"params": "A:b\nTemplate:{{#a.args2#}}",
"body": None,
},
}
@ -223,7 +223,7 @@ def test_json(setup_http_mock):
{
"key": "",
"type": "text",
"value": '{"a": "{{#a.b123.args1#}}"}',
"value": '{"a": "{{#a.args1#}}"}',
},
],
},
@ -264,12 +264,12 @@ def test_x_www_form_urlencoded(setup_http_mock):
{
"key": "a",
"type": "text",
"value": "{{#a.b123.args1#}}",
"value": "{{#a.args1#}}",
},
{
"key": "b",
"type": "text",
"value": "{{#a.b123.args2#}}",
"value": "{{#a.args2#}}",
},
],
},
@ -310,12 +310,12 @@ def test_form_data(setup_http_mock):
{
"key": "a",
"type": "text",
"value": "{{#a.b123.args1#}}",
"value": "{{#a.args1#}}",
},
{
"key": "b",
"type": "text",
"value": "{{#a.b123.args2#}}",
"value": "{{#a.args2#}}",
},
],
},
@ -436,3 +436,87 @@ def test_multi_colons_parse(setup_http_mock):
assert 'form-data; name="Redirect"\r\n\r\nhttp://example6.com' in result.process_data.get("request", "")
# resp = result.outputs
# assert "http://example3.com" == resp.get("headers", {}).get("referer")
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
def test_nested_object_variable_selector(setup_http_mock):
"""Test variable selector functionality with nested object properties."""
# Create independent test setup without affecting other tests
graph_config = {
"edges": [
{
"id": "start-source-next-target",
"source": "start",
"target": "1",
},
],
"nodes": [
{"data": {"type": "start"}, "id": "start"},
{
"id": "1",
"data": {
"title": "http",
"desc": "",
"method": "get",
"url": "http://example.com/{{#a.args2#}}/{{#a.args3.nested#}}",
"authorization": {
"type": "api-key",
"config": {
"type": "basic",
"api_key": "ak-xxx",
"header": "api-key",
},
},
"headers": "X-Header:{{#a.args3.nested#}}",
"params": "nested_param:{{#a.args3.nested#}}",
"body": None,
},
},
],
}
graph = Graph.init(graph_config=graph_config)
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config=graph_config,
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
# Create independent variable pool for this test only
variable_pool = VariablePool(
system_variables=SystemVariable(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
conversation_variables=[],
)
variable_pool.add(["a", "args1"], 1)
variable_pool.add(["a", "args2"], 2)
variable_pool.add(["a", "args3"], {"nested": "nested_value"}) # Only for this test
node = HttpRequestNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config=graph_config["nodes"][1],
)
# Initialize node data
if "data" in graph_config["nodes"][1]:
node.init_node_data(graph_config["nodes"][1]["data"])
result = node._run()
assert result.process_data is not None
data = result.process_data.get("request", "")
# Verify nested object property is correctly resolved
assert "/2/nested_value" in data # URL path should contain resolved nested value
assert "X-Header: nested_value" in data # Header should contain nested value
assert "nested_param=nested_value" in data # Param should contain nested value

View File

@ -71,8 +71,8 @@ def init_parameter_extractor_node(config: dict):
environment_variables=[],
conversation_variables=[],
)
variable_pool.add(["a", "b123", "args1"], 1)
variable_pool.add(["a", "b123", "args2"], 2)
variable_pool.add(["a", "args1"], 1)
variable_pool.add(["a", "args2"], 2)
node = ParameterExtractorNode(
id=str(uuid.uuid4()),

View File

@ -26,9 +26,9 @@ def test_execute_code(setup_code_executor_mock):
"variables": [
{
"variable": "args1",
"value_selector": ["1", "123", "args1"],
"value_selector": ["1", "args1"],
},
{"variable": "args2", "value_selector": ["1", "123", "args2"]},
{"variable": "args2", "value_selector": ["1", "args2"]},
],
"template": code,
},
@ -66,8 +66,8 @@ def test_execute_code(setup_code_executor_mock):
environment_variables=[],
conversation_variables=[],
)
variable_pool.add(["1", "123", "args1"], 1)
variable_pool.add(["1", "123", "args2"], 3)
variable_pool.add(["1", "args1"], 1)
variable_pool.add(["1", "args2"], 3)
node = TemplateTransformNode(
id=str(uuid.uuid4()),

View File

@ -81,7 +81,7 @@ def test_tool_variable_invoke():
ToolParameterConfigurationManager.decrypt_tool_parameters = MagicMock(return_value={"format": "%Y-%m-%d %H:%M:%S"})
node.graph_runtime_state.variable_pool.add(["1", "123", "args1"], "1+1")
node.graph_runtime_state.variable_pool.add(["1", "args1"], "1+1")
# execute node
result = node._run()

View File

@ -0,0 +1,913 @@
import hashlib
from io import BytesIO
from unittest.mock import patch
import pytest
from faker import Faker
from werkzeug.exceptions import NotFound
from configs import dify_config
from models.account import Account, Tenant
from models.enums import CreatorUserRole
from models.model import EndUser, UploadFile
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
from services.file_service import FileService
class TestFileService:
"""Integration tests for FileService using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("services.file_service.storage") as mock_storage,
patch("services.file_service.file_helpers") as mock_file_helpers,
patch("services.file_service.ExtractProcessor") as mock_extract_processor,
):
# Setup default mock returns
mock_storage.save.return_value = None
mock_storage.load.return_value = BytesIO(b"mock file content")
mock_file_helpers.get_signed_file_url.return_value = "https://example.com/signed-url"
mock_file_helpers.verify_image_signature.return_value = True
mock_file_helpers.verify_file_signature.return_value = True
mock_extract_processor.load_from_upload_file.return_value = "extracted text content"
yield {
"storage": mock_storage,
"file_helpers": mock_file_helpers,
"extract_processor": mock_extract_processor,
}
def _create_test_account(self, db_session_with_containers, mock_external_service_dependencies):
"""
Helper method to create a test account for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
Returns:
Account: Created account instance
"""
fake = Faker()
# Create account
account = Account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
)
from extensions.ext_database import db
db.session.add(account)
db.session.commit()
# Create tenant for the account
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
# Create tenant-account join
from models.account import TenantAccountJoin, TenantAccountRole
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER.value,
current=True,
)
db.session.add(join)
db.session.commit()
# Set current tenant for account
account.current_tenant = tenant
return account
def _create_test_end_user(self, db_session_with_containers, mock_external_service_dependencies):
"""
Helper method to create a test end user for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
Returns:
EndUser: Created end user instance
"""
fake = Faker()
end_user = EndUser(
tenant_id=str(fake.uuid4()),
type="web",
name=fake.name(),
is_anonymous=False,
session_id=fake.uuid4(),
)
from extensions.ext_database import db
db.session.add(end_user)
db.session.commit()
return end_user
def _create_test_upload_file(self, db_session_with_containers, mock_external_service_dependencies, account):
"""
Helper method to create a test upload file for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
account: Account instance
Returns:
UploadFile: Created upload file instance
"""
fake = Faker()
upload_file = UploadFile(
tenant_id=account.current_tenant_id if hasattr(account, "current_tenant_id") else str(fake.uuid4()),
storage_type="local",
key=f"upload_files/test/{fake.uuid4()}.txt",
name="test_file.txt",
size=1024,
extension="txt",
mime_type="text/plain",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
created_at=fake.date_time(),
used=False,
hash=hashlib.sha3_256(b"test content").hexdigest(),
source_url="",
)
from extensions.ext_database import db
db.session.add(upload_file)
db.session.commit()
return upload_file
# Test upload_file method
def test_upload_file_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful file upload with valid parameters.
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
filename = "test_document.pdf"
content = b"test file content"
mimetype = "application/pdf"
upload_file = FileService.upload_file(
filename=filename,
content=content,
mimetype=mimetype,
user=account,
)
assert upload_file is not None
assert upload_file.name == filename
assert upload_file.size == len(content)
assert upload_file.extension == "pdf"
assert upload_file.mime_type == mimetype
assert upload_file.created_by == account.id
assert upload_file.created_by_role == CreatorUserRole.ACCOUNT.value
assert upload_file.used is False
assert upload_file.hash == hashlib.sha3_256(content).hexdigest()
# Verify storage was called
mock_external_service_dependencies["storage"].save.assert_called_once()
# Verify database state
from extensions.ext_database import db
db.session.refresh(upload_file)
assert upload_file.id is not None
def test_upload_file_with_end_user(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test file upload with end user instead of account.
"""
fake = Faker()
end_user = self._create_test_end_user(db_session_with_containers, mock_external_service_dependencies)
filename = "test_image.jpg"
content = b"test image content"
mimetype = "image/jpeg"
upload_file = FileService.upload_file(
filename=filename,
content=content,
mimetype=mimetype,
user=end_user,
)
assert upload_file is not None
assert upload_file.created_by == end_user.id
assert upload_file.created_by_role == CreatorUserRole.END_USER.value
def test_upload_file_with_datasets_source(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test file upload with datasets source parameter.
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
filename = "test_document.pdf"
content = b"test file content"
mimetype = "application/pdf"
upload_file = FileService.upload_file(
filename=filename,
content=content,
mimetype=mimetype,
user=account,
source="datasets",
source_url="https://example.com/source",
)
assert upload_file is not None
assert upload_file.source_url == "https://example.com/source"
def test_upload_file_invalid_filename_characters(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test file upload with invalid filename characters.
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
filename = "test/file<name>.txt"
content = b"test content"
mimetype = "text/plain"
with pytest.raises(ValueError, match="Filename contains invalid characters"):
FileService.upload_file(
filename=filename,
content=content,
mimetype=mimetype,
user=account,
)
def test_upload_file_filename_too_long(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test file upload with filename that exceeds length limit.
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
# Create a filename longer than 200 characters
long_name = "a" * 250
filename = f"{long_name}.txt"
content = b"test content"
mimetype = "text/plain"
upload_file = FileService.upload_file(
filename=filename,
content=content,
mimetype=mimetype,
user=account,
)
# Verify filename was truncated (the logic truncates the base name to 200 chars + extension)
# So the total length should be <= 200 + len(extension) + 1 (for the dot)
assert len(upload_file.name) <= 200 + len(upload_file.extension) + 1
assert upload_file.name.endswith(".txt")
# Verify the base name was truncated
base_name = upload_file.name[:-4] # Remove .txt
assert len(base_name) <= 200
def test_upload_file_datasets_unsupported_type(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test file upload for datasets with unsupported file type.
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
filename = "test_image.jpg"
content = b"test content"
mimetype = "image/jpeg"
with pytest.raises(UnsupportedFileTypeError):
FileService.upload_file(
filename=filename,
content=content,
mimetype=mimetype,
user=account,
source="datasets",
)
def test_upload_file_too_large(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test file upload with file size exceeding limit.
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
filename = "large_image.jpg"
# Create content larger than the limit
content = b"x" * (dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024 + 1)
mimetype = "image/jpeg"
with pytest.raises(FileTooLargeError):
FileService.upload_file(
filename=filename,
content=content,
mimetype=mimetype,
user=account,
)
# Test is_file_size_within_limit method
def test_is_file_size_within_limit_image_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test file size check for image files within limit.
"""
extension = "jpg"
file_size = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024 # Exactly at limit
result = FileService.is_file_size_within_limit(extension=extension, file_size=file_size)
assert result is True
def test_is_file_size_within_limit_video_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test file size check for video files within limit.
"""
extension = "mp4"
file_size = dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT * 1024 * 1024 # Exactly at limit
result = FileService.is_file_size_within_limit(extension=extension, file_size=file_size)
assert result is True
def test_is_file_size_within_limit_audio_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test file size check for audio files within limit.
"""
extension = "mp3"
file_size = dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT * 1024 * 1024 # Exactly at limit
result = FileService.is_file_size_within_limit(extension=extension, file_size=file_size)
assert result is True
def test_is_file_size_within_limit_document_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test file size check for document files within limit.
"""
extension = "pdf"
file_size = dify_config.UPLOAD_FILE_SIZE_LIMIT * 1024 * 1024 # Exactly at limit
result = FileService.is_file_size_within_limit(extension=extension, file_size=file_size)
assert result is True
def test_is_file_size_within_limit_image_exceeded(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test file size check for image files exceeding limit.
"""
extension = "jpg"
file_size = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024 + 1 # Exceeds limit
result = FileService.is_file_size_within_limit(extension=extension, file_size=file_size)
assert result is False
def test_is_file_size_within_limit_unknown_extension(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test file size check for unknown file extension.
"""
extension = "xyz"
file_size = dify_config.UPLOAD_FILE_SIZE_LIMIT * 1024 * 1024 # Uses default limit
result = FileService.is_file_size_within_limit(extension=extension, file_size=file_size)
assert result is True
# Test upload_text method
def test_upload_text_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful text upload.
"""
fake = Faker()
text = "This is a test text content"
text_name = "test_text.txt"
# Mock current_user
with patch("services.file_service.current_user") as mock_current_user:
mock_current_user.current_tenant_id = str(fake.uuid4())
mock_current_user.id = str(fake.uuid4())
upload_file = FileService.upload_text(text=text, text_name=text_name)
assert upload_file is not None
assert upload_file.name == text_name
assert upload_file.size == len(text)
assert upload_file.extension == "txt"
assert upload_file.mime_type == "text/plain"
assert upload_file.used is True
assert upload_file.used_by == mock_current_user.id
# Verify storage was called
mock_external_service_dependencies["storage"].save.assert_called_once()
def test_upload_text_name_too_long(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test text upload with name that exceeds length limit.
"""
fake = Faker()
text = "test content"
long_name = "a" * 250 # Longer than 200 characters
# Mock current_user
with patch("services.file_service.current_user") as mock_current_user:
mock_current_user.current_tenant_id = str(fake.uuid4())
mock_current_user.id = str(fake.uuid4())
upload_file = FileService.upload_text(text=text, text_name=long_name)
# Verify name was truncated
assert len(upload_file.name) <= 200
assert upload_file.name == "a" * 200
# Test get_file_preview method
def test_get_file_preview_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful file preview generation.
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
upload_file = self._create_test_upload_file(
db_session_with_containers, mock_external_service_dependencies, account
)
# Update file to have document extension
upload_file.extension = "pdf"
from extensions.ext_database import db
db.session.commit()
result = FileService.get_file_preview(file_id=upload_file.id)
assert result == "extracted text content"
mock_external_service_dependencies["extract_processor"].load_from_upload_file.assert_called_once()
def test_get_file_preview_file_not_found(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test file preview with non-existent file.
"""
fake = Faker()
non_existent_id = str(fake.uuid4())
with pytest.raises(NotFound, match="File not found"):
FileService.get_file_preview(file_id=non_existent_id)
def test_get_file_preview_unsupported_file_type(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test file preview with unsupported file type.
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
upload_file = self._create_test_upload_file(
db_session_with_containers, mock_external_service_dependencies, account
)
# Update file to have non-document extension
upload_file.extension = "jpg"
from extensions.ext_database import db
db.session.commit()
with pytest.raises(UnsupportedFileTypeError):
FileService.get_file_preview(file_id=upload_file.id)
def test_get_file_preview_text_truncation(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test file preview with text that exceeds preview limit.
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
upload_file = self._create_test_upload_file(
db_session_with_containers, mock_external_service_dependencies, account
)
# Update file to have document extension
upload_file.extension = "pdf"
from extensions.ext_database import db
db.session.commit()
# Mock long text content
long_text = "x" * 5000 # Longer than PREVIEW_WORDS_LIMIT
mock_external_service_dependencies["extract_processor"].load_from_upload_file.return_value = long_text
result = FileService.get_file_preview(file_id=upload_file.id)
assert len(result) == 3000 # PREVIEW_WORDS_LIMIT
assert result == "x" * 3000
# Test get_image_preview method
def test_get_image_preview_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful image preview generation.
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
upload_file = self._create_test_upload_file(
db_session_with_containers, mock_external_service_dependencies, account
)
# Update file to have image extension
upload_file.extension = "jpg"
from extensions.ext_database import db
db.session.commit()
timestamp = "1234567890"
nonce = "test_nonce"
sign = "test_signature"
generator, mime_type = FileService.get_image_preview(
file_id=upload_file.id,
timestamp=timestamp,
nonce=nonce,
sign=sign,
)
assert generator is not None
assert mime_type == upload_file.mime_type
mock_external_service_dependencies["file_helpers"].verify_image_signature.assert_called_once()
def test_get_image_preview_invalid_signature(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test image preview with invalid signature.
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
upload_file = self._create_test_upload_file(
db_session_with_containers, mock_external_service_dependencies, account
)
# Mock invalid signature
mock_external_service_dependencies["file_helpers"].verify_image_signature.return_value = False
timestamp = "1234567890"
nonce = "test_nonce"
sign = "invalid_signature"
with pytest.raises(NotFound, match="File not found or signature is invalid"):
FileService.get_image_preview(
file_id=upload_file.id,
timestamp=timestamp,
nonce=nonce,
sign=sign,
)
def test_get_image_preview_file_not_found(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test image preview with non-existent file.
"""
fake = Faker()
non_existent_id = str(fake.uuid4())
timestamp = "1234567890"
nonce = "test_nonce"
sign = "test_signature"
with pytest.raises(NotFound, match="File not found or signature is invalid"):
FileService.get_image_preview(
file_id=non_existent_id,
timestamp=timestamp,
nonce=nonce,
sign=sign,
)
def test_get_image_preview_unsupported_file_type(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test image preview with non-image file type.
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
upload_file = self._create_test_upload_file(
db_session_with_containers, mock_external_service_dependencies, account
)
# Update file to have non-image extension
upload_file.extension = "pdf"
from extensions.ext_database import db
db.session.commit()
timestamp = "1234567890"
nonce = "test_nonce"
sign = "test_signature"
with pytest.raises(UnsupportedFileTypeError):
FileService.get_image_preview(
file_id=upload_file.id,
timestamp=timestamp,
nonce=nonce,
sign=sign,
)
# Test get_file_generator_by_file_id method
def test_get_file_generator_by_file_id_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful file generator retrieval.
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
upload_file = self._create_test_upload_file(
db_session_with_containers, mock_external_service_dependencies, account
)
timestamp = "1234567890"
nonce = "test_nonce"
sign = "test_signature"
generator, file_obj = FileService.get_file_generator_by_file_id(
file_id=upload_file.id,
timestamp=timestamp,
nonce=nonce,
sign=sign,
)
assert generator is not None
assert file_obj == upload_file
mock_external_service_dependencies["file_helpers"].verify_file_signature.assert_called_once()
def test_get_file_generator_by_file_id_invalid_signature(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test file generator retrieval with invalid signature.
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
upload_file = self._create_test_upload_file(
db_session_with_containers, mock_external_service_dependencies, account
)
# Mock invalid signature
mock_external_service_dependencies["file_helpers"].verify_file_signature.return_value = False
timestamp = "1234567890"
nonce = "test_nonce"
sign = "invalid_signature"
with pytest.raises(NotFound, match="File not found or signature is invalid"):
FileService.get_file_generator_by_file_id(
file_id=upload_file.id,
timestamp=timestamp,
nonce=nonce,
sign=sign,
)
def test_get_file_generator_by_file_id_file_not_found(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test file generator retrieval with non-existent file.
"""
fake = Faker()
non_existent_id = str(fake.uuid4())
timestamp = "1234567890"
nonce = "test_nonce"
sign = "test_signature"
with pytest.raises(NotFound, match="File not found or signature is invalid"):
FileService.get_file_generator_by_file_id(
file_id=non_existent_id,
timestamp=timestamp,
nonce=nonce,
sign=sign,
)
# Test get_public_image_preview method
def test_get_public_image_preview_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful public image preview generation.
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
upload_file = self._create_test_upload_file(
db_session_with_containers, mock_external_service_dependencies, account
)
# Update file to have image extension
upload_file.extension = "jpg"
from extensions.ext_database import db
db.session.commit()
generator, mime_type = FileService.get_public_image_preview(file_id=upload_file.id)
assert generator is not None
assert mime_type == upload_file.mime_type
mock_external_service_dependencies["storage"].load.assert_called_once()
def test_get_public_image_preview_file_not_found(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test public image preview with non-existent file.
"""
fake = Faker()
non_existent_id = str(fake.uuid4())
with pytest.raises(NotFound, match="File not found or signature is invalid"):
FileService.get_public_image_preview(file_id=non_existent_id)
def test_get_public_image_preview_unsupported_file_type(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test public image preview with non-image file type.
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
upload_file = self._create_test_upload_file(
db_session_with_containers, mock_external_service_dependencies, account
)
# Update file to have non-image extension
upload_file.extension = "pdf"
from extensions.ext_database import db
db.session.commit()
with pytest.raises(UnsupportedFileTypeError):
FileService.get_public_image_preview(file_id=upload_file.id)
# Test edge cases and boundary conditions
def test_upload_file_empty_content(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test file upload with empty content.
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
filename = "empty.txt"
content = b""
mimetype = "text/plain"
upload_file = FileService.upload_file(
filename=filename,
content=content,
mimetype=mimetype,
user=account,
)
assert upload_file is not None
assert upload_file.size == 0
def test_upload_file_special_characters_in_name(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test file upload with special characters in filename (but valid ones).
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
filename = "test-file_with_underscores_and.dots.txt"
content = b"test content"
mimetype = "text/plain"
upload_file = FileService.upload_file(
filename=filename,
content=content,
mimetype=mimetype,
user=account,
)
assert upload_file is not None
assert upload_file.name == filename
def test_upload_file_different_case_extensions(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test file upload with different case extensions.
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
filename = "test.PDF"
content = b"test content"
mimetype = "application/pdf"
upload_file = FileService.upload_file(
filename=filename,
content=content,
mimetype=mimetype,
user=account,
)
assert upload_file is not None
assert upload_file.extension == "pdf" # Should be converted to lowercase
def test_upload_text_empty_text(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test text upload with empty text.
"""
fake = Faker()
text = ""
text_name = "empty.txt"
# Mock current_user
with patch("services.file_service.current_user") as mock_current_user:
mock_current_user.current_tenant_id = str(fake.uuid4())
mock_current_user.id = str(fake.uuid4())
upload_file = FileService.upload_text(text=text, text_name=text_name)
assert upload_file is not None
assert upload_file.size == 0
def test_file_size_limits_edge_cases(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test file size limits with edge case values.
"""
# Test exactly at limit
for extension, limit_config in [
("jpg", dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT),
("mp4", dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT),
("mp3", dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT),
("pdf", dify_config.UPLOAD_FILE_SIZE_LIMIT),
]:
file_size = limit_config * 1024 * 1024
result = FileService.is_file_size_within_limit(extension=extension, file_size=file_size)
assert result is True
# Test one byte over limit
file_size = limit_config * 1024 * 1024 + 1
result = FileService.is_file_size_within_limit(extension=extension, file_size=file_size)
assert result is False
def test_upload_file_with_source_url(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test file upload with source URL that gets overridden by signed URL.
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
filename = "test.pdf"
content = b"test content"
mimetype = "application/pdf"
source_url = "https://original-source.com/file.pdf"
upload_file = FileService.upload_file(
filename=filename,
content=content,
mimetype=mimetype,
user=account,
source_url=source_url,
)
# When source_url is provided, it should be preserved
assert upload_file.source_url == source_url
# The signed URL should only be set when source_url is empty
# Let's test that scenario
upload_file2 = FileService.upload_file(
filename="test2.pdf",
content=b"test content 2",
mimetype="application/pdf",
user=account,
source_url="", # Empty source_url
)
# Should have the signed URL when source_url is empty
assert upload_file2.source_url == "https://example.com/signed-url"

View File

@ -0,0 +1,775 @@
from unittest.mock import patch
import pytest
from faker import Faker
from models.model import MessageFeedback
from services.app_service import AppService
from services.errors.message import (
FirstMessageNotExistsError,
LastMessageNotExistsError,
MessageNotExistsError,
SuggestedQuestionsAfterAnswerDisabledError,
)
from services.message_service import MessageService
class TestMessageService:
"""Integration tests for MessageService using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("services.account_service.FeatureService") as mock_account_feature_service,
patch("services.message_service.ModelManager") as mock_model_manager,
patch("services.message_service.WorkflowService") as mock_workflow_service,
patch("services.message_service.AdvancedChatAppConfigManager") as mock_app_config_manager,
patch("services.message_service.LLMGenerator") as mock_llm_generator,
patch("services.message_service.TraceQueueManager") as mock_trace_manager_class,
patch("services.message_service.TokenBufferMemory") as mock_token_buffer_memory,
):
# Setup default mock returns
mock_account_feature_service.get_features.return_value.billing.enabled = False
# Mock ModelManager
mock_model_instance = mock_model_manager.return_value.get_default_model_instance.return_value
mock_model_instance.get_tts_voices.return_value = [{"value": "test-voice"}]
# Mock get_model_instance method as well
mock_model_manager.return_value.get_model_instance.return_value = mock_model_instance
# Mock WorkflowService
mock_workflow = mock_workflow_service.return_value.get_published_workflow.return_value
mock_workflow_service.return_value.get_draft_workflow.return_value = mock_workflow
# Mock AdvancedChatAppConfigManager
mock_app_config = mock_app_config_manager.get_app_config.return_value
mock_app_config.additional_features.suggested_questions_after_answer = True
# Mock LLMGenerator
mock_llm_generator.generate_suggested_questions_after_answer.return_value = ["Question 1", "Question 2"]
# Mock TraceQueueManager
mock_trace_manager_instance = mock_trace_manager_class.return_value
# Mock TokenBufferMemory
mock_memory_instance = mock_token_buffer_memory.return_value
mock_memory_instance.get_history_prompt_text.return_value = "Mocked history prompt"
yield {
"account_feature_service": mock_account_feature_service,
"model_manager": mock_model_manager,
"workflow_service": mock_workflow_service,
"app_config_manager": mock_app_config_manager,
"llm_generator": mock_llm_generator,
"trace_manager_class": mock_trace_manager_class,
"trace_manager_instance": mock_trace_manager_instance,
"token_buffer_memory": mock_token_buffer_memory,
# "current_user": mock_current_user,
}
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
"""
Helper method to create a test app and account for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
Returns:
tuple: (app, account) - Created app and account instances
"""
fake = Faker()
# Setup mocks for account creation
mock_external_service_dependencies[
"account_feature_service"
].get_system_features.return_value.is_allow_register = True
# Create account and tenant first
from services.account_service import AccountService, TenantService
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Setup app creation arguments
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "advanced-chat", # Use advanced-chat mode to use mocked workflow
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
"api_rph": 100,
"api_rpm": 10,
}
# Create app
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
# Setup current_user mock
self._mock_current_user(mock_external_service_dependencies, account.id, tenant.id)
return app, account
def _mock_current_user(self, mock_external_service_dependencies, account_id, tenant_id):
"""
Helper method to mock the current user for testing.
"""
# mock_external_service_dependencies["current_user"].id = account_id
# mock_external_service_dependencies["current_user"].current_tenant_id = tenant_id
def _create_test_conversation(self, app, account, fake):
"""
Helper method to create a test conversation with all required fields.
"""
from extensions.ext_database import db
from models.model import Conversation
conversation = Conversation(
app_id=app.id,
app_model_config_id=None,
model_provider=None,
model_id="",
override_model_configs=None,
mode=app.mode,
name=fake.sentence(),
inputs={},
introduction="",
system_instruction="",
system_instruction_tokens=0,
status="normal",
invoke_from="console",
from_source="console",
from_end_user_id=None,
from_account_id=account.id,
)
db.session.add(conversation)
db.session.flush()
return conversation
def _create_test_message(self, app, conversation, account, fake):
"""
Helper method to create a test message with all required fields.
"""
import json
from extensions.ext_database import db
from models.model import Message
message = Message(
app_id=app.id,
model_provider=None,
model_id="",
override_model_configs=None,
conversation_id=conversation.id,
inputs={},
query=fake.sentence(),
message=json.dumps([{"role": "user", "text": fake.sentence()}]),
message_tokens=0,
message_unit_price=0,
message_price_unit=0.001,
answer=fake.text(max_nb_chars=200),
answer_tokens=0,
answer_unit_price=0,
answer_price_unit=0.001,
parent_message_id=None,
provider_response_latency=0,
total_price=0,
currency="USD",
invoke_from="console",
from_source="console",
from_end_user_id=None,
from_account_id=account.id,
)
db.session.add(message)
db.session.commit()
return message
def test_pagination_by_first_id_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful pagination by first ID.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and multiple messages
conversation = self._create_test_conversation(app, account, fake)
messages = []
for i in range(5):
message = self._create_test_message(app, conversation, account, fake)
messages.append(message)
# Test pagination by first ID
result = MessageService.pagination_by_first_id(
app_model=app,
user=account,
conversation_id=conversation.id,
first_id=messages[2].id, # Use middle message as first_id
limit=2,
order="asc",
)
# Verify results
assert result.limit == 2
assert len(result.data) == 2
# total 5, from the middle, no more
assert result.has_more is False
# Verify messages are in ascending order
assert result.data[0].created_at <= result.data[1].created_at
def test_pagination_by_first_id_no_user(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test pagination by first ID when no user is provided.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Test pagination with no user
result = MessageService.pagination_by_first_id(
app_model=app, user=None, conversation_id=fake.uuid4(), first_id=None, limit=10
)
# Verify empty result
assert result.limit == 10
assert len(result.data) == 0
assert result.has_more is False
def test_pagination_by_first_id_no_conversation_id(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test pagination by first ID when no conversation ID is provided.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Test pagination with no conversation ID
result = MessageService.pagination_by_first_id(
app_model=app, user=account, conversation_id="", first_id=None, limit=10
)
# Verify empty result
assert result.limit == 10
assert len(result.data) == 0
assert result.has_more is False
def test_pagination_by_first_id_invalid_first_id(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test pagination by first ID with invalid first_id.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message
conversation = self._create_test_conversation(app, account, fake)
self._create_test_message(app, conversation, account, fake)
# Test pagination with invalid first_id
with pytest.raises(FirstMessageNotExistsError):
MessageService.pagination_by_first_id(
app_model=app,
user=account,
conversation_id=conversation.id,
first_id=fake.uuid4(), # Non-existent message ID
limit=10,
)
def test_pagination_by_last_id_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful pagination by last ID.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and multiple messages
conversation = self._create_test_conversation(app, account, fake)
messages = []
for i in range(5):
message = self._create_test_message(app, conversation, account, fake)
messages.append(message)
# Test pagination by last ID
result = MessageService.pagination_by_last_id(
app_model=app,
user=account,
last_id=messages[2].id, # Use middle message as last_id
limit=2,
conversation_id=conversation.id,
)
# Verify results
assert result.limit == 2
assert len(result.data) == 2
# total 5, from the middle, no more
assert result.has_more is False
# Verify messages are in descending order
assert result.data[0].created_at >= result.data[1].created_at
def test_pagination_by_last_id_with_include_ids(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test pagination by last ID with include_ids filter.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and multiple messages
conversation = self._create_test_conversation(app, account, fake)
messages = []
for i in range(5):
message = self._create_test_message(app, conversation, account, fake)
messages.append(message)
# Test pagination with include_ids
include_ids = [messages[0].id, messages[1].id, messages[2].id]
result = MessageService.pagination_by_last_id(
app_model=app, user=account, last_id=messages[1].id, limit=2, include_ids=include_ids
)
# Verify results
assert result.limit == 2
assert len(result.data) <= 2
# Verify all returned messages are in include_ids
for message in result.data:
assert message.id in include_ids
def test_pagination_by_last_id_no_user(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test pagination by last ID when no user is provided.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Test pagination with no user
result = MessageService.pagination_by_last_id(app_model=app, user=None, last_id=None, limit=10)
# Verify empty result
assert result.limit == 10
assert len(result.data) == 0
assert result.has_more is False
def test_pagination_by_last_id_invalid_last_id(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test pagination by last ID with invalid last_id.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message
conversation = self._create_test_conversation(app, account, fake)
self._create_test_message(app, conversation, account, fake)
# Test pagination with invalid last_id
with pytest.raises(LastMessageNotExistsError):
MessageService.pagination_by_last_id(
app_model=app,
user=account,
last_id=fake.uuid4(), # Non-existent message ID
limit=10,
conversation_id=conversation.id,
)
def test_create_feedback_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful creation of feedback.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
# Create feedback
rating = "like"
content = fake.text(max_nb_chars=100)
feedback = MessageService.create_feedback(
app_model=app, message_id=message.id, user=account, rating=rating, content=content
)
# Verify feedback was created correctly
assert feedback.app_id == app.id
assert feedback.conversation_id == conversation.id
assert feedback.message_id == message.id
assert feedback.rating == rating
assert feedback.content == content
assert feedback.from_source == "admin"
assert feedback.from_account_id == account.id
assert feedback.from_end_user_id is None
def test_create_feedback_no_user(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test creating feedback when no user is provided.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
# Test creating feedback with no user
with pytest.raises(ValueError, match="user cannot be None"):
MessageService.create_feedback(
app_model=app, message_id=message.id, user=None, rating="like", content=fake.text(max_nb_chars=100)
)
def test_create_feedback_update_existing(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test updating existing feedback.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
# Create initial feedback
initial_rating = "like"
initial_content = fake.text(max_nb_chars=100)
feedback = MessageService.create_feedback(
app_model=app, message_id=message.id, user=account, rating=initial_rating, content=initial_content
)
# Update feedback
updated_rating = "dislike"
updated_content = fake.text(max_nb_chars=100)
updated_feedback = MessageService.create_feedback(
app_model=app, message_id=message.id, user=account, rating=updated_rating, content=updated_content
)
# Verify feedback was updated correctly
assert updated_feedback.id == feedback.id
assert updated_feedback.rating == updated_rating
assert updated_feedback.content == updated_content
assert updated_feedback.rating != initial_rating
assert updated_feedback.content != initial_content
def test_create_feedback_delete_existing(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test deleting existing feedback by setting rating to None.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
# Create initial feedback
feedback = MessageService.create_feedback(
app_model=app, message_id=message.id, user=account, rating="like", content=fake.text(max_nb_chars=100)
)
# Delete feedback by setting rating to None
MessageService.create_feedback(app_model=app, message_id=message.id, user=account, rating=None, content=None)
# Verify feedback was deleted
from extensions.ext_database import db
deleted_feedback = db.session.query(MessageFeedback).filter(MessageFeedback.id == feedback.id).first()
assert deleted_feedback is None
def test_create_feedback_no_rating_when_not_exists(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test creating feedback with no rating when feedback doesn't exist.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
# Test creating feedback with no rating when no feedback exists
with pytest.raises(ValueError, match="rating cannot be None when feedback not exists"):
MessageService.create_feedback(
app_model=app, message_id=message.id, user=account, rating=None, content=None
)
def test_get_all_messages_feedbacks_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful retrieval of all message feedbacks.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create multiple conversations and messages with feedbacks
feedbacks = []
for i in range(3):
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
feedback = MessageService.create_feedback(
app_model=app,
message_id=message.id,
user=account,
rating="like" if i % 2 == 0 else "dislike",
content=f"Feedback {i}: {fake.text(max_nb_chars=50)}",
)
feedbacks.append(feedback)
# Get all feedbacks
result = MessageService.get_all_messages_feedbacks(app, page=1, limit=10)
# Verify results
assert len(result) == 3
# Verify feedbacks are ordered by created_at desc
for i in range(len(result) - 1):
assert result[i]["created_at"] >= result[i + 1]["created_at"]
def test_get_all_messages_feedbacks_pagination(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test pagination of message feedbacks.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create multiple conversations and messages with feedbacks
for i in range(5):
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
MessageService.create_feedback(
app_model=app, message_id=message.id, user=account, rating="like", content=f"Feedback {i}"
)
# Get feedbacks with pagination
result_page_1 = MessageService.get_all_messages_feedbacks(app, page=1, limit=3)
result_page_2 = MessageService.get_all_messages_feedbacks(app, page=2, limit=3)
# Verify pagination results
assert len(result_page_1) == 3
assert len(result_page_2) == 2
# Verify no overlap between pages
page_1_ids = {feedback["id"] for feedback in result_page_1}
page_2_ids = {feedback["id"] for feedback in result_page_2}
assert len(page_1_ids.intersection(page_2_ids)) == 0
def test_get_message_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful retrieval of message.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
# Get message
retrieved_message = MessageService.get_message(app_model=app, user=account, message_id=message.id)
# Verify message was retrieved correctly
assert retrieved_message.id == message.id
assert retrieved_message.app_id == app.id
assert retrieved_message.conversation_id == conversation.id
assert retrieved_message.from_source == "console"
assert retrieved_message.from_account_id == account.id
def test_get_message_not_exists(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test getting message that doesn't exist.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Test getting non-existent message
with pytest.raises(MessageNotExistsError):
MessageService.get_message(app_model=app, user=account, message_id=fake.uuid4())
def test_get_message_wrong_user(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test getting message with wrong user (different account).
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
# Create another account
from services.account_service import AccountService, TenantService
other_account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(other_account, name=fake.company())
# Test getting message with different user
with pytest.raises(MessageNotExistsError):
MessageService.get_message(app_model=app, user=other_account, message_id=message.id)
def test_get_suggested_questions_after_answer_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful generation of suggested questions after answer.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
# Mock the LLMGenerator to return specific questions
mock_questions = ["What is AI?", "How does machine learning work?", "Tell me about neural networks"]
mock_external_service_dependencies[
"llm_generator"
].generate_suggested_questions_after_answer.return_value = mock_questions
# Get suggested questions
from core.app.entities.app_invoke_entities import InvokeFrom
result = MessageService.get_suggested_questions_after_answer(
app_model=app, user=account, message_id=message.id, invoke_from=InvokeFrom.SERVICE_API
)
# Verify results
assert result == mock_questions
# Verify LLMGenerator was called
mock_external_service_dependencies[
"llm_generator"
].generate_suggested_questions_after_answer.assert_called_once()
# Verify TraceQueueManager was called
mock_external_service_dependencies["trace_manager_instance"].add_trace_task.assert_called_once()
def test_get_suggested_questions_after_answer_no_user(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test getting suggested questions when no user is provided.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
# Test getting suggested questions with no user
from core.app.entities.app_invoke_entities import InvokeFrom
with pytest.raises(ValueError, match="user cannot be None"):
MessageService.get_suggested_questions_after_answer(
app_model=app, user=None, message_id=message.id, invoke_from=InvokeFrom.SERVICE_API
)
def test_get_suggested_questions_after_answer_disabled(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test getting suggested questions when feature is disabled.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
# Mock the feature to be disabled
mock_external_service_dependencies[
"app_config_manager"
].get_app_config.return_value.additional_features.suggested_questions_after_answer = False
# Test getting suggested questions when feature is disabled
from core.app.entities.app_invoke_entities import InvokeFrom
with pytest.raises(SuggestedQuestionsAfterAnswerDisabledError):
MessageService.get_suggested_questions_after_answer(
app_model=app, user=account, message_id=message.id, invoke_from=InvokeFrom.SERVICE_API
)
def test_get_suggested_questions_after_answer_no_workflow(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test getting suggested questions when no workflow exists.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
# Mock no workflow
mock_external_service_dependencies["workflow_service"].return_value.get_published_workflow.return_value = None
# Get suggested questions (should return empty list)
from core.app.entities.app_invoke_entities import InvokeFrom
result = MessageService.get_suggested_questions_after_answer(
app_model=app, user=account, message_id=message.id, invoke_from=InvokeFrom.SERVICE_API
)
# Verify empty result
assert result == []
def test_get_suggested_questions_after_answer_debugger_mode(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test getting suggested questions in debugger mode.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
# Mock questions
mock_questions = ["Debug question 1", "Debug question 2"]
mock_external_service_dependencies[
"llm_generator"
].generate_suggested_questions_after_answer.return_value = mock_questions
# Get suggested questions in debugger mode
from core.app.entities.app_invoke_entities import InvokeFrom
result = MessageService.get_suggested_questions_after_answer(
app_model=app, user=account, message_id=message.id, invoke_from=InvokeFrom.DEBUGGER
)
# Verify results
assert result == mock_questions
# Verify draft workflow was used instead of published workflow
mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.assert_called_once_with(
app_model=app
)
# Verify TraceQueueManager was called
mock_external_service_dependencies["trace_manager_instance"].add_trace_task.assert_called_once()

View File

@ -12,6 +12,10 @@ from services.workflow_draft_variable_service import (
)
def _get_random_variable_name(fake: Faker):
return "".join(fake.random_letters(length=10))
class TestWorkflowDraftVariableService:
"""
Comprehensive integration tests for WorkflowDraftVariableService using testcontainers.
@ -112,7 +116,14 @@ class TestWorkflowDraftVariableService:
return workflow
def _create_test_variable(
self, db_session_with_containers, app_id, node_id, name, value, variable_type="conversation", fake=None
self,
db_session_with_containers,
app_id,
node_id,
name,
value,
variable_type: DraftVariableType = DraftVariableType.CONVERSATION,
fake=None,
):
"""
Helper method to create a test workflow draft variable with proper configuration.
@ -227,7 +238,13 @@ class TestWorkflowDraftVariableService:
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "var2", var2_value, fake=fake
)
var3 = self._create_test_variable(
db_session_with_containers, app.id, "test_node_1", "var3", var3_value, "node", fake=fake
db_session_with_containers,
app.id,
"test_node_1",
"var3",
var3_value,
variable_type=DraftVariableType.NODE,
fake=fake,
)
selectors = [
[CONVERSATION_VARIABLE_NODE_ID, "var1"],
@ -263,9 +280,14 @@ class TestWorkflowDraftVariableService:
fake = Faker()
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
for i in range(5):
test_value = StringSegment(value=fake.numerify("value##"))
test_value = StringSegment(value=fake.numerify("value######"))
self._create_test_variable(
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, fake.word(), test_value, fake=fake
db_session_with_containers,
app.id,
CONVERSATION_VARIABLE_NODE_ID,
_get_random_variable_name(fake),
test_value,
fake=fake,
)
service = WorkflowDraftVariableService(db_session_with_containers)
result = service.list_variables_without_values(app.id, page=1, limit=3)
@ -291,10 +313,32 @@ class TestWorkflowDraftVariableService:
var1_value = StringSegment(value=fake.word())
var2_value = StringSegment(value=fake.word())
var3_value = StringSegment(value=fake.word())
self._create_test_variable(db_session_with_containers, app.id, node_id, "var1", var1_value, "node", fake=fake)
self._create_test_variable(db_session_with_containers, app.id, node_id, "var2", var3_value, "node", fake=fake)
self._create_test_variable(
db_session_with_containers, app.id, "other_node", "var3", var2_value, "node", fake=fake
db_session_with_containers,
app.id,
node_id,
"var1",
var1_value,
variable_type=DraftVariableType.NODE,
fake=fake,
)
self._create_test_variable(
db_session_with_containers,
app.id,
node_id,
"var2",
var3_value,
variable_type=DraftVariableType.NODE,
fake=fake,
)
self._create_test_variable(
db_session_with_containers,
app.id,
"other_node",
"var3",
var2_value,
variable_type=DraftVariableType.NODE,
fake=fake,
)
service = WorkflowDraftVariableService(db_session_with_containers)
result = service.list_node_variables(app.id, node_id)
@ -328,7 +372,13 @@ class TestWorkflowDraftVariableService:
)
sys_var_value = StringSegment(value=fake.word())
self._create_test_variable(
db_session_with_containers, app.id, SYSTEM_VARIABLE_NODE_ID, "sys_var", sys_var_value, "system", fake=fake
db_session_with_containers,
app.id,
SYSTEM_VARIABLE_NODE_ID,
"sys_var",
sys_var_value,
variable_type=DraftVariableType.SYS,
fake=fake,
)
service = WorkflowDraftVariableService(db_session_with_containers)
result = service.list_conversation_variables(app.id)
@ -480,14 +530,24 @@ class TestWorkflowDraftVariableService:
fake = Faker()
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
for i in range(3):
test_value = StringSegment(value=fake.numerify("value##"))
test_value = StringSegment(value=fake.numerify("value######"))
self._create_test_variable(
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, fake.word(), test_value, fake=fake
db_session_with_containers,
app.id,
CONVERSATION_VARIABLE_NODE_ID,
_get_random_variable_name(fake),
test_value,
fake=fake,
)
other_app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
other_value = StringSegment(value=fake.word())
self._create_test_variable(
db_session_with_containers, other_app.id, CONVERSATION_VARIABLE_NODE_ID, fake.word(), other_value, fake=fake
db_session_with_containers,
other_app.id,
CONVERSATION_VARIABLE_NODE_ID,
_get_random_variable_name(fake),
other_value,
fake=fake,
)
from extensions.ext_database import db
@ -515,17 +575,34 @@ class TestWorkflowDraftVariableService:
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
node_id = fake.word()
for i in range(2):
test_value = StringSegment(value=fake.numerify("node_value##"))
test_value = StringSegment(value=fake.numerify("node_value######"))
self._create_test_variable(
db_session_with_containers, app.id, node_id, fake.word(), test_value, "node", fake=fake
db_session_with_containers,
app.id,
node_id,
_get_random_variable_name(fake),
test_value,
variable_type=DraftVariableType.NODE,
fake=fake,
)
other_node_value = StringSegment(value=fake.word())
self._create_test_variable(
db_session_with_containers, app.id, "other_node", fake.word(), other_node_value, "node", fake=fake
db_session_with_containers,
app.id,
"other_node",
_get_random_variable_name(fake),
other_node_value,
variable_type=DraftVariableType.NODE,
fake=fake,
)
conv_value = StringSegment(value=fake.word())
self._create_test_variable(
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, fake.word(), conv_value, fake=fake
db_session_with_containers,
app.id,
CONVERSATION_VARIABLE_NODE_ID,
_get_random_variable_name(fake),
conv_value,
fake=fake,
)
from extensions.ext_database import db
@ -627,7 +704,7 @@ class TestWorkflowDraftVariableService:
SYSTEM_VARIABLE_NODE_ID,
"conversation_id",
conv_id_value,
"system",
variable_type=DraftVariableType.SYS,
fake=fake,
)
service = WorkflowDraftVariableService(db_session_with_containers)
@ -664,10 +741,22 @@ class TestWorkflowDraftVariableService:
sys_var1_value = StringSegment(value=fake.word())
sys_var2_value = StringSegment(value=fake.word())
sys_var1 = self._create_test_variable(
db_session_with_containers, app.id, SYSTEM_VARIABLE_NODE_ID, "sys_var1", sys_var1_value, "system", fake=fake
db_session_with_containers,
app.id,
SYSTEM_VARIABLE_NODE_ID,
"sys_var1",
sys_var1_value,
variable_type=DraftVariableType.SYS,
fake=fake,
)
sys_var2 = self._create_test_variable(
db_session_with_containers, app.id, SYSTEM_VARIABLE_NODE_ID, "sys_var2", sys_var2_value, "system", fake=fake
db_session_with_containers,
app.id,
SYSTEM_VARIABLE_NODE_ID,
"sys_var2",
sys_var2_value,
variable_type=DraftVariableType.SYS,
fake=fake,
)
conv_var_value = StringSegment(value=fake.word())
self._create_test_variable(
@ -701,10 +790,22 @@ class TestWorkflowDraftVariableService:
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "test_conv_var", test_value, fake=fake
)
sys_var = self._create_test_variable(
db_session_with_containers, app.id, SYSTEM_VARIABLE_NODE_ID, "test_sys_var", test_value, "system", fake=fake
db_session_with_containers,
app.id,
SYSTEM_VARIABLE_NODE_ID,
"test_sys_var",
test_value,
variable_type=DraftVariableType.SYS,
fake=fake,
)
node_var = self._create_test_variable(
db_session_with_containers, app.id, "test_node", "test_node_var", test_value, "node", fake=fake
db_session_with_containers,
app.id,
"test_node",
"test_node_var",
test_value,
variable_type=DraftVariableType.NODE,
fake=fake,
)
service = WorkflowDraftVariableService(db_session_with_containers)
retrieved_conv_var = service.get_conversation_variable(app.id, "test_conv_var")

View File

@ -0,0 +1,124 @@
import time
from unittest.mock import MagicMock, patch
import pytest
from core.app.features.rate_limiting.rate_limit import RateLimit
@pytest.fixture
def mock_redis():
"""Mock Redis client with realistic behavior for rate limiting tests."""
mock_client = MagicMock()
# Redis data storage for simulation
mock_data = {}
mock_hashes = {}
mock_expiry = {}
def mock_setex(key, ttl, value):
mock_data[key] = str(value)
mock_expiry[key] = time.time() + ttl.total_seconds() if hasattr(ttl, "total_seconds") else time.time() + ttl
return True
def mock_get(key):
if key in mock_data and (key not in mock_expiry or time.time() < mock_expiry[key]):
return mock_data[key].encode("utf-8")
return None
def mock_exists(key):
return key in mock_data or key in mock_hashes
def mock_expire(key, ttl):
if key in mock_data or key in mock_hashes:
mock_expiry[key] = time.time() + ttl.total_seconds() if hasattr(ttl, "total_seconds") else time.time() + ttl
return True
def mock_hset(key, field, value):
if key not in mock_hashes:
mock_hashes[key] = {}
mock_hashes[key][field] = str(value).encode("utf-8")
return True
def mock_hgetall(key):
return mock_hashes.get(key, {})
def mock_hdel(key, *fields):
if key in mock_hashes:
count = 0
for field in fields:
if field in mock_hashes[key]:
del mock_hashes[key][field]
count += 1
return count
return 0
def mock_hlen(key):
return len(mock_hashes.get(key, {}))
# Configure mock methods
mock_client.setex = mock_setex
mock_client.get = mock_get
mock_client.exists = mock_exists
mock_client.expire = mock_expire
mock_client.hset = mock_hset
mock_client.hgetall = mock_hgetall
mock_client.hdel = mock_hdel
mock_client.hlen = mock_hlen
# Store references for test verification
mock_client._mock_data = mock_data
mock_client._mock_hashes = mock_hashes
mock_client._mock_expiry = mock_expiry
return mock_client
@pytest.fixture
def mock_time():
"""Mock time.time() for deterministic tests."""
mock_time_val = 1000.0
def increment_time(seconds=1):
nonlocal mock_time_val
mock_time_val += seconds
return mock_time_val
with patch("time.time", return_value=mock_time_val) as mock:
mock.increment = increment_time
yield mock
@pytest.fixture
def sample_generator():
"""Sample generator for testing RateLimitGenerator."""
def _create_generator(items=None, raise_error=False):
items = items or ["item1", "item2", "item3"]
for item in items:
if raise_error and item == "item2":
raise ValueError("Test error")
yield item
return _create_generator
@pytest.fixture
def sample_mapping():
"""Sample mapping for testing RateLimitGenerator."""
return {"key1": "value1", "key2": "value2"}
@pytest.fixture(autouse=True)
def reset_rate_limit_instances():
"""Clear RateLimit singleton instances between tests."""
RateLimit._instance_dict.clear()
yield
RateLimit._instance_dict.clear()
@pytest.fixture
def redis_patch():
"""Patch redis_client globally for rate limit tests."""
with patch("core.app.features.rate_limiting.rate_limit.redis_client") as mock:
yield mock

View File

@ -0,0 +1,569 @@
import threading
import time
from datetime import timedelta
from unittest.mock import patch
import pytest
from core.app.features.rate_limiting.rate_limit import RateLimit
from core.errors.error import AppInvokeQuotaExceededError
class TestRateLimit:
"""Core rate limiting functionality tests."""
def test_should_return_same_instance_for_same_client_id(self, redis_patch):
"""Test singleton behavior for same client ID."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
}
)
rate_limit1 = RateLimit("client1", 5)
rate_limit2 = RateLimit("client1", 10) # Second instance with different limit
assert rate_limit1 is rate_limit2
# Current implementation: last constructor call overwrites max_active_requests
# This reflects the actual behavior where __init__ always sets max_active_requests
assert rate_limit1.max_active_requests == 10
def test_should_create_different_instances_for_different_client_ids(self, redis_patch):
"""Test different instances for different client IDs."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
}
)
rate_limit1 = RateLimit("client1", 5)
rate_limit2 = RateLimit("client2", 10)
assert rate_limit1 is not rate_limit2
assert rate_limit1.client_id == "client1"
assert rate_limit2.client_id == "client2"
def test_should_initialize_with_valid_parameters(self, redis_patch):
"""Test normal initialization."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
}
)
rate_limit = RateLimit("test_client", 5)
assert rate_limit.client_id == "test_client"
assert rate_limit.max_active_requests == 5
assert hasattr(rate_limit, "initialized")
redis_patch.setex.assert_called_once()
def test_should_skip_initialization_if_disabled(self):
"""Test no initialization when rate limiting is disabled."""
rate_limit = RateLimit("test_client", 0)
assert rate_limit.disabled()
assert not hasattr(rate_limit, "initialized")
def test_should_skip_reinitialization_of_existing_instance(self, redis_patch):
"""Test that existing instance doesn't reinitialize."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
}
)
RateLimit("client1", 5)
redis_patch.reset_mock()
RateLimit("client1", 10)
redis_patch.setex.assert_not_called()
def test_should_be_disabled_when_max_requests_is_zero_or_negative(self):
"""Test disabled state for zero or negative limits."""
rate_limit_zero = RateLimit("client1", 0)
rate_limit_negative = RateLimit("client2", -5)
assert rate_limit_zero.disabled()
assert rate_limit_negative.disabled()
def test_should_set_redis_keys_on_first_flush(self, redis_patch):
"""Test Redis keys are set correctly on initial flush."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
}
)
rate_limit = RateLimit("test_client", 5)
expected_max_key = "dify:rate_limit:test_client:max_active_requests"
redis_patch.setex.assert_called_with(expected_max_key, timedelta(days=1), 5)
def test_should_sync_max_requests_from_redis_on_subsequent_flush(self, redis_patch):
"""Test max requests syncs from Redis when key exists."""
redis_patch.configure_mock(
**{
"exists.return_value": True,
"get.return_value": b"10",
"expire.return_value": True,
}
)
rate_limit = RateLimit("test_client", 5)
rate_limit.flush_cache()
assert rate_limit.max_active_requests == 10
@patch("time.time")
def test_should_clean_timeout_requests_from_active_list(self, mock_time, redis_patch):
"""Test cleanup of timed-out requests."""
current_time = 1000.0
mock_time.return_value = current_time
# Setup mock Redis with timed-out requests
timeout_requests = {
b"req1": str(current_time - 700).encode(), # 700 seconds ago (timeout)
b"req2": str(current_time - 100).encode(), # 100 seconds ago (active)
}
redis_patch.configure_mock(
**{
"exists.return_value": True,
"get.return_value": b"5",
"expire.return_value": True,
"hgetall.return_value": timeout_requests,
"hdel.return_value": 1,
}
)
rate_limit = RateLimit("test_client", 5)
redis_patch.reset_mock() # Reset to avoid counting initialization calls
rate_limit.flush_cache()
# Verify timeout request was cleaned up
redis_patch.hdel.assert_called_once()
call_args = redis_patch.hdel.call_args[0]
assert call_args[0] == "dify:rate_limit:test_client:active_requests"
assert b"req1" in call_args # Timeout request should be removed
assert b"req2" not in call_args # Active request should remain
class TestRateLimitEnterExit:
"""Rate limiting enter/exit logic tests."""
def test_should_allow_request_within_limit(self, redis_patch):
"""Test allowing requests within the rate limit."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
"hlen.return_value": 2,
"hset.return_value": True,
}
)
rate_limit = RateLimit("test_client", 5)
request_id = rate_limit.enter()
assert request_id != RateLimit._UNLIMITED_REQUEST_ID
redis_patch.hset.assert_called_once()
def test_should_generate_request_id_if_not_provided(self, redis_patch):
"""Test auto-generation of request ID."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
"hlen.return_value": 0,
"hset.return_value": True,
}
)
rate_limit = RateLimit("test_client", 5)
request_id = rate_limit.enter()
assert len(request_id) == 36 # UUID format
def test_should_use_provided_request_id(self, redis_patch):
"""Test using provided request ID."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
"hlen.return_value": 0,
"hset.return_value": True,
}
)
rate_limit = RateLimit("test_client", 5)
custom_id = "custom_request_123"
request_id = rate_limit.enter(custom_id)
assert request_id == custom_id
def test_should_remove_request_on_exit(self, redis_patch):
"""Test request removal on exit."""
redis_patch.configure_mock(
**{
"hdel.return_value": 1,
}
)
rate_limit = RateLimit("test_client", 5)
rate_limit.exit("test_request_id")
redis_patch.hdel.assert_called_once_with("dify:rate_limit:test_client:active_requests", "test_request_id")
def test_should_raise_quota_exceeded_when_at_limit(self, redis_patch):
"""Test quota exceeded error when at limit."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
"hlen.return_value": 5, # At limit
}
)
rate_limit = RateLimit("test_client", 5)
with pytest.raises(AppInvokeQuotaExceededError) as exc_info:
rate_limit.enter()
assert "Too many requests" in str(exc_info.value)
assert "test_client" in str(exc_info.value)
def test_should_allow_request_after_previous_exit(self, redis_patch):
"""Test allowing new request after previous exit."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
"hlen.return_value": 4, # Under limit after exit
"hset.return_value": True,
"hdel.return_value": 1,
}
)
rate_limit = RateLimit("test_client", 5)
request_id = rate_limit.enter()
rate_limit.exit(request_id)
new_request_id = rate_limit.enter()
assert new_request_id is not None
@patch("time.time")
def test_should_flush_cache_when_interval_exceeded(self, mock_time, redis_patch):
"""Test cache flush when time interval exceeded."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
"hlen.return_value": 0,
}
)
mock_time.return_value = 1000.0
rate_limit = RateLimit("test_client", 5)
# Advance time beyond flush interval
mock_time.return_value = 1400.0 # 400 seconds later
redis_patch.reset_mock()
rate_limit.enter()
# Should have called setex again due to cache flush
redis_patch.setex.assert_called()
def test_should_return_unlimited_id_when_disabled(self):
"""Test unlimited ID return when rate limiting disabled."""
rate_limit = RateLimit("test_client", 0)
request_id = rate_limit.enter()
assert request_id == RateLimit._UNLIMITED_REQUEST_ID
def test_should_ignore_exit_for_unlimited_requests(self, redis_patch):
"""Test ignoring exit for unlimited requests."""
rate_limit = RateLimit("test_client", 0)
rate_limit.exit(RateLimit._UNLIMITED_REQUEST_ID)
redis_patch.hdel.assert_not_called()
class TestRateLimitGenerator:
"""Rate limit generator wrapper tests."""
def test_should_wrap_generator_and_iterate_normally(self, redis_patch, sample_generator):
"""Test normal generator iteration with rate limit wrapper."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
"hdel.return_value": 1,
}
)
rate_limit = RateLimit("test_client", 5)
generator = sample_generator()
request_id = "test_request"
wrapped_gen = rate_limit.generate(generator, request_id)
result = list(wrapped_gen)
assert result == ["item1", "item2", "item3"]
redis_patch.hdel.assert_called_once_with("dify:rate_limit:test_client:active_requests", request_id)
def test_should_handle_mapping_input_directly(self, sample_mapping):
"""Test direct return of mapping input."""
rate_limit = RateLimit("test_client", 0) # Disabled
result = rate_limit.generate(sample_mapping, "test_request")
assert result is sample_mapping
def test_should_cleanup_on_exception_during_iteration(self, redis_patch, sample_generator):
"""Test cleanup when exception occurs during iteration."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
"hdel.return_value": 1,
}
)
rate_limit = RateLimit("test_client", 5)
generator = sample_generator(raise_error=True)
request_id = "test_request"
wrapped_gen = rate_limit.generate(generator, request_id)
with pytest.raises(ValueError):
list(wrapped_gen)
redis_patch.hdel.assert_called_once_with("dify:rate_limit:test_client:active_requests", request_id)
def test_should_cleanup_on_explicit_close(self, redis_patch, sample_generator):
"""Test cleanup on explicit generator close."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
"hdel.return_value": 1,
}
)
rate_limit = RateLimit("test_client", 5)
generator = sample_generator()
request_id = "test_request"
wrapped_gen = rate_limit.generate(generator, request_id)
wrapped_gen.close()
redis_patch.hdel.assert_called_once()
def test_should_handle_generator_without_close_method(self, redis_patch):
"""Test handling generator without close method."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
"hdel.return_value": 1,
}
)
# Create a generator-like object without close method
class SimpleGenerator:
def __init__(self):
self.items = ["test"]
self.index = 0
def __iter__(self):
return self
def __next__(self):
if self.index >= len(self.items):
raise StopIteration
item = self.items[self.index]
self.index += 1
return item
rate_limit = RateLimit("test_client", 5)
generator = SimpleGenerator()
wrapped_gen = rate_limit.generate(generator, "test_request")
wrapped_gen.close() # Should not raise error
redis_patch.hdel.assert_called_once()
def test_should_prevent_iteration_after_close(self, redis_patch, sample_generator):
"""Test StopIteration after generator is closed."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
"hdel.return_value": 1,
}
)
rate_limit = RateLimit("test_client", 5)
generator = sample_generator()
wrapped_gen = rate_limit.generate(generator, "test_request")
wrapped_gen.close()
with pytest.raises(StopIteration):
next(wrapped_gen)
class TestRateLimitConcurrency:
"""Concurrent access safety tests."""
def test_should_handle_concurrent_instance_creation(self, redis_patch):
"""Test thread-safe singleton instance creation."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
}
)
instances = []
errors = []
def create_instance():
try:
instance = RateLimit("concurrent_client", 5)
instances.append(instance)
except Exception as e:
errors.append(e)
threads = [threading.Thread(target=create_instance) for _ in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
assert len(errors) == 0
assert len({id(inst) for inst in instances}) == 1 # All same instance
def test_should_handle_concurrent_enter_requests(self, redis_patch):
"""Test concurrent enter requests handling."""
# Setup mock to simulate realistic Redis behavior
request_count = 0
def mock_hlen(key):
nonlocal request_count
return request_count
def mock_hset(key, field, value):
nonlocal request_count
request_count += 1
return True
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
"hlen.side_effect": mock_hlen,
"hset.side_effect": mock_hset,
}
)
rate_limit = RateLimit("concurrent_client", 3)
results = []
errors = []
def try_enter():
try:
request_id = rate_limit.enter()
results.append(request_id)
except AppInvokeQuotaExceededError as e:
errors.append(e)
threads = [threading.Thread(target=try_enter) for _ in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()
# Should have some successful requests and some quota exceeded
assert len(results) + len(errors) == 5
assert len(errors) > 0 # Some should be rejected
@patch("time.time")
def test_should_maintain_accurate_count_under_load(self, mock_time, redis_patch):
"""Test accurate count maintenance under concurrent load."""
mock_time.return_value = 1000.0
# Use real mock_redis fixture for better simulation
mock_client = self._create_mock_redis()
redis_patch.configure_mock(**mock_client)
rate_limit = RateLimit("load_test_client", 10)
active_requests = []
def enter_and_exit():
try:
request_id = rate_limit.enter()
active_requests.append(request_id)
time.sleep(0.01) # Simulate some work
rate_limit.exit(request_id)
active_requests.remove(request_id)
except AppInvokeQuotaExceededError:
pass # Expected under load
threads = [threading.Thread(target=enter_and_exit) for _ in range(20)]
for t in threads:
t.start()
for t in threads:
t.join()
# All requests should have been cleaned up
assert len(active_requests) == 0
def _create_mock_redis(self):
"""Create a thread-safe mock Redis for concurrency tests."""
import threading
lock = threading.Lock()
data = {}
hashes = {}
def mock_hlen(key):
with lock:
return len(hashes.get(key, {}))
def mock_hset(key, field, value):
with lock:
if key not in hashes:
hashes[key] = {}
hashes[key][field] = str(value).encode("utf-8")
return True
def mock_hdel(key, *fields):
with lock:
if key in hashes:
count = 0
for field in fields:
if field in hashes[key]:
del hashes[key][field]
count += 1
return count
return 0
return {
"exists.return_value": False,
"setex.return_value": True,
"hlen.side_effect": mock_hlen,
"hset.side_effect": mock_hset,
"hdel.side_effect": mock_hdel,
}

View File

@ -0,0 +1,247 @@
"""
Unit tests for CeleryWorkflowExecutionRepository.
These tests verify the Celery-based asynchronous storage functionality
for workflow execution data.
"""
from datetime import UTC, datetime
from unittest.mock import Mock, patch
from uuid import uuid4
import pytest
from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowType
from models import Account, EndUser
from models.enums import WorkflowRunTriggeredFrom
@pytest.fixture
def mock_session_factory():
"""Mock SQLAlchemy session factory."""
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
# Create a real sessionmaker with in-memory SQLite for testing
engine = create_engine("sqlite:///:memory:")
return sessionmaker(bind=engine)
@pytest.fixture
def mock_account():
"""Mock Account user."""
account = Mock(spec=Account)
account.id = str(uuid4())
account.current_tenant_id = str(uuid4())
return account
@pytest.fixture
def mock_end_user():
"""Mock EndUser."""
user = Mock(spec=EndUser)
user.id = str(uuid4())
user.tenant_id = str(uuid4())
return user
@pytest.fixture
def sample_workflow_execution():
"""Sample WorkflowExecution for testing."""
return WorkflowExecution.new(
id_=str(uuid4()),
workflow_id=str(uuid4()),
workflow_type=WorkflowType.WORKFLOW,
workflow_version="1.0",
graph={"nodes": [], "edges": []},
inputs={"input1": "value1"},
started_at=datetime.now(UTC).replace(tzinfo=None),
)
class TestCeleryWorkflowExecutionRepository:
"""Test cases for CeleryWorkflowExecutionRepository."""
def test_init_with_sessionmaker(self, mock_session_factory, mock_account):
"""Test repository initialization with sessionmaker."""
app_id = "test-app-id"
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id=app_id,
triggered_from=triggered_from,
)
assert repo._tenant_id == mock_account.current_tenant_id
assert repo._app_id == app_id
assert repo._triggered_from == triggered_from
assert repo._creator_user_id == mock_account.id
assert repo._creator_user_role is not None
def test_init_basic_functionality(self, mock_session_factory, mock_account):
"""Test repository initialization basic functionality."""
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Verify basic initialization
assert repo._tenant_id == mock_account.current_tenant_id
assert repo._app_id == "test-app"
assert repo._triggered_from == WorkflowRunTriggeredFrom.DEBUGGING
def test_init_with_end_user(self, mock_session_factory, mock_end_user):
"""Test repository initialization with EndUser."""
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_end_user,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
assert repo._tenant_id == mock_end_user.tenant_id
def test_init_without_tenant_id_raises_error(self, mock_session_factory):
"""Test that initialization fails without tenant_id."""
# Create a mock Account with no tenant_id
user = Mock(spec=Account)
user.current_tenant_id = None
user.id = str(uuid4())
with pytest.raises(ValueError, match="User must have a tenant_id"):
CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=user,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
def test_save_queues_celery_task(self, mock_task, mock_session_factory, mock_account, sample_workflow_execution):
"""Test that save operation queues a Celery task without tracking."""
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
repo.save(sample_workflow_execution)
# Verify Celery task was queued with correct parameters
mock_task.delay.assert_called_once()
call_args = mock_task.delay.call_args[1]
assert call_args["execution_data"] == sample_workflow_execution.model_dump()
assert call_args["tenant_id"] == mock_account.current_tenant_id
assert call_args["app_id"] == "test-app"
assert call_args["triggered_from"] == WorkflowRunTriggeredFrom.APP_RUN.value
assert call_args["creator_user_id"] == mock_account.id
# Verify no task tracking occurs (no _pending_saves attribute)
assert not hasattr(repo, "_pending_saves")
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
def test_save_handles_celery_failure(
self, mock_task, mock_session_factory, mock_account, sample_workflow_execution
):
"""Test that save operation handles Celery task failures."""
mock_task.delay.side_effect = Exception("Celery is down")
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
with pytest.raises(Exception, match="Celery is down"):
repo.save(sample_workflow_execution)
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
def test_save_operation_fire_and_forget(
self, mock_task, mock_session_factory, mock_account, sample_workflow_execution
):
"""Test that save operation works in fire-and-forget mode."""
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
# Test that save doesn't block or maintain state
repo.save(sample_workflow_execution)
# Verify no pending saves are tracked (no _pending_saves attribute)
assert not hasattr(repo, "_pending_saves")
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
def test_multiple_save_operations(self, mock_task, mock_session_factory, mock_account):
"""Test multiple save operations work correctly."""
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
# Create multiple executions
exec1 = WorkflowExecution.new(
id_=str(uuid4()),
workflow_id=str(uuid4()),
workflow_type=WorkflowType.WORKFLOW,
workflow_version="1.0",
graph={"nodes": [], "edges": []},
inputs={"input1": "value1"},
started_at=datetime.now(UTC).replace(tzinfo=None),
)
exec2 = WorkflowExecution.new(
id_=str(uuid4()),
workflow_id=str(uuid4()),
workflow_type=WorkflowType.WORKFLOW,
workflow_version="1.0",
graph={"nodes": [], "edges": []},
inputs={"input2": "value2"},
started_at=datetime.now(UTC).replace(tzinfo=None),
)
# Save both executions
repo.save(exec1)
repo.save(exec2)
# Should work without issues and not maintain state (no _pending_saves attribute)
assert not hasattr(repo, "_pending_saves")
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
def test_save_with_different_user_types(self, mock_task, mock_session_factory, mock_end_user):
"""Test save operation with different user types."""
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_end_user,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
execution = WorkflowExecution.new(
id_=str(uuid4()),
workflow_id=str(uuid4()),
workflow_type=WorkflowType.WORKFLOW,
workflow_version="1.0",
graph={"nodes": [], "edges": []},
inputs={"input1": "value1"},
started_at=datetime.now(UTC).replace(tzinfo=None),
)
repo.save(execution)
# Verify task was called with EndUser context
mock_task.delay.assert_called_once()
call_args = mock_task.delay.call_args[1]
assert call_args["tenant_id"] == mock_end_user.tenant_id
assert call_args["creator_user_id"] == mock_end_user.id

View File

@ -0,0 +1,349 @@
"""
Unit tests for CeleryWorkflowNodeExecutionRepository.
These tests verify the Celery-based asynchronous storage functionality
for workflow node execution data.
"""
from datetime import UTC, datetime
from unittest.mock import Mock, patch
from uuid import uuid4
import pytest
from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
)
from core.workflow.nodes.enums import NodeType
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig
from models import Account, EndUser
from models.workflow import WorkflowNodeExecutionTriggeredFrom
@pytest.fixture
def mock_session_factory():
"""Mock SQLAlchemy session factory."""
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
# Create a real sessionmaker with in-memory SQLite for testing
engine = create_engine("sqlite:///:memory:")
return sessionmaker(bind=engine)
@pytest.fixture
def mock_account():
"""Mock Account user."""
account = Mock(spec=Account)
account.id = str(uuid4())
account.current_tenant_id = str(uuid4())
return account
@pytest.fixture
def mock_end_user():
"""Mock EndUser."""
user = Mock(spec=EndUser)
user.id = str(uuid4())
user.tenant_id = str(uuid4())
return user
@pytest.fixture
def sample_workflow_node_execution():
"""Sample WorkflowNodeExecution for testing."""
return WorkflowNodeExecution(
id=str(uuid4()),
node_execution_id=str(uuid4()),
workflow_id=str(uuid4()),
workflow_execution_id=str(uuid4()),
index=1,
node_id="test_node",
node_type=NodeType.START,
title="Test Node",
inputs={"input1": "value1"},
status=WorkflowNodeExecutionStatus.RUNNING,
created_at=datetime.now(UTC).replace(tzinfo=None),
)
class TestCeleryWorkflowNodeExecutionRepository:
"""Test cases for CeleryWorkflowNodeExecutionRepository."""
def test_init_with_sessionmaker(self, mock_session_factory, mock_account):
"""Test repository initialization with sessionmaker."""
app_id = "test-app-id"
triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id=app_id,
triggered_from=triggered_from,
)
assert repo._tenant_id == mock_account.current_tenant_id
assert repo._app_id == app_id
assert repo._triggered_from == triggered_from
assert repo._creator_user_id == mock_account.id
assert repo._creator_user_role is not None
def test_init_with_cache_initialized(self, mock_session_factory, mock_account):
"""Test repository initialization with cache properly initialized."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
assert repo._execution_cache == {}
assert repo._workflow_execution_mapping == {}
def test_init_with_end_user(self, mock_session_factory, mock_end_user):
"""Test repository initialization with EndUser."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_end_user,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
assert repo._tenant_id == mock_end_user.tenant_id
def test_init_without_tenant_id_raises_error(self, mock_session_factory):
"""Test that initialization fails without tenant_id."""
# Create a mock Account with no tenant_id
user = Mock(spec=Account)
user.current_tenant_id = None
user.id = str(uuid4())
with pytest.raises(ValueError, match="User must have a tenant_id"):
CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=user,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
def test_save_caches_and_queues_celery_task(
self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution
):
"""Test that save operation caches execution and queues a Celery task."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
repo.save(sample_workflow_node_execution)
# Verify Celery task was queued with correct parameters
mock_task.delay.assert_called_once()
call_args = mock_task.delay.call_args[1]
assert call_args["execution_data"] == sample_workflow_node_execution.model_dump()
assert call_args["tenant_id"] == mock_account.current_tenant_id
assert call_args["app_id"] == "test-app"
assert call_args["triggered_from"] == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
assert call_args["creator_user_id"] == mock_account.id
# Verify execution is cached
assert sample_workflow_node_execution.id in repo._execution_cache
assert repo._execution_cache[sample_workflow_node_execution.id] == sample_workflow_node_execution
# Verify workflow execution mapping is updated
assert sample_workflow_node_execution.workflow_execution_id in repo._workflow_execution_mapping
assert (
sample_workflow_node_execution.id
in repo._workflow_execution_mapping[sample_workflow_node_execution.workflow_execution_id]
)
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
def test_save_handles_celery_failure(
self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution
):
"""Test that save operation handles Celery task failures."""
mock_task.delay.side_effect = Exception("Celery is down")
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
with pytest.raises(Exception, match="Celery is down"):
repo.save(sample_workflow_node_execution)
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
def test_get_by_workflow_run_from_cache(
self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution
):
"""Test that get_by_workflow_run retrieves executions from cache."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# Save execution to cache first
repo.save(sample_workflow_node_execution)
workflow_run_id = sample_workflow_node_execution.workflow_execution_id
order_config = OrderConfig(order_by=["index"], order_direction="asc")
result = repo.get_by_workflow_run(workflow_run_id, order_config)
# Verify results were retrieved from cache
assert len(result) == 1
assert result[0].id == sample_workflow_node_execution.id
assert result[0] is sample_workflow_node_execution
def test_get_by_workflow_run_without_order_config(self, mock_session_factory, mock_account):
"""Test get_by_workflow_run without order configuration."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
result = repo.get_by_workflow_run("workflow-run-id")
# Should return empty list since nothing in cache
assert len(result) == 0
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
def test_cache_operations(self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution):
"""Test cache operations work correctly."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# Test saving to cache
repo.save(sample_workflow_node_execution)
# Verify cache contains the execution
assert sample_workflow_node_execution.id in repo._execution_cache
# Test retrieving from cache
result = repo.get_by_workflow_run(sample_workflow_node_execution.workflow_execution_id)
assert len(result) == 1
assert result[0].id == sample_workflow_node_execution.id
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
def test_multiple_executions_same_workflow(self, mock_task, mock_session_factory, mock_account):
"""Test multiple executions for the same workflow."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# Create multiple executions for the same workflow
workflow_run_id = str(uuid4())
exec1 = WorkflowNodeExecution(
id=str(uuid4()),
node_execution_id=str(uuid4()),
workflow_id=str(uuid4()),
workflow_execution_id=workflow_run_id,
index=1,
node_id="node1",
node_type=NodeType.START,
title="Node 1",
inputs={"input1": "value1"},
status=WorkflowNodeExecutionStatus.RUNNING,
created_at=datetime.now(UTC).replace(tzinfo=None),
)
exec2 = WorkflowNodeExecution(
id=str(uuid4()),
node_execution_id=str(uuid4()),
workflow_id=str(uuid4()),
workflow_execution_id=workflow_run_id,
index=2,
node_id="node2",
node_type=NodeType.LLM,
title="Node 2",
inputs={"input2": "value2"},
status=WorkflowNodeExecutionStatus.RUNNING,
created_at=datetime.now(UTC).replace(tzinfo=None),
)
# Save both executions
repo.save(exec1)
repo.save(exec2)
# Verify both are cached and mapped
assert len(repo._execution_cache) == 2
assert len(repo._workflow_execution_mapping[workflow_run_id]) == 2
# Test retrieval
result = repo.get_by_workflow_run(workflow_run_id)
assert len(result) == 2
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
def test_ordering_functionality(self, mock_task, mock_session_factory, mock_account):
"""Test ordering functionality works correctly."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# Create executions with different indices
workflow_run_id = str(uuid4())
exec1 = WorkflowNodeExecution(
id=str(uuid4()),
node_execution_id=str(uuid4()),
workflow_id=str(uuid4()),
workflow_execution_id=workflow_run_id,
index=2,
node_id="node2",
node_type=NodeType.START,
title="Node 2",
inputs={},
status=WorkflowNodeExecutionStatus.RUNNING,
created_at=datetime.now(UTC).replace(tzinfo=None),
)
exec2 = WorkflowNodeExecution(
id=str(uuid4()),
node_execution_id=str(uuid4()),
workflow_id=str(uuid4()),
workflow_execution_id=workflow_run_id,
index=1,
node_id="node1",
node_type=NodeType.LLM,
title="Node 1",
inputs={},
status=WorkflowNodeExecutionStatus.RUNNING,
created_at=datetime.now(UTC).replace(tzinfo=None),
)
# Save in random order
repo.save(exec1)
repo.save(exec2)
# Test ascending order
order_config = OrderConfig(order_by=["index"], order_direction="asc")
result = repo.get_by_workflow_run(workflow_run_id, order_config)
assert len(result) == 2
assert result[0].index == 1
assert result[1].index == 2
# Test descending order
order_config = OrderConfig(order_by=["index"], order_direction="desc")
result = repo.get_by_workflow_run(workflow_run_id, order_config)
assert len(result) == 2
assert result[0].index == 2
assert result[1].index == 1

View File

@ -59,7 +59,7 @@ class TestRepositoryFactory:
def get_by_id(self):
pass
# Create a mock interface with the same methods
# Create a mock interface class
class MockInterface:
def save(self):
pass
@ -67,20 +67,20 @@ class TestRepositoryFactory:
def get_by_id(self):
pass
# Should not raise an exception
# Should not raise an exception when all methods are present
DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface)
def test_validate_repository_interface_missing_methods(self):
"""Test interface validation with missing methods."""
# Create a mock class that doesn't implement all required methods
# Create a mock class that's missing required methods
class IncompleteRepository:
def save(self):
pass
# Missing get_by_id method
# Create a mock interface with required methods
# Create a mock interface that requires both methods
class MockInterface:
def save(self):
pass
@ -88,57 +88,39 @@ class TestRepositoryFactory:
def get_by_id(self):
pass
def missing_method(self):
pass
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory._validate_repository_interface(IncompleteRepository, MockInterface)
assert "does not implement required methods" in str(exc_info.value)
assert "get_by_id" in str(exc_info.value)
def test_validate_constructor_signature_success(self):
"""Test successful constructor signature validation."""
def test_validate_repository_interface_with_private_methods(self):
"""Test that private methods are ignored during interface validation."""
class MockRepository:
def __init__(self, session_factory, user, app_id, triggered_from):
def save(self):
pass
# Should not raise an exception
DifyCoreRepositoryFactory._validate_constructor_signature(
MockRepository, ["session_factory", "user", "app_id", "triggered_from"]
)
def test_validate_constructor_signature_missing_params(self):
"""Test constructor validation with missing parameters."""
class IncompleteRepository:
def __init__(self, session_factory, user):
# Missing app_id and triggered_from parameters
def _private_method(self):
pass
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory._validate_constructor_signature(
IncompleteRepository, ["session_factory", "user", "app_id", "triggered_from"]
)
assert "does not accept required parameters" in str(exc_info.value)
assert "app_id" in str(exc_info.value)
assert "triggered_from" in str(exc_info.value)
def test_validate_constructor_signature_inspection_error(self, mocker: MockerFixture):
"""Test constructor validation when inspection fails."""
# Mock inspect.signature to raise an exception
mocker.patch("inspect.signature", side_effect=Exception("Inspection failed"))
class MockRepository:
def __init__(self, session_factory):
# Create a mock interface with private methods
class MockInterface:
def save(self):
pass
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory._validate_constructor_signature(MockRepository, ["session_factory"])
assert "Failed to validate constructor signature" in str(exc_info.value)
def _private_method(self):
pass
# Should not raise exception - private methods should be ignored
DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface)
@patch("core.repositories.factory.dify_config")
def test_create_workflow_execution_repository_success(self, mock_config, mocker: MockerFixture):
"""Test successful creation of WorkflowExecutionRepository."""
def test_create_workflow_execution_repository_success(self, mock_config):
"""Test successful WorkflowExecutionRepository creation."""
# Setup mock configuration
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
# Create mock dependencies
mock_session_factory = MagicMock(spec=sessionmaker)
@ -146,7 +128,7 @@ class TestRepositoryFactory:
app_id = "test-app-id"
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
# Mock the imported class to be a valid repository
# Create mock repository class and instance
mock_repository_class = MagicMock()
mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository)
mock_repository_class.return_value = mock_repository_instance
@ -155,7 +137,6 @@ class TestRepositoryFactory:
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
):
result = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=mock_session_factory,
@ -177,7 +158,7 @@ class TestRepositoryFactory:
def test_create_workflow_execution_repository_import_error(self, mock_config):
"""Test WorkflowExecutionRepository creation with import error."""
# Setup mock configuration with invalid class path
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "invalid.module.InvalidClass"
mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "invalid.module.InvalidClass"
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=Account)
@ -195,45 +176,46 @@ class TestRepositoryFactory:
def test_create_workflow_execution_repository_validation_error(self, mock_config, mocker: MockerFixture):
"""Test WorkflowExecutionRepository creation with validation error."""
# Setup mock configuration
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=Account)
# Mock import to succeed but validation to fail
# Mock the import to succeed but validation to fail
mock_repository_class = MagicMock()
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(
DifyCoreRepositoryFactory,
"_validate_repository_interface",
side_effect=RepositoryImportError("Interface validation failed"),
),
):
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
assert "Interface validation failed" in str(exc_info.value)
mocker.patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class)
mocker.patch.object(
DifyCoreRepositoryFactory,
"_validate_repository_interface",
side_effect=RepositoryImportError("Interface validation failed"),
)
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
assert "Interface validation failed" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
def test_create_workflow_execution_repository_instantiation_error(self, mock_config, mocker: MockerFixture):
def test_create_workflow_execution_repository_instantiation_error(self, mock_config):
"""Test WorkflowExecutionRepository creation with instantiation error."""
# Setup mock configuration
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=Account)
# Mock import and validation to succeed but instantiation to fail
mock_repository_class = MagicMock(side_effect=Exception("Instantiation failed"))
# Create a mock repository class that raises exception on instantiation
mock_repository_class = MagicMock()
mock_repository_class.side_effect = Exception("Instantiation failed")
# Mock the validation methods to succeed
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
):
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_execution_repository(
@ -245,18 +227,18 @@ class TestRepositoryFactory:
assert "Failed to create WorkflowExecutionRepository" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
def test_create_workflow_node_execution_repository_success(self, mock_config, mocker: MockerFixture):
"""Test successful creation of WorkflowNodeExecutionRepository."""
def test_create_workflow_node_execution_repository_success(self, mock_config):
"""Test successful WorkflowNodeExecutionRepository creation."""
# Setup mock configuration
mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
# Create mock dependencies
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=EndUser)
app_id = "test-app-id"
triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP
# Mock the imported class to be a valid repository
# Create mock repository class and instance
mock_repository_class = MagicMock()
mock_repository_instance = MagicMock(spec=WorkflowNodeExecutionRepository)
mock_repository_class.return_value = mock_repository_instance
@ -265,7 +247,6 @@ class TestRepositoryFactory:
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
):
result = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=mock_session_factory,
@ -287,7 +268,7 @@ class TestRepositoryFactory:
def test_create_workflow_node_execution_repository_import_error(self, mock_config):
"""Test WorkflowNodeExecutionRepository creation with import error."""
# Setup mock configuration with invalid class path
mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "invalid.module.InvalidClass"
mock_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY = "invalid.module.InvalidClass"
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=EndUser)
@ -297,28 +278,83 @@ class TestRepositoryFactory:
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
assert "Cannot import repository class" in str(exc_info.value)
def test_repository_import_error_exception(self):
"""Test RepositoryImportError exception."""
error_message = "Test error message"
exception = RepositoryImportError(error_message)
assert str(exception) == error_message
assert isinstance(exception, Exception)
@patch("core.repositories.factory.dify_config")
def test_create_workflow_node_execution_repository_validation_error(self, mock_config, mocker: MockerFixture):
"""Test WorkflowNodeExecutionRepository creation with validation error."""
# Setup mock configuration
mock_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=EndUser)
# Mock the import to succeed but validation to fail
mock_repository_class = MagicMock()
mocker.patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class)
mocker.patch.object(
DifyCoreRepositoryFactory,
"_validate_repository_interface",
side_effect=RepositoryImportError("Interface validation failed"),
)
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
assert "Interface validation failed" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
def test_create_with_engine_instead_of_sessionmaker(self, mock_config, mocker: MockerFixture):
def test_create_workflow_node_execution_repository_instantiation_error(self, mock_config):
"""Test WorkflowNodeExecutionRepository creation with instantiation error."""
# Setup mock configuration
mock_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=EndUser)
# Create a mock repository class that raises exception on instantiation
mock_repository_class = MagicMock()
mock_repository_class.side_effect = Exception("Instantiation failed")
# Mock the validation methods to succeed
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
):
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
assert "Failed to create WorkflowNodeExecutionRepository" in str(exc_info.value)
def test_repository_import_error_exception(self):
"""Test RepositoryImportError exception handling."""
error_message = "Custom error message"
error = RepositoryImportError(error_message)
assert str(error) == error_message
@patch("core.repositories.factory.dify_config")
def test_create_with_engine_instead_of_sessionmaker(self, mock_config):
"""Test repository creation with Engine instead of sessionmaker."""
# Setup mock configuration
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
# Create mock dependencies with Engine instead of sessionmaker
# Create mock dependencies using Engine instead of sessionmaker
mock_engine = MagicMock(spec=Engine)
mock_user = MagicMock(spec=Account)
app_id = "test-app-id"
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
# Mock the imported class to be a valid repository
# Create mock repository class and instance
mock_repository_class = MagicMock()
mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository)
mock_repository_class.return_value = mock_repository_instance
@ -327,129 +363,19 @@ class TestRepositoryFactory:
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
):
result = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=mock_engine, # Using Engine instead of sessionmaker
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
app_id=app_id,
triggered_from=triggered_from,
)
# Verify the repository was created with the Engine
# Verify the repository was created with correct parameters
mock_repository_class.assert_called_once_with(
session_factory=mock_engine,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
app_id=app_id,
triggered_from=triggered_from,
)
assert result is mock_repository_instance
@patch("core.repositories.factory.dify_config")
def test_create_workflow_node_execution_repository_validation_error(self, mock_config):
"""Test WorkflowNodeExecutionRepository creation with validation error."""
# Setup mock configuration
mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=EndUser)
# Mock import to succeed but validation to fail
mock_repository_class = MagicMock()
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(
DifyCoreRepositoryFactory,
"_validate_repository_interface",
side_effect=RepositoryImportError("Interface validation failed"),
),
):
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
assert "Interface validation failed" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
def test_create_workflow_node_execution_repository_instantiation_error(self, mock_config):
"""Test WorkflowNodeExecutionRepository creation with instantiation error."""
# Setup mock configuration
mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=EndUser)
# Mock import and validation to succeed but instantiation to fail
mock_repository_class = MagicMock(side_effect=Exception("Instantiation failed"))
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
):
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
assert "Failed to create WorkflowNodeExecutionRepository" in str(exc_info.value)
def test_validate_repository_interface_with_private_methods(self):
"""Test interface validation ignores private methods."""
# Create a mock class with private methods
class MockRepository:
def save(self):
pass
def get_by_id(self):
pass
def _private_method(self):
pass
# Create a mock interface with private methods
class MockInterface:
def save(self):
pass
def get_by_id(self):
pass
def _private_method(self):
pass
# Should not raise an exception (private methods are ignored)
DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface)
def test_validate_constructor_signature_with_extra_params(self):
"""Test constructor validation with extra parameters (should pass)."""
class MockRepository:
def __init__(self, session_factory, user, app_id, triggered_from, extra_param=None):
pass
# Should not raise an exception (extra parameters are allowed)
DifyCoreRepositoryFactory._validate_constructor_signature(
MockRepository, ["session_factory", "user", "app_id", "triggered_from"]
)
def test_validate_constructor_signature_with_kwargs(self):
"""Test constructor validation with **kwargs (current implementation doesn't support this)."""
class MockRepository:
def __init__(self, session_factory, user, **kwargs):
pass
# Current implementation doesn't handle **kwargs, so this should raise an exception
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory._validate_constructor_signature(
MockRepository, ["session_factory", "user", "app_id", "triggered_from"]
)
assert "does not accept required parameters" in str(exc_info.value)
assert "app_id" in str(exc_info.value)
assert "triggered_from" in str(exc_info.value)

View File

@ -69,8 +69,12 @@ def test_get_file_attribute(pool, file):
def test_use_long_selector(pool):
pool.add(("node_1", "part_1", "part_2"), StringSegment(value="test_value"))
# The add method now only accepts 2-element selectors (node_id, variable_name)
# Store nested data as an ObjectSegment instead
nested_data = {"part_2": "test_value"}
pool.add(("node_1", "part_1"), ObjectSegment(value=nested_data))
# The get method supports longer selectors for nested access
result = pool.get(("node_1", "part_1", "part_2"))
assert result is not None
assert result.value == "test_value"
@ -280,8 +284,10 @@ class TestVariablePoolSerialization:
pool.add((self._NODE2_ID, "array_file"), ArrayFileSegment(value=[test_file]))
pool.add((self._NODE2_ID, "array_any"), ArrayAnySegment(value=["mixed", 123, {"key": "value"}]))
# Add nested variables
pool.add((self._NODE3_ID, "nested", "deep", "var"), StringSegment(value="deep_value"))
# Add nested variables as ObjectSegment
# The add method only accepts 2-element selectors
nested_obj = {"deep": {"var": "deep_value"}}
pool.add((self._NODE3_ID, "nested"), ObjectSegment(value=nested_obj))
def test_system_variables(self):
sys_vars = SystemVariable(

View File

@ -1,148 +0,0 @@
from typing import Any
from core.variables.segments import ObjectSegment, StringSegment
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.utils.variable_utils import append_variables_recursively
class TestAppendVariablesRecursively:
"""Test cases for append_variables_recursively function"""
def test_append_simple_dict_value(self):
"""Test appending a simple dictionary value"""
pool = VariablePool.empty()
node_id = "test_node"
variable_key_list = ["output"]
variable_value = {"name": "John", "age": 30}
append_variables_recursively(pool, node_id, variable_key_list, variable_value)
# Check that the main variable is added
main_var = pool.get([node_id] + variable_key_list)
assert main_var is not None
assert main_var.value == variable_value
# Check that nested variables are added recursively
name_var = pool.get([node_id] + variable_key_list + ["name"])
assert name_var is not None
assert name_var.value == "John"
age_var = pool.get([node_id] + variable_key_list + ["age"])
assert age_var is not None
assert age_var.value == 30
def test_append_object_segment_value(self):
"""Test appending an ObjectSegment value"""
pool = VariablePool.empty()
node_id = "test_node"
variable_key_list = ["result"]
# Create an ObjectSegment
obj_data = {"status": "success", "code": 200}
variable_value = ObjectSegment(value=obj_data)
append_variables_recursively(pool, node_id, variable_key_list, variable_value)
# Check that the main variable is added
main_var = pool.get([node_id] + variable_key_list)
assert main_var is not None
assert isinstance(main_var, ObjectSegment)
assert main_var.value == obj_data
# Check that nested variables are added recursively
status_var = pool.get([node_id] + variable_key_list + ["status"])
assert status_var is not None
assert status_var.value == "success"
code_var = pool.get([node_id] + variable_key_list + ["code"])
assert code_var is not None
assert code_var.value == 200
def test_append_nested_dict_value(self):
"""Test appending a nested dictionary value"""
pool = VariablePool.empty()
node_id = "test_node"
variable_key_list = ["data"]
variable_value = {
"user": {
"profile": {"name": "Alice", "email": "alice@example.com"},
"settings": {"theme": "dark", "notifications": True},
},
"metadata": {"version": "1.0", "timestamp": 1234567890},
}
append_variables_recursively(pool, node_id, variable_key_list, variable_value)
# Check deeply nested variables
name_var = pool.get([node_id] + variable_key_list + ["user", "profile", "name"])
assert name_var is not None
assert name_var.value == "Alice"
email_var = pool.get([node_id] + variable_key_list + ["user", "profile", "email"])
assert email_var is not None
assert email_var.value == "alice@example.com"
theme_var = pool.get([node_id] + variable_key_list + ["user", "settings", "theme"])
assert theme_var is not None
assert theme_var.value == "dark"
notifications_var = pool.get([node_id] + variable_key_list + ["user", "settings", "notifications"])
assert notifications_var is not None
assert notifications_var.value == 1 # Boolean True is converted to integer 1
version_var = pool.get([node_id] + variable_key_list + ["metadata", "version"])
assert version_var is not None
assert version_var.value == "1.0"
def test_append_non_dict_value(self):
"""Test appending a non-dictionary value (should not recurse)"""
pool = VariablePool.empty()
node_id = "test_node"
variable_key_list = ["simple"]
variable_value = "simple_string"
append_variables_recursively(pool, node_id, variable_key_list, variable_value)
# Check that only the main variable is added
main_var = pool.get([node_id] + variable_key_list)
assert main_var is not None
assert main_var.value == variable_value
# Ensure no additional variables are created
assert len(pool.variable_dictionary[node_id]) == 1
def test_append_segment_non_object_value(self):
"""Test appending a Segment that is not ObjectSegment (should not recurse)"""
pool = VariablePool.empty()
node_id = "test_node"
variable_key_list = ["text"]
variable_value = StringSegment(value="Hello World")
append_variables_recursively(pool, node_id, variable_key_list, variable_value)
# Check that only the main variable is added
main_var = pool.get([node_id] + variable_key_list)
assert main_var is not None
assert isinstance(main_var, StringSegment)
assert main_var.value == "Hello World"
# Ensure no additional variables are created
assert len(pool.variable_dictionary[node_id]) == 1
def test_append_empty_dict_value(self):
"""Test appending an empty dictionary value"""
pool = VariablePool.empty()
node_id = "test_node"
variable_key_list = ["empty"]
variable_value: dict[str, Any] = {}
append_variables_recursively(pool, node_id, variable_key_list, variable_value)
# Check that the main variable is added
main_var = pool.get([node_id] + variable_key_list)
assert main_var is not None
assert main_var.value == {}
# Ensure only the main variable is created (no recursion for empty dict)
assert len(pool.variable_dictionary[node_id]) == 1

View File

@ -1,5 +1,5 @@
version = 1
revision = 3
revision = 2
requires-python = ">=3.11, <3.13"
resolution-markers = [
"python_full_version >= '3.12.4' and platform_python_implementation != 'PyPy' and sys_platform == 'linux'",
@ -1236,7 +1236,7 @@ wheels = [
[[package]]
name = "dify-api"
version = "1.7.1"
version = "1.7.2"
source = { virtual = "." }
dependencies = [
{ name = "arize-phoenix-otel" },
@ -1371,6 +1371,7 @@ dev = [
{ name = "types-python-http-client" },
{ name = "types-pywin32" },
{ name = "types-pyyaml" },
{ name = "types-redis" },
{ name = "types-regex" },
{ name = "types-requests" },
{ name = "types-requests-oauthlib" },
@ -1557,6 +1558,7 @@ dev = [
{ name = "types-python-http-client", specifier = ">=3.3.7.20240910" },
{ name = "types-pywin32", specifier = "~=310.0.0" },
{ name = "types-pyyaml", specifier = "~=6.0.12" },
{ name = "types-redis", specifier = ">=4.6.0.20241004" },
{ name = "types-regex", specifier = "~=2024.11.6" },
{ name = "types-requests", specifier = "~=2.32.0" },
{ name = "types-requests-oauthlib", specifier = "~=2.0.0" },
@ -6064,6 +6066,19 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/99/5f/e0af6f7f6a260d9af67e1db4f54d732abad514252a7a378a6c4d17dd1036/types_pyyaml-6.0.12.20250516-py3-none-any.whl", hash = "sha256:8478208feaeb53a34cb5d970c56a7cd76b72659442e733e268a94dc72b2d0530", size = 20312, upload-time = "2025-05-16T03:08:04.019Z" },
]
[[package]]
name = "types-redis"
version = "4.6.0.20241004"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "cryptography" },
{ name = "types-pyopenssl" },
]
sdist = { url = "https://files.pythonhosted.org/packages/3a/95/c054d3ac940e8bac4ca216470c80c26688a0e79e09f520a942bb27da3386/types-redis-4.6.0.20241004.tar.gz", hash = "sha256:5f17d2b3f9091ab75384153bfa276619ffa1cf6a38da60e10d5e6749cc5b902e", size = 49679, upload-time = "2024-10-04T02:43:59.224Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/55/82/7d25dce10aad92d2226b269bce2f85cfd843b4477cd50245d7d40ecf8f89/types_redis-4.6.0.20241004-py3-none-any.whl", hash = "sha256:ef5da68cb827e5f606c8f9c0b49eeee4c2669d6d97122f301d3a55dc6a63f6ed", size = 58737, upload-time = "2024-10-04T02:43:57.968Z" },
]
[[package]]
name = "types-regex"
version = "2024.11.6.20250403"

View File

@ -8,4 +8,4 @@ cd "$SCRIPT_DIR/.."
uv --directory api run \
celery -A app.celery worker \
-P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion
-P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage

View File

@ -861,17 +861,23 @@ WORKFLOW_NODE_EXECUTION_STORAGE=rdbms
# Repository configuration
# Core workflow execution repository implementation
# Options:
# - core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository (default)
# - core.repositories.celery_workflow_execution_repository.CeleryWorkflowExecutionRepository
CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository
# Core workflow node execution repository implementation
# Options:
# - core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository (default)
# - core.repositories.celery_workflow_node_execution_repository.CeleryWorkflowNodeExecutionRepository
CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository
# API workflow node execution repository implementation
API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository
# API workflow run repository implementation
API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository
# API workflow node execution repository implementation
API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository
# HTTP request node in workflow configuration
HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576
@ -907,6 +913,9 @@ TEXT_GENERATION_TIMEOUT_MS=60000
# Allow rendering unsafe URLs which have "data:" scheme.
ALLOW_UNSAFE_DATA_SCHEME=false
# Maximum number of tree depth in the workflow
MAX_TREE_DEPTH=50
# ------------------------------
# Environment Variables for db Service
# ------------------------------

View File

@ -2,7 +2,7 @@ x-shared-env: &shared-api-worker-env
services:
# API service
api:
image: langgenius/dify-api:1.7.1
image: langgenius/dify-api:1.7.2
restart: always
environment:
# Use the shared environment variables.
@ -31,7 +31,7 @@ services:
# worker service
# The Celery worker for processing the queue.
worker:
image: langgenius/dify-api:1.7.1
image: langgenius/dify-api:1.7.2
restart: always
environment:
# Use the shared environment variables.
@ -58,7 +58,7 @@ services:
# worker_beat service
# Celery beat for scheduling periodic tasks.
worker_beat:
image: langgenius/dify-api:1.7.1
image: langgenius/dify-api:1.7.2
restart: always
environment:
# Use the shared environment variables.
@ -76,7 +76,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:1.7.1
image: langgenius/dify-web:1.7.2
restart: always
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
@ -96,6 +96,7 @@ services:
MAX_TOOLS_NUM: ${MAX_TOOLS_NUM:-10}
MAX_PARALLEL_LIMIT: ${MAX_PARALLEL_LIMIT:-10}
MAX_ITERATIONS_NUM: ${MAX_ITERATIONS_NUM:-99}
MAX_TREE_DEPTH: ${MAX_TREE_DEPTH:-50}
ENABLE_WEBSITE_JINAREADER: ${ENABLE_WEBSITE_JINAREADER:-true}
ENABLE_WEBSITE_FIRECRAWL: ${ENABLE_WEBSITE_FIRECRAWL:-true}
ENABLE_WEBSITE_WATERCRAWL: ${ENABLE_WEBSITE_WATERCRAWL:-true}

View File

@ -390,8 +390,8 @@ x-shared-env: &shared-api-worker-env
WORKFLOW_NODE_EXECUTION_STORAGE: ${WORKFLOW_NODE_EXECUTION_STORAGE:-rdbms}
CORE_WORKFLOW_EXECUTION_REPOSITORY: ${CORE_WORKFLOW_EXECUTION_REPOSITORY:-core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository}
CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY:-core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository}
API_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${API_WORKFLOW_NODE_EXECUTION_REPOSITORY:-repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository}
API_WORKFLOW_RUN_REPOSITORY: ${API_WORKFLOW_RUN_REPOSITORY:-repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository}
API_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${API_WORKFLOW_NODE_EXECUTION_REPOSITORY:-repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository}
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760}
HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576}
HTTP_REQUEST_NODE_SSL_VERIFY: ${HTTP_REQUEST_NODE_SSL_VERIFY:-True}
@ -404,6 +404,7 @@ x-shared-env: &shared-api-worker-env
MAX_ITERATIONS_NUM: ${MAX_ITERATIONS_NUM:-99}
TEXT_GENERATION_TIMEOUT_MS: ${TEXT_GENERATION_TIMEOUT_MS:-60000}
ALLOW_UNSAFE_DATA_SCHEME: ${ALLOW_UNSAFE_DATA_SCHEME:-false}
MAX_TREE_DEPTH: ${MAX_TREE_DEPTH:-50}
POSTGRES_USER: ${POSTGRES_USER:-${DB_USERNAME}}
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-${DB_PASSWORD}}
POSTGRES_DB: ${POSTGRES_DB:-${DB_DATABASE}}
@ -567,7 +568,7 @@ x-shared-env: &shared-api-worker-env
services:
# API service
api:
image: langgenius/dify-api:1.7.1
image: langgenius/dify-api:1.7.2
restart: always
environment:
# Use the shared environment variables.
@ -596,7 +597,7 @@ services:
# worker service
# The Celery worker for processing the queue.
worker:
image: langgenius/dify-api:1.7.1
image: langgenius/dify-api:1.7.2
restart: always
environment:
# Use the shared environment variables.
@ -623,7 +624,7 @@ services:
# worker_beat service
# Celery beat for scheduling periodic tasks.
worker_beat:
image: langgenius/dify-api:1.7.1
image: langgenius/dify-api:1.7.2
restart: always
environment:
# Use the shared environment variables.
@ -641,7 +642,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:1.7.1
image: langgenius/dify-web:1.7.2
restart: always
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
@ -661,6 +662,7 @@ services:
MAX_TOOLS_NUM: ${MAX_TOOLS_NUM:-10}
MAX_PARALLEL_LIMIT: ${MAX_PARALLEL_LIMIT:-10}
MAX_ITERATIONS_NUM: ${MAX_ITERATIONS_NUM:-99}
MAX_TREE_DEPTH: ${MAX_TREE_DEPTH:-50}
ENABLE_WEBSITE_JINAREADER: ${ENABLE_WEBSITE_JINAREADER:-true}
ENABLE_WEBSITE_FIRECRAWL: ${ENABLE_WEBSITE_FIRECRAWL:-true}
ENABLE_WEBSITE_WATERCRAWL: ${ENABLE_WEBSITE_WATERCRAWL:-true}

View File

@ -0,0 +1,333 @@
import React from 'react'
import { fireEvent, render, screen } from '@testing-library/react'
import '@testing-library/jest-dom'
import CommandSelector from '../../app/components/goto-anything/command-selector'
import type { ActionItem } from '../../app/components/goto-anything/actions/types'
jest.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
jest.mock('cmdk', () => ({
Command: {
Group: ({ children, className }: any) => <div className={className}>{children}</div>,
Item: ({ children, onSelect, value, className }: any) => (
<div
className={className}
onClick={() => onSelect && onSelect()}
data-value={value}
data-testid={`command-item-${value}`}
>
{children}
</div>
),
},
}))
describe('CommandSelector', () => {
const mockActions: Record<string, ActionItem> = {
app: {
key: '@app',
shortcut: '@app',
title: 'Search Applications',
description: 'Search apps',
search: jest.fn(),
},
knowledge: {
key: '@knowledge',
shortcut: '@knowledge',
title: 'Search Knowledge',
description: 'Search knowledge bases',
search: jest.fn(),
},
plugin: {
key: '@plugin',
shortcut: '@plugin',
title: 'Search Plugins',
description: 'Search plugins',
search: jest.fn(),
},
node: {
key: '@node',
shortcut: '@node',
title: 'Search Nodes',
description: 'Search workflow nodes',
search: jest.fn(),
},
}
const mockOnCommandSelect = jest.fn()
const mockOnCommandValueChange = jest.fn()
beforeEach(() => {
jest.clearAllMocks()
})
describe('Basic Rendering', () => {
it('should render all actions when no filter is provided', () => {
render(
<CommandSelector
actions={mockActions}
onCommandSelect={mockOnCommandSelect}
/>,
)
expect(screen.getByTestId('command-item-@app')).toBeInTheDocument()
expect(screen.getByTestId('command-item-@knowledge')).toBeInTheDocument()
expect(screen.getByTestId('command-item-@plugin')).toBeInTheDocument()
expect(screen.getByTestId('command-item-@node')).toBeInTheDocument()
})
it('should render empty filter as showing all actions', () => {
render(
<CommandSelector
actions={mockActions}
onCommandSelect={mockOnCommandSelect}
searchFilter=""
/>,
)
expect(screen.getByTestId('command-item-@app')).toBeInTheDocument()
expect(screen.getByTestId('command-item-@knowledge')).toBeInTheDocument()
expect(screen.getByTestId('command-item-@plugin')).toBeInTheDocument()
expect(screen.getByTestId('command-item-@node')).toBeInTheDocument()
})
})
describe('Filtering Functionality', () => {
it('should filter actions based on searchFilter - single match', () => {
render(
<CommandSelector
actions={mockActions}
onCommandSelect={mockOnCommandSelect}
searchFilter="k"
/>,
)
expect(screen.queryByTestId('command-item-@app')).not.toBeInTheDocument()
expect(screen.getByTestId('command-item-@knowledge')).toBeInTheDocument()
expect(screen.queryByTestId('command-item-@plugin')).not.toBeInTheDocument()
expect(screen.queryByTestId('command-item-@node')).not.toBeInTheDocument()
})
it('should filter actions with multiple matches', () => {
render(
<CommandSelector
actions={mockActions}
onCommandSelect={mockOnCommandSelect}
searchFilter="p"
/>,
)
expect(screen.getByTestId('command-item-@app')).toBeInTheDocument()
expect(screen.queryByTestId('command-item-@knowledge')).not.toBeInTheDocument()
expect(screen.getByTestId('command-item-@plugin')).toBeInTheDocument()
expect(screen.queryByTestId('command-item-@node')).not.toBeInTheDocument()
})
it('should be case-insensitive when filtering', () => {
render(
<CommandSelector
actions={mockActions}
onCommandSelect={mockOnCommandSelect}
searchFilter="APP"
/>,
)
expect(screen.getByTestId('command-item-@app')).toBeInTheDocument()
expect(screen.queryByTestId('command-item-@knowledge')).not.toBeInTheDocument()
})
it('should match partial strings', () => {
render(
<CommandSelector
actions={mockActions}
onCommandSelect={mockOnCommandSelect}
searchFilter="nowl"
/>,
)
expect(screen.queryByTestId('command-item-@app')).not.toBeInTheDocument()
expect(screen.getByTestId('command-item-@knowledge')).toBeInTheDocument()
expect(screen.queryByTestId('command-item-@plugin')).not.toBeInTheDocument()
expect(screen.queryByTestId('command-item-@node')).not.toBeInTheDocument()
})
})
describe('Empty State', () => {
it('should show empty state when no matches found', () => {
render(
<CommandSelector
actions={mockActions}
onCommandSelect={mockOnCommandSelect}
searchFilter="xyz"
/>,
)
expect(screen.queryByTestId('command-item-@app')).not.toBeInTheDocument()
expect(screen.queryByTestId('command-item-@knowledge')).not.toBeInTheDocument()
expect(screen.queryByTestId('command-item-@plugin')).not.toBeInTheDocument()
expect(screen.queryByTestId('command-item-@node')).not.toBeInTheDocument()
expect(screen.getByText('app.gotoAnything.noMatchingCommands')).toBeInTheDocument()
expect(screen.getByText('app.gotoAnything.tryDifferentSearch')).toBeInTheDocument()
})
it('should not show empty state when filter is empty', () => {
render(
<CommandSelector
actions={mockActions}
onCommandSelect={mockOnCommandSelect}
searchFilter=""
/>,
)
expect(screen.queryByText('app.gotoAnything.noMatchingCommands')).not.toBeInTheDocument()
})
})
describe('Selection and Highlight Management', () => {
it('should call onCommandValueChange when filter changes and first item differs', () => {
const { rerender } = render(
<CommandSelector
actions={mockActions}
onCommandSelect={mockOnCommandSelect}
searchFilter=""
commandValue="@app"
onCommandValueChange={mockOnCommandValueChange}
/>,
)
rerender(
<CommandSelector
actions={mockActions}
onCommandSelect={mockOnCommandSelect}
searchFilter="k"
commandValue="@app"
onCommandValueChange={mockOnCommandValueChange}
/>,
)
expect(mockOnCommandValueChange).toHaveBeenCalledWith('@knowledge')
})
it('should not call onCommandValueChange if current value still exists', () => {
const { rerender } = render(
<CommandSelector
actions={mockActions}
onCommandSelect={mockOnCommandSelect}
searchFilter=""
commandValue="@app"
onCommandValueChange={mockOnCommandValueChange}
/>,
)
rerender(
<CommandSelector
actions={mockActions}
onCommandSelect={mockOnCommandSelect}
searchFilter="a"
commandValue="@app"
onCommandValueChange={mockOnCommandValueChange}
/>,
)
expect(mockOnCommandValueChange).not.toHaveBeenCalled()
})
it('should handle onCommandSelect callback correctly', () => {
render(
<CommandSelector
actions={mockActions}
onCommandSelect={mockOnCommandSelect}
searchFilter="k"
/>,
)
const knowledgeItem = screen.getByTestId('command-item-@knowledge')
fireEvent.click(knowledgeItem)
expect(mockOnCommandSelect).toHaveBeenCalledWith('@knowledge')
})
})
describe('Edge Cases', () => {
it('should handle empty actions object', () => {
render(
<CommandSelector
actions={{}}
onCommandSelect={mockOnCommandSelect}
searchFilter=""
/>,
)
expect(screen.getByText('app.gotoAnything.noMatchingCommands')).toBeInTheDocument()
})
it('should handle special characters in filter', () => {
render(
<CommandSelector
actions={mockActions}
onCommandSelect={mockOnCommandSelect}
searchFilter="@"
/>,
)
expect(screen.getByTestId('command-item-@app')).toBeInTheDocument()
expect(screen.getByTestId('command-item-@knowledge')).toBeInTheDocument()
expect(screen.getByTestId('command-item-@plugin')).toBeInTheDocument()
expect(screen.getByTestId('command-item-@node')).toBeInTheDocument()
})
it('should handle undefined onCommandValueChange gracefully', () => {
const { rerender } = render(
<CommandSelector
actions={mockActions}
onCommandSelect={mockOnCommandSelect}
searchFilter=""
/>,
)
expect(() => {
rerender(
<CommandSelector
actions={mockActions}
onCommandSelect={mockOnCommandSelect}
searchFilter="k"
/>,
)
}).not.toThrow()
})
})
describe('Backward Compatibility', () => {
it('should work without searchFilter prop (backward compatible)', () => {
render(
<CommandSelector
actions={mockActions}
onCommandSelect={mockOnCommandSelect}
/>,
)
expect(screen.getByTestId('command-item-@app')).toBeInTheDocument()
expect(screen.getByTestId('command-item-@knowledge')).toBeInTheDocument()
expect(screen.getByTestId('command-item-@plugin')).toBeInTheDocument()
expect(screen.getByTestId('command-item-@node')).toBeInTheDocument()
})
it('should work without commandValue and onCommandValueChange props', () => {
render(
<CommandSelector
actions={mockActions}
onCommandSelect={mockOnCommandSelect}
searchFilter="k"
/>,
)
expect(screen.getByTestId('command-item-@knowledge')).toBeInTheDocument()
expect(screen.queryByTestId('command-item-@app')).not.toBeInTheDocument()
})
})
})

View File

@ -0,0 +1,197 @@
/**
* Test GotoAnything search error handling mechanisms
*
* Main validations:
* 1. @plugin search error handling when API fails
* 2. Regular search (without @prefix) error handling when API fails
* 3. Verify consistent error handling across different search types
* 4. Ensure errors don't propagate to UI layer causing "search failed"
*/
import { Actions, searchAnything } from '@/app/components/goto-anything/actions'
import { postMarketplace } from '@/service/base'
import { fetchAppList } from '@/service/apps'
import { fetchDatasets } from '@/service/datasets'
// Mock API functions
jest.mock('@/service/base', () => ({
postMarketplace: jest.fn(),
}))
jest.mock('@/service/apps', () => ({
fetchAppList: jest.fn(),
}))
jest.mock('@/service/datasets', () => ({
fetchDatasets: jest.fn(),
}))
const mockPostMarketplace = postMarketplace as jest.MockedFunction<typeof postMarketplace>
const mockFetchAppList = fetchAppList as jest.MockedFunction<typeof fetchAppList>
const mockFetchDatasets = fetchDatasets as jest.MockedFunction<typeof fetchDatasets>
describe('GotoAnything Search Error Handling', () => {
beforeEach(() => {
jest.clearAllMocks()
// Suppress console.warn for clean test output
jest.spyOn(console, 'warn').mockImplementation(() => {
// Suppress console.warn for clean test output
})
})
afterEach(() => {
jest.restoreAllMocks()
})
describe('@plugin search error handling', () => {
it('should return empty array when API fails instead of throwing error', async () => {
// Mock marketplace API failure (403 permission denied)
mockPostMarketplace.mockRejectedValue(new Error('HTTP 403: Forbidden'))
const pluginAction = Actions.plugin
// Directly call plugin action's search method
const result = await pluginAction.search('@plugin', 'test', 'en')
// Should return empty array instead of throwing error
expect(result).toEqual([])
expect(mockPostMarketplace).toHaveBeenCalledWith('/plugins/search/advanced', {
body: {
page: 1,
page_size: 10,
query: 'test',
type: 'plugin',
},
})
})
it('should return empty array when user has no plugin data', async () => {
// Mock marketplace returning empty data
mockPostMarketplace.mockResolvedValue({
data: { plugins: [] },
})
const pluginAction = Actions.plugin
const result = await pluginAction.search('@plugin', '', 'en')
expect(result).toEqual([])
})
it('should return empty array when API returns unexpected data structure', async () => {
// Mock API returning unexpected data structure
mockPostMarketplace.mockResolvedValue({
data: null,
})
const pluginAction = Actions.plugin
const result = await pluginAction.search('@plugin', 'test', 'en')
expect(result).toEqual([])
})
})
describe('Other search types error handling', () => {
it('@app search should return empty array when API fails', async () => {
// Mock app API failure
mockFetchAppList.mockRejectedValue(new Error('API Error'))
const appAction = Actions.app
const result = await appAction.search('@app', 'test', 'en')
expect(result).toEqual([])
})
it('@knowledge search should return empty array when API fails', async () => {
// Mock knowledge API failure
mockFetchDatasets.mockRejectedValue(new Error('API Error'))
const knowledgeAction = Actions.knowledge
const result = await knowledgeAction.search('@knowledge', 'test', 'en')
expect(result).toEqual([])
})
})
describe('Unified search entry error handling', () => {
it('regular search (without @prefix) should return successful results even when partial APIs fail', async () => {
// Set app and knowledge success, plugin failure
mockFetchAppList.mockResolvedValue({ data: [], has_more: false, limit: 10, page: 1, total: 0 })
mockFetchDatasets.mockResolvedValue({ data: [], has_more: false, limit: 10, page: 1, total: 0 })
mockPostMarketplace.mockRejectedValue(new Error('Plugin API failed'))
const result = await searchAnything('en', 'test')
// Should return successful results even if plugin search fails
expect(result).toEqual([])
expect(console.warn).toHaveBeenCalledWith('Plugin search failed:', expect.any(Error))
})
it('@plugin dedicated search should return empty array when API fails', async () => {
// Mock plugin API failure
mockPostMarketplace.mockRejectedValue(new Error('Plugin service unavailable'))
const pluginAction = Actions.plugin
const result = await searchAnything('en', '@plugin test', pluginAction)
// Should return empty array instead of throwing error
expect(result).toEqual([])
})
it('@app dedicated search should return empty array when API fails', async () => {
// Mock app API failure
mockFetchAppList.mockRejectedValue(new Error('App service unavailable'))
const appAction = Actions.app
const result = await searchAnything('en', '@app test', appAction)
expect(result).toEqual([])
})
})
describe('Error handling consistency validation', () => {
it('all search types should return empty array when encountering errors', async () => {
// Mock all APIs to fail
mockPostMarketplace.mockRejectedValue(new Error('Plugin API failed'))
mockFetchAppList.mockRejectedValue(new Error('App API failed'))
mockFetchDatasets.mockRejectedValue(new Error('Dataset API failed'))
const actions = [
{ name: '@plugin', action: Actions.plugin },
{ name: '@app', action: Actions.app },
{ name: '@knowledge', action: Actions.knowledge },
]
for (const { name, action } of actions) {
const result = await action.search(name, 'test', 'en')
expect(result).toEqual([])
}
})
})
describe('Edge case testing', () => {
it('empty search term should be handled properly', async () => {
mockPostMarketplace.mockResolvedValue({ data: { plugins: [] } })
const result = await searchAnything('en', '@plugin ', Actions.plugin)
expect(result).toEqual([])
})
it('network timeout should be handled correctly', async () => {
const timeoutError = new Error('Network timeout')
timeoutError.name = 'TimeoutError'
mockPostMarketplace.mockRejectedValue(timeoutError)
const result = await searchAnything('en', '@plugin test', Actions.plugin)
expect(result).toEqual([])
})
it('JSON parsing errors should be handled correctly', async () => {
const parseError = new SyntaxError('Unexpected token in JSON')
mockPostMarketplace.mockRejectedValue(parseError)
const result = await searchAnything('en', '@plugin test', Actions.plugin)
expect(result).toEqual([])
})
})
})

View File

@ -8,6 +8,7 @@ import Header from '@/app/components/header'
import { EventEmitterContextProvider } from '@/context/event-emitter'
import { ProviderContextProvider } from '@/context/provider-context'
import { ModalContextProvider } from '@/context/modal-context'
import GotoAnything from '@/app/components/goto-anything'
const Layout = ({ children }: { children: ReactNode }) => {
return (
@ -22,6 +23,7 @@ const Layout = ({ children }: { children: ReactNode }) => {
<Header />
</HeaderWrapper>
{children}
<GotoAnything />
</ModalContextProvider>
</ProviderContextProvider>
</EventEmitterContextProvider>

View File

@ -32,7 +32,7 @@ export type InputProps = {
unit?: string
} & Omit<React.InputHTMLAttributes<HTMLInputElement>, 'size'> & VariantProps<typeof inputVariants>
const Input = ({
const Input = React.forwardRef<HTMLInputElement, InputProps>(({
size,
disabled,
destructive,
@ -47,12 +47,13 @@ const Input = ({
onChange = noop,
unit,
...props
}: InputProps) => {
}, ref) => {
const { t } = useTranslation()
return (
<div className={cn('relative w-full', wrapperClassName)}>
{showLeftIcon && <RiSearchLine className={cn('absolute left-2 top-1/2 h-4 w-4 -translate-y-1/2 text-components-input-text-placeholder')} />}
<input
ref={ref}
style={styleCss}
className={cn(
'w-full appearance-none border border-transparent bg-components-input-bg-normal py-[7px] text-components-input-text-filled caret-primary-600 outline-none placeholder:text-components-input-text-placeholder hover:border-components-input-border-hover hover:bg-components-input-bg-hover focus:border-components-input-border-active focus:bg-components-input-bg-active focus:shadow-xs',
@ -92,6 +93,8 @@ const Input = ({
}
</div>
)
}
})
Input.displayName = 'Input'
export default Input

View File

@ -15,6 +15,7 @@ type IModal = {
children?: React.ReactNode
closable?: boolean
overflowVisible?: boolean
highPriority?: boolean // For modals that need to appear above dropdowns
}
export default function Modal({
@ -27,10 +28,11 @@ export default function Modal({
children,
closable = false,
overflowVisible = false,
highPriority = false,
}: IModal) {
return (
<Transition appear show={isShow} as={Fragment}>
<Dialog as="div" className={classNames('relative z-[60]', wrapperClassName)} onClose={onClose}>
<Dialog as="div" className={classNames('relative', highPriority ? 'z-[1100]' : 'z-[60]', wrapperClassName)} onClose={onClose}>
<TransitionChild>
<div className={classNames(
'fixed inset-0 bg-background-overlay',

View File

@ -192,6 +192,7 @@ const SimpleSelect: FC<ISelectProps> = ({
const localPlaceholder = placeholder || t('common.placeholder.select')
const [selectedItem, setSelectedItem] = useState<Item | null>(null)
const [open, setOpen] = useState(false)
useEffect(() => {
let defaultSelect = null
@ -220,8 +221,11 @@ const SimpleSelect: FC<ISelectProps> = ({
<ListboxButton onClick={() => {
// get data-open, use setTimeout to ensure the attribute is set
setTimeout(() => {
if (listboxRef.current)
onOpenChange?.(listboxRef.current.getAttribute('data-open') !== null)
if (listboxRef.current) {
const isOpen = listboxRef.current.getAttribute('data-open') !== null
setOpen(isOpen)
onOpenChange?.(isOpen)
}
})
}} className={classNames(`flex h-full w-full items-center rounded-lg border-0 bg-components-input-bg-normal pl-3 pr-10 focus-visible:bg-state-base-hover-alt focus-visible:outline-none group-hover/simple-select:bg-state-base-hover-alt sm:text-sm sm:leading-6 ${disabled ? 'cursor-not-allowed' : 'cursor-pointer'}`, className)}>
<span className={classNames('system-sm-regular block truncate text-left text-components-input-text-filled', !selectedItem?.name && 'text-components-input-text-placeholder')}>{selectedItem?.name ?? localPlaceholder}</span>
@ -240,10 +244,17 @@ const SimpleSelect: FC<ISelectProps> = ({
/>
)
: (
<ChevronDownIcon
className="h-4 w-4 text-text-quaternary group-hover/simple-select:text-text-secondary"
aria-hidden="true"
/>
open ? (
<ChevronUpIcon
className="h-4 w-4 text-text-quaternary group-hover/simple-select:text-text-secondary"
aria-hidden="true"
/>
) : (
<ChevronDownIcon
className="h-4 w-4 text-text-quaternary group-hover/simple-select:text-text-secondary"
aria-hidden="true"
/>
)
)}
</span>
</ListboxButton>

View File

@ -79,7 +79,7 @@ const TagFilter: FC<TagFilterProps> = ({
className='block'
>
<div className={cn(
'flex h-8 cursor-pointer items-center gap-1 rounded-lg border-[0.5px] border-transparent bg-components-input-bg-normal px-2',
'flex h-8 cursor-pointer select-none items-center gap-1 rounded-lg border-[0.5px] border-transparent bg-components-input-bg-normal px-2',
!open && !!value.length && 'shadow-xs',
open && !!value.length && 'shadow-xs',
)}>
@ -123,7 +123,7 @@ const TagFilter: FC<TagFilterProps> = ({
{filteredTagList.map(tag => (
<div
key={tag.id}
className='flex cursor-pointer items-center gap-2 rounded-lg py-[6px] pl-3 pr-2 hover:bg-state-base-hover'
className='flex cursor-pointer select-none items-center gap-2 rounded-lg py-[6px] pl-3 pr-2 hover:bg-state-base-hover'
onClick={() => selectTag(tag)}
>
<div title={tag.name} className='grow truncate text-sm leading-5 text-text-tertiary'>{tag.name}</div>
@ -139,7 +139,7 @@ const TagFilter: FC<TagFilterProps> = ({
</div>
<div className='border-t-[0.5px] border-divider-regular' />
<div className='p-1'>
<div className='flex cursor-pointer items-center gap-2 rounded-lg py-[6px] pl-3 pr-2 hover:bg-state-base-hover' onClick={() => {
<div className='flex cursor-pointer select-none items-center gap-2 rounded-lg py-[6px] pl-3 pr-2 hover:bg-state-base-hover' onClick={() => {
setShowTagManagementModal(true)
setOpen(false)
}}>

View File

@ -69,9 +69,11 @@ const RenameDatasetModal = ({ show, dataset, onSuccess, onClose }: RenameDataset
isShow={show}
onClose={noop}
>
<div className='relative pb-2 text-xl font-medium leading-[30px] text-text-primary'>{t('datasetSettings.title')}</div>
<div className='absolute right-4 top-4 cursor-pointer p-2' onClick={onClose}>
<RiCloseLine className='h-4 w-4 text-text-tertiary' />
<div className='flex items-center justify-between pb-2'>
<div className='text-xl font-medium leading-[30px] text-text-primary'>{t('datasetSettings.title')}</div>
<div className='cursor-pointer p-2' onClick={onClose}>
<RiCloseLine className='h-4 w-4 text-text-tertiary' />
</div>
</div>
<div>
<div className={cn('flex flex-wrap items-center justify-between py-4')}>

View File

@ -0,0 +1,58 @@
import type { ActionItem, AppSearchResult } from './types'
import type { App } from '@/types/app'
import { fetchAppList } from '@/service/apps'
import AppIcon from '../../base/app-icon'
import { AppTypeIcon } from '../../app/type-selector'
import { getRedirectionPath } from '@/utils/app-redirection'
const parser = (apps: App[]): AppSearchResult[] => {
return apps.map(app => ({
id: app.id,
title: app.name,
description: app.description,
type: 'app' as const,
path: getRedirectionPath(true, {
id: app.id,
mode: app.mode,
}),
icon: (
<div className='relative shrink-0'>
<AppIcon
size='large'
iconType={app.icon_type}
icon={app.icon}
background={app.icon_background}
imageUrl={app.icon_url}
/>
<AppTypeIcon wrapperClassName='absolute -bottom-0.5 -right-0.5 w-4 h-4 rounded-[4px] border border-divider-regular outline outline-components-panel-on-panel-item-bg'
className='h-3 w-3' type={app.mode} />
</div>
),
data: app,
}))
}
export const appAction: ActionItem = {
key: '@app',
shortcut: '@app',
title: 'Search Applications',
description: 'Search and navigate to your applications',
// action,
search: async (_, searchTerm = '', _locale) => {
try {
const response = await fetchAppList({
url: 'apps',
params: {
page: 1,
name: searchTerm,
},
})
const apps = response?.data || []
return parser(apps)
}
catch (error) {
console.warn('App search failed:', error)
return []
}
},
}

View File

@ -0,0 +1,26 @@
export type CommandHandler = (args?: Record<string, any>) => void | Promise<void>
const handlers = new Map<string, CommandHandler>()
export const registerCommand = (name: string, handler: CommandHandler) => {
handlers.set(name, handler)
}
export const unregisterCommand = (name: string) => {
handlers.delete(name)
}
export const executeCommand = async (name: string, args?: Record<string, any>) => {
const handler = handlers.get(name)
if (!handler)
return
await handler(args)
}
export const registerCommands = (map: Record<string, CommandHandler>) => {
Object.entries(map).forEach(([name, handler]) => registerCommand(name, handler))
}
export const unregisterCommands = (names: string[]) => {
names.forEach(unregisterCommand)
}

View File

@ -0,0 +1,76 @@
import { appAction } from './app'
import { knowledgeAction } from './knowledge'
import { pluginAction } from './plugin'
import { workflowNodesAction } from './workflow-nodes'
import type { ActionItem, SearchResult } from './types'
import { commandAction } from './run'
export const Actions = {
app: appAction,
knowledge: knowledgeAction,
plugin: pluginAction,
run: commandAction,
node: workflowNodesAction,
}
export const searchAnything = async (
locale: string,
query: string,
actionItem?: ActionItem,
): Promise<SearchResult[]> => {
if (actionItem) {
const searchTerm = query.replace(actionItem.key, '').replace(actionItem.shortcut, '').trim()
try {
return await actionItem.search(query, searchTerm, locale)
}
catch (error) {
console.warn(`Search failed for ${actionItem.key}:`, error)
return []
}
}
if (query.startsWith('@'))
return []
// Use Promise.allSettled to handle partial failures gracefully
const searchPromises = Object.values(Actions).map(async (action) => {
try {
const results = await action.search(query, query, locale)
return { success: true, data: results, actionType: action.key }
}
catch (error) {
console.warn(`Search failed for ${action.key}:`, error)
return { success: false, data: [], actionType: action.key, error }
}
})
const settledResults = await Promise.allSettled(searchPromises)
const allResults: SearchResult[] = []
const failedActions: string[] = []
settledResults.forEach((result, index) => {
if (result.status === 'fulfilled' && result.value.success) {
allResults.push(...result.value.data)
}
else {
const actionKey = Object.values(Actions)[index]?.key || 'unknown'
failedActions.push(actionKey)
}
})
if (failedActions.length > 0)
console.warn(`Some search actions failed: ${failedActions.join(', ')}`)
return allResults
}
export const matchAction = (query: string, actions: Record<string, ActionItem>) => {
return Object.values(actions).find((action) => {
const reg = new RegExp(`^(${action.key}|${action.shortcut})(?:\\s|$)`)
return reg.test(query)
})
}
export * from './types'
export { appAction, knowledgeAction, pluginAction, workflowNodesAction }

View File

@ -0,0 +1,56 @@
import type { ActionItem, KnowledgeSearchResult } from './types'
import type { DataSet } from '@/models/datasets'
import { fetchDatasets } from '@/service/datasets'
import { Folder } from '../../base/icons/src/vender/solid/files'
import cn from '@/utils/classnames'
const EXTERNAL_PROVIDER = 'external' as const
const isExternalProvider = (provider: string): boolean => provider === EXTERNAL_PROVIDER
const parser = (datasets: DataSet[]): KnowledgeSearchResult[] => {
return datasets.map((dataset) => {
const path = isExternalProvider(dataset.provider) ? `/datasets/${dataset.id}/hitTesting` : `/datasets/${dataset.id}/documents`
return {
id: dataset.id,
title: dataset.name,
description: dataset.description,
type: 'knowledge' as const,
path,
icon: (
<div className={cn(
'flex shrink-0 items-center justify-center rounded-md border-[0.5px] border-[#E0EAFF] bg-[#F5F8FF] p-2.5',
!dataset.embedding_available && 'opacity-50 hover:opacity-100',
)}>
<Folder className='h-5 w-5 text-[#444CE7]' />
</div>
),
data: dataset,
}
})
}
export const knowledgeAction: ActionItem = {
key: '@knowledge',
shortcut: '@kb',
title: 'Search Knowledge Bases',
description: 'Search and navigate to your knowledge bases',
// action,
search: async (_, searchTerm = '', _locale) => {
try {
const response = await fetchDatasets({
url: '/datasets',
params: {
page: 1,
limit: 10,
keyword: searchTerm,
},
})
const datasets = response?.data || []
return parser(datasets)
}
catch (error) {
console.warn('Knowledge search failed:', error)
return []
}
},
}

View File

@ -0,0 +1,53 @@
import type { ActionItem, PluginSearchResult } from './types'
import { renderI18nObject } from '@/i18n-config'
import Icon from '../../plugins/card/base/card-icon'
import { postMarketplace } from '@/service/base'
import type { Plugin, PluginsFromMarketplaceResponse } from '../../plugins/types'
import { getPluginIconInMarketplace } from '../../plugins/marketplace/utils'
const parser = (plugins: Plugin[], locale: string): PluginSearchResult[] => {
return plugins.map((plugin) => {
return {
id: plugin.name,
title: renderI18nObject(plugin.label, locale) || plugin.name,
description: renderI18nObject(plugin.brief, locale) || '',
type: 'plugin' as const,
icon: <Icon src={plugin.icon} />,
data: plugin,
}
})
}
export const pluginAction: ActionItem = {
key: '@plugin',
shortcut: '@plugin',
title: 'Search Plugins',
description: 'Search and navigate to your plugins',
search: async (_, searchTerm = '', locale) => {
try {
const response = await postMarketplace<{ data: PluginsFromMarketplaceResponse }>('/plugins/search/advanced', {
body: {
page: 1,
page_size: 10,
query: searchTerm,
type: 'plugin',
},
})
if (!response?.data?.plugins) {
console.warn('Plugin search: Unexpected response structure', response)
return []
}
const list = response.data.plugins.map(plugin => ({
...plugin,
icon: getPluginIconInMarketplace(plugin),
}))
return parser(list, locale!)
}
catch (error) {
console.warn('Plugin search failed:', error)
return []
}
},
}

View File

@ -0,0 +1,33 @@
import type { CommandSearchResult } from './types'
import { languages } from '@/i18n-config/language'
import { RiTranslate } from '@remixicon/react'
import i18n from '@/i18n-config/i18next-config'
export const buildLanguageCommands = (query: string): CommandSearchResult[] => {
const q = query.toLowerCase()
const list = languages.filter(item => item.supported && (
!q || item.name.toLowerCase().includes(q) || String(item.value).toLowerCase().includes(q)
))
return list.map(item => ({
id: `lang-${item.value}`,
title: item.name,
description: i18n.t('app.gotoAnything.actions.languageChangeDesc'),
type: 'command' as const,
data: { command: 'i18n.set', args: { locale: item.value } },
}))
}
export const buildLanguageRootItem = (): CommandSearchResult => {
return {
id: 'category-language',
title: i18n.t('app.gotoAnything.actions.languageCategoryTitle'),
description: i18n.t('app.gotoAnything.actions.languageCategoryDesc'),
type: 'command',
icon: (
<div className='flex h-6 w-6 items-center justify-center rounded-md border-[0.5px] border-divider-regular bg-components-panel-bg'>
<RiTranslate className='h-4 w-4 text-text-tertiary' />
</div>
),
data: { command: 'nav.search', args: { query: '@run language ' } },
}
}

View File

@ -0,0 +1,61 @@
import type { CommandSearchResult } from './types'
import type { ReactNode } from 'react'
import { RiComputerLine, RiMoonLine, RiPaletteLine, RiSunLine } from '@remixicon/react'
import i18n from '@/i18n-config/i18next-config'
const THEME_ITEMS: { id: 'light' | 'dark' | 'system'; titleKey: string; descKey: string; icon: ReactNode }[] = [
{
id: 'system',
titleKey: 'app.gotoAnything.actions.themeSystem',
descKey: 'app.gotoAnything.actions.themeSystemDesc',
icon: <RiComputerLine className='h-4 w-4 text-text-tertiary' />,
},
{
id: 'light',
titleKey: 'app.gotoAnything.actions.themeLight',
descKey: 'app.gotoAnything.actions.themeLightDesc',
icon: <RiSunLine className='h-4 w-4 text-text-tertiary' />,
},
{
id: 'dark',
titleKey: 'app.gotoAnything.actions.themeDark',
descKey: 'app.gotoAnything.actions.themeDarkDesc',
icon: <RiMoonLine className='h-4 w-4 text-text-tertiary' />,
},
]
export const buildThemeCommands = (query: string, locale?: string): CommandSearchResult[] => {
const q = query.toLowerCase()
const list = THEME_ITEMS.filter(item =>
!q
|| i18n.t(item.titleKey, { lng: locale }).toLowerCase().includes(q)
|| item.id.includes(q),
)
return list.map(item => ({
id: item.id,
title: i18n.t(item.titleKey, { lng: locale }),
description: i18n.t(item.descKey, { lng: locale }),
type: 'command' as const,
icon: (
<div className='flex h-6 w-6 items-center justify-center rounded-md border-[0.5px] border-divider-regular bg-components-panel-bg'>
{item.icon}
</div>
),
data: { command: 'theme.set', args: { value: item.id } },
}))
}
export const buildThemeRootItem = (): CommandSearchResult => {
return {
id: 'category-theme',
title: i18n.t('app.gotoAnything.actions.themeCategoryTitle'),
description: i18n.t('app.gotoAnything.actions.themeCategoryDesc'),
type: 'command',
icon: (
<div className='flex h-6 w-6 items-center justify-center rounded-md border-[0.5px] border-divider-regular bg-components-panel-bg'>
<RiPaletteLine className='h-4 w-4 text-text-tertiary' />
</div>
),
data: { command: 'nav.search', args: { query: '@run theme ' } },
}
}

View File

@ -0,0 +1,97 @@
'use client'
import { useEffect } from 'react'
import type { ActionItem, CommandSearchResult } from './types'
import { buildLanguageCommands, buildLanguageRootItem } from './run-language'
import { buildThemeCommands, buildThemeRootItem } from './run-theme'
import i18n from '@/i18n-config/i18next-config'
import { executeCommand, registerCommands, unregisterCommands } from './command-bus'
import { useTheme } from 'next-themes'
import { setLocaleOnClient } from '@/i18n-config'
const rootParser = (query: string): CommandSearchResult[] => {
const q = query.toLowerCase()
const items: CommandSearchResult[] = []
if (!q || 'theme'.includes(q))
items.push(buildThemeRootItem())
if (!q || 'language'.includes(q) || 'lang'.includes(q))
items.push(buildLanguageRootItem())
return items
}
type RunContext = {
setTheme?: (value: 'light' | 'dark' | 'system') => void
setLocale?: (locale: string) => Promise<void>
search?: (query: string) => void
}
export const commandAction: ActionItem = {
key: '@run',
shortcut: '@run',
title: i18n.t('app.gotoAnything.actions.runTitle'),
description: i18n.t('app.gotoAnything.actions.runDesc'),
action: (result) => {
if (result.type !== 'command') return
const { command, args } = result.data
if (command === 'theme.set') {
executeCommand('theme.set', args)
return
}
if (command === 'i18n.set') {
executeCommand('i18n.set', args)
return
}
if (command === 'nav.search')
executeCommand('nav.search', args)
},
search: async (_, searchTerm = '') => {
const q = searchTerm.trim()
if (q.startsWith('theme'))
return buildThemeCommands(q.replace(/^theme\s*/, ''), i18n.language)
if (q.startsWith('language') || q.startsWith('lang'))
return buildLanguageCommands(q.replace(/^(language|lang)\s*/, ''))
// root categories
return rootParser(q)
},
}
// Register/unregister default handlers for @run commands with external dependencies.
export const registerRunCommands = (deps: {
setTheme?: (value: 'light' | 'dark' | 'system') => void
setLocale?: (locale: string) => Promise<void>
search?: (query: string) => void
}) => {
registerCommands({
'theme.set': async (args) => {
deps.setTheme?.(args?.value)
},
'i18n.set': async (args) => {
const locale = args?.locale
if (locale)
await deps.setLocale?.(locale)
},
'nav.search': (args) => {
const q = args?.query
if (q)
deps.search?.(q)
},
})
}
export const unregisterRunCommands = () => {
unregisterCommands(['theme.set', 'i18n.set', 'nav.search'])
}
export const RunCommandProvider = ({ onNavSearch }: { onNavSearch?: (q: string) => void }) => {
const theme = useTheme()
useEffect(() => {
registerRunCommands({
setTheme: theme.setTheme,
setLocale: setLocaleOnClient,
search: onNavSearch,
})
return () => unregisterRunCommands()
}, [theme.setTheme, onNavSearch])
return null
}

View File

@ -0,0 +1,58 @@
import type { ReactNode } from 'react'
import type { TypeWithI18N } from '../../base/form/types'
import type { App } from '@/types/app'
import type { Plugin } from '../../plugins/types'
import type { DataSet } from '@/models/datasets'
import type { CommonNodeType } from '../../workflow/types'
export type SearchResultType = 'app' | 'knowledge' | 'plugin' | 'workflow-node' | 'command'
export type BaseSearchResult<T = any> = {
id: string
title: string
description?: string
type: SearchResultType
path?: string
icon?: ReactNode
data: T
}
export type AppSearchResult = {
type: 'app'
} & BaseSearchResult<App>
export type PluginSearchResult = {
type: 'plugin'
} & BaseSearchResult<Plugin>
export type KnowledgeSearchResult = {
type: 'knowledge'
} & BaseSearchResult<DataSet>
export type WorkflowNodeSearchResult = {
type: 'workflow-node'
metadata?: {
nodeId: string
nodeData: CommonNodeType
}
} & BaseSearchResult<CommonNodeType>
export type CommandSearchResult = {
type: 'command'
} & BaseSearchResult<{ command: string; args?: Record<string, any> }>
export type SearchResult = AppSearchResult | PluginSearchResult | KnowledgeSearchResult | WorkflowNodeSearchResult | CommandSearchResult
export type ActionItem = {
key: '@app' | '@knowledge' | '@plugin' | '@node' | '@run'
shortcut: string
title: string | TypeWithI18N
description: string
action?: (data: SearchResult) => void
searchFn?: (searchTerm: string) => SearchResult[]
search: (
query: string,
searchTerm: string,
locale?: string,
) => (Promise<SearchResult[]> | SearchResult[])
}

View File

@ -0,0 +1,24 @@
import type { ActionItem } from './types'
// Create the workflow nodes action
export const workflowNodesAction: ActionItem = {
key: '@node',
shortcut: '@node',
title: 'Search Workflow Nodes',
description: 'Find and jump to nodes in the current workflow by name or type',
searchFn: undefined, // Will be set by useWorkflowSearch hook
search: async (_, searchTerm = '', locale) => {
try {
// Use the searchFn if available (set by useWorkflowSearch hook)
if (workflowNodesAction.searchFn)
return workflowNodesAction.searchFn(searchTerm)
// If not in workflow context, return empty array
return []
}
catch (error) {
console.warn('Workflow nodes search failed:', error)
return []
}
},
}

View File

@ -0,0 +1,89 @@
import type { FC } from 'react'
import { useEffect } from 'react'
import { Command } from 'cmdk'
import { useTranslation } from 'react-i18next'
import type { ActionItem } from './actions/types'
type Props = {
actions: Record<string, ActionItem>
onCommandSelect: (commandKey: string) => void
searchFilter?: string
commandValue?: string
onCommandValueChange?: (value: string) => void
}
const CommandSelector: FC<Props> = ({ actions, onCommandSelect, searchFilter, commandValue, onCommandValueChange }) => {
const { t } = useTranslation()
const filteredActions = Object.values(actions).filter((action) => {
if (!searchFilter)
return true
const filterLower = searchFilter.toLowerCase()
return action.shortcut.toLowerCase().includes(filterLower)
|| action.key.toLowerCase().includes(filterLower)
})
useEffect(() => {
if (filteredActions.length > 0 && onCommandValueChange) {
const currentValueExists = filteredActions.some(action => action.shortcut === commandValue)
if (!currentValueExists)
onCommandValueChange(filteredActions[0].shortcut)
}
}, [searchFilter, filteredActions.length])
if (filteredActions.length === 0) {
return (
<div className="p-4">
<div className="flex items-center justify-center py-8 text-center text-text-tertiary">
<div>
<div className="text-sm font-medium text-text-tertiary">
{t('app.gotoAnything.noMatchingCommands')}
</div>
<div className="mt-1 text-xs text-text-quaternary">
{t('app.gotoAnything.tryDifferentSearch')}
</div>
</div>
</div>
</div>
)
}
return (
<div className="p-4">
<div className="mb-3 text-left text-sm font-medium text-text-secondary">
{t('app.gotoAnything.selectSearchType')}
</div>
<Command.Group className="space-y-1">
{filteredActions.map(action => (
<Command.Item
key={action.key}
value={action.shortcut}
className="flex cursor-pointer items-center rounded-md
p-2.5
transition-all
duration-150 hover:bg-state-base-hover-alt aria-[selected=true]:bg-state-base-hover"
onSelect={() => onCommandSelect(action.shortcut)}
>
<span className="min-w-[4.5rem] text-left font-mono text-xs text-text-tertiary">
{action.shortcut}
</span>
<span className="ml-3 text-sm text-text-secondary">
{(() => {
const keyMap: Record<string, string> = {
'@app': 'app.gotoAnything.actions.searchApplicationsDesc',
'@plugin': 'app.gotoAnything.actions.searchPluginsDesc',
'@knowledge': 'app.gotoAnything.actions.searchKnowledgeBasesDesc',
'@run': 'app.gotoAnything.actions.runDesc',
'@node': 'app.gotoAnything.actions.searchWorkflowNodesDesc',
}
return t(keyMap[action.key])
})()}
</span>
</Command.Item>
))}
</Command.Group>
</div>
)
}
export default CommandSelector

View File

@ -0,0 +1,50 @@
'use client'
import type { ReactNode } from 'react'
import React, { createContext, useContext, useEffect, useState } from 'react'
import { usePathname } from 'next/navigation'
/**
* Interface for the GotoAnything context
*/
type GotoAnythingContextType = {
/**
* Whether the current page is a workflow page
*/
isWorkflowPage: boolean
}
// Create context with default values
const GotoAnythingContext = createContext<GotoAnythingContextType>({
isWorkflowPage: false,
})
/**
* Hook to use the GotoAnything context
*/
export const useGotoAnythingContext = () => useContext(GotoAnythingContext)
type GotoAnythingProviderProps = {
children: ReactNode
}
/**
* Provider component for GotoAnything context
*/
export const GotoAnythingProvider: React.FC<GotoAnythingProviderProps> = ({ children }) => {
const [isWorkflowPage, setIsWorkflowPage] = useState(false)
const pathname = usePathname()
// Update context based on current pathname
useEffect(() => {
// Check if current path contains workflow
const isWorkflow = pathname?.includes('/workflow') || false
setIsWorkflowPage(isWorkflow)
}, [pathname])
return (
<GotoAnythingContext.Provider value={{ isWorkflowPage }}>
{children}
</GotoAnythingContext.Provider>
)
}

View File

@ -0,0 +1,420 @@
'use client'
import type { FC } from 'react'
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
import { useRouter } from 'next/navigation'
import Modal from '@/app/components/base/modal'
import Input from '@/app/components/base/input'
import { useDebounce, useKeyPress } from 'ahooks'
import { getKeyboardKeyCodeBySystem, isEventTargetInputArea, isMac } from '@/app/components/workflow/utils/common'
import { selectWorkflowNode } from '@/app/components/workflow/utils/node-navigation'
import { RiSearchLine } from '@remixicon/react'
import { Actions as AllActions, type SearchResult, matchAction, searchAnything } from './actions'
import { GotoAnythingProvider, useGotoAnythingContext } from './context'
import { useQuery } from '@tanstack/react-query'
import { useGetLanguage } from '@/context/i18n'
import { useTranslation } from 'react-i18next'
import InstallFromMarketplace from '../plugins/install-plugin/install-from-marketplace'
import type { Plugin } from '../plugins/types'
import { Command } from 'cmdk'
import CommandSelector from './command-selector'
import { RunCommandProvider } from './actions/run'
type Props = {
onHide?: () => void
}
const GotoAnything: FC<Props> = ({
onHide,
}) => {
const router = useRouter()
const defaultLocale = useGetLanguage()
const { isWorkflowPage } = useGotoAnythingContext()
const { t } = useTranslation()
const [show, setShow] = useState<boolean>(false)
const [searchQuery, setSearchQuery] = useState<string>('')
const [cmdVal, setCmdVal] = useState<string>('')
const inputRef = useRef<HTMLInputElement>(null)
const handleNavSearch = useCallback((q: string) => {
setShow(true)
setSearchQuery(q)
requestAnimationFrame(() => inputRef.current?.focus())
}, [])
// Filter actions based on context
const Actions = useMemo(() => {
// Create a filtered copy of actions based on current page context
if (isWorkflowPage) {
// Include all actions on workflow pages
return AllActions
}
else {
// Exclude node action on non-workflow pages
const { app, knowledge, plugin, run } = AllActions
return { app, knowledge, plugin, run }
}
}, [isWorkflowPage])
const [activePlugin, setActivePlugin] = useState<Plugin>()
// Handle keyboard shortcuts
const handleToggleModal = useCallback((e: KeyboardEvent) => {
// Allow closing when modal is open, even if focus is in the search input
if (!show && isEventTargetInputArea(e.target as HTMLElement))
return
e.preventDefault()
setShow((prev) => {
if (!prev) {
// Opening modal - reset search state
setSearchQuery('')
}
return !prev
})
}, [show])
useKeyPress(`${getKeyboardKeyCodeBySystem('ctrl')}.k`, handleToggleModal, {
exactMatch: true,
useCapture: true,
})
useKeyPress(['esc'], (e) => {
if (show) {
e.preventDefault()
setShow(false)
setSearchQuery('')
}
})
const searchQueryDebouncedValue = useDebounce(searchQuery.trim(), {
wait: 300,
})
const isCommandsMode = searchQuery.trim() === '@' || (searchQuery.trim().startsWith('@') && !matchAction(searchQuery.trim(), Actions))
const searchMode = useMemo(() => {
if (isCommandsMode) return 'commands'
const query = searchQueryDebouncedValue.toLowerCase()
const action = matchAction(query, Actions)
return action ? action.key : 'general'
}, [searchQueryDebouncedValue, Actions, isCommandsMode])
const { data: searchResults = [], isLoading, isError, error } = useQuery(
{
queryKey: [
'goto-anything',
'search-result',
searchQueryDebouncedValue,
searchMode,
isWorkflowPage,
defaultLocale,
Object.keys(Actions).sort().join(','),
],
queryFn: async () => {
const query = searchQueryDebouncedValue.toLowerCase()
const action = matchAction(query, Actions)
return await searchAnything(defaultLocale, query, action)
},
enabled: !!searchQueryDebouncedValue && !isCommandsMode,
staleTime: 30000,
gcTime: 300000,
},
)
const handleCommandSelect = useCallback((commandKey: string) => {
setSearchQuery(`${commandKey} `)
setCmdVal('')
setTimeout(() => {
inputRef.current?.focus()
}, 0)
}, [])
// Handle navigation to selected result
const handleNavigate = useCallback((result: SearchResult) => {
setShow(false)
setSearchQuery('')
switch (result.type) {
case 'command': {
const action = Object.values(Actions).find(a => a.key === '@run')
action?.action?.(result)
break
}
case 'plugin':
setActivePlugin(result.data)
break
case 'workflow-node':
// Handle workflow node selection and navigation
if (result.metadata?.nodeId)
selectWorkflowNode(result.metadata.nodeId, true)
break
default:
if (result.path)
router.push(result.path)
}
}, [router])
// Group results by type
const groupedResults = useMemo(() => searchResults.reduce((acc, result) => {
if (!acc[result.type])
acc[result.type] = []
acc[result.type].push(result)
return acc
}, {} as { [key: string]: SearchResult[] }),
[searchResults])
const emptyResult = useMemo(() => {
if (searchResults.length || !searchQuery.trim() || isLoading || isCommandsMode)
return null
const isCommandSearch = searchMode !== 'general'
const commandType = isCommandSearch ? searchMode.replace('@', '') : ''
if (isError) {
return (
<div className="flex items-center justify-center py-8 text-center text-text-tertiary">
<div>
<div className='text-sm font-medium text-red-500'>{t('app.gotoAnything.searchTemporarilyUnavailable')}</div>
<div className='mt-1 text-xs text-text-quaternary'>
{t('app.gotoAnything.servicesUnavailableMessage')}
</div>
</div>
</div>
)
}
return (
<div className="flex items-center justify-center py-8 text-center text-text-tertiary">
<div>
<div className='text-sm font-medium'>
{isCommandSearch
? (() => {
const keyMap: Record<string, string> = {
app: 'app.gotoAnything.emptyState.noAppsFound',
plugin: 'app.gotoAnything.emptyState.noPluginsFound',
knowledge: 'app.gotoAnything.emptyState.noKnowledgeBasesFound',
node: 'app.gotoAnything.emptyState.noWorkflowNodesFound',
}
return t(keyMap[commandType] || 'app.gotoAnything.noResults')
})()
: t('app.gotoAnything.noResults')
}
</div>
<div className='mt-1 text-xs text-text-quaternary'>
{isCommandSearch
? t('app.gotoAnything.emptyState.tryDifferentTerm', { mode: searchMode })
: t('app.gotoAnything.emptyState.trySpecificSearch', { shortcuts: Object.values(Actions).map(action => action.shortcut).join(', ') })
}
</div>
</div>
</div>
)
}, [searchResults, searchQuery, Actions, searchMode, isLoading, isError, isCommandsMode])
const defaultUI = useMemo(() => {
if (searchQuery.trim())
return null
return (<div className="flex items-center justify-center py-12 text-center text-text-tertiary">
<div>
<div className='text-sm font-medium'>{t('app.gotoAnything.searchTitle')}</div>
<div className='mt-3 space-y-1 text-xs text-text-quaternary'>
<div>{t('app.gotoAnything.searchHint')}</div>
<div>{t('app.gotoAnything.commandHint')}</div>
</div>
</div>
</div>)
}, [searchQuery, Actions])
useEffect(() => {
if (show) {
requestAnimationFrame(() => {
inputRef.current?.focus()
})
}
return () => {
setCmdVal('')
}
}, [show])
return (
<>
<Modal
isShow={show}
onClose={() => {
setShow(false)
setSearchQuery('')
onHide?.()
}}
closable={false}
className='!w-[480px] !p-0'
highPriority={true}
>
<div className='flex flex-col rounded-2xl border border-components-panel-border bg-components-panel-bg shadow-xl'>
<Command
className='outline-none'
value={cmdVal}
onValueChange={setCmdVal}
disablePointerSelection
>
<div className='flex items-center gap-3 border-b border-divider-subtle bg-components-panel-bg-blur px-4 py-3'>
<RiSearchLine className='h-4 w-4 text-text-quaternary' />
<div className='flex flex-1 items-center gap-2'>
<Input
ref={inputRef}
value={searchQuery}
placeholder={t('app.gotoAnything.searchPlaceholder')}
onChange={(e) => {
setSearchQuery(e.target.value)
if (!e.target.value.startsWith('@'))
setCmdVal('')
}}
className='flex-1 !border-0 !bg-transparent !shadow-none'
wrapperClassName='flex-1 !border-0 !bg-transparent'
autoFocus
/>
{searchMode !== 'general' && (
<div className='flex items-center gap-1 rounded bg-blue-50 px-2 py-[2px] text-xs font-medium text-blue-600 dark:bg-blue-900/40 dark:text-blue-300'>
<span>{searchMode.replace('@', '').toUpperCase()}</span>
</div>
)}
</div>
<div className='text-xs text-text-quaternary'>
<span className='system-kbd rounded bg-gray-200 px-1 py-[2px] font-mono text-gray-700 dark:bg-gray-800 dark:text-gray-100'>
{isMac() ? '⌘' : 'Ctrl'}
</span>
<span className='system-kbd ml-1 rounded bg-gray-200 px-1 py-[2px] font-mono text-gray-700 dark:bg-gray-800 dark:text-gray-100'>
K
</span>
</div>
</div>
<Command.List className='max-h-[275px] min-h-[240px] overflow-y-auto'>
{isLoading && (
<div className="flex items-center justify-center py-8 text-center text-text-tertiary">
<div className="flex items-center gap-2">
<div className="h-4 w-4 animate-spin rounded-full border-2 border-gray-300 border-t-gray-600"></div>
<span className="text-sm">{t('app.gotoAnything.searching')}</span>
</div>
</div>
)}
{isError && (
<div className="flex items-center justify-center py-8 text-center text-text-tertiary">
<div>
<div className="text-sm font-medium text-red-500">{t('app.gotoAnything.searchFailed')}</div>
<div className="mt-1 text-xs text-text-quaternary">
{error.message}
</div>
</div>
</div>
)}
{!isLoading && !isError && (
<>
{isCommandsMode ? (
<CommandSelector
actions={Actions}
onCommandSelect={handleCommandSelect}
searchFilter={searchQuery.trim().substring(1)}
commandValue={cmdVal}
onCommandValueChange={setCmdVal}
/>
) : (
Object.entries(groupedResults).map(([type, results], groupIndex) => (
<Command.Group key={groupIndex} heading={(() => {
const typeMap: Record<string, string> = {
'app': 'app.gotoAnything.groups.apps',
'plugin': 'app.gotoAnything.groups.plugins',
'knowledge': 'app.gotoAnything.groups.knowledgeBases',
'workflow-node': 'app.gotoAnything.groups.workflowNodes',
}
return t(typeMap[type] || `${type}s`)
})()} className='p-2 capitalize text-text-secondary'>
{results.map(result => (
<Command.Item
key={`${result.type}-${result.id}`}
value={result.title}
className='flex cursor-pointer items-center gap-3 rounded-md p-3 will-change-[background-color] hover:bg-state-base-hover-alt aria-[selected=true]:bg-state-base-hover data-[selected=true]:bg-state-base-hover'
onSelect={() => handleNavigate(result)}
>
{result.icon}
<div className='min-w-0 flex-1'>
<div className='truncate font-medium text-text-secondary'>
{result.title}
</div>
{result.description && (
<div className='mt-0.5 truncate text-xs text-text-quaternary'>
{result.description}
</div>
)}
</div>
<div className='text-xs capitalize text-text-quaternary'>
{result.type}
</div>
</Command.Item>
))}
</Command.Group>
))
)}
{!isCommandsMode && emptyResult}
{!isCommandsMode && defaultUI}
</>
)}
</Command.List>
{(!!searchResults.length || isError) && (
<div className='border-t border-divider-subtle bg-components-panel-bg-blur px-4 py-2 text-xs text-text-tertiary'>
<div className='flex items-center justify-between'>
<span>
{isError ? (
<span className='text-red-500'>{t('app.gotoAnything.someServicesUnavailable')}</span>
) : (
<>
{t('app.gotoAnything.resultCount', { count: searchResults.length })}
{searchMode !== 'general' && (
<span className='ml-2 opacity-60'>
{t('app.gotoAnything.inScope', { scope: searchMode.replace('@', '') })}
</span>
)}
</>
)}
</span>
<span className='opacity-60'>
{searchMode !== 'general'
? t('app.gotoAnything.clearToSearchAll')
: t('app.gotoAnything.useAtForSpecific')
}
</span>
</div>
</div>
)}
</Command>
</div>
</Modal>
<RunCommandProvider onNavSearch={handleNavSearch} />
{
activePlugin && (
<InstallFromMarketplace
manifest={activePlugin}
uniqueIdentifier={activePlugin.latest_package_identifier}
onClose={() => setActivePlugin(undefined)}
onSuccess={() => setActivePlugin(undefined)}
/>
)
}
</>
)
}
/**
* GotoAnything component with context provider
*/
const GotoAnythingWithContext: FC<Props> = (props) => {
return (
<GotoAnythingProvider>
<GotoAnything {...props} />
</GotoAnythingProvider>
)
}
export default GotoAnythingWithContext

View File

@ -70,6 +70,7 @@ export default function LanguagePage() {
items={languages.filter(item => item.supported)}
onSelect={item => handleSelectLanguage(item)}
disabled={editing}
notClearable={true}
/>
</div>
<div className='mb-8'>
@ -79,6 +80,7 @@ export default function LanguagePage() {
items={timezones}
onSelect={item => handleSelectTimezone(item)}
disabled={editing}
notClearable={true}
/>
</div>
</>

View File

@ -1,6 +1,7 @@
import type { FC } from 'react'
import { useEffect, useRef, useState } from 'react'
import type { ModelParameterRule } from '../declarations'
import { useLanguage } from '../hooks'
import { isNullOrUndefined } from '../utils'
import cn from '@/utils/classnames'
import Switch from '@/app/components/base/switch'
@ -26,6 +27,7 @@ const ParameterItem: FC<ParameterItemProps> = ({
onSwitch,
isInWorkflow,
}) => {
const language = useLanguage()
const [localValue, setLocalValue] = useState(value)
const numberInputRef = useRef<HTMLInputElement>(null)

View File

@ -54,7 +54,7 @@ const TagsFilter = ({
onClick={() => setOpen(v => !v)}
>
<div className={cn(
'ml-0.5 mr-1.5 flex items-center text-text-tertiary ',
'ml-0.5 mr-1.5 flex select-none items-center text-text-tertiary',
size === 'large' && 'h-8 py-1',
size === 'small' && 'h-7 py-0.5 ',
// selectedTagsLength && 'text-text-secondary',
@ -80,7 +80,7 @@ const TagsFilter = ({
filteredOptions.map(option => (
<div
key={option.name}
className='flex h-7 cursor-pointer items-center rounded-lg px-2 py-1.5 hover:bg-state-base-hover'
className='flex h-7 cursor-pointer select-none items-center rounded-lg px-2 py-1.5 hover:bg-state-base-hover'
onClick={() => handleCheck(option.name)}
>
<Checkbox

View File

@ -48,7 +48,7 @@ const TagsFilter = ({
>
<PortalToFollowElemTrigger onClick={() => setOpen(v => !v)}>
<div className={cn(
'flex h-8 cursor-pointer items-center rounded-lg bg-components-input-bg-normal px-2 py-1 text-text-tertiary hover:bg-state-base-hover-alt',
'flex h-8 cursor-pointer select-none items-center rounded-lg bg-components-input-bg-normal px-2 py-1 text-text-tertiary hover:bg-state-base-hover-alt',
selectedTagsLength && 'text-text-secondary',
open && 'bg-state-base-hover',
)}>
@ -99,7 +99,7 @@ const TagsFilter = ({
filteredOptions.map(option => (
<div
key={option.name}
className='flex h-7 cursor-pointer items-center rounded-lg px-2 py-1.5 hover:bg-state-base-hover'
className='flex h-7 cursor-pointer select-none items-center rounded-lg px-2 py-1.5 hover:bg-state-base-hover'
onClick={() => handleCheck(option.name)}
>
<Checkbox

View File

@ -67,7 +67,7 @@ const LabelFilter: FC<LabelFilterProps> = ({
className='block'
>
<div className={cn(
'flex h-8 cursor-pointer items-center gap-1 rounded-lg border-[0.5px] border-transparent bg-components-input-bg-normal px-2 hover:bg-components-input-bg-hover',
'flex h-8 cursor-pointer select-none items-center gap-1 rounded-lg border-[0.5px] border-transparent bg-components-input-bg-normal px-2 hover:bg-components-input-bg-hover',
!open && !!value.length && 'shadow-xs',
open && !!value.length && 'shadow-xs',
)}>
@ -111,7 +111,7 @@ const LabelFilter: FC<LabelFilterProps> = ({
{filteredLabelList.map(label => (
<div
key={label.name}
className='flex cursor-pointer items-center gap-2 rounded-lg py-[6px] pl-3 pr-2 hover:bg-state-base-hover'
className='flex cursor-pointer select-none items-center gap-2 rounded-lg py-[6px] pl-3 pr-2 hover:bg-state-base-hover'
onClick={() => selectLabel(label)}
>
<div title={label.label} className='grow truncate text-sm leading-5 text-text-secondary'>{label.label}</div>

View File

@ -27,6 +27,8 @@ export type DuplicateAppModalProps = {
icon: string
icon_background?: string | null
server_identifier: string
timeout: number
sse_read_timeout: number
}) => void
onHide: () => void
}
@ -64,6 +66,8 @@ const MCPModal = ({
const [appIcon, setAppIcon] = useState<AppIconSelection>(getIcon(data))
const [showAppIconPicker, setShowAppIconPicker] = useState(false)
const [serverIdentifier, setServerIdentifier] = React.useState(data?.server_identifier || '')
const [timeout, setMcpTimeout] = React.useState(30)
const [sseReadTimeout, setSseReadTimeout] = React.useState(300)
const [isFetchingIcon, setIsFetchingIcon] = useState(false)
const appIconRef = useRef<HTMLDivElement>(null)
const isHovering = useHover(appIconRef)
@ -73,7 +77,7 @@ const MCPModal = ({
const urlPattern = /^(https?:\/\/)((([a-z\d]([a-z\d-]*[a-z\d])*)\.)+[a-z]{2,}|((\d{1,3}\.){3}\d{1,3})|localhost)(\:\d+)?(\/[-a-z\d%_.~+]*)*(\?[;&a-z\d%_.~+=-]*)?/i
return urlPattern.test(string)
}
catch (e) {
catch {
return false
}
}
@ -123,6 +127,8 @@ const MCPModal = ({
icon: appIcon.type === 'emoji' ? appIcon.icon : appIcon.fileId,
icon_background: appIcon.type === 'emoji' ? appIcon.background : undefined,
server_identifier: serverIdentifier.trim(),
timeout: timeout || 30,
sse_read_timeout: sseReadTimeout || 300,
})
if(isCreate)
onHide()
@ -201,6 +207,30 @@ const MCPModal = ({
</div>
)}
</div>
<div>
<div className='mb-1 flex h-6 items-center'>
<span className='system-sm-medium text-text-secondary'>{t('tools.mcp.modal.timeout')}</span>
</div>
<Input
type='number'
value={timeout}
onChange={e => setMcpTimeout(Number(e.target.value))}
onBlur={e => handleBlur(e.target.value.trim())}
placeholder={t('tools.mcp.modal.timeoutPlaceholder')}
/>
</div>
<div>
<div className='mb-1 flex h-6 items-center'>
<span className='system-sm-medium text-text-secondary'>{t('tools.mcp.modal.sseReadTimeout')}</span>
</div>
<Input
type='number'
value={sseReadTimeout}
onChange={e => setSseReadTimeout(Number(e.target.value))}
onBlur={e => handleBlur(e.target.value.trim())}
placeholder={t('tools.mcp.modal.timeoutPlaceholder')}
/>
</div>
</div>
<div className='flex flex-row-reverse pt-5'>
<Button disabled={!name || !url || !serverIdentifier || isFetchingIcon} className='ml-2' variant='primary' onClick={submit}>{data ? t('tools.mcp.modal.save') : t('tools.mcp.modal.confirm')}</Button>

View File

@ -57,6 +57,8 @@ export type Collection = {
server_url?: string
updated_at?: number
server_identifier?: string
timeout?: number
sse_read_timeout?: number
}
export type ToolParameter = {

View File

@ -18,3 +18,4 @@ export * from './use-workflow-mode'
export * from './use-workflow-refresh-draft'
export * from './use-inspect-vars-crud'
export * from './use-set-workflow-vars-with-value'
export * from './use-workflow-search'

Some files were not shown because too many files have changed in this diff Show More