From 15f5c7064e8c44359c7e05efe8ff3ae836d4af4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9B=90=E7=B2=92=20Yanli?= Date: Tue, 12 May 2026 00:08:55 +0800 Subject: [PATCH] add dify-agent python client --- dify-agent/docs/dify-agent/api/index.md | 51 ++ dify-agent/docs/dify-agent/examples/index.md | 5 + .../dify_agent_examples/__main__.py | 1 + .../run_server_consumer.py | 27 +- .../run_server_sse_consumer.py | 17 +- .../run_server_sync_client.py | 40 ++ dify-agent/src/dify_agent/client/__init__.py | 21 + dify-agent/src/dify_agent/client/_client.py | 667 ++++++++++++++++++ .../src/dify_agent/protocol/__init__.py | 47 ++ dify-agent/src/dify_agent/protocol/schemas.py | 201 ++++++ .../src/dify_agent/runtime/agent_factory.py | 2 +- .../src/dify_agent/runtime/event_sink.py | 2 +- .../src/dify_agent/runtime/run_scheduler.py | 3 +- dify-agent/src/dify_agent/runtime/runner.py | 2 +- .../src/dify_agent/server/routes/runs.py | 2 +- dify-agent/src/dify_agent/server/schemas.py | 221 +----- dify-agent/src/dify_agent/server/sse.py | 2 +- .../src/dify_agent/storage/redis_run_store.py | 7 +- .../local/dify_agent/client/test_client.py | 381 ++++++++++ .../protocol/test_protocol_schemas.py | 40 ++ .../dify_agent/runtime/test_run_scheduler.py | 3 +- .../local/dify_agent/runtime/test_runner.py | 2 +- .../dify_agent/server/test_runs_routes.py | 2 +- .../local/dify_agent/server/test_schemas.py | 44 +- .../tests/local/dify_agent/server/test_sse.py | 2 +- .../storage/test_redis_run_store.py | 2 +- .../examples/test_dify_agent_examples.py | 1 + 27 files changed, 1520 insertions(+), 275 deletions(-) create mode 100644 dify-agent/examples/dify_agent/dify_agent_examples/run_server_sync_client.py create mode 100644 dify-agent/src/dify_agent/client/__init__.py create mode 100644 dify-agent/src/dify_agent/client/_client.py create mode 100644 dify-agent/src/dify_agent/protocol/__init__.py create mode 100644 dify-agent/src/dify_agent/protocol/schemas.py create mode 100644 dify-agent/tests/local/dify_agent/client/test_client.py create mode 100644 dify-agent/tests/local/dify_agent/protocol/test_protocol_schemas.py diff --git a/dify-agent/docs/dify-agent/api/index.md b/dify-agent/docs/dify-agent/api/index.md index 14c764df4b..c61e94bed8 100644 --- a/dify-agent/docs/dify-agent/api/index.md +++ b/dify-agent/docs/dify-agent/api/index.md @@ -5,6 +5,10 @@ configuration, Pydantic AI runtime execution, Redis run records, and per-run Red Streams event logs. The FastAPI application lives at `dify-agent/src/dify_agent/server/app.py`. +Public Python DTOs and event models are exported from +`dify_agent.protocol.schemas`. `dify_agent.server.schemas` is intentionally +server-only and should not be used by API consumers. + ## Input model Create-run requests accept a `CompositorConfig` and an optional @@ -154,6 +158,52 @@ Replay can start from a cursor with either: If both are provided, the `after` query parameter takes precedence. +## Python client + +Use `dify_agent.client.Client` for both async and sync code. Async methods use +normal names; sync methods add `_sync`. + +```python {test="skip" lint="skip"} +from dify_agent.client import Client + + +async def main() -> None: + async with Client(base_url="http://localhost:8000") as client: + run = await client.create_run( + { + "compositor": { + "schema_version": 1, + "layers": [{"name": "prompt", "type": "plain.prompt", "config": {"user": "hello"}}], + } + } + ) + async for event in client.stream_events(run.run_id): + print(event) +``` + +```python {test="skip" lint="skip"} +from dify_agent.client import Client + + +with Client(base_url="http://localhost:8000") as client: + run = client.create_run_sync( + { + "compositor": { + "schema_version": 1, + "layers": [{"name": "prompt", "type": "plain.prompt", "config": {"user": "hello"}}], + } + } + ) + terminal = client.wait_run_sync(run.run_id) +``` + +`stream_events` and `stream_events_sync` parse SSE without an extra dependency. +They reconnect by default from the latest yielded event id and stop after +`run_succeeded` or `run_failed`. They do not reconnect for HTTP 4xx responses, +DTO validation failures, or malformed SSE frames. `create_run` and +`create_run_sync` never retry `POST /runs`; if a timeout occurs, the caller must +decide whether to inspect existing runs or submit a new run. + ## Event types and order A normal successful run emits: @@ -184,3 +234,4 @@ See: - `dify-agent/examples/dify_agent/dify_agent_examples/run_server_consumer.py` for cursor polling - `dify-agent/examples/dify_agent/dify_agent_examples/run_server_sse_consumer.py` for SSE consumption +- `dify-agent/examples/dify_agent/dify_agent_examples/run_server_sync_client.py` for synchronous client usage diff --git a/dify-agent/docs/dify-agent/examples/index.md b/dify-agent/docs/dify-agent/examples/index.md index ec2175d639..8624dc790e 100644 --- a/dify-agent/docs/dify-agent/examples/index.md +++ b/dify-agent/docs/dify-agent/examples/index.md @@ -14,6 +14,11 @@ such as the FastAPI server, Redis, or the plugin daemon. ```snippet {path="/examples/dify_agent/dify_agent_examples/run_server_consumer.py"} ``` +## Use the synchronous client + +```snippet {path="/examples/dify_agent/dify_agent_examples/run_server_sync_client.py"} +``` + ## Stream run events with SSE ```snippet {path="/examples/dify_agent/dify_agent_examples/run_server_sse_consumer.py"} diff --git a/dify-agent/examples/dify_agent/dify_agent_examples/__main__.py b/dify-agent/examples/dify_agent/dify_agent_examples/__main__.py index 0eec2ec21f..204d2b6b8a 100644 --- a/dify-agent/examples/dify_agent/dify_agent_examples/__main__.py +++ b/dify-agent/examples/dify_agent/dify_agent_examples/__main__.py @@ -11,6 +11,7 @@ EXAMPLE_MODULES = ( "run_pydantic_ai_agent", "run_server_consumer", "run_server_sse_consumer", + "run_server_sync_client", ) diff --git a/dify-agent/examples/dify_agent/dify_agent_examples/run_server_consumer.py b/dify-agent/examples/dify_agent/dify_agent_examples/run_server_consumer.py index 6542463bd0..4955629646 100644 --- a/dify-agent/examples/dify_agent/dify_agent_examples/run_server_consumer.py +++ b/dify-agent/examples/dify_agent/dify_agent_examples/run_server_consumer.py @@ -1,4 +1,4 @@ -"""Example consumer for the Dify Agent run server. +"""Async Python client example for the Dify Agent run server. Requires Redis and a running API server. The server schedules runs in-process, for example: @@ -7,21 +7,22 @@ example: The default request uses the credential-free pydantic-ai TestModel profile. This script prints the created run and every event observed through cursor polling. +``Client.create_run`` performs one POST attempt only; use polling or SSE replay to +recover after client-side uncertainty. """ import asyncio -import httpx +from dify_agent.client import Client API_BASE_URL = "http://localhost:8000" async def main() -> None: - async with httpx.AsyncClient(base_url=API_BASE_URL, timeout=30) as client: - create_response = await client.post( - "/runs", - json={ + async with Client(base_url=API_BASE_URL) as client: + run = await client.create_run( + { "compositor": { "schema_version": 1, "layers": [ @@ -36,21 +37,17 @@ async def main() -> None: ], }, "agent_profile": {"provider": "test", "output_text": "Hello from the example TestModel."}, - }, + } ) - create_response.raise_for_status() - run = create_response.json() print("created run", run) cursor = "0-0" while True: - events_response = await client.get(f"/runs/{run['run_id']}/events", params={"after": cursor}) - events_response.raise_for_status() - page = events_response.json() - cursor = page["next_cursor"] or cursor - for event in page["events"]: + page = await client.get_events(run.run_id, after=cursor) + cursor = page.next_cursor or cursor + for event in page.events: print("event", event) - if event["type"] in {"run_succeeded", "run_failed"}: + if event.type in {"run_succeeded", "run_failed"}: return await asyncio.sleep(0.5) diff --git a/dify-agent/examples/dify_agent/dify_agent_examples/run_server_sse_consumer.py b/dify-agent/examples/dify_agent/dify_agent_examples/run_server_sse_consumer.py index 727667fa4b..8e0979973a 100644 --- a/dify-agent/examples/dify_agent/dify_agent_examples/run_server_sse_consumer.py +++ b/dify-agent/examples/dify_agent/dify_agent_examples/run_server_sse_consumer.py @@ -1,13 +1,14 @@ -"""SSE consumer sketch for the Dify Agent run server. +"""Async SSE client example for the Dify Agent run server. Create a run with ``run_server_consumer.py`` or any HTTP client, then set RUN_ID -below and run this script while the server is available. It prints raw SSE frames -without requiring model credentials. +below and run this script while the server is available. The Python client parses +SSE frames into typed protocol events and reconnects with the latest event id by +default. Malformed frames and HTTP 4xx responses fail without reconnecting. """ import asyncio -import httpx +from dify_agent.client import Client API_BASE_URL = "http://localhost:8000" @@ -15,11 +16,9 @@ RUN_ID = "replace-with-run-id" async def main() -> None: - async with httpx.AsyncClient(base_url=API_BASE_URL, timeout=None) as client: - async with client.stream("GET", f"/runs/{RUN_ID}/events/sse") as response: - response.raise_for_status() - async for line in response.aiter_lines(): - print(line) + async with Client(base_url=API_BASE_URL, stream_timeout=None) as client: + async for event in client.stream_events(RUN_ID): + print(event) if __name__ == "__main__": diff --git a/dify-agent/examples/dify_agent/dify_agent_examples/run_server_sync_client.py b/dify-agent/examples/dify_agent/dify_agent_examples/run_server_sync_client.py new file mode 100644 index 0000000000..43f63de943 --- /dev/null +++ b/dify-agent/examples/dify_agent/dify_agent_examples/run_server_sync_client.py @@ -0,0 +1,40 @@ +"""Synchronous Python client example for the Dify Agent run server. + +Requires the same running FastAPI server as the async examples. ``create_run_sync`` +does not retry ``POST /runs``; if a timeout occurs, inspect server state or create +a new run explicitly rather than assuming the original request was not accepted. +""" + +from dify_agent.client import Client + + +API_BASE_URL = "http://localhost:8000" + + +def main() -> None: + with Client(base_url=API_BASE_URL) as client: + run = client.create_run_sync( + { + "compositor": { + "schema_version": 1, + "layers": [ + { + "name": "prompt", + "type": "plain.prompt", + "config": { + "prefix": "You are a concise assistant.", + "user": "Say hello from the synchronous Dify Agent client example.", + }, + } + ], + }, + "agent_profile": {"provider": "test", "output_text": "Hello from the sync TestModel."}, + } + ) + print("created run", run) + terminal = client.wait_run_sync(run.run_id, poll_interval_seconds=0.5) + print("terminal status", terminal) + + +if __name__ == "__main__": + main() diff --git a/dify-agent/src/dify_agent/client/__init__.py b/dify-agent/src/dify_agent/client/__init__.py new file mode 100644 index 0000000000..ff0027b291 --- /dev/null +++ b/dify-agent/src/dify_agent/client/__init__.py @@ -0,0 +1,21 @@ +"""Unified sync and async Python client for the Dify Agent run API.""" + +from ._client import ( + Client, + DifyAgentClientError, + DifyAgentHTTPError, + DifyAgentNotFoundError, + DifyAgentStreamError, + DifyAgentTimeoutError, + DifyAgentValidationError, +) + +__all__ = [ + "Client", + "DifyAgentClientError", + "DifyAgentHTTPError", + "DifyAgentNotFoundError", + "DifyAgentStreamError", + "DifyAgentTimeoutError", + "DifyAgentValidationError", +] diff --git a/dify-agent/src/dify_agent/client/_client.py b/dify-agent/src/dify_agent/client/_client.py new file mode 100644 index 0000000000..a956f080c3 --- /dev/null +++ b/dify-agent/src/dify_agent/client/_client.py @@ -0,0 +1,667 @@ +"""HTTPX-based client for Dify Agent runs. + +The client uses the public DTOs from ``dify_agent.protocol.schemas`` for all +normal request and response parsing. It intentionally does not retry +``POST /runs`` because create-run is not idempotent. SSE streams are the only +operation with reconnect logic: transient stream/connect/read failures, stream +timeouts, and HTTP 5xx stream responses reconnect with the latest observed event +id, while HTTP 4xx responses, DTO validation failures, and malformed SSE frames +fail immediately. +""" + +from __future__ import annotations + +import asyncio +import time +from collections.abc import AsyncIterator, Iterator +from types import TracebackType +from typing import Self, TypeVar, cast +from urllib.parse import quote + +import httpx +from pydantic import BaseModel, ValidationError + +from dify_agent.protocol.schemas import ( + CreateRunRequest, + CreateRunResponse, + RUN_EVENT_ADAPTER, + RunEvent, + RunEventsResponse, + RunStatusResponse, +) + +_ResponseModelT = TypeVar("_ResponseModelT", bound=BaseModel) +_TERMINAL_EVENT_TYPES = {"run_succeeded", "run_failed"} +_TERMINAL_RUN_STATUSES = {"succeeded", "failed"} + + +class DifyAgentClientError(RuntimeError): + """Base class for errors raised by the Dify Agent Python client.""" + + +class DifyAgentHTTPError(DifyAgentClientError): + """Raised for HTTP 4xx/5xx responses not covered by a narrower subclass.""" + + status_code: int + detail: object + + def __init__(self, status_code: int, detail: object) -> None: + self.status_code = status_code + self.detail = detail + super().__init__(f"Dify Agent HTTP {status_code}: {detail}") + + +class DifyAgentNotFoundError(DifyAgentHTTPError): + """Raised when the server returns ``404`` for a run resource.""" + + +class DifyAgentValidationError(DifyAgentHTTPError): + """Raised for local input validation, invalid DTO responses, or HTTP ``422``.""" + + def __init__(self, detail: object, *, status_code: int = 422) -> None: + super().__init__(status_code=status_code, detail=detail) + + +class DifyAgentTimeoutError(DifyAgentClientError): + """Raised when an HTTPX timeout occurs outside successful SSE reconnects.""" + + +class DifyAgentStreamError(DifyAgentClientError): + """Raised for malformed SSE frames or exhausted SSE reconnect attempts.""" + + +class _ReconnectableStreamError(Exception): + """Internal wrapper for stream failures that may be retried by the caller.""" + + error: DifyAgentClientError + + def __init__(self, error: DifyAgentClientError) -> None: + self.error = error + super().__init__(str(error)) + + +class _SSEDecoder: + """Incrementally decode SSE lines into typed run events. + + The decoder keeps only the fields for the current frame. Comments are ignored, + ``data`` fields are joined with newlines as required by the SSE specification, + and payload JSON is validated by ``RUN_EVENT_ADAPTER``. The frame ``id`` is + copied into the decoded event only when the JSON payload omits ``event.id``. + """ + + _event_id: str | None + _event_type: str | None + _data_lines: list[str] + + def __init__(self) -> None: + self._event_id = None + self._event_type = None + self._data_lines = [] + + def feed_line(self, raw_line: str) -> RunEvent | None: + """Consume one SSE line and return an event when a frame completes. + + Empty lines dispatch the current frame. Comment-only frames and frames + without ``data`` are ignored so server heartbeats do not surface to users. + Malformed event payloads raise ``DifyAgentStreamError`` and must not be + retried because replaying would repeat the same invalid frame. + """ + line = raw_line.rstrip("\r") + if line == "": + return self._dispatch() + if line.startswith(":"): + return None + + field, separator, value = line.partition(":") + if separator and value.startswith(" "): + value = value[1:] + if field == "id": + self._event_id = value + elif field == "event": + self._event_type = value + elif field == "data": + self._data_lines.append(value) + return None + + def _dispatch(self) -> RunEvent | None: + """Validate and return the current frame, then clear decoder state.""" + if not self._data_lines: + self._reset() + return None + + frame_id = self._event_id + frame_event_type = self._event_type + data = "\n".join(self._data_lines) + self._reset() + + try: + event = RUN_EVENT_ADAPTER.validate_json(data) + except ValidationError as exc: + raise DifyAgentStreamError("malformed SSE data frame") from exc + if frame_event_type is not None and frame_event_type != event.type: + raise DifyAgentStreamError( + f"SSE event field {frame_event_type!r} does not match payload type {event.type!r}" + ) + if frame_id is not None and event.id is None: + return event.model_copy(update={"id": frame_id}) + return event + + def _reset(self) -> None: + """Clear the current frame without changing decoder configuration.""" + self._event_id = None + self._event_type = None + self._data_lines = [] + + +class Client: + """Unified synchronous and asynchronous client for Dify Agent runs. + + The instance is intentionally small and stateful: it stores base URL, default + headers, timeout settings, optional external HTTPX clients, and lazy-owned + clients for whichever sync/async side is used. External clients are never + closed by this wrapper. Owned sync clients close via ``close_sync`` or the + sync context manager; owned async clients close via ``aclose`` or the async + context manager. + """ + + _base_url: str + _timeout: float | httpx.Timeout + _stream_timeout: float | httpx.Timeout | None + _headers: dict[str, str] + _sync_http_client: httpx.Client | None + _async_http_client: httpx.AsyncClient | None + _owns_sync_http_client: bool + _owns_async_http_client: bool + _sync_closed: bool + _async_closed: bool + + def __init__( + self, + *, + base_url: str, + timeout: float | httpx.Timeout = 30.0, + stream_timeout: float | httpx.Timeout | None = None, + headers: dict[str, str] | None = None, + sync_http_client: httpx.Client | None = None, + async_http_client: httpx.AsyncClient | None = None, + ) -> None: + self._base_url = base_url.rstrip("/") + self._timeout = timeout + self._stream_timeout = stream_timeout + self._headers = dict(headers or {}) + self._sync_http_client = sync_http_client + self._async_http_client = async_http_client + self._owns_sync_http_client = sync_http_client is None + self._owns_async_http_client = async_http_client is None + self._sync_closed = False + self._async_closed = False + + def __enter__(self) -> Self: + """Enter a sync context and return this client without opening the network.""" + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + """Close the owned sync HTTP client when leaving a sync context.""" + del exc_type, exc_value, traceback + self.close_sync() + + async def __aenter__(self) -> Self: + """Enter an async context and return this client without opening the network.""" + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + """Close owned async resources when leaving an async context.""" + del exc_type, exc_value, traceback + await self.aclose() + + def close_sync(self) -> None: + """Close the owned synchronous HTTPX client if it was created.""" + if self._sync_closed: + return + if self._owns_sync_http_client and self._sync_http_client is not None: + self._sync_http_client.close() + self._sync_closed = True + + async def aclose(self) -> None: + """Close owned asynchronous resources and any owned sync client already opened.""" + if not self._async_closed: + if self._owns_async_http_client and self._async_http_client is not None: + await self._async_http_client.aclose() + self._async_closed = True + if self._owns_sync_http_client and self._sync_http_client is not None: + self.close_sync() + + async def create_run(self, request: CreateRunRequest | dict[str, object]) -> CreateRunResponse: + """Create one run and return its accepted status response. + + Dict inputs are validated as ``CreateRunRequest`` before the request is + sent. This method performs exactly one ``POST /runs`` attempt and maps + HTTPX timeouts to ``DifyAgentTimeoutError``. + """ + request_model = _validate_create_run_request(request) + try: + response = await self._get_async_http_client().post( + self._url("/runs"), + content=request_model.model_dump_json(), + headers=self._merged_headers({"Content-Type": "application/json"}), + timeout=self._timeout, + ) + except httpx.TimeoutException as exc: + raise DifyAgentTimeoutError("create_run timed out") from exc + except httpx.RequestError as exc: + raise DifyAgentClientError(f"create_run request failed: {exc}") from exc + return _parse_model_response(response, CreateRunResponse) + + def create_run_sync(self, request: CreateRunRequest | dict[str, object]) -> CreateRunResponse: + """Synchronous variant of ``create_run`` with the same no-retry contract.""" + request_model = _validate_create_run_request(request) + try: + response = self._get_sync_http_client().post( + self._url("/runs"), + content=request_model.model_dump_json(), + headers=self._merged_headers({"Content-Type": "application/json"}), + timeout=self._timeout, + ) + except httpx.TimeoutException as exc: + raise DifyAgentTimeoutError("create_run_sync timed out") from exc + except httpx.RequestError as exc: + raise DifyAgentClientError(f"create_run_sync request failed: {exc}") from exc + return _parse_model_response(response, CreateRunResponse) + + async def get_run(self, run_id: str) -> RunStatusResponse: + """Return the current status for ``run_id`` or raise a mapped client error.""" + try: + response = await self._get_async_http_client().get( + self._url(f"/runs/{quote(run_id, safe='')}"), + headers=self._merged_headers(), + timeout=self._timeout, + ) + except httpx.TimeoutException as exc: + raise DifyAgentTimeoutError("get_run timed out") from exc + except httpx.RequestError as exc: + raise DifyAgentClientError(f"get_run request failed: {exc}") from exc + return _parse_model_response(response, RunStatusResponse) + + def get_run_sync(self, run_id: str) -> RunStatusResponse: + """Synchronous variant of ``get_run``.""" + try: + response = self._get_sync_http_client().get( + self._url(f"/runs/{quote(run_id, safe='')}"), + headers=self._merged_headers(), + timeout=self._timeout, + ) + except httpx.TimeoutException as exc: + raise DifyAgentTimeoutError("get_run_sync timed out") from exc + except httpx.RequestError as exc: + raise DifyAgentClientError(f"get_run_sync request failed: {exc}") from exc + return _parse_model_response(response, RunStatusResponse) + + async def get_events(self, run_id: str, *, after: str = "0-0", limit: int = 100) -> RunEventsResponse: + """Return one cursor-paginated page of events for ``run_id``.""" + try: + response = await self._get_async_http_client().get( + self._url(f"/runs/{quote(run_id, safe='')}/events"), + params={"after": after, "limit": str(limit)}, + headers=self._merged_headers(), + timeout=self._timeout, + ) + except httpx.TimeoutException as exc: + raise DifyAgentTimeoutError("get_events timed out") from exc + except httpx.RequestError as exc: + raise DifyAgentClientError(f"get_events request failed: {exc}") from exc + return _parse_model_response(response, RunEventsResponse) + + def get_events_sync(self, run_id: str, *, after: str = "0-0", limit: int = 100) -> RunEventsResponse: + """Synchronous variant of ``get_events``.""" + try: + response = self._get_sync_http_client().get( + self._url(f"/runs/{quote(run_id, safe='')}/events"), + params={"after": after, "limit": str(limit)}, + headers=self._merged_headers(), + timeout=self._timeout, + ) + except httpx.TimeoutException as exc: + raise DifyAgentTimeoutError("get_events_sync timed out") from exc + except httpx.RequestError as exc: + raise DifyAgentClientError(f"get_events_sync request failed: {exc}") from exc + return _parse_model_response(response, RunEventsResponse) + + async def stream_events( + self, + run_id: str, + *, + after: str | None = None, + reconnect: bool = True, + max_reconnects: int | None = None, + reconnect_delay_seconds: float = 1.0, + until_terminal: bool = True, + ) -> AsyncIterator[RunEvent]: + """Yield typed events from SSE with cursor-based reconnect. + + The initial cursor is ``after`` or ``"0-0"``. After every yielded event + with an id, reconnects resume from that id using the ``after`` query + parameter. HTTP 5xx stream responses are retried, but HTTP 4xx responses, + DTO validation failures, and malformed SSE frames are not retried. By + default iteration stops after ``run_succeeded`` or ``run_failed``. + """ + _validate_stream_options(max_reconnects, reconnect_delay_seconds) + cursor = after or "0-0" + reconnect_attempts = 0 + while True: + try: + async for event in self._stream_events_once(run_id, after=cursor): + if event.id is not None: + cursor = event.id + yield event + if until_terminal and event.type in _TERMINAL_EVENT_TYPES: + return + except _ReconnectableStreamError as exc: + if not reconnect: + raise exc.error from exc + reconnect_attempts = _next_reconnect_attempt( + reconnect_attempts, + max_reconnects=max_reconnects, + error=exc.error, + ) + await _sleep_async(reconnect_delay_seconds) + continue + if not reconnect: + return + reconnect_attempts = _next_reconnect_attempt( + reconnect_attempts, + max_reconnects=max_reconnects, + error=DifyAgentStreamError("SSE stream ended before a terminal event"), + ) + await _sleep_async(reconnect_delay_seconds) + + def stream_events_sync( + self, + run_id: str, + *, + after: str | None = None, + reconnect: bool = True, + max_reconnects: int | None = None, + reconnect_delay_seconds: float = 1.0, + until_terminal: bool = True, + ) -> Iterator[RunEvent]: + """Synchronous variant of ``stream_events`` with the same reconnect rules.""" + _validate_stream_options(max_reconnects, reconnect_delay_seconds) + cursor = after or "0-0" + reconnect_attempts = 0 + while True: + try: + for event in self._stream_events_once_sync(run_id, after=cursor): + if event.id is not None: + cursor = event.id + yield event + if until_terminal and event.type in _TERMINAL_EVENT_TYPES: + return + except _ReconnectableStreamError as exc: + if not reconnect: + raise exc.error from exc + reconnect_attempts = _next_reconnect_attempt( + reconnect_attempts, + max_reconnects=max_reconnects, + error=exc.error, + ) + _sleep_sync(reconnect_delay_seconds) + continue + if not reconnect: + return + reconnect_attempts = _next_reconnect_attempt( + reconnect_attempts, + max_reconnects=max_reconnects, + error=DifyAgentStreamError("SSE stream ended before a terminal event"), + ) + _sleep_sync(reconnect_delay_seconds) + + async def wait_run( + self, + run_id: str, + *, + poll_interval_seconds: float = 1.0, + timeout_seconds: float | None = None, + ) -> RunStatusResponse: + """Poll run status until it becomes terminal and return the final status.""" + _validate_wait_options(poll_interval_seconds, timeout_seconds) + deadline = time.monotonic() + timeout_seconds if timeout_seconds is not None else None + while True: + status = await self.get_run(run_id) + if status.status in _TERMINAL_RUN_STATUSES: + return status + sleep_for = _next_sleep_seconds(poll_interval_seconds, deadline) + if sleep_for is None: + raise DifyAgentTimeoutError(f"run {run_id!r} did not finish before timeout") + await _sleep_async(sleep_for) + + def wait_run_sync( + self, + run_id: str, + *, + poll_interval_seconds: float = 1.0, + timeout_seconds: float | None = None, + ) -> RunStatusResponse: + """Synchronous variant of ``wait_run``.""" + _validate_wait_options(poll_interval_seconds, timeout_seconds) + deadline = time.monotonic() + timeout_seconds if timeout_seconds is not None else None + while True: + status = self.get_run_sync(run_id) + if status.status in _TERMINAL_RUN_STATUSES: + return status + sleep_for = _next_sleep_seconds(poll_interval_seconds, deadline) + if sleep_for is None: + raise DifyAgentTimeoutError(f"run {run_id!r} did not finish before timeout") + _sleep_sync(sleep_for) + + async def _stream_events_once(self, run_id: str, *, after: str) -> AsyncIterator[RunEvent]: + """Open one SSE connection and yield events until it ends or fails.""" + try: + async with self._get_async_http_client().stream( + "GET", + self._url(f"/runs/{quote(run_id, safe='')}/events/sse"), + params={"after": after}, + headers=self._merged_headers(), + timeout=self._stream_timeout, + ) as response: + if response.status_code >= 400: + _ = await response.aread() + _raise_for_stream_status(response) + decoder = _SSEDecoder() + async for line in response.aiter_lines(): + event = decoder.feed_line(line) + if event is not None: + yield event + except DifyAgentHTTPError: + raise + except DifyAgentStreamError: + raise + except httpx.TimeoutException as exc: + raise _ReconnectableStreamError(DifyAgentTimeoutError("SSE stream timed out")) from exc + except httpx.TransportError as exc: + raise _ReconnectableStreamError(DifyAgentStreamError(f"SSE stream failed: {exc}")) from exc + except httpx.StreamError as exc: + raise _ReconnectableStreamError(DifyAgentStreamError(f"SSE stream failed: {exc}")) from exc + + def _stream_events_once_sync(self, run_id: str, *, after: str) -> Iterator[RunEvent]: + """Open one synchronous SSE connection and yield events until it ends or fails.""" + try: + with self._get_sync_http_client().stream( + "GET", + self._url(f"/runs/{quote(run_id, safe='')}/events/sse"), + params={"after": after}, + headers=self._merged_headers(), + timeout=self._stream_timeout, + ) as response: + if response.status_code >= 400: + _ = response.read() + _raise_for_stream_status(response) + decoder = _SSEDecoder() + for line in response.iter_lines(): + event = decoder.feed_line(line) + if event is not None: + yield event + except DifyAgentHTTPError: + raise + except DifyAgentStreamError: + raise + except httpx.TimeoutException as exc: + raise _ReconnectableStreamError(DifyAgentTimeoutError("SSE stream timed out")) from exc + except httpx.TransportError as exc: + raise _ReconnectableStreamError(DifyAgentStreamError(f"SSE stream failed: {exc}")) from exc + except httpx.StreamError as exc: + raise _ReconnectableStreamError(DifyAgentStreamError(f"SSE stream failed: {exc}")) from exc + + def _get_sync_http_client(self) -> httpx.Client: + """Return an open sync HTTPX client, creating an owned one lazily.""" + if self._sync_closed: + raise DifyAgentClientError("sync client is closed") + if self._sync_http_client is None: + self._sync_http_client = httpx.Client(timeout=self._timeout, headers=self._headers) + return self._sync_http_client + + def _get_async_http_client(self) -> httpx.AsyncClient: + """Return an open async HTTPX client, creating an owned one lazily.""" + if self._async_closed: + raise DifyAgentClientError("async client is closed") + if self._async_http_client is None: + self._async_http_client = httpx.AsyncClient(timeout=self._timeout, headers=self._headers) + return self._async_http_client + + def _url(self, path: str) -> str: + """Build an absolute URL from the configured base and API path.""" + return f"{self._base_url}{path}" + + def _merged_headers(self, extra: dict[str, str] | None = None) -> dict[str, str]: + """Return per-request headers without mutating client defaults.""" + headers = dict(self._headers) + if extra is not None: + headers.update(extra) + return headers + + +def _validate_create_run_request(request: CreateRunRequest | dict[str, object]) -> CreateRunRequest: + """Validate user input before creating a run.""" + if isinstance(request, CreateRunRequest): + return request + try: + return CreateRunRequest.model_validate(request) + except ValidationError as exc: + raise DifyAgentValidationError(detail=exc.errors(include_url=False)) from exc + + +def _parse_model_response(response: httpx.Response, model_type: type[_ResponseModelT]) -> _ResponseModelT: + """Map HTTP errors and parse a Pydantic response DTO.""" + _raise_for_status(response) + try: + return model_type.model_validate_json(response.content) + except ValidationError as exc: + raise DifyAgentValidationError( + detail=exc.errors(include_url=False), + status_code=response.status_code, + ) from exc + + +def _raise_for_status(response: httpx.Response) -> None: + """Raise the configured client exception for HTTP 4xx/5xx responses.""" + if response.status_code < 400: + return + detail = _extract_error_detail(response) + if response.status_code == 404: + raise DifyAgentNotFoundError(status_code=response.status_code, detail=detail) + if response.status_code == 422: + raise DifyAgentValidationError(status_code=response.status_code, detail=detail) + raise DifyAgentHTTPError(status_code=response.status_code, detail=detail) + + +def _raise_for_stream_status(response: httpx.Response) -> None: + """Raise terminal 4xx errors or wrap retryable SSE 5xx responses.""" + try: + _raise_for_status(response) + except DifyAgentHTTPError as exc: + if response.status_code >= 500: + raise _ReconnectableStreamError( + DifyAgentStreamError(f"SSE stream HTTP {response.status_code}: {exc.detail}") + ) from exc + raise + + +def _extract_error_detail(response: httpx.Response) -> object: + """Extract FastAPI's ``detail`` field when present, falling back to text.""" + try: + payload = cast(object, response.json()) + except (ValueError, httpx.ResponseNotRead): + return response.text or response.reason_phrase + if isinstance(payload, dict) and "detail" in payload: + return cast(object, payload["detail"]) + return cast(object, payload) + + +def _next_reconnect_attempt( + reconnect_attempts: int, + *, + max_reconnects: int | None, + error: DifyAgentClientError, +) -> int: + """Increment reconnect attempts or raise when the configured budget is spent.""" + if max_reconnects is not None and reconnect_attempts >= max_reconnects: + raise DifyAgentStreamError("SSE stream reconnect attempts exhausted") from error + return reconnect_attempts + 1 + + +def _validate_stream_options(max_reconnects: int | None, reconnect_delay_seconds: float) -> None: + """Reject stream options that cannot produce deterministic reconnect behavior.""" + if max_reconnects is not None and max_reconnects < 0: + raise DifyAgentValidationError(detail="max_reconnects must be non-negative") + if reconnect_delay_seconds < 0: + raise DifyAgentValidationError(detail="reconnect_delay_seconds must be non-negative") + + +def _validate_wait_options(poll_interval_seconds: float, timeout_seconds: float | None) -> None: + """Reject wait options that would make polling ambiguous.""" + if poll_interval_seconds < 0: + raise DifyAgentValidationError(detail="poll_interval_seconds must be non-negative") + if timeout_seconds is not None and timeout_seconds < 0: + raise DifyAgentValidationError(detail="timeout_seconds must be non-negative") + + +def _next_sleep_seconds(poll_interval_seconds: float, deadline: float | None) -> float | None: + """Return the next polling sleep duration, or ``None`` when timed out.""" + if deadline is None: + return poll_interval_seconds + remaining = deadline - time.monotonic() + if remaining <= 0: + return None + return min(poll_interval_seconds, remaining) + + +async def _sleep_async(seconds: float) -> None: + """Sleep asynchronously, skipping the call for zero-second test delays.""" + if seconds > 0: + await asyncio.sleep(seconds) + + +def _sleep_sync(seconds: float) -> None: + """Sleep synchronously, skipping the call for zero-second test delays.""" + if seconds > 0: + time.sleep(seconds) + + +__all__ = [ + "Client", + "DifyAgentClientError", + "DifyAgentHTTPError", + "DifyAgentNotFoundError", + "DifyAgentStreamError", + "DifyAgentTimeoutError", + "DifyAgentValidationError", +] diff --git a/dify-agent/src/dify_agent/protocol/__init__.py b/dify-agent/src/dify_agent/protocol/__init__.py new file mode 100644 index 0000000000..4c00d7ffeb --- /dev/null +++ b/dify-agent/src/dify_agent/protocol/__init__.py @@ -0,0 +1,47 @@ +"""Public protocol exports shared by the Dify Agent server and clients.""" + +from .schemas import ( + RUN_EVENT_ADAPTER, + AgentOutputRunEvent, + AgentOutputRunEventData, + AgentProfileConfig, + BaseRunEvent, + CreateRunRequest, + CreateRunResponse, + EmptyRunEventData, + PydanticAIStreamRunEvent, + RunEvent, + RunEventType, + RunEventsResponse, + RunFailedEvent, + RunFailedEventData, + RunStartedEvent, + RunStatus, + RunStatusResponse, + RunSucceededEvent, + SessionSnapshotRunEvent, + utc_now, +) + +__all__ = [ + "AgentProfileConfig", + "AgentOutputRunEvent", + "AgentOutputRunEventData", + "BaseRunEvent", + "CreateRunRequest", + "CreateRunResponse", + "EmptyRunEventData", + "PydanticAIStreamRunEvent", + "RUN_EVENT_ADAPTER", + "RunEvent", + "RunEventType", + "RunEventsResponse", + "RunFailedEvent", + "RunFailedEventData", + "RunStartedEvent", + "RunStatus", + "RunStatusResponse", + "RunSucceededEvent", + "SessionSnapshotRunEvent", + "utc_now", +] diff --git a/dify-agent/src/dify_agent/protocol/schemas.py b/dify-agent/src/dify_agent/protocol/schemas.py new file mode 100644 index 0000000000..b7628bb978 --- /dev/null +++ b/dify-agent/src/dify_agent/protocol/schemas.py @@ -0,0 +1,201 @@ +"""Public HTTP protocol schemas for the Dify Agent run API. + +This module is the shared wire contract for the FastAPI server, runtime event +producers, storage adapters, and Python client. The server accepts only +registry-backed Agenton compositor configs, keeping HTTP input data-only and +preventing unsafe import-path construction. Run events are append-only records; +Redis stream ids (or in-memory equivalents in tests) are the public cursors used +by polling and SSE replay. Event envelopes keep the public +``id``/``run_id``/``type``/``data``/``created_at`` shape, while each ``type`` has +a typed ``data`` model so OpenAPI, Redis replay, and clients parse the same +payload contract. +""" + +from datetime import datetime, timezone +from typing import Annotated, ClassVar, Literal, TypeAlias + +from pydantic import BaseModel, ConfigDict, Field, TypeAdapter +from pydantic_ai.messages import AgentStreamEvent + +from agenton.compositor import CompositorConfig, CompositorSessionSnapshot + + +RunStatus = Literal["running", "succeeded", "failed"] +RunEventType = Literal[ + "run_started", + "pydantic_ai_event", + "agent_output", + "session_snapshot", + "run_succeeded", + "run_failed", +] + + +def utc_now() -> datetime: + """Return the timezone-aware timestamp format used by public schemas.""" + return datetime.now(timezone.utc) + + +class AgentProfileConfig(BaseModel): + """Minimal model profile for the MVP runner. + + ``test`` uses pydantic-ai's ``TestModel`` and is credential-free. Other + profiles can be added behind this schema without changing run/event storage. + """ + + provider: Literal["test"] = "test" + output_text: str = "Hello from the Dify Agent test model." + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + +class CreateRunRequest(BaseModel): + """Request body for creating one async agent run.""" + + compositor: CompositorConfig + session_snapshot: CompositorSessionSnapshot | None = None + agent_profile: AgentProfileConfig = Field(default_factory=AgentProfileConfig) + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + +class CreateRunResponse(BaseModel): + """Response returned after a run has been persisted and scheduled locally.""" + + run_id: str + status: RunStatus + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + +class RunStatusResponse(BaseModel): + """Current server-side status for one run.""" + + run_id: str + status: RunStatus + created_at: datetime + updated_at: datetime + error: str | None = None + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + +class EmptyRunEventData(BaseModel): + """Typed empty payload for lifecycle events that carry no extra data.""" + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + +class AgentOutputRunEventData(BaseModel): + """Final agent output payload emitted before the session snapshot.""" + + output: str + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + +class RunFailedEventData(BaseModel): + """Terminal failure payload shown to polling and SSE consumers.""" + + error: str + reason: str | None = None + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + +class BaseRunEvent(BaseModel): + """Shared append-only event envelope visible through polling and SSE.""" + + id: str | None = None + run_id: str + created_at: datetime = Field(default_factory=utc_now) + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + +class RunStartedEvent(BaseRunEvent): + """Run lifecycle event emitted before runtime execution starts.""" + + type: Literal["run_started"] = "run_started" + data: EmptyRunEventData = Field(default_factory=EmptyRunEventData) + + +class PydanticAIStreamRunEvent(BaseRunEvent): + """Pydantic AI stream event using the upstream typed event model.""" + + type: Literal["pydantic_ai_event"] = "pydantic_ai_event" + data: AgentStreamEvent + + +class AgentOutputRunEvent(BaseRunEvent): + """Run event carrying the final agent output string.""" + + type: Literal["agent_output"] = "agent_output" + data: AgentOutputRunEventData + + +class SessionSnapshotRunEvent(BaseRunEvent): + """Run event carrying the resumable Agenton session snapshot.""" + + type: Literal["session_snapshot"] = "session_snapshot" + data: CompositorSessionSnapshot + + +class RunSucceededEvent(BaseRunEvent): + """Terminal success event emitted after output and session snapshot.""" + + type: Literal["run_succeeded"] = "run_succeeded" + data: EmptyRunEventData = Field(default_factory=EmptyRunEventData) + + +class RunFailedEvent(BaseRunEvent): + """Terminal failure event emitted before the run status becomes failed.""" + + type: Literal["run_failed"] = "run_failed" + data: RunFailedEventData + + +RunEvent: TypeAlias = Annotated[ + RunStartedEvent + | PydanticAIStreamRunEvent + | AgentOutputRunEvent + | SessionSnapshotRunEvent + | RunSucceededEvent + | RunFailedEvent, + Field(discriminator="type"), +] +RUN_EVENT_ADAPTER: TypeAdapter[RunEvent] = TypeAdapter(RunEvent) + + +class RunEventsResponse(BaseModel): + """Cursor-paginated event log response.""" + + run_id: str + events: list[RunEvent] + next_cursor: str | None = None + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + +__all__ = [ + "AgentProfileConfig", + "AgentOutputRunEvent", + "AgentOutputRunEventData", + "BaseRunEvent", + "CreateRunRequest", + "CreateRunResponse", + "EmptyRunEventData", + "PydanticAIStreamRunEvent", + "RUN_EVENT_ADAPTER", + "RunEvent", + "RunEventType", + "RunEventsResponse", + "RunFailedEvent", + "RunFailedEventData", + "RunStartedEvent", + "RunStatus", + "RunStatusResponse", + "RunSucceededEvent", + "SessionSnapshotRunEvent", + "utc_now", +] diff --git a/dify-agent/src/dify_agent/runtime/agent_factory.py b/dify-agent/src/dify_agent/runtime/agent_factory.py index d0130c7271..59ed9ef359 100644 --- a/dify-agent/src/dify_agent/runtime/agent_factory.py +++ b/dify-agent/src/dify_agent/runtime/agent_factory.py @@ -13,7 +13,7 @@ from pydantic_ai.messages import UserContent from pydantic_ai.models.test import TestModel from agenton.layers.types import PydanticAIPrompt, PydanticAITool -from dify_agent.server.schemas import AgentProfileConfig +from dify_agent.protocol.schemas import AgentProfileConfig def create_agent( diff --git a/dify-agent/src/dify_agent/runtime/event_sink.py b/dify-agent/src/dify_agent/runtime/event_sink.py index 552961750d..70658e26bd 100644 --- a/dify-agent/src/dify_agent/runtime/event_sink.py +++ b/dify-agent/src/dify_agent/runtime/event_sink.py @@ -11,7 +11,7 @@ from typing import Protocol from pydantic_ai.messages import AgentStreamEvent from agenton.compositor import CompositorSessionSnapshot -from dify_agent.server.schemas import ( +from dify_agent.protocol.schemas import ( AgentOutputRunEvent, AgentOutputRunEventData, EmptyRunEventData, diff --git a/dify-agent/src/dify_agent/runtime/run_scheduler.py b/dify-agent/src/dify_agent/runtime/run_scheduler.py index 329893d290..93423ade75 100644 --- a/dify-agent/src/dify_agent/runtime/run_scheduler.py +++ b/dify-agent/src/dify_agent/runtime/run_scheduler.py @@ -12,11 +12,12 @@ import logging from collections.abc import Callable from typing import Protocol +from dify_agent.protocol.schemas import CreateRunRequest from dify_agent.runtime.compositor_factory import build_pydantic_ai_compositor from dify_agent.runtime.event_sink import RunEventSink, emit_run_failed from dify_agent.runtime.runner import AgentRunRunner from dify_agent.runtime.user_prompt_validation import EMPTY_USER_PROMPTS_ERROR, has_non_blank_user_prompt -from dify_agent.server.schemas import CreateRunRequest, RunRecord +from dify_agent.server.schemas import RunRecord logger = logging.getLogger(__name__) diff --git a/dify-agent/src/dify_agent/runtime/runner.py b/dify-agent/src/dify_agent/runtime/runner.py index 296d14d1c5..102ec7c2de 100644 --- a/dify-agent/src/dify_agent/runtime/runner.py +++ b/dify-agent/src/dify_agent/runtime/runner.py @@ -11,6 +11,7 @@ from collections.abc import AsyncIterable from pydantic_ai.messages import AgentStreamEvent from agenton.compositor import CompositorSessionSnapshot +from dify_agent.protocol.schemas import CreateRunRequest from dify_agent.runtime.agent_factory import create_agent, normalize_user_input from dify_agent.runtime.compositor_factory import build_pydantic_ai_compositor from dify_agent.runtime.event_sink import ( @@ -23,7 +24,6 @@ from dify_agent.runtime.event_sink import ( emit_session_snapshot, ) from dify_agent.runtime.user_prompt_validation import EMPTY_USER_PROMPTS_ERROR, has_non_blank_user_prompt -from dify_agent.server.schemas import CreateRunRequest class AgentRunValidationError(ValueError): diff --git a/dify-agent/src/dify_agent/server/routes/runs.py b/dify-agent/src/dify_agent/server/routes/runs.py index 8f0d23ddb6..e279e26d33 100644 --- a/dify-agent/src/dify_agent/server/routes/runs.py +++ b/dify-agent/src/dify_agent/server/routes/runs.py @@ -13,10 +13,10 @@ from typing import Annotated from fastapi import APIRouter, Depends, Header, HTTPException, Query from fastapi.responses import StreamingResponse +from dify_agent.protocol.schemas import CreateRunRequest, CreateRunResponse, RunEventsResponse, RunStatusResponse from dify_agent.runtime.compositor_factory import build_pydantic_ai_compositor from dify_agent.runtime.run_scheduler import RunScheduler, SchedulerStoppingError from dify_agent.runtime.user_prompt_validation import EMPTY_USER_PROMPTS_ERROR, has_non_blank_user_prompt -from dify_agent.server.schemas import CreateRunRequest, CreateRunResponse, RunEventsResponse, RunStatusResponse from dify_agent.server.sse import sse_event_stream from dify_agent.storage.redis_run_store import RedisRunStore, RunNotFoundError diff --git a/dify-agent/src/dify_agent/server/schemas.py b/dify-agent/src/dify_agent/server/schemas.py index 7b3ff93f92..75d34ece34 100644 --- a/dify-agent/src/dify_agent/server/schemas.py +++ b/dify-agent/src/dify_agent/server/schemas.py @@ -1,198 +1,41 @@ -"""Public API schemas for the Dify Agent run server. +"""Server-only schemas and helpers for persisted run records. -The server accepts only registry-backed Agenton compositor configs. This keeps -HTTP input data-only and prevents unsafe import-path construction. Run events are -append-only records; Redis stream ids (or in-memory equivalents in tests) are the -public cursors used by polling and SSE replay. Event envelopes keep the public -``id``/``run_id``/``type``/``data``/``created_at`` shape, but each ``type`` has a -typed ``data`` model so OpenAPI, Redis replay, and runtime producers agree on the -payload contract. +Public HTTP DTOs and run events live in ``dify_agent.protocol.schemas`` and are +intentionally not re-exported here. Keeping this module server-only prevents old +imports from silently depending on implementation modules while preserving the +internal ``RunRecord`` model used by schedulers and Redis storage. """ -from datetime import datetime, timezone -from typing import Annotated, Literal, TypeAlias +from datetime import datetime +from typing import ClassVar from uuid import uuid4 -from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, field_validator -from pydantic_ai.messages import AgentStreamEvent +from pydantic import BaseModel, ConfigDict, Field, field_validator -from agenton.compositor import CompositorConfig, CompositorSessionSnapshot - - -RunStatus = Literal["running", "succeeded", "failed"] -RunEventType = Literal[ - "run_started", - "pydantic_ai_event", - "agent_output", - "session_snapshot", - "run_succeeded", - "run_failed", -] +from dify_agent.protocol import schemas as _protocol_schemas def new_run_id() -> str: - """Return a stable external run id.""" + """Return a stable external run id for newly persisted server records.""" return str(uuid4()) -def utc_now() -> datetime: - """Return the timestamp format used by public schemas.""" - return datetime.now(timezone.utc) +class RunRecord(BaseModel): + """Internal representation persisted for status reads. - -class AgentProfileConfig(BaseModel): - """Minimal model profile for the MVP runner. - - ``test`` uses pydantic-ai's ``TestModel`` and is credential-free. Other - profiles can be added behind this schema without changing run/event storage. + The embedded request and status use protocol types so persisted records stay + JSON-compatible with the public API, but callers must import those DTOs from + ``dify_agent.protocol.schemas`` rather than this server-only module. """ - provider: Literal["test"] = "test" - output_text: str = "Hello from the Dify Agent test model." - - model_config = ConfigDict(extra="forbid") - - -class CreateRunRequest(BaseModel): - """Request body for creating one async agent run.""" - - compositor: CompositorConfig - session_snapshot: CompositorSessionSnapshot | None = None - agent_profile: AgentProfileConfig = Field(default_factory=AgentProfileConfig) - - model_config = ConfigDict(extra="forbid") - - -class CreateRunResponse(BaseModel): - """Response returned after a run has been persisted and scheduled locally.""" - run_id: str - status: RunStatus - - model_config = ConfigDict(extra="forbid") - - -class RunStatusResponse(BaseModel): - """Current server-side status for one run.""" - - run_id: str - status: RunStatus - created_at: datetime - updated_at: datetime + status: _protocol_schemas.RunStatus + request: _protocol_schemas.CreateRunRequest + created_at: datetime = Field(default_factory=_protocol_schemas.utc_now) + updated_at: datetime = Field(default_factory=_protocol_schemas.utc_now) error: str | None = None - model_config = ConfigDict(extra="forbid") - - -class EmptyRunEventData(BaseModel): - """Typed empty payload for lifecycle events that carry no extra data.""" - - model_config = ConfigDict(extra="forbid") - - -class AgentOutputRunEventData(BaseModel): - """Final agent output payload emitted before the session snapshot.""" - - output: str - - model_config = ConfigDict(extra="forbid") - - -class RunFailedEventData(BaseModel): - """Terminal failure payload shown to polling and SSE consumers.""" - - error: str - reason: str | None = None - - model_config = ConfigDict(extra="forbid") - - -class BaseRunEvent(BaseModel): - """Shared append-only event envelope visible through polling and SSE.""" - - id: str | None = None - run_id: str - created_at: datetime = Field(default_factory=utc_now) - - model_config = ConfigDict(extra="forbid") - - -class RunStartedEvent(BaseRunEvent): - """Run lifecycle event emitted before runtime execution starts.""" - - type: Literal["run_started"] = "run_started" - data: EmptyRunEventData = Field(default_factory=EmptyRunEventData) - - -class PydanticAIStreamRunEvent(BaseRunEvent): - """Pydantic AI stream event using the upstream typed event model.""" - - type: Literal["pydantic_ai_event"] = "pydantic_ai_event" - data: AgentStreamEvent - - -class AgentOutputRunEvent(BaseRunEvent): - """Run event carrying the final agent output string.""" - - type: Literal["agent_output"] = "agent_output" - data: AgentOutputRunEventData - - -class SessionSnapshotRunEvent(BaseRunEvent): - """Run event carrying the resumable Agenton session snapshot.""" - - type: Literal["session_snapshot"] = "session_snapshot" - data: CompositorSessionSnapshot - - -class RunSucceededEvent(BaseRunEvent): - """Terminal success event emitted after output and session snapshot.""" - - type: Literal["run_succeeded"] = "run_succeeded" - data: EmptyRunEventData = Field(default_factory=EmptyRunEventData) - - -class RunFailedEvent(BaseRunEvent): - """Terminal failure event emitted before the run status becomes failed.""" - - type: Literal["run_failed"] = "run_failed" - data: RunFailedEventData - - - -RunEvent: TypeAlias = Annotated[ - RunStartedEvent - | PydanticAIStreamRunEvent - | AgentOutputRunEvent - | SessionSnapshotRunEvent - | RunSucceededEvent - | RunFailedEvent, - Field(discriminator="type"), -] -RUN_EVENT_ADAPTER = TypeAdapter(RunEvent) - - -class RunEventsResponse(BaseModel): - """Cursor-paginated event log response.""" - - run_id: str - events: list[RunEvent] - next_cursor: str | None = None - - model_config = ConfigDict(extra="forbid") - - -class RunRecord(BaseModel): - """Internal representation persisted for status reads.""" - - run_id: str - status: RunStatus - request: CreateRunRequest - created_at: datetime = Field(default_factory=utc_now) - updated_at: datetime = Field(default_factory=utc_now) - error: str | None = None - - model_config = ConfigDict(extra="forbid") + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") @field_validator("updated_at") @classmethod @@ -203,26 +46,4 @@ class RunRecord(BaseModel): return value -__all__ = [ - "AgentProfileConfig", - "AgentOutputRunEvent", - "AgentOutputRunEventData", - "BaseRunEvent", - "CreateRunRequest", - "CreateRunResponse", - "EmptyRunEventData", - "PydanticAIStreamRunEvent", - "RUN_EVENT_ADAPTER", - "RunEvent", - "RunEventsResponse", - "RunFailedEvent", - "RunFailedEventData", - "RunRecord", - "RunStartedEvent", - "RunStatus", - "RunStatusResponse", - "RunSucceededEvent", - "SessionSnapshotRunEvent", - "new_run_id", - "utc_now", -] +__all__ = ["RunRecord", "new_run_id"] diff --git a/dify-agent/src/dify_agent/server/sse.py b/dify-agent/src/dify_agent/server/sse.py index 0e917120a5..72a880ab0f 100644 --- a/dify-agent/src/dify_agent/server/sse.py +++ b/dify-agent/src/dify_agent/server/sse.py @@ -7,7 +7,7 @@ name. Payload data is the full public ``RunEvent`` JSON object. from collections.abc import AsyncIterable, AsyncIterator -from dify_agent.server.schemas import RUN_EVENT_ADAPTER, RunEvent +from dify_agent.protocol.schemas import RUN_EVENT_ADAPTER, RunEvent def format_sse_event(event: RunEvent) -> str: diff --git a/dify-agent/src/dify_agent/storage/redis_run_store.py b/dify-agent/src/dify_agent/storage/redis_run_store.py index f5c56815c9..f9183b37ba 100644 --- a/dify-agent/src/dify_agent/storage/redis_run_store.py +++ b/dify-agent/src/dify_agent/storage/redis_run_store.py @@ -12,17 +12,16 @@ from typing import cast from redis.asyncio import Redis -from dify_agent.runtime.event_sink import RunEventSink -from dify_agent.server.schemas import ( +from dify_agent.protocol.schemas import ( CreateRunRequest, RUN_EVENT_ADAPTER, RunEvent, RunEventsResponse, - RunRecord, RunStatus, - new_run_id, utc_now, ) +from dify_agent.runtime.event_sink import RunEventSink +from dify_agent.server.schemas import RunRecord, new_run_id from dify_agent.server.settings import DEFAULT_RUN_RETENTION_SECONDS from dify_agent.storage.redis_keys import run_events_key, run_record_key diff --git a/dify-agent/tests/local/dify_agent/client/test_client.py b/dify-agent/tests/local/dify_agent/client/test_client.py new file mode 100644 index 0000000000..aa9eba527f --- /dev/null +++ b/dify-agent/tests/local/dify_agent/client/test_client.py @@ -0,0 +1,381 @@ +from __future__ import annotations + +import asyncio +import json +from collections.abc import Iterator +from datetime import UTC, datetime +from typing import cast, override + +import httpx +import pytest + +from dify_agent.client import ( + Client, + DifyAgentHTTPError, + DifyAgentNotFoundError, + DifyAgentStreamError, + DifyAgentTimeoutError, + DifyAgentValidationError, +) +from dify_agent.protocol.schemas import ( + CreateRunRequest, + EmptyRunEventData, + RUN_EVENT_ADAPTER, + RunEvent, + RunEventsResponse, + RunStartedEvent, + RunSucceededEvent, +) + + +def _create_run_payload() -> dict[str, object]: + return { + "compositor": { + "schema_version": 1, + "layers": [{"name": "prompt", "type": "plain.prompt", "config": {"user": "hello"}}], + }, + "agent_profile": {"provider": "test", "output_text": "done"}, + } + + +def _event_frame(event: RunEvent, *, event_id: str | None = None, exclude_id: bool = False) -> str: + payload = RUN_EVENT_ADAPTER.dump_json(event, exclude={"id"} if exclude_id else None).decode() + lines: list[str] = [] + if event_id is not None: + lines.append(f"id: {event_id}") + lines.append(f"data: {payload}") + return "\n".join(lines) + "\n\n" + + +def _run_status_json(status: str) -> dict[str, object]: + now = datetime(2026, 5, 11, tzinfo=UTC).isoformat() + return {"run_id": "run-1", "status": status, "created_at": now, "updated_at": now, "error": None} + + +class DisconnectingSyncStream(httpx.SyncByteStream): + chunks: list[bytes] + + def __init__(self, *chunks: str) -> None: + self.chunks = [chunk.encode() for chunk in chunks] + + @override + def __iter__(self) -> Iterator[bytes]: + yield from self.chunks + raise httpx.ReadError("stream disconnected") + + +def test_sync_methods_parse_protocol_dtos_and_validate_create_dict() -> None: + def handler(request: httpx.Request) -> httpx.Response: + if request.method == "POST" and request.url.path == "/runs": + payload = cast(dict[str, object], json.loads(request.content)) + compositor = cast(dict[str, object], payload["compositor"]) + layers = cast(list[dict[str, object]], compositor["layers"]) + assert layers[0]["config"] == {"user": "hello"} + assert payload["agent_profile"] == {"provider": "test", "output_text": "done"} + return httpx.Response(202, json={"run_id": "run-1", "status": "running"}) + if request.method == "GET" and request.url.path == "/runs/run-1": + return httpx.Response(200, json=_run_status_json("running")) + if request.method == "GET" and request.url.path == "/runs/run-1/events": + assert request.url.params["after"] == "0-0" + assert request.url.params["limit"] == "10" + event = RunStartedEvent(id="1-0", run_id="run-1") + return httpx.Response( + 200, + json={ + "run_id": "run-1", + "events": [cast(object, json.loads(RUN_EVENT_ADAPTER.dump_json(event)))], + "next_cursor": "1-0", + }, + ) + raise AssertionError(f"unexpected request: {request.method} {request.url}") + + http_client = httpx.Client(transport=httpx.MockTransport(handler)) + client = Client(base_url="http://testserver", sync_http_client=http_client) + + created = client.create_run_sync(_create_run_payload()) + status = client.get_run_sync(created.run_id) + events = client.get_events_sync(created.run_id, after="0-0", limit=10) + + assert created.status == "running" + assert status.status == "running" + assert isinstance(events, RunEventsResponse) + assert [event.type for event in events.events] == ["run_started"] + + +def test_async_methods_and_wait_run_parse_protocol_dtos() -> None: + statuses = iter(["running", "succeeded"]) + + def handler(request: httpx.Request) -> httpx.Response: + if request.method == "POST" and request.url.path == "/runs": + return httpx.Response(202, json={"run_id": "run-1", "status": "running"}) + if request.method == "GET" and request.url.path == "/runs/run-1": + return httpx.Response(200, json=_run_status_json(next(statuses))) + if request.method == "GET" and request.url.path == "/runs/run-1/events": + return httpx.Response(200, json={"run_id": "run-1", "events": [], "next_cursor": "0-0"}) + raise AssertionError(f"unexpected request: {request.method} {request.url}") + + async def scenario() -> None: + http_client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) + client = Client(base_url="http://testserver", async_http_client=http_client) + request = CreateRunRequest.model_validate(_create_run_payload()) + + created = await client.create_run(request) + events = await client.get_events(created.run_id) + terminal = await client.wait_run(created.run_id, poll_interval_seconds=0) + + assert created.run_id == "run-1" + assert events.events == [] + assert terminal.status == "succeeded" + await http_client.aclose() + + asyncio.run(scenario()) + + +def test_error_mapping_and_create_run_input_validation() -> None: + responses = iter( + [ + httpx.Response(404, json={"detail": "run not found"}), + httpx.Response(422, json={"detail": "invalid"}), + httpx.Response(500, json={"detail": "boom"}), + ] + ) + + def handler(_request: httpx.Request) -> httpx.Response: + return next(responses) + + client = Client( + base_url="http://testserver", + sync_http_client=httpx.Client(transport=httpx.MockTransport(handler)), + ) + + with pytest.raises(DifyAgentNotFoundError) as not_found: + _ = client.get_run_sync("missing") + assert not_found.value.status_code == 404 + assert not_found.value.detail == "run not found" + + with pytest.raises(DifyAgentValidationError) as validation: + _ = client.get_run_sync("bad") + assert validation.value.status_code == 422 + + with pytest.raises(DifyAgentHTTPError) as server_error: + _ = client.get_run_sync("bad") + assert server_error.value.status_code == 500 + + with pytest.raises(DifyAgentValidationError): + _ = client.create_run_sync({"unknown": "field"}) + + +def test_http_timeout_maps_to_client_timeout_error() -> None: + def handler(request: httpx.Request) -> httpx.Response: + raise httpx.ReadTimeout("slow", request=request) + + client = Client( + base_url="http://testserver", + sync_http_client=httpx.Client(transport=httpx.MockTransport(handler)), + ) + + with pytest.raises(DifyAgentTimeoutError): + _ = client.get_run_sync("run-1") + + +def test_create_run_is_not_retried_after_timeout() -> None: + attempts = 0 + + def handler(request: httpx.Request) -> httpx.Response: + nonlocal attempts + attempts += 1 + raise httpx.ConnectTimeout("cannot connect", request=request) + + client = Client( + base_url="http://testserver", + sync_http_client=httpx.Client(transport=httpx.MockTransport(handler)), + ) + + with pytest.raises(DifyAgentTimeoutError): + _ = client.create_run_sync(_create_run_payload()) + assert attempts == 1 + + +def test_sync_sse_parser_supports_comments_multiline_data_and_id_fill() -> None: + payload = RUN_EVENT_ADAPTER.dump_json(RunStartedEvent(run_id="run-1"), exclude={"id"}).decode() + before_type, after_type = payload.split('"type"', maxsplit=1) + body = f": keepalive\nid: 5-0\nevent: run_started\ndata: {before_type}\ndata: \"type\"{after_type}\n\n" + + def handler(request: httpx.Request) -> httpx.Response: + assert request.url.params["after"] == "0-0" + return httpx.Response(200, content=body) + + client = Client( + base_url="http://testserver", + sync_http_client=httpx.Client(transport=httpx.MockTransport(handler)), + ) + + events = list(client.stream_events_sync("run-1", until_terminal=False, reconnect=False)) + + assert [event.id for event in events] == ["5-0"] + assert [event.type for event in events] == ["run_started"] + + +def test_stream_events_stops_after_terminal_event() -> None: + calls = 0 + body = "".join( + [ + _event_frame(RunStartedEvent(id="1-0", run_id="run-1")), + _event_frame(RunSucceededEvent(id="2-0", run_id="run-1", data=EmptyRunEventData())), + ] + ) + + def handler(_request: httpx.Request) -> httpx.Response: + nonlocal calls + calls += 1 + return httpx.Response(200, content=body) + + client = Client( + base_url="http://testserver", + sync_http_client=httpx.Client(transport=httpx.MockTransport(handler)), + ) + + events = list(client.stream_events_sync("run-1", reconnect_delay_seconds=0)) + + assert [event.type for event in events] == ["run_started", "run_succeeded"] + assert calls == 1 + + +def test_stream_events_reconnects_from_latest_event_id() -> None: + seen_after: list[str] = [] + + def handler(request: httpx.Request) -> httpx.Response: + seen_after.append(request.url.params["after"]) + if len(seen_after) == 1: + return httpx.Response( + 200, + stream=DisconnectingSyncStream(_event_frame(RunStartedEvent(id="1-0", run_id="run-1"))), + ) + return httpx.Response(200, content=_event_frame(RunSucceededEvent(id="2-0", run_id="run-1"))) + + client = Client( + base_url="http://testserver", + sync_http_client=httpx.Client(transport=httpx.MockTransport(handler)), + ) + + events = list(client.stream_events_sync("run-1", reconnect_delay_seconds=0)) + + assert seen_after == ["0-0", "1-0"] + assert [event.type for event in events] == ["run_started", "run_succeeded"] + + +def test_stream_events_reconnects_after_http_5xx_response() -> None: + seen_after: list[str] = [] + + def handler(request: httpx.Request) -> httpx.Response: + seen_after.append(request.url.params["after"]) + if len(seen_after) == 1: + return httpx.Response(503, json={"detail": "temporarily unavailable"}) + return httpx.Response(200, content=_event_frame(RunSucceededEvent(id="2-0", run_id="run-1"))) + + client = Client( + base_url="http://testserver", + sync_http_client=httpx.Client(transport=httpx.MockTransport(handler)), + ) + + events = list(client.stream_events_sync("run-1", reconnect_delay_seconds=0)) + + assert seen_after == ["0-0", "0-0"] + assert [event.type for event in events] == ["run_succeeded"] + + +def test_stream_events_raises_when_reconnects_are_exhausted() -> None: + calls = 0 + + def handler(_request: httpx.Request) -> httpx.Response: + nonlocal calls + calls += 1 + return httpx.Response(200, stream=DisconnectingSyncStream()) + + client = Client( + base_url="http://testserver", + sync_http_client=httpx.Client(transport=httpx.MockTransport(handler)), + ) + + with pytest.raises(DifyAgentStreamError): + _ = list(client.stream_events_sync("run-1", max_reconnects=1, reconnect_delay_seconds=0)) + assert calls == 2 + + +def test_malformed_sse_frame_does_not_reconnect() -> None: + calls = 0 + + def handler(_request: httpx.Request) -> httpx.Response: + nonlocal calls + calls += 1 + return httpx.Response(200, content="data: not-json\n\n") + + client = Client( + base_url="http://testserver", + sync_http_client=httpx.Client(transport=httpx.MockTransport(handler)), + ) + + with pytest.raises(DifyAgentStreamError): + _ = list(client.stream_events_sync("run-1", reconnect_delay_seconds=0)) + assert calls == 1 + + +def test_async_stream_events_yields_terminal_event() -> None: + body = _event_frame(RunSucceededEvent(id="2-0", run_id="run-1")) + + def handler(_request: httpx.Request) -> httpx.Response: + return httpx.Response(200, content=body) + + async def scenario() -> None: + http_client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) + client = Client(base_url="http://testserver", async_http_client=http_client) + + events = [event async for event in client.stream_events("run-1")] + + assert [event.type for event in events] == ["run_succeeded"] + await http_client.aclose() + + asyncio.run(scenario()) + + +def test_async_stream_events_reconnects_after_http_5xx_response() -> None: + seen_after: list[str] = [] + + def handler(request: httpx.Request) -> httpx.Response: + seen_after.append(request.url.params["after"]) + if len(seen_after) == 1: + return httpx.Response(502, json={"detail": "bad gateway"}) + return httpx.Response(200, content=_event_frame(RunSucceededEvent(id="2-0", run_id="run-1"))) + + async def scenario() -> None: + http_client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) + client = Client(base_url="http://testserver", async_http_client=http_client) + + events = [event async for event in client.stream_events("run-1", reconnect_delay_seconds=0)] + + assert seen_after == ["0-0", "0-0"] + assert [event.type for event in events] == ["run_succeeded"] + await http_client.aclose() + + asyncio.run(scenario()) + + +def test_stream_timeout_can_reconnect_until_terminal() -> None: + calls = 0 + + def handler(request: httpx.Request) -> httpx.Response: + nonlocal calls + calls += 1 + if calls == 1: + raise httpx.ReadTimeout("stream stalled", request=request) + return httpx.Response(200, content=_event_frame(RunSucceededEvent(id="2-0", run_id="run-1"))) + + client = Client( + base_url="http://testserver", + sync_http_client=httpx.Client(transport=httpx.MockTransport(handler)), + ) + + events = list(client.stream_events_sync("run-1", reconnect_delay_seconds=0)) + + assert calls == 2 + assert [event.type for event in events] == ["run_succeeded"] diff --git a/dify-agent/tests/local/dify_agent/protocol/test_protocol_schemas.py b/dify-agent/tests/local/dify_agent/protocol/test_protocol_schemas.py new file mode 100644 index 0000000000..0a52316bbb --- /dev/null +++ b/dify-agent/tests/local/dify_agent/protocol/test_protocol_schemas.py @@ -0,0 +1,40 @@ +from pydantic_ai.messages import FinalResultEvent + +from dify_agent.protocol.schemas import ( + RUN_EVENT_ADAPTER, + AgentOutputRunEvent, + AgentOutputRunEventData, + PydanticAIStreamRunEvent, + RunFailedEvent, + RunFailedEventData, + RunStartedEvent, +) + + +def test_run_event_adapter_round_trips_typed_variants() -> None: + events = [ + RunStartedEvent(run_id="run-1"), + PydanticAIStreamRunEvent(run_id="run-1", data=FinalResultEvent(tool_name=None, tool_call_id=None)), + AgentOutputRunEvent(run_id="run-1", data=AgentOutputRunEventData(output="done")), + RunFailedEvent(run_id="run-1", data=RunFailedEventData(error="boom", reason="shutdown")), + ] + + for event in events: + payload = RUN_EVENT_ADAPTER.dump_json(event) + decoded = RUN_EVENT_ADAPTER.validate_json(payload) + + assert decoded.type == event.type + assert decoded.run_id == event.run_id + + +def test_pydantic_ai_event_data_uses_agent_stream_event_model() -> None: + event = RUN_EVENT_ADAPTER.validate_python( + { + "run_id": "run-1", + "type": "pydantic_ai_event", + "data": {"event_kind": "final_result", "tool_name": None, "tool_call_id": None}, + } + ) + + assert isinstance(event, PydanticAIStreamRunEvent) + assert isinstance(event.data, FinalResultEvent) diff --git a/dify-agent/tests/local/dify_agent/runtime/test_run_scheduler.py b/dify-agent/tests/local/dify_agent/runtime/test_run_scheduler.py index 13984e3760..06d6af2e59 100644 --- a/dify-agent/tests/local/dify_agent/runtime/test_run_scheduler.py +++ b/dify-agent/tests/local/dify_agent/runtime/test_run_scheduler.py @@ -6,8 +6,9 @@ import pytest from pydantic import JsonValue from agenton.compositor import CompositorConfig, LayerNodeConfig +from dify_agent.protocol.schemas import CreateRunRequest, RunEvent, RunStatus from dify_agent.runtime.run_scheduler import RunScheduler, SchedulerStoppingError -from dify_agent.server.schemas import CreateRunRequest, RunEvent, RunRecord, RunStatus +from dify_agent.server.schemas import RunRecord def _request(user: str | list[str] = "hello") -> CreateRunRequest: diff --git a/dify-agent/tests/local/dify_agent/runtime/test_runner.py b/dify-agent/tests/local/dify_agent/runtime/test_runner.py index c0cecb470e..92930ccb18 100644 --- a/dify-agent/tests/local/dify_agent/runtime/test_runner.py +++ b/dify-agent/tests/local/dify_agent/runtime/test_runner.py @@ -3,9 +3,9 @@ import asyncio import pytest from agenton.compositor import CompositorConfig, LayerNodeConfig +from dify_agent.protocol.schemas import AgentProfileConfig, CreateRunRequest from dify_agent.runtime.event_sink import InMemoryRunEventSink from dify_agent.runtime.runner import AgentRunRunner, AgentRunValidationError -from dify_agent.server.schemas import AgentProfileConfig, CreateRunRequest def test_runner_emits_terminal_success_and_snapshot() -> None: diff --git a/dify-agent/tests/local/dify_agent/server/test_runs_routes.py b/dify-agent/tests/local/dify_agent/server/test_runs_routes.py index ccb6f82fe3..083152ae5e 100644 --- a/dify-agent/tests/local/dify_agent/server/test_runs_routes.py +++ b/dify-agent/tests/local/dify_agent/server/test_runs_routes.py @@ -1,5 +1,6 @@ from fastapi.testclient import TestClient +from dify_agent.protocol.schemas import CreateRunRequest from dify_agent.runtime.run_scheduler import SchedulerStoppingError from dify_agent.server.routes.runs import create_runs_router from dify_agent.server.schemas import RunRecord @@ -122,7 +123,6 @@ def test_create_run_does_not_map_infrastructure_failure_to_422() -> None: def _request(): from agenton.compositor import CompositorConfig, LayerNodeConfig - from dify_agent.server.schemas import CreateRunRequest return CreateRunRequest( compositor=CompositorConfig( diff --git a/dify-agent/tests/local/dify_agent/server/test_schemas.py b/dify-agent/tests/local/dify_agent/server/test_schemas.py index e4be884b2e..e627c9e53f 100644 --- a/dify-agent/tests/local/dify_agent/server/test_schemas.py +++ b/dify-agent/tests/local/dify_agent/server/test_schemas.py @@ -1,40 +1,12 @@ -from pydantic_ai.messages import FinalResultEvent - -from dify_agent.server.schemas import ( - RUN_EVENT_ADAPTER, - AgentOutputRunEvent, - AgentOutputRunEventData, - PydanticAIStreamRunEvent, - RunFailedEvent, - RunFailedEventData, - RunStartedEvent, -) +import dify_agent.server.schemas as server_schemas -def test_run_event_adapter_round_trips_typed_variants() -> None: - events = [ - RunStartedEvent(run_id="run-1"), - PydanticAIStreamRunEvent(run_id="run-1", data=FinalResultEvent(tool_name=None, tool_call_id=None)), - AgentOutputRunEvent(run_id="run-1", data=AgentOutputRunEventData(output="done")), - RunFailedEvent(run_id="run-1", data=RunFailedEventData(error="boom", reason="shutdown")), - ] - - for event in events: - payload = RUN_EVENT_ADAPTER.dump_json(event) - decoded = RUN_EVENT_ADAPTER.validate_json(payload) - - assert decoded.type == event.type - assert decoded.run_id == event.run_id +def test_server_schemas_do_not_reexport_public_protocol_dtos() -> None: + assert server_schemas.__all__ == ["RunRecord", "new_run_id"] + assert not hasattr(server_schemas, "CreateRunRequest") + assert not hasattr(server_schemas, "RunStartedEvent") -def test_pydantic_ai_event_data_uses_agent_stream_event_model() -> None: - event = RUN_EVENT_ADAPTER.validate_python( - { - "run_id": "run-1", - "type": "pydantic_ai_event", - "data": {"event_kind": "final_result", "tool_name": None, "tool_call_id": None}, - } - ) - - assert isinstance(event, PydanticAIStreamRunEvent) - assert isinstance(event.data, FinalResultEvent) +def test_server_schemas_keep_server_only_run_helpers() -> None: + assert isinstance(server_schemas.new_run_id(), str) + assert hasattr(server_schemas, "RunRecord") diff --git a/dify-agent/tests/local/dify_agent/server/test_sse.py b/dify-agent/tests/local/dify_agent/server/test_sse.py index f54249453e..64201a8080 100644 --- a/dify-agent/tests/local/dify_agent/server/test_sse.py +++ b/dify-agent/tests/local/dify_agent/server/test_sse.py @@ -1,4 +1,4 @@ -from dify_agent.server.schemas import RunStartedEvent +from dify_agent.protocol.schemas import RunStartedEvent from dify_agent.server.sse import format_sse_event diff --git a/dify-agent/tests/local/dify_agent/storage/test_redis_run_store.py b/dify-agent/tests/local/dify_agent/storage/test_redis_run_store.py index e76ec5437c..2be665df22 100644 --- a/dify-agent/tests/local/dify_agent/storage/test_redis_run_store.py +++ b/dify-agent/tests/local/dify_agent/storage/test_redis_run_store.py @@ -2,7 +2,7 @@ import asyncio from collections.abc import Mapping from agenton.compositor import CompositorConfig, LayerNodeConfig -from dify_agent.server.schemas import CreateRunRequest, RunStartedEvent +from dify_agent.protocol.schemas import CreateRunRequest, RunStartedEvent from dify_agent.storage.redis_run_store import DEFAULT_RUN_RETENTION_SECONDS, RedisRunStore diff --git a/dify-agent/tests/local/examples/test_dify_agent_examples.py b/dify-agent/tests/local/examples/test_dify_agent_examples.py index 8feb3cd4a7..5b4cd3d9b9 100644 --- a/dify-agent/tests/local/examples/test_dify_agent_examples.py +++ b/dify-agent/tests/local/examples/test_dify_agent_examples.py @@ -8,5 +8,6 @@ def test_dify_agent_examples_are_importable() -> None: "dify_agent_examples.run_pydantic_ai_agent", "dify_agent_examples.run_server_consumer", "dify_agent_examples.run_server_sse_consumer", + "dify_agent_examples.run_server_sync_client", ]: importlib.import_module(module_name)