From 226f14a20f183753d413d3d18f6eefb75cd44233 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 4 Sep 2025 15:35:20 +0800 Subject: [PATCH] feat(graph_engine): implement scale down worker Signed-off-by: -LAN- --- .../worker_management/simple_worker_pool.py | 169 +++++++++++++++++- 1 file changed, 166 insertions(+), 3 deletions(-) diff --git a/api/core/workflow/graph_engine/worker_management/simple_worker_pool.py b/api/core/workflow/graph_engine/worker_management/simple_worker_pool.py index 94b8ff3ca2..367c2b36fc 100644 --- a/api/core/workflow/graph_engine/worker_management/simple_worker_pool.py +++ b/api/core/workflow/graph_engine/worker_management/simple_worker_pool.py @@ -5,8 +5,10 @@ This is a simpler implementation that merges WorkerPool, ActivityTracker, DynamicScaler, and WorkerFactory into a single class. """ +import logging import queue import threading +import time from typing import TYPE_CHECKING, final from configs import dify_config @@ -15,6 +17,8 @@ from core.workflow.graph_events import GraphNodeEventBase from ..worker import Worker +logger = logging.getLogger(__name__) + if TYPE_CHECKING: from contextvars import Context @@ -74,6 +78,10 @@ class SimpleWorkerPool: self._lock = threading.RLock() self._running = False + # Track worker idle times for scale-down + self._worker_idle_times: dict[int, float] = {} + self._worker_active_states: dict[int, bool] = {} + def start(self, initial_count: int | None = None) -> None: """ Start the worker pool. @@ -97,6 +105,14 @@ class SimpleWorkerPool: else: initial_count = min(self._min_workers + 2, self._max_workers) + logger.debug( + "Starting worker pool: %d workers (nodes=%d, min=%d, max=%d)", + initial_count, + node_count, + self._min_workers, + self._max_workers, + ) + # Create initial workers for _ in range(initial_count): self._create_worker() @@ -105,6 +121,10 @@ class SimpleWorkerPool: """Stop all workers in the pool.""" with self._lock: self._running = False + worker_count = len(self._workers) + + if worker_count > 0: + logger.debug("Stopping worker pool: %d workers", worker_count) # Stop all workers for worker in self._workers: @@ -116,6 +136,8 @@ class SimpleWorkerPool: worker.join(timeout=10.0) self._workers.clear() + self._worker_active_states.clear() + self._worker_idle_times.clear() def _create_worker(self) -> None: """Create and start a new worker.""" @@ -129,11 +151,146 @@ class SimpleWorkerPool: worker_id=worker_id, flask_app=self._flask_app, context_vars=self._context_vars, + on_idle_callback=self._on_worker_idle, + on_active_callback=self._on_worker_active, ) worker.start() self._workers.append(worker) + # Initialize tracking + self._worker_active_states[worker_id] = True + self._worker_idle_times[worker_id] = 0.0 + + def _on_worker_idle(self, worker_id: int) -> None: + """Handle worker becoming idle.""" + with self._lock: + if worker_id not in self._worker_active_states: + return + + # Mark as idle and record time if transitioning from active + if self._worker_active_states.get(worker_id, False): + self._worker_active_states[worker_id] = False + self._worker_idle_times[worker_id] = time.time() + + def _on_worker_active(self, worker_id: int) -> None: + """Handle worker becoming active.""" + with self._lock: + if worker_id not in self._worker_active_states: + return + + # Mark as active and clear idle time + self._worker_active_states[worker_id] = True + self._worker_idle_times[worker_id] = 0.0 + + def _remove_worker(self, worker: Worker, worker_id: int) -> None: + """Remove a specific worker from the pool.""" + # Stop the worker + worker.stop() + + # Wait for it to finish + if worker.is_alive(): + worker.join(timeout=2.0) + + # Remove from list and tracking + if worker in self._workers: + self._workers.remove(worker) + + # Clean up tracking + self._worker_active_states.pop(worker_id, None) + self._worker_idle_times.pop(worker_id, None) + + def _try_scale_up(self, queue_depth: int, current_count: int) -> bool: + """ + Try to scale up workers if needed. + + Args: + queue_depth: Current queue depth + current_count: Current number of workers + + Returns: + True if scaled up, False otherwise + """ + if queue_depth > self._scale_up_threshold and current_count < self._max_workers: + old_count = current_count + self._create_worker() + + logger.debug( + "Scaled up workers: %d -> %d (queue_depth=%d exceeded threshold=%d)", + old_count, + len(self._workers), + queue_depth, + self._scale_up_threshold, + ) + return True + return False + + def _try_scale_down(self, queue_depth: int, current_count: int, active_count: int, idle_count: int) -> bool: + """ + Try to scale down workers if we have excess capacity. + + Args: + queue_depth: Current queue depth + current_count: Current number of workers + active_count: Number of active workers + idle_count: Number of idle workers + + Returns: + True if scaled down, False otherwise + """ + # Skip if we're at minimum or have no idle workers + if current_count <= self._min_workers or idle_count == 0: + return False + + # Check if we have excess capacity + has_excess_capacity = ( + queue_depth <= active_count # Active workers can handle current queue + or idle_count > active_count # More idle than active workers + or (queue_depth == 0 and idle_count > 0) # No work and have idle workers + ) + + if not has_excess_capacity: + return False + + # Find and remove idle workers + current_time = time.time() + workers_to_remove = [] + + for worker in self._workers: + worker_id = worker._worker_id + + # Check if worker is idle and has exceeded idle time threshold + if not self._worker_active_states.get(worker_id, True) and self._worker_idle_times.get(worker_id, 0) > 0: + idle_duration = current_time - self._worker_idle_times[worker_id] + if idle_duration >= self._scale_down_idle_time: + # Don't remove if it would leave us unable to handle the queue + remaining_workers = current_count - len(workers_to_remove) - 1 + if remaining_workers >= self._min_workers and remaining_workers >= max(1, queue_depth // 2): + workers_to_remove.append((worker, worker_id)) + # Only remove one worker per check to avoid aggressive scaling + break + + # Remove idle workers if any found + if workers_to_remove: + old_count = current_count + for worker, worker_id in workers_to_remove: + self._remove_worker(worker, worker_id) + + logger.debug( + "Scaled down workers: %d -> %d (removed %d idle workers after %.1fs, " + "queue_depth=%d, active=%d, idle=%d)", + old_count, + len(self._workers), + len(workers_to_remove), + self._scale_down_idle_time, + queue_depth, + active_count, + idle_count - len(workers_to_remove), + ) + return True + + return False + def check_and_scale(self) -> None: """Check and perform scaling if needed.""" with self._lock: @@ -143,9 +300,15 @@ class SimpleWorkerPool: current_count = len(self._workers) queue_depth = self._ready_queue.qsize() - # Simple scaling logic - if queue_depth > self._scale_up_threshold and current_count < self._max_workers: - self._create_worker() + # Count active vs idle workers + active_count = sum(1 for state in self._worker_active_states.values() if state) + idle_count = current_count - active_count + + # Try to scale up if queue is backing up + self._try_scale_up(queue_depth, current_count) + + # Try to scale down if we have excess capacity + self._try_scale_down(queue_depth, current_count, active_count, idle_count) def get_worker_count(self) -> int: """Get current number of workers."""