mirror of
https://github.com/langgenius/dify.git
synced 2026-05-12 07:37:09 +08:00
refactor dify-agent runs to local scheduler
This commit is contained in:
parent
3c95ff4782
commit
658caa2ae7
@ -1,8 +1,9 @@
|
||||
# Dify Agent Run API
|
||||
|
||||
The Dify Agent API exposes asynchronous agent runs backed by Agenton compositor
|
||||
configuration, Pydantic AI runtime execution, and Redis Streams event logs. The
|
||||
FastAPI application lives at `dify-agent/src/dify_agent/server/app.py`.
|
||||
configuration, Pydantic AI runtime execution, Redis run records, and per-run Redis
|
||||
Streams event logs. The FastAPI application lives at
|
||||
`dify-agent/src/dify_agent/server/app.py`.
|
||||
|
||||
## Input model
|
||||
|
||||
@ -14,8 +15,8 @@ field becomes `Compositor.user_prompts` and is passed to Pydantic AI as the run
|
||||
input.
|
||||
|
||||
Blank user input is rejected. A request with no user prompt, an empty string, or
|
||||
only whitespace strings such as `"user": ["", " "]` returns `422` from the API
|
||||
or a runner validation error if it reaches worker execution.
|
||||
only whitespace strings such as `"user": ["", " "]` returns `422` before a run
|
||||
record is created.
|
||||
|
||||
The server does not implement a Pydantic AI history layer. Resumable Agenton
|
||||
state is represented only by `session_snapshot`.
|
||||
@ -57,10 +58,13 @@ Response (`202 Accepted`):
|
||||
```json
|
||||
{
|
||||
"run_id": "4a7f9a98-5c55-48d0-8f3e-87ef2cf81234",
|
||||
"status": "queued"
|
||||
"status": "running"
|
||||
}
|
||||
```
|
||||
|
||||
The server persists the run record and schedules execution immediately in the
|
||||
same FastAPI process. Redis is not used as a job queue.
|
||||
|
||||
`agent_profile.provider` currently supports the credential-free `test` profile.
|
||||
|
||||
Validation error example (`422`):
|
||||
@ -91,7 +95,6 @@ Response:
|
||||
|
||||
Status values are:
|
||||
|
||||
- `queued`
|
||||
- `running`
|
||||
- `succeeded`
|
||||
- `failed`
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
# Operating the Dify Agent Run Server
|
||||
|
||||
This guide describes how to run the MVP Dify Agent API server and worker. The
|
||||
server is implemented in `dify-agent/src/dify_agent/server/app.py` and uses Redis
|
||||
for run records, job queues, and event streams.
|
||||
This guide describes how to run the MVP Dify Agent API server. The server is
|
||||
implemented in `dify-agent/src/dify_agent/server/app.py` and uses Redis for run
|
||||
records and per-run event streams only.
|
||||
|
||||
## Default local startup
|
||||
|
||||
@ -15,7 +15,7 @@ uv run --project dify-agent uvicorn dify_agent.server.app:app --reload
|
||||
By default, the FastAPI lifespan creates both:
|
||||
|
||||
- one Redis-backed run store used by HTTP routes
|
||||
- one embedded Redis Streams worker task that executes queued runs
|
||||
- one process-local scheduler that starts background `asyncio` run tasks
|
||||
|
||||
This means local development needs one uvicorn process plus Redis. Run execution
|
||||
still happens outside request handlers, so client disconnects do not cancel the
|
||||
@ -29,54 +29,32 @@ also reads `.env` and `dify-agent/.env` when present.
|
||||
| Environment variable | Default | Description |
|
||||
| --- | --- | --- |
|
||||
| `DIFY_AGENT_REDIS_URL` | `redis://localhost:6379/0` | Redis connection URL. |
|
||||
| `DIFY_AGENT_REDIS_PREFIX` | `dify-agent` | Prefix for Redis record, job, and event keys. |
|
||||
| `DIFY_AGENT_WORKER_ENABLED` | `true` | Starts the embedded worker in the FastAPI process when true. |
|
||||
| `DIFY_AGENT_WORKER_GROUP_NAME` | `run-workers` | Redis consumer group used by workers. |
|
||||
| `DIFY_AGENT_WORKER_CONSUMER_NAME` | unset | Explicit consumer name. If unset, the API process uses `api-{hostname}-{pid}`; the standalone worker uses `worker-1`. |
|
||||
| `DIFY_AGENT_WORKER_PENDING_IDLE_MS` | `600000` | Idle time before a pending job may be reclaimed with `XAUTOCLAIM` (10 minutes). |
|
||||
|
||||
Boolean settings accept Pydantic settings values such as `false`, `0`, or `no`.
|
||||
| `DIFY_AGENT_REDIS_PREFIX` | `dify-agent` | Prefix for Redis record and event keys. |
|
||||
| `DIFY_AGENT_SHUTDOWN_GRACE_SECONDS` | `30` | Seconds to wait for active local runs during graceful shutdown before cancellation. |
|
||||
|
||||
Example `.env`:
|
||||
|
||||
```env
|
||||
DIFY_AGENT_REDIS_URL=redis://localhost:6379/0
|
||||
DIFY_AGENT_REDIS_PREFIX=dify-agent-dev
|
||||
DIFY_AGENT_WORKER_ENABLED=true
|
||||
DIFY_AGENT_WORKER_PENDING_IDLE_MS=600000
|
||||
DIFY_AGENT_SHUTDOWN_GRACE_SECONDS=30
|
||||
```
|
||||
|
||||
## Running a separate worker
|
||||
## Scheduling and shutdown semantics
|
||||
|
||||
For deployments that want to scale HTTP and worker processes independently,
|
||||
disable the embedded worker and start a worker process separately:
|
||||
`POST /runs` validates the compositor, persists a `running` run record, and starts
|
||||
an `asyncio` task in the same process. There is no Redis job stream, consumer
|
||||
group, pending reclaim, or automatic retry layer.
|
||||
|
||||
```bash
|
||||
DIFY_AGENT_WORKER_ENABLED=false \
|
||||
uv run --project dify-agent uvicorn dify_agent.server.app:app
|
||||
During FastAPI shutdown the scheduler rejects new runs, waits up to
|
||||
`DIFY_AGENT_SHUTDOWN_GRACE_SECONDS` for active tasks, then cancels remaining tasks
|
||||
and best-effort appends a `run_failed` event plus failed status. A hard process
|
||||
crash can still leave active runs stuck as `running`; there is no in-service
|
||||
recovery or worker handoff.
|
||||
|
||||
uv run --project dify-agent python -m dify_agent.worker.job_worker
|
||||
```
|
||||
|
||||
Use the same Redis URL, prefix, and worker group for the API process and all
|
||||
standalone workers. Give each live worker a unique
|
||||
`DIFY_AGENT_WORKER_CONSUMER_NAME` when running multiple standalone workers.
|
||||
|
||||
## Redis Streams reliability
|
||||
|
||||
Run creation stores the run record and enqueues the worker job in one Redis
|
||||
transaction (`MULTI/EXEC`). A create request either persists both pieces or fails
|
||||
without leaving a queued run that has no job.
|
||||
|
||||
Workers read jobs from a Redis Streams consumer group. If a worker crashes after
|
||||
receiving a job but before acknowledging it, Redis keeps the entry pending. On
|
||||
later iterations, workers call `XAUTOCLAIM` and reclaim entries idle for at least
|
||||
`DIFY_AGENT_WORKER_PENDING_IDLE_MS` before reading new `>` entries. The default
|
||||
idle time is `600000` milliseconds (10 minutes).
|
||||
|
||||
Choose the pending idle value according to your longest expected run time. A
|
||||
value that is too short can cause a healthy long-running job to be reclaimed by
|
||||
another worker; a value that is too long delays recovery after crashes.
|
||||
Horizontal scaling is possible by running multiple API processes against the same
|
||||
Redis prefix, but each process executes only the runs it accepted. Redis provides
|
||||
shared status/event visibility, not load balancing or queued-job recovery.
|
||||
|
||||
## Run inputs and session snapshots
|
||||
|
||||
@ -103,8 +81,8 @@ whose Agenton layers provide user input. With the MVP registry, use
|
||||
```
|
||||
|
||||
`config.user` can be a string or a list of strings. Empty or whitespace-only
|
||||
effective prompts are rejected with `422` at the API boundary or with a runner
|
||||
validation error if they reach execution.
|
||||
effective prompts are rejected during create-run validation before the run is
|
||||
persisted or scheduled.
|
||||
|
||||
There is no Pydantic AI history layer. To resume Agenton layer state, pass the
|
||||
`session_snapshot` emitted by a previous run together with a compositor that has
|
||||
@ -115,8 +93,8 @@ the same layer names and order.
|
||||
Use the HTTP status endpoint for coarse state and the event endpoints for detailed
|
||||
progress:
|
||||
|
||||
- `POST /runs` creates a queued run.
|
||||
- `GET /runs/{run_id}` returns `queued`, `running`, `succeeded`, or `failed`.
|
||||
- `POST /runs` creates a running run and schedules it locally.
|
||||
- `GET /runs/{run_id}` returns `running`, `succeeded`, or `failed`.
|
||||
- `GET /runs/{run_id}/events` polls the Redis Stream event log with `after` and
|
||||
`next_cursor` cursors.
|
||||
- `GET /runs/{run_id}/events/sse` replays and streams events over SSE. The SSE
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
"""Example consumer for the Dify Agent run server.
|
||||
|
||||
Requires Redis and a running API server. The server starts its Redis Streams
|
||||
worker in the same process by default, for example:
|
||||
Requires Redis and a running API server. The server schedules runs in-process, for
|
||||
example:
|
||||
|
||||
uv run --project dify-agent uvicorn dify_agent.server.app:app --reload
|
||||
|
||||
|
||||
145
dify-agent/src/dify_agent/runtime/run_scheduler.py
Normal file
145
dify-agent/src/dify_agent/runtime/run_scheduler.py
Normal file
@ -0,0 +1,145 @@
|
||||
"""In-process scheduling for Dify Agent runs.
|
||||
|
||||
The scheduler is intentionally process-local: it persists a run record, starts an
|
||||
``asyncio.Task`` for ``AgentRunRunner.run()``, and keeps only a transient active
|
||||
task registry. Redis remains the durable source for status and event streams, but
|
||||
there is no Redis job queue or cross-process handoff. If the process crashes,
|
||||
currently active runs are lost until an external operator marks or retries them.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Protocol
|
||||
|
||||
from dify_agent.runtime.compositor_factory import build_pydantic_ai_compositor
|
||||
from dify_agent.runtime.event_sink import RunEventSink, emit_run_event
|
||||
from dify_agent.runtime.runner import AgentRunRunner
|
||||
from dify_agent.runtime.user_prompt_validation import EMPTY_USER_PROMPTS_ERROR, has_non_blank_user_prompt
|
||||
from dify_agent.server.schemas import CreateRunRequest, RunRecord
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SchedulerStoppingError(RuntimeError):
|
||||
"""Raised when a create-run request arrives after shutdown has started."""
|
||||
|
||||
|
||||
class RunStore(RunEventSink, Protocol):
|
||||
"""Persistence boundary needed by the scheduler."""
|
||||
|
||||
async def create_run(self, request: CreateRunRequest) -> RunRecord:
|
||||
"""Persist a new run record and return it with status ``running``."""
|
||||
...
|
||||
|
||||
|
||||
class RunnableRun(Protocol):
|
||||
"""Executable unit for one scheduled run."""
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Run until terminal status/events have been written or cancellation occurs."""
|
||||
...
|
||||
|
||||
|
||||
type RunRunnerFactory = Callable[[RunRecord], RunnableRun]
|
||||
|
||||
|
||||
class RunScheduler:
|
||||
"""Owns process-local run tasks and best-effort graceful shutdown.
|
||||
|
||||
``active_tasks`` is mutated only on the event loop that calls ``create_run``
|
||||
and ``shutdown``. The task registry is not durable; it exists so the lifespan
|
||||
hook can wait for in-flight work and mark cancelled runs failed before Redis is
|
||||
closed. A lock guards the stopping flag, run persistence, and task
|
||||
registration so shutdown cannot complete while a run is between record
|
||||
creation and active-task tracking.
|
||||
"""
|
||||
|
||||
store: RunStore
|
||||
shutdown_grace_seconds: float
|
||||
active_tasks: dict[str, asyncio.Task[None]]
|
||||
stopping: bool
|
||||
runner_factory: RunRunnerFactory
|
||||
_lifecycle_lock: asyncio.Lock
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
store: RunStore,
|
||||
shutdown_grace_seconds: float = 30,
|
||||
runner_factory: RunRunnerFactory | None = None,
|
||||
) -> None:
|
||||
self.store = store
|
||||
self.shutdown_grace_seconds = shutdown_grace_seconds
|
||||
self.active_tasks = {}
|
||||
self.stopping = False
|
||||
self.runner_factory = runner_factory or self._default_runner_factory
|
||||
self._lifecycle_lock = asyncio.Lock()
|
||||
|
||||
async def create_run(self, request: CreateRunRequest) -> RunRecord:
|
||||
"""Validate, persist, and schedule one run in the current process.
|
||||
|
||||
The returned record is already ``running``. The background task is removed
|
||||
from ``active_tasks`` when it finishes, regardless of success or failure.
|
||||
"""
|
||||
compositor = build_pydantic_ai_compositor(request.compositor)
|
||||
if not has_non_blank_user_prompt(compositor.user_prompts):
|
||||
raise ValueError(EMPTY_USER_PROMPTS_ERROR)
|
||||
|
||||
async with self._lifecycle_lock:
|
||||
if self.stopping:
|
||||
raise SchedulerStoppingError("run scheduler is shutting down")
|
||||
record = await self.store.create_run(request)
|
||||
task = asyncio.create_task(self._run_record(record), name=f"dify-agent-run-{record.run_id}")
|
||||
self.active_tasks[record.run_id] = task
|
||||
task.add_done_callback(lambda _task, run_id=record.run_id: self.active_tasks.pop(run_id, None))
|
||||
return record
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Stop accepting runs, wait briefly, then cancel and fail unfinished runs."""
|
||||
async with self._lifecycle_lock:
|
||||
self.stopping = True
|
||||
if not self.active_tasks:
|
||||
return
|
||||
tasks_by_run_id = dict(self.active_tasks)
|
||||
done, pending = await asyncio.wait(tasks_by_run_id.values(), timeout=self.shutdown_grace_seconds)
|
||||
del done
|
||||
if not pending:
|
||||
return
|
||||
|
||||
pending_run_ids = [run_id for run_id, task in tasks_by_run_id.items() if task in pending]
|
||||
for task in pending:
|
||||
_ = task.cancel()
|
||||
_ = await asyncio.gather(*pending, return_exceptions=True)
|
||||
for run_id in pending_run_ids:
|
||||
await self._mark_cancelled_run_failed(run_id)
|
||||
|
||||
async def _run_record(self, record: RunRecord) -> None:
|
||||
"""Execute a stored run and log failures already reflected in events."""
|
||||
try:
|
||||
await self.runner_factory(record).run()
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("scheduled run failed", extra={"run_id": record.run_id})
|
||||
|
||||
def _default_runner_factory(self, record: RunRecord) -> RunnableRun:
|
||||
"""Create the production runner for a stored run record."""
|
||||
return AgentRunRunner(sink=self.store, request=record.request, run_id=record.run_id)
|
||||
|
||||
async def _mark_cancelled_run_failed(self, run_id: str) -> None:
|
||||
"""Best-effort failure event/status for shutdown-cancelled runs."""
|
||||
message = "run cancelled during server shutdown"
|
||||
try:
|
||||
_ = await emit_run_event(
|
||||
self.store,
|
||||
run_id=run_id,
|
||||
type="run_failed",
|
||||
data={"error": message, "reason": "shutdown"},
|
||||
)
|
||||
await self.store.update_status(run_id, "failed", message)
|
||||
except Exception:
|
||||
logger.exception("failed to mark cancelled run failed", extra={"run_id": run_id})
|
||||
|
||||
|
||||
__all__ = ["RunScheduler", "SchedulerStoppingError"]
|
||||
@ -1,4 +1,4 @@
|
||||
"""Runtime execution for one queued Dify Agent run.
|
||||
"""Runtime execution for one scheduled Dify Agent run.
|
||||
|
||||
The runner is storage-agnostic: it builds an Agenton compositor, enters or
|
||||
resumes its session, runs pydantic-ai with ``compositor.user_prompts`` as the user
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
"""Validation for effective user prompts produced by Agenton compositors.
|
||||
|
||||
Validation happens after safe compositor construction so API and worker paths use
|
||||
the same semantics as the actual pydantic-ai input. Blank string fragments do not
|
||||
Validation happens after safe compositor construction so scheduler and runner
|
||||
paths use the same semantics as the actual pydantic-ai input. Blank string fragments do not
|
||||
count as meaningful input; non-string ``UserContent`` is treated as intentional
|
||||
content because rich media/message parts do not have a universal whitespace
|
||||
representation.
|
||||
|
||||
@ -1,78 +1,54 @@
|
||||
"""FastAPI application factory for the Dify Agent run server.
|
||||
|
||||
The HTTP process owns Redis clients, route wiring, and by default one embedded
|
||||
Redis Streams worker task. Run execution still happens outside request handlers,
|
||||
so client latency and disconnects do not control the agent runtime, but local
|
||||
development only needs one ``uvicorn`` process plus Redis.
|
||||
The HTTP process owns Redis clients, route wiring, and a process-local scheduler.
|
||||
Run execution happens in background ``asyncio`` tasks rather than request
|
||||
handlers, so client disconnects do not cancel the agent runtime. Redis persists
|
||||
run records and per-run event streams only; it is not used as a job queue.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import socket
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager, suppress
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from dify_agent.runtime.run_scheduler import RunScheduler
|
||||
from dify_agent.server.routes.runs import create_runs_router
|
||||
from dify_agent.server.settings import ServerSettings
|
||||
from dify_agent.storage.redis_run_store import RedisRunStore
|
||||
from dify_agent.worker.job_worker import RunJobWorker
|
||||
|
||||
|
||||
def create_app(settings: ServerSettings | None = None) -> FastAPI:
|
||||
"""Build the FastAPI app with one shared Redis-backed run store and worker."""
|
||||
"""Build the FastAPI app with one shared Redis store and local scheduler."""
|
||||
resolved_settings = settings or ServerSettings()
|
||||
state: dict[str, RedisRunStore] = {}
|
||||
state: dict[str, RedisRunStore | RunScheduler] = {}
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
redis = Redis.from_url(resolved_settings.redis_url)
|
||||
store = RedisRunStore(redis, prefix=resolved_settings.redis_prefix)
|
||||
scheduler = RunScheduler(store=store, shutdown_grace_seconds=resolved_settings.shutdown_grace_seconds)
|
||||
state["store"] = store
|
||||
worker_task: asyncio.Task[None] | None = None
|
||||
if resolved_settings.worker_enabled:
|
||||
worker = RunJobWorker(
|
||||
store=store,
|
||||
group_name=resolved_settings.worker_group_name,
|
||||
consumer_name=_worker_consumer_name(resolved_settings),
|
||||
pending_idle_ms=resolved_settings.worker_pending_idle_ms,
|
||||
)
|
||||
worker_task = asyncio.create_task(worker.run_forever(), name="dify-agent-run-worker")
|
||||
# Give the worker one loop turn so startup tests and immediate failures observe the task.
|
||||
await asyncio.sleep(0)
|
||||
state["scheduler"] = scheduler
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if worker_task is not None:
|
||||
_ = worker_task.cancel()
|
||||
with suppress(asyncio.CancelledError):
|
||||
await worker_task
|
||||
await scheduler.shutdown()
|
||||
await redis.aclose()
|
||||
|
||||
app = FastAPI(title="Dify Agent Run Server", version="0.1.0", lifespan=lifespan)
|
||||
|
||||
def get_store() -> RedisRunStore:
|
||||
return state["store"]
|
||||
return state["store"] # pyright: ignore[reportReturnType]
|
||||
|
||||
app.include_router(create_runs_router(get_store))
|
||||
def get_scheduler() -> RunScheduler:
|
||||
return state["scheduler"] # pyright: ignore[reportReturnType]
|
||||
|
||||
app.include_router(create_runs_router(get_store, get_scheduler))
|
||||
return app
|
||||
|
||||
|
||||
app = create_app()
|
||||
|
||||
|
||||
def _worker_consumer_name(settings: ServerSettings) -> str:
|
||||
"""Return a stable-enough consumer name for this API process.
|
||||
|
||||
Redis consumer names should be unique per live process. The explicit setting
|
||||
is useful for tests or controlled deployments; otherwise hostname and PID
|
||||
distinguish common ``uvicorn --workers`` and reload processes.
|
||||
"""
|
||||
if settings.worker_consumer_name:
|
||||
return settings.worker_consumer_name
|
||||
return f"api-{socket.gethostname()}-{os.getpid()}"
|
||||
|
||||
|
||||
__all__ = ["app", "create_app"]
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
"""FastAPI routes for asynchronous agent runs.
|
||||
|
||||
Controllers translate storage/validation errors into HTTP status codes and keep
|
||||
worker execution out of the request path. A created run is only queued; clients
|
||||
observe progress through polling or SSE replay.
|
||||
Controllers translate known validation and shutdown errors into HTTP status codes.
|
||||
Unexpected scheduler or storage failures are intentionally left for FastAPI's
|
||||
server-error handling so infrastructure problems are not reported as client input
|
||||
errors. Created runs are scheduled in the current process and observed through
|
||||
status polling or SSE replay backed by Redis event streams.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
@ -12,23 +14,27 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from dify_agent.runtime.compositor_factory import build_pydantic_ai_compositor
|
||||
from dify_agent.runtime.run_scheduler import RunScheduler, SchedulerStoppingError
|
||||
from dify_agent.runtime.user_prompt_validation import EMPTY_USER_PROMPTS_ERROR, has_non_blank_user_prompt
|
||||
from dify_agent.server.schemas import CreateRunRequest, CreateRunResponse, RunEventsResponse, RunStatusResponse
|
||||
from dify_agent.server.sse import sse_event_stream
|
||||
from dify_agent.storage.redis_run_store import RedisRunStore, RunNotFoundError
|
||||
|
||||
|
||||
def create_runs_router(get_store: Callable[[], RedisRunStore]) -> APIRouter:
|
||||
def create_runs_router(get_store: Callable[[], RedisRunStore], get_scheduler: Callable[[], RunScheduler]) -> APIRouter:
|
||||
"""Create routes bound to the application's store dependency provider."""
|
||||
router = APIRouter(prefix="/runs", tags=["runs"])
|
||||
|
||||
async def store_dep() -> RedisRunStore:
|
||||
return get_store()
|
||||
|
||||
async def scheduler_dep() -> RunScheduler:
|
||||
return get_scheduler()
|
||||
|
||||
@router.post("", response_model=CreateRunResponse, status_code=202)
|
||||
async def create_run(
|
||||
request: CreateRunRequest,
|
||||
store: Annotated[RedisRunStore, Depends(store_dep)],
|
||||
scheduler: Annotated[RunScheduler, Depends(scheduler_dep)],
|
||||
) -> CreateRunResponse:
|
||||
try:
|
||||
compositor = build_pydantic_ai_compositor(request.compositor)
|
||||
@ -36,7 +42,11 @@ def create_runs_router(get_store: Callable[[], RedisRunStore]) -> APIRouter:
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
if not has_non_blank_user_prompt(compositor.user_prompts):
|
||||
raise HTTPException(status_code=422, detail=EMPTY_USER_PROMPTS_ERROR)
|
||||
record = await store.create_run(request)
|
||||
|
||||
try:
|
||||
record = await scheduler.create_run(request)
|
||||
except SchedulerStoppingError as exc:
|
||||
raise HTTPException(status_code=503, detail="run scheduler is shutting down") from exc
|
||||
return CreateRunResponse(run_id=record.run_id, status=record.status)
|
||||
|
||||
@router.get("/{run_id}", response_model=RunStatusResponse)
|
||||
|
||||
@ -15,7 +15,7 @@ from pydantic import BaseModel, ConfigDict, Field, JsonValue, field_validator
|
||||
from agenton.compositor import CompositorConfig, CompositorSessionSnapshot
|
||||
|
||||
|
||||
RunStatus = Literal["queued", "running", "succeeded", "failed"]
|
||||
RunStatus = Literal["running", "succeeded", "failed"]
|
||||
RunEventType = Literal[
|
||||
"run_started",
|
||||
"pydantic_ai_event",
|
||||
@ -60,7 +60,7 @@ class CreateRunRequest(BaseModel):
|
||||
|
||||
|
||||
class CreateRunResponse(BaseModel):
|
||||
"""Response returned after a run job has been durably queued."""
|
||||
"""Response returned after a run has been persisted and scheduled locally."""
|
||||
|
||||
run_id: str
|
||||
status: RunStatus
|
||||
@ -102,15 +102,6 @@ class RunEventsResponse(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class RunnerJob(BaseModel):
|
||||
"""Durable worker payload stored in Redis streams."""
|
||||
|
||||
run_id: str
|
||||
request: CreateRunRequest
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class RunRecord(BaseModel):
|
||||
"""Internal representation persisted for status reads."""
|
||||
|
||||
@ -141,7 +132,6 @@ __all__ = [
|
||||
"RunRecord",
|
||||
"RunStatus",
|
||||
"RunStatusResponse",
|
||||
"RunnerJob",
|
||||
"new_run_id",
|
||||
"utc_now",
|
||||
]
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
"""Configuration for the FastAPI run server and embedded worker."""
|
||||
"""Configuration for the FastAPI run server."""
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
@ -6,20 +6,11 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class ServerSettings(BaseSettings):
|
||||
"""Environment-backed settings shared by HTTP routes and the run worker.
|
||||
|
||||
The default deployment mode runs the Redis Streams worker inside the FastAPI
|
||||
process so a single ``uvicorn`` command is enough for local development and
|
||||
small deployments. Set ``DIFY_AGENT_WORKER_ENABLED=false`` when running a
|
||||
separate worker process or when only the HTTP API should be started.
|
||||
"""
|
||||
"""Environment-backed settings for Redis persistence and local scheduling."""
|
||||
|
||||
redis_url: str = "redis://localhost:6379/0"
|
||||
redis_prefix: str = "dify-agent"
|
||||
worker_enabled: bool = True
|
||||
worker_group_name: str = "run-workers"
|
||||
worker_consumer_name: str | None = None
|
||||
worker_pending_idle_ms: int = 600_000
|
||||
shutdown_grace_seconds: float = 30
|
||||
|
||||
model_config: ClassVar[SettingsConfigDict] = SettingsConfigDict(
|
||||
env_prefix="DIFY_AGENT_",
|
||||
|
||||
@ -1,8 +1,4 @@
|
||||
"""Redis key helpers for the run server.
|
||||
|
||||
Keys are centralized so workers, projectors, and HTTP routes can share the same
|
||||
stream/hash layout without duplicating string formats.
|
||||
"""
|
||||
"""Redis key helpers for run records and per-run event streams."""
|
||||
|
||||
|
||||
def run_record_key(prefix: str, run_id: str) -> str:
|
||||
@ -15,9 +11,4 @@ def run_events_key(prefix: str, run_id: str) -> str:
|
||||
return f"{prefix}:runs:{run_id}:events"
|
||||
|
||||
|
||||
def run_jobs_key(prefix: str) -> str:
|
||||
"""Return the Redis stream key holding queued run jobs."""
|
||||
return f"{prefix}:runs:jobs"
|
||||
|
||||
|
||||
__all__ = ["run_events_key", "run_jobs_key", "run_record_key"]
|
||||
__all__ = ["run_events_key", "run_record_key"]
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
"""Redis Streams-backed run persistence.
|
||||
"""Redis-backed run records and per-run event streams.
|
||||
|
||||
The store writes run records as JSON strings and events/jobs as Redis streams.
|
||||
HTTP event cursors are Redis stream ids; ``0-0`` means replay from the beginning
|
||||
for polling and SSE. The worker uses the jobs stream directly and updates the run
|
||||
record through the same status/event sink protocol as tests.
|
||||
The store writes run records as JSON strings and events as Redis streams. HTTP
|
||||
event cursors are Redis stream ids; ``0-0`` means replay from the beginning for
|
||||
polling and SSE. Execution is scheduled in-process by
|
||||
``dify_agent.runtime.run_scheduler``; Redis is not a job queue.
|
||||
"""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
@ -19,11 +19,10 @@ from dify_agent.server.schemas import (
|
||||
RunEventsResponse,
|
||||
RunRecord,
|
||||
RunStatus,
|
||||
RunnerJob,
|
||||
new_run_id,
|
||||
utc_now,
|
||||
)
|
||||
from dify_agent.storage.redis_keys import run_events_key, run_jobs_key, run_record_key
|
||||
from dify_agent.storage.redis_keys import run_events_key, run_record_key
|
||||
|
||||
|
||||
class RunNotFoundError(LookupError):
|
||||
@ -31,7 +30,7 @@ class RunNotFoundError(LookupError):
|
||||
|
||||
|
||||
class RedisRunStore(RunEventSink):
|
||||
"""Async Redis implementation for run records, jobs, and events."""
|
||||
"""Async Redis implementation for run records and event logs."""
|
||||
|
||||
redis: Redis
|
||||
prefix: str
|
||||
@ -41,19 +40,10 @@ class RedisRunStore(RunEventSink):
|
||||
self.prefix = prefix
|
||||
|
||||
async def create_run(self, request: CreateRunRequest) -> RunRecord:
|
||||
"""Persist a queued run and enqueue its worker job atomically.
|
||||
|
||||
The run record and jobs stream entry are one durability boundary: either
|
||||
both are committed by Redis ``MULTI/EXEC`` or neither is visible. This
|
||||
prevents permanently queued records with no corresponding worker job.
|
||||
"""
|
||||
"""Persist a running run record without enqueueing external work."""
|
||||
run_id = new_run_id()
|
||||
record = RunRecord(run_id=run_id, status="queued", request=request)
|
||||
job = RunnerJob(run_id=run_id, request=request)
|
||||
async with self.redis.pipeline(transaction=True) as pipe:
|
||||
pipe.set(run_record_key(self.prefix, run_id), record.model_dump_json())
|
||||
pipe.xadd(run_jobs_key(self.prefix), {"payload": job.model_dump_json()})
|
||||
await pipe.execute()
|
||||
record = RunRecord(run_id=run_id, status="running", request=request)
|
||||
await self.redis.set(run_record_key(self.prefix, run_id), record.model_dump_json())
|
||||
return record
|
||||
|
||||
async def get_run(self, run_id: str) -> RunRecord:
|
||||
|
||||
@ -1,158 +0,0 @@
|
||||
"""Redis Streams worker for executing queued runs.
|
||||
|
||||
This worker is asyncio/uvloop compatible and intentionally does not use Celery.
|
||||
It reads jobs from the shared Redis stream, executes them through
|
||||
``AgentRunRunner``, and acknowledges entries only after terminal status/events
|
||||
have been written.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Protocol, cast
|
||||
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from dify_agent.runtime.runner import AgentRunRunner
|
||||
from dify_agent.server.schemas import RunnerJob
|
||||
from dify_agent.server.settings import ServerSettings
|
||||
from dify_agent.storage.redis_keys import run_jobs_key
|
||||
from dify_agent.storage.redis_run_store import RedisRunStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JobRunner(Protocol):
|
||||
"""Executable unit for one decoded run job."""
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Execute the job and write terminal status/events."""
|
||||
...
|
||||
|
||||
|
||||
type JobRunnerFactory = Callable[[RunnerJob], JobRunner]
|
||||
|
||||
|
||||
def create_default_job_runner(store: RedisRunStore, job: RunnerJob) -> JobRunner:
|
||||
"""Create the production runner for a decoded Redis job."""
|
||||
return AgentRunRunner(sink=store, request=job.request, run_id=job.run_id)
|
||||
|
||||
|
||||
class RunJobWorker:
|
||||
"""Long-running worker that consumes the run jobs stream."""
|
||||
|
||||
store: RedisRunStore
|
||||
group_name: str
|
||||
consumer_name: str
|
||||
pending_idle_ms: int
|
||||
runner_factory: JobRunnerFactory
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
store: RedisRunStore,
|
||||
group_name: str = "run-workers",
|
||||
consumer_name: str = "worker-1",
|
||||
pending_idle_ms: int = 600_000,
|
||||
runner_factory: JobRunnerFactory | None = None,
|
||||
) -> None:
|
||||
self.store = store
|
||||
self.group_name = group_name
|
||||
self.consumer_name = consumer_name
|
||||
self.pending_idle_ms = pending_idle_ms
|
||||
self.runner_factory = runner_factory or (lambda job: create_default_job_runner(store, job))
|
||||
|
||||
async def run_forever(self) -> None:
|
||||
"""Continuously read and execute jobs until cancelled."""
|
||||
jobs_key = run_jobs_key(self.store.prefix)
|
||||
await self._ensure_group(jobs_key)
|
||||
while True:
|
||||
await self.process_once(jobs_key, block_ms=30_000)
|
||||
|
||||
async def process_once(self, jobs_key: str | None = None, *, block_ms: int = 30_000) -> bool:
|
||||
"""Process one stale pending or new job entry.
|
||||
|
||||
Stale pending entries are reclaimed before blocking on new work. This
|
||||
covers worker crashes after ``XREADGROUP`` delivery but before ``XACK``:
|
||||
Redis keeps the entry pending, and another worker can claim it after the
|
||||
configured idle timeout instead of leaving the run stuck forever.
|
||||
"""
|
||||
resolved_jobs_key = jobs_key or run_jobs_key(self.store.prefix)
|
||||
claimed = await self._claim_stale_pending(resolved_jobs_key)
|
||||
if claimed:
|
||||
for entry_id, fields in claimed:
|
||||
await self._handle_entry(resolved_jobs_key, entry_id, fields)
|
||||
return True
|
||||
|
||||
response = await self.store.redis.xreadgroup(
|
||||
self.group_name,
|
||||
self.consumer_name,
|
||||
{resolved_jobs_key: ">"},
|
||||
count=1,
|
||||
block=block_ms,
|
||||
)
|
||||
for _stream_name, entries in response:
|
||||
for entry_id, fields in entries:
|
||||
await self._handle_entry(resolved_jobs_key, entry_id, fields)
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _claim_stale_pending(self, jobs_key: str) -> list[tuple[object, dict[object, object]]]:
|
||||
"""Claim stale pending jobs from crashed consumers."""
|
||||
response = await self.store.redis.xautoclaim(
|
||||
jobs_key,
|
||||
self.group_name,
|
||||
self.consumer_name,
|
||||
min_idle_time=self.pending_idle_ms,
|
||||
start_id="0-0",
|
||||
count=1,
|
||||
)
|
||||
if len(response) >= 2:
|
||||
entries = response[1]
|
||||
return list(entries)
|
||||
return []
|
||||
|
||||
async def _ensure_group(self, jobs_key: str) -> None:
|
||||
"""Create the Redis consumer group if needed."""
|
||||
try:
|
||||
await self.store.redis.xgroup_create(jobs_key, self.group_name, id="0", mkstream=True)
|
||||
except Exception as exc:
|
||||
if "BUSYGROUP" not in str(exc):
|
||||
raise
|
||||
|
||||
async def _handle_entry(self, jobs_key: str, entry_id: object, fields: dict[object, object]) -> None:
|
||||
"""Decode and execute one stream entry."""
|
||||
payload = fields.get(b"payload") or fields.get("payload")
|
||||
if isinstance(payload, bytes):
|
||||
payload = payload.decode()
|
||||
if not isinstance(payload, str | bytes | bytearray):
|
||||
raise ValueError("Redis job payload must be JSON text")
|
||||
job = RunnerJob.model_validate_json(payload)
|
||||
try:
|
||||
await self.runner_factory(job).run()
|
||||
except Exception:
|
||||
logger.exception("run worker failed", extra={"run_id": job.run_id})
|
||||
finally:
|
||||
await self.store.redis.xack(jobs_key, self.group_name, cast(str | bytes, entry_id))
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Run the worker using environment settings."""
|
||||
settings = ServerSettings()
|
||||
redis = Redis.from_url(settings.redis_url)
|
||||
try:
|
||||
await RunJobWorker(
|
||||
store=RedisRunStore(redis, prefix=settings.redis_prefix),
|
||||
group_name=settings.worker_group_name,
|
||||
consumer_name=settings.worker_consumer_name or "worker-1",
|
||||
pending_idle_ms=settings.worker_pending_idle_ms,
|
||||
).run_forever()
|
||||
finally:
|
||||
await redis.aclose()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
__all__ = ["RunJobWorker", "main"]
|
||||
@ -1,21 +0,0 @@
|
||||
"""Lightweight run-event projector service.
|
||||
|
||||
The MVP writes status directly from the runner/store, so this projector currently
|
||||
acts as an async-compatible extension point for future derived views. Keeping the
|
||||
module explicit documents that Redis Streams, not Celery, are the service
|
||||
boundary for background processing.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
|
||||
class RunProjector:
|
||||
"""No-op projector placeholder with a cancellable service loop."""
|
||||
|
||||
async def run_forever(self) -> None:
|
||||
"""Stay alive until cancelled; future projections can be added here."""
|
||||
while True:
|
||||
await asyncio.sleep(3600)
|
||||
|
||||
|
||||
__all__ = ["RunProjector"]
|
||||
180
dify-agent/tests/local/dify_agent/runtime/test_run_scheduler.py
Normal file
180
dify-agent/tests/local/dify_agent/runtime/test_run_scheduler.py
Normal file
@ -0,0 +1,180 @@
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from pydantic import JsonValue
|
||||
|
||||
from agenton.compositor import CompositorConfig, LayerNodeConfig
|
||||
from dify_agent.runtime.run_scheduler import RunScheduler, SchedulerStoppingError
|
||||
from dify_agent.server.schemas import CreateRunRequest, RunEvent, RunRecord, RunStatus
|
||||
|
||||
|
||||
def _request(user: str | list[str] = "hello") -> CreateRunRequest:
|
||||
return CreateRunRequest(
|
||||
compositor=CompositorConfig(
|
||||
layers=[LayerNodeConfig(name="prompt", type="plain.prompt", config=cast(JsonValue, {"user": user}))]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class FakeStore:
|
||||
records: dict[str, RunRecord]
|
||||
events: dict[str, list[RunEvent]]
|
||||
statuses: dict[str, RunStatus]
|
||||
errors: dict[str, str | None]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.records = {}
|
||||
self.events = defaultdict(list)
|
||||
self.statuses = {}
|
||||
self.errors = {}
|
||||
|
||||
async def create_run(self, request: CreateRunRequest) -> RunRecord:
|
||||
run_id = f"run-{len(self.records) + 1}"
|
||||
record = RunRecord(run_id=run_id, status="running", request=request)
|
||||
self.records[run_id] = record
|
||||
self.statuses[run_id] = "running"
|
||||
return record
|
||||
|
||||
async def append_event(self, event: RunEvent) -> str:
|
||||
event_id = str(len(self.events[event.run_id]) + 1)
|
||||
self.events[event.run_id].append(event.model_copy(update={"id": event_id}))
|
||||
return event_id
|
||||
|
||||
async def update_status(self, run_id: str, status: RunStatus, error: str | None = None) -> None:
|
||||
self.statuses[run_id] = status
|
||||
self.errors[run_id] = error
|
||||
|
||||
|
||||
class SlowCreateStore(FakeStore):
|
||||
create_started: asyncio.Event
|
||||
release_create: asyncio.Event
|
||||
|
||||
def __init__(self, *, create_started: asyncio.Event, release_create: asyncio.Event) -> None:
|
||||
super().__init__()
|
||||
self.create_started = create_started
|
||||
self.release_create = release_create
|
||||
|
||||
async def create_run(self, request: CreateRunRequest) -> RunRecord:
|
||||
_ = self.create_started.set()
|
||||
await self.release_create.wait()
|
||||
return await super().create_run(request)
|
||||
|
||||
|
||||
class ControlledRunner:
|
||||
started: asyncio.Event
|
||||
release: asyncio.Event
|
||||
|
||||
def __init__(self, *, started: asyncio.Event, release: asyncio.Event) -> None:
|
||||
self.started = started
|
||||
self.release = release
|
||||
|
||||
async def run(self) -> None:
|
||||
_ = self.started.set()
|
||||
await self.release.wait()
|
||||
|
||||
|
||||
def test_create_run_starts_background_task_and_returns_running() -> None:
|
||||
async def scenario() -> None:
|
||||
store = FakeStore()
|
||||
started = asyncio.Event()
|
||||
release = asyncio.Event()
|
||||
scheduler = RunScheduler(
|
||||
store=store,
|
||||
runner_factory=lambda _record: ControlledRunner(started=started, release=release),
|
||||
)
|
||||
|
||||
record = await scheduler.create_run(_request())
|
||||
await asyncio.wait_for(started.wait(), timeout=1)
|
||||
|
||||
assert record.status == "running"
|
||||
assert list(scheduler.active_tasks) == [record.run_id]
|
||||
_ = release.set()
|
||||
await asyncio.wait_for(scheduler.active_tasks[record.run_id], timeout=1)
|
||||
await asyncio.sleep(0)
|
||||
assert scheduler.active_tasks == {}
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_shutdown_marks_unfinished_runs_failed_and_appends_event() -> None:
|
||||
async def scenario() -> None:
|
||||
store = FakeStore()
|
||||
started = asyncio.Event()
|
||||
scheduler = RunScheduler(
|
||||
store=store,
|
||||
shutdown_grace_seconds=0,
|
||||
runner_factory=lambda _record: ControlledRunner(started=started, release=asyncio.Event()),
|
||||
)
|
||||
record = await scheduler.create_run(_request())
|
||||
await asyncio.wait_for(started.wait(), timeout=1)
|
||||
|
||||
await scheduler.shutdown()
|
||||
|
||||
assert scheduler.stopping is True
|
||||
assert scheduler.active_tasks == {}
|
||||
assert store.statuses[record.run_id] == "failed"
|
||||
assert store.errors[record.run_id] == "run cancelled during server shutdown"
|
||||
assert [event.type for event in store.events[record.run_id]] == ["run_failed"]
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_create_run_rejects_blank_prompt_before_persisting() -> None:
|
||||
async def scenario() -> None:
|
||||
store = FakeStore()
|
||||
scheduler = RunScheduler(store=store)
|
||||
|
||||
with pytest.raises(ValueError, match="compositor.user_prompts must not be empty"):
|
||||
await scheduler.create_run(_request(["", " "]))
|
||||
|
||||
assert store.records == {}
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_create_run_rejects_after_shutdown_starts() -> None:
|
||||
async def scenario() -> None:
|
||||
scheduler = RunScheduler(store=FakeStore())
|
||||
await scheduler.shutdown()
|
||||
|
||||
with pytest.raises(SchedulerStoppingError):
|
||||
await scheduler.create_run(_request())
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_shutdown_waits_for_in_flight_create_to_register_before_cancelling() -> None:
|
||||
async def scenario() -> None:
|
||||
create_started = asyncio.Event()
|
||||
release_create = asyncio.Event()
|
||||
runner_started = asyncio.Event()
|
||||
store = SlowCreateStore(create_started=create_started, release_create=release_create)
|
||||
scheduler = RunScheduler(
|
||||
store=store,
|
||||
shutdown_grace_seconds=0,
|
||||
runner_factory=lambda _record: ControlledRunner(started=runner_started, release=asyncio.Event()),
|
||||
)
|
||||
|
||||
create_task = asyncio.create_task(scheduler.create_run(_request()))
|
||||
await asyncio.wait_for(create_started.wait(), timeout=1)
|
||||
shutdown_task = asyncio.create_task(scheduler.shutdown())
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert shutdown_task.done() is False
|
||||
assert scheduler.stopping is False
|
||||
|
||||
_ = release_create.set()
|
||||
record = await asyncio.wait_for(create_task, timeout=1)
|
||||
await asyncio.wait_for(shutdown_task, timeout=1)
|
||||
|
||||
assert scheduler.stopping is True
|
||||
assert scheduler.active_tasks == {}
|
||||
assert store.statuses[record.run_id] == "failed"
|
||||
assert [event.type for event in store.events[record.run_id]] == ["run_failed"]
|
||||
|
||||
with pytest.raises(SchedulerStoppingError):
|
||||
await scheduler.create_run(_request())
|
||||
|
||||
asyncio.run(scenario())
|
||||
@ -1,6 +1,3 @@
|
||||
import asyncio
|
||||
from typing import ClassVar
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
@ -19,74 +16,43 @@ class FakeRedis:
|
||||
self.closed = True
|
||||
|
||||
|
||||
class FakeRunJobWorker:
|
||||
created: ClassVar[list["FakeRunJobWorker"]] = []
|
||||
class FakeRunScheduler:
|
||||
created: list["FakeRunScheduler"] = []
|
||||
|
||||
group_name: str
|
||||
consumer_name: str
|
||||
pending_idle_ms: int
|
||||
started: bool
|
||||
cancelled: bool
|
||||
shutdown_grace_seconds: float
|
||||
shutdown_called: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
store: object,
|
||||
group_name: str,
|
||||
consumer_name: str,
|
||||
pending_idle_ms: int,
|
||||
shutdown_grace_seconds: float,
|
||||
) -> None:
|
||||
del store
|
||||
self.group_name = group_name
|
||||
self.consumer_name = consumer_name
|
||||
self.pending_idle_ms = pending_idle_ms
|
||||
self.started = False
|
||||
self.cancelled = False
|
||||
self.shutdown_grace_seconds = shutdown_grace_seconds
|
||||
self.shutdown_called = False
|
||||
self.created.append(self)
|
||||
|
||||
async def run_forever(self) -> None:
|
||||
self.started = True
|
||||
try:
|
||||
await asyncio.get_running_loop().create_future()
|
||||
except asyncio.CancelledError:
|
||||
self.cancelled = True
|
||||
raise
|
||||
async def shutdown(self) -> None:
|
||||
self.shutdown_called = True
|
||||
|
||||
|
||||
def test_create_app_starts_and_cancels_embedded_worker(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_create_app_creates_scheduler_and_closes_after_shutdown(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
fake_redis = FakeRedis()
|
||||
FakeRunJobWorker.created.clear()
|
||||
FakeRunScheduler.created.clear()
|
||||
monkeypatch.setattr(app_module.Redis, "from_url", lambda _url: fake_redis)
|
||||
monkeypatch.setattr(app_module, "RunJobWorker", FakeRunJobWorker)
|
||||
monkeypatch.setattr(app_module, "RunScheduler", FakeRunScheduler)
|
||||
|
||||
settings = ServerSettings(
|
||||
redis_url="redis://example.invalid/0",
|
||||
redis_prefix="test",
|
||||
worker_enabled=True,
|
||||
worker_group_name="workers",
|
||||
worker_consumer_name="consumer-a",
|
||||
worker_pending_idle_ms=5,
|
||||
shutdown_grace_seconds=5,
|
||||
)
|
||||
|
||||
with TestClient(create_app(settings)):
|
||||
assert len(FakeRunJobWorker.created) == 1
|
||||
worker = FakeRunJobWorker.created[0]
|
||||
assert worker.started is True
|
||||
assert worker.group_name == "workers"
|
||||
assert worker.consumer_name == "consumer-a"
|
||||
assert worker.pending_idle_ms == 5
|
||||
|
||||
assert FakeRunJobWorker.created[0].cancelled is True
|
||||
assert fake_redis.closed is True
|
||||
|
||||
|
||||
def test_create_app_can_disable_embedded_worker(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
fake_redis = FakeRedis()
|
||||
FakeRunJobWorker.created.clear()
|
||||
monkeypatch.setattr(app_module.Redis, "from_url", lambda _url: fake_redis)
|
||||
monkeypatch.setattr(app_module, "RunJobWorker", FakeRunJobWorker)
|
||||
|
||||
with TestClient(create_app(ServerSettings(worker_enabled=False))):
|
||||
assert FakeRunJobWorker.created == []
|
||||
assert len(FakeRunScheduler.created) == 1
|
||||
scheduler = FakeRunScheduler.created[0]
|
||||
assert scheduler.shutdown_grace_seconds == 5
|
||||
|
||||
assert FakeRunScheduler.created[0].shutdown_called is True
|
||||
assert fake_redis.closed is True
|
||||
|
||||
@ -1,18 +1,26 @@
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from dify_agent.runtime.run_scheduler import SchedulerStoppingError
|
||||
from dify_agent.server.routes.runs import create_runs_router
|
||||
from dify_agent.server.schemas import RunRecord
|
||||
|
||||
|
||||
class FakeScheduler:
|
||||
async def create_run(self, request: object) -> object:
|
||||
raise AssertionError("blank prompt requests must be rejected before scheduling")
|
||||
|
||||
|
||||
class FakeStore:
|
||||
async def create_run(self, request: object) -> object:
|
||||
raise AssertionError("blank prompt requests must be rejected before enqueue")
|
||||
pass
|
||||
|
||||
|
||||
def test_create_run_rejects_effectively_blank_user_prompt_list() -> None:
|
||||
from fastapi import FastAPI
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(create_runs_router(lambda: FakeStore())) # pyright: ignore[reportArgumentType]
|
||||
app.include_router(
|
||||
create_runs_router(lambda: FakeStore(), lambda: FakeScheduler()) # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.post(
|
||||
@ -27,3 +35,97 @@ def test_create_run_rejects_effectively_blank_user_prompt_list() -> None:
|
||||
|
||||
assert response.status_code == 422
|
||||
assert response.json()["detail"] == "compositor.user_prompts must not be empty"
|
||||
|
||||
|
||||
def test_create_run_returns_running_from_scheduler() -> None:
|
||||
from fastapi import FastAPI
|
||||
|
||||
class CapturingScheduler:
|
||||
async def create_run(self, request: object) -> RunRecord:
|
||||
del request
|
||||
return RunRecord(run_id="run-1", status="running", request=_request())
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(
|
||||
create_runs_router(lambda: FakeStore(), lambda: CapturingScheduler()) # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.post(
|
||||
"/runs",
|
||||
json={
|
||||
"compositor": {
|
||||
"schema_version": 1,
|
||||
"layers": [{"name": "prompt", "type": "plain.prompt", "config": {"user": "hello"}}],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 202
|
||||
assert response.json() == {"run_id": "run-1", "status": "running"}
|
||||
|
||||
|
||||
def test_create_run_returns_503_when_scheduler_is_stopping() -> None:
|
||||
from fastapi import FastAPI
|
||||
|
||||
class StoppingScheduler:
|
||||
async def create_run(self, request: object) -> RunRecord:
|
||||
del request
|
||||
raise SchedulerStoppingError("run scheduler is shutting down")
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(
|
||||
create_runs_router(lambda: FakeStore(), lambda: StoppingScheduler()) # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.post(
|
||||
"/runs",
|
||||
json={
|
||||
"compositor": {
|
||||
"schema_version": 1,
|
||||
"layers": [{"name": "prompt", "type": "plain.prompt", "config": {"user": "hello"}}],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 503
|
||||
assert response.json()["detail"] == "run scheduler is shutting down"
|
||||
|
||||
|
||||
def test_create_run_does_not_map_infrastructure_failure_to_422() -> None:
|
||||
from fastapi import FastAPI
|
||||
|
||||
class FailingScheduler:
|
||||
async def create_run(self, request: object) -> RunRecord:
|
||||
del request
|
||||
raise RuntimeError("redis unavailable")
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(
|
||||
create_runs_router(lambda: FakeStore(), lambda: FailingScheduler()) # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
|
||||
response = client.post(
|
||||
"/runs",
|
||||
json={
|
||||
"compositor": {
|
||||
"schema_version": 1,
|
||||
"layers": [{"name": "prompt", "type": "plain.prompt", "config": {"user": "hello"}}],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 500
|
||||
|
||||
|
||||
def _request():
|
||||
from agenton.compositor import CompositorConfig, LayerNodeConfig
|
||||
from dify_agent.server.schemas import CreateRunRequest
|
||||
|
||||
return CreateRunRequest(
|
||||
compositor=CompositorConfig(
|
||||
layers=[LayerNodeConfig(name="prompt", type="plain.prompt", config={"user": "hello"})]
|
||||
)
|
||||
)
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
import asyncio
|
||||
from collections.abc import Mapping
|
||||
|
||||
import pytest
|
||||
|
||||
from agenton.compositor import CompositorConfig, LayerNodeConfig
|
||||
from dify_agent.server.schemas import CreateRunRequest
|
||||
from dify_agent.storage.redis_run_store import RedisRunStore
|
||||
@ -16,74 +14,26 @@ def _request() -> CreateRunRequest:
|
||||
)
|
||||
|
||||
|
||||
class FakePipeline:
|
||||
staged: list[tuple[str, str, object]]
|
||||
executed: bool
|
||||
fail_execute: bool
|
||||
|
||||
def __init__(self, *, fail_execute: bool = False) -> None:
|
||||
self.staged = []
|
||||
self.executed = False
|
||||
self.fail_execute = fail_execute
|
||||
|
||||
async def __aenter__(self) -> "FakePipeline":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: object, exc: object, traceback: object) -> None:
|
||||
return None
|
||||
|
||||
def set(self, key: str, value: object) -> None:
|
||||
self.staged.append(("set", key, value))
|
||||
|
||||
def xadd(self, key: str, fields: Mapping[str, object]) -> None:
|
||||
self.staged.append(("xadd", key, dict(fields)))
|
||||
|
||||
async def execute(self) -> None:
|
||||
if self.fail_execute:
|
||||
raise RuntimeError("transaction failed")
|
||||
self.executed = True
|
||||
|
||||
|
||||
class FakeRedis:
|
||||
pipeline_instance: FakePipeline
|
||||
direct_commands: list[str]
|
||||
commands: list[tuple[str, str, object]]
|
||||
|
||||
def __init__(self, pipeline: FakePipeline) -> None:
|
||||
self.pipeline_instance = pipeline
|
||||
self.direct_commands = []
|
||||
|
||||
def pipeline(self, *, transaction: bool) -> FakePipeline:
|
||||
assert transaction is True
|
||||
return self.pipeline_instance
|
||||
def __init__(self) -> None:
|
||||
self.commands = []
|
||||
|
||||
async def set(self, key: str, value: object) -> None:
|
||||
self.direct_commands.append(f"set:{key}")
|
||||
self.commands.append(("set", key, value))
|
||||
|
||||
async def xadd(self, key: str, fields: Mapping[str, object]) -> str:
|
||||
self.direct_commands.append(f"xadd:{key}")
|
||||
self.commands.append(("xadd", key, dict(fields)))
|
||||
return "1-0"
|
||||
|
||||
|
||||
def test_create_run_writes_record_and_job_in_one_transaction() -> None:
|
||||
pipeline = FakePipeline()
|
||||
redis = FakeRedis(pipeline)
|
||||
def test_create_run_writes_running_record_without_job_queue() -> None:
|
||||
redis = FakeRedis()
|
||||
store = RedisRunStore(redis, prefix="test") # pyright: ignore[reportArgumentType]
|
||||
|
||||
record = asyncio.run(store.create_run(_request()))
|
||||
|
||||
assert record.status == "queued"
|
||||
assert pipeline.executed is True
|
||||
assert [command[0] for command in pipeline.staged] == ["set", "xadd"]
|
||||
assert redis.direct_commands == []
|
||||
|
||||
|
||||
def test_create_run_does_not_fall_back_to_partial_writes_when_transaction_fails() -> None:
|
||||
pipeline = FakePipeline(fail_execute=True)
|
||||
redis = FakeRedis(pipeline)
|
||||
store = RedisRunStore(redis, prefix="test") # pyright: ignore[reportArgumentType]
|
||||
|
||||
with pytest.raises(RuntimeError, match="transaction failed"):
|
||||
asyncio.run(store.create_run(_request()))
|
||||
|
||||
assert pipeline.executed is False
|
||||
assert redis.direct_commands == []
|
||||
assert record.status == "running"
|
||||
assert [command[0] for command in redis.commands] == ["set"]
|
||||
assert redis.commands[0][1] == f"test:runs:{record.run_id}:record"
|
||||
|
||||
@ -1,90 +0,0 @@
|
||||
import asyncio
|
||||
from collections.abc import Mapping
|
||||
from typing import cast
|
||||
|
||||
from agenton.compositor import CompositorConfig, LayerNodeConfig
|
||||
from dify_agent.server.schemas import CreateRunRequest, RunnerJob
|
||||
from dify_agent.storage.redis_run_store import RedisRunStore
|
||||
from dify_agent.worker.job_worker import JobRunner, RunJobWorker
|
||||
|
||||
|
||||
def _job() -> RunnerJob:
|
||||
request = CreateRunRequest(
|
||||
compositor=CompositorConfig(
|
||||
layers=[LayerNodeConfig(name="prompt", type="plain.prompt", config={"user": "hello"})]
|
||||
)
|
||||
)
|
||||
return RunnerJob(run_id="run-1", request=request)
|
||||
|
||||
|
||||
class FakeRunner:
|
||||
ran: bool
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.ran = False
|
||||
|
||||
async def run(self) -> None:
|
||||
self.ran = True
|
||||
|
||||
|
||||
class FakeRedis:
|
||||
xreadgroup_called: bool
|
||||
acked: list[tuple[str, str, str | bytes]]
|
||||
claimed_payload: str
|
||||
|
||||
def __init__(self, claimed_payload: str) -> None:
|
||||
self.xreadgroup_called = False
|
||||
self.acked = []
|
||||
self.claimed_payload = claimed_payload
|
||||
|
||||
async def xautoclaim(
|
||||
self,
|
||||
name: str,
|
||||
groupname: str,
|
||||
consumername: str,
|
||||
min_idle_time: int,
|
||||
start_id: str,
|
||||
count: int,
|
||||
) -> tuple[str, list[tuple[bytes, dict[bytes, bytes]]], list[bytes]]:
|
||||
assert name == "test:runs:jobs"
|
||||
assert groupname == "workers"
|
||||
assert consumername == "worker-b"
|
||||
assert min_idle_time == 10
|
||||
assert start_id == "0-0"
|
||||
assert count == 1
|
||||
return "0-0", [(b"1-0", {b"payload": self.claimed_payload.encode()})], []
|
||||
|
||||
async def xreadgroup(
|
||||
self,
|
||||
groupname: str,
|
||||
consumername: str,
|
||||
streams: Mapping[str, str],
|
||||
count: int,
|
||||
block: int,
|
||||
) -> list[tuple[str, list[tuple[bytes, dict[bytes, bytes]]]]]:
|
||||
self.xreadgroup_called = True
|
||||
return []
|
||||
|
||||
async def xack(self, name: str, groupname: str, entry_id: str | bytes) -> None:
|
||||
self.acked.append((name, groupname, entry_id))
|
||||
|
||||
|
||||
def test_process_once_reclaims_stale_pending_job_before_reading_new_entries() -> None:
|
||||
job = _job()
|
||||
runner = FakeRunner()
|
||||
redis = FakeRedis(job.model_dump_json())
|
||||
store = RedisRunStore(cast(object, redis), prefix="test") # pyright: ignore[reportArgumentType]
|
||||
worker = RunJobWorker(
|
||||
store=store,
|
||||
group_name="workers",
|
||||
consumer_name="worker-b",
|
||||
pending_idle_ms=10,
|
||||
runner_factory=lambda _job: cast(JobRunner, runner),
|
||||
)
|
||||
|
||||
processed = asyncio.run(worker.process_once(block_ms=0))
|
||||
|
||||
assert processed is True
|
||||
assert runner.ran is True
|
||||
assert redis.xreadgroup_called is False
|
||||
assert redis.acked == [("test:runs:jobs", "workers", b"1-0")]
|
||||
Loading…
Reference in New Issue
Block a user