dify/dify-agent/tests/local/dify_agent/runtime/test_run_scheduler.py
盐粒 Yanli 0f06aa2fdd
feat(dify-agent): sync agent progress (#36633)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-26 03:14:10 +00:00

312 lines
11 KiB
Python

import asyncio
from collections import defaultdict
from collections.abc import Mapping
import httpx
import pytest
from agenton.compositor import CompositorSessionSnapshot, LayerSessionSnapshot
from agenton.layers import LifecycleState
from agenton_collections.layers.plain import PromptLayerConfig
from dify_agent.layers.output import DIFY_OUTPUT_LAYER_TYPE_ID, DifyOutputLayerConfig
from dify_agent.protocol import DIFY_AGENT_OUTPUT_LAYER_ID
from dify_agent.protocol.schemas import (
CreateRunRequest,
RunComposition,
RunEvent,
RunLayerSpec,
RunStatus,
)
from dify_agent.runtime.run_scheduler import RunScheduler, SchedulerStoppingError
from dify_agent.server.schemas import RunRecord
def _request(
user: str | list[str] = "hello",
*,
output_config: Mapping[str, object] | DifyOutputLayerConfig | None = None,
) -> CreateRunRequest:
layers = [RunLayerSpec(name="prompt", type="plain.prompt", config=PromptLayerConfig(user=user))]
if output_config is not None:
layers.append(
RunLayerSpec(
name=DIFY_AGENT_OUTPUT_LAYER_ID,
type=DIFY_OUTPUT_LAYER_TYPE_ID,
config=output_config,
)
)
return CreateRunRequest(composition=RunComposition(layers=layers))
def _recursive_output_schema() -> dict[str, object]:
return {
"type": "object",
"properties": {"node": {"$ref": "#/$defs/node"}},
"$defs": {
"node": {
"type": "object",
"properties": {"child": {"$ref": "#/$defs/node"}},
"additionalProperties": False,
}
},
"additionalProperties": False,
}
class FakeStore:
records: dict[str, RunRecord]
events: dict[str, list[RunEvent]]
statuses: dict[str, RunStatus]
errors: dict[str, str | None]
def __init__(self) -> None:
self.records = {}
self.events = defaultdict(list)
self.statuses = {}
self.errors = {}
async def create_run(self) -> RunRecord:
run_id = f"run-{len(self.records) + 1}"
record = RunRecord(run_id=run_id, status="running")
self.records[run_id] = record
self.statuses[run_id] = "running"
return record
async def append_event(self, event: RunEvent) -> str:
event_id = str(len(self.events[event.run_id]) + 1)
self.events[event.run_id].append(event.model_copy(update={"id": event_id}))
return event_id
async def update_status(self, run_id: str, status: RunStatus, error: str | None = None) -> None:
self.statuses[run_id] = status
self.errors[run_id] = error
class SlowCreateStore(FakeStore):
create_started: asyncio.Event
release_create: asyncio.Event
def __init__(self, *, create_started: asyncio.Event, release_create: asyncio.Event) -> None:
super().__init__()
self.create_started = create_started
self.release_create = release_create
async def create_run(self) -> RunRecord:
_ = self.create_started.set()
await self.release_create.wait()
return await super().create_run()
class ControlledRunner:
started: asyncio.Event
release: asyncio.Event
def __init__(self, *, started: asyncio.Event, release: asyncio.Event) -> None:
self.started = started
self.release = release
async def run(self) -> None:
_ = self.started.set()
await self.release.wait()
def test_create_run_starts_background_task_and_returns_running() -> None:
async def scenario() -> None:
store = FakeStore()
started = asyncio.Event()
release = asyncio.Event()
async with httpx.AsyncClient() as client:
scheduler = RunScheduler(
store=store,
plugin_daemon_http_client=client,
runner_factory=lambda _record, _request: ControlledRunner(started=started, release=release),
)
record = await scheduler.create_run(_request())
await asyncio.wait_for(started.wait(), timeout=1)
assert record.status == "running"
assert list(scheduler.active_tasks) == [record.run_id]
_ = release.set()
await asyncio.wait_for(scheduler.active_tasks[record.run_id], timeout=1)
await asyncio.sleep(0)
assert scheduler.active_tasks == {}
asyncio.run(scenario())
def test_shutdown_marks_unfinished_runs_failed_and_appends_event() -> None:
async def scenario() -> None:
store = FakeStore()
started = asyncio.Event()
async with httpx.AsyncClient() as client:
scheduler = RunScheduler(
store=store,
plugin_daemon_http_client=client,
shutdown_grace_seconds=0,
runner_factory=lambda _record, _request: ControlledRunner(started=started, release=asyncio.Event()),
)
record = await scheduler.create_run(_request())
await asyncio.wait_for(started.wait(), timeout=1)
await scheduler.shutdown()
assert scheduler.stopping is True
assert scheduler.active_tasks == {}
assert store.statuses[record.run_id] == "failed"
assert store.errors[record.run_id] == "run cancelled during server shutdown"
assert [event.type for event in store.events[record.run_id]] == ["run_failed"]
asyncio.run(scenario())
def test_create_run_accepts_blank_prompt_and_runner_fails_asynchronously() -> None:
async def scenario() -> None:
store = FakeStore()
async with httpx.AsyncClient() as client:
scheduler = RunScheduler(store=store, plugin_daemon_http_client=client)
record = await scheduler.create_run(_request(["", " "]))
await asyncio.wait_for(scheduler.active_tasks[record.run_id], timeout=1)
assert store.records == {record.run_id: record}
assert [event.type for event in store.events[record.run_id]] == ["run_started", "run_failed"]
assert store.statuses[record.run_id] == "failed"
assert store.errors[record.run_id] == "run.user_prompts must not be empty"
asyncio.run(scenario())
def test_create_run_accepts_invalid_output_schema_and_runner_fails_asynchronously() -> None:
async def scenario() -> None:
store = FakeStore()
async with httpx.AsyncClient() as client:
scheduler = RunScheduler(store=store, plugin_daemon_http_client=client)
record = await scheduler.create_run(
_request(
output_config={
"json_schema": _recursive_output_schema(),
}
)
)
await asyncio.wait_for(scheduler.active_tasks[record.run_id], timeout=1)
assert store.records == {record.run_id: record}
assert [event.type for event in store.events[record.run_id]] == ["run_started", "run_failed"]
assert store.statuses[record.run_id] == "failed"
assert "Recursive $defs refs are not supported" in (store.errors[record.run_id] or "")
asyncio.run(scenario())
def test_create_run_honors_explicit_empty_layer_providers_by_failing_after_persisting() -> None:
async def scenario() -> None:
store = FakeStore()
async with httpx.AsyncClient() as client:
scheduler = RunScheduler(store=store, plugin_daemon_http_client=client, layer_providers=())
record = await scheduler.create_run(_request())
await asyncio.wait_for(scheduler.active_tasks[record.run_id], timeout=1)
assert store.records == {record.run_id: record}
assert [event.type for event in store.events[record.run_id]] == ["run_started", "run_failed"]
assert store.statuses[record.run_id] == "failed"
assert "plain.prompt" in (store.errors[record.run_id] or "")
asyncio.run(scenario())
def test_create_run_accepts_closed_session_snapshot_and_runner_fails_asynchronously() -> None:
async def scenario() -> None:
store = FakeStore()
async with httpx.AsyncClient() as client:
scheduler = RunScheduler(store=store, plugin_daemon_http_client=client)
request = _request()
request.session_snapshot = CompositorSessionSnapshot(
layers=[
LayerSessionSnapshot(
name="prompt",
lifecycle_state=LifecycleState.CLOSED,
runtime_state={},
)
]
)
record = await scheduler.create_run(request)
await asyncio.wait_for(scheduler.active_tasks[record.run_id], timeout=1)
assert store.records == {record.run_id: record}
assert [event.type for event in store.events[record.run_id]] == ["run_started", "run_failed"]
assert store.statuses[record.run_id] == "failed"
assert "CLOSED snapshots cannot be entered" in (store.errors[record.run_id] or "")
asyncio.run(scenario())
def test_create_run_rejects_after_shutdown_starts() -> None:
async def scenario() -> None:
async with httpx.AsyncClient() as client:
scheduler = RunScheduler(store=FakeStore(), plugin_daemon_http_client=client)
await scheduler.shutdown()
with pytest.raises(SchedulerStoppingError):
await scheduler.create_run(_request())
asyncio.run(scenario())
def test_create_run_rejects_invalid_request_after_shutdown_without_persisting() -> None:
async def scenario() -> None:
store = FakeStore()
async with httpx.AsyncClient() as client:
scheduler = RunScheduler(store=store, plugin_daemon_http_client=client)
await scheduler.shutdown()
with pytest.raises(SchedulerStoppingError):
_ = await scheduler.create_run(_request(["", " "]))
assert store.records == {}
asyncio.run(scenario())
def test_shutdown_waits_for_in_flight_create_to_register_before_cancelling() -> None:
async def scenario() -> None:
create_started = asyncio.Event()
release_create = asyncio.Event()
runner_started = asyncio.Event()
store = SlowCreateStore(create_started=create_started, release_create=release_create)
async with httpx.AsyncClient() as client:
scheduler = RunScheduler(
store=store,
plugin_daemon_http_client=client,
shutdown_grace_seconds=0,
runner_factory=lambda _record, _request: ControlledRunner(
started=runner_started, release=asyncio.Event()
),
)
create_task = asyncio.create_task(scheduler.create_run(_request()))
await asyncio.wait_for(create_started.wait(), timeout=1)
shutdown_task = asyncio.create_task(scheduler.shutdown())
await asyncio.sleep(0)
assert shutdown_task.done() is False
assert scheduler.stopping is False
_ = release_create.set()
record = await asyncio.wait_for(create_task, timeout=1)
await asyncio.wait_for(shutdown_task, timeout=1)
assert scheduler.stopping is True
assert scheduler.active_tasks == {}
assert store.statuses[record.run_id] == "failed"
assert [event.type for event in store.events[record.run_id]] == ["run_failed"]
with pytest.raises(SchedulerStoppingError):
await scheduler.create_run(_request())
asyncio.run(scenario())