diff --git a/api/core/rag/extractor/firecrawl/firecrawl_app.py b/api/core/rag/extractor/firecrawl/firecrawl_app.py index c20ecd2b89..789ac8557d 100644 --- a/api/core/rag/extractor/firecrawl/firecrawl_app.py +++ b/api/core/rag/extractor/firecrawl/firecrawl_app.py @@ -25,7 +25,7 @@ class FirecrawlApp: } if params: json_data.update(params) - response = self._post_request(f"{self.base_url}/v1/scrape", json_data, headers) + response = self._post_request(f"{self.base_url}/v2/scrape", json_data, headers) if response.status_code == 200: response_data = response.json() data = response_data["data"] @@ -42,7 +42,7 @@ class FirecrawlApp: json_data = {"url": url} if params: json_data.update(params) - response = self._post_request(f"{self.base_url}/v1/crawl", json_data, headers) + response = self._post_request(f"{self.base_url}/v2/crawl", json_data, headers) if response.status_code == 200: # There's also another two fields in the response: "success" (bool) and "url" (str) job_id = response.json().get("id") @@ -51,9 +51,25 @@ class FirecrawlApp: self._handle_error(response, "start crawl job") return "" # unreachable + def map(self, url: str, params: dict[str, Any] | None = None) -> dict[str, Any]: + # Documentation: https://docs.firecrawl.dev/api-reference/endpoint/map + headers = self._prepare_headers() + json_data: dict[str, Any] = {"url": url, "integration": "dify"} + if params: + # Pass through provided params, including optional "sitemap": "only" | "include" | "skip" + json_data.update(params) + response = self._post_request(f"{self.base_url}/v2/map", json_data, headers) + if response.status_code == 200: + return cast(dict[str, Any], response.json()) + elif response.status_code in {402, 409, 500, 429, 408}: + self._handle_error(response, "start map job") + return {} + else: + raise Exception(f"Failed to start map job. Status code: {response.status_code}") + def check_crawl_status(self, job_id) -> dict[str, Any]: headers = self._prepare_headers() - response = self._get_request(f"{self.base_url}/v1/crawl/{job_id}", headers) + response = self._get_request(f"{self.base_url}/v2/crawl/{job_id}", headers) if response.status_code == 200: crawl_status_response = response.json() if crawl_status_response.get("status") == "completed": @@ -135,12 +151,16 @@ class FirecrawlApp: "lang": "en", "country": "us", "timeout": 60000, - "ignoreInvalidURLs": False, + "ignoreInvalidURLs": True, "scrapeOptions": {}, + "sources": [ + {"type": "web"}, + ], + "integration": "dify", } if params: json_data.update(params) - response = self._post_request(f"{self.base_url}/v1/search", json_data, headers) + response = self._post_request(f"{self.base_url}/v2/search", json_data, headers) if response.status_code == 200: response_data = response.json() if not response_data.get("success"): diff --git a/api/core/workflow/graph_engine/command_channels/redis_channel.py b/api/core/workflow/graph_engine/command_channels/redis_channel.py index c841459170..527647ae3b 100644 --- a/api/core/workflow/graph_engine/command_channels/redis_channel.py +++ b/api/core/workflow/graph_engine/command_channels/redis_channel.py @@ -41,6 +41,7 @@ class RedisChannel: self._redis = redis_client self._key = channel_key self._command_ttl = command_ttl + self._pending_key = f"{channel_key}:pending" def fetch_commands(self) -> list[GraphEngineCommand]: """ @@ -49,6 +50,9 @@ class RedisChannel: Returns: List of pending commands (drains the Redis list) """ + if not self._has_pending_commands(): + return [] + commands: list[GraphEngineCommand] = [] # Use pipeline for atomic operations @@ -85,6 +89,7 @@ class RedisChannel: with self._redis.pipeline() as pipe: pipe.rpush(self._key, command_json) pipe.expire(self._key, self._command_ttl) + pipe.set(self._pending_key, "1", ex=self._command_ttl) pipe.execute() def _deserialize_command(self, data: dict[str, Any]) -> GraphEngineCommand | None: @@ -112,3 +117,17 @@ class RedisChannel: except (ValueError, TypeError): return None + + def _has_pending_commands(self) -> bool: + """ + Check and consume the pending marker to avoid unnecessary list reads. + + Returns: + True if commands should be fetched from Redis. + """ + with self._redis.pipeline() as pipe: + pipe.get(self._pending_key) + pipe.delete(self._pending_key) + pending_value, _ = pipe.execute() + + return pending_value is not None diff --git a/api/core/workflow/graph_engine/event_management/event_handlers.py b/api/core/workflow/graph_engine/event_management/event_handlers.py index 7247b17967..1cb5851ab1 100644 --- a/api/core/workflow/graph_engine/event_management/event_handlers.py +++ b/api/core/workflow/graph_engine/event_management/event_handlers.py @@ -7,6 +7,7 @@ from collections.abc import Mapping from functools import singledispatchmethod from typing import TYPE_CHECKING, final +from core.model_runtime.entities.llm_entities import LLMUsage from core.workflow.entities import GraphRuntimeState from core.workflow.enums import ErrorStrategy, NodeExecutionType from core.workflow.graph import Graph @@ -125,6 +126,7 @@ class EventHandler: node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) is_initial_attempt = node_execution.retry_count == 0 node_execution.mark_started(event.id) + self._graph_runtime_state.increment_node_run_steps() # Track in response coordinator for stream ordering self._response_coordinator.track_node_execution(event.node_id, event.id) @@ -163,6 +165,8 @@ class EventHandler: node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) node_execution.mark_taken() + self._accumulate_node_usage(event.node_run_result.llm_usage) + # Store outputs in variable pool self._store_node_outputs(event.node_id, event.node_run_result.outputs) @@ -212,6 +216,8 @@ class EventHandler: node_execution.mark_failed(event.error) self._graph_execution.record_node_failure() + self._accumulate_node_usage(event.node_run_result.llm_usage) + result = self._error_handler.handle_node_failure(event) if result: @@ -235,6 +241,8 @@ class EventHandler: node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) node_execution.mark_taken() + self._accumulate_node_usage(event.node_run_result.llm_usage) + # Persist outputs produced by the exception strategy (e.g. default values) self._store_node_outputs(event.node_id, event.node_run_result.outputs) @@ -286,6 +294,19 @@ class EventHandler: self._state_manager.enqueue_node(event.node_id) self._state_manager.start_execution(event.node_id) + def _accumulate_node_usage(self, usage: LLMUsage) -> None: + """Accumulate token usage into the shared runtime state.""" + if usage.total_tokens <= 0: + return + + self._graph_runtime_state.add_tokens(usage.total_tokens) + + current_usage = self._graph_runtime_state.llm_usage + if current_usage.total_tokens == 0: + self._graph_runtime_state.llm_usage = usage + else: + self._graph_runtime_state.llm_usage = current_usage.plus(usage) + def _store_node_outputs(self, node_id: str, outputs: Mapping[str, object]) -> None: """ Store node outputs in the variable pool. diff --git a/api/core/workflow/graph_engine/orchestration/dispatcher.py b/api/core/workflow/graph_engine/orchestration/dispatcher.py index a7229ce4e8..8340c10b49 100644 --- a/api/core/workflow/graph_engine/orchestration/dispatcher.py +++ b/api/core/workflow/graph_engine/orchestration/dispatcher.py @@ -8,7 +8,12 @@ import threading import time from typing import TYPE_CHECKING, final -from core.workflow.graph_events.base import GraphNodeEventBase +from core.workflow.graph_events import ( + GraphNodeEventBase, + NodeRunExceptionEvent, + NodeRunFailedEvent, + NodeRunSucceededEvent, +) from ..event_management import EventManager from .execution_coordinator import ExecutionCoordinator @@ -72,13 +77,16 @@ class Dispatcher: if self._thread and self._thread.is_alive(): self._thread.join(timeout=10.0) + _COMMAND_TRIGGER_EVENTS = ( + NodeRunSucceededEvent, + NodeRunFailedEvent, + NodeRunExceptionEvent, + ) + def _dispatcher_loop(self) -> None: """Main dispatcher loop.""" try: while not self._stop_event.is_set(): - # Check for commands - self._execution_coordinator.check_commands() - # Check for scaling self._execution_coordinator.check_scaling() @@ -87,6 +95,8 @@ class Dispatcher: event = self._event_queue.get(timeout=0.1) # Route to the event handler self._event_handler.dispatch(event) + if self._should_check_commands(event): + self._execution_coordinator.check_commands() self._event_queue.task_done() except queue.Empty: # Check if execution is complete @@ -102,3 +112,7 @@ class Dispatcher: # Signal the event emitter that execution is complete if self._event_emitter: self._event_emitter.mark_complete() + + def _should_check_commands(self, event: GraphNodeEventBase) -> bool: + """Return True if the event represents a node completion.""" + return isinstance(event, self._COMMAND_TRIGGER_EVENTS) diff --git a/api/services/website_service.py b/api/services/website_service.py index 37588d6ba5..a23f01ec71 100644 --- a/api/services/website_service.py +++ b/api/services/website_service.py @@ -23,6 +23,7 @@ class CrawlOptions: only_main_content: bool = False includes: str | None = None excludes: str | None = None + prompt: str | None = None max_depth: int | None = None use_sitemap: bool = True @@ -70,6 +71,7 @@ class WebsiteCrawlApiRequest: only_main_content=self.options.get("only_main_content", False), includes=self.options.get("includes"), excludes=self.options.get("excludes"), + prompt=self.options.get("prompt"), max_depth=self.options.get("max_depth"), use_sitemap=self.options.get("use_sitemap", True), ) @@ -174,6 +176,7 @@ class WebsiteService: def _crawl_with_firecrawl(cls, request: CrawlRequest, api_key: str, config: dict) -> dict[str, Any]: firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url")) + params: dict[str, Any] if not request.options.crawl_sub_pages: params = { "includePaths": [], @@ -188,8 +191,10 @@ class WebsiteService: "limit": request.options.limit, "scrapeOptions": {"onlyMainContent": request.options.only_main_content}, } - if request.options.max_depth: - params["maxDepth"] = request.options.max_depth + + # Add optional prompt for Firecrawl v2 crawl-params compatibility + if request.options.prompt: + params["prompt"] = request.options.prompt job_id = firecrawl_app.crawl_url(request.url, params) website_crawl_time_cache_key = f"website_crawl_{job_id}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py b/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py index 7ebccf83a7..8677325d4e 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py @@ -35,11 +35,15 @@ class TestRedisChannel: """Test sending a command to Redis.""" mock_redis = MagicMock() mock_pipe = MagicMock() - mock_redis.pipeline.return_value.__enter__ = MagicMock(return_value=mock_pipe) - mock_redis.pipeline.return_value.__exit__ = MagicMock(return_value=None) + context = MagicMock() + context.__enter__.return_value = mock_pipe + context.__exit__.return_value = None + mock_redis.pipeline.return_value = context channel = RedisChannel(mock_redis, "test:key", 3600) + pending_key = "test:key:pending" + # Create a test command command = GraphEngineCommand(command_type=CommandType.ABORT) @@ -55,6 +59,7 @@ class TestRedisChannel: # Verify expire was set mock_pipe.expire.assert_called_once_with("test:key", 3600) + mock_pipe.set.assert_called_once_with(pending_key, "1", ex=3600) # Verify execute was called mock_pipe.execute.assert_called_once() @@ -62,33 +67,48 @@ class TestRedisChannel: def test_fetch_commands_empty(self): """Test fetching commands when Redis list is empty.""" mock_redis = MagicMock() - mock_pipe = MagicMock() - mock_redis.pipeline.return_value.__enter__ = MagicMock(return_value=mock_pipe) - mock_redis.pipeline.return_value.__exit__ = MagicMock(return_value=None) + pending_pipe = MagicMock() + fetch_pipe = MagicMock() + pending_context = MagicMock() + fetch_context = MagicMock() + pending_context.__enter__.return_value = pending_pipe + pending_context.__exit__.return_value = None + fetch_context.__enter__.return_value = fetch_pipe + fetch_context.__exit__.return_value = None + mock_redis.pipeline.side_effect = [pending_context] - # Simulate empty list - mock_pipe.execute.return_value = [[], 1] # Empty list, delete successful + # No pending marker + pending_pipe.execute.return_value = [None, 0] + mock_redis.llen.return_value = 0 channel = RedisChannel(mock_redis, "test:key") commands = channel.fetch_commands() assert commands == [] - mock_pipe.lrange.assert_called_once_with("test:key", 0, -1) - mock_pipe.delete.assert_called_once_with("test:key") + mock_redis.pipeline.assert_called_once() + fetch_pipe.lrange.assert_not_called() + fetch_pipe.delete.assert_not_called() def test_fetch_commands_with_abort_command(self): """Test fetching abort commands from Redis.""" mock_redis = MagicMock() - mock_pipe = MagicMock() - mock_redis.pipeline.return_value.__enter__ = MagicMock(return_value=mock_pipe) - mock_redis.pipeline.return_value.__exit__ = MagicMock(return_value=None) + pending_pipe = MagicMock() + fetch_pipe = MagicMock() + pending_context = MagicMock() + fetch_context = MagicMock() + pending_context.__enter__.return_value = pending_pipe + pending_context.__exit__.return_value = None + fetch_context.__enter__.return_value = fetch_pipe + fetch_context.__exit__.return_value = None + mock_redis.pipeline.side_effect = [pending_context, fetch_context] # Create abort command data abort_command = AbortCommand() command_json = json.dumps(abort_command.model_dump()) # Simulate Redis returning one command - mock_pipe.execute.return_value = [[command_json.encode()], 1] + pending_pipe.execute.return_value = [b"1", 1] + fetch_pipe.execute.return_value = [[command_json.encode()], 1] channel = RedisChannel(mock_redis, "test:key") commands = channel.fetch_commands() @@ -100,9 +120,15 @@ class TestRedisChannel: def test_fetch_commands_multiple(self): """Test fetching multiple commands from Redis.""" mock_redis = MagicMock() - mock_pipe = MagicMock() - mock_redis.pipeline.return_value.__enter__ = MagicMock(return_value=mock_pipe) - mock_redis.pipeline.return_value.__exit__ = MagicMock(return_value=None) + pending_pipe = MagicMock() + fetch_pipe = MagicMock() + pending_context = MagicMock() + fetch_context = MagicMock() + pending_context.__enter__.return_value = pending_pipe + pending_context.__exit__.return_value = None + fetch_context.__enter__.return_value = fetch_pipe + fetch_context.__exit__.return_value = None + mock_redis.pipeline.side_effect = [pending_context, fetch_context] # Create multiple commands command1 = GraphEngineCommand(command_type=CommandType.ABORT) @@ -112,7 +138,8 @@ class TestRedisChannel: command2_json = json.dumps(command2.model_dump()) # Simulate Redis returning multiple commands - mock_pipe.execute.return_value = [[command1_json.encode(), command2_json.encode()], 1] + pending_pipe.execute.return_value = [b"1", 1] + fetch_pipe.execute.return_value = [[command1_json.encode(), command2_json.encode()], 1] channel = RedisChannel(mock_redis, "test:key") commands = channel.fetch_commands() @@ -124,9 +151,15 @@ class TestRedisChannel: def test_fetch_commands_skips_invalid_json(self): """Test that invalid JSON commands are skipped.""" mock_redis = MagicMock() - mock_pipe = MagicMock() - mock_redis.pipeline.return_value.__enter__ = MagicMock(return_value=mock_pipe) - mock_redis.pipeline.return_value.__exit__ = MagicMock(return_value=None) + pending_pipe = MagicMock() + fetch_pipe = MagicMock() + pending_context = MagicMock() + fetch_context = MagicMock() + pending_context.__enter__.return_value = pending_pipe + pending_context.__exit__.return_value = None + fetch_context.__enter__.return_value = fetch_pipe + fetch_context.__exit__.return_value = None + mock_redis.pipeline.side_effect = [pending_context, fetch_context] # Mix valid and invalid JSON valid_command = AbortCommand() @@ -134,7 +167,8 @@ class TestRedisChannel: invalid_json = b"invalid json {" # Simulate Redis returning mixed valid/invalid commands - mock_pipe.execute.return_value = [[invalid_json, valid_json.encode()], 1] + pending_pipe.execute.return_value = [b"1", 1] + fetch_pipe.execute.return_value = [[invalid_json, valid_json.encode()], 1] channel = RedisChannel(mock_redis, "test:key") commands = channel.fetch_commands() @@ -187,13 +221,20 @@ class TestRedisChannel: def test_atomic_fetch_and_clear(self): """Test that fetch_commands atomically fetches and clears the list.""" mock_redis = MagicMock() - mock_pipe = MagicMock() - mock_redis.pipeline.return_value.__enter__ = MagicMock(return_value=mock_pipe) - mock_redis.pipeline.return_value.__exit__ = MagicMock(return_value=None) + pending_pipe = MagicMock() + fetch_pipe = MagicMock() + pending_context = MagicMock() + fetch_context = MagicMock() + pending_context.__enter__.return_value = pending_pipe + pending_context.__exit__.return_value = None + fetch_context.__enter__.return_value = fetch_pipe + fetch_context.__exit__.return_value = None + mock_redis.pipeline.side_effect = [pending_context, fetch_context] command = AbortCommand() command_json = json.dumps(command.model_dump()) - mock_pipe.execute.return_value = [[command_json.encode()], 1] + pending_pipe.execute.return_value = [b"1", 1] + fetch_pipe.execute.return_value = [[command_json.encode()], 1] channel = RedisChannel(mock_redis, "test:key") @@ -202,7 +243,29 @@ class TestRedisChannel: assert len(commands) == 1 # Verify both lrange and delete were called in the pipeline - assert mock_pipe.lrange.call_count == 1 - assert mock_pipe.delete.call_count == 1 - mock_pipe.lrange.assert_called_with("test:key", 0, -1) - mock_pipe.delete.assert_called_with("test:key") + assert fetch_pipe.lrange.call_count == 1 + assert fetch_pipe.delete.call_count == 1 + fetch_pipe.lrange.assert_called_with("test:key", 0, -1) + fetch_pipe.delete.assert_called_with("test:key") + + def test_fetch_commands_without_pending_marker_returns_empty(self): + """Ensure we avoid unnecessary list reads when pending flag is missing.""" + mock_redis = MagicMock() + pending_pipe = MagicMock() + fetch_pipe = MagicMock() + pending_context = MagicMock() + fetch_context = MagicMock() + pending_context.__enter__.return_value = pending_pipe + pending_context.__exit__.return_value = None + fetch_context.__enter__.return_value = fetch_pipe + fetch_context.__exit__.return_value = None + mock_redis.pipeline.side_effect = [pending_context, fetch_context] + + # Pending flag absent + pending_pipe.execute.return_value = [None, 0] + channel = RedisChannel(mock_redis, "test:key") + commands = channel.fetch_commands() + + assert commands == [] + mock_redis.llen.assert_not_called() + assert mock_redis.pipeline.call_count == 1 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher.py b/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher.py new file mode 100644 index 0000000000..830fc0884d --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher.py @@ -0,0 +1,104 @@ +"""Tests for dispatcher command checking behavior.""" + +from __future__ import annotations + +import queue +from datetime import datetime + +from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus +from core.workflow.graph_engine.event_management.event_manager import EventManager +from core.workflow.graph_engine.orchestration.dispatcher import Dispatcher +from core.workflow.graph_events import NodeRunStartedEvent, NodeRunSucceededEvent +from core.workflow.node_events import NodeRunResult + + +class _StubExecutionCoordinator: + """Stub execution coordinator that tracks command checks.""" + + def __init__(self) -> None: + self.command_checks = 0 + self.scaling_checks = 0 + self._execution_complete = False + self.mark_complete_called = False + self.failed = False + + def check_commands(self) -> None: + self.command_checks += 1 + + def check_scaling(self) -> None: + self.scaling_checks += 1 + + def is_execution_complete(self) -> bool: + return self._execution_complete + + def mark_complete(self) -> None: + self.mark_complete_called = True + + def mark_failed(self, error: Exception) -> None: # pragma: no cover - defensive, not triggered in tests + self.failed = True + + def set_execution_complete(self) -> None: + self._execution_complete = True + + +class _StubEventHandler: + """Minimal event handler that marks execution complete after handling an event.""" + + def __init__(self, coordinator: _StubExecutionCoordinator) -> None: + self._coordinator = coordinator + self.events = [] + + def dispatch(self, event) -> None: + self.events.append(event) + self._coordinator.set_execution_complete() + + +def _run_dispatcher_for_event(event) -> int: + """Run the dispatcher loop for a single event and return command check count.""" + event_queue: queue.Queue = queue.Queue() + event_queue.put(event) + + coordinator = _StubExecutionCoordinator() + event_handler = _StubEventHandler(coordinator) + event_manager = EventManager() + + dispatcher = Dispatcher( + event_queue=event_queue, + event_handler=event_handler, + event_collector=event_manager, + execution_coordinator=coordinator, + ) + + dispatcher._dispatcher_loop() + + return coordinator.command_checks + + +def _make_started_event() -> NodeRunStartedEvent: + return NodeRunStartedEvent( + id="start-event", + node_id="node-1", + node_type=NodeType.CODE, + node_title="Test Node", + start_at=datetime.utcnow(), + ) + + +def _make_succeeded_event() -> NodeRunSucceededEvent: + return NodeRunSucceededEvent( + id="success-event", + node_id="node-1", + node_type=NodeType.CODE, + node_title="Test Node", + start_at=datetime.utcnow(), + node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED), + ) + + +def test_dispatcher_checks_commands_after_node_completion() -> None: + """Dispatcher should only check commands after node completion events.""" + started_checks = _run_dispatcher_for_event(_make_started_event()) + succeeded_checks = _run_dispatcher_for_event(_make_succeeded_event()) + + assert started_checks == 0 + assert succeeded_checks == 1 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py b/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py index bd41fdeee5..e191246bed 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py @@ -132,15 +132,22 @@ class TestRedisStopIntegration: """Test RedisChannel correctly fetches and deserializes commands.""" # Setup mock_redis = MagicMock() - mock_pipeline = MagicMock() - mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline) - mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None) + pending_pipe = MagicMock() + fetch_pipe = MagicMock() + pending_context = MagicMock() + fetch_context = MagicMock() + pending_context.__enter__.return_value = pending_pipe + pending_context.__exit__.return_value = None + fetch_context.__enter__.return_value = fetch_pipe + fetch_context.__exit__.return_value = None + mock_redis.pipeline.side_effect = [pending_context, fetch_context] # Mock command data abort_command_json = json.dumps({"command_type": CommandType.ABORT, "reason": "Test abort", "payload": None}) # Mock pipeline execute to return commands - mock_pipeline.execute.return_value = [ + pending_pipe.execute.return_value = [b"1", 1] + fetch_pipe.execute.return_value = [ [abort_command_json.encode()], # lrange result True, # delete result ] @@ -158,19 +165,29 @@ class TestRedisStopIntegration: assert commands[0].reason == "Test abort" # Verify Redis operations - mock_pipeline.lrange.assert_called_once_with(channel_key, 0, -1) - mock_pipeline.delete.assert_called_once_with(channel_key) + pending_pipe.get.assert_called_once_with(f"{channel_key}:pending") + pending_pipe.delete.assert_called_once_with(f"{channel_key}:pending") + fetch_pipe.lrange.assert_called_once_with(channel_key, 0, -1) + fetch_pipe.delete.assert_called_once_with(channel_key) + assert mock_redis.pipeline.call_count == 2 def test_redis_channel_fetch_commands_handles_invalid_json(self): """Test RedisChannel gracefully handles invalid JSON in commands.""" # Setup mock_redis = MagicMock() - mock_pipeline = MagicMock() - mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline) - mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None) + pending_pipe = MagicMock() + fetch_pipe = MagicMock() + pending_context = MagicMock() + fetch_context = MagicMock() + pending_context.__enter__.return_value = pending_pipe + pending_context.__exit__.return_value = None + fetch_context.__enter__.return_value = fetch_pipe + fetch_context.__exit__.return_value = None + mock_redis.pipeline.side_effect = [pending_context, fetch_context] # Mock invalid command data - mock_pipeline.execute.return_value = [ + pending_pipe.execute.return_value = [b"1", 1] + fetch_pipe.execute.return_value = [ [b"invalid json", b'{"command_type": "invalid_type"}'], # lrange result True, # delete result ] diff --git a/web/app/components/app-sidebar/dataset-info/menu.tsx b/web/app/components/app-sidebar/dataset-info/menu.tsx index fd560ce643..6f91c9c513 100644 --- a/web/app/components/app-sidebar/dataset-info/menu.tsx +++ b/web/app/components/app-sidebar/dataset-info/menu.tsx @@ -3,6 +3,7 @@ import { useTranslation } from 'react-i18next' import MenuItem from './menu-item' import { RiDeleteBinLine, RiEditLine, RiFileDownloadLine } from '@remixicon/react' import Divider from '../../base/divider' +import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' type MenuProps = { showDelete: boolean @@ -18,6 +19,7 @@ const Menu = ({ detectIsUsedByApp, }: MenuProps) => { const { t } = useTranslation() + const runtimeMode = useDatasetDetailContextWithSelector(state => state.dataset?.runtime_mode) return (
@@ -27,11 +29,13 @@ const Menu = ({ name={t('common.operation.edit')} handleClick={openRenameModal} /> - + {runtimeMode === 'rag_pipeline' && ( + + )}
{showDelete && ( <> diff --git a/web/app/components/app/configuration/base/icons/citation.tsx b/web/app/components/app/configuration/base/icons/citation.tsx deleted file mode 100644 index 3aa6b0f0e1..0000000000 --- a/web/app/components/app/configuration/base/icons/citation.tsx +++ /dev/null @@ -1,29 +0,0 @@ -import type { SVGProps } from 'react' - -const CitationIcon = (props: SVGProps) => ( - -) - -export default CitationIcon diff --git a/web/app/components/app/configuration/base/icons/more-like-this-icon.tsx b/web/app/components/app/configuration/base/icons/more-like-this-icon.tsx deleted file mode 100644 index 74c808eb39..0000000000 --- a/web/app/components/app/configuration/base/icons/more-like-this-icon.tsx +++ /dev/null @@ -1,14 +0,0 @@ -'use client' -import type { FC } from 'react' -import React from 'react' - -const MoreLikeThisIcon: FC = () => { - return ( - - - - - - ) -} -export default React.memo(MoreLikeThisIcon) diff --git a/web/app/components/app/configuration/base/icons/suggested-questions-after-answer-icon.tsx b/web/app/components/app/configuration/base/icons/suggested-questions-after-answer-icon.tsx deleted file mode 100644 index cabc2e4d73..0000000000 --- a/web/app/components/app/configuration/base/icons/suggested-questions-after-answer-icon.tsx +++ /dev/null @@ -1,12 +0,0 @@ -'use client' -import type { FC } from 'react' -import React from 'react' - -const SuggestedQuestionsAfterAnswerIcon: FC = () => { - return ( - - - - ) -} -export default React.memo(SuggestedQuestionsAfterAnswerIcon) diff --git a/web/app/components/app/configuration/config-var/index.tsx b/web/app/components/app/configuration/config-var/index.tsx index 2ac68227e3..b9227c6846 100644 --- a/web/app/components/app/configuration/config-var/index.tsx +++ b/web/app/components/app/configuration/config-var/index.tsx @@ -1,10 +1,11 @@ 'use client' import type { FC } from 'react' -import React, { useState } from 'react' +import React, { useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' import { useBoolean } from 'ahooks' import { useContext } from 'use-context-selector' import produce from 'immer' +import { ReactSortable } from 'react-sortablejs' import Panel from '../base/feature-panel' import EditModal from './config-modal' import VarItem from './var-item' @@ -22,6 +23,7 @@ import { useModalContext } from '@/context/modal-context' import { useEventEmitterContextContext } from '@/context/event-emitter' import type { InputVar } from '@/app/components/workflow/types' import { InputVarType } from '@/app/components/workflow/types' +import cn from '@/utils/classnames' export const ADD_EXTERNAL_DATA_TOOL = 'ADD_EXTERNAL_DATA_TOOL' @@ -218,6 +220,16 @@ const ConfigVar: FC = ({ promptVariables, readonly, onPromptVar showEditModal() } + + const promptVariablesWithIds = useMemo(() => promptVariables.map((item) => { + return { + id: item.key, + variable: { ...item }, + } + }), [promptVariables]) + + const canDrag = !readonly && promptVariables.length > 1 + return ( = ({ promptVariables, readonly, onPromptVar )} {hasVar && (
- {promptVariables.map(({ key, name, type, required, config, icon, icon_background }, index) => ( - handleConfig({ type, key, index, name, config, icon, icon_background })} - onRemove={() => handleRemoveVar(index)} - /> - ))} + { onPromptVariablesChange?.(list.map(item => item.variable)) }} + handle='.handle' + ghostClass='opacity-50' + animation={150} + > + {promptVariablesWithIds.map((item, index) => { + const { key, name, type, required, config, icon, icon_background } = item.variable + return ( + handleConfig({ type, key, index, name, config, icon, icon_background })} + onRemove={() => handleRemoveVar(index)} + canDrag={canDrag} + /> + ) + })} +
)} diff --git a/web/app/components/app/configuration/config-var/var-item.tsx b/web/app/components/app/configuration/config-var/var-item.tsx index 78ed4b1031..88cd5d7843 100644 --- a/web/app/components/app/configuration/config-var/var-item.tsx +++ b/web/app/components/app/configuration/config-var/var-item.tsx @@ -3,6 +3,7 @@ import type { FC } from 'react' import React, { useState } from 'react' import { RiDeleteBinLine, + RiDraggable, RiEditLine, } from '@remixicon/react' import type { IInputTypeIconProps } from './input-type-icon' @@ -12,6 +13,7 @@ import Badge from '@/app/components/base/badge' import cn from '@/utils/classnames' type ItemProps = { + className?: string readonly?: boolean name: string label: string @@ -19,9 +21,11 @@ type ItemProps = { type: string onEdit: () => void onRemove: () => void + canDrag?: boolean } const VarItem: FC = ({ + className, readonly, name, label, @@ -29,12 +33,16 @@ const VarItem: FC = ({ type, onEdit, onRemove, + canDrag, }) => { const [isDeleting, setIsDeleting] = useState(false) return ( -
- +
+ + {canDrag && ( + + )}
{name} diff --git a/web/app/components/app/configuration/config/feature/use-feature.tsx b/web/app/components/app/configuration/config/feature/use-feature.tsx deleted file mode 100644 index acc08dd4a4..0000000000 --- a/web/app/components/app/configuration/config/feature/use-feature.tsx +++ /dev/null @@ -1,96 +0,0 @@ -import React, { useEffect } from 'react' - -function useFeature({ - introduction, - setIntroduction, - moreLikeThis, - setMoreLikeThis, - suggestedQuestionsAfterAnswer, - setSuggestedQuestionsAfterAnswer, - speechToText, - setSpeechToText, - textToSpeech, - setTextToSpeech, - citation, - setCitation, - annotation, - setAnnotation, - moderation, - setModeration, -}: { - introduction: string - setIntroduction: (introduction: string) => void - moreLikeThis: boolean - setMoreLikeThis: (moreLikeThis: boolean) => void - suggestedQuestionsAfterAnswer: boolean - setSuggestedQuestionsAfterAnswer: (suggestedQuestionsAfterAnswer: boolean) => void - speechToText: boolean - setSpeechToText: (speechToText: boolean) => void - textToSpeech: boolean - setTextToSpeech: (textToSpeech: boolean) => void - citation: boolean - setCitation: (citation: boolean) => void - annotation: boolean - setAnnotation: (annotation: boolean) => void - moderation: boolean - setModeration: (moderation: boolean) => void -}) { - const [tempShowOpeningStatement, setTempShowOpeningStatement] = React.useState(!!introduction) - useEffect(() => { - // wait to api data back - if (introduction) - setTempShowOpeningStatement(true) - }, [introduction]) - - // const [tempMoreLikeThis, setTempMoreLikeThis] = React.useState(moreLikeThis) - // useEffect(() => { - // setTempMoreLikeThis(moreLikeThis) - // }, [moreLikeThis]) - - const featureConfig = { - openingStatement: tempShowOpeningStatement, - moreLikeThis, - suggestedQuestionsAfterAnswer, - speechToText, - textToSpeech, - citation, - annotation, - moderation, - } - const handleFeatureChange = (key: string, value: boolean) => { - switch (key) { - case 'openingStatement': - if (!value) - setIntroduction('') - - setTempShowOpeningStatement(value) - break - case 'moreLikeThis': - setMoreLikeThis(value) - break - case 'suggestedQuestionsAfterAnswer': - setSuggestedQuestionsAfterAnswer(value) - break - case 'speechToText': - setSpeechToText(value) - break - case 'textToSpeech': - setTextToSpeech(value) - break - case 'citation': - setCitation(value) - break - case 'annotation': - setAnnotation(value) - break - case 'moderation': - setModeration(value) - } - } - return { - featureConfig, - handleFeatureChange, - } -} - -export default useFeature diff --git a/web/app/components/app/configuration/prompt-mode/advanced-mode-waring.tsx b/web/app/components/app/configuration/prompt-mode/advanced-mode-waring.tsx deleted file mode 100644 index f207cddd16..0000000000 --- a/web/app/components/app/configuration/prompt-mode/advanced-mode-waring.tsx +++ /dev/null @@ -1,50 +0,0 @@ -'use client' -import type { FC } from 'react' -import React from 'react' -import { useTranslation } from 'react-i18next' -import { useDocLink } from '@/context/i18n' -type Props = { - onReturnToSimpleMode: () => void -} - -const AdvancedModeWarning: FC = ({ - onReturnToSimpleMode, -}) => { - const { t } = useTranslation() - const docLink = useDocLink() - const [show, setShow] = React.useState(true) - if (!show) - return null - return ( -
-
{t('appDebug.promptMode.advancedWarning.title')}
-
-
- {t('appDebug.promptMode.advancedWarning.description')} - - {t('appDebug.promptMode.advancedWarning.learnMore')} - -
- -
-
-
{t('appDebug.promptMode.switchBack')}
-
-
setShow(false)} - >{t('appDebug.promptMode.advancedWarning.ok')}
-
- -
-
- ) -} -export default React.memo(AdvancedModeWarning) diff --git a/web/app/components/base/auto-height-textarea/common.tsx b/web/app/components/base/auto-height-textarea/common.tsx deleted file mode 100644 index eb0275cfcd..0000000000 --- a/web/app/components/base/auto-height-textarea/common.tsx +++ /dev/null @@ -1,54 +0,0 @@ -import { useEffect, useRef } from 'react' -import cn from '@/utils/classnames' - -type AutoHeightTextareaProps - = & React.DetailedHTMLProps, HTMLTextAreaElement> - & { outerClassName?: string } - -const AutoHeightTextarea = ( - { - ref: outRef, - outerClassName, - value, - className, - placeholder, - autoFocus, - disabled, - ...rest - }: AutoHeightTextareaProps & { - ref: React.RefObject; - }, -) => { - const innerRef = useRef(null) - const ref = outRef || innerRef - - useEffect(() => { - if (autoFocus && !disabled && value) { - if (typeof ref !== 'function') { - ref.current?.setSelectionRange(`${value}`.length, `${value}`.length) - ref.current?.focus() - } - } - }, [autoFocus, disabled, ref]) - return ( - (
-
-
- {!value ? placeholder : `${value}`.replace(/\n$/, '\n ')} -
-