mirror of https://github.com/langgenius/dify.git
Merge branch 'main' into feat/model-auth
This commit is contained in:
commit
e69797d738
|
|
@ -5,7 +5,7 @@ cd web && pnpm install
|
|||
pipx install uv
|
||||
|
||||
echo 'alias start-api="cd /workspaces/dify/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc
|
||||
echo 'alias start-worker="cd /workspaces/dify/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc
|
||||
echo 'alias start-worker="cd /workspaces/dify/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage"' >> ~/.bashrc
|
||||
echo 'alias start-web="cd /workspaces/dify/web && pnpm dev"' >> ~/.bashrc
|
||||
echo 'alias start-web-prod="cd /workspaces/dify/web && pnpm build && pnpm start"' >> ~/.bashrc
|
||||
echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d"' >> ~/.bashrc
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ inputs:
|
|||
uv-version:
|
||||
description: UV version to set up
|
||||
required: true
|
||||
default: '~=0.7.11'
|
||||
default: '0.8.9'
|
||||
uv-lockfile:
|
||||
description: Path to the UV lockfile to restore cache from
|
||||
required: true
|
||||
|
|
@ -26,7 +26,7 @@ runs:
|
|||
python-version: ${{ inputs.python-version }}
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v5
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
version: ${{ inputs.uv-version }}
|
||||
python-version: ${{ inputs.python-version }}
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ FROM python:3.12-slim-bookworm AS base
|
|||
WORKDIR /app/api
|
||||
|
||||
# Install uv
|
||||
ENV UV_VERSION=0.7.11
|
||||
ENV UV_VERSION=0.8.9
|
||||
|
||||
RUN pip install --no-cache-dir uv==${UV_VERSION}
|
||||
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@
|
|||
10. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
|
||||
|
||||
```bash
|
||||
uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin
|
||||
uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage
|
||||
```
|
||||
|
||||
Addition, if you want to debug the celery scheduled tasks, you can use the following command in another terminal:
|
||||
|
|
|
|||
|
|
@ -552,12 +552,18 @@ class RepositoryConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
CORE_WORKFLOW_EXECUTION_REPOSITORY: str = Field(
|
||||
description="Repository implementation for WorkflowExecution. Specify as a module path",
|
||||
description="Repository implementation for WorkflowExecution. Options: "
|
||||
"'core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository' (default), "
|
||||
"'core.repositories.celery_workflow_execution_repository.CeleryWorkflowExecutionRepository'",
|
||||
default="core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository",
|
||||
)
|
||||
|
||||
CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field(
|
||||
description="Repository implementation for WorkflowNodeExecution. Specify as a module path",
|
||||
description="Repository implementation for WorkflowNodeExecution. Options: "
|
||||
"'core.repositories.sqlalchemy_workflow_node_execution_repository."
|
||||
"SQLAlchemyWorkflowNodeExecutionRepository' (default), "
|
||||
"'core.repositories.celery_workflow_node_execution_repository."
|
||||
"CeleryWorkflowNodeExecutionRepository'",
|
||||
default="core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository",
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -144,7 +144,8 @@ class DatabaseConfig(BaseSettings):
|
|||
default="postgresql",
|
||||
)
|
||||
|
||||
@computed_field
|
||||
@computed_field # type: ignore[misc]
|
||||
@property
|
||||
def SQLALCHEMY_DATABASE_URI(self) -> str:
|
||||
db_extras = (
|
||||
f"{self.DB_EXTRAS}&client_encoding={self.DB_CHARSET}" if self.DB_CHARSET else self.DB_EXTRAS
|
||||
|
|
|
|||
|
|
@ -862,6 +862,10 @@ class ToolProviderMCPApi(Resource):
|
|||
parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
|
||||
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("timeout", type=float, required=False, nullable=False, location="json", default=30)
|
||||
parser.add_argument(
|
||||
"sse_read_timeout", type=float, required=False, nullable=False, location="json", default=300
|
||||
)
|
||||
args = parser.parse_args()
|
||||
user = current_user
|
||||
if not is_valid_url(args["server_url"]):
|
||||
|
|
@ -876,6 +880,8 @@ class ToolProviderMCPApi(Resource):
|
|||
icon_background=args["icon_background"],
|
||||
user_id=user.id,
|
||||
server_identifier=args["server_identifier"],
|
||||
timeout=args["timeout"],
|
||||
sse_read_timeout=args["sse_read_timeout"],
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -891,6 +897,8 @@ class ToolProviderMCPApi(Resource):
|
|||
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("timeout", type=float, required=False, nullable=True, location="json")
|
||||
parser.add_argument("sse_read_timeout", type=float, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
if not is_valid_url(args["server_url"]):
|
||||
if "[__HIDDEN__]" in args["server_url"]:
|
||||
|
|
@ -906,6 +914,8 @@ class ToolProviderMCPApi(Resource):
|
|||
icon_type=args["icon_type"],
|
||||
icon_background=args["icon_background"],
|
||||
server_identifier=args["server_identifier"],
|
||||
timeout=args.get("timeout"),
|
||||
sse_read_timeout=args.get("sse_read_timeout"),
|
||||
)
|
||||
return {"result": "success"}
|
||||
|
||||
|
|
|
|||
|
|
@ -568,7 +568,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
)
|
||||
|
||||
yield workflow_finish_resp
|
||||
self._base_task_pipeline._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
|
||||
def _handle_workflow_partial_success_event(
|
||||
self,
|
||||
|
|
@ -600,7 +600,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
)
|
||||
|
||||
yield workflow_finish_resp
|
||||
self._base_task_pipeline._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
|
||||
def _handle_workflow_failed_event(
|
||||
self,
|
||||
|
|
@ -845,7 +845,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
# Initialize graph runtime state
|
||||
graph_runtime_state: Optional[GraphRuntimeState] = None
|
||||
|
||||
for queue_message in self._base_task_pipeline._queue_manager.listen():
|
||||
for queue_message in self._base_task_pipeline.queue_manager.listen():
|
||||
event = queue_message.event
|
||||
|
||||
match event:
|
||||
|
|
@ -959,11 +959,11 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
if self._base_task_pipeline._output_moderation_handler:
|
||||
if self._base_task_pipeline._output_moderation_handler.should_direct_output():
|
||||
self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output()
|
||||
self._base_task_pipeline._queue_manager.publish(
|
||||
self._base_task_pipeline.queue_manager.publish(
|
||||
QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
|
||||
self._base_task_pipeline._queue_manager.publish(
|
||||
self._base_task_pipeline.queue_manager.publish(
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -711,7 +711,7 @@ class WorkflowAppGenerateTaskPipeline:
|
|||
# Initialize graph runtime state
|
||||
graph_runtime_state = None
|
||||
|
||||
for queue_message in self._base_task_pipeline._queue_manager.listen():
|
||||
for queue_message in self._base_task_pipeline.queue_manager.listen():
|
||||
event = queue_message.event
|
||||
|
||||
match event:
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAp
|
|||
from core.entities.provider_configuration import ProviderModelBundle
|
||||
from core.file import File, FileUploadConfig
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
|
||||
|
||||
class InvokeFrom(Enum):
|
||||
|
|
@ -114,7 +113,8 @@ class AppGenerateEntity(BaseModel):
|
|||
extras: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# tracing instance
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
# Using Any to avoid circular import with TraceQueueManager
|
||||
trace_manager: Optional[Any] = None
|
||||
|
||||
|
||||
class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ class BasedGenerateTaskPipeline:
|
|||
stream: bool,
|
||||
) -> None:
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._queue_manager = queue_manager
|
||||
self.queue_manager = queue_manager
|
||||
self._start_at = time.perf_counter()
|
||||
self._output_moderation_handler = self._init_output_moderation()
|
||||
self._stream = stream
|
||||
|
|
@ -113,7 +113,7 @@ class BasedGenerateTaskPipeline:
|
|||
tenant_id=app_config.tenant_id,
|
||||
app_id=app_config.app_id,
|
||||
rule=ModerationRule(type=sensitive_word_avoidance.type, config=sensitive_word_avoidance.config),
|
||||
queue_manager=self._queue_manager,
|
||||
queue_manager=self.queue_manager,
|
||||
)
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -257,7 +257,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||
Process stream response.
|
||||
:return:
|
||||
"""
|
||||
for message in self._queue_manager.listen():
|
||||
for message in self.queue_manager.listen():
|
||||
if publisher:
|
||||
publisher.publish(message)
|
||||
event = message.event
|
||||
|
|
@ -499,7 +499,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||
if self._output_moderation_handler.should_direct_output():
|
||||
# stop subscribe new token when output moderation should direct output
|
||||
self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output()
|
||||
self._queue_manager.publish(
|
||||
self.queue_manager.publish(
|
||||
QueueLLMChunkEvent(
|
||||
chunk=LLMResultChunk(
|
||||
model=self._task_state.llm_result.model,
|
||||
|
|
@ -513,7 +513,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||
PublishFrom.TASK_PIPELINE,
|
||||
)
|
||||
|
||||
self._queue_manager.publish(
|
||||
self.queue_manager.publish(
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -327,7 +327,7 @@ def send_message(http_client: httpx.Client, endpoint_url: str, session_message:
|
|||
)
|
||||
response.raise_for_status()
|
||||
logger.debug("Client message sent successfully: %s", response.status_code)
|
||||
except Exception as exc:
|
||||
except Exception:
|
||||
logger.exception("Error sending message")
|
||||
raise
|
||||
|
||||
|
|
|
|||
|
|
@ -55,14 +55,10 @@ DEFAULT_QUEUE_READ_TIMEOUT = 3
|
|||
class StreamableHTTPError(Exception):
|
||||
"""Base exception for StreamableHTTP transport errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ResumptionError(StreamableHTTPError):
|
||||
"""Raised when resumption request is invalid."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestContext:
|
||||
|
|
@ -74,7 +70,7 @@ class RequestContext:
|
|||
session_message: SessionMessage
|
||||
metadata: ClientMessageMetadata | None
|
||||
server_to_client_queue: ServerToClientQueue # Renamed for clarity
|
||||
sse_read_timeout: timedelta
|
||||
sse_read_timeout: float
|
||||
|
||||
|
||||
class StreamableHTTPTransport:
|
||||
|
|
@ -84,8 +80,8 @@ class StreamableHTTPTransport:
|
|||
self,
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: timedelta = timedelta(seconds=30),
|
||||
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
|
||||
timeout: float | timedelta = 30,
|
||||
sse_read_timeout: float | timedelta = 60 * 5,
|
||||
) -> None:
|
||||
"""Initialize the StreamableHTTP transport.
|
||||
|
||||
|
|
@ -97,8 +93,10 @@ class StreamableHTTPTransport:
|
|||
"""
|
||||
self.url = url
|
||||
self.headers = headers or {}
|
||||
self.timeout = timeout
|
||||
self.sse_read_timeout = sse_read_timeout
|
||||
self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout
|
||||
self.sse_read_timeout = (
|
||||
sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout
|
||||
)
|
||||
self.session_id: str | None = None
|
||||
self.request_headers = {
|
||||
ACCEPT: f"{JSON}, {SSE}",
|
||||
|
|
@ -186,7 +184,7 @@ class StreamableHTTPTransport:
|
|||
with ssrf_proxy_sse_connect(
|
||||
self.url,
|
||||
headers=headers,
|
||||
timeout=httpx.Timeout(self.timeout.seconds, read=self.sse_read_timeout.seconds),
|
||||
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
|
||||
client=client,
|
||||
method="GET",
|
||||
) as event_source:
|
||||
|
|
@ -215,7 +213,7 @@ class StreamableHTTPTransport:
|
|||
with ssrf_proxy_sse_connect(
|
||||
self.url,
|
||||
headers=headers,
|
||||
timeout=httpx.Timeout(self.timeout.seconds, read=ctx.sse_read_timeout.seconds),
|
||||
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
|
||||
client=ctx.client,
|
||||
method="GET",
|
||||
) as event_source:
|
||||
|
|
@ -402,8 +400,8 @@ class StreamableHTTPTransport:
|
|||
def streamablehttp_client(
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: timedelta = timedelta(seconds=30),
|
||||
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
|
||||
timeout: float | timedelta = 30,
|
||||
sse_read_timeout: float | timedelta = 60 * 5,
|
||||
terminate_on_close: bool = True,
|
||||
) -> Generator[
|
||||
tuple[
|
||||
|
|
@ -436,7 +434,7 @@ def streamablehttp_client(
|
|||
try:
|
||||
with create_ssrf_proxy_mcp_http_client(
|
||||
headers=transport.request_headers,
|
||||
timeout=httpx.Timeout(transport.timeout.seconds, read=transport.sse_read_timeout.seconds),
|
||||
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
|
||||
) as client:
|
||||
# Define callbacks that need access to thread pool
|
||||
def start_get_stream() -> None:
|
||||
|
|
|
|||
|
|
@ -23,12 +23,18 @@ class MCPClient:
|
|||
authed: bool = True,
|
||||
authorization_code: Optional[str] = None,
|
||||
for_list: bool = False,
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
sse_read_timeout: Optional[float] = None,
|
||||
):
|
||||
# Initialize info
|
||||
self.provider_id = provider_id
|
||||
self.tenant_id = tenant_id
|
||||
self.client_type = "streamable"
|
||||
self.server_url = server_url
|
||||
self.headers = headers or {}
|
||||
self.timeout = timeout
|
||||
self.sse_read_timeout = sse_read_timeout
|
||||
|
||||
# Authentication info
|
||||
self.authed = authed
|
||||
|
|
@ -43,7 +49,7 @@ class MCPClient:
|
|||
self._session: Optional[ClientSession] = None
|
||||
self._streams_context: Optional[AbstractContextManager[Any]] = None
|
||||
self._session_context: Optional[ClientSession] = None
|
||||
self.exit_stack = ExitStack()
|
||||
self._exit_stack = ExitStack()
|
||||
|
||||
# Whether the client has been initialized
|
||||
self._initialized = False
|
||||
|
|
@ -90,21 +96,26 @@ class MCPClient:
|
|||
headers = (
|
||||
{"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
|
||||
if self.authed and self.token
|
||||
else {}
|
||||
else self.headers
|
||||
)
|
||||
self._streams_context = client_factory(
|
||||
url=self.server_url,
|
||||
headers=headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
)
|
||||
self._streams_context = client_factory(url=self.server_url, headers=headers)
|
||||
if not self._streams_context:
|
||||
raise MCPConnectionError("Failed to create connection context")
|
||||
|
||||
# Use exit_stack to manage context managers properly
|
||||
if method_name == "mcp":
|
||||
read_stream, write_stream, _ = self.exit_stack.enter_context(self._streams_context)
|
||||
read_stream, write_stream, _ = self._exit_stack.enter_context(self._streams_context)
|
||||
streams = (read_stream, write_stream)
|
||||
else: # sse_client
|
||||
streams = self.exit_stack.enter_context(self._streams_context)
|
||||
streams = self._exit_stack.enter_context(self._streams_context)
|
||||
|
||||
self._session_context = ClientSession(*streams)
|
||||
self._session = self.exit_stack.enter_context(self._session_context)
|
||||
self._session = self._exit_stack.enter_context(self._session_context)
|
||||
session = cast(ClientSession, self._session)
|
||||
session.initialize()
|
||||
return
|
||||
|
|
@ -120,9 +131,6 @@ class MCPClient:
|
|||
if first_try:
|
||||
return self.connect_server(client_factory, method_name, first_try=False)
|
||||
|
||||
except MCPConnectionError:
|
||||
raise
|
||||
|
||||
def list_tools(self) -> list[Tool]:
|
||||
"""Connect to an MCP server running with SSE transport"""
|
||||
# List available tools to verify connection
|
||||
|
|
@ -142,7 +150,7 @@ class MCPClient:
|
|||
"""Clean up resources"""
|
||||
try:
|
||||
# ExitStack will handle proper cleanup of all managed context managers
|
||||
self.exit_stack.close()
|
||||
self._exit_stack.close()
|
||||
except Exception as e:
|
||||
logging.exception("Error during cleanup")
|
||||
raise ValueError(f"Error during cleanup: {e}")
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import logging
|
|||
import queue
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError
|
||||
from contextlib import ExitStack
|
||||
from datetime import timedelta
|
||||
from types import TracebackType
|
||||
from typing import Any, Generic, Self, TypeVar
|
||||
|
|
@ -170,7 +169,6 @@ class BaseSession(
|
|||
self._receive_notification_type = receive_notification_type
|
||||
self._session_read_timeout_seconds = read_timeout_seconds
|
||||
self._in_flight = {}
|
||||
self._exit_stack = ExitStack()
|
||||
# Initialize executor and future to None for proper cleanup checks
|
||||
self._executor: ThreadPoolExecutor | None = None
|
||||
self._receiver_future: Future | None = None
|
||||
|
|
@ -377,7 +375,7 @@ class BaseSession(
|
|||
self._handle_incoming(RuntimeError(f"Server Error: {message}"))
|
||||
except queue.Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logging.exception("Error in message processing loop")
|
||||
raise
|
||||
|
||||
|
|
@ -389,14 +387,12 @@ class BaseSession(
|
|||
If the request is responded to within this method, it will not be
|
||||
forwarded on to the message stream.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _received_notification(self, notification: ReceiveNotificationT) -> None:
|
||||
"""
|
||||
Can be overridden by subclasses to handle a notification without needing
|
||||
to listen on the message stream.
|
||||
"""
|
||||
pass
|
||||
|
||||
def send_progress_notification(
|
||||
self, progress_token: str | int, progress: float, total: float | None = None
|
||||
|
|
@ -405,11 +401,9 @@ class BaseSession(
|
|||
Sends a progress notification for a request that is currently being
|
||||
processed.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _handle_incoming(
|
||||
self,
|
||||
req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception,
|
||||
) -> None:
|
||||
"""A generic handler for incoming messages. Overwritten by subclasses."""
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import queue
|
||||
from datetime import timedelta
|
||||
from typing import Any, Protocol
|
||||
|
||||
|
|
@ -85,8 +86,8 @@ class ClientSession(
|
|||
):
|
||||
def __init__(
|
||||
self,
|
||||
read_stream,
|
||||
write_stream,
|
||||
read_stream: queue.Queue,
|
||||
write_stream: queue.Queue,
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
sampling_callback: SamplingFnT | None = None,
|
||||
list_roots_callback: ListRootsFnT | None = None,
|
||||
|
|
|
|||
|
|
@ -99,13 +99,13 @@ class TokenBufferMemory:
|
|||
prompt_messages.append(UserPromptMessage(content=message.query))
|
||||
else:
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
|
||||
for file in file_objs:
|
||||
prompt_message = file_manager.to_prompt_message_content(
|
||||
file,
|
||||
image_detail_config=detail,
|
||||
)
|
||||
prompt_message_contents.append(prompt_message)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
|
||||
|
|
|
|||
|
|
@ -257,11 +257,6 @@ class ModelProviderFactory:
|
|||
# scan all providers
|
||||
plugin_model_provider_entities = self.get_plugin_model_providers()
|
||||
|
||||
# convert provider_configs to dict
|
||||
provider_credentials_dict = {}
|
||||
for provider_config in provider_configs:
|
||||
provider_credentials_dict[provider_config.provider] = provider_config.credentials
|
||||
|
||||
# traverse all model_provider_extensions
|
||||
providers = []
|
||||
for plugin_model_provider_entity in plugin_model_provider_entities:
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ class CommonValidator:
|
|||
if credential_form_schema.max_length:
|
||||
if len(value) > credential_form_schema.max_length:
|
||||
raise ValueError(
|
||||
f"Variable {credential_form_schema.variable} length should not"
|
||||
f"Variable {credential_form_schema.variable} length should not be"
|
||||
f" greater than {credential_form_schema.max_length}"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +0,0 @@
|
|||
import pydantic
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def dump_model(model: BaseModel) -> dict:
|
||||
if hasattr(pydantic, "model_dump"):
|
||||
# FIXME mypy error, try to fix it instead of using type: ignore
|
||||
return pydantic.model_dump(model) # type: ignore
|
||||
else:
|
||||
return model.model_dump()
|
||||
|
|
@ -109,8 +109,19 @@ class OracleVector(BaseVector):
|
|||
)
|
||||
|
||||
def _get_connection(self) -> Connection:
|
||||
connection = oracledb.connect(user=self.config.user, password=self.config.password, dsn=self.config.dsn)
|
||||
return connection
|
||||
if self.config.is_autonomous:
|
||||
connection = oracledb.connect(
|
||||
user=self.config.user,
|
||||
password=self.config.password,
|
||||
dsn=self.config.dsn,
|
||||
config_dir=self.config.config_dir,
|
||||
wallet_location=self.config.wallet_location,
|
||||
wallet_password=self.config.wallet_password,
|
||||
)
|
||||
return connection
|
||||
else:
|
||||
connection = oracledb.connect(user=self.config.user, password=self.config.password, dsn=self.config.dsn)
|
||||
return connection
|
||||
|
||||
def _create_connection_pool(self, config: OracleVectorConfig):
|
||||
pool_params = {
|
||||
|
|
|
|||
|
|
@ -5,10 +5,14 @@ This package contains concrete implementations of the repository interfaces
|
|||
defined in the core.workflow.repository package.
|
||||
"""
|
||||
|
||||
from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository
|
||||
from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository
|
||||
from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
|
||||
__all__ = [
|
||||
"CeleryWorkflowExecutionRepository",
|
||||
"CeleryWorkflowNodeExecutionRepository",
|
||||
"DifyCoreRepositoryFactory",
|
||||
"RepositoryImportError",
|
||||
"SQLAlchemyWorkflowNodeExecutionRepository",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,126 @@
|
|||
"""
|
||||
Celery-based implementation of the WorkflowExecutionRepository.
|
||||
|
||||
This implementation uses Celery tasks for asynchronous storage operations,
|
||||
providing improved performance by offloading database operations to background workers.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from libs.helper import extract_tenant_id
|
||||
from models import Account, CreatorUserRole, EndUser
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from tasks.workflow_execution_tasks import (
|
||||
save_workflow_execution_task,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository):
|
||||
"""
|
||||
Celery-based implementation of the WorkflowExecutionRepository interface.
|
||||
|
||||
This implementation provides asynchronous storage capabilities by using Celery tasks
|
||||
to handle database operations in background workers. This improves performance by
|
||||
reducing the blocking time for workflow execution storage operations.
|
||||
|
||||
Key features:
|
||||
- Asynchronous save operations using Celery tasks
|
||||
- Support for multi-tenancy through tenant/app filtering
|
||||
- Automatic retry and error handling through Celery
|
||||
"""
|
||||
|
||||
_session_factory: sessionmaker
|
||||
_tenant_id: str
|
||||
_app_id: Optional[str]
|
||||
_triggered_from: Optional[WorkflowRunTriggeredFrom]
|
||||
_creator_user_id: str
|
||||
_creator_user_role: CreatorUserRole
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_factory: sessionmaker | Engine,
|
||||
user: Union[Account, EndUser],
|
||||
app_id: Optional[str],
|
||||
triggered_from: Optional[WorkflowRunTriggeredFrom],
|
||||
):
|
||||
"""
|
||||
Initialize the repository with Celery task configuration and context information.
|
||||
|
||||
Args:
|
||||
session_factory: SQLAlchemy sessionmaker or engine for fallback operations
|
||||
user: Account or EndUser object containing tenant_id, user ID, and role information
|
||||
app_id: App ID for filtering by application (can be None)
|
||||
triggered_from: Source of the execution trigger (DEBUGGING or APP_RUN)
|
||||
"""
|
||||
# Store session factory for fallback operations
|
||||
if isinstance(session_factory, Engine):
|
||||
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
|
||||
elif isinstance(session_factory, sessionmaker):
|
||||
self._session_factory = session_factory
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
|
||||
)
|
||||
|
||||
# Extract tenant_id from user
|
||||
tenant_id = extract_tenant_id(user)
|
||||
if not tenant_id:
|
||||
raise ValueError("User must have a tenant_id or current_tenant_id")
|
||||
self._tenant_id = tenant_id # type: ignore[assignment] # We've already checked tenant_id is not None
|
||||
|
||||
# Store app context
|
||||
self._app_id = app_id
|
||||
|
||||
# Extract user context
|
||||
self._triggered_from = triggered_from
|
||||
self._creator_user_id = user.id
|
||||
|
||||
# Determine user role based on user type
|
||||
self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER
|
||||
|
||||
logger.info(
|
||||
"Initialized CeleryWorkflowExecutionRepository for tenant %s, app %s, triggered_from %s",
|
||||
self._tenant_id,
|
||||
self._app_id,
|
||||
self._triggered_from,
|
||||
)
|
||||
|
||||
def save(self, execution: WorkflowExecution) -> None:
|
||||
"""
|
||||
Save or update a WorkflowExecution instance asynchronously using Celery.
|
||||
|
||||
This method queues the save operation as a Celery task and returns immediately,
|
||||
providing improved performance for high-throughput scenarios.
|
||||
|
||||
Args:
|
||||
execution: The WorkflowExecution instance to save or update
|
||||
"""
|
||||
try:
|
||||
# Serialize execution for Celery task
|
||||
execution_data = execution.model_dump()
|
||||
|
||||
# Queue the save operation as a Celery task (fire and forget)
|
||||
save_workflow_execution_task.delay(
|
||||
execution_data=execution_data,
|
||||
tenant_id=self._tenant_id,
|
||||
app_id=self._app_id or "",
|
||||
triggered_from=self._triggered_from.value if self._triggered_from else "",
|
||||
creator_user_id=self._creator_user_id,
|
||||
creator_user_role=self._creator_user_role.value,
|
||||
)
|
||||
|
||||
logger.debug("Queued async save for workflow execution: %s", execution.id_)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to queue save operation for execution %s", execution.id_)
|
||||
# In case of Celery failure, we could implement a fallback to synchronous save
|
||||
# For now, we'll re-raise the exception
|
||||
raise
|
||||
|
|
@ -0,0 +1,190 @@
|
|||
"""
|
||||
Celery-based implementation of the WorkflowNodeExecutionRepository.
|
||||
|
||||
This implementation uses Celery tasks for asynchronous storage operations,
|
||||
providing improved performance by offloading database operations to background workers.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional, Union
|
||||
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution
|
||||
from core.workflow.repositories.workflow_node_execution_repository import (
|
||||
OrderConfig,
|
||||
WorkflowNodeExecutionRepository,
|
||||
)
|
||||
from libs.helper import extract_tenant_id
|
||||
from models import Account, CreatorUserRole, EndUser
|
||||
from models.workflow import WorkflowNodeExecutionTriggeredFrom
|
||||
from tasks.workflow_node_execution_tasks import (
|
||||
save_workflow_node_execution_task,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
|
||||
"""
|
||||
Celery-based implementation of the WorkflowNodeExecutionRepository interface.
|
||||
|
||||
This implementation provides asynchronous storage capabilities by using Celery tasks
|
||||
to handle database operations in background workers. This improves performance by
|
||||
reducing the blocking time for workflow node execution storage operations.
|
||||
|
||||
Key features:
|
||||
- Asynchronous save operations using Celery tasks
|
||||
- In-memory cache for immediate reads
|
||||
- Support for multi-tenancy through tenant/app filtering
|
||||
- Automatic retry and error handling through Celery
|
||||
"""
|
||||
|
||||
_session_factory: sessionmaker
|
||||
_tenant_id: str
|
||||
_app_id: Optional[str]
|
||||
_triggered_from: Optional[WorkflowNodeExecutionTriggeredFrom]
|
||||
_creator_user_id: str
|
||||
_creator_user_role: CreatorUserRole
|
||||
_execution_cache: dict[str, WorkflowNodeExecution]
|
||||
_workflow_execution_mapping: dict[str, list[str]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_factory: sessionmaker | Engine,
|
||||
user: Union[Account, EndUser],
|
||||
app_id: Optional[str],
|
||||
triggered_from: Optional[WorkflowNodeExecutionTriggeredFrom],
|
||||
):
|
||||
"""
|
||||
Initialize the repository with Celery task configuration and context information.
|
||||
|
||||
Args:
|
||||
session_factory: SQLAlchemy sessionmaker or engine for fallback operations
|
||||
user: Account or EndUser object containing tenant_id, user ID, and role information
|
||||
app_id: App ID for filtering by application (can be None)
|
||||
triggered_from: Source of the execution trigger (SINGLE_STEP or WORKFLOW_RUN)
|
||||
"""
|
||||
# Store session factory for fallback operations
|
||||
if isinstance(session_factory, Engine):
|
||||
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
|
||||
elif isinstance(session_factory, sessionmaker):
|
||||
self._session_factory = session_factory
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
|
||||
)
|
||||
|
||||
# Extract tenant_id from user
|
||||
tenant_id = extract_tenant_id(user)
|
||||
if not tenant_id:
|
||||
raise ValueError("User must have a tenant_id or current_tenant_id")
|
||||
self._tenant_id = tenant_id # type: ignore[assignment] # We've already checked tenant_id is not None
|
||||
|
||||
# Store app context
|
||||
self._app_id = app_id
|
||||
|
||||
# Extract user context
|
||||
self._triggered_from = triggered_from
|
||||
self._creator_user_id = user.id
|
||||
|
||||
# Determine user role based on user type
|
||||
self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER
|
||||
|
||||
# In-memory cache for workflow node executions
|
||||
self._execution_cache: dict[str, WorkflowNodeExecution] = {}
|
||||
|
||||
# Cache for mapping workflow_execution_ids to execution IDs for efficient retrieval
|
||||
self._workflow_execution_mapping: dict[str, list[str]] = {}
|
||||
|
||||
logger.info(
|
||||
"Initialized CeleryWorkflowNodeExecutionRepository for tenant %s, app %s, triggered_from %s",
|
||||
self._tenant_id,
|
||||
self._app_id,
|
||||
self._triggered_from,
|
||||
)
|
||||
|
||||
def save(self, execution: WorkflowNodeExecution) -> None:
|
||||
"""
|
||||
Save or update a WorkflowNodeExecution instance to cache and asynchronously to database.
|
||||
|
||||
This method stores the execution in cache immediately for fast reads and queues
|
||||
the save operation as a Celery task without tracking the task status.
|
||||
|
||||
Args:
|
||||
execution: The WorkflowNodeExecution instance to save or update
|
||||
"""
|
||||
try:
|
||||
# Store in cache immediately for fast reads
|
||||
self._execution_cache[execution.id] = execution
|
||||
|
||||
# Update workflow execution mapping for efficient retrieval
|
||||
if execution.workflow_execution_id:
|
||||
if execution.workflow_execution_id not in self._workflow_execution_mapping:
|
||||
self._workflow_execution_mapping[execution.workflow_execution_id] = []
|
||||
if execution.id not in self._workflow_execution_mapping[execution.workflow_execution_id]:
|
||||
self._workflow_execution_mapping[execution.workflow_execution_id].append(execution.id)
|
||||
|
||||
# Serialize execution for Celery task
|
||||
execution_data = execution.model_dump()
|
||||
|
||||
# Queue the save operation as a Celery task (fire and forget)
|
||||
save_workflow_node_execution_task.delay(
|
||||
execution_data=execution_data,
|
||||
tenant_id=self._tenant_id,
|
||||
app_id=self._app_id or "",
|
||||
triggered_from=self._triggered_from.value if self._triggered_from else "",
|
||||
creator_user_id=self._creator_user_id,
|
||||
creator_user_role=self._creator_user_role.value,
|
||||
)
|
||||
|
||||
logger.debug("Cached and queued async save for workflow node execution: %s", execution.id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to cache or queue save operation for node execution %s", execution.id)
|
||||
# In case of Celery failure, we could implement a fallback to synchronous save
|
||||
# For now, we'll re-raise the exception
|
||||
raise
|
||||
|
||||
def get_by_workflow_run(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
order_config: Optional[OrderConfig] = None,
|
||||
) -> Sequence[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve all WorkflowNodeExecution instances for a specific workflow run from cache.
|
||||
|
||||
Args:
|
||||
workflow_run_id: The workflow run ID
|
||||
order_config: Optional configuration for ordering results
|
||||
|
||||
Returns:
|
||||
A sequence of WorkflowNodeExecution instances
|
||||
"""
|
||||
try:
|
||||
# Get execution IDs for this workflow run from cache
|
||||
execution_ids = self._workflow_execution_mapping.get(workflow_run_id, [])
|
||||
|
||||
# Retrieve executions from cache
|
||||
result = []
|
||||
for execution_id in execution_ids:
|
||||
if execution_id in self._execution_cache:
|
||||
result.append(self._execution_cache[execution_id])
|
||||
|
||||
# Apply ordering if specified
|
||||
if order_config and result:
|
||||
# Sort based on the order configuration
|
||||
reverse = order_config.order_direction == "desc"
|
||||
|
||||
# Sort by multiple fields if specified
|
||||
for field_name in reversed(order_config.order_by):
|
||||
result.sort(key=lambda x: getattr(x, field_name, 0), reverse=reverse)
|
||||
|
||||
logger.debug("Retrieved %d workflow node executions for run %s from cache", len(result), workflow_run_id)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to get workflow node executions for run %s from cache", workflow_run_id)
|
||||
return []
|
||||
|
|
@ -94,11 +94,9 @@ class DifyCoreRepositoryFactory:
|
|||
def _validate_constructor_signature(repository_class: type, required_params: list[str]) -> None:
|
||||
"""
|
||||
Validate that a repository class constructor accepts required parameters.
|
||||
|
||||
Args:
|
||||
repository_class: The class to validate
|
||||
required_params: List of required parameter names
|
||||
|
||||
Raises:
|
||||
RepositoryImportError: If the constructor doesn't accept required parameters
|
||||
"""
|
||||
|
|
@ -158,10 +156,8 @@ class DifyCoreRepositoryFactory:
|
|||
try:
|
||||
repository_class = cls._import_class(class_path)
|
||||
cls._validate_repository_interface(repository_class, WorkflowExecutionRepository)
|
||||
cls._validate_constructor_signature(
|
||||
repository_class, ["session_factory", "user", "app_id", "triggered_from"]
|
||||
)
|
||||
|
||||
# All repository types now use the same constructor parameters
|
||||
return repository_class( # type: ignore[no-any-return]
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
|
|
@ -204,10 +200,8 @@ class DifyCoreRepositoryFactory:
|
|||
try:
|
||||
repository_class = cls._import_class(class_path)
|
||||
cls._validate_repository_interface(repository_class, WorkflowNodeExecutionRepository)
|
||||
cls._validate_constructor_signature(
|
||||
repository_class, ["session_factory", "user", "app_id", "triggered_from"]
|
||||
)
|
||||
|
||||
# All repository types now use the same constructor parameters
|
||||
return repository_class( # type: ignore[no-any-return]
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
|
|
|
|||
|
|
@ -12,8 +12,6 @@ from core.tools.errors import ToolProviderCredentialValidationError
|
|||
|
||||
|
||||
class ToolProviderController(ABC):
|
||||
entity: ToolProviderEntity
|
||||
|
||||
def __init__(self, entity: ToolProviderEntity) -> None:
|
||||
self.entity = entity
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import json
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.mcp.types import Tool as RemoteMCPTool
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
|
|
@ -19,15 +19,24 @@ from services.tools.tools_transform_service import ToolTransformService
|
|||
|
||||
|
||||
class MCPToolProviderController(ToolProviderController):
|
||||
provider_id: str
|
||||
entity: ToolProviderEntityWithPlugin
|
||||
|
||||
def __init__(self, entity: ToolProviderEntityWithPlugin, provider_id: str, tenant_id: str, server_url: str) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
entity: ToolProviderEntityWithPlugin,
|
||||
provider_id: str,
|
||||
tenant_id: str,
|
||||
server_url: str,
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
sse_read_timeout: Optional[float] = None,
|
||||
) -> None:
|
||||
super().__init__(entity)
|
||||
self.entity = entity
|
||||
self.entity: ToolProviderEntityWithPlugin = entity
|
||||
self.tenant_id = tenant_id
|
||||
self.provider_id = provider_id
|
||||
self.server_url = server_url
|
||||
self.headers = headers or {}
|
||||
self.timeout = timeout
|
||||
self.sse_read_timeout = sse_read_timeout
|
||||
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
|
|
@ -85,6 +94,9 @@ class MCPToolProviderController(ToolProviderController):
|
|||
provider_id=db_provider.server_identifier or "",
|
||||
tenant_id=db_provider.tenant_id or "",
|
||||
server_url=db_provider.decrypted_server_url,
|
||||
headers={}, # TODO: get headers from db provider
|
||||
timeout=db_provider.timeout,
|
||||
sse_read_timeout=db_provider.sse_read_timeout,
|
||||
)
|
||||
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
|
|
@ -111,6 +123,9 @@ class MCPToolProviderController(ToolProviderController):
|
|||
icon=self.entity.identity.icon,
|
||||
server_url=self.server_url,
|
||||
provider_id=self.provider_id,
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
)
|
||||
|
||||
def get_tools(self) -> list[MCPTool]: # type: ignore
|
||||
|
|
@ -125,6 +140,9 @@ class MCPToolProviderController(ToolProviderController):
|
|||
icon=self.entity.identity.icon,
|
||||
server_url=self.server_url,
|
||||
provider_id=self.provider_id,
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
)
|
||||
for tool_entity in self.entity.tools
|
||||
]
|
||||
|
|
|
|||
|
|
@ -13,13 +13,25 @@ from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, Too
|
|||
|
||||
class MCPTool(Tool):
|
||||
def __init__(
|
||||
self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, server_url: str, provider_id: str
|
||||
self,
|
||||
entity: ToolEntity,
|
||||
runtime: ToolRuntime,
|
||||
tenant_id: str,
|
||||
icon: str,
|
||||
server_url: str,
|
||||
provider_id: str,
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
sse_read_timeout: Optional[float] = None,
|
||||
) -> None:
|
||||
super().__init__(entity, runtime)
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.server_url = server_url
|
||||
self.provider_id = provider_id
|
||||
self.headers = headers or {}
|
||||
self.timeout = timeout
|
||||
self.sse_read_timeout = sse_read_timeout
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.MCP
|
||||
|
|
@ -35,7 +47,15 @@ class MCPTool(Tool):
|
|||
from core.tools.errors import ToolInvokeError
|
||||
|
||||
try:
|
||||
with MCPClient(self.server_url, self.provider_id, self.tenant_id, authed=True) as mcp_client:
|
||||
with MCPClient(
|
||||
self.server_url,
|
||||
self.provider_id,
|
||||
self.tenant_id,
|
||||
authed=True,
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
) as mcp_client:
|
||||
tool_parameters = self._handle_none_parameter(tool_parameters)
|
||||
result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
|
||||
except MCPAuthError as e:
|
||||
|
|
@ -72,6 +92,9 @@ class MCPTool(Tool):
|
|||
icon=self.icon,
|
||||
server_url=self.server_url,
|
||||
provider_id=self.provider_id,
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
)
|
||||
|
||||
def _handle_none_parameter(self, parameter: dict[str, Any]) -> dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -789,9 +789,6 @@ class ToolManager:
|
|||
"""
|
||||
get api provider
|
||||
"""
|
||||
"""
|
||||
get tool provider
|
||||
"""
|
||||
provider_name = provider
|
||||
provider_obj: ApiToolProvider | None = (
|
||||
db.session.query(ApiToolProvider)
|
||||
|
|
|
|||
|
|
@ -4,4 +4,4 @@
|
|||
#
|
||||
# If the selector length is more than 2, the remaining parts are the keys / indexes paths used
|
||||
# to extract part of the variable value.
|
||||
MIN_SELECTORS_LENGTH = 2
|
||||
SELECTORS_LENGTH = 2
|
||||
|
|
|
|||
|
|
@ -7,8 +7,8 @@ from pydantic import BaseModel, Field
|
|||
|
||||
from core.file import File, FileAttribute, file_manager
|
||||
from core.variables import Segment, SegmentGroup, Variable
|
||||
from core.variables.consts import MIN_SELECTORS_LENGTH
|
||||
from core.variables.segments import FileSegment, NoneSegment
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.variables.segments import FileSegment, ObjectSegment
|
||||
from core.variables.variables import VariableUnion
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
|
@ -24,7 +24,7 @@ class VariablePool(BaseModel):
|
|||
# The first element of the selector is the node id, it's the first-level key in the dictionary.
|
||||
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
|
||||
# elements of the selector except the first one.
|
||||
variable_dictionary: defaultdict[str, Annotated[dict[int, VariableUnion], Field(default_factory=dict)]] = Field(
|
||||
variable_dictionary: defaultdict[str, Annotated[dict[str, VariableUnion], Field(default_factory=dict)]] = Field(
|
||||
description="Variables mapping",
|
||||
default=defaultdict(dict),
|
||||
)
|
||||
|
|
@ -36,6 +36,7 @@ class VariablePool(BaseModel):
|
|||
)
|
||||
system_variables: SystemVariable = Field(
|
||||
description="System variables",
|
||||
default_factory=SystemVariable.empty,
|
||||
)
|
||||
environment_variables: Sequence[VariableUnion] = Field(
|
||||
description="Environment variables.",
|
||||
|
|
@ -58,23 +59,29 @@ class VariablePool(BaseModel):
|
|||
|
||||
def add(self, selector: Sequence[str], value: Any, /) -> None:
|
||||
"""
|
||||
Adds a variable to the variable pool.
|
||||
Add a variable to the variable pool.
|
||||
|
||||
NOTE: You should not add a non-Segment value to the variable pool
|
||||
even if it is allowed now.
|
||||
This method accepts a selector path and a value, converting the value
|
||||
to a Variable object if necessary before storing it in the pool.
|
||||
|
||||
Args:
|
||||
selector (Sequence[str]): The selector for the variable.
|
||||
value (VariableValue): The value of the variable.
|
||||
selector: A two-element sequence containing [node_id, variable_name].
|
||||
The selector must have exactly 2 elements to be valid.
|
||||
value: The value to store. Can be a Variable, Segment, or any value
|
||||
that can be converted to a Segment (str, int, float, dict, list, File).
|
||||
|
||||
Raises:
|
||||
ValueError: If the selector is invalid.
|
||||
ValueError: If selector length is not exactly 2 elements.
|
||||
|
||||
Returns:
|
||||
None
|
||||
Note:
|
||||
While non-Segment values are currently accepted and automatically
|
||||
converted, it's recommended to pass Segment or Variable objects directly.
|
||||
"""
|
||||
if len(selector) < MIN_SELECTORS_LENGTH:
|
||||
raise ValueError("Invalid selector")
|
||||
if len(selector) != SELECTORS_LENGTH:
|
||||
raise ValueError(
|
||||
f"Invalid selector: expected {SELECTORS_LENGTH} elements (node_id, variable_name), "
|
||||
f"got {len(selector)} elements"
|
||||
)
|
||||
|
||||
if isinstance(value, Variable):
|
||||
variable = value
|
||||
|
|
@ -84,57 +91,85 @@ class VariablePool(BaseModel):
|
|||
segment = variable_factory.build_segment(value)
|
||||
variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
|
||||
|
||||
key, hash_key = self._selector_to_keys(selector)
|
||||
node_id, name = self._selector_to_keys(selector)
|
||||
# Based on the definition of `VariableUnion`,
|
||||
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
|
||||
self.variable_dictionary[key][hash_key] = cast(VariableUnion, variable)
|
||||
self.variable_dictionary[node_id][name] = cast(VariableUnion, variable)
|
||||
|
||||
@classmethod
|
||||
def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, int]:
|
||||
return selector[0], hash(tuple(selector[1:]))
|
||||
def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, str]:
|
||||
return selector[0], selector[1]
|
||||
|
||||
def _has(self, selector: Sequence[str]) -> bool:
|
||||
key, hash_key = self._selector_to_keys(selector)
|
||||
if key not in self.variable_dictionary:
|
||||
node_id, name = self._selector_to_keys(selector)
|
||||
if node_id not in self.variable_dictionary:
|
||||
return False
|
||||
if hash_key not in self.variable_dictionary[key]:
|
||||
if name not in self.variable_dictionary[node_id]:
|
||||
return False
|
||||
return True
|
||||
|
||||
def get(self, selector: Sequence[str], /) -> Segment | None:
|
||||
"""
|
||||
Retrieves the value from the variable pool based on the given selector.
|
||||
Retrieve a variable's value from the pool as a Segment.
|
||||
|
||||
This method supports both simple selectors [node_id, variable_name] and
|
||||
extended selectors that include attribute access for FileSegment and
|
||||
ObjectSegment types.
|
||||
|
||||
Args:
|
||||
selector (Sequence[str]): The selector used to identify the variable.
|
||||
selector: A sequence with at least 2 elements:
|
||||
- [node_id, variable_name]: Returns the full segment
|
||||
- [node_id, variable_name, attr, ...]: Returns a nested value
|
||||
from FileSegment (e.g., 'url', 'name') or ObjectSegment
|
||||
|
||||
Returns:
|
||||
Any: The value associated with the given selector.
|
||||
The Segment associated with the selector, or None if not found.
|
||||
Returns None if selector has fewer than 2 elements.
|
||||
|
||||
Raises:
|
||||
ValueError: If the selector is invalid.
|
||||
ValueError: If attempting to access an invalid FileAttribute.
|
||||
"""
|
||||
if len(selector) < MIN_SELECTORS_LENGTH:
|
||||
if len(selector) < SELECTORS_LENGTH:
|
||||
return None
|
||||
|
||||
key, hash_key = self._selector_to_keys(selector)
|
||||
value: Segment | None = self.variable_dictionary[key].get(hash_key)
|
||||
node_id, name = self._selector_to_keys(selector)
|
||||
segment: Segment | None = self.variable_dictionary[node_id].get(name)
|
||||
|
||||
if value is None:
|
||||
selector, attr = selector[:-1], selector[-1]
|
||||
if segment is None:
|
||||
return None
|
||||
|
||||
if len(selector) == 2:
|
||||
return segment
|
||||
|
||||
if isinstance(segment, FileSegment):
|
||||
attr = selector[2]
|
||||
# Python support `attr in FileAttribute` after 3.12
|
||||
if attr not in {item.value for item in FileAttribute}:
|
||||
return None
|
||||
value = self.get(selector)
|
||||
if not isinstance(value, FileSegment | NoneSegment):
|
||||
return None
|
||||
if isinstance(value, FileSegment):
|
||||
attr = FileAttribute(attr)
|
||||
attr_value = file_manager.get_attr(file=value.value, attr=attr)
|
||||
return variable_factory.build_segment(attr_value)
|
||||
return value
|
||||
attr = FileAttribute(attr)
|
||||
attr_value = file_manager.get_attr(file=segment.value, attr=attr)
|
||||
return variable_factory.build_segment(attr_value)
|
||||
|
||||
return value
|
||||
# Navigate through nested attributes
|
||||
result: Any = segment
|
||||
for attr in selector[2:]:
|
||||
result = self._extract_value(result)
|
||||
result = self._get_nested_attribute(result, attr)
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
# Return result as Segment
|
||||
return result if isinstance(result, Segment) else variable_factory.build_segment(result)
|
||||
|
||||
def _extract_value(self, obj: Any) -> Any:
|
||||
"""Extract the actual value from an ObjectSegment."""
|
||||
return obj.value if isinstance(obj, ObjectSegment) else obj
|
||||
|
||||
def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str) -> Any:
|
||||
"""Get a nested attribute from a dictionary-like object."""
|
||||
if not isinstance(obj, dict):
|
||||
return None
|
||||
return obj.get(attr)
|
||||
|
||||
def remove(self, selector: Sequence[str], /):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from configs import dify_config
|
|||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
|
|
@ -51,7 +51,6 @@ from core.workflow.nodes.base import BaseNode
|
|||
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
|
||||
from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
from core.workflow.utils import variable_utils
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
|
|
@ -701,11 +700,9 @@ class GraphEngine:
|
|||
route_node_state.status = RouteNodeState.Status.EXCEPTION
|
||||
if run_result.outputs:
|
||||
for variable_key, variable_value in run_result.outputs.items():
|
||||
# append variables to variable pool recursively
|
||||
self._append_variables_recursively(
|
||||
node_id=node.node_id,
|
||||
variable_key_list=[variable_key],
|
||||
variable_value=variable_value,
|
||||
# Add variables to variable pool
|
||||
self.graph_runtime_state.variable_pool.add(
|
||||
[node.node_id, variable_key], variable_value
|
||||
)
|
||||
yield NodeRunExceptionEvent(
|
||||
error=run_result.error or "System Error",
|
||||
|
|
@ -758,11 +755,9 @@ class GraphEngine:
|
|||
# append node output variables to variable pool
|
||||
if run_result.outputs:
|
||||
for variable_key, variable_value in run_result.outputs.items():
|
||||
# append variables to variable pool recursively
|
||||
self._append_variables_recursively(
|
||||
node_id=node.node_id,
|
||||
variable_key_list=[variable_key],
|
||||
variable_value=variable_value,
|
||||
# Add variables to variable pool
|
||||
self.graph_runtime_state.variable_pool.add(
|
||||
[node.node_id, variable_key], variable_value
|
||||
)
|
||||
|
||||
# When setting metadata, convert to dict first
|
||||
|
|
@ -851,21 +846,6 @@ class GraphEngine:
|
|||
logger.exception("Node %s run failed", node.title)
|
||||
raise e
|
||||
|
||||
def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
|
||||
"""
|
||||
Append variables recursively
|
||||
:param node_id: node id
|
||||
:param variable_key_list: variable key list
|
||||
:param variable_value: variable value
|
||||
:return:
|
||||
"""
|
||||
variable_utils.append_variables_recursively(
|
||||
self.graph_runtime_state.variable_pool,
|
||||
node_id,
|
||||
variable_key_list,
|
||||
variable_value,
|
||||
)
|
||||
|
||||
def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
|
||||
"""
|
||||
Check timeout
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from typing import Any, TypeVar
|
|||
from pydantic import BaseModel
|
||||
|
||||
from core.variables import Segment
|
||||
from core.variables.consts import MIN_SELECTORS_LENGTH
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.variables.types import SegmentType
|
||||
|
||||
# Use double underscore (`__`) prefix for internal variables
|
||||
|
|
@ -23,7 +23,7 @@ _T = TypeVar("_T", bound=MutableMapping[str, Any])
|
|||
|
||||
|
||||
def variable_to_processed_data(selector: Sequence[str], seg: Segment) -> UpdatedVariable:
|
||||
if len(selector) < MIN_SELECTORS_LENGTH:
|
||||
if len(selector) < SELECTORS_LENGTH:
|
||||
raise Exception("selector too short")
|
||||
node_id, var_name = selector[:2]
|
||||
return UpdatedVariable(
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from typing import Any, Optional, cast
|
|||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.variables.consts import MIN_SELECTORS_LENGTH
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
|
|
@ -46,7 +46,7 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_
|
|||
selector = item.value
|
||||
if not isinstance(selector, list):
|
||||
raise InvalidDataError(f"selector is not a list, {node_id=}, {item=}")
|
||||
if len(selector) < MIN_SELECTORS_LENGTH:
|
||||
if len(selector) < SELECTORS_LENGTH:
|
||||
raise InvalidDataError(f"selector too short, {node_id=}, {item=}")
|
||||
selector_str = ".".join(selector)
|
||||
key = f"{node_id}.#{selector_str}#"
|
||||
|
|
|
|||
|
|
@ -1,29 +0,0 @@
|
|||
from core.variables.segments import ObjectSegment, Segment
|
||||
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||
|
||||
|
||||
def append_variables_recursively(
|
||||
pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue | Segment
|
||||
):
|
||||
"""
|
||||
Append variables recursively
|
||||
:param pool: variable pool to append variables to
|
||||
:param node_id: node id
|
||||
:param variable_key_list: variable key list
|
||||
:param variable_value: variable value
|
||||
:return:
|
||||
"""
|
||||
pool.add([node_id] + variable_key_list, variable_value)
|
||||
|
||||
# if variable_value is a dict, then recursively append variables
|
||||
if isinstance(variable_value, ObjectSegment):
|
||||
variable_dict = variable_value.value
|
||||
elif isinstance(variable_value, dict):
|
||||
variable_dict = variable_value
|
||||
else:
|
||||
return
|
||||
|
||||
for key, value in variable_dict.items():
|
||||
# construct new key list
|
||||
new_key_list = variable_key_list + [key]
|
||||
append_variables_recursively(pool, node_id=node_id, variable_key_list=new_key_list, variable_value=value)
|
||||
|
|
@ -3,9 +3,8 @@ from collections.abc import Mapping, Sequence
|
|||
from typing import Any, Protocol
|
||||
|
||||
from core.variables import Variable
|
||||
from core.variables.consts import MIN_SELECTORS_LENGTH
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.utils import variable_utils
|
||||
|
||||
|
||||
class VariableLoader(Protocol):
|
||||
|
|
@ -78,7 +77,7 @@ def load_into_variable_pool(
|
|||
variables_to_load.append(list(selector))
|
||||
loaded = variable_loader.load_variables(variables_to_load)
|
||||
for var in loaded:
|
||||
assert len(var.selector) >= MIN_SELECTORS_LENGTH, f"Invalid variable {var}"
|
||||
variable_utils.append_variables_recursively(
|
||||
variable_pool, node_id=var.selector[0], variable_key_list=list(var.selector[1:]), variable_value=var
|
||||
)
|
||||
assert len(var.selector) >= SELECTORS_LENGTH, f"Invalid variable {var}"
|
||||
# Add variable directly to the pool
|
||||
# The variable pool expects 2-element selectors [node_id, variable_name]
|
||||
variable_pool.add([var.selector[0], var.selector[1]], var)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from collections.abc import Mapping
|
||||
from decimal import Decimal
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -17,6 +18,9 @@ class WorkflowRuntimeTypeConverter:
|
|||
return value
|
||||
if isinstance(value, (bool, int, str, float)):
|
||||
return value
|
||||
if isinstance(value, Decimal):
|
||||
# Convert Decimal to float for JSON serialization
|
||||
return float(value)
|
||||
if isinstance(value, Segment):
|
||||
return self._to_json_encodable_recursive(value.value)
|
||||
if isinstance(value, File):
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ if [[ "${MODE}" == "worker" ]]; then
|
|||
|
||||
exec celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION \
|
||||
--max-tasks-per-child ${MAX_TASK_PRE_CHILD:-50} --loglevel ${LOG_LEVEL:-INFO} \
|
||||
-Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion,plugin}
|
||||
-Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion,plugin,workflow_storage}
|
||||
|
||||
elif [[ "${MODE}" == "beat" ]]; then
|
||||
exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO}
|
||||
|
|
|
|||
|
|
@ -1,18 +1,23 @@
|
|||
import functools
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Union
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
import redis
|
||||
from redis import RedisError
|
||||
from redis.cache import CacheConfig
|
||||
from redis.cluster import ClusterNode, RedisCluster
|
||||
from redis.connection import Connection, SSLConnection
|
||||
from redis.lock import Lock
|
||||
from redis.sentinel import Sentinel
|
||||
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.lock import Lock
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -28,8 +33,8 @@ class RedisClientWrapper:
|
|||
a failover in a Sentinel-managed Redis setup.
|
||||
|
||||
Attributes:
|
||||
_client (redis.Redis): The actual Redis client instance. It remains None until
|
||||
initialized with the `initialize` method.
|
||||
_client: The actual Redis client instance. It remains None until
|
||||
initialized with the `initialize` method.
|
||||
|
||||
Methods:
|
||||
initialize(client): Initializes the Redis client if it hasn't been initialized already.
|
||||
|
|
@ -37,20 +42,78 @@ class RedisClientWrapper:
|
|||
if the client is not initialized.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
_client: Union[redis.Redis, RedisCluster, None]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._client = None
|
||||
|
||||
def initialize(self, client):
|
||||
def initialize(self, client: Union[redis.Redis, RedisCluster]) -> None:
|
||||
if self._client is None:
|
||||
self._client = client
|
||||
|
||||
def __getattr__(self, item):
|
||||
if TYPE_CHECKING:
|
||||
# Type hints for IDE support and static analysis
|
||||
# These are not executed at runtime but provide type information
|
||||
def get(self, name: str | bytes) -> Any: ...
|
||||
|
||||
def set(
|
||||
self,
|
||||
name: str | bytes,
|
||||
value: Any,
|
||||
ex: int | None = None,
|
||||
px: int | None = None,
|
||||
nx: bool = False,
|
||||
xx: bool = False,
|
||||
keepttl: bool = False,
|
||||
get: bool = False,
|
||||
exat: int | None = None,
|
||||
pxat: int | None = None,
|
||||
) -> Any: ...
|
||||
|
||||
def setex(self, name: str | bytes, time: int | timedelta, value: Any) -> Any: ...
|
||||
def setnx(self, name: str | bytes, value: Any) -> Any: ...
|
||||
def delete(self, *names: str | bytes) -> Any: ...
|
||||
def incr(self, name: str | bytes, amount: int = 1) -> Any: ...
|
||||
def expire(
|
||||
self,
|
||||
name: str | bytes,
|
||||
time: int | timedelta,
|
||||
nx: bool = False,
|
||||
xx: bool = False,
|
||||
gt: bool = False,
|
||||
lt: bool = False,
|
||||
) -> Any: ...
|
||||
def lock(
|
||||
self,
|
||||
name: str,
|
||||
timeout: float | None = None,
|
||||
sleep: float = 0.1,
|
||||
blocking: bool = True,
|
||||
blocking_timeout: float | None = None,
|
||||
thread_local: bool = True,
|
||||
) -> Lock: ...
|
||||
def zadd(
|
||||
self,
|
||||
name: str | bytes,
|
||||
mapping: dict[str | bytes | int | float, float | int | str | bytes],
|
||||
nx: bool = False,
|
||||
xx: bool = False,
|
||||
ch: bool = False,
|
||||
incr: bool = False,
|
||||
gt: bool = False,
|
||||
lt: bool = False,
|
||||
) -> Any: ...
|
||||
def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any: ...
|
||||
def zcard(self, name: str | bytes) -> Any: ...
|
||||
def getdel(self, name: str | bytes) -> Any: ...
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
if self._client is None:
|
||||
raise RuntimeError("Redis client is not initialized. Call init_app first.")
|
||||
return getattr(self._client, item)
|
||||
|
||||
|
||||
redis_client = RedisClientWrapper()
|
||||
redis_client: RedisClientWrapper = RedisClientWrapper()
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
|
|
@ -80,6 +143,9 @@ def init_app(app: DifyApp):
|
|||
|
||||
if dify_config.REDIS_USE_SENTINEL:
|
||||
assert dify_config.REDIS_SENTINELS is not None, "REDIS_SENTINELS must be set when REDIS_USE_SENTINEL is True"
|
||||
assert dify_config.REDIS_SENTINEL_SERVICE_NAME is not None, (
|
||||
"REDIS_SENTINEL_SERVICE_NAME must be set when REDIS_USE_SENTINEL is True"
|
||||
)
|
||||
sentinel_hosts = [
|
||||
(node.split(":")[0], int(node.split(":")[1])) for node in dify_config.REDIS_SENTINELS.split(",")
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,33 @@
|
|||
"""add timeout for tool_mcp_providers
|
||||
|
||||
Revision ID: fa8b0fa6f407
|
||||
Revises: 532b3f888abf
|
||||
Create Date: 2025-08-07 11:15:31.517985
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'fa8b0fa6f407'
|
||||
down_revision = '532b3f888abf'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('tool_mcp_providers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('timeout', sa.Float(), server_default=sa.text('30'), nullable=False))
|
||||
batch_op.add_column(sa.Column('sse_read_timeout', sa.Float(), server_default=sa.text('300'), nullable=False))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('tool_mcp_providers', schema=None) as batch_op:
|
||||
batch_op.drop_column('sse_read_timeout')
|
||||
batch_op.drop_column('timeout')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -278,6 +278,8 @@ class MCPToolProvider(Base):
|
|||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
)
|
||||
timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("30"))
|
||||
sse_read_timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("300"))
|
||||
|
||||
def load_user(self) -> Account | None:
|
||||
return db.session.query(Account).where(Account.id == self.user_id).first()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "dify-api"
|
||||
version = "1.7.1"
|
||||
version = "1.7.2"
|
||||
requires-python = ">=3.11,<3.13"
|
||||
|
||||
dependencies = [
|
||||
|
|
@ -162,6 +162,7 @@ dev = [
|
|||
"pandas-stubs~=2.2.3",
|
||||
"scipy-stubs>=1.15.3.0",
|
||||
"types-python-http-client>=3.3.7.20240910",
|
||||
"types-redis>=4.6.0.20241004",
|
||||
]
|
||||
|
||||
############################################################
|
||||
|
|
|
|||
|
|
@ -48,7 +48,6 @@ class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory):
|
|||
RepositoryImportError: If the configured repository cannot be imported or instantiated
|
||||
"""
|
||||
class_path = dify_config.API_WORKFLOW_NODE_EXECUTION_REPOSITORY
|
||||
logger.debug("Creating DifyAPIWorkflowNodeExecutionRepository from: %s", class_path)
|
||||
|
||||
try:
|
||||
repository_class = cls._import_class(class_path)
|
||||
|
|
@ -86,7 +85,6 @@ class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory):
|
|||
RepositoryImportError: If the configured repository cannot be imported or instantiated
|
||||
"""
|
||||
class_path = dify_config.API_WORKFLOW_RUN_REPOSITORY
|
||||
logger.debug("Creating APIWorkflowRunRepository from: %s", class_path)
|
||||
|
||||
try:
|
||||
repository_class = cls._import_class(class_path)
|
||||
|
|
|
|||
|
|
@ -24,9 +24,20 @@ def queue_monitor_task():
|
|||
queue_name = "dataset"
|
||||
threshold = dify_config.QUEUE_MONITOR_THRESHOLD
|
||||
|
||||
if threshold is None:
|
||||
logging.warning(click.style("QUEUE_MONITOR_THRESHOLD is not configured, skipping monitoring", fg="yellow"))
|
||||
return
|
||||
|
||||
try:
|
||||
queue_length = celery_redis.llen(f"{queue_name}")
|
||||
logging.info(click.style(f"Start monitor {queue_name}", fg="green"))
|
||||
|
||||
if queue_length is None:
|
||||
logging.error(
|
||||
click.style(f"Failed to get queue length for {queue_name} - Redis may be unavailable", fg="red")
|
||||
)
|
||||
return
|
||||
|
||||
logging.info(click.style(f"Queue length: {queue_length}", fg="green"))
|
||||
|
||||
if queue_length >= threshold:
|
||||
|
|
|
|||
|
|
@ -59,6 +59,8 @@ class MCPToolManageService:
|
|||
icon_type: str,
|
||||
icon_background: str,
|
||||
server_identifier: str,
|
||||
timeout: float,
|
||||
sse_read_timeout: float,
|
||||
) -> ToolProviderApiEntity:
|
||||
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
|
||||
existing_provider = (
|
||||
|
|
@ -91,6 +93,8 @@ class MCPToolManageService:
|
|||
tools="[]",
|
||||
icon=json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon,
|
||||
server_identifier=server_identifier,
|
||||
timeout=timeout,
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
)
|
||||
db.session.add(mcp_tool)
|
||||
db.session.commit()
|
||||
|
|
@ -166,6 +170,8 @@ class MCPToolManageService:
|
|||
icon_type: str,
|
||||
icon_background: str,
|
||||
server_identifier: str,
|
||||
timeout: float | None = None,
|
||||
sse_read_timeout: float | None = None,
|
||||
):
|
||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
|
||||
|
|
@ -197,6 +203,10 @@ class MCPToolManageService:
|
|||
mcp_provider.tools = reconnect_result["tools"]
|
||||
mcp_provider.encrypted_credentials = reconnect_result["encrypted_credentials"]
|
||||
|
||||
if timeout is not None:
|
||||
mcp_provider.timeout = timeout
|
||||
if sse_read_timeout is not None:
|
||||
mcp_provider.sse_read_timeout = sse_read_timeout
|
||||
db.session.commit()
|
||||
except IntegrityError as e:
|
||||
db.session.rollback()
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from sqlalchemy.sql.expression import and_, or_
|
|||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.models import File
|
||||
from core.variables import Segment, StringSegment, Variable
|
||||
from core.variables.consts import MIN_SELECTORS_LENGTH
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
|
|
@ -147,7 +147,7 @@ class WorkflowDraftVariableService:
|
|||
) -> list[WorkflowDraftVariable]:
|
||||
ors = []
|
||||
for selector in selectors:
|
||||
assert len(selector) >= MIN_SELECTORS_LENGTH, f"Invalid selector to get: {selector}"
|
||||
assert len(selector) >= SELECTORS_LENGTH, f"Invalid selector to get: {selector}"
|
||||
node_id, name = selector[:2]
|
||||
ors.append(and_(WorkflowDraftVariable.node_id == node_id, WorkflowDraftVariable.name == name))
|
||||
|
||||
|
|
@ -608,7 +608,7 @@ class DraftVariableSaver:
|
|||
|
||||
for item in updated_variables:
|
||||
selector = item.selector
|
||||
if len(selector) < MIN_SELECTORS_LENGTH:
|
||||
if len(selector) < SELECTORS_LENGTH:
|
||||
raise Exception("selector too short")
|
||||
# NOTE(QuantumGhost): only the following two kinds of variable could be updated by
|
||||
# VariableAssigner: ConversationVariable and iteration variable.
|
||||
|
|
|
|||
|
|
@ -56,19 +56,29 @@ def clean_dataset_task(
|
|||
documents = db.session.query(Document).where(Document.dataset_id == dataset_id).all()
|
||||
segments = db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id).all()
|
||||
|
||||
# Fix: Always clean vector database resources regardless of document existence
|
||||
# This ensures all 33 vector databases properly drop tables/collections/indices
|
||||
if doc_form is None:
|
||||
# Use default paragraph index type for empty datasets to enable vector database cleanup
|
||||
# Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace
|
||||
# This ensures all invalid doc_form values are properly handled
|
||||
if doc_form is None or (isinstance(doc_form, str) and not doc_form.strip()):
|
||||
# Use default paragraph index type for empty/invalid datasets to enable vector database cleanup
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
|
||||
doc_form = IndexType.PARAGRAPH_INDEX
|
||||
logging.info(
|
||||
click.style(f"No documents found, using default index type for cleanup: {doc_form}", fg="yellow")
|
||||
click.style(f"Invalid doc_form detected, using default index type for cleanup: {doc_form}", fg="yellow")
|
||||
)
|
||||
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True)
|
||||
# Add exception handling around IndexProcessorFactory.clean() to prevent single point of failure
|
||||
# This ensures Document/Segment deletion can continue even if vector database cleanup fails
|
||||
try:
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True)
|
||||
logging.info(click.style(f"Successfully cleaned vector database for dataset: {dataset_id}", fg="green"))
|
||||
except Exception as index_cleanup_error:
|
||||
logging.exception(click.style(f"Failed to clean vector database for dataset {dataset_id}", fg="red"))
|
||||
# Continue with document and segment deletion even if vector cleanup fails
|
||||
logging.info(
|
||||
click.style(f"Continuing with document and segment deletion for dataset: {dataset_id}", fg="yellow")
|
||||
)
|
||||
|
||||
if documents is None or len(documents) == 0:
|
||||
logging.info(click.style(f"No documents found for dataset: {dataset_id}", fg="green"))
|
||||
|
|
@ -128,6 +138,14 @@ def clean_dataset_task(
|
|||
click.style(f"Cleaned dataset when dataset deleted: {dataset_id} latency: {end_at - start_at}", fg="green")
|
||||
)
|
||||
except Exception:
|
||||
# Add rollback to prevent dirty session state in case of exceptions
|
||||
# This ensures the database session is properly cleaned up
|
||||
try:
|
||||
db.session.rollback()
|
||||
logging.info(click.style(f"Rolled back database session for dataset: {dataset_id}", fg="yellow"))
|
||||
except Exception as rollback_error:
|
||||
logging.exception("Failed to rollback database session")
|
||||
|
||||
logging.exception("Cleaned dataset when dataset deleted failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,136 @@
|
|||
"""
|
||||
Celery tasks for asynchronous workflow execution storage operations.
|
||||
|
||||
These tasks provide asynchronous storage capabilities for workflow execution data,
|
||||
improving performance by offloading storage operations to background workers.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from celery import shared_task # type: ignore[import-untyped]
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from extensions.ext_database import db
|
||||
from models import CreatorUserRole, WorkflowRun
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="workflow_storage", bind=True, max_retries=3, default_retry_delay=60)
|
||||
def save_workflow_execution_task(
|
||||
self,
|
||||
execution_data: dict,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
triggered_from: str,
|
||||
creator_user_id: str,
|
||||
creator_user_role: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Asynchronously save or update a workflow execution to the database.
|
||||
|
||||
Args:
|
||||
execution_data: Serialized WorkflowExecution data
|
||||
tenant_id: Tenant ID for multi-tenancy
|
||||
app_id: Application ID
|
||||
triggered_from: Source of the execution trigger
|
||||
creator_user_id: ID of the user who created the execution
|
||||
creator_user_role: Role of the user who created the execution
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Create a new session for this task
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
with session_factory() as session:
|
||||
# Deserialize execution data
|
||||
execution = WorkflowExecution.model_validate(execution_data)
|
||||
|
||||
# Check if workflow run already exists
|
||||
existing_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == execution.id_))
|
||||
|
||||
if existing_run:
|
||||
# Update existing workflow run
|
||||
_update_workflow_run_from_execution(existing_run, execution)
|
||||
logger.debug("Updated existing workflow run: %s", execution.id_)
|
||||
else:
|
||||
# Create new workflow run
|
||||
workflow_run = _create_workflow_run_from_execution(
|
||||
execution=execution,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom(triggered_from),
|
||||
creator_user_id=creator_user_id,
|
||||
creator_user_role=CreatorUserRole(creator_user_role),
|
||||
)
|
||||
session.add(workflow_run)
|
||||
logger.debug("Created new workflow run: %s", execution.id_)
|
||||
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to save workflow execution %s", execution_data.get("id_", "unknown"))
|
||||
# Retry the task with exponential backoff
|
||||
raise self.retry(exc=e, countdown=60 * (2**self.request.retries))
|
||||
|
||||
|
||||
def _create_workflow_run_from_execution(
|
||||
execution: WorkflowExecution,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
triggered_from: WorkflowRunTriggeredFrom,
|
||||
creator_user_id: str,
|
||||
creator_user_role: CreatorUserRole,
|
||||
) -> WorkflowRun:
|
||||
"""
|
||||
Create a WorkflowRun database model from a WorkflowExecution domain entity.
|
||||
"""
|
||||
workflow_run = WorkflowRun()
|
||||
workflow_run.id = execution.id_
|
||||
workflow_run.tenant_id = tenant_id
|
||||
workflow_run.app_id = app_id
|
||||
workflow_run.workflow_id = execution.workflow_id
|
||||
workflow_run.type = execution.workflow_type.value
|
||||
workflow_run.triggered_from = triggered_from.value
|
||||
workflow_run.version = execution.workflow_version
|
||||
json_converter = WorkflowRuntimeTypeConverter()
|
||||
workflow_run.graph = json.dumps(json_converter.to_json_encodable(execution.graph))
|
||||
workflow_run.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs))
|
||||
workflow_run.status = execution.status.value
|
||||
workflow_run.outputs = (
|
||||
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
|
||||
)
|
||||
workflow_run.error = execution.error_message
|
||||
workflow_run.elapsed_time = execution.elapsed_time
|
||||
workflow_run.total_tokens = execution.total_tokens
|
||||
workflow_run.total_steps = execution.total_steps
|
||||
workflow_run.created_by_role = creator_user_role.value
|
||||
workflow_run.created_by = creator_user_id
|
||||
workflow_run.created_at = execution.started_at
|
||||
workflow_run.finished_at = execution.finished_at
|
||||
|
||||
return workflow_run
|
||||
|
||||
|
||||
def _update_workflow_run_from_execution(workflow_run: WorkflowRun, execution: WorkflowExecution) -> None:
|
||||
"""
|
||||
Update a WorkflowRun database model from a WorkflowExecution domain entity.
|
||||
"""
|
||||
json_converter = WorkflowRuntimeTypeConverter()
|
||||
workflow_run.status = execution.status.value
|
||||
workflow_run.outputs = (
|
||||
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
|
||||
)
|
||||
workflow_run.error = execution.error_message
|
||||
workflow_run.elapsed_time = execution.elapsed_time
|
||||
workflow_run.total_tokens = execution.total_tokens
|
||||
workflow_run.total_steps = execution.total_steps
|
||||
workflow_run.finished_at = execution.finished_at
|
||||
|
|
@ -0,0 +1,171 @@
|
|||
"""
|
||||
Celery tasks for asynchronous workflow node execution storage operations.
|
||||
|
||||
These tasks provide asynchronous storage capabilities for workflow node execution data,
|
||||
improving performance by offloading storage operations to background workers.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from celery import shared_task # type: ignore[import-untyped]
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.workflow.entities.workflow_node_execution import (
|
||||
WorkflowNodeExecution,
|
||||
)
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from extensions.ext_database import db
|
||||
from models import CreatorUserRole, WorkflowNodeExecutionModel
|
||||
from models.workflow import WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="workflow_storage", bind=True, max_retries=3, default_retry_delay=60)
|
||||
def save_workflow_node_execution_task(
|
||||
self,
|
||||
execution_data: dict,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
triggered_from: str,
|
||||
creator_user_id: str,
|
||||
creator_user_role: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Asynchronously save or update a workflow node execution to the database.
|
||||
|
||||
Args:
|
||||
execution_data: Serialized WorkflowNodeExecution data
|
||||
tenant_id: Tenant ID for multi-tenancy
|
||||
app_id: Application ID
|
||||
triggered_from: Source of the execution trigger
|
||||
creator_user_id: ID of the user who created the execution
|
||||
creator_user_role: Role of the user who created the execution
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Create a new session for this task
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
with session_factory() as session:
|
||||
# Deserialize execution data
|
||||
execution = WorkflowNodeExecution.model_validate(execution_data)
|
||||
|
||||
# Check if node execution already exists
|
||||
existing_execution = session.scalar(
|
||||
select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == execution.id)
|
||||
)
|
||||
|
||||
if existing_execution:
|
||||
# Update existing node execution
|
||||
_update_node_execution_from_domain(existing_execution, execution)
|
||||
logger.debug("Updated existing workflow node execution: %s", execution.id)
|
||||
else:
|
||||
# Create new node execution
|
||||
node_execution = _create_node_execution_from_domain(
|
||||
execution=execution,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom(triggered_from),
|
||||
creator_user_id=creator_user_id,
|
||||
creator_user_role=CreatorUserRole(creator_user_role),
|
||||
)
|
||||
session.add(node_execution)
|
||||
logger.debug("Created new workflow node execution: %s", execution.id)
|
||||
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to save workflow node execution %s", execution_data.get("id", "unknown"))
|
||||
# Retry the task with exponential backoff
|
||||
raise self.retry(exc=e, countdown=60 * (2**self.request.retries))
|
||||
|
||||
|
||||
def _create_node_execution_from_domain(
|
||||
execution: WorkflowNodeExecution,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
triggered_from: WorkflowNodeExecutionTriggeredFrom,
|
||||
creator_user_id: str,
|
||||
creator_user_role: CreatorUserRole,
|
||||
) -> WorkflowNodeExecutionModel:
|
||||
"""
|
||||
Create a WorkflowNodeExecutionModel database model from a WorkflowNodeExecution domain entity.
|
||||
"""
|
||||
node_execution = WorkflowNodeExecutionModel()
|
||||
node_execution.id = execution.id
|
||||
node_execution.tenant_id = tenant_id
|
||||
node_execution.app_id = app_id
|
||||
node_execution.workflow_id = execution.workflow_id
|
||||
node_execution.triggered_from = triggered_from.value
|
||||
node_execution.workflow_run_id = execution.workflow_execution_id
|
||||
node_execution.index = execution.index
|
||||
node_execution.predecessor_node_id = execution.predecessor_node_id
|
||||
node_execution.node_id = execution.node_id
|
||||
node_execution.node_type = execution.node_type.value
|
||||
node_execution.title = execution.title
|
||||
node_execution.node_execution_id = execution.node_execution_id
|
||||
|
||||
# Serialize complex data as JSON
|
||||
json_converter = WorkflowRuntimeTypeConverter()
|
||||
node_execution.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs)) if execution.inputs else "{}"
|
||||
node_execution.process_data = (
|
||||
json.dumps(json_converter.to_json_encodable(execution.process_data)) if execution.process_data else "{}"
|
||||
)
|
||||
node_execution.outputs = (
|
||||
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
|
||||
)
|
||||
# Convert metadata enum keys to strings for JSON serialization
|
||||
if execution.metadata:
|
||||
metadata_for_json = {
|
||||
key.value if hasattr(key, "value") else str(key): value for key, value in execution.metadata.items()
|
||||
}
|
||||
node_execution.execution_metadata = json.dumps(json_converter.to_json_encodable(metadata_for_json))
|
||||
else:
|
||||
node_execution.execution_metadata = "{}"
|
||||
|
||||
node_execution.status = execution.status.value
|
||||
node_execution.error = execution.error
|
||||
node_execution.elapsed_time = execution.elapsed_time
|
||||
node_execution.created_by_role = creator_user_role.value
|
||||
node_execution.created_by = creator_user_id
|
||||
node_execution.created_at = execution.created_at
|
||||
node_execution.finished_at = execution.finished_at
|
||||
|
||||
return node_execution
|
||||
|
||||
|
||||
def _update_node_execution_from_domain(
|
||||
node_execution: WorkflowNodeExecutionModel, execution: WorkflowNodeExecution
|
||||
) -> None:
|
||||
"""
|
||||
Update a WorkflowNodeExecutionModel database model from a WorkflowNodeExecution domain entity.
|
||||
"""
|
||||
# Update serialized data
|
||||
json_converter = WorkflowRuntimeTypeConverter()
|
||||
node_execution.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs)) if execution.inputs else "{}"
|
||||
node_execution.process_data = (
|
||||
json.dumps(json_converter.to_json_encodable(execution.process_data)) if execution.process_data else "{}"
|
||||
)
|
||||
node_execution.outputs = (
|
||||
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
|
||||
)
|
||||
# Convert metadata enum keys to strings for JSON serialization
|
||||
if execution.metadata:
|
||||
metadata_for_json = {
|
||||
key.value if hasattr(key, "value") else str(key): value for key, value in execution.metadata.items()
|
||||
}
|
||||
node_execution.execution_metadata = json.dumps(json_converter.to_json_encodable(metadata_for_json))
|
||||
else:
|
||||
node_execution.execution_metadata = "{}"
|
||||
|
||||
# Update other fields
|
||||
node_execution.status = execution.status.value
|
||||
node_execution.error = execution.error
|
||||
node_execution.elapsed_time = execution.elapsed_time
|
||||
node_execution.finished_at = execution.finished_at
|
||||
|
|
@ -55,8 +55,8 @@ def init_code_node(code_config: dict):
|
|||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
variable_pool.add(["code", "123", "args1"], 1)
|
||||
variable_pool.add(["code", "123", "args2"], 2)
|
||||
variable_pool.add(["code", "args1"], 1)
|
||||
variable_pool.add(["code", "args2"], 2)
|
||||
|
||||
node = CodeNode(
|
||||
id=str(uuid.uuid4()),
|
||||
|
|
@ -96,9 +96,9 @@ def test_execute_code(setup_code_executor_mock):
|
|||
"variables": [
|
||||
{
|
||||
"variable": "args1",
|
||||
"value_selector": ["1", "123", "args1"],
|
||||
"value_selector": ["1", "args1"],
|
||||
},
|
||||
{"variable": "args2", "value_selector": ["1", "123", "args2"]},
|
||||
{"variable": "args2", "value_selector": ["1", "args2"]},
|
||||
],
|
||||
"answer": "123",
|
||||
"code_language": "python3",
|
||||
|
|
@ -107,8 +107,8 @@ def test_execute_code(setup_code_executor_mock):
|
|||
}
|
||||
|
||||
node = init_code_node(code_config)
|
||||
node.graph_runtime_state.variable_pool.add(["1", "123", "args1"], 1)
|
||||
node.graph_runtime_state.variable_pool.add(["1", "123", "args2"], 2)
|
||||
node.graph_runtime_state.variable_pool.add(["1", "args1"], 1)
|
||||
node.graph_runtime_state.variable_pool.add(["1", "args2"], 2)
|
||||
|
||||
# execute node
|
||||
result = node._run()
|
||||
|
|
@ -142,9 +142,9 @@ def test_execute_code_output_validator(setup_code_executor_mock):
|
|||
"variables": [
|
||||
{
|
||||
"variable": "args1",
|
||||
"value_selector": ["1", "123", "args1"],
|
||||
"value_selector": ["1", "args1"],
|
||||
},
|
||||
{"variable": "args2", "value_selector": ["1", "123", "args2"]},
|
||||
{"variable": "args2", "value_selector": ["1", "args2"]},
|
||||
],
|
||||
"answer": "123",
|
||||
"code_language": "python3",
|
||||
|
|
@ -153,8 +153,8 @@ def test_execute_code_output_validator(setup_code_executor_mock):
|
|||
}
|
||||
|
||||
node = init_code_node(code_config)
|
||||
node.graph_runtime_state.variable_pool.add(["1", "123", "args1"], 1)
|
||||
node.graph_runtime_state.variable_pool.add(["1", "123", "args2"], 2)
|
||||
node.graph_runtime_state.variable_pool.add(["1", "args1"], 1)
|
||||
node.graph_runtime_state.variable_pool.add(["1", "args2"], 2)
|
||||
|
||||
# execute node
|
||||
result = node._run()
|
||||
|
|
@ -217,9 +217,9 @@ def test_execute_code_output_validator_depth():
|
|||
"variables": [
|
||||
{
|
||||
"variable": "args1",
|
||||
"value_selector": ["1", "123", "args1"],
|
||||
"value_selector": ["1", "args1"],
|
||||
},
|
||||
{"variable": "args2", "value_selector": ["1", "123", "args2"]},
|
||||
{"variable": "args2", "value_selector": ["1", "args2"]},
|
||||
],
|
||||
"answer": "123",
|
||||
"code_language": "python3",
|
||||
|
|
@ -307,9 +307,9 @@ def test_execute_code_output_object_list():
|
|||
"variables": [
|
||||
{
|
||||
"variable": "args1",
|
||||
"value_selector": ["1", "123", "args1"],
|
||||
"value_selector": ["1", "args1"],
|
||||
},
|
||||
{"variable": "args2", "value_selector": ["1", "123", "args2"]},
|
||||
{"variable": "args2", "value_selector": ["1", "args2"]},
|
||||
],
|
||||
"answer": "123",
|
||||
"code_language": "python3",
|
||||
|
|
|
|||
|
|
@ -49,8 +49,8 @@ def init_http_node(config: dict):
|
|||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
variable_pool.add(["a", "b123", "args1"], 1)
|
||||
variable_pool.add(["a", "b123", "args2"], 2)
|
||||
variable_pool.add(["a", "args1"], 1)
|
||||
variable_pool.add(["a", "args2"], 2)
|
||||
|
||||
node = HttpRequestNode(
|
||||
id=str(uuid.uuid4()),
|
||||
|
|
@ -171,7 +171,7 @@ def test_template(setup_http_mock):
|
|||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "get",
|
||||
"url": "http://example.com/{{#a.b123.args2#}}",
|
||||
"url": "http://example.com/{{#a.args2#}}",
|
||||
"authorization": {
|
||||
"type": "api-key",
|
||||
"config": {
|
||||
|
|
@ -180,8 +180,8 @@ def test_template(setup_http_mock):
|
|||
"header": "api-key",
|
||||
},
|
||||
},
|
||||
"headers": "X-Header:123\nX-Header2:{{#a.b123.args2#}}",
|
||||
"params": "A:b\nTemplate:{{#a.b123.args2#}}",
|
||||
"headers": "X-Header:123\nX-Header2:{{#a.args2#}}",
|
||||
"params": "A:b\nTemplate:{{#a.args2#}}",
|
||||
"body": None,
|
||||
},
|
||||
}
|
||||
|
|
@ -223,7 +223,7 @@ def test_json(setup_http_mock):
|
|||
{
|
||||
"key": "",
|
||||
"type": "text",
|
||||
"value": '{"a": "{{#a.b123.args1#}}"}',
|
||||
"value": '{"a": "{{#a.args1#}}"}',
|
||||
},
|
||||
],
|
||||
},
|
||||
|
|
@ -264,12 +264,12 @@ def test_x_www_form_urlencoded(setup_http_mock):
|
|||
{
|
||||
"key": "a",
|
||||
"type": "text",
|
||||
"value": "{{#a.b123.args1#}}",
|
||||
"value": "{{#a.args1#}}",
|
||||
},
|
||||
{
|
||||
"key": "b",
|
||||
"type": "text",
|
||||
"value": "{{#a.b123.args2#}}",
|
||||
"value": "{{#a.args2#}}",
|
||||
},
|
||||
],
|
||||
},
|
||||
|
|
@ -310,12 +310,12 @@ def test_form_data(setup_http_mock):
|
|||
{
|
||||
"key": "a",
|
||||
"type": "text",
|
||||
"value": "{{#a.b123.args1#}}",
|
||||
"value": "{{#a.args1#}}",
|
||||
},
|
||||
{
|
||||
"key": "b",
|
||||
"type": "text",
|
||||
"value": "{{#a.b123.args2#}}",
|
||||
"value": "{{#a.args2#}}",
|
||||
},
|
||||
],
|
||||
},
|
||||
|
|
@ -436,3 +436,87 @@ def test_multi_colons_parse(setup_http_mock):
|
|||
assert 'form-data; name="Redirect"\r\n\r\nhttp://example6.com' in result.process_data.get("request", "")
|
||||
# resp = result.outputs
|
||||
# assert "http://example3.com" == resp.get("headers", {}).get("referer")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
def test_nested_object_variable_selector(setup_http_mock):
|
||||
"""Test variable selector functionality with nested object properties."""
|
||||
# Create independent test setup without affecting other tests
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-next-target",
|
||||
"source": "start",
|
||||
"target": "1",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{
|
||||
"id": "1",
|
||||
"data": {
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "get",
|
||||
"url": "http://example.com/{{#a.args2#}}/{{#a.args3.nested#}}",
|
||||
"authorization": {
|
||||
"type": "api-key",
|
||||
"config": {
|
||||
"type": "basic",
|
||||
"api_key": "ak-xxx",
|
||||
"header": "api-key",
|
||||
},
|
||||
},
|
||||
"headers": "X-Header:{{#a.args3.nested#}}",
|
||||
"params": "nested_param:{{#a.args3.nested#}}",
|
||||
"body": None,
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# Create independent variable pool for this test only
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
variable_pool.add(["a", "args1"], 1)
|
||||
variable_pool.add(["a", "args2"], 2)
|
||||
variable_pool.add(["a", "args3"], {"nested": "nested_value"}) # Only for this test
|
||||
|
||||
node = HttpRequestNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config=graph_config["nodes"][1],
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
if "data" in graph_config["nodes"][1]:
|
||||
node.init_node_data(graph_config["nodes"][1]["data"])
|
||||
|
||||
result = node._run()
|
||||
assert result.process_data is not None
|
||||
data = result.process_data.get("request", "")
|
||||
|
||||
# Verify nested object property is correctly resolved
|
||||
assert "/2/nested_value" in data # URL path should contain resolved nested value
|
||||
assert "X-Header: nested_value" in data # Header should contain nested value
|
||||
assert "nested_param=nested_value" in data # Param should contain nested value
|
||||
|
|
|
|||
|
|
@ -71,8 +71,8 @@ def init_parameter_extractor_node(config: dict):
|
|||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
variable_pool.add(["a", "b123", "args1"], 1)
|
||||
variable_pool.add(["a", "b123", "args2"], 2)
|
||||
variable_pool.add(["a", "args1"], 1)
|
||||
variable_pool.add(["a", "args2"], 2)
|
||||
|
||||
node = ParameterExtractorNode(
|
||||
id=str(uuid.uuid4()),
|
||||
|
|
|
|||
|
|
@ -26,9 +26,9 @@ def test_execute_code(setup_code_executor_mock):
|
|||
"variables": [
|
||||
{
|
||||
"variable": "args1",
|
||||
"value_selector": ["1", "123", "args1"],
|
||||
"value_selector": ["1", "args1"],
|
||||
},
|
||||
{"variable": "args2", "value_selector": ["1", "123", "args2"]},
|
||||
{"variable": "args2", "value_selector": ["1", "args2"]},
|
||||
],
|
||||
"template": code,
|
||||
},
|
||||
|
|
@ -66,8 +66,8 @@ def test_execute_code(setup_code_executor_mock):
|
|||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
variable_pool.add(["1", "123", "args1"], 1)
|
||||
variable_pool.add(["1", "123", "args2"], 3)
|
||||
variable_pool.add(["1", "args1"], 1)
|
||||
variable_pool.add(["1", "args2"], 3)
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id=str(uuid.uuid4()),
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ def test_tool_variable_invoke():
|
|||
|
||||
ToolParameterConfigurationManager.decrypt_tool_parameters = MagicMock(return_value={"format": "%Y-%m-%d %H:%M:%S"})
|
||||
|
||||
node.graph_runtime_state.variable_pool.add(["1", "123", "args1"], "1+1")
|
||||
node.graph_runtime_state.variable_pool.add(["1", "args1"], "1+1")
|
||||
|
||||
# execute node
|
||||
result = node._run()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,913 @@
|
|||
import hashlib
|
||||
from io import BytesIO
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from models.account import Account, Tenant
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import EndUser, UploadFile
|
||||
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
class TestFileService:
|
||||
"""Integration tests for FileService using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("services.file_service.storage") as mock_storage,
|
||||
patch("services.file_service.file_helpers") as mock_file_helpers,
|
||||
patch("services.file_service.ExtractProcessor") as mock_extract_processor,
|
||||
):
|
||||
# Setup default mock returns
|
||||
mock_storage.save.return_value = None
|
||||
mock_storage.load.return_value = BytesIO(b"mock file content")
|
||||
mock_file_helpers.get_signed_file_url.return_value = "https://example.com/signed-url"
|
||||
mock_file_helpers.verify_image_signature.return_value = True
|
||||
mock_file_helpers.verify_file_signature.return_value = True
|
||||
mock_extract_processor.load_from_upload_file.return_value = "extracted text content"
|
||||
|
||||
yield {
|
||||
"storage": mock_storage,
|
||||
"file_helpers": mock_file_helpers,
|
||||
"extract_processor": mock_extract_processor,
|
||||
}
|
||||
|
||||
def _create_test_account(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test account for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
|
||||
Returns:
|
||||
Account: Created account instance
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account
|
||||
account = Account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
|
||||
# Create tenant for the account
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
from models.account import TenantAccountJoin, TenantAccountRole
|
||||
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER.value,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
|
||||
# Set current tenant for account
|
||||
account.current_tenant = tenant
|
||||
|
||||
return account
|
||||
|
||||
def _create_test_end_user(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test end user for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
|
||||
Returns:
|
||||
EndUser: Created end user instance
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
end_user = EndUser(
|
||||
tenant_id=str(fake.uuid4()),
|
||||
type="web",
|
||||
name=fake.name(),
|
||||
is_anonymous=False,
|
||||
session_id=fake.uuid4(),
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
|
||||
return end_user
|
||||
|
||||
def _create_test_upload_file(self, db_session_with_containers, mock_external_service_dependencies, account):
|
||||
"""
|
||||
Helper method to create a test upload file for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
account: Account instance
|
||||
|
||||
Returns:
|
||||
UploadFile: Created upload file instance
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
upload_file = UploadFile(
|
||||
tenant_id=account.current_tenant_id if hasattr(account, "current_tenant_id") else str(fake.uuid4()),
|
||||
storage_type="local",
|
||||
key=f"upload_files/test/{fake.uuid4()}.txt",
|
||||
name="test_file.txt",
|
||||
size=1024,
|
||||
extension="txt",
|
||||
mime_type="text/plain",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
created_at=fake.date_time(),
|
||||
used=False,
|
||||
hash=hashlib.sha3_256(b"test content").hexdigest(),
|
||||
source_url="",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(upload_file)
|
||||
db.session.commit()
|
||||
|
||||
return upload_file
|
||||
|
||||
# Test upload_file method
|
||||
def test_upload_file_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful file upload with valid parameters.
|
||||
"""
|
||||
fake = Faker()
|
||||
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
filename = "test_document.pdf"
|
||||
content = b"test file content"
|
||||
mimetype = "application/pdf"
|
||||
|
||||
upload_file = FileService.upload_file(
|
||||
filename=filename,
|
||||
content=content,
|
||||
mimetype=mimetype,
|
||||
user=account,
|
||||
)
|
||||
|
||||
assert upload_file is not None
|
||||
assert upload_file.name == filename
|
||||
assert upload_file.size == len(content)
|
||||
assert upload_file.extension == "pdf"
|
||||
assert upload_file.mime_type == mimetype
|
||||
assert upload_file.created_by == account.id
|
||||
assert upload_file.created_by_role == CreatorUserRole.ACCOUNT.value
|
||||
assert upload_file.used is False
|
||||
assert upload_file.hash == hashlib.sha3_256(content).hexdigest()
|
||||
|
||||
# Verify storage was called
|
||||
mock_external_service_dependencies["storage"].save.assert_called_once()
|
||||
|
||||
# Verify database state
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(upload_file)
|
||||
assert upload_file.id is not None
|
||||
|
||||
def test_upload_file_with_end_user(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test file upload with end user instead of account.
|
||||
"""
|
||||
fake = Faker()
|
||||
end_user = self._create_test_end_user(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
filename = "test_image.jpg"
|
||||
content = b"test image content"
|
||||
mimetype = "image/jpeg"
|
||||
|
||||
upload_file = FileService.upload_file(
|
||||
filename=filename,
|
||||
content=content,
|
||||
mimetype=mimetype,
|
||||
user=end_user,
|
||||
)
|
||||
|
||||
assert upload_file is not None
|
||||
assert upload_file.created_by == end_user.id
|
||||
assert upload_file.created_by_role == CreatorUserRole.END_USER.value
|
||||
|
||||
def test_upload_file_with_datasets_source(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test file upload with datasets source parameter.
|
||||
"""
|
||||
fake = Faker()
|
||||
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
filename = "test_document.pdf"
|
||||
content = b"test file content"
|
||||
mimetype = "application/pdf"
|
||||
|
||||
upload_file = FileService.upload_file(
|
||||
filename=filename,
|
||||
content=content,
|
||||
mimetype=mimetype,
|
||||
user=account,
|
||||
source="datasets",
|
||||
source_url="https://example.com/source",
|
||||
)
|
||||
|
||||
assert upload_file is not None
|
||||
assert upload_file.source_url == "https://example.com/source"
|
||||
|
||||
def test_upload_file_invalid_filename_characters(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file upload with invalid filename characters.
|
||||
"""
|
||||
fake = Faker()
|
||||
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
filename = "test/file<name>.txt"
|
||||
content = b"test content"
|
||||
mimetype = "text/plain"
|
||||
|
||||
with pytest.raises(ValueError, match="Filename contains invalid characters"):
|
||||
FileService.upload_file(
|
||||
filename=filename,
|
||||
content=content,
|
||||
mimetype=mimetype,
|
||||
user=account,
|
||||
)
|
||||
|
||||
def test_upload_file_filename_too_long(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test file upload with filename that exceeds length limit.
|
||||
"""
|
||||
fake = Faker()
|
||||
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a filename longer than 200 characters
|
||||
long_name = "a" * 250
|
||||
filename = f"{long_name}.txt"
|
||||
content = b"test content"
|
||||
mimetype = "text/plain"
|
||||
|
||||
upload_file = FileService.upload_file(
|
||||
filename=filename,
|
||||
content=content,
|
||||
mimetype=mimetype,
|
||||
user=account,
|
||||
)
|
||||
|
||||
# Verify filename was truncated (the logic truncates the base name to 200 chars + extension)
|
||||
# So the total length should be <= 200 + len(extension) + 1 (for the dot)
|
||||
assert len(upload_file.name) <= 200 + len(upload_file.extension) + 1
|
||||
assert upload_file.name.endswith(".txt")
|
||||
# Verify the base name was truncated
|
||||
base_name = upload_file.name[:-4] # Remove .txt
|
||||
assert len(base_name) <= 200
|
||||
|
||||
def test_upload_file_datasets_unsupported_type(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file upload for datasets with unsupported file type.
|
||||
"""
|
||||
fake = Faker()
|
||||
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
filename = "test_image.jpg"
|
||||
content = b"test content"
|
||||
mimetype = "image/jpeg"
|
||||
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
FileService.upload_file(
|
||||
filename=filename,
|
||||
content=content,
|
||||
mimetype=mimetype,
|
||||
user=account,
|
||||
source="datasets",
|
||||
)
|
||||
|
||||
def test_upload_file_too_large(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test file upload with file size exceeding limit.
|
||||
"""
|
||||
fake = Faker()
|
||||
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
filename = "large_image.jpg"
|
||||
# Create content larger than the limit
|
||||
content = b"x" * (dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024 + 1)
|
||||
mimetype = "image/jpeg"
|
||||
|
||||
with pytest.raises(FileTooLargeError):
|
||||
FileService.upload_file(
|
||||
filename=filename,
|
||||
content=content,
|
||||
mimetype=mimetype,
|
||||
user=account,
|
||||
)
|
||||
|
||||
# Test is_file_size_within_limit method
|
||||
def test_is_file_size_within_limit_image_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file size check for image files within limit.
|
||||
"""
|
||||
extension = "jpg"
|
||||
file_size = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024 # Exactly at limit
|
||||
|
||||
result = FileService.is_file_size_within_limit(extension=extension, file_size=file_size)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_is_file_size_within_limit_video_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file size check for video files within limit.
|
||||
"""
|
||||
extension = "mp4"
|
||||
file_size = dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT * 1024 * 1024 # Exactly at limit
|
||||
|
||||
result = FileService.is_file_size_within_limit(extension=extension, file_size=file_size)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_is_file_size_within_limit_audio_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file size check for audio files within limit.
|
||||
"""
|
||||
extension = "mp3"
|
||||
file_size = dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT * 1024 * 1024 # Exactly at limit
|
||||
|
||||
result = FileService.is_file_size_within_limit(extension=extension, file_size=file_size)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_is_file_size_within_limit_document_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file size check for document files within limit.
|
||||
"""
|
||||
extension = "pdf"
|
||||
file_size = dify_config.UPLOAD_FILE_SIZE_LIMIT * 1024 * 1024 # Exactly at limit
|
||||
|
||||
result = FileService.is_file_size_within_limit(extension=extension, file_size=file_size)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_is_file_size_within_limit_image_exceeded(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file size check for image files exceeding limit.
|
||||
"""
|
||||
extension = "jpg"
|
||||
file_size = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024 + 1 # Exceeds limit
|
||||
|
||||
result = FileService.is_file_size_within_limit(extension=extension, file_size=file_size)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_is_file_size_within_limit_unknown_extension(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file size check for unknown file extension.
|
||||
"""
|
||||
extension = "xyz"
|
||||
file_size = dify_config.UPLOAD_FILE_SIZE_LIMIT * 1024 * 1024 # Uses default limit
|
||||
|
||||
result = FileService.is_file_size_within_limit(extension=extension, file_size=file_size)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Test upload_text method
|
||||
def test_upload_text_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful text upload.
|
||||
"""
|
||||
fake = Faker()
|
||||
text = "This is a test text content"
|
||||
text_name = "test_text.txt"
|
||||
|
||||
# Mock current_user
|
||||
with patch("services.file_service.current_user") as mock_current_user:
|
||||
mock_current_user.current_tenant_id = str(fake.uuid4())
|
||||
mock_current_user.id = str(fake.uuid4())
|
||||
|
||||
upload_file = FileService.upload_text(text=text, text_name=text_name)
|
||||
|
||||
assert upload_file is not None
|
||||
assert upload_file.name == text_name
|
||||
assert upload_file.size == len(text)
|
||||
assert upload_file.extension == "txt"
|
||||
assert upload_file.mime_type == "text/plain"
|
||||
assert upload_file.used is True
|
||||
assert upload_file.used_by == mock_current_user.id
|
||||
|
||||
# Verify storage was called
|
||||
mock_external_service_dependencies["storage"].save.assert_called_once()
|
||||
|
||||
def test_upload_text_name_too_long(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test text upload with name that exceeds length limit.
|
||||
"""
|
||||
fake = Faker()
|
||||
text = "test content"
|
||||
long_name = "a" * 250 # Longer than 200 characters
|
||||
|
||||
# Mock current_user
|
||||
with patch("services.file_service.current_user") as mock_current_user:
|
||||
mock_current_user.current_tenant_id = str(fake.uuid4())
|
||||
mock_current_user.id = str(fake.uuid4())
|
||||
|
||||
upload_file = FileService.upload_text(text=text, text_name=long_name)
|
||||
|
||||
# Verify name was truncated
|
||||
assert len(upload_file.name) <= 200
|
||||
assert upload_file.name == "a" * 200
|
||||
|
||||
# Test get_file_preview method
|
||||
def test_get_file_preview_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful file preview generation.
|
||||
"""
|
||||
fake = Faker()
|
||||
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
upload_file = self._create_test_upload_file(
|
||||
db_session_with_containers, mock_external_service_dependencies, account
|
||||
)
|
||||
|
||||
# Update file to have document extension
|
||||
upload_file.extension = "pdf"
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
|
||||
result = FileService.get_file_preview(file_id=upload_file.id)
|
||||
|
||||
assert result == "extracted text content"
|
||||
mock_external_service_dependencies["extract_processor"].load_from_upload_file.assert_called_once()
|
||||
|
||||
def test_get_file_preview_file_not_found(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test file preview with non-existent file.
|
||||
"""
|
||||
fake = Faker()
|
||||
non_existent_id = str(fake.uuid4())
|
||||
|
||||
with pytest.raises(NotFound, match="File not found"):
|
||||
FileService.get_file_preview(file_id=non_existent_id)
|
||||
|
||||
def test_get_file_preview_unsupported_file_type(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file preview with unsupported file type.
|
||||
"""
|
||||
fake = Faker()
|
||||
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
upload_file = self._create_test_upload_file(
|
||||
db_session_with_containers, mock_external_service_dependencies, account
|
||||
)
|
||||
|
||||
# Update file to have non-document extension
|
||||
upload_file.extension = "jpg"
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
FileService.get_file_preview(file_id=upload_file.id)
|
||||
|
||||
def test_get_file_preview_text_truncation(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test file preview with text that exceeds preview limit.
|
||||
"""
|
||||
fake = Faker()
|
||||
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
upload_file = self._create_test_upload_file(
|
||||
db_session_with_containers, mock_external_service_dependencies, account
|
||||
)
|
||||
|
||||
# Update file to have document extension
|
||||
upload_file.extension = "pdf"
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# Mock long text content
|
||||
long_text = "x" * 5000 # Longer than PREVIEW_WORDS_LIMIT
|
||||
mock_external_service_dependencies["extract_processor"].load_from_upload_file.return_value = long_text
|
||||
|
||||
result = FileService.get_file_preview(file_id=upload_file.id)
|
||||
|
||||
assert len(result) == 3000 # PREVIEW_WORDS_LIMIT
|
||||
assert result == "x" * 3000
|
||||
|
||||
# Test get_image_preview method
|
||||
def test_get_image_preview_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful image preview generation.
|
||||
"""
|
||||
fake = Faker()
|
||||
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
upload_file = self._create_test_upload_file(
|
||||
db_session_with_containers, mock_external_service_dependencies, account
|
||||
)
|
||||
|
||||
# Update file to have image extension
|
||||
upload_file.extension = "jpg"
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
|
||||
timestamp = "1234567890"
|
||||
nonce = "test_nonce"
|
||||
sign = "test_signature"
|
||||
|
||||
generator, mime_type = FileService.get_image_preview(
|
||||
file_id=upload_file.id,
|
||||
timestamp=timestamp,
|
||||
nonce=nonce,
|
||||
sign=sign,
|
||||
)
|
||||
|
||||
assert generator is not None
|
||||
assert mime_type == upload_file.mime_type
|
||||
mock_external_service_dependencies["file_helpers"].verify_image_signature.assert_called_once()
|
||||
|
||||
def test_get_image_preview_invalid_signature(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test image preview with invalid signature.
|
||||
"""
|
||||
fake = Faker()
|
||||
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
upload_file = self._create_test_upload_file(
|
||||
db_session_with_containers, mock_external_service_dependencies, account
|
||||
)
|
||||
|
||||
# Mock invalid signature
|
||||
mock_external_service_dependencies["file_helpers"].verify_image_signature.return_value = False
|
||||
|
||||
timestamp = "1234567890"
|
||||
nonce = "test_nonce"
|
||||
sign = "invalid_signature"
|
||||
|
||||
with pytest.raises(NotFound, match="File not found or signature is invalid"):
|
||||
FileService.get_image_preview(
|
||||
file_id=upload_file.id,
|
||||
timestamp=timestamp,
|
||||
nonce=nonce,
|
||||
sign=sign,
|
||||
)
|
||||
|
||||
def test_get_image_preview_file_not_found(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test image preview with non-existent file.
|
||||
"""
|
||||
fake = Faker()
|
||||
non_existent_id = str(fake.uuid4())
|
||||
|
||||
timestamp = "1234567890"
|
||||
nonce = "test_nonce"
|
||||
sign = "test_signature"
|
||||
|
||||
with pytest.raises(NotFound, match="File not found or signature is invalid"):
|
||||
FileService.get_image_preview(
|
||||
file_id=non_existent_id,
|
||||
timestamp=timestamp,
|
||||
nonce=nonce,
|
||||
sign=sign,
|
||||
)
|
||||
|
||||
def test_get_image_preview_unsupported_file_type(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test image preview with non-image file type.
|
||||
"""
|
||||
fake = Faker()
|
||||
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
upload_file = self._create_test_upload_file(
|
||||
db_session_with_containers, mock_external_service_dependencies, account
|
||||
)
|
||||
|
||||
# Update file to have non-image extension
|
||||
upload_file.extension = "pdf"
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
|
||||
timestamp = "1234567890"
|
||||
nonce = "test_nonce"
|
||||
sign = "test_signature"
|
||||
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
FileService.get_image_preview(
|
||||
file_id=upload_file.id,
|
||||
timestamp=timestamp,
|
||||
nonce=nonce,
|
||||
sign=sign,
|
||||
)
|
||||
|
||||
# Test get_file_generator_by_file_id method
|
||||
def test_get_file_generator_by_file_id_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful file generator retrieval.
|
||||
"""
|
||||
fake = Faker()
|
||||
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
upload_file = self._create_test_upload_file(
|
||||
db_session_with_containers, mock_external_service_dependencies, account
|
||||
)
|
||||
|
||||
timestamp = "1234567890"
|
||||
nonce = "test_nonce"
|
||||
sign = "test_signature"
|
||||
|
||||
generator, file_obj = FileService.get_file_generator_by_file_id(
|
||||
file_id=upload_file.id,
|
||||
timestamp=timestamp,
|
||||
nonce=nonce,
|
||||
sign=sign,
|
||||
)
|
||||
|
||||
assert generator is not None
|
||||
assert file_obj == upload_file
|
||||
mock_external_service_dependencies["file_helpers"].verify_file_signature.assert_called_once()
|
||||
|
||||
def test_get_file_generator_by_file_id_invalid_signature(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file generator retrieval with invalid signature.
|
||||
"""
|
||||
fake = Faker()
|
||||
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
upload_file = self._create_test_upload_file(
|
||||
db_session_with_containers, mock_external_service_dependencies, account
|
||||
)
|
||||
|
||||
# Mock invalid signature
|
||||
mock_external_service_dependencies["file_helpers"].verify_file_signature.return_value = False
|
||||
|
||||
timestamp = "1234567890"
|
||||
nonce = "test_nonce"
|
||||
sign = "invalid_signature"
|
||||
|
||||
with pytest.raises(NotFound, match="File not found or signature is invalid"):
|
||||
FileService.get_file_generator_by_file_id(
|
||||
file_id=upload_file.id,
|
||||
timestamp=timestamp,
|
||||
nonce=nonce,
|
||||
sign=sign,
|
||||
)
|
||||
|
||||
def test_get_file_generator_by_file_id_file_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file generator retrieval with non-existent file.
|
||||
"""
|
||||
fake = Faker()
|
||||
non_existent_id = str(fake.uuid4())
|
||||
|
||||
timestamp = "1234567890"
|
||||
nonce = "test_nonce"
|
||||
sign = "test_signature"
|
||||
|
||||
with pytest.raises(NotFound, match="File not found or signature is invalid"):
|
||||
FileService.get_file_generator_by_file_id(
|
||||
file_id=non_existent_id,
|
||||
timestamp=timestamp,
|
||||
nonce=nonce,
|
||||
sign=sign,
|
||||
)
|
||||
|
||||
# Test get_public_image_preview method
|
||||
def test_get_public_image_preview_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful public image preview generation.
|
||||
"""
|
||||
fake = Faker()
|
||||
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
upload_file = self._create_test_upload_file(
|
||||
db_session_with_containers, mock_external_service_dependencies, account
|
||||
)
|
||||
|
||||
# Update file to have image extension
|
||||
upload_file.extension = "jpg"
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
|
||||
generator, mime_type = FileService.get_public_image_preview(file_id=upload_file.id)
|
||||
|
||||
assert generator is not None
|
||||
assert mime_type == upload_file.mime_type
|
||||
mock_external_service_dependencies["storage"].load.assert_called_once()
|
||||
|
||||
def test_get_public_image_preview_file_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test public image preview with non-existent file.
|
||||
"""
|
||||
fake = Faker()
|
||||
non_existent_id = str(fake.uuid4())
|
||||
|
||||
with pytest.raises(NotFound, match="File not found or signature is invalid"):
|
||||
FileService.get_public_image_preview(file_id=non_existent_id)
|
||||
|
||||
def test_get_public_image_preview_unsupported_file_type(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test public image preview with non-image file type.
|
||||
"""
|
||||
fake = Faker()
|
||||
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
upload_file = self._create_test_upload_file(
|
||||
db_session_with_containers, mock_external_service_dependencies, account
|
||||
)
|
||||
|
||||
# Update file to have non-image extension
|
||||
upload_file.extension = "pdf"
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
FileService.get_public_image_preview(file_id=upload_file.id)
|
||||
|
||||
# Test edge cases and boundary conditions
|
||||
def test_upload_file_empty_content(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test file upload with empty content.
|
||||
"""
|
||||
fake = Faker()
|
||||
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
filename = "empty.txt"
|
||||
content = b""
|
||||
mimetype = "text/plain"
|
||||
|
||||
upload_file = FileService.upload_file(
|
||||
filename=filename,
|
||||
content=content,
|
||||
mimetype=mimetype,
|
||||
user=account,
|
||||
)
|
||||
|
||||
assert upload_file is not None
|
||||
assert upload_file.size == 0
|
||||
|
||||
def test_upload_file_special_characters_in_name(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file upload with special characters in filename (but valid ones).
|
||||
"""
|
||||
fake = Faker()
|
||||
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
filename = "test-file_with_underscores_and.dots.txt"
|
||||
content = b"test content"
|
||||
mimetype = "text/plain"
|
||||
|
||||
upload_file = FileService.upload_file(
|
||||
filename=filename,
|
||||
content=content,
|
||||
mimetype=mimetype,
|
||||
user=account,
|
||||
)
|
||||
|
||||
assert upload_file is not None
|
||||
assert upload_file.name == filename
|
||||
|
||||
def test_upload_file_different_case_extensions(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file upload with different case extensions.
|
||||
"""
|
||||
fake = Faker()
|
||||
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
filename = "test.PDF"
|
||||
content = b"test content"
|
||||
mimetype = "application/pdf"
|
||||
|
||||
upload_file = FileService.upload_file(
|
||||
filename=filename,
|
||||
content=content,
|
||||
mimetype=mimetype,
|
||||
user=account,
|
||||
)
|
||||
|
||||
assert upload_file is not None
|
||||
assert upload_file.extension == "pdf" # Should be converted to lowercase
|
||||
|
||||
def test_upload_text_empty_text(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test text upload with empty text.
|
||||
"""
|
||||
fake = Faker()
|
||||
text = ""
|
||||
text_name = "empty.txt"
|
||||
|
||||
# Mock current_user
|
||||
with patch("services.file_service.current_user") as mock_current_user:
|
||||
mock_current_user.current_tenant_id = str(fake.uuid4())
|
||||
mock_current_user.id = str(fake.uuid4())
|
||||
|
||||
upload_file = FileService.upload_text(text=text, text_name=text_name)
|
||||
|
||||
assert upload_file is not None
|
||||
assert upload_file.size == 0
|
||||
|
||||
def test_file_size_limits_edge_cases(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test file size limits with edge case values.
|
||||
"""
|
||||
# Test exactly at limit
|
||||
for extension, limit_config in [
|
||||
("jpg", dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT),
|
||||
("mp4", dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT),
|
||||
("mp3", dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT),
|
||||
("pdf", dify_config.UPLOAD_FILE_SIZE_LIMIT),
|
||||
]:
|
||||
file_size = limit_config * 1024 * 1024
|
||||
result = FileService.is_file_size_within_limit(extension=extension, file_size=file_size)
|
||||
assert result is True
|
||||
|
||||
# Test one byte over limit
|
||||
file_size = limit_config * 1024 * 1024 + 1
|
||||
result = FileService.is_file_size_within_limit(extension=extension, file_size=file_size)
|
||||
assert result is False
|
||||
|
||||
def test_upload_file_with_source_url(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test file upload with source URL that gets overridden by signed URL.
|
||||
"""
|
||||
fake = Faker()
|
||||
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
filename = "test.pdf"
|
||||
content = b"test content"
|
||||
mimetype = "application/pdf"
|
||||
source_url = "https://original-source.com/file.pdf"
|
||||
|
||||
upload_file = FileService.upload_file(
|
||||
filename=filename,
|
||||
content=content,
|
||||
mimetype=mimetype,
|
||||
user=account,
|
||||
source_url=source_url,
|
||||
)
|
||||
|
||||
# When source_url is provided, it should be preserved
|
||||
assert upload_file.source_url == source_url
|
||||
|
||||
# The signed URL should only be set when source_url is empty
|
||||
# Let's test that scenario
|
||||
upload_file2 = FileService.upload_file(
|
||||
filename="test2.pdf",
|
||||
content=b"test content 2",
|
||||
mimetype="application/pdf",
|
||||
user=account,
|
||||
source_url="", # Empty source_url
|
||||
)
|
||||
|
||||
# Should have the signed URL when source_url is empty
|
||||
assert upload_file2.source_url == "https://example.com/signed-url"
|
||||
|
|
@ -0,0 +1,775 @@
|
|||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from models.model import MessageFeedback
|
||||
from services.app_service import AppService
|
||||
from services.errors.message import (
|
||||
FirstMessageNotExistsError,
|
||||
LastMessageNotExistsError,
|
||||
MessageNotExistsError,
|
||||
SuggestedQuestionsAfterAnswerDisabledError,
|
||||
)
|
||||
from services.message_service import MessageService
|
||||
|
||||
|
||||
class TestMessageService:
|
||||
"""Integration tests for MessageService using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("services.account_service.FeatureService") as mock_account_feature_service,
|
||||
patch("services.message_service.ModelManager") as mock_model_manager,
|
||||
patch("services.message_service.WorkflowService") as mock_workflow_service,
|
||||
patch("services.message_service.AdvancedChatAppConfigManager") as mock_app_config_manager,
|
||||
patch("services.message_service.LLMGenerator") as mock_llm_generator,
|
||||
patch("services.message_service.TraceQueueManager") as mock_trace_manager_class,
|
||||
patch("services.message_service.TokenBufferMemory") as mock_token_buffer_memory,
|
||||
):
|
||||
# Setup default mock returns
|
||||
mock_account_feature_service.get_features.return_value.billing.enabled = False
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_instance = mock_model_manager.return_value.get_default_model_instance.return_value
|
||||
mock_model_instance.get_tts_voices.return_value = [{"value": "test-voice"}]
|
||||
|
||||
# Mock get_model_instance method as well
|
||||
mock_model_manager.return_value.get_model_instance.return_value = mock_model_instance
|
||||
|
||||
# Mock WorkflowService
|
||||
mock_workflow = mock_workflow_service.return_value.get_published_workflow.return_value
|
||||
mock_workflow_service.return_value.get_draft_workflow.return_value = mock_workflow
|
||||
|
||||
# Mock AdvancedChatAppConfigManager
|
||||
mock_app_config = mock_app_config_manager.get_app_config.return_value
|
||||
mock_app_config.additional_features.suggested_questions_after_answer = True
|
||||
|
||||
# Mock LLMGenerator
|
||||
mock_llm_generator.generate_suggested_questions_after_answer.return_value = ["Question 1", "Question 2"]
|
||||
|
||||
# Mock TraceQueueManager
|
||||
mock_trace_manager_instance = mock_trace_manager_class.return_value
|
||||
|
||||
# Mock TokenBufferMemory
|
||||
mock_memory_instance = mock_token_buffer_memory.return_value
|
||||
mock_memory_instance.get_history_prompt_text.return_value = "Mocked history prompt"
|
||||
|
||||
yield {
|
||||
"account_feature_service": mock_account_feature_service,
|
||||
"model_manager": mock_model_manager,
|
||||
"workflow_service": mock_workflow_service,
|
||||
"app_config_manager": mock_app_config_manager,
|
||||
"llm_generator": mock_llm_generator,
|
||||
"trace_manager_class": mock_trace_manager_class,
|
||||
"trace_manager_instance": mock_trace_manager_instance,
|
||||
"token_buffer_memory": mock_token_buffer_memory,
|
||||
# "current_user": mock_current_user,
|
||||
}
|
||||
|
||||
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test app and account for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
|
||||
Returns:
|
||||
tuple: (app, account) - Created app and account instances
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Setup mocks for account creation
|
||||
mock_external_service_dependencies[
|
||||
"account_feature_service"
|
||||
].get_system_features.return_value.is_allow_register = True
|
||||
|
||||
# Create account and tenant first
|
||||
from services.account_service import AccountService, TenantService
|
||||
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Setup app creation arguments
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "advanced-chat", # Use advanced-chat mode to use mocked workflow
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
"api_rph": 100,
|
||||
"api_rpm": 10,
|
||||
}
|
||||
|
||||
# Create app
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Setup current_user mock
|
||||
self._mock_current_user(mock_external_service_dependencies, account.id, tenant.id)
|
||||
|
||||
return app, account
|
||||
|
||||
def _mock_current_user(self, mock_external_service_dependencies, account_id, tenant_id):
|
||||
"""
|
||||
Helper method to mock the current user for testing.
|
||||
"""
|
||||
# mock_external_service_dependencies["current_user"].id = account_id
|
||||
# mock_external_service_dependencies["current_user"].current_tenant_id = tenant_id
|
||||
|
||||
def _create_test_conversation(self, app, account, fake):
|
||||
"""
|
||||
Helper method to create a test conversation with all required fields.
|
||||
"""
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation
|
||||
|
||||
conversation = Conversation(
|
||||
app_id=app.id,
|
||||
app_model_config_id=None,
|
||||
model_provider=None,
|
||||
model_id="",
|
||||
override_model_configs=None,
|
||||
mode=app.mode,
|
||||
name=fake.sentence(),
|
||||
inputs={},
|
||||
introduction="",
|
||||
system_instruction="",
|
||||
system_instruction_tokens=0,
|
||||
status="normal",
|
||||
invoke_from="console",
|
||||
from_source="console",
|
||||
from_end_user_id=None,
|
||||
from_account_id=account.id,
|
||||
)
|
||||
|
||||
db.session.add(conversation)
|
||||
db.session.flush()
|
||||
return conversation
|
||||
|
||||
def _create_test_message(self, app, conversation, account, fake):
|
||||
"""
|
||||
Helper method to create a test message with all required fields.
|
||||
"""
|
||||
import json
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message
|
||||
|
||||
message = Message(
|
||||
app_id=app.id,
|
||||
model_provider=None,
|
||||
model_id="",
|
||||
override_model_configs=None,
|
||||
conversation_id=conversation.id,
|
||||
inputs={},
|
||||
query=fake.sentence(),
|
||||
message=json.dumps([{"role": "user", "text": fake.sentence()}]),
|
||||
message_tokens=0,
|
||||
message_unit_price=0,
|
||||
message_price_unit=0.001,
|
||||
answer=fake.text(max_nb_chars=200),
|
||||
answer_tokens=0,
|
||||
answer_unit_price=0,
|
||||
answer_price_unit=0.001,
|
||||
parent_message_id=None,
|
||||
provider_response_latency=0,
|
||||
total_price=0,
|
||||
currency="USD",
|
||||
invoke_from="console",
|
||||
from_source="console",
|
||||
from_end_user_id=None,
|
||||
from_account_id=account.id,
|
||||
)
|
||||
|
||||
db.session.add(message)
|
||||
db.session.commit()
|
||||
return message
|
||||
|
||||
def test_pagination_by_first_id_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful pagination by first ID.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and multiple messages
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
messages = []
|
||||
for i in range(5):
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
messages.append(message)
|
||||
|
||||
# Test pagination by first ID
|
||||
result = MessageService.pagination_by_first_id(
|
||||
app_model=app,
|
||||
user=account,
|
||||
conversation_id=conversation.id,
|
||||
first_id=messages[2].id, # Use middle message as first_id
|
||||
limit=2,
|
||||
order="asc",
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert result.limit == 2
|
||||
assert len(result.data) == 2
|
||||
# total 5, from the middle, no more
|
||||
assert result.has_more is False
|
||||
# Verify messages are in ascending order
|
||||
assert result.data[0].created_at <= result.data[1].created_at
|
||||
|
||||
def test_pagination_by_first_id_no_user(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test pagination by first ID when no user is provided.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Test pagination with no user
|
||||
result = MessageService.pagination_by_first_id(
|
||||
app_model=app, user=None, conversation_id=fake.uuid4(), first_id=None, limit=10
|
||||
)
|
||||
|
||||
# Verify empty result
|
||||
assert result.limit == 10
|
||||
assert len(result.data) == 0
|
||||
assert result.has_more is False
|
||||
|
||||
def test_pagination_by_first_id_no_conversation_id(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test pagination by first ID when no conversation ID is provided.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Test pagination with no conversation ID
|
||||
result = MessageService.pagination_by_first_id(
|
||||
app_model=app, user=account, conversation_id="", first_id=None, limit=10
|
||||
)
|
||||
|
||||
# Verify empty result
|
||||
assert result.limit == 10
|
||||
assert len(result.data) == 0
|
||||
assert result.has_more is False
|
||||
|
||||
def test_pagination_by_first_id_invalid_first_id(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test pagination by first ID with invalid first_id.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
self._create_test_message(app, conversation, account, fake)
|
||||
|
||||
# Test pagination with invalid first_id
|
||||
with pytest.raises(FirstMessageNotExistsError):
|
||||
MessageService.pagination_by_first_id(
|
||||
app_model=app,
|
||||
user=account,
|
||||
conversation_id=conversation.id,
|
||||
first_id=fake.uuid4(), # Non-existent message ID
|
||||
limit=10,
|
||||
)
|
||||
|
||||
def test_pagination_by_last_id_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful pagination by last ID.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and multiple messages
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
messages = []
|
||||
for i in range(5):
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
messages.append(message)
|
||||
|
||||
# Test pagination by last ID
|
||||
result = MessageService.pagination_by_last_id(
|
||||
app_model=app,
|
||||
user=account,
|
||||
last_id=messages[2].id, # Use middle message as last_id
|
||||
limit=2,
|
||||
conversation_id=conversation.id,
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert result.limit == 2
|
||||
assert len(result.data) == 2
|
||||
# total 5, from the middle, no more
|
||||
assert result.has_more is False
|
||||
# Verify messages are in descending order
|
||||
assert result.data[0].created_at >= result.data[1].created_at
|
||||
|
||||
def test_pagination_by_last_id_with_include_ids(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test pagination by last ID with include_ids filter.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and multiple messages
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
messages = []
|
||||
for i in range(5):
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
messages.append(message)
|
||||
|
||||
# Test pagination with include_ids
|
||||
include_ids = [messages[0].id, messages[1].id, messages[2].id]
|
||||
result = MessageService.pagination_by_last_id(
|
||||
app_model=app, user=account, last_id=messages[1].id, limit=2, include_ids=include_ids
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert result.limit == 2
|
||||
assert len(result.data) <= 2
|
||||
# Verify all returned messages are in include_ids
|
||||
for message in result.data:
|
||||
assert message.id in include_ids
|
||||
|
||||
def test_pagination_by_last_id_no_user(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test pagination by last ID when no user is provided.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Test pagination with no user
|
||||
result = MessageService.pagination_by_last_id(app_model=app, user=None, last_id=None, limit=10)
|
||||
|
||||
# Verify empty result
|
||||
assert result.limit == 10
|
||||
assert len(result.data) == 0
|
||||
assert result.has_more is False
|
||||
|
||||
def test_pagination_by_last_id_invalid_last_id(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test pagination by last ID with invalid last_id.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
self._create_test_message(app, conversation, account, fake)
|
||||
|
||||
# Test pagination with invalid last_id
|
||||
with pytest.raises(LastMessageNotExistsError):
|
||||
MessageService.pagination_by_last_id(
|
||||
app_model=app,
|
||||
user=account,
|
||||
last_id=fake.uuid4(), # Non-existent message ID
|
||||
limit=10,
|
||||
conversation_id=conversation.id,
|
||||
)
|
||||
|
||||
def test_create_feedback_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful creation of feedback.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
|
||||
# Create feedback
|
||||
rating = "like"
|
||||
content = fake.text(max_nb_chars=100)
|
||||
feedback = MessageService.create_feedback(
|
||||
app_model=app, message_id=message.id, user=account, rating=rating, content=content
|
||||
)
|
||||
|
||||
# Verify feedback was created correctly
|
||||
assert feedback.app_id == app.id
|
||||
assert feedback.conversation_id == conversation.id
|
||||
assert feedback.message_id == message.id
|
||||
assert feedback.rating == rating
|
||||
assert feedback.content == content
|
||||
assert feedback.from_source == "admin"
|
||||
assert feedback.from_account_id == account.id
|
||||
assert feedback.from_end_user_id is None
|
||||
|
||||
def test_create_feedback_no_user(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test creating feedback when no user is provided.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
|
||||
# Test creating feedback with no user
|
||||
with pytest.raises(ValueError, match="user cannot be None"):
|
||||
MessageService.create_feedback(
|
||||
app_model=app, message_id=message.id, user=None, rating="like", content=fake.text(max_nb_chars=100)
|
||||
)
|
||||
|
||||
def test_create_feedback_update_existing(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test updating existing feedback.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
|
||||
# Create initial feedback
|
||||
initial_rating = "like"
|
||||
initial_content = fake.text(max_nb_chars=100)
|
||||
feedback = MessageService.create_feedback(
|
||||
app_model=app, message_id=message.id, user=account, rating=initial_rating, content=initial_content
|
||||
)
|
||||
|
||||
# Update feedback
|
||||
updated_rating = "dislike"
|
||||
updated_content = fake.text(max_nb_chars=100)
|
||||
updated_feedback = MessageService.create_feedback(
|
||||
app_model=app, message_id=message.id, user=account, rating=updated_rating, content=updated_content
|
||||
)
|
||||
|
||||
# Verify feedback was updated correctly
|
||||
assert updated_feedback.id == feedback.id
|
||||
assert updated_feedback.rating == updated_rating
|
||||
assert updated_feedback.content == updated_content
|
||||
assert updated_feedback.rating != initial_rating
|
||||
assert updated_feedback.content != initial_content
|
||||
|
||||
def test_create_feedback_delete_existing(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test deleting existing feedback by setting rating to None.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
|
||||
# Create initial feedback
|
||||
feedback = MessageService.create_feedback(
|
||||
app_model=app, message_id=message.id, user=account, rating="like", content=fake.text(max_nb_chars=100)
|
||||
)
|
||||
|
||||
# Delete feedback by setting rating to None
|
||||
MessageService.create_feedback(app_model=app, message_id=message.id, user=account, rating=None, content=None)
|
||||
|
||||
# Verify feedback was deleted
|
||||
from extensions.ext_database import db
|
||||
|
||||
deleted_feedback = db.session.query(MessageFeedback).filter(MessageFeedback.id == feedback.id).first()
|
||||
assert deleted_feedback is None
|
||||
|
||||
def test_create_feedback_no_rating_when_not_exists(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test creating feedback with no rating when feedback doesn't exist.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
|
||||
# Test creating feedback with no rating when no feedback exists
|
||||
with pytest.raises(ValueError, match="rating cannot be None when feedback not exists"):
|
||||
MessageService.create_feedback(
|
||||
app_model=app, message_id=message.id, user=account, rating=None, content=None
|
||||
)
|
||||
|
||||
def test_get_all_messages_feedbacks_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful retrieval of all message feedbacks.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create multiple conversations and messages with feedbacks
|
||||
feedbacks = []
|
||||
for i in range(3):
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
|
||||
feedback = MessageService.create_feedback(
|
||||
app_model=app,
|
||||
message_id=message.id,
|
||||
user=account,
|
||||
rating="like" if i % 2 == 0 else "dislike",
|
||||
content=f"Feedback {i}: {fake.text(max_nb_chars=50)}",
|
||||
)
|
||||
feedbacks.append(feedback)
|
||||
|
||||
# Get all feedbacks
|
||||
result = MessageService.get_all_messages_feedbacks(app, page=1, limit=10)
|
||||
|
||||
# Verify results
|
||||
assert len(result) == 3
|
||||
|
||||
# Verify feedbacks are ordered by created_at desc
|
||||
for i in range(len(result) - 1):
|
||||
assert result[i]["created_at"] >= result[i + 1]["created_at"]
|
||||
|
||||
def test_get_all_messages_feedbacks_pagination(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test pagination of message feedbacks.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create multiple conversations and messages with feedbacks
|
||||
for i in range(5):
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
|
||||
MessageService.create_feedback(
|
||||
app_model=app, message_id=message.id, user=account, rating="like", content=f"Feedback {i}"
|
||||
)
|
||||
|
||||
# Get feedbacks with pagination
|
||||
result_page_1 = MessageService.get_all_messages_feedbacks(app, page=1, limit=3)
|
||||
result_page_2 = MessageService.get_all_messages_feedbacks(app, page=2, limit=3)
|
||||
|
||||
# Verify pagination results
|
||||
assert len(result_page_1) == 3
|
||||
assert len(result_page_2) == 2
|
||||
|
||||
# Verify no overlap between pages
|
||||
page_1_ids = {feedback["id"] for feedback in result_page_1}
|
||||
page_2_ids = {feedback["id"] for feedback in result_page_2}
|
||||
assert len(page_1_ids.intersection(page_2_ids)) == 0
|
||||
|
||||
def test_get_message_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful retrieval of message.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
|
||||
# Get message
|
||||
retrieved_message = MessageService.get_message(app_model=app, user=account, message_id=message.id)
|
||||
|
||||
# Verify message was retrieved correctly
|
||||
assert retrieved_message.id == message.id
|
||||
assert retrieved_message.app_id == app.id
|
||||
assert retrieved_message.conversation_id == conversation.id
|
||||
assert retrieved_message.from_source == "console"
|
||||
assert retrieved_message.from_account_id == account.id
|
||||
|
||||
def test_get_message_not_exists(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test getting message that doesn't exist.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Test getting non-existent message
|
||||
with pytest.raises(MessageNotExistsError):
|
||||
MessageService.get_message(app_model=app, user=account, message_id=fake.uuid4())
|
||||
|
||||
def test_get_message_wrong_user(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test getting message with wrong user (different account).
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
|
||||
# Create another account
|
||||
from services.account_service import AccountService, TenantService
|
||||
|
||||
other_account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(other_account, name=fake.company())
|
||||
|
||||
# Test getting message with different user
|
||||
with pytest.raises(MessageNotExistsError):
|
||||
MessageService.get_message(app_model=app, user=other_account, message_id=message.id)
|
||||
|
||||
def test_get_suggested_questions_after_answer_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful generation of suggested questions after answer.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
|
||||
# Mock the LLMGenerator to return specific questions
|
||||
mock_questions = ["What is AI?", "How does machine learning work?", "Tell me about neural networks"]
|
||||
mock_external_service_dependencies[
|
||||
"llm_generator"
|
||||
].generate_suggested_questions_after_answer.return_value = mock_questions
|
||||
|
||||
# Get suggested questions
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
|
||||
result = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app, user=account, message_id=message.id, invoke_from=InvokeFrom.SERVICE_API
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert result == mock_questions
|
||||
|
||||
# Verify LLMGenerator was called
|
||||
mock_external_service_dependencies[
|
||||
"llm_generator"
|
||||
].generate_suggested_questions_after_answer.assert_called_once()
|
||||
|
||||
# Verify TraceQueueManager was called
|
||||
mock_external_service_dependencies["trace_manager_instance"].add_trace_task.assert_called_once()
|
||||
|
||||
def test_get_suggested_questions_after_answer_no_user(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test getting suggested questions when no user is provided.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
|
||||
# Test getting suggested questions with no user
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
|
||||
with pytest.raises(ValueError, match="user cannot be None"):
|
||||
MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app, user=None, message_id=message.id, invoke_from=InvokeFrom.SERVICE_API
|
||||
)
|
||||
|
||||
def test_get_suggested_questions_after_answer_disabled(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test getting suggested questions when feature is disabled.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
|
||||
# Mock the feature to be disabled
|
||||
mock_external_service_dependencies[
|
||||
"app_config_manager"
|
||||
].get_app_config.return_value.additional_features.suggested_questions_after_answer = False
|
||||
|
||||
# Test getting suggested questions when feature is disabled
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
|
||||
with pytest.raises(SuggestedQuestionsAfterAnswerDisabledError):
|
||||
MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app, user=account, message_id=message.id, invoke_from=InvokeFrom.SERVICE_API
|
||||
)
|
||||
|
||||
def test_get_suggested_questions_after_answer_no_workflow(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test getting suggested questions when no workflow exists.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
|
||||
# Mock no workflow
|
||||
mock_external_service_dependencies["workflow_service"].return_value.get_published_workflow.return_value = None
|
||||
|
||||
# Get suggested questions (should return empty list)
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
|
||||
result = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app, user=account, message_id=message.id, invoke_from=InvokeFrom.SERVICE_API
|
||||
)
|
||||
|
||||
# Verify empty result
|
||||
assert result == []
|
||||
|
||||
def test_get_suggested_questions_after_answer_debugger_mode(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test getting suggested questions in debugger mode.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
|
||||
# Mock questions
|
||||
mock_questions = ["Debug question 1", "Debug question 2"]
|
||||
mock_external_service_dependencies[
|
||||
"llm_generator"
|
||||
].generate_suggested_questions_after_answer.return_value = mock_questions
|
||||
|
||||
# Get suggested questions in debugger mode
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
|
||||
result = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app, user=account, message_id=message.id, invoke_from=InvokeFrom.DEBUGGER
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert result == mock_questions
|
||||
|
||||
# Verify draft workflow was used instead of published workflow
|
||||
mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.assert_called_once_with(
|
||||
app_model=app
|
||||
)
|
||||
|
||||
# Verify TraceQueueManager was called
|
||||
mock_external_service_dependencies["trace_manager_instance"].add_trace_task.assert_called_once()
|
||||
|
|
@ -12,6 +12,10 @@ from services.workflow_draft_variable_service import (
|
|||
)
|
||||
|
||||
|
||||
def _get_random_variable_name(fake: Faker):
|
||||
return "".join(fake.random_letters(length=10))
|
||||
|
||||
|
||||
class TestWorkflowDraftVariableService:
|
||||
"""
|
||||
Comprehensive integration tests for WorkflowDraftVariableService using testcontainers.
|
||||
|
|
@ -112,7 +116,14 @@ class TestWorkflowDraftVariableService:
|
|||
return workflow
|
||||
|
||||
def _create_test_variable(
|
||||
self, db_session_with_containers, app_id, node_id, name, value, variable_type="conversation", fake=None
|
||||
self,
|
||||
db_session_with_containers,
|
||||
app_id,
|
||||
node_id,
|
||||
name,
|
||||
value,
|
||||
variable_type: DraftVariableType = DraftVariableType.CONVERSATION,
|
||||
fake=None,
|
||||
):
|
||||
"""
|
||||
Helper method to create a test workflow draft variable with proper configuration.
|
||||
|
|
@ -227,7 +238,13 @@ class TestWorkflowDraftVariableService:
|
|||
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "var2", var2_value, fake=fake
|
||||
)
|
||||
var3 = self._create_test_variable(
|
||||
db_session_with_containers, app.id, "test_node_1", "var3", var3_value, "node", fake=fake
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
"test_node_1",
|
||||
"var3",
|
||||
var3_value,
|
||||
variable_type=DraftVariableType.NODE,
|
||||
fake=fake,
|
||||
)
|
||||
selectors = [
|
||||
[CONVERSATION_VARIABLE_NODE_ID, "var1"],
|
||||
|
|
@ -263,9 +280,14 @@ class TestWorkflowDraftVariableService:
|
|||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
for i in range(5):
|
||||
test_value = StringSegment(value=fake.numerify("value##"))
|
||||
test_value = StringSegment(value=fake.numerify("value######"))
|
||||
self._create_test_variable(
|
||||
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, fake.word(), test_value, fake=fake
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
_get_random_variable_name(fake),
|
||||
test_value,
|
||||
fake=fake,
|
||||
)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
result = service.list_variables_without_values(app.id, page=1, limit=3)
|
||||
|
|
@ -291,10 +313,32 @@ class TestWorkflowDraftVariableService:
|
|||
var1_value = StringSegment(value=fake.word())
|
||||
var2_value = StringSegment(value=fake.word())
|
||||
var3_value = StringSegment(value=fake.word())
|
||||
self._create_test_variable(db_session_with_containers, app.id, node_id, "var1", var1_value, "node", fake=fake)
|
||||
self._create_test_variable(db_session_with_containers, app.id, node_id, "var2", var3_value, "node", fake=fake)
|
||||
self._create_test_variable(
|
||||
db_session_with_containers, app.id, "other_node", "var3", var2_value, "node", fake=fake
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
node_id,
|
||||
"var1",
|
||||
var1_value,
|
||||
variable_type=DraftVariableType.NODE,
|
||||
fake=fake,
|
||||
)
|
||||
self._create_test_variable(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
node_id,
|
||||
"var2",
|
||||
var3_value,
|
||||
variable_type=DraftVariableType.NODE,
|
||||
fake=fake,
|
||||
)
|
||||
self._create_test_variable(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
"other_node",
|
||||
"var3",
|
||||
var2_value,
|
||||
variable_type=DraftVariableType.NODE,
|
||||
fake=fake,
|
||||
)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
result = service.list_node_variables(app.id, node_id)
|
||||
|
|
@ -328,7 +372,13 @@ class TestWorkflowDraftVariableService:
|
|||
)
|
||||
sys_var_value = StringSegment(value=fake.word())
|
||||
self._create_test_variable(
|
||||
db_session_with_containers, app.id, SYSTEM_VARIABLE_NODE_ID, "sys_var", sys_var_value, "system", fake=fake
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
"sys_var",
|
||||
sys_var_value,
|
||||
variable_type=DraftVariableType.SYS,
|
||||
fake=fake,
|
||||
)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
result = service.list_conversation_variables(app.id)
|
||||
|
|
@ -480,14 +530,24 @@ class TestWorkflowDraftVariableService:
|
|||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
for i in range(3):
|
||||
test_value = StringSegment(value=fake.numerify("value##"))
|
||||
test_value = StringSegment(value=fake.numerify("value######"))
|
||||
self._create_test_variable(
|
||||
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, fake.word(), test_value, fake=fake
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
_get_random_variable_name(fake),
|
||||
test_value,
|
||||
fake=fake,
|
||||
)
|
||||
other_app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
other_value = StringSegment(value=fake.word())
|
||||
self._create_test_variable(
|
||||
db_session_with_containers, other_app.id, CONVERSATION_VARIABLE_NODE_ID, fake.word(), other_value, fake=fake
|
||||
db_session_with_containers,
|
||||
other_app.id,
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
_get_random_variable_name(fake),
|
||||
other_value,
|
||||
fake=fake,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
|
||||
|
|
@ -515,17 +575,34 @@ class TestWorkflowDraftVariableService:
|
|||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
node_id = fake.word()
|
||||
for i in range(2):
|
||||
test_value = StringSegment(value=fake.numerify("node_value##"))
|
||||
test_value = StringSegment(value=fake.numerify("node_value######"))
|
||||
self._create_test_variable(
|
||||
db_session_with_containers, app.id, node_id, fake.word(), test_value, "node", fake=fake
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
node_id,
|
||||
_get_random_variable_name(fake),
|
||||
test_value,
|
||||
variable_type=DraftVariableType.NODE,
|
||||
fake=fake,
|
||||
)
|
||||
other_node_value = StringSegment(value=fake.word())
|
||||
self._create_test_variable(
|
||||
db_session_with_containers, app.id, "other_node", fake.word(), other_node_value, "node", fake=fake
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
"other_node",
|
||||
_get_random_variable_name(fake),
|
||||
other_node_value,
|
||||
variable_type=DraftVariableType.NODE,
|
||||
fake=fake,
|
||||
)
|
||||
conv_value = StringSegment(value=fake.word())
|
||||
self._create_test_variable(
|
||||
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, fake.word(), conv_value, fake=fake
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
_get_random_variable_name(fake),
|
||||
conv_value,
|
||||
fake=fake,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
|
||||
|
|
@ -627,7 +704,7 @@ class TestWorkflowDraftVariableService:
|
|||
SYSTEM_VARIABLE_NODE_ID,
|
||||
"conversation_id",
|
||||
conv_id_value,
|
||||
"system",
|
||||
variable_type=DraftVariableType.SYS,
|
||||
fake=fake,
|
||||
)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
|
|
@ -664,10 +741,22 @@ class TestWorkflowDraftVariableService:
|
|||
sys_var1_value = StringSegment(value=fake.word())
|
||||
sys_var2_value = StringSegment(value=fake.word())
|
||||
sys_var1 = self._create_test_variable(
|
||||
db_session_with_containers, app.id, SYSTEM_VARIABLE_NODE_ID, "sys_var1", sys_var1_value, "system", fake=fake
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
"sys_var1",
|
||||
sys_var1_value,
|
||||
variable_type=DraftVariableType.SYS,
|
||||
fake=fake,
|
||||
)
|
||||
sys_var2 = self._create_test_variable(
|
||||
db_session_with_containers, app.id, SYSTEM_VARIABLE_NODE_ID, "sys_var2", sys_var2_value, "system", fake=fake
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
"sys_var2",
|
||||
sys_var2_value,
|
||||
variable_type=DraftVariableType.SYS,
|
||||
fake=fake,
|
||||
)
|
||||
conv_var_value = StringSegment(value=fake.word())
|
||||
self._create_test_variable(
|
||||
|
|
@ -701,10 +790,22 @@ class TestWorkflowDraftVariableService:
|
|||
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "test_conv_var", test_value, fake=fake
|
||||
)
|
||||
sys_var = self._create_test_variable(
|
||||
db_session_with_containers, app.id, SYSTEM_VARIABLE_NODE_ID, "test_sys_var", test_value, "system", fake=fake
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
"test_sys_var",
|
||||
test_value,
|
||||
variable_type=DraftVariableType.SYS,
|
||||
fake=fake,
|
||||
)
|
||||
node_var = self._create_test_variable(
|
||||
db_session_with_containers, app.id, "test_node", "test_node_var", test_value, "node", fake=fake
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
"test_node",
|
||||
"test_node_var",
|
||||
test_value,
|
||||
variable_type=DraftVariableType.NODE,
|
||||
fake=fake,
|
||||
)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
retrieved_conv_var = service.get_conversation_variable(app.id, "test_conv_var")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,124 @@
|
|||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.features.rate_limiting.rate_limit import RateLimit
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis():
|
||||
"""Mock Redis client with realistic behavior for rate limiting tests."""
|
||||
mock_client = MagicMock()
|
||||
|
||||
# Redis data storage for simulation
|
||||
mock_data = {}
|
||||
mock_hashes = {}
|
||||
mock_expiry = {}
|
||||
|
||||
def mock_setex(key, ttl, value):
|
||||
mock_data[key] = str(value)
|
||||
mock_expiry[key] = time.time() + ttl.total_seconds() if hasattr(ttl, "total_seconds") else time.time() + ttl
|
||||
return True
|
||||
|
||||
def mock_get(key):
|
||||
if key in mock_data and (key not in mock_expiry or time.time() < mock_expiry[key]):
|
||||
return mock_data[key].encode("utf-8")
|
||||
return None
|
||||
|
||||
def mock_exists(key):
|
||||
return key in mock_data or key in mock_hashes
|
||||
|
||||
def mock_expire(key, ttl):
|
||||
if key in mock_data or key in mock_hashes:
|
||||
mock_expiry[key] = time.time() + ttl.total_seconds() if hasattr(ttl, "total_seconds") else time.time() + ttl
|
||||
return True
|
||||
|
||||
def mock_hset(key, field, value):
|
||||
if key not in mock_hashes:
|
||||
mock_hashes[key] = {}
|
||||
mock_hashes[key][field] = str(value).encode("utf-8")
|
||||
return True
|
||||
|
||||
def mock_hgetall(key):
|
||||
return mock_hashes.get(key, {})
|
||||
|
||||
def mock_hdel(key, *fields):
|
||||
if key in mock_hashes:
|
||||
count = 0
|
||||
for field in fields:
|
||||
if field in mock_hashes[key]:
|
||||
del mock_hashes[key][field]
|
||||
count += 1
|
||||
return count
|
||||
return 0
|
||||
|
||||
def mock_hlen(key):
|
||||
return len(mock_hashes.get(key, {}))
|
||||
|
||||
# Configure mock methods
|
||||
mock_client.setex = mock_setex
|
||||
mock_client.get = mock_get
|
||||
mock_client.exists = mock_exists
|
||||
mock_client.expire = mock_expire
|
||||
mock_client.hset = mock_hset
|
||||
mock_client.hgetall = mock_hgetall
|
||||
mock_client.hdel = mock_hdel
|
||||
mock_client.hlen = mock_hlen
|
||||
|
||||
# Store references for test verification
|
||||
mock_client._mock_data = mock_data
|
||||
mock_client._mock_hashes = mock_hashes
|
||||
mock_client._mock_expiry = mock_expiry
|
||||
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_time():
|
||||
"""Mock time.time() for deterministic tests."""
|
||||
mock_time_val = 1000.0
|
||||
|
||||
def increment_time(seconds=1):
|
||||
nonlocal mock_time_val
|
||||
mock_time_val += seconds
|
||||
return mock_time_val
|
||||
|
||||
with patch("time.time", return_value=mock_time_val) as mock:
|
||||
mock.increment = increment_time
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_generator():
|
||||
"""Sample generator for testing RateLimitGenerator."""
|
||||
|
||||
def _create_generator(items=None, raise_error=False):
|
||||
items = items or ["item1", "item2", "item3"]
|
||||
for item in items:
|
||||
if raise_error and item == "item2":
|
||||
raise ValueError("Test error")
|
||||
yield item
|
||||
|
||||
return _create_generator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_mapping():
|
||||
"""Sample mapping for testing RateLimitGenerator."""
|
||||
return {"key1": "value1", "key2": "value2"}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_rate_limit_instances():
|
||||
"""Clear RateLimit singleton instances between tests."""
|
||||
RateLimit._instance_dict.clear()
|
||||
yield
|
||||
RateLimit._instance_dict.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def redis_patch():
|
||||
"""Patch redis_client globally for rate limit tests."""
|
||||
with patch("core.app.features.rate_limiting.rate_limit.redis_client") as mock:
|
||||
yield mock
|
||||
|
|
@ -0,0 +1,569 @@
|
|||
import threading
|
||||
import time
|
||||
from datetime import timedelta
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.features.rate_limiting.rate_limit import RateLimit
|
||||
from core.errors.error import AppInvokeQuotaExceededError
|
||||
|
||||
|
||||
class TestRateLimit:
|
||||
"""Core rate limiting functionality tests."""
|
||||
|
||||
def test_should_return_same_instance_for_same_client_id(self, redis_patch):
|
||||
"""Test singleton behavior for same client ID."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit1 = RateLimit("client1", 5)
|
||||
rate_limit2 = RateLimit("client1", 10) # Second instance with different limit
|
||||
|
||||
assert rate_limit1 is rate_limit2
|
||||
# Current implementation: last constructor call overwrites max_active_requests
|
||||
# This reflects the actual behavior where __init__ always sets max_active_requests
|
||||
assert rate_limit1.max_active_requests == 10
|
||||
|
||||
def test_should_create_different_instances_for_different_client_ids(self, redis_patch):
|
||||
"""Test different instances for different client IDs."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit1 = RateLimit("client1", 5)
|
||||
rate_limit2 = RateLimit("client2", 10)
|
||||
|
||||
assert rate_limit1 is not rate_limit2
|
||||
assert rate_limit1.client_id == "client1"
|
||||
assert rate_limit2.client_id == "client2"
|
||||
|
||||
def test_should_initialize_with_valid_parameters(self, redis_patch):
|
||||
"""Test normal initialization."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
|
||||
assert rate_limit.client_id == "test_client"
|
||||
assert rate_limit.max_active_requests == 5
|
||||
assert hasattr(rate_limit, "initialized")
|
||||
redis_patch.setex.assert_called_once()
|
||||
|
||||
def test_should_skip_initialization_if_disabled(self):
|
||||
"""Test no initialization when rate limiting is disabled."""
|
||||
rate_limit = RateLimit("test_client", 0)
|
||||
|
||||
assert rate_limit.disabled()
|
||||
assert not hasattr(rate_limit, "initialized")
|
||||
|
||||
def test_should_skip_reinitialization_of_existing_instance(self, redis_patch):
|
||||
"""Test that existing instance doesn't reinitialize."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
}
|
||||
)
|
||||
|
||||
RateLimit("client1", 5)
|
||||
redis_patch.reset_mock()
|
||||
|
||||
RateLimit("client1", 10)
|
||||
|
||||
redis_patch.setex.assert_not_called()
|
||||
|
||||
def test_should_be_disabled_when_max_requests_is_zero_or_negative(self):
|
||||
"""Test disabled state for zero or negative limits."""
|
||||
rate_limit_zero = RateLimit("client1", 0)
|
||||
rate_limit_negative = RateLimit("client2", -5)
|
||||
|
||||
assert rate_limit_zero.disabled()
|
||||
assert rate_limit_negative.disabled()
|
||||
|
||||
def test_should_set_redis_keys_on_first_flush(self, redis_patch):
|
||||
"""Test Redis keys are set correctly on initial flush."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
|
||||
expected_max_key = "dify:rate_limit:test_client:max_active_requests"
|
||||
redis_patch.setex.assert_called_with(expected_max_key, timedelta(days=1), 5)
|
||||
|
||||
def test_should_sync_max_requests_from_redis_on_subsequent_flush(self, redis_patch):
|
||||
"""Test max requests syncs from Redis when key exists."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": True,
|
||||
"get.return_value": b"10",
|
||||
"expire.return_value": True,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
rate_limit.flush_cache()
|
||||
|
||||
assert rate_limit.max_active_requests == 10
|
||||
|
||||
@patch("time.time")
|
||||
def test_should_clean_timeout_requests_from_active_list(self, mock_time, redis_patch):
|
||||
"""Test cleanup of timed-out requests."""
|
||||
current_time = 1000.0
|
||||
mock_time.return_value = current_time
|
||||
|
||||
# Setup mock Redis with timed-out requests
|
||||
timeout_requests = {
|
||||
b"req1": str(current_time - 700).encode(), # 700 seconds ago (timeout)
|
||||
b"req2": str(current_time - 100).encode(), # 100 seconds ago (active)
|
||||
}
|
||||
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": True,
|
||||
"get.return_value": b"5",
|
||||
"expire.return_value": True,
|
||||
"hgetall.return_value": timeout_requests,
|
||||
"hdel.return_value": 1,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
redis_patch.reset_mock() # Reset to avoid counting initialization calls
|
||||
rate_limit.flush_cache()
|
||||
|
||||
# Verify timeout request was cleaned up
|
||||
redis_patch.hdel.assert_called_once()
|
||||
call_args = redis_patch.hdel.call_args[0]
|
||||
assert call_args[0] == "dify:rate_limit:test_client:active_requests"
|
||||
assert b"req1" in call_args # Timeout request should be removed
|
||||
assert b"req2" not in call_args # Active request should remain
|
||||
|
||||
|
||||
class TestRateLimitEnterExit:
|
||||
"""Rate limiting enter/exit logic tests."""
|
||||
|
||||
def test_should_allow_request_within_limit(self, redis_patch):
|
||||
"""Test allowing requests within the rate limit."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
"hlen.return_value": 2,
|
||||
"hset.return_value": True,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
request_id = rate_limit.enter()
|
||||
|
||||
assert request_id != RateLimit._UNLIMITED_REQUEST_ID
|
||||
redis_patch.hset.assert_called_once()
|
||||
|
||||
def test_should_generate_request_id_if_not_provided(self, redis_patch):
|
||||
"""Test auto-generation of request ID."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
"hlen.return_value": 0,
|
||||
"hset.return_value": True,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
request_id = rate_limit.enter()
|
||||
|
||||
assert len(request_id) == 36 # UUID format
|
||||
|
||||
def test_should_use_provided_request_id(self, redis_patch):
|
||||
"""Test using provided request ID."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
"hlen.return_value": 0,
|
||||
"hset.return_value": True,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
custom_id = "custom_request_123"
|
||||
request_id = rate_limit.enter(custom_id)
|
||||
|
||||
assert request_id == custom_id
|
||||
|
||||
def test_should_remove_request_on_exit(self, redis_patch):
|
||||
"""Test request removal on exit."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"hdel.return_value": 1,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
rate_limit.exit("test_request_id")
|
||||
|
||||
redis_patch.hdel.assert_called_once_with("dify:rate_limit:test_client:active_requests", "test_request_id")
|
||||
|
||||
def test_should_raise_quota_exceeded_when_at_limit(self, redis_patch):
|
||||
"""Test quota exceeded error when at limit."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
"hlen.return_value": 5, # At limit
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
|
||||
with pytest.raises(AppInvokeQuotaExceededError) as exc_info:
|
||||
rate_limit.enter()
|
||||
|
||||
assert "Too many requests" in str(exc_info.value)
|
||||
assert "test_client" in str(exc_info.value)
|
||||
|
||||
def test_should_allow_request_after_previous_exit(self, redis_patch):
|
||||
"""Test allowing new request after previous exit."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
"hlen.return_value": 4, # Under limit after exit
|
||||
"hset.return_value": True,
|
||||
"hdel.return_value": 1,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
|
||||
request_id = rate_limit.enter()
|
||||
rate_limit.exit(request_id)
|
||||
|
||||
new_request_id = rate_limit.enter()
|
||||
assert new_request_id is not None
|
||||
|
||||
@patch("time.time")
|
||||
def test_should_flush_cache_when_interval_exceeded(self, mock_time, redis_patch):
|
||||
"""Test cache flush when time interval exceeded."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
"hlen.return_value": 0,
|
||||
}
|
||||
)
|
||||
|
||||
mock_time.return_value = 1000.0
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
|
||||
# Advance time beyond flush interval
|
||||
mock_time.return_value = 1400.0 # 400 seconds later
|
||||
redis_patch.reset_mock()
|
||||
|
||||
rate_limit.enter()
|
||||
|
||||
# Should have called setex again due to cache flush
|
||||
redis_patch.setex.assert_called()
|
||||
|
||||
def test_should_return_unlimited_id_when_disabled(self):
|
||||
"""Test unlimited ID return when rate limiting disabled."""
|
||||
rate_limit = RateLimit("test_client", 0)
|
||||
request_id = rate_limit.enter()
|
||||
|
||||
assert request_id == RateLimit._UNLIMITED_REQUEST_ID
|
||||
|
||||
def test_should_ignore_exit_for_unlimited_requests(self, redis_patch):
|
||||
"""Test ignoring exit for unlimited requests."""
|
||||
rate_limit = RateLimit("test_client", 0)
|
||||
rate_limit.exit(RateLimit._UNLIMITED_REQUEST_ID)
|
||||
|
||||
redis_patch.hdel.assert_not_called()
|
||||
|
||||
|
||||
class TestRateLimitGenerator:
|
||||
"""Rate limit generator wrapper tests."""
|
||||
|
||||
def test_should_wrap_generator_and_iterate_normally(self, redis_patch, sample_generator):
|
||||
"""Test normal generator iteration with rate limit wrapper."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
"hdel.return_value": 1,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
generator = sample_generator()
|
||||
request_id = "test_request"
|
||||
|
||||
wrapped_gen = rate_limit.generate(generator, request_id)
|
||||
result = list(wrapped_gen)
|
||||
|
||||
assert result == ["item1", "item2", "item3"]
|
||||
redis_patch.hdel.assert_called_once_with("dify:rate_limit:test_client:active_requests", request_id)
|
||||
|
||||
def test_should_handle_mapping_input_directly(self, sample_mapping):
|
||||
"""Test direct return of mapping input."""
|
||||
rate_limit = RateLimit("test_client", 0) # Disabled
|
||||
result = rate_limit.generate(sample_mapping, "test_request")
|
||||
|
||||
assert result is sample_mapping
|
||||
|
||||
def test_should_cleanup_on_exception_during_iteration(self, redis_patch, sample_generator):
|
||||
"""Test cleanup when exception occurs during iteration."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
"hdel.return_value": 1,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
generator = sample_generator(raise_error=True)
|
||||
request_id = "test_request"
|
||||
|
||||
wrapped_gen = rate_limit.generate(generator, request_id)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
list(wrapped_gen)
|
||||
|
||||
redis_patch.hdel.assert_called_once_with("dify:rate_limit:test_client:active_requests", request_id)
|
||||
|
||||
def test_should_cleanup_on_explicit_close(self, redis_patch, sample_generator):
|
||||
"""Test cleanup on explicit generator close."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
"hdel.return_value": 1,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
generator = sample_generator()
|
||||
request_id = "test_request"
|
||||
|
||||
wrapped_gen = rate_limit.generate(generator, request_id)
|
||||
wrapped_gen.close()
|
||||
|
||||
redis_patch.hdel.assert_called_once()
|
||||
|
||||
def test_should_handle_generator_without_close_method(self, redis_patch):
|
||||
"""Test handling generator without close method."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
"hdel.return_value": 1,
|
||||
}
|
||||
)
|
||||
|
||||
# Create a generator-like object without close method
|
||||
class SimpleGenerator:
|
||||
def __init__(self):
|
||||
self.items = ["test"]
|
||||
self.index = 0
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.index >= len(self.items):
|
||||
raise StopIteration
|
||||
item = self.items[self.index]
|
||||
self.index += 1
|
||||
return item
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
generator = SimpleGenerator()
|
||||
|
||||
wrapped_gen = rate_limit.generate(generator, "test_request")
|
||||
wrapped_gen.close() # Should not raise error
|
||||
|
||||
redis_patch.hdel.assert_called_once()
|
||||
|
||||
def test_should_prevent_iteration_after_close(self, redis_patch, sample_generator):
|
||||
"""Test StopIteration after generator is closed."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
"hdel.return_value": 1,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
generator = sample_generator()
|
||||
|
||||
wrapped_gen = rate_limit.generate(generator, "test_request")
|
||||
wrapped_gen.close()
|
||||
|
||||
with pytest.raises(StopIteration):
|
||||
next(wrapped_gen)
|
||||
|
||||
|
||||
class TestRateLimitConcurrency:
|
||||
"""Concurrent access safety tests."""
|
||||
|
||||
def test_should_handle_concurrent_instance_creation(self, redis_patch):
|
||||
"""Test thread-safe singleton instance creation."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
}
|
||||
)
|
||||
|
||||
instances = []
|
||||
errors = []
|
||||
|
||||
def create_instance():
|
||||
try:
|
||||
instance = RateLimit("concurrent_client", 5)
|
||||
instances.append(instance)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
threads = [threading.Thread(target=create_instance) for _ in range(10)]
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert len(errors) == 0
|
||||
assert len({id(inst) for inst in instances}) == 1 # All same instance
|
||||
|
||||
def test_should_handle_concurrent_enter_requests(self, redis_patch):
|
||||
"""Test concurrent enter requests handling."""
|
||||
# Setup mock to simulate realistic Redis behavior
|
||||
request_count = 0
|
||||
|
||||
def mock_hlen(key):
|
||||
nonlocal request_count
|
||||
return request_count
|
||||
|
||||
def mock_hset(key, field, value):
|
||||
nonlocal request_count
|
||||
request_count += 1
|
||||
return True
|
||||
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
"hlen.side_effect": mock_hlen,
|
||||
"hset.side_effect": mock_hset,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("concurrent_client", 3)
|
||||
results = []
|
||||
errors = []
|
||||
|
||||
def try_enter():
|
||||
try:
|
||||
request_id = rate_limit.enter()
|
||||
results.append(request_id)
|
||||
except AppInvokeQuotaExceededError as e:
|
||||
errors.append(e)
|
||||
|
||||
threads = [threading.Thread(target=try_enter) for _ in range(5)]
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Should have some successful requests and some quota exceeded
|
||||
assert len(results) + len(errors) == 5
|
||||
assert len(errors) > 0 # Some should be rejected
|
||||
|
||||
@patch("time.time")
|
||||
def test_should_maintain_accurate_count_under_load(self, mock_time, redis_patch):
|
||||
"""Test accurate count maintenance under concurrent load."""
|
||||
mock_time.return_value = 1000.0
|
||||
|
||||
# Use real mock_redis fixture for better simulation
|
||||
mock_client = self._create_mock_redis()
|
||||
redis_patch.configure_mock(**mock_client)
|
||||
|
||||
rate_limit = RateLimit("load_test_client", 10)
|
||||
active_requests = []
|
||||
|
||||
def enter_and_exit():
|
||||
try:
|
||||
request_id = rate_limit.enter()
|
||||
active_requests.append(request_id)
|
||||
time.sleep(0.01) # Simulate some work
|
||||
rate_limit.exit(request_id)
|
||||
active_requests.remove(request_id)
|
||||
except AppInvokeQuotaExceededError:
|
||||
pass # Expected under load
|
||||
|
||||
threads = [threading.Thread(target=enter_and_exit) for _ in range(20)]
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# All requests should have been cleaned up
|
||||
assert len(active_requests) == 0
|
||||
|
||||
def _create_mock_redis(self):
|
||||
"""Create a thread-safe mock Redis for concurrency tests."""
|
||||
import threading
|
||||
|
||||
lock = threading.Lock()
|
||||
data = {}
|
||||
hashes = {}
|
||||
|
||||
def mock_hlen(key):
|
||||
with lock:
|
||||
return len(hashes.get(key, {}))
|
||||
|
||||
def mock_hset(key, field, value):
|
||||
with lock:
|
||||
if key not in hashes:
|
||||
hashes[key] = {}
|
||||
hashes[key][field] = str(value).encode("utf-8")
|
||||
return True
|
||||
|
||||
def mock_hdel(key, *fields):
|
||||
with lock:
|
||||
if key in hashes:
|
||||
count = 0
|
||||
for field in fields:
|
||||
if field in hashes[key]:
|
||||
del hashes[key][field]
|
||||
count += 1
|
||||
return count
|
||||
return 0
|
||||
|
||||
return {
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
"hlen.side_effect": mock_hlen,
|
||||
"hset.side_effect": mock_hset,
|
||||
"hdel.side_effect": mock_hdel,
|
||||
}
|
||||
|
|
@ -0,0 +1,247 @@
|
|||
"""
|
||||
Unit tests for CeleryWorkflowExecutionRepository.
|
||||
|
||||
These tests verify the Celery-based asynchronous storage functionality
|
||||
for workflow execution data.
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import Mock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowType
|
||||
from models import Account, EndUser
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_factory():
|
||||
"""Mock SQLAlchemy session factory."""
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
# Create a real sessionmaker with in-memory SQLite for testing
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
return sessionmaker(bind=engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account():
|
||||
"""Mock Account user."""
|
||||
account = Mock(spec=Account)
|
||||
account.id = str(uuid4())
|
||||
account.current_tenant_id = str(uuid4())
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_end_user():
|
||||
"""Mock EndUser."""
|
||||
user = Mock(spec=EndUser)
|
||||
user.id = str(uuid4())
|
||||
user.tenant_id = str(uuid4())
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_workflow_execution():
|
||||
"""Sample WorkflowExecution for testing."""
|
||||
return WorkflowExecution.new(
|
||||
id_=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_version="1.0",
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"input1": "value1"},
|
||||
started_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
)
|
||||
|
||||
|
||||
class TestCeleryWorkflowExecutionRepository:
|
||||
"""Test cases for CeleryWorkflowExecutionRepository."""
|
||||
|
||||
def test_init_with_sessionmaker(self, mock_session_factory, mock_account):
|
||||
"""Test repository initialization with sessionmaker."""
|
||||
app_id = "test-app-id"
|
||||
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
|
||||
repo = CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id=app_id,
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
|
||||
assert repo._tenant_id == mock_account.current_tenant_id
|
||||
assert repo._app_id == app_id
|
||||
assert repo._triggered_from == triggered_from
|
||||
assert repo._creator_user_id == mock_account.id
|
||||
assert repo._creator_user_role is not None
|
||||
|
||||
def test_init_basic_functionality(self, mock_session_factory, mock_account):
|
||||
"""Test repository initialization basic functionality."""
|
||||
repo = CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
)
|
||||
|
||||
# Verify basic initialization
|
||||
assert repo._tenant_id == mock_account.current_tenant_id
|
||||
assert repo._app_id == "test-app"
|
||||
assert repo._triggered_from == WorkflowRunTriggeredFrom.DEBUGGING
|
||||
|
||||
def test_init_with_end_user(self, mock_session_factory, mock_end_user):
|
||||
"""Test repository initialization with EndUser."""
|
||||
repo = CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_end_user,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
assert repo._tenant_id == mock_end_user.tenant_id
|
||||
|
||||
def test_init_without_tenant_id_raises_error(self, mock_session_factory):
|
||||
"""Test that initialization fails without tenant_id."""
|
||||
# Create a mock Account with no tenant_id
|
||||
user = Mock(spec=Account)
|
||||
user.current_tenant_id = None
|
||||
user.id = str(uuid4())
|
||||
|
||||
with pytest.raises(ValueError, match="User must have a tenant_id"):
|
||||
CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=user,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
|
||||
def test_save_queues_celery_task(self, mock_task, mock_session_factory, mock_account, sample_workflow_execution):
|
||||
"""Test that save operation queues a Celery task without tracking."""
|
||||
repo = CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
repo.save(sample_workflow_execution)
|
||||
|
||||
# Verify Celery task was queued with correct parameters
|
||||
mock_task.delay.assert_called_once()
|
||||
call_args = mock_task.delay.call_args[1]
|
||||
|
||||
assert call_args["execution_data"] == sample_workflow_execution.model_dump()
|
||||
assert call_args["tenant_id"] == mock_account.current_tenant_id
|
||||
assert call_args["app_id"] == "test-app"
|
||||
assert call_args["triggered_from"] == WorkflowRunTriggeredFrom.APP_RUN.value
|
||||
assert call_args["creator_user_id"] == mock_account.id
|
||||
|
||||
# Verify no task tracking occurs (no _pending_saves attribute)
|
||||
assert not hasattr(repo, "_pending_saves")
|
||||
|
||||
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
|
||||
def test_save_handles_celery_failure(
|
||||
self, mock_task, mock_session_factory, mock_account, sample_workflow_execution
|
||||
):
|
||||
"""Test that save operation handles Celery task failures."""
|
||||
mock_task.delay.side_effect = Exception("Celery is down")
|
||||
|
||||
repo = CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
with pytest.raises(Exception, match="Celery is down"):
|
||||
repo.save(sample_workflow_execution)
|
||||
|
||||
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
|
||||
def test_save_operation_fire_and_forget(
|
||||
self, mock_task, mock_session_factory, mock_account, sample_workflow_execution
|
||||
):
|
||||
"""Test that save operation works in fire-and-forget mode."""
|
||||
repo = CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
# Test that save doesn't block or maintain state
|
||||
repo.save(sample_workflow_execution)
|
||||
|
||||
# Verify no pending saves are tracked (no _pending_saves attribute)
|
||||
assert not hasattr(repo, "_pending_saves")
|
||||
|
||||
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
|
||||
def test_multiple_save_operations(self, mock_task, mock_session_factory, mock_account):
|
||||
"""Test multiple save operations work correctly."""
|
||||
repo = CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
# Create multiple executions
|
||||
exec1 = WorkflowExecution.new(
|
||||
id_=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_version="1.0",
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"input1": "value1"},
|
||||
started_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
)
|
||||
exec2 = WorkflowExecution.new(
|
||||
id_=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_version="1.0",
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"input2": "value2"},
|
||||
started_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
)
|
||||
|
||||
# Save both executions
|
||||
repo.save(exec1)
|
||||
repo.save(exec2)
|
||||
|
||||
# Should work without issues and not maintain state (no _pending_saves attribute)
|
||||
assert not hasattr(repo, "_pending_saves")
|
||||
|
||||
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
|
||||
def test_save_with_different_user_types(self, mock_task, mock_session_factory, mock_end_user):
|
||||
"""Test save operation with different user types."""
|
||||
repo = CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_end_user,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
execution = WorkflowExecution.new(
|
||||
id_=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_version="1.0",
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"input1": "value1"},
|
||||
started_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
)
|
||||
|
||||
repo.save(execution)
|
||||
|
||||
# Verify task was called with EndUser context
|
||||
mock_task.delay.assert_called_once()
|
||||
call_args = mock_task.delay.call_args[1]
|
||||
assert call_args["tenant_id"] == mock_end_user.tenant_id
|
||||
assert call_args["creator_user_id"] == mock_end_user.id
|
||||
|
|
@ -0,0 +1,349 @@
|
|||
"""
|
||||
Unit tests for CeleryWorkflowNodeExecutionRepository.
|
||||
|
||||
These tests verify the Celery-based asynchronous storage functionality
|
||||
for workflow node execution data.
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import Mock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository
|
||||
from core.workflow.entities.workflow_node_execution import (
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig
|
||||
from models import Account, EndUser
|
||||
from models.workflow import WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_factory():
|
||||
"""Mock SQLAlchemy session factory."""
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
# Create a real sessionmaker with in-memory SQLite for testing
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
return sessionmaker(bind=engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account():
|
||||
"""Mock Account user."""
|
||||
account = Mock(spec=Account)
|
||||
account.id = str(uuid4())
|
||||
account.current_tenant_id = str(uuid4())
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_end_user():
|
||||
"""Mock EndUser."""
|
||||
user = Mock(spec=EndUser)
|
||||
user.id = str(uuid4())
|
||||
user.tenant_id = str(uuid4())
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_workflow_node_execution():
|
||||
"""Sample WorkflowNodeExecution for testing."""
|
||||
return WorkflowNodeExecution(
|
||||
id=str(uuid4()),
|
||||
node_execution_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
workflow_execution_id=str(uuid4()),
|
||||
index=1,
|
||||
node_id="test_node",
|
||||
node_type=NodeType.START,
|
||||
title="Test Node",
|
||||
inputs={"input1": "value1"},
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
created_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
)
|
||||
|
||||
|
||||
class TestCeleryWorkflowNodeExecutionRepository:
|
||||
"""Test cases for CeleryWorkflowNodeExecutionRepository."""
|
||||
|
||||
def test_init_with_sessionmaker(self, mock_session_factory, mock_account):
|
||||
"""Test repository initialization with sessionmaker."""
|
||||
app_id = "test-app-id"
|
||||
triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
|
||||
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id=app_id,
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
|
||||
assert repo._tenant_id == mock_account.current_tenant_id
|
||||
assert repo._app_id == app_id
|
||||
assert repo._triggered_from == triggered_from
|
||||
assert repo._creator_user_id == mock_account.id
|
||||
assert repo._creator_user_role is not None
|
||||
|
||||
def test_init_with_cache_initialized(self, mock_session_factory, mock_account):
|
||||
"""Test repository initialization with cache properly initialized."""
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
|
||||
assert repo._execution_cache == {}
|
||||
assert repo._workflow_execution_mapping == {}
|
||||
|
||||
def test_init_with_end_user(self, mock_session_factory, mock_end_user):
|
||||
"""Test repository initialization with EndUser."""
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_end_user,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
assert repo._tenant_id == mock_end_user.tenant_id
|
||||
|
||||
def test_init_without_tenant_id_raises_error(self, mock_session_factory):
|
||||
"""Test that initialization fails without tenant_id."""
|
||||
# Create a mock Account with no tenant_id
|
||||
user = Mock(spec=Account)
|
||||
user.current_tenant_id = None
|
||||
user.id = str(uuid4())
|
||||
|
||||
with pytest.raises(ValueError, match="User must have a tenant_id"):
|
||||
CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=user,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
|
||||
def test_save_caches_and_queues_celery_task(
|
||||
self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution
|
||||
):
|
||||
"""Test that save operation caches execution and queues a Celery task."""
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
repo.save(sample_workflow_node_execution)
|
||||
|
||||
# Verify Celery task was queued with correct parameters
|
||||
mock_task.delay.assert_called_once()
|
||||
call_args = mock_task.delay.call_args[1]
|
||||
|
||||
assert call_args["execution_data"] == sample_workflow_node_execution.model_dump()
|
||||
assert call_args["tenant_id"] == mock_account.current_tenant_id
|
||||
assert call_args["app_id"] == "test-app"
|
||||
assert call_args["triggered_from"] == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
|
||||
assert call_args["creator_user_id"] == mock_account.id
|
||||
|
||||
# Verify execution is cached
|
||||
assert sample_workflow_node_execution.id in repo._execution_cache
|
||||
assert repo._execution_cache[sample_workflow_node_execution.id] == sample_workflow_node_execution
|
||||
|
||||
# Verify workflow execution mapping is updated
|
||||
assert sample_workflow_node_execution.workflow_execution_id in repo._workflow_execution_mapping
|
||||
assert (
|
||||
sample_workflow_node_execution.id
|
||||
in repo._workflow_execution_mapping[sample_workflow_node_execution.workflow_execution_id]
|
||||
)
|
||||
|
||||
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
|
||||
def test_save_handles_celery_failure(
|
||||
self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution
|
||||
):
|
||||
"""Test that save operation handles Celery task failures."""
|
||||
mock_task.delay.side_effect = Exception("Celery is down")
|
||||
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
with pytest.raises(Exception, match="Celery is down"):
|
||||
repo.save(sample_workflow_node_execution)
|
||||
|
||||
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
|
||||
def test_get_by_workflow_run_from_cache(
|
||||
self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution
|
||||
):
|
||||
"""Test that get_by_workflow_run retrieves executions from cache."""
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
# Save execution to cache first
|
||||
repo.save(sample_workflow_node_execution)
|
||||
|
||||
workflow_run_id = sample_workflow_node_execution.workflow_execution_id
|
||||
order_config = OrderConfig(order_by=["index"], order_direction="asc")
|
||||
|
||||
result = repo.get_by_workflow_run(workflow_run_id, order_config)
|
||||
|
||||
# Verify results were retrieved from cache
|
||||
assert len(result) == 1
|
||||
assert result[0].id == sample_workflow_node_execution.id
|
||||
assert result[0] is sample_workflow_node_execution
|
||||
|
||||
def test_get_by_workflow_run_without_order_config(self, mock_session_factory, mock_account):
|
||||
"""Test get_by_workflow_run without order configuration."""
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
result = repo.get_by_workflow_run("workflow-run-id")
|
||||
|
||||
# Should return empty list since nothing in cache
|
||||
assert len(result) == 0
|
||||
|
||||
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
|
||||
def test_cache_operations(self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution):
|
||||
"""Test cache operations work correctly."""
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
# Test saving to cache
|
||||
repo.save(sample_workflow_node_execution)
|
||||
|
||||
# Verify cache contains the execution
|
||||
assert sample_workflow_node_execution.id in repo._execution_cache
|
||||
|
||||
# Test retrieving from cache
|
||||
result = repo.get_by_workflow_run(sample_workflow_node_execution.workflow_execution_id)
|
||||
assert len(result) == 1
|
||||
assert result[0].id == sample_workflow_node_execution.id
|
||||
|
||||
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
|
||||
def test_multiple_executions_same_workflow(self, mock_task, mock_session_factory, mock_account):
|
||||
"""Test multiple executions for the same workflow."""
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
# Create multiple executions for the same workflow
|
||||
workflow_run_id = str(uuid4())
|
||||
exec1 = WorkflowNodeExecution(
|
||||
id=str(uuid4()),
|
||||
node_execution_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
workflow_execution_id=workflow_run_id,
|
||||
index=1,
|
||||
node_id="node1",
|
||||
node_type=NodeType.START,
|
||||
title="Node 1",
|
||||
inputs={"input1": "value1"},
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
created_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
)
|
||||
exec2 = WorkflowNodeExecution(
|
||||
id=str(uuid4()),
|
||||
node_execution_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
workflow_execution_id=workflow_run_id,
|
||||
index=2,
|
||||
node_id="node2",
|
||||
node_type=NodeType.LLM,
|
||||
title="Node 2",
|
||||
inputs={"input2": "value2"},
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
created_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
)
|
||||
|
||||
# Save both executions
|
||||
repo.save(exec1)
|
||||
repo.save(exec2)
|
||||
|
||||
# Verify both are cached and mapped
|
||||
assert len(repo._execution_cache) == 2
|
||||
assert len(repo._workflow_execution_mapping[workflow_run_id]) == 2
|
||||
|
||||
# Test retrieval
|
||||
result = repo.get_by_workflow_run(workflow_run_id)
|
||||
assert len(result) == 2
|
||||
|
||||
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
|
||||
def test_ordering_functionality(self, mock_task, mock_session_factory, mock_account):
|
||||
"""Test ordering functionality works correctly."""
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
# Create executions with different indices
|
||||
workflow_run_id = str(uuid4())
|
||||
exec1 = WorkflowNodeExecution(
|
||||
id=str(uuid4()),
|
||||
node_execution_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
workflow_execution_id=workflow_run_id,
|
||||
index=2,
|
||||
node_id="node2",
|
||||
node_type=NodeType.START,
|
||||
title="Node 2",
|
||||
inputs={},
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
created_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
)
|
||||
exec2 = WorkflowNodeExecution(
|
||||
id=str(uuid4()),
|
||||
node_execution_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
workflow_execution_id=workflow_run_id,
|
||||
index=1,
|
||||
node_id="node1",
|
||||
node_type=NodeType.LLM,
|
||||
title="Node 1",
|
||||
inputs={},
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
created_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
)
|
||||
|
||||
# Save in random order
|
||||
repo.save(exec1)
|
||||
repo.save(exec2)
|
||||
|
||||
# Test ascending order
|
||||
order_config = OrderConfig(order_by=["index"], order_direction="asc")
|
||||
result = repo.get_by_workflow_run(workflow_run_id, order_config)
|
||||
assert len(result) == 2
|
||||
assert result[0].index == 1
|
||||
assert result[1].index == 2
|
||||
|
||||
# Test descending order
|
||||
order_config = OrderConfig(order_by=["index"], order_direction="desc")
|
||||
result = repo.get_by_workflow_run(workflow_run_id, order_config)
|
||||
assert len(result) == 2
|
||||
assert result[0].index == 2
|
||||
assert result[1].index == 1
|
||||
|
|
@ -59,7 +59,7 @@ class TestRepositoryFactory:
|
|||
def get_by_id(self):
|
||||
pass
|
||||
|
||||
# Create a mock interface with the same methods
|
||||
# Create a mock interface class
|
||||
class MockInterface:
|
||||
def save(self):
|
||||
pass
|
||||
|
|
@ -67,20 +67,20 @@ class TestRepositoryFactory:
|
|||
def get_by_id(self):
|
||||
pass
|
||||
|
||||
# Should not raise an exception
|
||||
# Should not raise an exception when all methods are present
|
||||
DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface)
|
||||
|
||||
def test_validate_repository_interface_missing_methods(self):
|
||||
"""Test interface validation with missing methods."""
|
||||
|
||||
# Create a mock class that doesn't implement all required methods
|
||||
# Create a mock class that's missing required methods
|
||||
class IncompleteRepository:
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
# Missing get_by_id method
|
||||
|
||||
# Create a mock interface with required methods
|
||||
# Create a mock interface that requires both methods
|
||||
class MockInterface:
|
||||
def save(self):
|
||||
pass
|
||||
|
|
@ -88,57 +88,39 @@ class TestRepositoryFactory:
|
|||
def get_by_id(self):
|
||||
pass
|
||||
|
||||
def missing_method(self):
|
||||
pass
|
||||
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory._validate_repository_interface(IncompleteRepository, MockInterface)
|
||||
assert "does not implement required methods" in str(exc_info.value)
|
||||
assert "get_by_id" in str(exc_info.value)
|
||||
|
||||
def test_validate_constructor_signature_success(self):
|
||||
"""Test successful constructor signature validation."""
|
||||
def test_validate_repository_interface_with_private_methods(self):
|
||||
"""Test that private methods are ignored during interface validation."""
|
||||
|
||||
class MockRepository:
|
||||
def __init__(self, session_factory, user, app_id, triggered_from):
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
# Should not raise an exception
|
||||
DifyCoreRepositoryFactory._validate_constructor_signature(
|
||||
MockRepository, ["session_factory", "user", "app_id", "triggered_from"]
|
||||
)
|
||||
|
||||
def test_validate_constructor_signature_missing_params(self):
|
||||
"""Test constructor validation with missing parameters."""
|
||||
|
||||
class IncompleteRepository:
|
||||
def __init__(self, session_factory, user):
|
||||
# Missing app_id and triggered_from parameters
|
||||
def _private_method(self):
|
||||
pass
|
||||
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory._validate_constructor_signature(
|
||||
IncompleteRepository, ["session_factory", "user", "app_id", "triggered_from"]
|
||||
)
|
||||
assert "does not accept required parameters" in str(exc_info.value)
|
||||
assert "app_id" in str(exc_info.value)
|
||||
assert "triggered_from" in str(exc_info.value)
|
||||
|
||||
def test_validate_constructor_signature_inspection_error(self, mocker: MockerFixture):
|
||||
"""Test constructor validation when inspection fails."""
|
||||
# Mock inspect.signature to raise an exception
|
||||
mocker.patch("inspect.signature", side_effect=Exception("Inspection failed"))
|
||||
|
||||
class MockRepository:
|
||||
def __init__(self, session_factory):
|
||||
# Create a mock interface with private methods
|
||||
class MockInterface:
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory._validate_constructor_signature(MockRepository, ["session_factory"])
|
||||
assert "Failed to validate constructor signature" in str(exc_info.value)
|
||||
def _private_method(self):
|
||||
pass
|
||||
|
||||
# Should not raise exception - private methods should be ignored
|
||||
DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface)
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_execution_repository_success(self, mock_config, mocker: MockerFixture):
|
||||
"""Test successful creation of WorkflowExecutionRepository."""
|
||||
def test_create_workflow_execution_repository_success(self, mock_config):
|
||||
"""Test successful WorkflowExecutionRepository creation."""
|
||||
# Setup mock configuration
|
||||
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
|
||||
# Create mock dependencies
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
|
|
@ -146,7 +128,7 @@ class TestRepositoryFactory:
|
|||
app_id = "test-app-id"
|
||||
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
|
||||
# Mock the imported class to be a valid repository
|
||||
# Create mock repository class and instance
|
||||
mock_repository_class = MagicMock()
|
||||
mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository)
|
||||
mock_repository_class.return_value = mock_repository_instance
|
||||
|
|
@ -155,7 +137,6 @@ class TestRepositoryFactory:
|
|||
with (
|
||||
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
|
||||
):
|
||||
result = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
|
|
@ -177,7 +158,7 @@ class TestRepositoryFactory:
|
|||
def test_create_workflow_execution_repository_import_error(self, mock_config):
|
||||
"""Test WorkflowExecutionRepository creation with import error."""
|
||||
# Setup mock configuration with invalid class path
|
||||
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "invalid.module.InvalidClass"
|
||||
mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "invalid.module.InvalidClass"
|
||||
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_user = MagicMock(spec=Account)
|
||||
|
|
@ -195,45 +176,46 @@ class TestRepositoryFactory:
|
|||
def test_create_workflow_execution_repository_validation_error(self, mock_config, mocker: MockerFixture):
|
||||
"""Test WorkflowExecutionRepository creation with validation error."""
|
||||
# Setup mock configuration
|
||||
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_user = MagicMock(spec=Account)
|
||||
|
||||
# Mock import to succeed but validation to fail
|
||||
# Mock the import to succeed but validation to fail
|
||||
mock_repository_class = MagicMock()
|
||||
with (
|
||||
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
|
||||
patch.object(
|
||||
DifyCoreRepositoryFactory,
|
||||
"_validate_repository_interface",
|
||||
side_effect=RepositoryImportError("Interface validation failed"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
assert "Interface validation failed" in str(exc_info.value)
|
||||
mocker.patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class)
|
||||
mocker.patch.object(
|
||||
DifyCoreRepositoryFactory,
|
||||
"_validate_repository_interface",
|
||||
side_effect=RepositoryImportError("Interface validation failed"),
|
||||
)
|
||||
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
assert "Interface validation failed" in str(exc_info.value)
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_execution_repository_instantiation_error(self, mock_config, mocker: MockerFixture):
|
||||
def test_create_workflow_execution_repository_instantiation_error(self, mock_config):
|
||||
"""Test WorkflowExecutionRepository creation with instantiation error."""
|
||||
# Setup mock configuration
|
||||
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_user = MagicMock(spec=Account)
|
||||
|
||||
# Mock import and validation to succeed but instantiation to fail
|
||||
mock_repository_class = MagicMock(side_effect=Exception("Instantiation failed"))
|
||||
# Create a mock repository class that raises exception on instantiation
|
||||
mock_repository_class = MagicMock()
|
||||
mock_repository_class.side_effect = Exception("Instantiation failed")
|
||||
|
||||
# Mock the validation methods to succeed
|
||||
with (
|
||||
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
|
||||
):
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
|
|
@ -245,18 +227,18 @@ class TestRepositoryFactory:
|
|||
assert "Failed to create WorkflowExecutionRepository" in str(exc_info.value)
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_node_execution_repository_success(self, mock_config, mocker: MockerFixture):
|
||||
"""Test successful creation of WorkflowNodeExecutionRepository."""
|
||||
def test_create_workflow_node_execution_repository_success(self, mock_config):
|
||||
"""Test successful WorkflowNodeExecutionRepository creation."""
|
||||
# Setup mock configuration
|
||||
mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
mock_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
|
||||
# Create mock dependencies
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_user = MagicMock(spec=EndUser)
|
||||
app_id = "test-app-id"
|
||||
triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
|
||||
triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP
|
||||
|
||||
# Mock the imported class to be a valid repository
|
||||
# Create mock repository class and instance
|
||||
mock_repository_class = MagicMock()
|
||||
mock_repository_instance = MagicMock(spec=WorkflowNodeExecutionRepository)
|
||||
mock_repository_class.return_value = mock_repository_instance
|
||||
|
|
@ -265,7 +247,6 @@ class TestRepositoryFactory:
|
|||
with (
|
||||
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
|
||||
):
|
||||
result = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
|
|
@ -287,7 +268,7 @@ class TestRepositoryFactory:
|
|||
def test_create_workflow_node_execution_repository_import_error(self, mock_config):
|
||||
"""Test WorkflowNodeExecutionRepository creation with import error."""
|
||||
# Setup mock configuration with invalid class path
|
||||
mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "invalid.module.InvalidClass"
|
||||
mock_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY = "invalid.module.InvalidClass"
|
||||
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_user = MagicMock(spec=EndUser)
|
||||
|
|
@ -297,28 +278,83 @@ class TestRepositoryFactory:
|
|||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
assert "Cannot import repository class" in str(exc_info.value)
|
||||
|
||||
def test_repository_import_error_exception(self):
|
||||
"""Test RepositoryImportError exception."""
|
||||
error_message = "Test error message"
|
||||
exception = RepositoryImportError(error_message)
|
||||
assert str(exception) == error_message
|
||||
assert isinstance(exception, Exception)
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_node_execution_repository_validation_error(self, mock_config, mocker: MockerFixture):
|
||||
"""Test WorkflowNodeExecutionRepository creation with validation error."""
|
||||
# Setup mock configuration
|
||||
mock_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_user = MagicMock(spec=EndUser)
|
||||
|
||||
# Mock the import to succeed but validation to fail
|
||||
mock_repository_class = MagicMock()
|
||||
mocker.patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class)
|
||||
mocker.patch.object(
|
||||
DifyCoreRepositoryFactory,
|
||||
"_validate_repository_interface",
|
||||
side_effect=RepositoryImportError("Interface validation failed"),
|
||||
)
|
||||
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
assert "Interface validation failed" in str(exc_info.value)
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_with_engine_instead_of_sessionmaker(self, mock_config, mocker: MockerFixture):
|
||||
def test_create_workflow_node_execution_repository_instantiation_error(self, mock_config):
|
||||
"""Test WorkflowNodeExecutionRepository creation with instantiation error."""
|
||||
# Setup mock configuration
|
||||
mock_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_user = MagicMock(spec=EndUser)
|
||||
|
||||
# Create a mock repository class that raises exception on instantiation
|
||||
mock_repository_class = MagicMock()
|
||||
mock_repository_class.side_effect = Exception("Instantiation failed")
|
||||
|
||||
# Mock the validation methods to succeed
|
||||
with (
|
||||
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
|
||||
):
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
assert "Failed to create WorkflowNodeExecutionRepository" in str(exc_info.value)
|
||||
|
||||
def test_repository_import_error_exception(self):
|
||||
"""Test RepositoryImportError exception handling."""
|
||||
error_message = "Custom error message"
|
||||
error = RepositoryImportError(error_message)
|
||||
assert str(error) == error_message
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_with_engine_instead_of_sessionmaker(self, mock_config):
|
||||
"""Test repository creation with Engine instead of sessionmaker."""
|
||||
# Setup mock configuration
|
||||
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
|
||||
# Create mock dependencies with Engine instead of sessionmaker
|
||||
# Create mock dependencies using Engine instead of sessionmaker
|
||||
mock_engine = MagicMock(spec=Engine)
|
||||
mock_user = MagicMock(spec=Account)
|
||||
app_id = "test-app-id"
|
||||
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
|
||||
# Mock the imported class to be a valid repository
|
||||
# Create mock repository class and instance
|
||||
mock_repository_class = MagicMock()
|
||||
mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository)
|
||||
mock_repository_class.return_value = mock_repository_instance
|
||||
|
|
@ -327,129 +363,19 @@ class TestRepositoryFactory:
|
|||
with (
|
||||
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
|
||||
):
|
||||
result = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=mock_engine, # Using Engine instead of sessionmaker
|
||||
user=mock_user,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
app_id=app_id,
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
|
||||
# Verify the repository was created with the Engine
|
||||
# Verify the repository was created with correct parameters
|
||||
mock_repository_class.assert_called_once_with(
|
||||
session_factory=mock_engine,
|
||||
user=mock_user,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
app_id=app_id,
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
assert result is mock_repository_instance
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_node_execution_repository_validation_error(self, mock_config):
|
||||
"""Test WorkflowNodeExecutionRepository creation with validation error."""
|
||||
# Setup mock configuration
|
||||
mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_user = MagicMock(spec=EndUser)
|
||||
|
||||
# Mock import to succeed but validation to fail
|
||||
mock_repository_class = MagicMock()
|
||||
with (
|
||||
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
|
||||
patch.object(
|
||||
DifyCoreRepositoryFactory,
|
||||
"_validate_repository_interface",
|
||||
side_effect=RepositoryImportError("Interface validation failed"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
assert "Interface validation failed" in str(exc_info.value)
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_node_execution_repository_instantiation_error(self, mock_config):
|
||||
"""Test WorkflowNodeExecutionRepository creation with instantiation error."""
|
||||
# Setup mock configuration
|
||||
mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_user = MagicMock(spec=EndUser)
|
||||
|
||||
# Mock import and validation to succeed but instantiation to fail
|
||||
mock_repository_class = MagicMock(side_effect=Exception("Instantiation failed"))
|
||||
with (
|
||||
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
|
||||
):
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
assert "Failed to create WorkflowNodeExecutionRepository" in str(exc_info.value)
|
||||
|
||||
def test_validate_repository_interface_with_private_methods(self):
|
||||
"""Test interface validation ignores private methods."""
|
||||
|
||||
# Create a mock class with private methods
|
||||
class MockRepository:
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
def get_by_id(self):
|
||||
pass
|
||||
|
||||
def _private_method(self):
|
||||
pass
|
||||
|
||||
# Create a mock interface with private methods
|
||||
class MockInterface:
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
def get_by_id(self):
|
||||
pass
|
||||
|
||||
def _private_method(self):
|
||||
pass
|
||||
|
||||
# Should not raise an exception (private methods are ignored)
|
||||
DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface)
|
||||
|
||||
def test_validate_constructor_signature_with_extra_params(self):
|
||||
"""Test constructor validation with extra parameters (should pass)."""
|
||||
|
||||
class MockRepository:
|
||||
def __init__(self, session_factory, user, app_id, triggered_from, extra_param=None):
|
||||
pass
|
||||
|
||||
# Should not raise an exception (extra parameters are allowed)
|
||||
DifyCoreRepositoryFactory._validate_constructor_signature(
|
||||
MockRepository, ["session_factory", "user", "app_id", "triggered_from"]
|
||||
)
|
||||
|
||||
def test_validate_constructor_signature_with_kwargs(self):
|
||||
"""Test constructor validation with **kwargs (current implementation doesn't support this)."""
|
||||
|
||||
class MockRepository:
|
||||
def __init__(self, session_factory, user, **kwargs):
|
||||
pass
|
||||
|
||||
# Current implementation doesn't handle **kwargs, so this should raise an exception
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory._validate_constructor_signature(
|
||||
MockRepository, ["session_factory", "user", "app_id", "triggered_from"]
|
||||
)
|
||||
assert "does not accept required parameters" in str(exc_info.value)
|
||||
assert "app_id" in str(exc_info.value)
|
||||
assert "triggered_from" in str(exc_info.value)
|
||||
|
|
|
|||
|
|
@ -69,8 +69,12 @@ def test_get_file_attribute(pool, file):
|
|||
|
||||
|
||||
def test_use_long_selector(pool):
|
||||
pool.add(("node_1", "part_1", "part_2"), StringSegment(value="test_value"))
|
||||
# The add method now only accepts 2-element selectors (node_id, variable_name)
|
||||
# Store nested data as an ObjectSegment instead
|
||||
nested_data = {"part_2": "test_value"}
|
||||
pool.add(("node_1", "part_1"), ObjectSegment(value=nested_data))
|
||||
|
||||
# The get method supports longer selectors for nested access
|
||||
result = pool.get(("node_1", "part_1", "part_2"))
|
||||
assert result is not None
|
||||
assert result.value == "test_value"
|
||||
|
|
@ -280,8 +284,10 @@ class TestVariablePoolSerialization:
|
|||
pool.add((self._NODE2_ID, "array_file"), ArrayFileSegment(value=[test_file]))
|
||||
pool.add((self._NODE2_ID, "array_any"), ArrayAnySegment(value=["mixed", 123, {"key": "value"}]))
|
||||
|
||||
# Add nested variables
|
||||
pool.add((self._NODE3_ID, "nested", "deep", "var"), StringSegment(value="deep_value"))
|
||||
# Add nested variables as ObjectSegment
|
||||
# The add method only accepts 2-element selectors
|
||||
nested_obj = {"deep": {"var": "deep_value"}}
|
||||
pool.add((self._NODE3_ID, "nested"), ObjectSegment(value=nested_obj))
|
||||
|
||||
def test_system_variables(self):
|
||||
sys_vars = SystemVariable(
|
||||
|
|
|
|||
|
|
@ -1,148 +0,0 @@
|
|||
from typing import Any
|
||||
|
||||
from core.variables.segments import ObjectSegment, StringSegment
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.utils.variable_utils import append_variables_recursively
|
||||
|
||||
|
||||
class TestAppendVariablesRecursively:
|
||||
"""Test cases for append_variables_recursively function"""
|
||||
|
||||
def test_append_simple_dict_value(self):
|
||||
"""Test appending a simple dictionary value"""
|
||||
pool = VariablePool.empty()
|
||||
node_id = "test_node"
|
||||
variable_key_list = ["output"]
|
||||
variable_value = {"name": "John", "age": 30}
|
||||
|
||||
append_variables_recursively(pool, node_id, variable_key_list, variable_value)
|
||||
|
||||
# Check that the main variable is added
|
||||
main_var = pool.get([node_id] + variable_key_list)
|
||||
assert main_var is not None
|
||||
assert main_var.value == variable_value
|
||||
|
||||
# Check that nested variables are added recursively
|
||||
name_var = pool.get([node_id] + variable_key_list + ["name"])
|
||||
assert name_var is not None
|
||||
assert name_var.value == "John"
|
||||
|
||||
age_var = pool.get([node_id] + variable_key_list + ["age"])
|
||||
assert age_var is not None
|
||||
assert age_var.value == 30
|
||||
|
||||
def test_append_object_segment_value(self):
|
||||
"""Test appending an ObjectSegment value"""
|
||||
pool = VariablePool.empty()
|
||||
node_id = "test_node"
|
||||
variable_key_list = ["result"]
|
||||
|
||||
# Create an ObjectSegment
|
||||
obj_data = {"status": "success", "code": 200}
|
||||
variable_value = ObjectSegment(value=obj_data)
|
||||
|
||||
append_variables_recursively(pool, node_id, variable_key_list, variable_value)
|
||||
|
||||
# Check that the main variable is added
|
||||
main_var = pool.get([node_id] + variable_key_list)
|
||||
assert main_var is not None
|
||||
assert isinstance(main_var, ObjectSegment)
|
||||
assert main_var.value == obj_data
|
||||
|
||||
# Check that nested variables are added recursively
|
||||
status_var = pool.get([node_id] + variable_key_list + ["status"])
|
||||
assert status_var is not None
|
||||
assert status_var.value == "success"
|
||||
|
||||
code_var = pool.get([node_id] + variable_key_list + ["code"])
|
||||
assert code_var is not None
|
||||
assert code_var.value == 200
|
||||
|
||||
def test_append_nested_dict_value(self):
|
||||
"""Test appending a nested dictionary value"""
|
||||
pool = VariablePool.empty()
|
||||
node_id = "test_node"
|
||||
variable_key_list = ["data"]
|
||||
|
||||
variable_value = {
|
||||
"user": {
|
||||
"profile": {"name": "Alice", "email": "alice@example.com"},
|
||||
"settings": {"theme": "dark", "notifications": True},
|
||||
},
|
||||
"metadata": {"version": "1.0", "timestamp": 1234567890},
|
||||
}
|
||||
|
||||
append_variables_recursively(pool, node_id, variable_key_list, variable_value)
|
||||
|
||||
# Check deeply nested variables
|
||||
name_var = pool.get([node_id] + variable_key_list + ["user", "profile", "name"])
|
||||
assert name_var is not None
|
||||
assert name_var.value == "Alice"
|
||||
|
||||
email_var = pool.get([node_id] + variable_key_list + ["user", "profile", "email"])
|
||||
assert email_var is not None
|
||||
assert email_var.value == "alice@example.com"
|
||||
|
||||
theme_var = pool.get([node_id] + variable_key_list + ["user", "settings", "theme"])
|
||||
assert theme_var is not None
|
||||
assert theme_var.value == "dark"
|
||||
|
||||
notifications_var = pool.get([node_id] + variable_key_list + ["user", "settings", "notifications"])
|
||||
assert notifications_var is not None
|
||||
assert notifications_var.value == 1 # Boolean True is converted to integer 1
|
||||
|
||||
version_var = pool.get([node_id] + variable_key_list + ["metadata", "version"])
|
||||
assert version_var is not None
|
||||
assert version_var.value == "1.0"
|
||||
|
||||
def test_append_non_dict_value(self):
|
||||
"""Test appending a non-dictionary value (should not recurse)"""
|
||||
pool = VariablePool.empty()
|
||||
node_id = "test_node"
|
||||
variable_key_list = ["simple"]
|
||||
variable_value = "simple_string"
|
||||
|
||||
append_variables_recursively(pool, node_id, variable_key_list, variable_value)
|
||||
|
||||
# Check that only the main variable is added
|
||||
main_var = pool.get([node_id] + variable_key_list)
|
||||
assert main_var is not None
|
||||
assert main_var.value == variable_value
|
||||
|
||||
# Ensure no additional variables are created
|
||||
assert len(pool.variable_dictionary[node_id]) == 1
|
||||
|
||||
def test_append_segment_non_object_value(self):
|
||||
"""Test appending a Segment that is not ObjectSegment (should not recurse)"""
|
||||
pool = VariablePool.empty()
|
||||
node_id = "test_node"
|
||||
variable_key_list = ["text"]
|
||||
variable_value = StringSegment(value="Hello World")
|
||||
|
||||
append_variables_recursively(pool, node_id, variable_key_list, variable_value)
|
||||
|
||||
# Check that only the main variable is added
|
||||
main_var = pool.get([node_id] + variable_key_list)
|
||||
assert main_var is not None
|
||||
assert isinstance(main_var, StringSegment)
|
||||
assert main_var.value == "Hello World"
|
||||
|
||||
# Ensure no additional variables are created
|
||||
assert len(pool.variable_dictionary[node_id]) == 1
|
||||
|
||||
def test_append_empty_dict_value(self):
|
||||
"""Test appending an empty dictionary value"""
|
||||
pool = VariablePool.empty()
|
||||
node_id = "test_node"
|
||||
variable_key_list = ["empty"]
|
||||
variable_value: dict[str, Any] = {}
|
||||
|
||||
append_variables_recursively(pool, node_id, variable_key_list, variable_value)
|
||||
|
||||
# Check that the main variable is added
|
||||
main_var = pool.get([node_id] + variable_key_list)
|
||||
assert main_var is not None
|
||||
assert main_var.value == {}
|
||||
|
||||
# Ensure only the main variable is created (no recursion for empty dict)
|
||||
assert len(pool.variable_dictionary[node_id]) == 1
|
||||
19
api/uv.lock
19
api/uv.lock
|
|
@ -1,5 +1,5 @@
|
|||
version = 1
|
||||
revision = 3
|
||||
revision = 2
|
||||
requires-python = ">=3.11, <3.13"
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.12.4' and platform_python_implementation != 'PyPy' and sys_platform == 'linux'",
|
||||
|
|
@ -1236,7 +1236,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "dify-api"
|
||||
version = "1.7.1"
|
||||
version = "1.7.2"
|
||||
source = { virtual = "." }
|
||||
dependencies = [
|
||||
{ name = "arize-phoenix-otel" },
|
||||
|
|
@ -1371,6 +1371,7 @@ dev = [
|
|||
{ name = "types-python-http-client" },
|
||||
{ name = "types-pywin32" },
|
||||
{ name = "types-pyyaml" },
|
||||
{ name = "types-redis" },
|
||||
{ name = "types-regex" },
|
||||
{ name = "types-requests" },
|
||||
{ name = "types-requests-oauthlib" },
|
||||
|
|
@ -1557,6 +1558,7 @@ dev = [
|
|||
{ name = "types-python-http-client", specifier = ">=3.3.7.20240910" },
|
||||
{ name = "types-pywin32", specifier = "~=310.0.0" },
|
||||
{ name = "types-pyyaml", specifier = "~=6.0.12" },
|
||||
{ name = "types-redis", specifier = ">=4.6.0.20241004" },
|
||||
{ name = "types-regex", specifier = "~=2024.11.6" },
|
||||
{ name = "types-requests", specifier = "~=2.32.0" },
|
||||
{ name = "types-requests-oauthlib", specifier = "~=2.0.0" },
|
||||
|
|
@ -6064,6 +6066,19 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/99/5f/e0af6f7f6a260d9af67e1db4f54d732abad514252a7a378a6c4d17dd1036/types_pyyaml-6.0.12.20250516-py3-none-any.whl", hash = "sha256:8478208feaeb53a34cb5d970c56a7cd76b72659442e733e268a94dc72b2d0530", size = 20312, upload-time = "2025-05-16T03:08:04.019Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "types-redis"
|
||||
version = "4.6.0.20241004"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "cryptography" },
|
||||
{ name = "types-pyopenssl" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/3a/95/c054d3ac940e8bac4ca216470c80c26688a0e79e09f520a942bb27da3386/types-redis-4.6.0.20241004.tar.gz", hash = "sha256:5f17d2b3f9091ab75384153bfa276619ffa1cf6a38da60e10d5e6749cc5b902e", size = 49679, upload-time = "2024-10-04T02:43:59.224Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/55/82/7d25dce10aad92d2226b269bce2f85cfd843b4477cd50245d7d40ecf8f89/types_redis-4.6.0.20241004-py3-none-any.whl", hash = "sha256:ef5da68cb827e5f606c8f9c0b49eeee4c2669d6d97122f301d3a55dc6a63f6ed", size = 58737, upload-time = "2024-10-04T02:43:57.968Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "types-regex"
|
||||
version = "2024.11.6.20250403"
|
||||
|
|
|
|||
|
|
@ -8,4 +8,4 @@ cd "$SCRIPT_DIR/.."
|
|||
|
||||
uv --directory api run \
|
||||
celery -A app.celery worker \
|
||||
-P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion
|
||||
-P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage
|
||||
|
|
|
|||
|
|
@ -861,17 +861,23 @@ WORKFLOW_NODE_EXECUTION_STORAGE=rdbms
|
|||
|
||||
# Repository configuration
|
||||
# Core workflow execution repository implementation
|
||||
# Options:
|
||||
# - core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository (default)
|
||||
# - core.repositories.celery_workflow_execution_repository.CeleryWorkflowExecutionRepository
|
||||
CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository
|
||||
|
||||
# Core workflow node execution repository implementation
|
||||
# Options:
|
||||
# - core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository (default)
|
||||
# - core.repositories.celery_workflow_node_execution_repository.CeleryWorkflowNodeExecutionRepository
|
||||
CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository
|
||||
|
||||
# API workflow node execution repository implementation
|
||||
API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository
|
||||
|
||||
# API workflow run repository implementation
|
||||
API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository
|
||||
|
||||
# API workflow node execution repository implementation
|
||||
API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository
|
||||
|
||||
# HTTP request node in workflow configuration
|
||||
HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
|
||||
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576
|
||||
|
|
@ -907,6 +913,9 @@ TEXT_GENERATION_TIMEOUT_MS=60000
|
|||
# Allow rendering unsafe URLs which have "data:" scheme.
|
||||
ALLOW_UNSAFE_DATA_SCHEME=false
|
||||
|
||||
# Maximum number of tree depth in the workflow
|
||||
MAX_TREE_DEPTH=50
|
||||
|
||||
# ------------------------------
|
||||
# Environment Variables for db Service
|
||||
# ------------------------------
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ x-shared-env: &shared-api-worker-env
|
|||
services:
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:1.7.1
|
||||
image: langgenius/dify-api:1.7.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
|
|
@ -31,7 +31,7 @@ services:
|
|||
# worker service
|
||||
# The Celery worker for processing the queue.
|
||||
worker:
|
||||
image: langgenius/dify-api:1.7.1
|
||||
image: langgenius/dify-api:1.7.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
|
|
@ -58,7 +58,7 @@ services:
|
|||
# worker_beat service
|
||||
# Celery beat for scheduling periodic tasks.
|
||||
worker_beat:
|
||||
image: langgenius/dify-api:1.7.1
|
||||
image: langgenius/dify-api:1.7.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
|
|
@ -76,7 +76,7 @@ services:
|
|||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:1.7.1
|
||||
image: langgenius/dify-web:1.7.2
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
|
|
@ -96,6 +96,7 @@ services:
|
|||
MAX_TOOLS_NUM: ${MAX_TOOLS_NUM:-10}
|
||||
MAX_PARALLEL_LIMIT: ${MAX_PARALLEL_LIMIT:-10}
|
||||
MAX_ITERATIONS_NUM: ${MAX_ITERATIONS_NUM:-99}
|
||||
MAX_TREE_DEPTH: ${MAX_TREE_DEPTH:-50}
|
||||
ENABLE_WEBSITE_JINAREADER: ${ENABLE_WEBSITE_JINAREADER:-true}
|
||||
ENABLE_WEBSITE_FIRECRAWL: ${ENABLE_WEBSITE_FIRECRAWL:-true}
|
||||
ENABLE_WEBSITE_WATERCRAWL: ${ENABLE_WEBSITE_WATERCRAWL:-true}
|
||||
|
|
|
|||
|
|
@ -390,8 +390,8 @@ x-shared-env: &shared-api-worker-env
|
|||
WORKFLOW_NODE_EXECUTION_STORAGE: ${WORKFLOW_NODE_EXECUTION_STORAGE:-rdbms}
|
||||
CORE_WORKFLOW_EXECUTION_REPOSITORY: ${CORE_WORKFLOW_EXECUTION_REPOSITORY:-core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository}
|
||||
CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY:-core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository}
|
||||
API_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${API_WORKFLOW_NODE_EXECUTION_REPOSITORY:-repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository}
|
||||
API_WORKFLOW_RUN_REPOSITORY: ${API_WORKFLOW_RUN_REPOSITORY:-repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository}
|
||||
API_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${API_WORKFLOW_NODE_EXECUTION_REPOSITORY:-repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository}
|
||||
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760}
|
||||
HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576}
|
||||
HTTP_REQUEST_NODE_SSL_VERIFY: ${HTTP_REQUEST_NODE_SSL_VERIFY:-True}
|
||||
|
|
@ -404,6 +404,7 @@ x-shared-env: &shared-api-worker-env
|
|||
MAX_ITERATIONS_NUM: ${MAX_ITERATIONS_NUM:-99}
|
||||
TEXT_GENERATION_TIMEOUT_MS: ${TEXT_GENERATION_TIMEOUT_MS:-60000}
|
||||
ALLOW_UNSAFE_DATA_SCHEME: ${ALLOW_UNSAFE_DATA_SCHEME:-false}
|
||||
MAX_TREE_DEPTH: ${MAX_TREE_DEPTH:-50}
|
||||
POSTGRES_USER: ${POSTGRES_USER:-${DB_USERNAME}}
|
||||
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-${DB_PASSWORD}}
|
||||
POSTGRES_DB: ${POSTGRES_DB:-${DB_DATABASE}}
|
||||
|
|
@ -567,7 +568,7 @@ x-shared-env: &shared-api-worker-env
|
|||
services:
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:1.7.1
|
||||
image: langgenius/dify-api:1.7.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
|
|
@ -596,7 +597,7 @@ services:
|
|||
# worker service
|
||||
# The Celery worker for processing the queue.
|
||||
worker:
|
||||
image: langgenius/dify-api:1.7.1
|
||||
image: langgenius/dify-api:1.7.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
|
|
@ -623,7 +624,7 @@ services:
|
|||
# worker_beat service
|
||||
# Celery beat for scheduling periodic tasks.
|
||||
worker_beat:
|
||||
image: langgenius/dify-api:1.7.1
|
||||
image: langgenius/dify-api:1.7.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
|
|
@ -641,7 +642,7 @@ services:
|
|||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:1.7.1
|
||||
image: langgenius/dify-web:1.7.2
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
|
|
@ -661,6 +662,7 @@ services:
|
|||
MAX_TOOLS_NUM: ${MAX_TOOLS_NUM:-10}
|
||||
MAX_PARALLEL_LIMIT: ${MAX_PARALLEL_LIMIT:-10}
|
||||
MAX_ITERATIONS_NUM: ${MAX_ITERATIONS_NUM:-99}
|
||||
MAX_TREE_DEPTH: ${MAX_TREE_DEPTH:-50}
|
||||
ENABLE_WEBSITE_JINAREADER: ${ENABLE_WEBSITE_JINAREADER:-true}
|
||||
ENABLE_WEBSITE_FIRECRAWL: ${ENABLE_WEBSITE_FIRECRAWL:-true}
|
||||
ENABLE_WEBSITE_WATERCRAWL: ${ENABLE_WEBSITE_WATERCRAWL:-true}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,333 @@
|
|||
import React from 'react'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import '@testing-library/jest-dom'
|
||||
import CommandSelector from '../../app/components/goto-anything/command-selector'
|
||||
import type { ActionItem } from '../../app/components/goto-anything/actions/types'
|
||||
|
||||
jest.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key,
|
||||
}),
|
||||
}))
|
||||
|
||||
jest.mock('cmdk', () => ({
|
||||
Command: {
|
||||
Group: ({ children, className }: any) => <div className={className}>{children}</div>,
|
||||
Item: ({ children, onSelect, value, className }: any) => (
|
||||
<div
|
||||
className={className}
|
||||
onClick={() => onSelect && onSelect()}
|
||||
data-value={value}
|
||||
data-testid={`command-item-${value}`}
|
||||
>
|
||||
{children}
|
||||
</div>
|
||||
),
|
||||
},
|
||||
}))
|
||||
|
||||
describe('CommandSelector', () => {
|
||||
const mockActions: Record<string, ActionItem> = {
|
||||
app: {
|
||||
key: '@app',
|
||||
shortcut: '@app',
|
||||
title: 'Search Applications',
|
||||
description: 'Search apps',
|
||||
search: jest.fn(),
|
||||
},
|
||||
knowledge: {
|
||||
key: '@knowledge',
|
||||
shortcut: '@knowledge',
|
||||
title: 'Search Knowledge',
|
||||
description: 'Search knowledge bases',
|
||||
search: jest.fn(),
|
||||
},
|
||||
plugin: {
|
||||
key: '@plugin',
|
||||
shortcut: '@plugin',
|
||||
title: 'Search Plugins',
|
||||
description: 'Search plugins',
|
||||
search: jest.fn(),
|
||||
},
|
||||
node: {
|
||||
key: '@node',
|
||||
shortcut: '@node',
|
||||
title: 'Search Nodes',
|
||||
description: 'Search workflow nodes',
|
||||
search: jest.fn(),
|
||||
},
|
||||
}
|
||||
|
||||
const mockOnCommandSelect = jest.fn()
|
||||
const mockOnCommandValueChange = jest.fn()
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('Basic Rendering', () => {
|
||||
it('should render all actions when no filter is provided', () => {
|
||||
render(
|
||||
<CommandSelector
|
||||
actions={mockActions}
|
||||
onCommandSelect={mockOnCommandSelect}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('command-item-@app')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('command-item-@knowledge')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('command-item-@plugin')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('command-item-@node')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render empty filter as showing all actions', () => {
|
||||
render(
|
||||
<CommandSelector
|
||||
actions={mockActions}
|
||||
onCommandSelect={mockOnCommandSelect}
|
||||
searchFilter=""
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('command-item-@app')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('command-item-@knowledge')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('command-item-@plugin')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('command-item-@node')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Filtering Functionality', () => {
|
||||
it('should filter actions based on searchFilter - single match', () => {
|
||||
render(
|
||||
<CommandSelector
|
||||
actions={mockActions}
|
||||
onCommandSelect={mockOnCommandSelect}
|
||||
searchFilter="k"
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByTestId('command-item-@app')).not.toBeInTheDocument()
|
||||
expect(screen.getByTestId('command-item-@knowledge')).toBeInTheDocument()
|
||||
expect(screen.queryByTestId('command-item-@plugin')).not.toBeInTheDocument()
|
||||
expect(screen.queryByTestId('command-item-@node')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should filter actions with multiple matches', () => {
|
||||
render(
|
||||
<CommandSelector
|
||||
actions={mockActions}
|
||||
onCommandSelect={mockOnCommandSelect}
|
||||
searchFilter="p"
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('command-item-@app')).toBeInTheDocument()
|
||||
expect(screen.queryByTestId('command-item-@knowledge')).not.toBeInTheDocument()
|
||||
expect(screen.getByTestId('command-item-@plugin')).toBeInTheDocument()
|
||||
expect(screen.queryByTestId('command-item-@node')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should be case-insensitive when filtering', () => {
|
||||
render(
|
||||
<CommandSelector
|
||||
actions={mockActions}
|
||||
onCommandSelect={mockOnCommandSelect}
|
||||
searchFilter="APP"
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('command-item-@app')).toBeInTheDocument()
|
||||
expect(screen.queryByTestId('command-item-@knowledge')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should match partial strings', () => {
|
||||
render(
|
||||
<CommandSelector
|
||||
actions={mockActions}
|
||||
onCommandSelect={mockOnCommandSelect}
|
||||
searchFilter="nowl"
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByTestId('command-item-@app')).not.toBeInTheDocument()
|
||||
expect(screen.getByTestId('command-item-@knowledge')).toBeInTheDocument()
|
||||
expect(screen.queryByTestId('command-item-@plugin')).not.toBeInTheDocument()
|
||||
expect(screen.queryByTestId('command-item-@node')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Empty State', () => {
|
||||
it('should show empty state when no matches found', () => {
|
||||
render(
|
||||
<CommandSelector
|
||||
actions={mockActions}
|
||||
onCommandSelect={mockOnCommandSelect}
|
||||
searchFilter="xyz"
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByTestId('command-item-@app')).not.toBeInTheDocument()
|
||||
expect(screen.queryByTestId('command-item-@knowledge')).not.toBeInTheDocument()
|
||||
expect(screen.queryByTestId('command-item-@plugin')).not.toBeInTheDocument()
|
||||
expect(screen.queryByTestId('command-item-@node')).not.toBeInTheDocument()
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.noMatchingCommands')).toBeInTheDocument()
|
||||
expect(screen.getByText('app.gotoAnything.tryDifferentSearch')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not show empty state when filter is empty', () => {
|
||||
render(
|
||||
<CommandSelector
|
||||
actions={mockActions}
|
||||
onCommandSelect={mockOnCommandSelect}
|
||||
searchFilter=""
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByText('app.gotoAnything.noMatchingCommands')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Selection and Highlight Management', () => {
|
||||
it('should call onCommandValueChange when filter changes and first item differs', () => {
|
||||
const { rerender } = render(
|
||||
<CommandSelector
|
||||
actions={mockActions}
|
||||
onCommandSelect={mockOnCommandSelect}
|
||||
searchFilter=""
|
||||
commandValue="@app"
|
||||
onCommandValueChange={mockOnCommandValueChange}
|
||||
/>,
|
||||
)
|
||||
|
||||
rerender(
|
||||
<CommandSelector
|
||||
actions={mockActions}
|
||||
onCommandSelect={mockOnCommandSelect}
|
||||
searchFilter="k"
|
||||
commandValue="@app"
|
||||
onCommandValueChange={mockOnCommandValueChange}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(mockOnCommandValueChange).toHaveBeenCalledWith('@knowledge')
|
||||
})
|
||||
|
||||
it('should not call onCommandValueChange if current value still exists', () => {
|
||||
const { rerender } = render(
|
||||
<CommandSelector
|
||||
actions={mockActions}
|
||||
onCommandSelect={mockOnCommandSelect}
|
||||
searchFilter=""
|
||||
commandValue="@app"
|
||||
onCommandValueChange={mockOnCommandValueChange}
|
||||
/>,
|
||||
)
|
||||
|
||||
rerender(
|
||||
<CommandSelector
|
||||
actions={mockActions}
|
||||
onCommandSelect={mockOnCommandSelect}
|
||||
searchFilter="a"
|
||||
commandValue="@app"
|
||||
onCommandValueChange={mockOnCommandValueChange}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(mockOnCommandValueChange).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle onCommandSelect callback correctly', () => {
|
||||
render(
|
||||
<CommandSelector
|
||||
actions={mockActions}
|
||||
onCommandSelect={mockOnCommandSelect}
|
||||
searchFilter="k"
|
||||
/>,
|
||||
)
|
||||
|
||||
const knowledgeItem = screen.getByTestId('command-item-@knowledge')
|
||||
fireEvent.click(knowledgeItem)
|
||||
|
||||
expect(mockOnCommandSelect).toHaveBeenCalledWith('@knowledge')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle empty actions object', () => {
|
||||
render(
|
||||
<CommandSelector
|
||||
actions={{}}
|
||||
onCommandSelect={mockOnCommandSelect}
|
||||
searchFilter=""
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('app.gotoAnything.noMatchingCommands')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle special characters in filter', () => {
|
||||
render(
|
||||
<CommandSelector
|
||||
actions={mockActions}
|
||||
onCommandSelect={mockOnCommandSelect}
|
||||
searchFilter="@"
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('command-item-@app')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('command-item-@knowledge')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('command-item-@plugin')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('command-item-@node')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle undefined onCommandValueChange gracefully', () => {
|
||||
const { rerender } = render(
|
||||
<CommandSelector
|
||||
actions={mockActions}
|
||||
onCommandSelect={mockOnCommandSelect}
|
||||
searchFilter=""
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(() => {
|
||||
rerender(
|
||||
<CommandSelector
|
||||
actions={mockActions}
|
||||
onCommandSelect={mockOnCommandSelect}
|
||||
searchFilter="k"
|
||||
/>,
|
||||
)
|
||||
}).not.toThrow()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Backward Compatibility', () => {
|
||||
it('should work without searchFilter prop (backward compatible)', () => {
|
||||
render(
|
||||
<CommandSelector
|
||||
actions={mockActions}
|
||||
onCommandSelect={mockOnCommandSelect}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('command-item-@app')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('command-item-@knowledge')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('command-item-@plugin')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('command-item-@node')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should work without commandValue and onCommandValueChange props', () => {
|
||||
render(
|
||||
<CommandSelector
|
||||
actions={mockActions}
|
||||
onCommandSelect={mockOnCommandSelect}
|
||||
searchFilter="k"
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('command-item-@knowledge')).toBeInTheDocument()
|
||||
expect(screen.queryByTestId('command-item-@app')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,197 @@
|
|||
/**
|
||||
* Test GotoAnything search error handling mechanisms
|
||||
*
|
||||
* Main validations:
|
||||
* 1. @plugin search error handling when API fails
|
||||
* 2. Regular search (without @prefix) error handling when API fails
|
||||
* 3. Verify consistent error handling across different search types
|
||||
* 4. Ensure errors don't propagate to UI layer causing "search failed"
|
||||
*/
|
||||
|
||||
import { Actions, searchAnything } from '@/app/components/goto-anything/actions'
|
||||
import { postMarketplace } from '@/service/base'
|
||||
import { fetchAppList } from '@/service/apps'
|
||||
import { fetchDatasets } from '@/service/datasets'
|
||||
|
||||
// Mock API functions
|
||||
jest.mock('@/service/base', () => ({
|
||||
postMarketplace: jest.fn(),
|
||||
}))
|
||||
|
||||
jest.mock('@/service/apps', () => ({
|
||||
fetchAppList: jest.fn(),
|
||||
}))
|
||||
|
||||
jest.mock('@/service/datasets', () => ({
|
||||
fetchDatasets: jest.fn(),
|
||||
}))
|
||||
|
||||
const mockPostMarketplace = postMarketplace as jest.MockedFunction<typeof postMarketplace>
|
||||
const mockFetchAppList = fetchAppList as jest.MockedFunction<typeof fetchAppList>
|
||||
const mockFetchDatasets = fetchDatasets as jest.MockedFunction<typeof fetchDatasets>
|
||||
|
||||
describe('GotoAnything Search Error Handling', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
// Suppress console.warn for clean test output
|
||||
jest.spyOn(console, 'warn').mockImplementation(() => {
|
||||
// Suppress console.warn for clean test output
|
||||
})
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
jest.restoreAllMocks()
|
||||
})
|
||||
|
||||
describe('@plugin search error handling', () => {
|
||||
it('should return empty array when API fails instead of throwing error', async () => {
|
||||
// Mock marketplace API failure (403 permission denied)
|
||||
mockPostMarketplace.mockRejectedValue(new Error('HTTP 403: Forbidden'))
|
||||
|
||||
const pluginAction = Actions.plugin
|
||||
|
||||
// Directly call plugin action's search method
|
||||
const result = await pluginAction.search('@plugin', 'test', 'en')
|
||||
|
||||
// Should return empty array instead of throwing error
|
||||
expect(result).toEqual([])
|
||||
expect(mockPostMarketplace).toHaveBeenCalledWith('/plugins/search/advanced', {
|
||||
body: {
|
||||
page: 1,
|
||||
page_size: 10,
|
||||
query: 'test',
|
||||
type: 'plugin',
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
it('should return empty array when user has no plugin data', async () => {
|
||||
// Mock marketplace returning empty data
|
||||
mockPostMarketplace.mockResolvedValue({
|
||||
data: { plugins: [] },
|
||||
})
|
||||
|
||||
const pluginAction = Actions.plugin
|
||||
const result = await pluginAction.search('@plugin', '', 'en')
|
||||
|
||||
expect(result).toEqual([])
|
||||
})
|
||||
|
||||
it('should return empty array when API returns unexpected data structure', async () => {
|
||||
// Mock API returning unexpected data structure
|
||||
mockPostMarketplace.mockResolvedValue({
|
||||
data: null,
|
||||
})
|
||||
|
||||
const pluginAction = Actions.plugin
|
||||
const result = await pluginAction.search('@plugin', 'test', 'en')
|
||||
|
||||
expect(result).toEqual([])
|
||||
})
|
||||
})
|
||||
|
||||
describe('Other search types error handling', () => {
|
||||
it('@app search should return empty array when API fails', async () => {
|
||||
// Mock app API failure
|
||||
mockFetchAppList.mockRejectedValue(new Error('API Error'))
|
||||
|
||||
const appAction = Actions.app
|
||||
const result = await appAction.search('@app', 'test', 'en')
|
||||
|
||||
expect(result).toEqual([])
|
||||
})
|
||||
|
||||
it('@knowledge search should return empty array when API fails', async () => {
|
||||
// Mock knowledge API failure
|
||||
mockFetchDatasets.mockRejectedValue(new Error('API Error'))
|
||||
|
||||
const knowledgeAction = Actions.knowledge
|
||||
const result = await knowledgeAction.search('@knowledge', 'test', 'en')
|
||||
|
||||
expect(result).toEqual([])
|
||||
})
|
||||
})
|
||||
|
||||
describe('Unified search entry error handling', () => {
|
||||
it('regular search (without @prefix) should return successful results even when partial APIs fail', async () => {
|
||||
// Set app and knowledge success, plugin failure
|
||||
mockFetchAppList.mockResolvedValue({ data: [], has_more: false, limit: 10, page: 1, total: 0 })
|
||||
mockFetchDatasets.mockResolvedValue({ data: [], has_more: false, limit: 10, page: 1, total: 0 })
|
||||
mockPostMarketplace.mockRejectedValue(new Error('Plugin API failed'))
|
||||
|
||||
const result = await searchAnything('en', 'test')
|
||||
|
||||
// Should return successful results even if plugin search fails
|
||||
expect(result).toEqual([])
|
||||
expect(console.warn).toHaveBeenCalledWith('Plugin search failed:', expect.any(Error))
|
||||
})
|
||||
|
||||
it('@plugin dedicated search should return empty array when API fails', async () => {
|
||||
// Mock plugin API failure
|
||||
mockPostMarketplace.mockRejectedValue(new Error('Plugin service unavailable'))
|
||||
|
||||
const pluginAction = Actions.plugin
|
||||
const result = await searchAnything('en', '@plugin test', pluginAction)
|
||||
|
||||
// Should return empty array instead of throwing error
|
||||
expect(result).toEqual([])
|
||||
})
|
||||
|
||||
it('@app dedicated search should return empty array when API fails', async () => {
|
||||
// Mock app API failure
|
||||
mockFetchAppList.mockRejectedValue(new Error('App service unavailable'))
|
||||
|
||||
const appAction = Actions.app
|
||||
const result = await searchAnything('en', '@app test', appAction)
|
||||
|
||||
expect(result).toEqual([])
|
||||
})
|
||||
})
|
||||
|
||||
describe('Error handling consistency validation', () => {
|
||||
it('all search types should return empty array when encountering errors', async () => {
|
||||
// Mock all APIs to fail
|
||||
mockPostMarketplace.mockRejectedValue(new Error('Plugin API failed'))
|
||||
mockFetchAppList.mockRejectedValue(new Error('App API failed'))
|
||||
mockFetchDatasets.mockRejectedValue(new Error('Dataset API failed'))
|
||||
|
||||
const actions = [
|
||||
{ name: '@plugin', action: Actions.plugin },
|
||||
{ name: '@app', action: Actions.app },
|
||||
{ name: '@knowledge', action: Actions.knowledge },
|
||||
]
|
||||
|
||||
for (const { name, action } of actions) {
|
||||
const result = await action.search(name, 'test', 'en')
|
||||
expect(result).toEqual([])
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edge case testing', () => {
|
||||
it('empty search term should be handled properly', async () => {
|
||||
mockPostMarketplace.mockResolvedValue({ data: { plugins: [] } })
|
||||
|
||||
const result = await searchAnything('en', '@plugin ', Actions.plugin)
|
||||
expect(result).toEqual([])
|
||||
})
|
||||
|
||||
it('network timeout should be handled correctly', async () => {
|
||||
const timeoutError = new Error('Network timeout')
|
||||
timeoutError.name = 'TimeoutError'
|
||||
|
||||
mockPostMarketplace.mockRejectedValue(timeoutError)
|
||||
|
||||
const result = await searchAnything('en', '@plugin test', Actions.plugin)
|
||||
expect(result).toEqual([])
|
||||
})
|
||||
|
||||
it('JSON parsing errors should be handled correctly', async () => {
|
||||
const parseError = new SyntaxError('Unexpected token in JSON')
|
||||
mockPostMarketplace.mockRejectedValue(parseError)
|
||||
|
||||
const result = await searchAnything('en', '@plugin test', Actions.plugin)
|
||||
expect(result).toEqual([])
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
@ -8,6 +8,7 @@ import Header from '@/app/components/header'
|
|||
import { EventEmitterContextProvider } from '@/context/event-emitter'
|
||||
import { ProviderContextProvider } from '@/context/provider-context'
|
||||
import { ModalContextProvider } from '@/context/modal-context'
|
||||
import GotoAnything from '@/app/components/goto-anything'
|
||||
|
||||
const Layout = ({ children }: { children: ReactNode }) => {
|
||||
return (
|
||||
|
|
@ -22,6 +23,7 @@ const Layout = ({ children }: { children: ReactNode }) => {
|
|||
<Header />
|
||||
</HeaderWrapper>
|
||||
{children}
|
||||
<GotoAnything />
|
||||
</ModalContextProvider>
|
||||
</ProviderContextProvider>
|
||||
</EventEmitterContextProvider>
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ export type InputProps = {
|
|||
unit?: string
|
||||
} & Omit<React.InputHTMLAttributes<HTMLInputElement>, 'size'> & VariantProps<typeof inputVariants>
|
||||
|
||||
const Input = ({
|
||||
const Input = React.forwardRef<HTMLInputElement, InputProps>(({
|
||||
size,
|
||||
disabled,
|
||||
destructive,
|
||||
|
|
@ -47,12 +47,13 @@ const Input = ({
|
|||
onChange = noop,
|
||||
unit,
|
||||
...props
|
||||
}: InputProps) => {
|
||||
}, ref) => {
|
||||
const { t } = useTranslation()
|
||||
return (
|
||||
<div className={cn('relative w-full', wrapperClassName)}>
|
||||
{showLeftIcon && <RiSearchLine className={cn('absolute left-2 top-1/2 h-4 w-4 -translate-y-1/2 text-components-input-text-placeholder')} />}
|
||||
<input
|
||||
ref={ref}
|
||||
style={styleCss}
|
||||
className={cn(
|
||||
'w-full appearance-none border border-transparent bg-components-input-bg-normal py-[7px] text-components-input-text-filled caret-primary-600 outline-none placeholder:text-components-input-text-placeholder hover:border-components-input-border-hover hover:bg-components-input-bg-hover focus:border-components-input-border-active focus:bg-components-input-bg-active focus:shadow-xs',
|
||||
|
|
@ -92,6 +93,8 @@ const Input = ({
|
|||
}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
})
|
||||
|
||||
Input.displayName = 'Input'
|
||||
|
||||
export default Input
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ type IModal = {
|
|||
children?: React.ReactNode
|
||||
closable?: boolean
|
||||
overflowVisible?: boolean
|
||||
highPriority?: boolean // For modals that need to appear above dropdowns
|
||||
}
|
||||
|
||||
export default function Modal({
|
||||
|
|
@ -27,10 +28,11 @@ export default function Modal({
|
|||
children,
|
||||
closable = false,
|
||||
overflowVisible = false,
|
||||
highPriority = false,
|
||||
}: IModal) {
|
||||
return (
|
||||
<Transition appear show={isShow} as={Fragment}>
|
||||
<Dialog as="div" className={classNames('relative z-[60]', wrapperClassName)} onClose={onClose}>
|
||||
<Dialog as="div" className={classNames('relative', highPriority ? 'z-[1100]' : 'z-[60]', wrapperClassName)} onClose={onClose}>
|
||||
<TransitionChild>
|
||||
<div className={classNames(
|
||||
'fixed inset-0 bg-background-overlay',
|
||||
|
|
|
|||
|
|
@ -192,6 +192,7 @@ const SimpleSelect: FC<ISelectProps> = ({
|
|||
const localPlaceholder = placeholder || t('common.placeholder.select')
|
||||
|
||||
const [selectedItem, setSelectedItem] = useState<Item | null>(null)
|
||||
const [open, setOpen] = useState(false)
|
||||
|
||||
useEffect(() => {
|
||||
let defaultSelect = null
|
||||
|
|
@ -220,8 +221,11 @@ const SimpleSelect: FC<ISelectProps> = ({
|
|||
<ListboxButton onClick={() => {
|
||||
// get data-open, use setTimeout to ensure the attribute is set
|
||||
setTimeout(() => {
|
||||
if (listboxRef.current)
|
||||
onOpenChange?.(listboxRef.current.getAttribute('data-open') !== null)
|
||||
if (listboxRef.current) {
|
||||
const isOpen = listboxRef.current.getAttribute('data-open') !== null
|
||||
setOpen(isOpen)
|
||||
onOpenChange?.(isOpen)
|
||||
}
|
||||
})
|
||||
}} className={classNames(`flex h-full w-full items-center rounded-lg border-0 bg-components-input-bg-normal pl-3 pr-10 focus-visible:bg-state-base-hover-alt focus-visible:outline-none group-hover/simple-select:bg-state-base-hover-alt sm:text-sm sm:leading-6 ${disabled ? 'cursor-not-allowed' : 'cursor-pointer'}`, className)}>
|
||||
<span className={classNames('system-sm-regular block truncate text-left text-components-input-text-filled', !selectedItem?.name && 'text-components-input-text-placeholder')}>{selectedItem?.name ?? localPlaceholder}</span>
|
||||
|
|
@ -240,10 +244,17 @@ const SimpleSelect: FC<ISelectProps> = ({
|
|||
/>
|
||||
)
|
||||
: (
|
||||
<ChevronDownIcon
|
||||
className="h-4 w-4 text-text-quaternary group-hover/simple-select:text-text-secondary"
|
||||
aria-hidden="true"
|
||||
/>
|
||||
open ? (
|
||||
<ChevronUpIcon
|
||||
className="h-4 w-4 text-text-quaternary group-hover/simple-select:text-text-secondary"
|
||||
aria-hidden="true"
|
||||
/>
|
||||
) : (
|
||||
<ChevronDownIcon
|
||||
className="h-4 w-4 text-text-quaternary group-hover/simple-select:text-text-secondary"
|
||||
aria-hidden="true"
|
||||
/>
|
||||
)
|
||||
)}
|
||||
</span>
|
||||
</ListboxButton>
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ const TagFilter: FC<TagFilterProps> = ({
|
|||
className='block'
|
||||
>
|
||||
<div className={cn(
|
||||
'flex h-8 cursor-pointer items-center gap-1 rounded-lg border-[0.5px] border-transparent bg-components-input-bg-normal px-2',
|
||||
'flex h-8 cursor-pointer select-none items-center gap-1 rounded-lg border-[0.5px] border-transparent bg-components-input-bg-normal px-2',
|
||||
!open && !!value.length && 'shadow-xs',
|
||||
open && !!value.length && 'shadow-xs',
|
||||
)}>
|
||||
|
|
@ -123,7 +123,7 @@ const TagFilter: FC<TagFilterProps> = ({
|
|||
{filteredTagList.map(tag => (
|
||||
<div
|
||||
key={tag.id}
|
||||
className='flex cursor-pointer items-center gap-2 rounded-lg py-[6px] pl-3 pr-2 hover:bg-state-base-hover'
|
||||
className='flex cursor-pointer select-none items-center gap-2 rounded-lg py-[6px] pl-3 pr-2 hover:bg-state-base-hover'
|
||||
onClick={() => selectTag(tag)}
|
||||
>
|
||||
<div title={tag.name} className='grow truncate text-sm leading-5 text-text-tertiary'>{tag.name}</div>
|
||||
|
|
@ -139,7 +139,7 @@ const TagFilter: FC<TagFilterProps> = ({
|
|||
</div>
|
||||
<div className='border-t-[0.5px] border-divider-regular' />
|
||||
<div className='p-1'>
|
||||
<div className='flex cursor-pointer items-center gap-2 rounded-lg py-[6px] pl-3 pr-2 hover:bg-state-base-hover' onClick={() => {
|
||||
<div className='flex cursor-pointer select-none items-center gap-2 rounded-lg py-[6px] pl-3 pr-2 hover:bg-state-base-hover' onClick={() => {
|
||||
setShowTagManagementModal(true)
|
||||
setOpen(false)
|
||||
}}>
|
||||
|
|
|
|||
|
|
@ -69,9 +69,11 @@ const RenameDatasetModal = ({ show, dataset, onSuccess, onClose }: RenameDataset
|
|||
isShow={show}
|
||||
onClose={noop}
|
||||
>
|
||||
<div className='relative pb-2 text-xl font-medium leading-[30px] text-text-primary'>{t('datasetSettings.title')}</div>
|
||||
<div className='absolute right-4 top-4 cursor-pointer p-2' onClick={onClose}>
|
||||
<RiCloseLine className='h-4 w-4 text-text-tertiary' />
|
||||
<div className='flex items-center justify-between pb-2'>
|
||||
<div className='text-xl font-medium leading-[30px] text-text-primary'>{t('datasetSettings.title')}</div>
|
||||
<div className='cursor-pointer p-2' onClick={onClose}>
|
||||
<RiCloseLine className='h-4 w-4 text-text-tertiary' />
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<div className={cn('flex flex-wrap items-center justify-between py-4')}>
|
||||
|
|
|
|||
|
|
@ -0,0 +1,58 @@
|
|||
import type { ActionItem, AppSearchResult } from './types'
|
||||
import type { App } from '@/types/app'
|
||||
import { fetchAppList } from '@/service/apps'
|
||||
import AppIcon from '../../base/app-icon'
|
||||
import { AppTypeIcon } from '../../app/type-selector'
|
||||
import { getRedirectionPath } from '@/utils/app-redirection'
|
||||
|
||||
const parser = (apps: App[]): AppSearchResult[] => {
|
||||
return apps.map(app => ({
|
||||
id: app.id,
|
||||
title: app.name,
|
||||
description: app.description,
|
||||
type: 'app' as const,
|
||||
path: getRedirectionPath(true, {
|
||||
id: app.id,
|
||||
mode: app.mode,
|
||||
}),
|
||||
icon: (
|
||||
<div className='relative shrink-0'>
|
||||
<AppIcon
|
||||
size='large'
|
||||
iconType={app.icon_type}
|
||||
icon={app.icon}
|
||||
background={app.icon_background}
|
||||
imageUrl={app.icon_url}
|
||||
/>
|
||||
<AppTypeIcon wrapperClassName='absolute -bottom-0.5 -right-0.5 w-4 h-4 rounded-[4px] border border-divider-regular outline outline-components-panel-on-panel-item-bg'
|
||||
className='h-3 w-3' type={app.mode} />
|
||||
</div>
|
||||
),
|
||||
data: app,
|
||||
}))
|
||||
}
|
||||
|
||||
export const appAction: ActionItem = {
|
||||
key: '@app',
|
||||
shortcut: '@app',
|
||||
title: 'Search Applications',
|
||||
description: 'Search and navigate to your applications',
|
||||
// action,
|
||||
search: async (_, searchTerm = '', _locale) => {
|
||||
try {
|
||||
const response = await fetchAppList({
|
||||
url: 'apps',
|
||||
params: {
|
||||
page: 1,
|
||||
name: searchTerm,
|
||||
},
|
||||
})
|
||||
const apps = response?.data || []
|
||||
return parser(apps)
|
||||
}
|
||||
catch (error) {
|
||||
console.warn('App search failed:', error)
|
||||
return []
|
||||
}
|
||||
},
|
||||
}
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
export type CommandHandler = (args?: Record<string, any>) => void | Promise<void>
|
||||
|
||||
const handlers = new Map<string, CommandHandler>()
|
||||
|
||||
export const registerCommand = (name: string, handler: CommandHandler) => {
|
||||
handlers.set(name, handler)
|
||||
}
|
||||
|
||||
export const unregisterCommand = (name: string) => {
|
||||
handlers.delete(name)
|
||||
}
|
||||
|
||||
export const executeCommand = async (name: string, args?: Record<string, any>) => {
|
||||
const handler = handlers.get(name)
|
||||
if (!handler)
|
||||
return
|
||||
await handler(args)
|
||||
}
|
||||
|
||||
export const registerCommands = (map: Record<string, CommandHandler>) => {
|
||||
Object.entries(map).forEach(([name, handler]) => registerCommand(name, handler))
|
||||
}
|
||||
|
||||
export const unregisterCommands = (names: string[]) => {
|
||||
names.forEach(unregisterCommand)
|
||||
}
|
||||
|
|
@ -0,0 +1,76 @@
|
|||
import { appAction } from './app'
|
||||
import { knowledgeAction } from './knowledge'
|
||||
import { pluginAction } from './plugin'
|
||||
import { workflowNodesAction } from './workflow-nodes'
|
||||
import type { ActionItem, SearchResult } from './types'
|
||||
import { commandAction } from './run'
|
||||
|
||||
export const Actions = {
|
||||
app: appAction,
|
||||
knowledge: knowledgeAction,
|
||||
plugin: pluginAction,
|
||||
run: commandAction,
|
||||
node: workflowNodesAction,
|
||||
}
|
||||
|
||||
export const searchAnything = async (
|
||||
locale: string,
|
||||
query: string,
|
||||
actionItem?: ActionItem,
|
||||
): Promise<SearchResult[]> => {
|
||||
if (actionItem) {
|
||||
const searchTerm = query.replace(actionItem.key, '').replace(actionItem.shortcut, '').trim()
|
||||
try {
|
||||
return await actionItem.search(query, searchTerm, locale)
|
||||
}
|
||||
catch (error) {
|
||||
console.warn(`Search failed for ${actionItem.key}:`, error)
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
if (query.startsWith('@'))
|
||||
return []
|
||||
|
||||
// Use Promise.allSettled to handle partial failures gracefully
|
||||
const searchPromises = Object.values(Actions).map(async (action) => {
|
||||
try {
|
||||
const results = await action.search(query, query, locale)
|
||||
return { success: true, data: results, actionType: action.key }
|
||||
}
|
||||
catch (error) {
|
||||
console.warn(`Search failed for ${action.key}:`, error)
|
||||
return { success: false, data: [], actionType: action.key, error }
|
||||
}
|
||||
})
|
||||
|
||||
const settledResults = await Promise.allSettled(searchPromises)
|
||||
|
||||
const allResults: SearchResult[] = []
|
||||
const failedActions: string[] = []
|
||||
|
||||
settledResults.forEach((result, index) => {
|
||||
if (result.status === 'fulfilled' && result.value.success) {
|
||||
allResults.push(...result.value.data)
|
||||
}
|
||||
else {
|
||||
const actionKey = Object.values(Actions)[index]?.key || 'unknown'
|
||||
failedActions.push(actionKey)
|
||||
}
|
||||
})
|
||||
|
||||
if (failedActions.length > 0)
|
||||
console.warn(`Some search actions failed: ${failedActions.join(', ')}`)
|
||||
|
||||
return allResults
|
||||
}
|
||||
|
||||
export const matchAction = (query: string, actions: Record<string, ActionItem>) => {
|
||||
return Object.values(actions).find((action) => {
|
||||
const reg = new RegExp(`^(${action.key}|${action.shortcut})(?:\\s|$)`)
|
||||
return reg.test(query)
|
||||
})
|
||||
}
|
||||
|
||||
export * from './types'
|
||||
export { appAction, knowledgeAction, pluginAction, workflowNodesAction }
|
||||
|
|
@ -0,0 +1,56 @@
|
|||
import type { ActionItem, KnowledgeSearchResult } from './types'
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
import { fetchDatasets } from '@/service/datasets'
|
||||
import { Folder } from '../../base/icons/src/vender/solid/files'
|
||||
import cn from '@/utils/classnames'
|
||||
|
||||
const EXTERNAL_PROVIDER = 'external' as const
|
||||
const isExternalProvider = (provider: string): boolean => provider === EXTERNAL_PROVIDER
|
||||
|
||||
const parser = (datasets: DataSet[]): KnowledgeSearchResult[] => {
|
||||
return datasets.map((dataset) => {
|
||||
const path = isExternalProvider(dataset.provider) ? `/datasets/${dataset.id}/hitTesting` : `/datasets/${dataset.id}/documents`
|
||||
return {
|
||||
id: dataset.id,
|
||||
title: dataset.name,
|
||||
description: dataset.description,
|
||||
type: 'knowledge' as const,
|
||||
path,
|
||||
icon: (
|
||||
<div className={cn(
|
||||
'flex shrink-0 items-center justify-center rounded-md border-[0.5px] border-[#E0EAFF] bg-[#F5F8FF] p-2.5',
|
||||
!dataset.embedding_available && 'opacity-50 hover:opacity-100',
|
||||
)}>
|
||||
<Folder className='h-5 w-5 text-[#444CE7]' />
|
||||
</div>
|
||||
),
|
||||
data: dataset,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
export const knowledgeAction: ActionItem = {
|
||||
key: '@knowledge',
|
||||
shortcut: '@kb',
|
||||
title: 'Search Knowledge Bases',
|
||||
description: 'Search and navigate to your knowledge bases',
|
||||
// action,
|
||||
search: async (_, searchTerm = '', _locale) => {
|
||||
try {
|
||||
const response = await fetchDatasets({
|
||||
url: '/datasets',
|
||||
params: {
|
||||
page: 1,
|
||||
limit: 10,
|
||||
keyword: searchTerm,
|
||||
},
|
||||
})
|
||||
const datasets = response?.data || []
|
||||
return parser(datasets)
|
||||
}
|
||||
catch (error) {
|
||||
console.warn('Knowledge search failed:', error)
|
||||
return []
|
||||
}
|
||||
},
|
||||
}
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
import type { ActionItem, PluginSearchResult } from './types'
|
||||
import { renderI18nObject } from '@/i18n-config'
|
||||
import Icon from '../../plugins/card/base/card-icon'
|
||||
import { postMarketplace } from '@/service/base'
|
||||
import type { Plugin, PluginsFromMarketplaceResponse } from '../../plugins/types'
|
||||
import { getPluginIconInMarketplace } from '../../plugins/marketplace/utils'
|
||||
|
||||
const parser = (plugins: Plugin[], locale: string): PluginSearchResult[] => {
|
||||
return plugins.map((plugin) => {
|
||||
return {
|
||||
id: plugin.name,
|
||||
title: renderI18nObject(plugin.label, locale) || plugin.name,
|
||||
description: renderI18nObject(plugin.brief, locale) || '',
|
||||
type: 'plugin' as const,
|
||||
icon: <Icon src={plugin.icon} />,
|
||||
data: plugin,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
export const pluginAction: ActionItem = {
|
||||
key: '@plugin',
|
||||
shortcut: '@plugin',
|
||||
title: 'Search Plugins',
|
||||
description: 'Search and navigate to your plugins',
|
||||
search: async (_, searchTerm = '', locale) => {
|
||||
try {
|
||||
const response = await postMarketplace<{ data: PluginsFromMarketplaceResponse }>('/plugins/search/advanced', {
|
||||
body: {
|
||||
page: 1,
|
||||
page_size: 10,
|
||||
query: searchTerm,
|
||||
type: 'plugin',
|
||||
},
|
||||
})
|
||||
|
||||
if (!response?.data?.plugins) {
|
||||
console.warn('Plugin search: Unexpected response structure', response)
|
||||
return []
|
||||
}
|
||||
|
||||
const list = response.data.plugins.map(plugin => ({
|
||||
...plugin,
|
||||
icon: getPluginIconInMarketplace(plugin),
|
||||
}))
|
||||
return parser(list, locale!)
|
||||
}
|
||||
catch (error) {
|
||||
console.warn('Plugin search failed:', error)
|
||||
return []
|
||||
}
|
||||
},
|
||||
}
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
import type { CommandSearchResult } from './types'
|
||||
import { languages } from '@/i18n-config/language'
|
||||
import { RiTranslate } from '@remixicon/react'
|
||||
import i18n from '@/i18n-config/i18next-config'
|
||||
|
||||
export const buildLanguageCommands = (query: string): CommandSearchResult[] => {
|
||||
const q = query.toLowerCase()
|
||||
const list = languages.filter(item => item.supported && (
|
||||
!q || item.name.toLowerCase().includes(q) || String(item.value).toLowerCase().includes(q)
|
||||
))
|
||||
return list.map(item => ({
|
||||
id: `lang-${item.value}`,
|
||||
title: item.name,
|
||||
description: i18n.t('app.gotoAnything.actions.languageChangeDesc'),
|
||||
type: 'command' as const,
|
||||
data: { command: 'i18n.set', args: { locale: item.value } },
|
||||
}))
|
||||
}
|
||||
|
||||
export const buildLanguageRootItem = (): CommandSearchResult => {
|
||||
return {
|
||||
id: 'category-language',
|
||||
title: i18n.t('app.gotoAnything.actions.languageCategoryTitle'),
|
||||
description: i18n.t('app.gotoAnything.actions.languageCategoryDesc'),
|
||||
type: 'command',
|
||||
icon: (
|
||||
<div className='flex h-6 w-6 items-center justify-center rounded-md border-[0.5px] border-divider-regular bg-components-panel-bg'>
|
||||
<RiTranslate className='h-4 w-4 text-text-tertiary' />
|
||||
</div>
|
||||
),
|
||||
data: { command: 'nav.search', args: { query: '@run language ' } },
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
import type { CommandSearchResult } from './types'
|
||||
import type { ReactNode } from 'react'
|
||||
import { RiComputerLine, RiMoonLine, RiPaletteLine, RiSunLine } from '@remixicon/react'
|
||||
import i18n from '@/i18n-config/i18next-config'
|
||||
|
||||
const THEME_ITEMS: { id: 'light' | 'dark' | 'system'; titleKey: string; descKey: string; icon: ReactNode }[] = [
|
||||
{
|
||||
id: 'system',
|
||||
titleKey: 'app.gotoAnything.actions.themeSystem',
|
||||
descKey: 'app.gotoAnything.actions.themeSystemDesc',
|
||||
icon: <RiComputerLine className='h-4 w-4 text-text-tertiary' />,
|
||||
},
|
||||
{
|
||||
id: 'light',
|
||||
titleKey: 'app.gotoAnything.actions.themeLight',
|
||||
descKey: 'app.gotoAnything.actions.themeLightDesc',
|
||||
icon: <RiSunLine className='h-4 w-4 text-text-tertiary' />,
|
||||
},
|
||||
{
|
||||
id: 'dark',
|
||||
titleKey: 'app.gotoAnything.actions.themeDark',
|
||||
descKey: 'app.gotoAnything.actions.themeDarkDesc',
|
||||
icon: <RiMoonLine className='h-4 w-4 text-text-tertiary' />,
|
||||
},
|
||||
]
|
||||
|
||||
export const buildThemeCommands = (query: string, locale?: string): CommandSearchResult[] => {
|
||||
const q = query.toLowerCase()
|
||||
const list = THEME_ITEMS.filter(item =>
|
||||
!q
|
||||
|| i18n.t(item.titleKey, { lng: locale }).toLowerCase().includes(q)
|
||||
|| item.id.includes(q),
|
||||
)
|
||||
return list.map(item => ({
|
||||
id: item.id,
|
||||
title: i18n.t(item.titleKey, { lng: locale }),
|
||||
description: i18n.t(item.descKey, { lng: locale }),
|
||||
type: 'command' as const,
|
||||
icon: (
|
||||
<div className='flex h-6 w-6 items-center justify-center rounded-md border-[0.5px] border-divider-regular bg-components-panel-bg'>
|
||||
{item.icon}
|
||||
</div>
|
||||
),
|
||||
data: { command: 'theme.set', args: { value: item.id } },
|
||||
}))
|
||||
}
|
||||
|
||||
export const buildThemeRootItem = (): CommandSearchResult => {
|
||||
return {
|
||||
id: 'category-theme',
|
||||
title: i18n.t('app.gotoAnything.actions.themeCategoryTitle'),
|
||||
description: i18n.t('app.gotoAnything.actions.themeCategoryDesc'),
|
||||
type: 'command',
|
||||
icon: (
|
||||
<div className='flex h-6 w-6 items-center justify-center rounded-md border-[0.5px] border-divider-regular bg-components-panel-bg'>
|
||||
<RiPaletteLine className='h-4 w-4 text-text-tertiary' />
|
||||
</div>
|
||||
),
|
||||
data: { command: 'nav.search', args: { query: '@run theme ' } },
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,97 @@
|
|||
'use client'
|
||||
import { useEffect } from 'react'
|
||||
import type { ActionItem, CommandSearchResult } from './types'
|
||||
import { buildLanguageCommands, buildLanguageRootItem } from './run-language'
|
||||
import { buildThemeCommands, buildThemeRootItem } from './run-theme'
|
||||
import i18n from '@/i18n-config/i18next-config'
|
||||
import { executeCommand, registerCommands, unregisterCommands } from './command-bus'
|
||||
import { useTheme } from 'next-themes'
|
||||
import { setLocaleOnClient } from '@/i18n-config'
|
||||
|
||||
const rootParser = (query: string): CommandSearchResult[] => {
|
||||
const q = query.toLowerCase()
|
||||
const items: CommandSearchResult[] = []
|
||||
if (!q || 'theme'.includes(q))
|
||||
items.push(buildThemeRootItem())
|
||||
if (!q || 'language'.includes(q) || 'lang'.includes(q))
|
||||
items.push(buildLanguageRootItem())
|
||||
return items
|
||||
}
|
||||
|
||||
type RunContext = {
|
||||
setTheme?: (value: 'light' | 'dark' | 'system') => void
|
||||
setLocale?: (locale: string) => Promise<void>
|
||||
search?: (query: string) => void
|
||||
}
|
||||
|
||||
export const commandAction: ActionItem = {
|
||||
key: '@run',
|
||||
shortcut: '@run',
|
||||
title: i18n.t('app.gotoAnything.actions.runTitle'),
|
||||
description: i18n.t('app.gotoAnything.actions.runDesc'),
|
||||
action: (result) => {
|
||||
if (result.type !== 'command') return
|
||||
const { command, args } = result.data
|
||||
if (command === 'theme.set') {
|
||||
executeCommand('theme.set', args)
|
||||
return
|
||||
}
|
||||
if (command === 'i18n.set') {
|
||||
executeCommand('i18n.set', args)
|
||||
return
|
||||
}
|
||||
if (command === 'nav.search')
|
||||
executeCommand('nav.search', args)
|
||||
},
|
||||
search: async (_, searchTerm = '') => {
|
||||
const q = searchTerm.trim()
|
||||
if (q.startsWith('theme'))
|
||||
return buildThemeCommands(q.replace(/^theme\s*/, ''), i18n.language)
|
||||
if (q.startsWith('language') || q.startsWith('lang'))
|
||||
return buildLanguageCommands(q.replace(/^(language|lang)\s*/, ''))
|
||||
|
||||
// root categories
|
||||
return rootParser(q)
|
||||
},
|
||||
}
|
||||
|
||||
// Register/unregister default handlers for @run commands with external dependencies.
|
||||
export const registerRunCommands = (deps: {
|
||||
setTheme?: (value: 'light' | 'dark' | 'system') => void
|
||||
setLocale?: (locale: string) => Promise<void>
|
||||
search?: (query: string) => void
|
||||
}) => {
|
||||
registerCommands({
|
||||
'theme.set': async (args) => {
|
||||
deps.setTheme?.(args?.value)
|
||||
},
|
||||
'i18n.set': async (args) => {
|
||||
const locale = args?.locale
|
||||
if (locale)
|
||||
await deps.setLocale?.(locale)
|
||||
},
|
||||
'nav.search': (args) => {
|
||||
const q = args?.query
|
||||
if (q)
|
||||
deps.search?.(q)
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
export const unregisterRunCommands = () => {
|
||||
unregisterCommands(['theme.set', 'i18n.set', 'nav.search'])
|
||||
}
|
||||
|
||||
export const RunCommandProvider = ({ onNavSearch }: { onNavSearch?: (q: string) => void }) => {
|
||||
const theme = useTheme()
|
||||
useEffect(() => {
|
||||
registerRunCommands({
|
||||
setTheme: theme.setTheme,
|
||||
setLocale: setLocaleOnClient,
|
||||
search: onNavSearch,
|
||||
})
|
||||
return () => unregisterRunCommands()
|
||||
}, [theme.setTheme, onNavSearch])
|
||||
|
||||
return null
|
||||
}
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
import type { ReactNode } from 'react'
|
||||
import type { TypeWithI18N } from '../../base/form/types'
|
||||
import type { App } from '@/types/app'
|
||||
import type { Plugin } from '../../plugins/types'
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
import type { CommonNodeType } from '../../workflow/types'
|
||||
|
||||
export type SearchResultType = 'app' | 'knowledge' | 'plugin' | 'workflow-node' | 'command'
|
||||
|
||||
export type BaseSearchResult<T = any> = {
|
||||
id: string
|
||||
title: string
|
||||
description?: string
|
||||
type: SearchResultType
|
||||
path?: string
|
||||
icon?: ReactNode
|
||||
data: T
|
||||
}
|
||||
|
||||
export type AppSearchResult = {
|
||||
type: 'app'
|
||||
} & BaseSearchResult<App>
|
||||
|
||||
export type PluginSearchResult = {
|
||||
type: 'plugin'
|
||||
} & BaseSearchResult<Plugin>
|
||||
|
||||
export type KnowledgeSearchResult = {
|
||||
type: 'knowledge'
|
||||
} & BaseSearchResult<DataSet>
|
||||
|
||||
export type WorkflowNodeSearchResult = {
|
||||
type: 'workflow-node'
|
||||
metadata?: {
|
||||
nodeId: string
|
||||
nodeData: CommonNodeType
|
||||
}
|
||||
} & BaseSearchResult<CommonNodeType>
|
||||
|
||||
export type CommandSearchResult = {
|
||||
type: 'command'
|
||||
} & BaseSearchResult<{ command: string; args?: Record<string, any> }>
|
||||
|
||||
export type SearchResult = AppSearchResult | PluginSearchResult | KnowledgeSearchResult | WorkflowNodeSearchResult | CommandSearchResult
|
||||
|
||||
export type ActionItem = {
|
||||
key: '@app' | '@knowledge' | '@plugin' | '@node' | '@run'
|
||||
shortcut: string
|
||||
title: string | TypeWithI18N
|
||||
description: string
|
||||
action?: (data: SearchResult) => void
|
||||
searchFn?: (searchTerm: string) => SearchResult[]
|
||||
search: (
|
||||
query: string,
|
||||
searchTerm: string,
|
||||
locale?: string,
|
||||
) => (Promise<SearchResult[]> | SearchResult[])
|
||||
}
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
import type { ActionItem } from './types'
|
||||
|
||||
// Create the workflow nodes action
|
||||
export const workflowNodesAction: ActionItem = {
|
||||
key: '@node',
|
||||
shortcut: '@node',
|
||||
title: 'Search Workflow Nodes',
|
||||
description: 'Find and jump to nodes in the current workflow by name or type',
|
||||
searchFn: undefined, // Will be set by useWorkflowSearch hook
|
||||
search: async (_, searchTerm = '', locale) => {
|
||||
try {
|
||||
// Use the searchFn if available (set by useWorkflowSearch hook)
|
||||
if (workflowNodesAction.searchFn)
|
||||
return workflowNodesAction.searchFn(searchTerm)
|
||||
|
||||
// If not in workflow context, return empty array
|
||||
return []
|
||||
}
|
||||
catch (error) {
|
||||
console.warn('Workflow nodes search failed:', error)
|
||||
return []
|
||||
}
|
||||
},
|
||||
}
|
||||
|
|
@ -0,0 +1,89 @@
|
|||
import type { FC } from 'react'
|
||||
import { useEffect } from 'react'
|
||||
import { Command } from 'cmdk'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import type { ActionItem } from './actions/types'
|
||||
|
||||
type Props = {
|
||||
actions: Record<string, ActionItem>
|
||||
onCommandSelect: (commandKey: string) => void
|
||||
searchFilter?: string
|
||||
commandValue?: string
|
||||
onCommandValueChange?: (value: string) => void
|
||||
}
|
||||
|
||||
const CommandSelector: FC<Props> = ({ actions, onCommandSelect, searchFilter, commandValue, onCommandValueChange }) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const filteredActions = Object.values(actions).filter((action) => {
|
||||
if (!searchFilter)
|
||||
return true
|
||||
const filterLower = searchFilter.toLowerCase()
|
||||
return action.shortcut.toLowerCase().includes(filterLower)
|
||||
|| action.key.toLowerCase().includes(filterLower)
|
||||
})
|
||||
|
||||
useEffect(() => {
|
||||
if (filteredActions.length > 0 && onCommandValueChange) {
|
||||
const currentValueExists = filteredActions.some(action => action.shortcut === commandValue)
|
||||
if (!currentValueExists)
|
||||
onCommandValueChange(filteredActions[0].shortcut)
|
||||
}
|
||||
}, [searchFilter, filteredActions.length])
|
||||
|
||||
if (filteredActions.length === 0) {
|
||||
return (
|
||||
<div className="p-4">
|
||||
<div className="flex items-center justify-center py-8 text-center text-text-tertiary">
|
||||
<div>
|
||||
<div className="text-sm font-medium text-text-tertiary">
|
||||
{t('app.gotoAnything.noMatchingCommands')}
|
||||
</div>
|
||||
<div className="mt-1 text-xs text-text-quaternary">
|
||||
{t('app.gotoAnything.tryDifferentSearch')}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="p-4">
|
||||
<div className="mb-3 text-left text-sm font-medium text-text-secondary">
|
||||
{t('app.gotoAnything.selectSearchType')}
|
||||
</div>
|
||||
<Command.Group className="space-y-1">
|
||||
{filteredActions.map(action => (
|
||||
<Command.Item
|
||||
key={action.key}
|
||||
value={action.shortcut}
|
||||
className="flex cursor-pointer items-center rounded-md
|
||||
p-2.5
|
||||
transition-all
|
||||
duration-150 hover:bg-state-base-hover-alt aria-[selected=true]:bg-state-base-hover"
|
||||
onSelect={() => onCommandSelect(action.shortcut)}
|
||||
>
|
||||
<span className="min-w-[4.5rem] text-left font-mono text-xs text-text-tertiary">
|
||||
{action.shortcut}
|
||||
</span>
|
||||
<span className="ml-3 text-sm text-text-secondary">
|
||||
{(() => {
|
||||
const keyMap: Record<string, string> = {
|
||||
'@app': 'app.gotoAnything.actions.searchApplicationsDesc',
|
||||
'@plugin': 'app.gotoAnything.actions.searchPluginsDesc',
|
||||
'@knowledge': 'app.gotoAnything.actions.searchKnowledgeBasesDesc',
|
||||
'@run': 'app.gotoAnything.actions.runDesc',
|
||||
'@node': 'app.gotoAnything.actions.searchWorkflowNodesDesc',
|
||||
}
|
||||
return t(keyMap[action.key])
|
||||
})()}
|
||||
</span>
|
||||
</Command.Item>
|
||||
))}
|
||||
</Command.Group>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default CommandSelector
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
'use client'
|
||||
|
||||
import type { ReactNode } from 'react'
|
||||
import React, { createContext, useContext, useEffect, useState } from 'react'
|
||||
import { usePathname } from 'next/navigation'
|
||||
|
||||
/**
|
||||
* Interface for the GotoAnything context
|
||||
*/
|
||||
type GotoAnythingContextType = {
|
||||
/**
|
||||
* Whether the current page is a workflow page
|
||||
*/
|
||||
isWorkflowPage: boolean
|
||||
}
|
||||
|
||||
// Create context with default values
|
||||
const GotoAnythingContext = createContext<GotoAnythingContextType>({
|
||||
isWorkflowPage: false,
|
||||
})
|
||||
|
||||
/**
|
||||
* Hook to use the GotoAnything context
|
||||
*/
|
||||
export const useGotoAnythingContext = () => useContext(GotoAnythingContext)
|
||||
|
||||
type GotoAnythingProviderProps = {
|
||||
children: ReactNode
|
||||
}
|
||||
|
||||
/**
|
||||
* Provider component for GotoAnything context
|
||||
*/
|
||||
export const GotoAnythingProvider: React.FC<GotoAnythingProviderProps> = ({ children }) => {
|
||||
const [isWorkflowPage, setIsWorkflowPage] = useState(false)
|
||||
const pathname = usePathname()
|
||||
|
||||
// Update context based on current pathname
|
||||
useEffect(() => {
|
||||
// Check if current path contains workflow
|
||||
const isWorkflow = pathname?.includes('/workflow') || false
|
||||
setIsWorkflowPage(isWorkflow)
|
||||
}, [pathname])
|
||||
|
||||
return (
|
||||
<GotoAnythingContext.Provider value={{ isWorkflowPage }}>
|
||||
{children}
|
||||
</GotoAnythingContext.Provider>
|
||||
)
|
||||
}
|
||||
|
|
@ -0,0 +1,420 @@
|
|||
'use client'
|
||||
|
||||
import type { FC } from 'react'
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
|
||||
import { useRouter } from 'next/navigation'
|
||||
import Modal from '@/app/components/base/modal'
|
||||
import Input from '@/app/components/base/input'
|
||||
import { useDebounce, useKeyPress } from 'ahooks'
|
||||
import { getKeyboardKeyCodeBySystem, isEventTargetInputArea, isMac } from '@/app/components/workflow/utils/common'
|
||||
import { selectWorkflowNode } from '@/app/components/workflow/utils/node-navigation'
|
||||
import { RiSearchLine } from '@remixicon/react'
|
||||
import { Actions as AllActions, type SearchResult, matchAction, searchAnything } from './actions'
|
||||
import { GotoAnythingProvider, useGotoAnythingContext } from './context'
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import { useGetLanguage } from '@/context/i18n'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import InstallFromMarketplace from '../plugins/install-plugin/install-from-marketplace'
|
||||
import type { Plugin } from '../plugins/types'
|
||||
import { Command } from 'cmdk'
|
||||
import CommandSelector from './command-selector'
|
||||
import { RunCommandProvider } from './actions/run'
|
||||
|
||||
type Props = {
|
||||
onHide?: () => void
|
||||
}
|
||||
const GotoAnything: FC<Props> = ({
|
||||
onHide,
|
||||
}) => {
|
||||
const router = useRouter()
|
||||
const defaultLocale = useGetLanguage()
|
||||
const { isWorkflowPage } = useGotoAnythingContext()
|
||||
const { t } = useTranslation()
|
||||
const [show, setShow] = useState<boolean>(false)
|
||||
const [searchQuery, setSearchQuery] = useState<string>('')
|
||||
const [cmdVal, setCmdVal] = useState<string>('')
|
||||
const inputRef = useRef<HTMLInputElement>(null)
|
||||
const handleNavSearch = useCallback((q: string) => {
|
||||
setShow(true)
|
||||
setSearchQuery(q)
|
||||
requestAnimationFrame(() => inputRef.current?.focus())
|
||||
}, [])
|
||||
// Filter actions based on context
|
||||
const Actions = useMemo(() => {
|
||||
// Create a filtered copy of actions based on current page context
|
||||
if (isWorkflowPage) {
|
||||
// Include all actions on workflow pages
|
||||
return AllActions
|
||||
}
|
||||
else {
|
||||
// Exclude node action on non-workflow pages
|
||||
const { app, knowledge, plugin, run } = AllActions
|
||||
return { app, knowledge, plugin, run }
|
||||
}
|
||||
}, [isWorkflowPage])
|
||||
|
||||
const [activePlugin, setActivePlugin] = useState<Plugin>()
|
||||
|
||||
// Handle keyboard shortcuts
|
||||
const handleToggleModal = useCallback((e: KeyboardEvent) => {
|
||||
// Allow closing when modal is open, even if focus is in the search input
|
||||
if (!show && isEventTargetInputArea(e.target as HTMLElement))
|
||||
return
|
||||
e.preventDefault()
|
||||
setShow((prev) => {
|
||||
if (!prev) {
|
||||
// Opening modal - reset search state
|
||||
setSearchQuery('')
|
||||
}
|
||||
return !prev
|
||||
})
|
||||
}, [show])
|
||||
|
||||
useKeyPress(`${getKeyboardKeyCodeBySystem('ctrl')}.k`, handleToggleModal, {
|
||||
exactMatch: true,
|
||||
useCapture: true,
|
||||
})
|
||||
|
||||
useKeyPress(['esc'], (e) => {
|
||||
if (show) {
|
||||
e.preventDefault()
|
||||
setShow(false)
|
||||
setSearchQuery('')
|
||||
}
|
||||
})
|
||||
|
||||
const searchQueryDebouncedValue = useDebounce(searchQuery.trim(), {
|
||||
wait: 300,
|
||||
})
|
||||
|
||||
const isCommandsMode = searchQuery.trim() === '@' || (searchQuery.trim().startsWith('@') && !matchAction(searchQuery.trim(), Actions))
|
||||
|
||||
const searchMode = useMemo(() => {
|
||||
if (isCommandsMode) return 'commands'
|
||||
|
||||
const query = searchQueryDebouncedValue.toLowerCase()
|
||||
const action = matchAction(query, Actions)
|
||||
return action ? action.key : 'general'
|
||||
}, [searchQueryDebouncedValue, Actions, isCommandsMode])
|
||||
|
||||
const { data: searchResults = [], isLoading, isError, error } = useQuery(
|
||||
{
|
||||
queryKey: [
|
||||
'goto-anything',
|
||||
'search-result',
|
||||
searchQueryDebouncedValue,
|
||||
searchMode,
|
||||
isWorkflowPage,
|
||||
defaultLocale,
|
||||
Object.keys(Actions).sort().join(','),
|
||||
],
|
||||
queryFn: async () => {
|
||||
const query = searchQueryDebouncedValue.toLowerCase()
|
||||
const action = matchAction(query, Actions)
|
||||
return await searchAnything(defaultLocale, query, action)
|
||||
},
|
||||
enabled: !!searchQueryDebouncedValue && !isCommandsMode,
|
||||
staleTime: 30000,
|
||||
gcTime: 300000,
|
||||
},
|
||||
)
|
||||
|
||||
const handleCommandSelect = useCallback((commandKey: string) => {
|
||||
setSearchQuery(`${commandKey} `)
|
||||
setCmdVal('')
|
||||
setTimeout(() => {
|
||||
inputRef.current?.focus()
|
||||
}, 0)
|
||||
}, [])
|
||||
|
||||
// Handle navigation to selected result
|
||||
const handleNavigate = useCallback((result: SearchResult) => {
|
||||
setShow(false)
|
||||
setSearchQuery('')
|
||||
|
||||
switch (result.type) {
|
||||
case 'command': {
|
||||
const action = Object.values(Actions).find(a => a.key === '@run')
|
||||
action?.action?.(result)
|
||||
break
|
||||
}
|
||||
case 'plugin':
|
||||
setActivePlugin(result.data)
|
||||
break
|
||||
case 'workflow-node':
|
||||
// Handle workflow node selection and navigation
|
||||
if (result.metadata?.nodeId)
|
||||
selectWorkflowNode(result.metadata.nodeId, true)
|
||||
|
||||
break
|
||||
default:
|
||||
if (result.path)
|
||||
router.push(result.path)
|
||||
}
|
||||
}, [router])
|
||||
|
||||
// Group results by type
|
||||
const groupedResults = useMemo(() => searchResults.reduce((acc, result) => {
|
||||
if (!acc[result.type])
|
||||
acc[result.type] = []
|
||||
|
||||
acc[result.type].push(result)
|
||||
return acc
|
||||
}, {} as { [key: string]: SearchResult[] }),
|
||||
[searchResults])
|
||||
|
||||
const emptyResult = useMemo(() => {
|
||||
if (searchResults.length || !searchQuery.trim() || isLoading || isCommandsMode)
|
||||
return null
|
||||
|
||||
const isCommandSearch = searchMode !== 'general'
|
||||
const commandType = isCommandSearch ? searchMode.replace('@', '') : ''
|
||||
|
||||
if (isError) {
|
||||
return (
|
||||
<div className="flex items-center justify-center py-8 text-center text-text-tertiary">
|
||||
<div>
|
||||
<div className='text-sm font-medium text-red-500'>{t('app.gotoAnything.searchTemporarilyUnavailable')}</div>
|
||||
<div className='mt-1 text-xs text-text-quaternary'>
|
||||
{t('app.gotoAnything.servicesUnavailableMessage')}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex items-center justify-center py-8 text-center text-text-tertiary">
|
||||
<div>
|
||||
<div className='text-sm font-medium'>
|
||||
{isCommandSearch
|
||||
? (() => {
|
||||
const keyMap: Record<string, string> = {
|
||||
app: 'app.gotoAnything.emptyState.noAppsFound',
|
||||
plugin: 'app.gotoAnything.emptyState.noPluginsFound',
|
||||
knowledge: 'app.gotoAnything.emptyState.noKnowledgeBasesFound',
|
||||
node: 'app.gotoAnything.emptyState.noWorkflowNodesFound',
|
||||
}
|
||||
return t(keyMap[commandType] || 'app.gotoAnything.noResults')
|
||||
})()
|
||||
: t('app.gotoAnything.noResults')
|
||||
}
|
||||
</div>
|
||||
<div className='mt-1 text-xs text-text-quaternary'>
|
||||
{isCommandSearch
|
||||
? t('app.gotoAnything.emptyState.tryDifferentTerm', { mode: searchMode })
|
||||
: t('app.gotoAnything.emptyState.trySpecificSearch', { shortcuts: Object.values(Actions).map(action => action.shortcut).join(', ') })
|
||||
}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}, [searchResults, searchQuery, Actions, searchMode, isLoading, isError, isCommandsMode])
|
||||
|
||||
const defaultUI = useMemo(() => {
|
||||
if (searchQuery.trim())
|
||||
return null
|
||||
|
||||
return (<div className="flex items-center justify-center py-12 text-center text-text-tertiary">
|
||||
<div>
|
||||
<div className='text-sm font-medium'>{t('app.gotoAnything.searchTitle')}</div>
|
||||
<div className='mt-3 space-y-1 text-xs text-text-quaternary'>
|
||||
<div>{t('app.gotoAnything.searchHint')}</div>
|
||||
<div>{t('app.gotoAnything.commandHint')}</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>)
|
||||
}, [searchQuery, Actions])
|
||||
|
||||
useEffect(() => {
|
||||
if (show) {
|
||||
requestAnimationFrame(() => {
|
||||
inputRef.current?.focus()
|
||||
})
|
||||
}
|
||||
return () => {
|
||||
setCmdVal('')
|
||||
}
|
||||
}, [show])
|
||||
|
||||
return (
|
||||
<>
|
||||
<Modal
|
||||
isShow={show}
|
||||
onClose={() => {
|
||||
setShow(false)
|
||||
setSearchQuery('')
|
||||
onHide?.()
|
||||
}}
|
||||
closable={false}
|
||||
className='!w-[480px] !p-0'
|
||||
highPriority={true}
|
||||
>
|
||||
<div className='flex flex-col rounded-2xl border border-components-panel-border bg-components-panel-bg shadow-xl'>
|
||||
<Command
|
||||
className='outline-none'
|
||||
value={cmdVal}
|
||||
onValueChange={setCmdVal}
|
||||
disablePointerSelection
|
||||
>
|
||||
<div className='flex items-center gap-3 border-b border-divider-subtle bg-components-panel-bg-blur px-4 py-3'>
|
||||
<RiSearchLine className='h-4 w-4 text-text-quaternary' />
|
||||
<div className='flex flex-1 items-center gap-2'>
|
||||
<Input
|
||||
ref={inputRef}
|
||||
value={searchQuery}
|
||||
placeholder={t('app.gotoAnything.searchPlaceholder')}
|
||||
onChange={(e) => {
|
||||
setSearchQuery(e.target.value)
|
||||
if (!e.target.value.startsWith('@'))
|
||||
setCmdVal('')
|
||||
}}
|
||||
className='flex-1 !border-0 !bg-transparent !shadow-none'
|
||||
wrapperClassName='flex-1 !border-0 !bg-transparent'
|
||||
autoFocus
|
||||
/>
|
||||
{searchMode !== 'general' && (
|
||||
<div className='flex items-center gap-1 rounded bg-blue-50 px-2 py-[2px] text-xs font-medium text-blue-600 dark:bg-blue-900/40 dark:text-blue-300'>
|
||||
<span>{searchMode.replace('@', '').toUpperCase()}</span>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<div className='text-xs text-text-quaternary'>
|
||||
<span className='system-kbd rounded bg-gray-200 px-1 py-[2px] font-mono text-gray-700 dark:bg-gray-800 dark:text-gray-100'>
|
||||
{isMac() ? '⌘' : 'Ctrl'}
|
||||
</span>
|
||||
<span className='system-kbd ml-1 rounded bg-gray-200 px-1 py-[2px] font-mono text-gray-700 dark:bg-gray-800 dark:text-gray-100'>
|
||||
K
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Command.List className='max-h-[275px] min-h-[240px] overflow-y-auto'>
|
||||
{isLoading && (
|
||||
<div className="flex items-center justify-center py-8 text-center text-text-tertiary">
|
||||
<div className="flex items-center gap-2">
|
||||
<div className="h-4 w-4 animate-spin rounded-full border-2 border-gray-300 border-t-gray-600"></div>
|
||||
<span className="text-sm">{t('app.gotoAnything.searching')}</span>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
{isError && (
|
||||
<div className="flex items-center justify-center py-8 text-center text-text-tertiary">
|
||||
<div>
|
||||
<div className="text-sm font-medium text-red-500">{t('app.gotoAnything.searchFailed')}</div>
|
||||
<div className="mt-1 text-xs text-text-quaternary">
|
||||
{error.message}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
{!isLoading && !isError && (
|
||||
<>
|
||||
{isCommandsMode ? (
|
||||
<CommandSelector
|
||||
actions={Actions}
|
||||
onCommandSelect={handleCommandSelect}
|
||||
searchFilter={searchQuery.trim().substring(1)}
|
||||
commandValue={cmdVal}
|
||||
onCommandValueChange={setCmdVal}
|
||||
/>
|
||||
) : (
|
||||
Object.entries(groupedResults).map(([type, results], groupIndex) => (
|
||||
<Command.Group key={groupIndex} heading={(() => {
|
||||
const typeMap: Record<string, string> = {
|
||||
'app': 'app.gotoAnything.groups.apps',
|
||||
'plugin': 'app.gotoAnything.groups.plugins',
|
||||
'knowledge': 'app.gotoAnything.groups.knowledgeBases',
|
||||
'workflow-node': 'app.gotoAnything.groups.workflowNodes',
|
||||
}
|
||||
return t(typeMap[type] || `${type}s`)
|
||||
})()} className='p-2 capitalize text-text-secondary'>
|
||||
{results.map(result => (
|
||||
<Command.Item
|
||||
key={`${result.type}-${result.id}`}
|
||||
value={result.title}
|
||||
className='flex cursor-pointer items-center gap-3 rounded-md p-3 will-change-[background-color] hover:bg-state-base-hover-alt aria-[selected=true]:bg-state-base-hover data-[selected=true]:bg-state-base-hover'
|
||||
onSelect={() => handleNavigate(result)}
|
||||
>
|
||||
{result.icon}
|
||||
<div className='min-w-0 flex-1'>
|
||||
<div className='truncate font-medium text-text-secondary'>
|
||||
{result.title}
|
||||
</div>
|
||||
{result.description && (
|
||||
<div className='mt-0.5 truncate text-xs text-text-quaternary'>
|
||||
{result.description}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<div className='text-xs capitalize text-text-quaternary'>
|
||||
{result.type}
|
||||
</div>
|
||||
</Command.Item>
|
||||
))}
|
||||
</Command.Group>
|
||||
))
|
||||
)}
|
||||
{!isCommandsMode && emptyResult}
|
||||
{!isCommandsMode && defaultUI}
|
||||
</>
|
||||
)}
|
||||
</Command.List>
|
||||
|
||||
{(!!searchResults.length || isError) && (
|
||||
<div className='border-t border-divider-subtle bg-components-panel-bg-blur px-4 py-2 text-xs text-text-tertiary'>
|
||||
<div className='flex items-center justify-between'>
|
||||
<span>
|
||||
{isError ? (
|
||||
<span className='text-red-500'>{t('app.gotoAnything.someServicesUnavailable')}</span>
|
||||
) : (
|
||||
<>
|
||||
{t('app.gotoAnything.resultCount', { count: searchResults.length })}
|
||||
{searchMode !== 'general' && (
|
||||
<span className='ml-2 opacity-60'>
|
||||
{t('app.gotoAnything.inScope', { scope: searchMode.replace('@', '') })}
|
||||
</span>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</span>
|
||||
<span className='opacity-60'>
|
||||
{searchMode !== 'general'
|
||||
? t('app.gotoAnything.clearToSearchAll')
|
||||
: t('app.gotoAnything.useAtForSpecific')
|
||||
}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</Command>
|
||||
</div>
|
||||
|
||||
</Modal>
|
||||
<RunCommandProvider onNavSearch={handleNavSearch} />
|
||||
{
|
||||
activePlugin && (
|
||||
<InstallFromMarketplace
|
||||
manifest={activePlugin}
|
||||
uniqueIdentifier={activePlugin.latest_package_identifier}
|
||||
onClose={() => setActivePlugin(undefined)}
|
||||
onSuccess={() => setActivePlugin(undefined)}
|
||||
/>
|
||||
)
|
||||
}
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* GotoAnything component with context provider
|
||||
*/
|
||||
const GotoAnythingWithContext: FC<Props> = (props) => {
|
||||
return (
|
||||
<GotoAnythingProvider>
|
||||
<GotoAnything {...props} />
|
||||
</GotoAnythingProvider>
|
||||
)
|
||||
}
|
||||
|
||||
export default GotoAnythingWithContext
|
||||
|
|
@ -70,6 +70,7 @@ export default function LanguagePage() {
|
|||
items={languages.filter(item => item.supported)}
|
||||
onSelect={item => handleSelectLanguage(item)}
|
||||
disabled={editing}
|
||||
notClearable={true}
|
||||
/>
|
||||
</div>
|
||||
<div className='mb-8'>
|
||||
|
|
@ -79,6 +80,7 @@ export default function LanguagePage() {
|
|||
items={timezones}
|
||||
onSelect={item => handleSelectTimezone(item)}
|
||||
disabled={editing}
|
||||
notClearable={true}
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import type { FC } from 'react'
|
||||
import { useEffect, useRef, useState } from 'react'
|
||||
import type { ModelParameterRule } from '../declarations'
|
||||
import { useLanguage } from '../hooks'
|
||||
import { isNullOrUndefined } from '../utils'
|
||||
import cn from '@/utils/classnames'
|
||||
import Switch from '@/app/components/base/switch'
|
||||
|
|
@ -26,6 +27,7 @@ const ParameterItem: FC<ParameterItemProps> = ({
|
|||
onSwitch,
|
||||
isInWorkflow,
|
||||
}) => {
|
||||
const language = useLanguage()
|
||||
const [localValue, setLocalValue] = useState(value)
|
||||
const numberInputRef = useRef<HTMLInputElement>(null)
|
||||
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ const TagsFilter = ({
|
|||
onClick={() => setOpen(v => !v)}
|
||||
>
|
||||
<div className={cn(
|
||||
'ml-0.5 mr-1.5 flex items-center text-text-tertiary ',
|
||||
'ml-0.5 mr-1.5 flex select-none items-center text-text-tertiary',
|
||||
size === 'large' && 'h-8 py-1',
|
||||
size === 'small' && 'h-7 py-0.5 ',
|
||||
// selectedTagsLength && 'text-text-secondary',
|
||||
|
|
@ -80,7 +80,7 @@ const TagsFilter = ({
|
|||
filteredOptions.map(option => (
|
||||
<div
|
||||
key={option.name}
|
||||
className='flex h-7 cursor-pointer items-center rounded-lg px-2 py-1.5 hover:bg-state-base-hover'
|
||||
className='flex h-7 cursor-pointer select-none items-center rounded-lg px-2 py-1.5 hover:bg-state-base-hover'
|
||||
onClick={() => handleCheck(option.name)}
|
||||
>
|
||||
<Checkbox
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ const TagsFilter = ({
|
|||
>
|
||||
<PortalToFollowElemTrigger onClick={() => setOpen(v => !v)}>
|
||||
<div className={cn(
|
||||
'flex h-8 cursor-pointer items-center rounded-lg bg-components-input-bg-normal px-2 py-1 text-text-tertiary hover:bg-state-base-hover-alt',
|
||||
'flex h-8 cursor-pointer select-none items-center rounded-lg bg-components-input-bg-normal px-2 py-1 text-text-tertiary hover:bg-state-base-hover-alt',
|
||||
selectedTagsLength && 'text-text-secondary',
|
||||
open && 'bg-state-base-hover',
|
||||
)}>
|
||||
|
|
@ -99,7 +99,7 @@ const TagsFilter = ({
|
|||
filteredOptions.map(option => (
|
||||
<div
|
||||
key={option.name}
|
||||
className='flex h-7 cursor-pointer items-center rounded-lg px-2 py-1.5 hover:bg-state-base-hover'
|
||||
className='flex h-7 cursor-pointer select-none items-center rounded-lg px-2 py-1.5 hover:bg-state-base-hover'
|
||||
onClick={() => handleCheck(option.name)}
|
||||
>
|
||||
<Checkbox
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ const LabelFilter: FC<LabelFilterProps> = ({
|
|||
className='block'
|
||||
>
|
||||
<div className={cn(
|
||||
'flex h-8 cursor-pointer items-center gap-1 rounded-lg border-[0.5px] border-transparent bg-components-input-bg-normal px-2 hover:bg-components-input-bg-hover',
|
||||
'flex h-8 cursor-pointer select-none items-center gap-1 rounded-lg border-[0.5px] border-transparent bg-components-input-bg-normal px-2 hover:bg-components-input-bg-hover',
|
||||
!open && !!value.length && 'shadow-xs',
|
||||
open && !!value.length && 'shadow-xs',
|
||||
)}>
|
||||
|
|
@ -111,7 +111,7 @@ const LabelFilter: FC<LabelFilterProps> = ({
|
|||
{filteredLabelList.map(label => (
|
||||
<div
|
||||
key={label.name}
|
||||
className='flex cursor-pointer items-center gap-2 rounded-lg py-[6px] pl-3 pr-2 hover:bg-state-base-hover'
|
||||
className='flex cursor-pointer select-none items-center gap-2 rounded-lg py-[6px] pl-3 pr-2 hover:bg-state-base-hover'
|
||||
onClick={() => selectLabel(label)}
|
||||
>
|
||||
<div title={label.label} className='grow truncate text-sm leading-5 text-text-secondary'>{label.label}</div>
|
||||
|
|
|
|||
|
|
@ -27,6 +27,8 @@ export type DuplicateAppModalProps = {
|
|||
icon: string
|
||||
icon_background?: string | null
|
||||
server_identifier: string
|
||||
timeout: number
|
||||
sse_read_timeout: number
|
||||
}) => void
|
||||
onHide: () => void
|
||||
}
|
||||
|
|
@ -64,6 +66,8 @@ const MCPModal = ({
|
|||
const [appIcon, setAppIcon] = useState<AppIconSelection>(getIcon(data))
|
||||
const [showAppIconPicker, setShowAppIconPicker] = useState(false)
|
||||
const [serverIdentifier, setServerIdentifier] = React.useState(data?.server_identifier || '')
|
||||
const [timeout, setMcpTimeout] = React.useState(30)
|
||||
const [sseReadTimeout, setSseReadTimeout] = React.useState(300)
|
||||
const [isFetchingIcon, setIsFetchingIcon] = useState(false)
|
||||
const appIconRef = useRef<HTMLDivElement>(null)
|
||||
const isHovering = useHover(appIconRef)
|
||||
|
|
@ -73,7 +77,7 @@ const MCPModal = ({
|
|||
const urlPattern = /^(https?:\/\/)((([a-z\d]([a-z\d-]*[a-z\d])*)\.)+[a-z]{2,}|((\d{1,3}\.){3}\d{1,3})|localhost)(\:\d+)?(\/[-a-z\d%_.~+]*)*(\?[;&a-z\d%_.~+=-]*)?/i
|
||||
return urlPattern.test(string)
|
||||
}
|
||||
catch (e) {
|
||||
catch {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
|
@ -123,6 +127,8 @@ const MCPModal = ({
|
|||
icon: appIcon.type === 'emoji' ? appIcon.icon : appIcon.fileId,
|
||||
icon_background: appIcon.type === 'emoji' ? appIcon.background : undefined,
|
||||
server_identifier: serverIdentifier.trim(),
|
||||
timeout: timeout || 30,
|
||||
sse_read_timeout: sseReadTimeout || 300,
|
||||
})
|
||||
if(isCreate)
|
||||
onHide()
|
||||
|
|
@ -201,6 +207,30 @@ const MCPModal = ({
|
|||
</div>
|
||||
)}
|
||||
</div>
|
||||
<div>
|
||||
<div className='mb-1 flex h-6 items-center'>
|
||||
<span className='system-sm-medium text-text-secondary'>{t('tools.mcp.modal.timeout')}</span>
|
||||
</div>
|
||||
<Input
|
||||
type='number'
|
||||
value={timeout}
|
||||
onChange={e => setMcpTimeout(Number(e.target.value))}
|
||||
onBlur={e => handleBlur(e.target.value.trim())}
|
||||
placeholder={t('tools.mcp.modal.timeoutPlaceholder')}
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<div className='mb-1 flex h-6 items-center'>
|
||||
<span className='system-sm-medium text-text-secondary'>{t('tools.mcp.modal.sseReadTimeout')}</span>
|
||||
</div>
|
||||
<Input
|
||||
type='number'
|
||||
value={sseReadTimeout}
|
||||
onChange={e => setSseReadTimeout(Number(e.target.value))}
|
||||
onBlur={e => handleBlur(e.target.value.trim())}
|
||||
placeholder={t('tools.mcp.modal.timeoutPlaceholder')}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className='flex flex-row-reverse pt-5'>
|
||||
<Button disabled={!name || !url || !serverIdentifier || isFetchingIcon} className='ml-2' variant='primary' onClick={submit}>{data ? t('tools.mcp.modal.save') : t('tools.mcp.modal.confirm')}</Button>
|
||||
|
|
|
|||
|
|
@ -57,6 +57,8 @@ export type Collection = {
|
|||
server_url?: string
|
||||
updated_at?: number
|
||||
server_identifier?: string
|
||||
timeout?: number
|
||||
sse_read_timeout?: number
|
||||
}
|
||||
|
||||
export type ToolParameter = {
|
||||
|
|
|
|||
|
|
@ -18,3 +18,4 @@ export * from './use-workflow-mode'
|
|||
export * from './use-workflow-refresh-draft'
|
||||
export * from './use-inspect-vars-crud'
|
||||
export * from './use-set-workflow-vars-with-value'
|
||||
export * from './use-workflow-search'
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue