mirror of
https://github.com/langgenius/dify.git
synced 2026-05-13 08:57:28 +08:00
add dify-agent python client
This commit is contained in:
parent
e470e9d4c5
commit
15f5c7064e
@ -5,6 +5,10 @@ configuration, Pydantic AI runtime execution, Redis run records, and per-run Red
|
||||
Streams event logs. The FastAPI application lives at
|
||||
`dify-agent/src/dify_agent/server/app.py`.
|
||||
|
||||
Public Python DTOs and event models are exported from
|
||||
`dify_agent.protocol.schemas`. `dify_agent.server.schemas` is intentionally
|
||||
server-only and should not be used by API consumers.
|
||||
|
||||
## Input model
|
||||
|
||||
Create-run requests accept a `CompositorConfig` and an optional
|
||||
@ -154,6 +158,52 @@ Replay can start from a cursor with either:
|
||||
|
||||
If both are provided, the `after` query parameter takes precedence.
|
||||
|
||||
## Python client
|
||||
|
||||
Use `dify_agent.client.Client` for both async and sync code. Async methods use
|
||||
normal names; sync methods add `_sync`.
|
||||
|
||||
```python {test="skip" lint="skip"}
|
||||
from dify_agent.client import Client
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
async with Client(base_url="http://localhost:8000") as client:
|
||||
run = await client.create_run(
|
||||
{
|
||||
"compositor": {
|
||||
"schema_version": 1,
|
||||
"layers": [{"name": "prompt", "type": "plain.prompt", "config": {"user": "hello"}}],
|
||||
}
|
||||
}
|
||||
)
|
||||
async for event in client.stream_events(run.run_id):
|
||||
print(event)
|
||||
```
|
||||
|
||||
```python {test="skip" lint="skip"}
|
||||
from dify_agent.client import Client
|
||||
|
||||
|
||||
with Client(base_url="http://localhost:8000") as client:
|
||||
run = client.create_run_sync(
|
||||
{
|
||||
"compositor": {
|
||||
"schema_version": 1,
|
||||
"layers": [{"name": "prompt", "type": "plain.prompt", "config": {"user": "hello"}}],
|
||||
}
|
||||
}
|
||||
)
|
||||
terminal = client.wait_run_sync(run.run_id)
|
||||
```
|
||||
|
||||
`stream_events` and `stream_events_sync` parse SSE without an extra dependency.
|
||||
They reconnect by default from the latest yielded event id and stop after
|
||||
`run_succeeded` or `run_failed`. They do not reconnect for HTTP 4xx responses,
|
||||
DTO validation failures, or malformed SSE frames. `create_run` and
|
||||
`create_run_sync` never retry `POST /runs`; if a timeout occurs, the caller must
|
||||
decide whether to inspect existing runs or submit a new run.
|
||||
|
||||
## Event types and order
|
||||
|
||||
A normal successful run emits:
|
||||
@ -184,3 +234,4 @@ See:
|
||||
|
||||
- `dify-agent/examples/dify_agent/dify_agent_examples/run_server_consumer.py` for cursor polling
|
||||
- `dify-agent/examples/dify_agent/dify_agent_examples/run_server_sse_consumer.py` for SSE consumption
|
||||
- `dify-agent/examples/dify_agent/dify_agent_examples/run_server_sync_client.py` for synchronous client usage
|
||||
|
||||
@ -14,6 +14,11 @@ such as the FastAPI server, Redis, or the plugin daemon.
|
||||
```snippet {path="/examples/dify_agent/dify_agent_examples/run_server_consumer.py"}
|
||||
```
|
||||
|
||||
## Use the synchronous client
|
||||
|
||||
```snippet {path="/examples/dify_agent/dify_agent_examples/run_server_sync_client.py"}
|
||||
```
|
||||
|
||||
## Stream run events with SSE
|
||||
|
||||
```snippet {path="/examples/dify_agent/dify_agent_examples/run_server_sse_consumer.py"}
|
||||
|
||||
@ -11,6 +11,7 @@ EXAMPLE_MODULES = (
|
||||
"run_pydantic_ai_agent",
|
||||
"run_server_consumer",
|
||||
"run_server_sse_consumer",
|
||||
"run_server_sync_client",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
"""Example consumer for the Dify Agent run server.
|
||||
"""Async Python client example for the Dify Agent run server.
|
||||
|
||||
Requires Redis and a running API server. The server schedules runs in-process, for
|
||||
example:
|
||||
@ -7,21 +7,22 @@ example:
|
||||
|
||||
The default request uses the credential-free pydantic-ai TestModel profile. This
|
||||
script prints the created run and every event observed through cursor polling.
|
||||
``Client.create_run`` performs one POST attempt only; use polling or SSE replay to
|
||||
recover after client-side uncertainty.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import httpx
|
||||
from dify_agent.client import Client
|
||||
|
||||
|
||||
API_BASE_URL = "http://localhost:8000"
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
async with httpx.AsyncClient(base_url=API_BASE_URL, timeout=30) as client:
|
||||
create_response = await client.post(
|
||||
"/runs",
|
||||
json={
|
||||
async with Client(base_url=API_BASE_URL) as client:
|
||||
run = await client.create_run(
|
||||
{
|
||||
"compositor": {
|
||||
"schema_version": 1,
|
||||
"layers": [
|
||||
@ -36,21 +37,17 @@ async def main() -> None:
|
||||
],
|
||||
},
|
||||
"agent_profile": {"provider": "test", "output_text": "Hello from the example TestModel."},
|
||||
},
|
||||
}
|
||||
)
|
||||
create_response.raise_for_status()
|
||||
run = create_response.json()
|
||||
print("created run", run)
|
||||
|
||||
cursor = "0-0"
|
||||
while True:
|
||||
events_response = await client.get(f"/runs/{run['run_id']}/events", params={"after": cursor})
|
||||
events_response.raise_for_status()
|
||||
page = events_response.json()
|
||||
cursor = page["next_cursor"] or cursor
|
||||
for event in page["events"]:
|
||||
page = await client.get_events(run.run_id, after=cursor)
|
||||
cursor = page.next_cursor or cursor
|
||||
for event in page.events:
|
||||
print("event", event)
|
||||
if event["type"] in {"run_succeeded", "run_failed"}:
|
||||
if event.type in {"run_succeeded", "run_failed"}:
|
||||
return
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
|
||||
@ -1,13 +1,14 @@
|
||||
"""SSE consumer sketch for the Dify Agent run server.
|
||||
"""Async SSE client example for the Dify Agent run server.
|
||||
|
||||
Create a run with ``run_server_consumer.py`` or any HTTP client, then set RUN_ID
|
||||
below and run this script while the server is available. It prints raw SSE frames
|
||||
without requiring model credentials.
|
||||
below and run this script while the server is available. The Python client parses
|
||||
SSE frames into typed protocol events and reconnects with the latest event id by
|
||||
default. Malformed frames and HTTP 4xx responses fail without reconnecting.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import httpx
|
||||
from dify_agent.client import Client
|
||||
|
||||
|
||||
API_BASE_URL = "http://localhost:8000"
|
||||
@ -15,11 +16,9 @@ RUN_ID = "replace-with-run-id"
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
async with httpx.AsyncClient(base_url=API_BASE_URL, timeout=None) as client:
|
||||
async with client.stream("GET", f"/runs/{RUN_ID}/events/sse") as response:
|
||||
response.raise_for_status()
|
||||
async for line in response.aiter_lines():
|
||||
print(line)
|
||||
async with Client(base_url=API_BASE_URL, stream_timeout=None) as client:
|
||||
async for event in client.stream_events(RUN_ID):
|
||||
print(event)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -0,0 +1,40 @@
|
||||
"""Synchronous Python client example for the Dify Agent run server.
|
||||
|
||||
Requires the same running FastAPI server as the async examples. ``create_run_sync``
|
||||
does not retry ``POST /runs``; if a timeout occurs, inspect server state or create
|
||||
a new run explicitly rather than assuming the original request was not accepted.
|
||||
"""
|
||||
|
||||
from dify_agent.client import Client
|
||||
|
||||
|
||||
API_BASE_URL = "http://localhost:8000"
|
||||
|
||||
|
||||
def main() -> None:
|
||||
with Client(base_url=API_BASE_URL) as client:
|
||||
run = client.create_run_sync(
|
||||
{
|
||||
"compositor": {
|
||||
"schema_version": 1,
|
||||
"layers": [
|
||||
{
|
||||
"name": "prompt",
|
||||
"type": "plain.prompt",
|
||||
"config": {
|
||||
"prefix": "You are a concise assistant.",
|
||||
"user": "Say hello from the synchronous Dify Agent client example.",
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
"agent_profile": {"provider": "test", "output_text": "Hello from the sync TestModel."},
|
||||
}
|
||||
)
|
||||
print("created run", run)
|
||||
terminal = client.wait_run_sync(run.run_id, poll_interval_seconds=0.5)
|
||||
print("terminal status", terminal)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
21
dify-agent/src/dify_agent/client/__init__.py
Normal file
21
dify-agent/src/dify_agent/client/__init__.py
Normal file
@ -0,0 +1,21 @@
|
||||
"""Unified sync and async Python client for the Dify Agent run API."""
|
||||
|
||||
from ._client import (
|
||||
Client,
|
||||
DifyAgentClientError,
|
||||
DifyAgentHTTPError,
|
||||
DifyAgentNotFoundError,
|
||||
DifyAgentStreamError,
|
||||
DifyAgentTimeoutError,
|
||||
DifyAgentValidationError,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Client",
|
||||
"DifyAgentClientError",
|
||||
"DifyAgentHTTPError",
|
||||
"DifyAgentNotFoundError",
|
||||
"DifyAgentStreamError",
|
||||
"DifyAgentTimeoutError",
|
||||
"DifyAgentValidationError",
|
||||
]
|
||||
667
dify-agent/src/dify_agent/client/_client.py
Normal file
667
dify-agent/src/dify_agent/client/_client.py
Normal file
@ -0,0 +1,667 @@
|
||||
"""HTTPX-based client for Dify Agent runs.
|
||||
|
||||
The client uses the public DTOs from ``dify_agent.protocol.schemas`` for all
|
||||
normal request and response parsing. It intentionally does not retry
|
||||
``POST /runs`` because create-run is not idempotent. SSE streams are the only
|
||||
operation with reconnect logic: transient stream/connect/read failures, stream
|
||||
timeouts, and HTTP 5xx stream responses reconnect with the latest observed event
|
||||
id, while HTTP 4xx responses, DTO validation failures, and malformed SSE frames
|
||||
fail immediately.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from types import TracebackType
|
||||
from typing import Self, TypeVar, cast
|
||||
from urllib.parse import quote
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from dify_agent.protocol.schemas import (
|
||||
CreateRunRequest,
|
||||
CreateRunResponse,
|
||||
RUN_EVENT_ADAPTER,
|
||||
RunEvent,
|
||||
RunEventsResponse,
|
||||
RunStatusResponse,
|
||||
)
|
||||
|
||||
_ResponseModelT = TypeVar("_ResponseModelT", bound=BaseModel)
|
||||
_TERMINAL_EVENT_TYPES = {"run_succeeded", "run_failed"}
|
||||
_TERMINAL_RUN_STATUSES = {"succeeded", "failed"}
|
||||
|
||||
|
||||
class DifyAgentClientError(RuntimeError):
|
||||
"""Base class for errors raised by the Dify Agent Python client."""
|
||||
|
||||
|
||||
class DifyAgentHTTPError(DifyAgentClientError):
|
||||
"""Raised for HTTP 4xx/5xx responses not covered by a narrower subclass."""
|
||||
|
||||
status_code: int
|
||||
detail: object
|
||||
|
||||
def __init__(self, status_code: int, detail: object) -> None:
|
||||
self.status_code = status_code
|
||||
self.detail = detail
|
||||
super().__init__(f"Dify Agent HTTP {status_code}: {detail}")
|
||||
|
||||
|
||||
class DifyAgentNotFoundError(DifyAgentHTTPError):
|
||||
"""Raised when the server returns ``404`` for a run resource."""
|
||||
|
||||
|
||||
class DifyAgentValidationError(DifyAgentHTTPError):
|
||||
"""Raised for local input validation, invalid DTO responses, or HTTP ``422``."""
|
||||
|
||||
def __init__(self, detail: object, *, status_code: int = 422) -> None:
|
||||
super().__init__(status_code=status_code, detail=detail)
|
||||
|
||||
|
||||
class DifyAgentTimeoutError(DifyAgentClientError):
|
||||
"""Raised when an HTTPX timeout occurs outside successful SSE reconnects."""
|
||||
|
||||
|
||||
class DifyAgentStreamError(DifyAgentClientError):
|
||||
"""Raised for malformed SSE frames or exhausted SSE reconnect attempts."""
|
||||
|
||||
|
||||
class _ReconnectableStreamError(Exception):
|
||||
"""Internal wrapper for stream failures that may be retried by the caller."""
|
||||
|
||||
error: DifyAgentClientError
|
||||
|
||||
def __init__(self, error: DifyAgentClientError) -> None:
|
||||
self.error = error
|
||||
super().__init__(str(error))
|
||||
|
||||
|
||||
class _SSEDecoder:
|
||||
"""Incrementally decode SSE lines into typed run events.
|
||||
|
||||
The decoder keeps only the fields for the current frame. Comments are ignored,
|
||||
``data`` fields are joined with newlines as required by the SSE specification,
|
||||
and payload JSON is validated by ``RUN_EVENT_ADAPTER``. The frame ``id`` is
|
||||
copied into the decoded event only when the JSON payload omits ``event.id``.
|
||||
"""
|
||||
|
||||
_event_id: str | None
|
||||
_event_type: str | None
|
||||
_data_lines: list[str]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._event_id = None
|
||||
self._event_type = None
|
||||
self._data_lines = []
|
||||
|
||||
def feed_line(self, raw_line: str) -> RunEvent | None:
|
||||
"""Consume one SSE line and return an event when a frame completes.
|
||||
|
||||
Empty lines dispatch the current frame. Comment-only frames and frames
|
||||
without ``data`` are ignored so server heartbeats do not surface to users.
|
||||
Malformed event payloads raise ``DifyAgentStreamError`` and must not be
|
||||
retried because replaying would repeat the same invalid frame.
|
||||
"""
|
||||
line = raw_line.rstrip("\r")
|
||||
if line == "":
|
||||
return self._dispatch()
|
||||
if line.startswith(":"):
|
||||
return None
|
||||
|
||||
field, separator, value = line.partition(":")
|
||||
if separator and value.startswith(" "):
|
||||
value = value[1:]
|
||||
if field == "id":
|
||||
self._event_id = value
|
||||
elif field == "event":
|
||||
self._event_type = value
|
||||
elif field == "data":
|
||||
self._data_lines.append(value)
|
||||
return None
|
||||
|
||||
def _dispatch(self) -> RunEvent | None:
|
||||
"""Validate and return the current frame, then clear decoder state."""
|
||||
if not self._data_lines:
|
||||
self._reset()
|
||||
return None
|
||||
|
||||
frame_id = self._event_id
|
||||
frame_event_type = self._event_type
|
||||
data = "\n".join(self._data_lines)
|
||||
self._reset()
|
||||
|
||||
try:
|
||||
event = RUN_EVENT_ADAPTER.validate_json(data)
|
||||
except ValidationError as exc:
|
||||
raise DifyAgentStreamError("malformed SSE data frame") from exc
|
||||
if frame_event_type is not None and frame_event_type != event.type:
|
||||
raise DifyAgentStreamError(
|
||||
f"SSE event field {frame_event_type!r} does not match payload type {event.type!r}"
|
||||
)
|
||||
if frame_id is not None and event.id is None:
|
||||
return event.model_copy(update={"id": frame_id})
|
||||
return event
|
||||
|
||||
def _reset(self) -> None:
|
||||
"""Clear the current frame without changing decoder configuration."""
|
||||
self._event_id = None
|
||||
self._event_type = None
|
||||
self._data_lines = []
|
||||
|
||||
|
||||
class Client:
|
||||
"""Unified synchronous and asynchronous client for Dify Agent runs.
|
||||
|
||||
The instance is intentionally small and stateful: it stores base URL, default
|
||||
headers, timeout settings, optional external HTTPX clients, and lazy-owned
|
||||
clients for whichever sync/async side is used. External clients are never
|
||||
closed by this wrapper. Owned sync clients close via ``close_sync`` or the
|
||||
sync context manager; owned async clients close via ``aclose`` or the async
|
||||
context manager.
|
||||
"""
|
||||
|
||||
_base_url: str
|
||||
_timeout: float | httpx.Timeout
|
||||
_stream_timeout: float | httpx.Timeout | None
|
||||
_headers: dict[str, str]
|
||||
_sync_http_client: httpx.Client | None
|
||||
_async_http_client: httpx.AsyncClient | None
|
||||
_owns_sync_http_client: bool
|
||||
_owns_async_http_client: bool
|
||||
_sync_closed: bool
|
||||
_async_closed: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
base_url: str,
|
||||
timeout: float | httpx.Timeout = 30.0,
|
||||
stream_timeout: float | httpx.Timeout | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
sync_http_client: httpx.Client | None = None,
|
||||
async_http_client: httpx.AsyncClient | None = None,
|
||||
) -> None:
|
||||
self._base_url = base_url.rstrip("/")
|
||||
self._timeout = timeout
|
||||
self._stream_timeout = stream_timeout
|
||||
self._headers = dict(headers or {})
|
||||
self._sync_http_client = sync_http_client
|
||||
self._async_http_client = async_http_client
|
||||
self._owns_sync_http_client = sync_http_client is None
|
||||
self._owns_async_http_client = async_http_client is None
|
||||
self._sync_closed = False
|
||||
self._async_closed = False
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
"""Enter a sync context and return this client without opening the network."""
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> None:
|
||||
"""Close the owned sync HTTP client when leaving a sync context."""
|
||||
del exc_type, exc_value, traceback
|
||||
self.close_sync()
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
"""Enter an async context and return this client without opening the network."""
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> None:
|
||||
"""Close owned async resources when leaving an async context."""
|
||||
del exc_type, exc_value, traceback
|
||||
await self.aclose()
|
||||
|
||||
def close_sync(self) -> None:
|
||||
"""Close the owned synchronous HTTPX client if it was created."""
|
||||
if self._sync_closed:
|
||||
return
|
||||
if self._owns_sync_http_client and self._sync_http_client is not None:
|
||||
self._sync_http_client.close()
|
||||
self._sync_closed = True
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Close owned asynchronous resources and any owned sync client already opened."""
|
||||
if not self._async_closed:
|
||||
if self._owns_async_http_client and self._async_http_client is not None:
|
||||
await self._async_http_client.aclose()
|
||||
self._async_closed = True
|
||||
if self._owns_sync_http_client and self._sync_http_client is not None:
|
||||
self.close_sync()
|
||||
|
||||
async def create_run(self, request: CreateRunRequest | dict[str, object]) -> CreateRunResponse:
|
||||
"""Create one run and return its accepted status response.
|
||||
|
||||
Dict inputs are validated as ``CreateRunRequest`` before the request is
|
||||
sent. This method performs exactly one ``POST /runs`` attempt and maps
|
||||
HTTPX timeouts to ``DifyAgentTimeoutError``.
|
||||
"""
|
||||
request_model = _validate_create_run_request(request)
|
||||
try:
|
||||
response = await self._get_async_http_client().post(
|
||||
self._url("/runs"),
|
||||
content=request_model.model_dump_json(),
|
||||
headers=self._merged_headers({"Content-Type": "application/json"}),
|
||||
timeout=self._timeout,
|
||||
)
|
||||
except httpx.TimeoutException as exc:
|
||||
raise DifyAgentTimeoutError("create_run timed out") from exc
|
||||
except httpx.RequestError as exc:
|
||||
raise DifyAgentClientError(f"create_run request failed: {exc}") from exc
|
||||
return _parse_model_response(response, CreateRunResponse)
|
||||
|
||||
def create_run_sync(self, request: CreateRunRequest | dict[str, object]) -> CreateRunResponse:
|
||||
"""Synchronous variant of ``create_run`` with the same no-retry contract."""
|
||||
request_model = _validate_create_run_request(request)
|
||||
try:
|
||||
response = self._get_sync_http_client().post(
|
||||
self._url("/runs"),
|
||||
content=request_model.model_dump_json(),
|
||||
headers=self._merged_headers({"Content-Type": "application/json"}),
|
||||
timeout=self._timeout,
|
||||
)
|
||||
except httpx.TimeoutException as exc:
|
||||
raise DifyAgentTimeoutError("create_run_sync timed out") from exc
|
||||
except httpx.RequestError as exc:
|
||||
raise DifyAgentClientError(f"create_run_sync request failed: {exc}") from exc
|
||||
return _parse_model_response(response, CreateRunResponse)
|
||||
|
||||
async def get_run(self, run_id: str) -> RunStatusResponse:
|
||||
"""Return the current status for ``run_id`` or raise a mapped client error."""
|
||||
try:
|
||||
response = await self._get_async_http_client().get(
|
||||
self._url(f"/runs/{quote(run_id, safe='')}"),
|
||||
headers=self._merged_headers(),
|
||||
timeout=self._timeout,
|
||||
)
|
||||
except httpx.TimeoutException as exc:
|
||||
raise DifyAgentTimeoutError("get_run timed out") from exc
|
||||
except httpx.RequestError as exc:
|
||||
raise DifyAgentClientError(f"get_run request failed: {exc}") from exc
|
||||
return _parse_model_response(response, RunStatusResponse)
|
||||
|
||||
def get_run_sync(self, run_id: str) -> RunStatusResponse:
|
||||
"""Synchronous variant of ``get_run``."""
|
||||
try:
|
||||
response = self._get_sync_http_client().get(
|
||||
self._url(f"/runs/{quote(run_id, safe='')}"),
|
||||
headers=self._merged_headers(),
|
||||
timeout=self._timeout,
|
||||
)
|
||||
except httpx.TimeoutException as exc:
|
||||
raise DifyAgentTimeoutError("get_run_sync timed out") from exc
|
||||
except httpx.RequestError as exc:
|
||||
raise DifyAgentClientError(f"get_run_sync request failed: {exc}") from exc
|
||||
return _parse_model_response(response, RunStatusResponse)
|
||||
|
||||
async def get_events(self, run_id: str, *, after: str = "0-0", limit: int = 100) -> RunEventsResponse:
|
||||
"""Return one cursor-paginated page of events for ``run_id``."""
|
||||
try:
|
||||
response = await self._get_async_http_client().get(
|
||||
self._url(f"/runs/{quote(run_id, safe='')}/events"),
|
||||
params={"after": after, "limit": str(limit)},
|
||||
headers=self._merged_headers(),
|
||||
timeout=self._timeout,
|
||||
)
|
||||
except httpx.TimeoutException as exc:
|
||||
raise DifyAgentTimeoutError("get_events timed out") from exc
|
||||
except httpx.RequestError as exc:
|
||||
raise DifyAgentClientError(f"get_events request failed: {exc}") from exc
|
||||
return _parse_model_response(response, RunEventsResponse)
|
||||
|
||||
def get_events_sync(self, run_id: str, *, after: str = "0-0", limit: int = 100) -> RunEventsResponse:
|
||||
"""Synchronous variant of ``get_events``."""
|
||||
try:
|
||||
response = self._get_sync_http_client().get(
|
||||
self._url(f"/runs/{quote(run_id, safe='')}/events"),
|
||||
params={"after": after, "limit": str(limit)},
|
||||
headers=self._merged_headers(),
|
||||
timeout=self._timeout,
|
||||
)
|
||||
except httpx.TimeoutException as exc:
|
||||
raise DifyAgentTimeoutError("get_events_sync timed out") from exc
|
||||
except httpx.RequestError as exc:
|
||||
raise DifyAgentClientError(f"get_events_sync request failed: {exc}") from exc
|
||||
return _parse_model_response(response, RunEventsResponse)
|
||||
|
||||
async def stream_events(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
after: str | None = None,
|
||||
reconnect: bool = True,
|
||||
max_reconnects: int | None = None,
|
||||
reconnect_delay_seconds: float = 1.0,
|
||||
until_terminal: bool = True,
|
||||
) -> AsyncIterator[RunEvent]:
|
||||
"""Yield typed events from SSE with cursor-based reconnect.
|
||||
|
||||
The initial cursor is ``after`` or ``"0-0"``. After every yielded event
|
||||
with an id, reconnects resume from that id using the ``after`` query
|
||||
parameter. HTTP 5xx stream responses are retried, but HTTP 4xx responses,
|
||||
DTO validation failures, and malformed SSE frames are not retried. By
|
||||
default iteration stops after ``run_succeeded`` or ``run_failed``.
|
||||
"""
|
||||
_validate_stream_options(max_reconnects, reconnect_delay_seconds)
|
||||
cursor = after or "0-0"
|
||||
reconnect_attempts = 0
|
||||
while True:
|
||||
try:
|
||||
async for event in self._stream_events_once(run_id, after=cursor):
|
||||
if event.id is not None:
|
||||
cursor = event.id
|
||||
yield event
|
||||
if until_terminal and event.type in _TERMINAL_EVENT_TYPES:
|
||||
return
|
||||
except _ReconnectableStreamError as exc:
|
||||
if not reconnect:
|
||||
raise exc.error from exc
|
||||
reconnect_attempts = _next_reconnect_attempt(
|
||||
reconnect_attempts,
|
||||
max_reconnects=max_reconnects,
|
||||
error=exc.error,
|
||||
)
|
||||
await _sleep_async(reconnect_delay_seconds)
|
||||
continue
|
||||
if not reconnect:
|
||||
return
|
||||
reconnect_attempts = _next_reconnect_attempt(
|
||||
reconnect_attempts,
|
||||
max_reconnects=max_reconnects,
|
||||
error=DifyAgentStreamError("SSE stream ended before a terminal event"),
|
||||
)
|
||||
await _sleep_async(reconnect_delay_seconds)
|
||||
|
||||
def stream_events_sync(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
after: str | None = None,
|
||||
reconnect: bool = True,
|
||||
max_reconnects: int | None = None,
|
||||
reconnect_delay_seconds: float = 1.0,
|
||||
until_terminal: bool = True,
|
||||
) -> Iterator[RunEvent]:
|
||||
"""Synchronous variant of ``stream_events`` with the same reconnect rules."""
|
||||
_validate_stream_options(max_reconnects, reconnect_delay_seconds)
|
||||
cursor = after or "0-0"
|
||||
reconnect_attempts = 0
|
||||
while True:
|
||||
try:
|
||||
for event in self._stream_events_once_sync(run_id, after=cursor):
|
||||
if event.id is not None:
|
||||
cursor = event.id
|
||||
yield event
|
||||
if until_terminal and event.type in _TERMINAL_EVENT_TYPES:
|
||||
return
|
||||
except _ReconnectableStreamError as exc:
|
||||
if not reconnect:
|
||||
raise exc.error from exc
|
||||
reconnect_attempts = _next_reconnect_attempt(
|
||||
reconnect_attempts,
|
||||
max_reconnects=max_reconnects,
|
||||
error=exc.error,
|
||||
)
|
||||
_sleep_sync(reconnect_delay_seconds)
|
||||
continue
|
||||
if not reconnect:
|
||||
return
|
||||
reconnect_attempts = _next_reconnect_attempt(
|
||||
reconnect_attempts,
|
||||
max_reconnects=max_reconnects,
|
||||
error=DifyAgentStreamError("SSE stream ended before a terminal event"),
|
||||
)
|
||||
_sleep_sync(reconnect_delay_seconds)
|
||||
|
||||
async def wait_run(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
poll_interval_seconds: float = 1.0,
|
||||
timeout_seconds: float | None = None,
|
||||
) -> RunStatusResponse:
|
||||
"""Poll run status until it becomes terminal and return the final status."""
|
||||
_validate_wait_options(poll_interval_seconds, timeout_seconds)
|
||||
deadline = time.monotonic() + timeout_seconds if timeout_seconds is not None else None
|
||||
while True:
|
||||
status = await self.get_run(run_id)
|
||||
if status.status in _TERMINAL_RUN_STATUSES:
|
||||
return status
|
||||
sleep_for = _next_sleep_seconds(poll_interval_seconds, deadline)
|
||||
if sleep_for is None:
|
||||
raise DifyAgentTimeoutError(f"run {run_id!r} did not finish before timeout")
|
||||
await _sleep_async(sleep_for)
|
||||
|
||||
def wait_run_sync(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
poll_interval_seconds: float = 1.0,
|
||||
timeout_seconds: float | None = None,
|
||||
) -> RunStatusResponse:
|
||||
"""Synchronous variant of ``wait_run``."""
|
||||
_validate_wait_options(poll_interval_seconds, timeout_seconds)
|
||||
deadline = time.monotonic() + timeout_seconds if timeout_seconds is not None else None
|
||||
while True:
|
||||
status = self.get_run_sync(run_id)
|
||||
if status.status in _TERMINAL_RUN_STATUSES:
|
||||
return status
|
||||
sleep_for = _next_sleep_seconds(poll_interval_seconds, deadline)
|
||||
if sleep_for is None:
|
||||
raise DifyAgentTimeoutError(f"run {run_id!r} did not finish before timeout")
|
||||
_sleep_sync(sleep_for)
|
||||
|
||||
async def _stream_events_once(self, run_id: str, *, after: str) -> AsyncIterator[RunEvent]:
|
||||
"""Open one SSE connection and yield events until it ends or fails."""
|
||||
try:
|
||||
async with self._get_async_http_client().stream(
|
||||
"GET",
|
||||
self._url(f"/runs/{quote(run_id, safe='')}/events/sse"),
|
||||
params={"after": after},
|
||||
headers=self._merged_headers(),
|
||||
timeout=self._stream_timeout,
|
||||
) as response:
|
||||
if response.status_code >= 400:
|
||||
_ = await response.aread()
|
||||
_raise_for_stream_status(response)
|
||||
decoder = _SSEDecoder()
|
||||
async for line in response.aiter_lines():
|
||||
event = decoder.feed_line(line)
|
||||
if event is not None:
|
||||
yield event
|
||||
except DifyAgentHTTPError:
|
||||
raise
|
||||
except DifyAgentStreamError:
|
||||
raise
|
||||
except httpx.TimeoutException as exc:
|
||||
raise _ReconnectableStreamError(DifyAgentTimeoutError("SSE stream timed out")) from exc
|
||||
except httpx.TransportError as exc:
|
||||
raise _ReconnectableStreamError(DifyAgentStreamError(f"SSE stream failed: {exc}")) from exc
|
||||
except httpx.StreamError as exc:
|
||||
raise _ReconnectableStreamError(DifyAgentStreamError(f"SSE stream failed: {exc}")) from exc
|
||||
|
||||
def _stream_events_once_sync(self, run_id: str, *, after: str) -> Iterator[RunEvent]:
|
||||
"""Open one synchronous SSE connection and yield events until it ends or fails."""
|
||||
try:
|
||||
with self._get_sync_http_client().stream(
|
||||
"GET",
|
||||
self._url(f"/runs/{quote(run_id, safe='')}/events/sse"),
|
||||
params={"after": after},
|
||||
headers=self._merged_headers(),
|
||||
timeout=self._stream_timeout,
|
||||
) as response:
|
||||
if response.status_code >= 400:
|
||||
_ = response.read()
|
||||
_raise_for_stream_status(response)
|
||||
decoder = _SSEDecoder()
|
||||
for line in response.iter_lines():
|
||||
event = decoder.feed_line(line)
|
||||
if event is not None:
|
||||
yield event
|
||||
except DifyAgentHTTPError:
|
||||
raise
|
||||
except DifyAgentStreamError:
|
||||
raise
|
||||
except httpx.TimeoutException as exc:
|
||||
raise _ReconnectableStreamError(DifyAgentTimeoutError("SSE stream timed out")) from exc
|
||||
except httpx.TransportError as exc:
|
||||
raise _ReconnectableStreamError(DifyAgentStreamError(f"SSE stream failed: {exc}")) from exc
|
||||
except httpx.StreamError as exc:
|
||||
raise _ReconnectableStreamError(DifyAgentStreamError(f"SSE stream failed: {exc}")) from exc
|
||||
|
||||
def _get_sync_http_client(self) -> httpx.Client:
|
||||
"""Return an open sync HTTPX client, creating an owned one lazily."""
|
||||
if self._sync_closed:
|
||||
raise DifyAgentClientError("sync client is closed")
|
||||
if self._sync_http_client is None:
|
||||
self._sync_http_client = httpx.Client(timeout=self._timeout, headers=self._headers)
|
||||
return self._sync_http_client
|
||||
|
||||
def _get_async_http_client(self) -> httpx.AsyncClient:
|
||||
"""Return an open async HTTPX client, creating an owned one lazily."""
|
||||
if self._async_closed:
|
||||
raise DifyAgentClientError("async client is closed")
|
||||
if self._async_http_client is None:
|
||||
self._async_http_client = httpx.AsyncClient(timeout=self._timeout, headers=self._headers)
|
||||
return self._async_http_client
|
||||
|
||||
def _url(self, path: str) -> str:
|
||||
"""Build an absolute URL from the configured base and API path."""
|
||||
return f"{self._base_url}{path}"
|
||||
|
||||
def _merged_headers(self, extra: dict[str, str] | None = None) -> dict[str, str]:
|
||||
"""Return per-request headers without mutating client defaults."""
|
||||
headers = dict(self._headers)
|
||||
if extra is not None:
|
||||
headers.update(extra)
|
||||
return headers
|
||||
|
||||
|
||||
def _validate_create_run_request(request: CreateRunRequest | dict[str, object]) -> CreateRunRequest:
|
||||
"""Validate user input before creating a run."""
|
||||
if isinstance(request, CreateRunRequest):
|
||||
return request
|
||||
try:
|
||||
return CreateRunRequest.model_validate(request)
|
||||
except ValidationError as exc:
|
||||
raise DifyAgentValidationError(detail=exc.errors(include_url=False)) from exc
|
||||
|
||||
|
||||
def _parse_model_response(response: httpx.Response, model_type: type[_ResponseModelT]) -> _ResponseModelT:
|
||||
"""Map HTTP errors and parse a Pydantic response DTO."""
|
||||
_raise_for_status(response)
|
||||
try:
|
||||
return model_type.model_validate_json(response.content)
|
||||
except ValidationError as exc:
|
||||
raise DifyAgentValidationError(
|
||||
detail=exc.errors(include_url=False),
|
||||
status_code=response.status_code,
|
||||
) from exc
|
||||
|
||||
|
||||
def _raise_for_status(response: httpx.Response) -> None:
|
||||
"""Raise the configured client exception for HTTP 4xx/5xx responses."""
|
||||
if response.status_code < 400:
|
||||
return
|
||||
detail = _extract_error_detail(response)
|
||||
if response.status_code == 404:
|
||||
raise DifyAgentNotFoundError(status_code=response.status_code, detail=detail)
|
||||
if response.status_code == 422:
|
||||
raise DifyAgentValidationError(status_code=response.status_code, detail=detail)
|
||||
raise DifyAgentHTTPError(status_code=response.status_code, detail=detail)
|
||||
|
||||
|
||||
def _raise_for_stream_status(response: httpx.Response) -> None:
|
||||
"""Raise terminal 4xx errors or wrap retryable SSE 5xx responses."""
|
||||
try:
|
||||
_raise_for_status(response)
|
||||
except DifyAgentHTTPError as exc:
|
||||
if response.status_code >= 500:
|
||||
raise _ReconnectableStreamError(
|
||||
DifyAgentStreamError(f"SSE stream HTTP {response.status_code}: {exc.detail}")
|
||||
) from exc
|
||||
raise
|
||||
|
||||
|
||||
def _extract_error_detail(response: httpx.Response) -> object:
|
||||
"""Extract FastAPI's ``detail`` field when present, falling back to text."""
|
||||
try:
|
||||
payload = cast(object, response.json())
|
||||
except (ValueError, httpx.ResponseNotRead):
|
||||
return response.text or response.reason_phrase
|
||||
if isinstance(payload, dict) and "detail" in payload:
|
||||
return cast(object, payload["detail"])
|
||||
return cast(object, payload)
|
||||
|
||||
|
||||
def _next_reconnect_attempt(
|
||||
reconnect_attempts: int,
|
||||
*,
|
||||
max_reconnects: int | None,
|
||||
error: DifyAgentClientError,
|
||||
) -> int:
|
||||
"""Increment reconnect attempts or raise when the configured budget is spent."""
|
||||
if max_reconnects is not None and reconnect_attempts >= max_reconnects:
|
||||
raise DifyAgentStreamError("SSE stream reconnect attempts exhausted") from error
|
||||
return reconnect_attempts + 1
|
||||
|
||||
|
||||
def _validate_stream_options(max_reconnects: int | None, reconnect_delay_seconds: float) -> None:
|
||||
"""Reject stream options that cannot produce deterministic reconnect behavior."""
|
||||
if max_reconnects is not None and max_reconnects < 0:
|
||||
raise DifyAgentValidationError(detail="max_reconnects must be non-negative")
|
||||
if reconnect_delay_seconds < 0:
|
||||
raise DifyAgentValidationError(detail="reconnect_delay_seconds must be non-negative")
|
||||
|
||||
|
||||
def _validate_wait_options(poll_interval_seconds: float, timeout_seconds: float | None) -> None:
|
||||
"""Reject wait options that would make polling ambiguous."""
|
||||
if poll_interval_seconds < 0:
|
||||
raise DifyAgentValidationError(detail="poll_interval_seconds must be non-negative")
|
||||
if timeout_seconds is not None and timeout_seconds < 0:
|
||||
raise DifyAgentValidationError(detail="timeout_seconds must be non-negative")
|
||||
|
||||
|
||||
def _next_sleep_seconds(poll_interval_seconds: float, deadline: float | None) -> float | None:
|
||||
"""Return the next polling sleep duration, or ``None`` when timed out."""
|
||||
if deadline is None:
|
||||
return poll_interval_seconds
|
||||
remaining = deadline - time.monotonic()
|
||||
if remaining <= 0:
|
||||
return None
|
||||
return min(poll_interval_seconds, remaining)
|
||||
|
||||
|
||||
async def _sleep_async(seconds: float) -> None:
|
||||
"""Sleep asynchronously, skipping the call for zero-second test delays."""
|
||||
if seconds > 0:
|
||||
await asyncio.sleep(seconds)
|
||||
|
||||
|
||||
def _sleep_sync(seconds: float) -> None:
|
||||
"""Sleep synchronously, skipping the call for zero-second test delays."""
|
||||
if seconds > 0:
|
||||
time.sleep(seconds)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Client",
|
||||
"DifyAgentClientError",
|
||||
"DifyAgentHTTPError",
|
||||
"DifyAgentNotFoundError",
|
||||
"DifyAgentStreamError",
|
||||
"DifyAgentTimeoutError",
|
||||
"DifyAgentValidationError",
|
||||
]
|
||||
47
dify-agent/src/dify_agent/protocol/__init__.py
Normal file
47
dify-agent/src/dify_agent/protocol/__init__.py
Normal file
@ -0,0 +1,47 @@
|
||||
"""Public protocol exports shared by the Dify Agent server and clients."""
|
||||
|
||||
from .schemas import (
|
||||
RUN_EVENT_ADAPTER,
|
||||
AgentOutputRunEvent,
|
||||
AgentOutputRunEventData,
|
||||
AgentProfileConfig,
|
||||
BaseRunEvent,
|
||||
CreateRunRequest,
|
||||
CreateRunResponse,
|
||||
EmptyRunEventData,
|
||||
PydanticAIStreamRunEvent,
|
||||
RunEvent,
|
||||
RunEventType,
|
||||
RunEventsResponse,
|
||||
RunFailedEvent,
|
||||
RunFailedEventData,
|
||||
RunStartedEvent,
|
||||
RunStatus,
|
||||
RunStatusResponse,
|
||||
RunSucceededEvent,
|
||||
SessionSnapshotRunEvent,
|
||||
utc_now,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AgentProfileConfig",
|
||||
"AgentOutputRunEvent",
|
||||
"AgentOutputRunEventData",
|
||||
"BaseRunEvent",
|
||||
"CreateRunRequest",
|
||||
"CreateRunResponse",
|
||||
"EmptyRunEventData",
|
||||
"PydanticAIStreamRunEvent",
|
||||
"RUN_EVENT_ADAPTER",
|
||||
"RunEvent",
|
||||
"RunEventType",
|
||||
"RunEventsResponse",
|
||||
"RunFailedEvent",
|
||||
"RunFailedEventData",
|
||||
"RunStartedEvent",
|
||||
"RunStatus",
|
||||
"RunStatusResponse",
|
||||
"RunSucceededEvent",
|
||||
"SessionSnapshotRunEvent",
|
||||
"utc_now",
|
||||
]
|
||||
201
dify-agent/src/dify_agent/protocol/schemas.py
Normal file
201
dify-agent/src/dify_agent/protocol/schemas.py
Normal file
@ -0,0 +1,201 @@
|
||||
"""Public HTTP protocol schemas for the Dify Agent run API.
|
||||
|
||||
This module is the shared wire contract for the FastAPI server, runtime event
|
||||
producers, storage adapters, and Python client. The server accepts only
|
||||
registry-backed Agenton compositor configs, keeping HTTP input data-only and
|
||||
preventing unsafe import-path construction. Run events are append-only records;
|
||||
Redis stream ids (or in-memory equivalents in tests) are the public cursors used
|
||||
by polling and SSE replay. Event envelopes keep the public
|
||||
``id``/``run_id``/``type``/``data``/``created_at`` shape, while each ``type`` has
|
||||
a typed ``data`` model so OpenAPI, Redis replay, and clients parse the same
|
||||
payload contract.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated, ClassVar, Literal, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
||||
from pydantic_ai.messages import AgentStreamEvent
|
||||
|
||||
from agenton.compositor import CompositorConfig, CompositorSessionSnapshot
|
||||
|
||||
|
||||
RunStatus = Literal["running", "succeeded", "failed"]
|
||||
RunEventType = Literal[
|
||||
"run_started",
|
||||
"pydantic_ai_event",
|
||||
"agent_output",
|
||||
"session_snapshot",
|
||||
"run_succeeded",
|
||||
"run_failed",
|
||||
]
|
||||
|
||||
|
||||
def utc_now() -> datetime:
|
||||
"""Return the timezone-aware timestamp format used by public schemas."""
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
class AgentProfileConfig(BaseModel):
|
||||
"""Minimal model profile for the MVP runner.
|
||||
|
||||
``test`` uses pydantic-ai's ``TestModel`` and is credential-free. Other
|
||||
profiles can be added behind this schema without changing run/event storage.
|
||||
"""
|
||||
|
||||
provider: Literal["test"] = "test"
|
||||
output_text: str = "Hello from the Dify Agent test model."
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class CreateRunRequest(BaseModel):
|
||||
"""Request body for creating one async agent run."""
|
||||
|
||||
compositor: CompositorConfig
|
||||
session_snapshot: CompositorSessionSnapshot | None = None
|
||||
agent_profile: AgentProfileConfig = Field(default_factory=AgentProfileConfig)
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class CreateRunResponse(BaseModel):
|
||||
"""Response returned after a run has been persisted and scheduled locally."""
|
||||
|
||||
run_id: str
|
||||
status: RunStatus
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class RunStatusResponse(BaseModel):
|
||||
"""Current server-side status for one run."""
|
||||
|
||||
run_id: str
|
||||
status: RunStatus
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
error: str | None = None
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class EmptyRunEventData(BaseModel):
|
||||
"""Typed empty payload for lifecycle events that carry no extra data."""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class AgentOutputRunEventData(BaseModel):
|
||||
"""Final agent output payload emitted before the session snapshot."""
|
||||
|
||||
output: str
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class RunFailedEventData(BaseModel):
|
||||
"""Terminal failure payload shown to polling and SSE consumers."""
|
||||
|
||||
error: str
|
||||
reason: str | None = None
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class BaseRunEvent(BaseModel):
|
||||
"""Shared append-only event envelope visible through polling and SSE."""
|
||||
|
||||
id: str | None = None
|
||||
run_id: str
|
||||
created_at: datetime = Field(default_factory=utc_now)
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class RunStartedEvent(BaseRunEvent):
|
||||
"""Run lifecycle event emitted before runtime execution starts."""
|
||||
|
||||
type: Literal["run_started"] = "run_started"
|
||||
data: EmptyRunEventData = Field(default_factory=EmptyRunEventData)
|
||||
|
||||
|
||||
class PydanticAIStreamRunEvent(BaseRunEvent):
|
||||
"""Pydantic AI stream event using the upstream typed event model."""
|
||||
|
||||
type: Literal["pydantic_ai_event"] = "pydantic_ai_event"
|
||||
data: AgentStreamEvent
|
||||
|
||||
|
||||
class AgentOutputRunEvent(BaseRunEvent):
|
||||
"""Run event carrying the final agent output string."""
|
||||
|
||||
type: Literal["agent_output"] = "agent_output"
|
||||
data: AgentOutputRunEventData
|
||||
|
||||
|
||||
class SessionSnapshotRunEvent(BaseRunEvent):
|
||||
"""Run event carrying the resumable Agenton session snapshot."""
|
||||
|
||||
type: Literal["session_snapshot"] = "session_snapshot"
|
||||
data: CompositorSessionSnapshot
|
||||
|
||||
|
||||
class RunSucceededEvent(BaseRunEvent):
|
||||
"""Terminal success event emitted after output and session snapshot."""
|
||||
|
||||
type: Literal["run_succeeded"] = "run_succeeded"
|
||||
data: EmptyRunEventData = Field(default_factory=EmptyRunEventData)
|
||||
|
||||
|
||||
class RunFailedEvent(BaseRunEvent):
|
||||
"""Terminal failure event emitted before the run status becomes failed."""
|
||||
|
||||
type: Literal["run_failed"] = "run_failed"
|
||||
data: RunFailedEventData
|
||||
|
||||
|
||||
RunEvent: TypeAlias = Annotated[
|
||||
RunStartedEvent
|
||||
| PydanticAIStreamRunEvent
|
||||
| AgentOutputRunEvent
|
||||
| SessionSnapshotRunEvent
|
||||
| RunSucceededEvent
|
||||
| RunFailedEvent,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
RUN_EVENT_ADAPTER: TypeAdapter[RunEvent] = TypeAdapter(RunEvent)
|
||||
|
||||
|
||||
class RunEventsResponse(BaseModel):
|
||||
"""Cursor-paginated event log response."""
|
||||
|
||||
run_id: str
|
||||
events: list[RunEvent]
|
||||
next_cursor: str | None = None
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AgentProfileConfig",
|
||||
"AgentOutputRunEvent",
|
||||
"AgentOutputRunEventData",
|
||||
"BaseRunEvent",
|
||||
"CreateRunRequest",
|
||||
"CreateRunResponse",
|
||||
"EmptyRunEventData",
|
||||
"PydanticAIStreamRunEvent",
|
||||
"RUN_EVENT_ADAPTER",
|
||||
"RunEvent",
|
||||
"RunEventType",
|
||||
"RunEventsResponse",
|
||||
"RunFailedEvent",
|
||||
"RunFailedEventData",
|
||||
"RunStartedEvent",
|
||||
"RunStatus",
|
||||
"RunStatusResponse",
|
||||
"RunSucceededEvent",
|
||||
"SessionSnapshotRunEvent",
|
||||
"utc_now",
|
||||
]
|
||||
@ -13,7 +13,7 @@ from pydantic_ai.messages import UserContent
|
||||
from pydantic_ai.models.test import TestModel
|
||||
|
||||
from agenton.layers.types import PydanticAIPrompt, PydanticAITool
|
||||
from dify_agent.server.schemas import AgentProfileConfig
|
||||
from dify_agent.protocol.schemas import AgentProfileConfig
|
||||
|
||||
|
||||
def create_agent(
|
||||
|
||||
@ -11,7 +11,7 @@ from typing import Protocol
|
||||
from pydantic_ai.messages import AgentStreamEvent
|
||||
|
||||
from agenton.compositor import CompositorSessionSnapshot
|
||||
from dify_agent.server.schemas import (
|
||||
from dify_agent.protocol.schemas import (
|
||||
AgentOutputRunEvent,
|
||||
AgentOutputRunEventData,
|
||||
EmptyRunEventData,
|
||||
|
||||
@ -12,11 +12,12 @@ import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Protocol
|
||||
|
||||
from dify_agent.protocol.schemas import CreateRunRequest
|
||||
from dify_agent.runtime.compositor_factory import build_pydantic_ai_compositor
|
||||
from dify_agent.runtime.event_sink import RunEventSink, emit_run_failed
|
||||
from dify_agent.runtime.runner import AgentRunRunner
|
||||
from dify_agent.runtime.user_prompt_validation import EMPTY_USER_PROMPTS_ERROR, has_non_blank_user_prompt
|
||||
from dify_agent.server.schemas import CreateRunRequest, RunRecord
|
||||
from dify_agent.server.schemas import RunRecord
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -11,6 +11,7 @@ from collections.abc import AsyncIterable
|
||||
from pydantic_ai.messages import AgentStreamEvent
|
||||
|
||||
from agenton.compositor import CompositorSessionSnapshot
|
||||
from dify_agent.protocol.schemas import CreateRunRequest
|
||||
from dify_agent.runtime.agent_factory import create_agent, normalize_user_input
|
||||
from dify_agent.runtime.compositor_factory import build_pydantic_ai_compositor
|
||||
from dify_agent.runtime.event_sink import (
|
||||
@ -23,7 +24,6 @@ from dify_agent.runtime.event_sink import (
|
||||
emit_session_snapshot,
|
||||
)
|
||||
from dify_agent.runtime.user_prompt_validation import EMPTY_USER_PROMPTS_ERROR, has_non_blank_user_prompt
|
||||
from dify_agent.server.schemas import CreateRunRequest
|
||||
|
||||
|
||||
class AgentRunValidationError(ValueError):
|
||||
|
||||
@ -13,10 +13,10 @@ from typing import Annotated
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from dify_agent.protocol.schemas import CreateRunRequest, CreateRunResponse, RunEventsResponse, RunStatusResponse
|
||||
from dify_agent.runtime.compositor_factory import build_pydantic_ai_compositor
|
||||
from dify_agent.runtime.run_scheduler import RunScheduler, SchedulerStoppingError
|
||||
from dify_agent.runtime.user_prompt_validation import EMPTY_USER_PROMPTS_ERROR, has_non_blank_user_prompt
|
||||
from dify_agent.server.schemas import CreateRunRequest, CreateRunResponse, RunEventsResponse, RunStatusResponse
|
||||
from dify_agent.server.sse import sse_event_stream
|
||||
from dify_agent.storage.redis_run_store import RedisRunStore, RunNotFoundError
|
||||
|
||||
|
||||
@ -1,198 +1,41 @@
|
||||
"""Public API schemas for the Dify Agent run server.
|
||||
"""Server-only schemas and helpers for persisted run records.
|
||||
|
||||
The server accepts only registry-backed Agenton compositor configs. This keeps
|
||||
HTTP input data-only and prevents unsafe import-path construction. Run events are
|
||||
append-only records; Redis stream ids (or in-memory equivalents in tests) are the
|
||||
public cursors used by polling and SSE replay. Event envelopes keep the public
|
||||
``id``/``run_id``/``type``/``data``/``created_at`` shape, but each ``type`` has a
|
||||
typed ``data`` model so OpenAPI, Redis replay, and runtime producers agree on the
|
||||
payload contract.
|
||||
Public HTTP DTOs and run events live in ``dify_agent.protocol.schemas`` and are
|
||||
intentionally not re-exported here. Keeping this module server-only prevents old
|
||||
imports from silently depending on implementation modules while preserving the
|
||||
internal ``RunRecord`` model used by schedulers and Redis storage.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Literal, TypeAlias
|
||||
from datetime import datetime
|
||||
from typing import ClassVar
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, field_validator
|
||||
from pydantic_ai.messages import AgentStreamEvent
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from agenton.compositor import CompositorConfig, CompositorSessionSnapshot
|
||||
|
||||
|
||||
RunStatus = Literal["running", "succeeded", "failed"]
|
||||
RunEventType = Literal[
|
||||
"run_started",
|
||||
"pydantic_ai_event",
|
||||
"agent_output",
|
||||
"session_snapshot",
|
||||
"run_succeeded",
|
||||
"run_failed",
|
||||
]
|
||||
from dify_agent.protocol import schemas as _protocol_schemas
|
||||
|
||||
|
||||
def new_run_id() -> str:
|
||||
"""Return a stable external run id."""
|
||||
"""Return a stable external run id for newly persisted server records."""
|
||||
return str(uuid4())
|
||||
|
||||
|
||||
def utc_now() -> datetime:
|
||||
"""Return the timestamp format used by public schemas."""
|
||||
return datetime.now(timezone.utc)
|
||||
class RunRecord(BaseModel):
|
||||
"""Internal representation persisted for status reads.
|
||||
|
||||
|
||||
class AgentProfileConfig(BaseModel):
|
||||
"""Minimal model profile for the MVP runner.
|
||||
|
||||
``test`` uses pydantic-ai's ``TestModel`` and is credential-free. Other
|
||||
profiles can be added behind this schema without changing run/event storage.
|
||||
The embedded request and status use protocol types so persisted records stay
|
||||
JSON-compatible with the public API, but callers must import those DTOs from
|
||||
``dify_agent.protocol.schemas`` rather than this server-only module.
|
||||
"""
|
||||
|
||||
provider: Literal["test"] = "test"
|
||||
output_text: str = "Hello from the Dify Agent test model."
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class CreateRunRequest(BaseModel):
|
||||
"""Request body for creating one async agent run."""
|
||||
|
||||
compositor: CompositorConfig
|
||||
session_snapshot: CompositorSessionSnapshot | None = None
|
||||
agent_profile: AgentProfileConfig = Field(default_factory=AgentProfileConfig)
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class CreateRunResponse(BaseModel):
|
||||
"""Response returned after a run has been persisted and scheduled locally."""
|
||||
|
||||
run_id: str
|
||||
status: RunStatus
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class RunStatusResponse(BaseModel):
|
||||
"""Current server-side status for one run."""
|
||||
|
||||
run_id: str
|
||||
status: RunStatus
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
status: _protocol_schemas.RunStatus
|
||||
request: _protocol_schemas.CreateRunRequest
|
||||
created_at: datetime = Field(default_factory=_protocol_schemas.utc_now)
|
||||
updated_at: datetime = Field(default_factory=_protocol_schemas.utc_now)
|
||||
error: str | None = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class EmptyRunEventData(BaseModel):
|
||||
"""Typed empty payload for lifecycle events that carry no extra data."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class AgentOutputRunEventData(BaseModel):
|
||||
"""Final agent output payload emitted before the session snapshot."""
|
||||
|
||||
output: str
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class RunFailedEventData(BaseModel):
|
||||
"""Terminal failure payload shown to polling and SSE consumers."""
|
||||
|
||||
error: str
|
||||
reason: str | None = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class BaseRunEvent(BaseModel):
|
||||
"""Shared append-only event envelope visible through polling and SSE."""
|
||||
|
||||
id: str | None = None
|
||||
run_id: str
|
||||
created_at: datetime = Field(default_factory=utc_now)
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class RunStartedEvent(BaseRunEvent):
|
||||
"""Run lifecycle event emitted before runtime execution starts."""
|
||||
|
||||
type: Literal["run_started"] = "run_started"
|
||||
data: EmptyRunEventData = Field(default_factory=EmptyRunEventData)
|
||||
|
||||
|
||||
class PydanticAIStreamRunEvent(BaseRunEvent):
|
||||
"""Pydantic AI stream event using the upstream typed event model."""
|
||||
|
||||
type: Literal["pydantic_ai_event"] = "pydantic_ai_event"
|
||||
data: AgentStreamEvent
|
||||
|
||||
|
||||
class AgentOutputRunEvent(BaseRunEvent):
|
||||
"""Run event carrying the final agent output string."""
|
||||
|
||||
type: Literal["agent_output"] = "agent_output"
|
||||
data: AgentOutputRunEventData
|
||||
|
||||
|
||||
class SessionSnapshotRunEvent(BaseRunEvent):
|
||||
"""Run event carrying the resumable Agenton session snapshot."""
|
||||
|
||||
type: Literal["session_snapshot"] = "session_snapshot"
|
||||
data: CompositorSessionSnapshot
|
||||
|
||||
|
||||
class RunSucceededEvent(BaseRunEvent):
|
||||
"""Terminal success event emitted after output and session snapshot."""
|
||||
|
||||
type: Literal["run_succeeded"] = "run_succeeded"
|
||||
data: EmptyRunEventData = Field(default_factory=EmptyRunEventData)
|
||||
|
||||
|
||||
class RunFailedEvent(BaseRunEvent):
|
||||
"""Terminal failure event emitted before the run status becomes failed."""
|
||||
|
||||
type: Literal["run_failed"] = "run_failed"
|
||||
data: RunFailedEventData
|
||||
|
||||
|
||||
|
||||
RunEvent: TypeAlias = Annotated[
|
||||
RunStartedEvent
|
||||
| PydanticAIStreamRunEvent
|
||||
| AgentOutputRunEvent
|
||||
| SessionSnapshotRunEvent
|
||||
| RunSucceededEvent
|
||||
| RunFailedEvent,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
RUN_EVENT_ADAPTER = TypeAdapter(RunEvent)
|
||||
|
||||
|
||||
class RunEventsResponse(BaseModel):
|
||||
"""Cursor-paginated event log response."""
|
||||
|
||||
run_id: str
|
||||
events: list[RunEvent]
|
||||
next_cursor: str | None = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class RunRecord(BaseModel):
|
||||
"""Internal representation persisted for status reads."""
|
||||
|
||||
run_id: str
|
||||
status: RunStatus
|
||||
request: CreateRunRequest
|
||||
created_at: datetime = Field(default_factory=utc_now)
|
||||
updated_at: datetime = Field(default_factory=utc_now)
|
||||
error: str | None = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
@field_validator("updated_at")
|
||||
@classmethod
|
||||
@ -203,26 +46,4 @@ class RunRecord(BaseModel):
|
||||
return value
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AgentProfileConfig",
|
||||
"AgentOutputRunEvent",
|
||||
"AgentOutputRunEventData",
|
||||
"BaseRunEvent",
|
||||
"CreateRunRequest",
|
||||
"CreateRunResponse",
|
||||
"EmptyRunEventData",
|
||||
"PydanticAIStreamRunEvent",
|
||||
"RUN_EVENT_ADAPTER",
|
||||
"RunEvent",
|
||||
"RunEventsResponse",
|
||||
"RunFailedEvent",
|
||||
"RunFailedEventData",
|
||||
"RunRecord",
|
||||
"RunStartedEvent",
|
||||
"RunStatus",
|
||||
"RunStatusResponse",
|
||||
"RunSucceededEvent",
|
||||
"SessionSnapshotRunEvent",
|
||||
"new_run_id",
|
||||
"utc_now",
|
||||
]
|
||||
__all__ = ["RunRecord", "new_run_id"]
|
||||
|
||||
@ -7,7 +7,7 @@ name. Payload data is the full public ``RunEvent`` JSON object.
|
||||
|
||||
from collections.abc import AsyncIterable, AsyncIterator
|
||||
|
||||
from dify_agent.server.schemas import RUN_EVENT_ADAPTER, RunEvent
|
||||
from dify_agent.protocol.schemas import RUN_EVENT_ADAPTER, RunEvent
|
||||
|
||||
|
||||
def format_sse_event(event: RunEvent) -> str:
|
||||
|
||||
@ -12,17 +12,16 @@ from typing import cast
|
||||
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from dify_agent.runtime.event_sink import RunEventSink
|
||||
from dify_agent.server.schemas import (
|
||||
from dify_agent.protocol.schemas import (
|
||||
CreateRunRequest,
|
||||
RUN_EVENT_ADAPTER,
|
||||
RunEvent,
|
||||
RunEventsResponse,
|
||||
RunRecord,
|
||||
RunStatus,
|
||||
new_run_id,
|
||||
utc_now,
|
||||
)
|
||||
from dify_agent.runtime.event_sink import RunEventSink
|
||||
from dify_agent.server.schemas import RunRecord, new_run_id
|
||||
from dify_agent.server.settings import DEFAULT_RUN_RETENTION_SECONDS
|
||||
from dify_agent.storage.redis_keys import run_events_key, run_record_key
|
||||
|
||||
|
||||
381
dify-agent/tests/local/dify_agent/client/test_client.py
Normal file
381
dify-agent/tests/local/dify_agent/client/test_client.py
Normal file
@ -0,0 +1,381 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from collections.abc import Iterator
|
||||
from datetime import UTC, datetime
|
||||
from typing import cast, override
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from dify_agent.client import (
|
||||
Client,
|
||||
DifyAgentHTTPError,
|
||||
DifyAgentNotFoundError,
|
||||
DifyAgentStreamError,
|
||||
DifyAgentTimeoutError,
|
||||
DifyAgentValidationError,
|
||||
)
|
||||
from dify_agent.protocol.schemas import (
|
||||
CreateRunRequest,
|
||||
EmptyRunEventData,
|
||||
RUN_EVENT_ADAPTER,
|
||||
RunEvent,
|
||||
RunEventsResponse,
|
||||
RunStartedEvent,
|
||||
RunSucceededEvent,
|
||||
)
|
||||
|
||||
|
||||
def _create_run_payload() -> dict[str, object]:
|
||||
return {
|
||||
"compositor": {
|
||||
"schema_version": 1,
|
||||
"layers": [{"name": "prompt", "type": "plain.prompt", "config": {"user": "hello"}}],
|
||||
},
|
||||
"agent_profile": {"provider": "test", "output_text": "done"},
|
||||
}
|
||||
|
||||
|
||||
def _event_frame(event: RunEvent, *, event_id: str | None = None, exclude_id: bool = False) -> str:
|
||||
payload = RUN_EVENT_ADAPTER.dump_json(event, exclude={"id"} if exclude_id else None).decode()
|
||||
lines: list[str] = []
|
||||
if event_id is not None:
|
||||
lines.append(f"id: {event_id}")
|
||||
lines.append(f"data: {payload}")
|
||||
return "\n".join(lines) + "\n\n"
|
||||
|
||||
|
||||
def _run_status_json(status: str) -> dict[str, object]:
|
||||
now = datetime(2026, 5, 11, tzinfo=UTC).isoformat()
|
||||
return {"run_id": "run-1", "status": status, "created_at": now, "updated_at": now, "error": None}
|
||||
|
||||
|
||||
class DisconnectingSyncStream(httpx.SyncByteStream):
|
||||
chunks: list[bytes]
|
||||
|
||||
def __init__(self, *chunks: str) -> None:
|
||||
self.chunks = [chunk.encode() for chunk in chunks]
|
||||
|
||||
@override
|
||||
def __iter__(self) -> Iterator[bytes]:
|
||||
yield from self.chunks
|
||||
raise httpx.ReadError("stream disconnected")
|
||||
|
||||
|
||||
def test_sync_methods_parse_protocol_dtos_and_validate_create_dict() -> None:
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
if request.method == "POST" and request.url.path == "/runs":
|
||||
payload = cast(dict[str, object], json.loads(request.content))
|
||||
compositor = cast(dict[str, object], payload["compositor"])
|
||||
layers = cast(list[dict[str, object]], compositor["layers"])
|
||||
assert layers[0]["config"] == {"user": "hello"}
|
||||
assert payload["agent_profile"] == {"provider": "test", "output_text": "done"}
|
||||
return httpx.Response(202, json={"run_id": "run-1", "status": "running"})
|
||||
if request.method == "GET" and request.url.path == "/runs/run-1":
|
||||
return httpx.Response(200, json=_run_status_json("running"))
|
||||
if request.method == "GET" and request.url.path == "/runs/run-1/events":
|
||||
assert request.url.params["after"] == "0-0"
|
||||
assert request.url.params["limit"] == "10"
|
||||
event = RunStartedEvent(id="1-0", run_id="run-1")
|
||||
return httpx.Response(
|
||||
200,
|
||||
json={
|
||||
"run_id": "run-1",
|
||||
"events": [cast(object, json.loads(RUN_EVENT_ADAPTER.dump_json(event)))],
|
||||
"next_cursor": "1-0",
|
||||
},
|
||||
)
|
||||
raise AssertionError(f"unexpected request: {request.method} {request.url}")
|
||||
|
||||
http_client = httpx.Client(transport=httpx.MockTransport(handler))
|
||||
client = Client(base_url="http://testserver", sync_http_client=http_client)
|
||||
|
||||
created = client.create_run_sync(_create_run_payload())
|
||||
status = client.get_run_sync(created.run_id)
|
||||
events = client.get_events_sync(created.run_id, after="0-0", limit=10)
|
||||
|
||||
assert created.status == "running"
|
||||
assert status.status == "running"
|
||||
assert isinstance(events, RunEventsResponse)
|
||||
assert [event.type for event in events.events] == ["run_started"]
|
||||
|
||||
|
||||
def test_async_methods_and_wait_run_parse_protocol_dtos() -> None:
|
||||
statuses = iter(["running", "succeeded"])
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
if request.method == "POST" and request.url.path == "/runs":
|
||||
return httpx.Response(202, json={"run_id": "run-1", "status": "running"})
|
||||
if request.method == "GET" and request.url.path == "/runs/run-1":
|
||||
return httpx.Response(200, json=_run_status_json(next(statuses)))
|
||||
if request.method == "GET" and request.url.path == "/runs/run-1/events":
|
||||
return httpx.Response(200, json={"run_id": "run-1", "events": [], "next_cursor": "0-0"})
|
||||
raise AssertionError(f"unexpected request: {request.method} {request.url}")
|
||||
|
||||
async def scenario() -> None:
|
||||
http_client = httpx.AsyncClient(transport=httpx.MockTransport(handler))
|
||||
client = Client(base_url="http://testserver", async_http_client=http_client)
|
||||
request = CreateRunRequest.model_validate(_create_run_payload())
|
||||
|
||||
created = await client.create_run(request)
|
||||
events = await client.get_events(created.run_id)
|
||||
terminal = await client.wait_run(created.run_id, poll_interval_seconds=0)
|
||||
|
||||
assert created.run_id == "run-1"
|
||||
assert events.events == []
|
||||
assert terminal.status == "succeeded"
|
||||
await http_client.aclose()
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_error_mapping_and_create_run_input_validation() -> None:
|
||||
responses = iter(
|
||||
[
|
||||
httpx.Response(404, json={"detail": "run not found"}),
|
||||
httpx.Response(422, json={"detail": "invalid"}),
|
||||
httpx.Response(500, json={"detail": "boom"}),
|
||||
]
|
||||
)
|
||||
|
||||
def handler(_request: httpx.Request) -> httpx.Response:
|
||||
return next(responses)
|
||||
|
||||
client = Client(
|
||||
base_url="http://testserver",
|
||||
sync_http_client=httpx.Client(transport=httpx.MockTransport(handler)),
|
||||
)
|
||||
|
||||
with pytest.raises(DifyAgentNotFoundError) as not_found:
|
||||
_ = client.get_run_sync("missing")
|
||||
assert not_found.value.status_code == 404
|
||||
assert not_found.value.detail == "run not found"
|
||||
|
||||
with pytest.raises(DifyAgentValidationError) as validation:
|
||||
_ = client.get_run_sync("bad")
|
||||
assert validation.value.status_code == 422
|
||||
|
||||
with pytest.raises(DifyAgentHTTPError) as server_error:
|
||||
_ = client.get_run_sync("bad")
|
||||
assert server_error.value.status_code == 500
|
||||
|
||||
with pytest.raises(DifyAgentValidationError):
|
||||
_ = client.create_run_sync({"unknown": "field"})
|
||||
|
||||
|
||||
def test_http_timeout_maps_to_client_timeout_error() -> None:
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
raise httpx.ReadTimeout("slow", request=request)
|
||||
|
||||
client = Client(
|
||||
base_url="http://testserver",
|
||||
sync_http_client=httpx.Client(transport=httpx.MockTransport(handler)),
|
||||
)
|
||||
|
||||
with pytest.raises(DifyAgentTimeoutError):
|
||||
_ = client.get_run_sync("run-1")
|
||||
|
||||
|
||||
def test_create_run_is_not_retried_after_timeout() -> None:
|
||||
attempts = 0
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
nonlocal attempts
|
||||
attempts += 1
|
||||
raise httpx.ConnectTimeout("cannot connect", request=request)
|
||||
|
||||
client = Client(
|
||||
base_url="http://testserver",
|
||||
sync_http_client=httpx.Client(transport=httpx.MockTransport(handler)),
|
||||
)
|
||||
|
||||
with pytest.raises(DifyAgentTimeoutError):
|
||||
_ = client.create_run_sync(_create_run_payload())
|
||||
assert attempts == 1
|
||||
|
||||
|
||||
def test_sync_sse_parser_supports_comments_multiline_data_and_id_fill() -> None:
|
||||
payload = RUN_EVENT_ADAPTER.dump_json(RunStartedEvent(run_id="run-1"), exclude={"id"}).decode()
|
||||
before_type, after_type = payload.split('"type"', maxsplit=1)
|
||||
body = f": keepalive\nid: 5-0\nevent: run_started\ndata: {before_type}\ndata: \"type\"{after_type}\n\n"
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
assert request.url.params["after"] == "0-0"
|
||||
return httpx.Response(200, content=body)
|
||||
|
||||
client = Client(
|
||||
base_url="http://testserver",
|
||||
sync_http_client=httpx.Client(transport=httpx.MockTransport(handler)),
|
||||
)
|
||||
|
||||
events = list(client.stream_events_sync("run-1", until_terminal=False, reconnect=False))
|
||||
|
||||
assert [event.id for event in events] == ["5-0"]
|
||||
assert [event.type for event in events] == ["run_started"]
|
||||
|
||||
|
||||
def test_stream_events_stops_after_terminal_event() -> None:
|
||||
calls = 0
|
||||
body = "".join(
|
||||
[
|
||||
_event_frame(RunStartedEvent(id="1-0", run_id="run-1")),
|
||||
_event_frame(RunSucceededEvent(id="2-0", run_id="run-1", data=EmptyRunEventData())),
|
||||
]
|
||||
)
|
||||
|
||||
def handler(_request: httpx.Request) -> httpx.Response:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
return httpx.Response(200, content=body)
|
||||
|
||||
client = Client(
|
||||
base_url="http://testserver",
|
||||
sync_http_client=httpx.Client(transport=httpx.MockTransport(handler)),
|
||||
)
|
||||
|
||||
events = list(client.stream_events_sync("run-1", reconnect_delay_seconds=0))
|
||||
|
||||
assert [event.type for event in events] == ["run_started", "run_succeeded"]
|
||||
assert calls == 1
|
||||
|
||||
|
||||
def test_stream_events_reconnects_from_latest_event_id() -> None:
|
||||
seen_after: list[str] = []
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
seen_after.append(request.url.params["after"])
|
||||
if len(seen_after) == 1:
|
||||
return httpx.Response(
|
||||
200,
|
||||
stream=DisconnectingSyncStream(_event_frame(RunStartedEvent(id="1-0", run_id="run-1"))),
|
||||
)
|
||||
return httpx.Response(200, content=_event_frame(RunSucceededEvent(id="2-0", run_id="run-1")))
|
||||
|
||||
client = Client(
|
||||
base_url="http://testserver",
|
||||
sync_http_client=httpx.Client(transport=httpx.MockTransport(handler)),
|
||||
)
|
||||
|
||||
events = list(client.stream_events_sync("run-1", reconnect_delay_seconds=0))
|
||||
|
||||
assert seen_after == ["0-0", "1-0"]
|
||||
assert [event.type for event in events] == ["run_started", "run_succeeded"]
|
||||
|
||||
|
||||
def test_stream_events_reconnects_after_http_5xx_response() -> None:
|
||||
seen_after: list[str] = []
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
seen_after.append(request.url.params["after"])
|
||||
if len(seen_after) == 1:
|
||||
return httpx.Response(503, json={"detail": "temporarily unavailable"})
|
||||
return httpx.Response(200, content=_event_frame(RunSucceededEvent(id="2-0", run_id="run-1")))
|
||||
|
||||
client = Client(
|
||||
base_url="http://testserver",
|
||||
sync_http_client=httpx.Client(transport=httpx.MockTransport(handler)),
|
||||
)
|
||||
|
||||
events = list(client.stream_events_sync("run-1", reconnect_delay_seconds=0))
|
||||
|
||||
assert seen_after == ["0-0", "0-0"]
|
||||
assert [event.type for event in events] == ["run_succeeded"]
|
||||
|
||||
|
||||
def test_stream_events_raises_when_reconnects_are_exhausted() -> None:
|
||||
calls = 0
|
||||
|
||||
def handler(_request: httpx.Request) -> httpx.Response:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
return httpx.Response(200, stream=DisconnectingSyncStream())
|
||||
|
||||
client = Client(
|
||||
base_url="http://testserver",
|
||||
sync_http_client=httpx.Client(transport=httpx.MockTransport(handler)),
|
||||
)
|
||||
|
||||
with pytest.raises(DifyAgentStreamError):
|
||||
_ = list(client.stream_events_sync("run-1", max_reconnects=1, reconnect_delay_seconds=0))
|
||||
assert calls == 2
|
||||
|
||||
|
||||
def test_malformed_sse_frame_does_not_reconnect() -> None:
|
||||
calls = 0
|
||||
|
||||
def handler(_request: httpx.Request) -> httpx.Response:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
return httpx.Response(200, content="data: not-json\n\n")
|
||||
|
||||
client = Client(
|
||||
base_url="http://testserver",
|
||||
sync_http_client=httpx.Client(transport=httpx.MockTransport(handler)),
|
||||
)
|
||||
|
||||
with pytest.raises(DifyAgentStreamError):
|
||||
_ = list(client.stream_events_sync("run-1", reconnect_delay_seconds=0))
|
||||
assert calls == 1
|
||||
|
||||
|
||||
def test_async_stream_events_yields_terminal_event() -> None:
|
||||
body = _event_frame(RunSucceededEvent(id="2-0", run_id="run-1"))
|
||||
|
||||
def handler(_request: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(200, content=body)
|
||||
|
||||
async def scenario() -> None:
|
||||
http_client = httpx.AsyncClient(transport=httpx.MockTransport(handler))
|
||||
client = Client(base_url="http://testserver", async_http_client=http_client)
|
||||
|
||||
events = [event async for event in client.stream_events("run-1")]
|
||||
|
||||
assert [event.type for event in events] == ["run_succeeded"]
|
||||
await http_client.aclose()
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_async_stream_events_reconnects_after_http_5xx_response() -> None:
|
||||
seen_after: list[str] = []
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
seen_after.append(request.url.params["after"])
|
||||
if len(seen_after) == 1:
|
||||
return httpx.Response(502, json={"detail": "bad gateway"})
|
||||
return httpx.Response(200, content=_event_frame(RunSucceededEvent(id="2-0", run_id="run-1")))
|
||||
|
||||
async def scenario() -> None:
|
||||
http_client = httpx.AsyncClient(transport=httpx.MockTransport(handler))
|
||||
client = Client(base_url="http://testserver", async_http_client=http_client)
|
||||
|
||||
events = [event async for event in client.stream_events("run-1", reconnect_delay_seconds=0)]
|
||||
|
||||
assert seen_after == ["0-0", "0-0"]
|
||||
assert [event.type for event in events] == ["run_succeeded"]
|
||||
await http_client.aclose()
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_stream_timeout_can_reconnect_until_terminal() -> None:
|
||||
calls = 0
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
if calls == 1:
|
||||
raise httpx.ReadTimeout("stream stalled", request=request)
|
||||
return httpx.Response(200, content=_event_frame(RunSucceededEvent(id="2-0", run_id="run-1")))
|
||||
|
||||
client = Client(
|
||||
base_url="http://testserver",
|
||||
sync_http_client=httpx.Client(transport=httpx.MockTransport(handler)),
|
||||
)
|
||||
|
||||
events = list(client.stream_events_sync("run-1", reconnect_delay_seconds=0))
|
||||
|
||||
assert calls == 2
|
||||
assert [event.type for event in events] == ["run_succeeded"]
|
||||
@ -0,0 +1,40 @@
|
||||
from pydantic_ai.messages import FinalResultEvent
|
||||
|
||||
from dify_agent.protocol.schemas import (
|
||||
RUN_EVENT_ADAPTER,
|
||||
AgentOutputRunEvent,
|
||||
AgentOutputRunEventData,
|
||||
PydanticAIStreamRunEvent,
|
||||
RunFailedEvent,
|
||||
RunFailedEventData,
|
||||
RunStartedEvent,
|
||||
)
|
||||
|
||||
|
||||
def test_run_event_adapter_round_trips_typed_variants() -> None:
|
||||
events = [
|
||||
RunStartedEvent(run_id="run-1"),
|
||||
PydanticAIStreamRunEvent(run_id="run-1", data=FinalResultEvent(tool_name=None, tool_call_id=None)),
|
||||
AgentOutputRunEvent(run_id="run-1", data=AgentOutputRunEventData(output="done")),
|
||||
RunFailedEvent(run_id="run-1", data=RunFailedEventData(error="boom", reason="shutdown")),
|
||||
]
|
||||
|
||||
for event in events:
|
||||
payload = RUN_EVENT_ADAPTER.dump_json(event)
|
||||
decoded = RUN_EVENT_ADAPTER.validate_json(payload)
|
||||
|
||||
assert decoded.type == event.type
|
||||
assert decoded.run_id == event.run_id
|
||||
|
||||
|
||||
def test_pydantic_ai_event_data_uses_agent_stream_event_model() -> None:
|
||||
event = RUN_EVENT_ADAPTER.validate_python(
|
||||
{
|
||||
"run_id": "run-1",
|
||||
"type": "pydantic_ai_event",
|
||||
"data": {"event_kind": "final_result", "tool_name": None, "tool_call_id": None},
|
||||
}
|
||||
)
|
||||
|
||||
assert isinstance(event, PydanticAIStreamRunEvent)
|
||||
assert isinstance(event.data, FinalResultEvent)
|
||||
@ -6,8 +6,9 @@ import pytest
|
||||
from pydantic import JsonValue
|
||||
|
||||
from agenton.compositor import CompositorConfig, LayerNodeConfig
|
||||
from dify_agent.protocol.schemas import CreateRunRequest, RunEvent, RunStatus
|
||||
from dify_agent.runtime.run_scheduler import RunScheduler, SchedulerStoppingError
|
||||
from dify_agent.server.schemas import CreateRunRequest, RunEvent, RunRecord, RunStatus
|
||||
from dify_agent.server.schemas import RunRecord
|
||||
|
||||
|
||||
def _request(user: str | list[str] = "hello") -> CreateRunRequest:
|
||||
|
||||
@ -3,9 +3,9 @@ import asyncio
|
||||
import pytest
|
||||
|
||||
from agenton.compositor import CompositorConfig, LayerNodeConfig
|
||||
from dify_agent.protocol.schemas import AgentProfileConfig, CreateRunRequest
|
||||
from dify_agent.runtime.event_sink import InMemoryRunEventSink
|
||||
from dify_agent.runtime.runner import AgentRunRunner, AgentRunValidationError
|
||||
from dify_agent.server.schemas import AgentProfileConfig, CreateRunRequest
|
||||
|
||||
|
||||
def test_runner_emits_terminal_success_and_snapshot() -> None:
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from dify_agent.protocol.schemas import CreateRunRequest
|
||||
from dify_agent.runtime.run_scheduler import SchedulerStoppingError
|
||||
from dify_agent.server.routes.runs import create_runs_router
|
||||
from dify_agent.server.schemas import RunRecord
|
||||
@ -122,7 +123,6 @@ def test_create_run_does_not_map_infrastructure_failure_to_422() -> None:
|
||||
|
||||
def _request():
|
||||
from agenton.compositor import CompositorConfig, LayerNodeConfig
|
||||
from dify_agent.server.schemas import CreateRunRequest
|
||||
|
||||
return CreateRunRequest(
|
||||
compositor=CompositorConfig(
|
||||
|
||||
@ -1,40 +1,12 @@
|
||||
from pydantic_ai.messages import FinalResultEvent
|
||||
|
||||
from dify_agent.server.schemas import (
|
||||
RUN_EVENT_ADAPTER,
|
||||
AgentOutputRunEvent,
|
||||
AgentOutputRunEventData,
|
||||
PydanticAIStreamRunEvent,
|
||||
RunFailedEvent,
|
||||
RunFailedEventData,
|
||||
RunStartedEvent,
|
||||
)
|
||||
import dify_agent.server.schemas as server_schemas
|
||||
|
||||
|
||||
def test_run_event_adapter_round_trips_typed_variants() -> None:
|
||||
events = [
|
||||
RunStartedEvent(run_id="run-1"),
|
||||
PydanticAIStreamRunEvent(run_id="run-1", data=FinalResultEvent(tool_name=None, tool_call_id=None)),
|
||||
AgentOutputRunEvent(run_id="run-1", data=AgentOutputRunEventData(output="done")),
|
||||
RunFailedEvent(run_id="run-1", data=RunFailedEventData(error="boom", reason="shutdown")),
|
||||
]
|
||||
|
||||
for event in events:
|
||||
payload = RUN_EVENT_ADAPTER.dump_json(event)
|
||||
decoded = RUN_EVENT_ADAPTER.validate_json(payload)
|
||||
|
||||
assert decoded.type == event.type
|
||||
assert decoded.run_id == event.run_id
|
||||
def test_server_schemas_do_not_reexport_public_protocol_dtos() -> None:
|
||||
assert server_schemas.__all__ == ["RunRecord", "new_run_id"]
|
||||
assert not hasattr(server_schemas, "CreateRunRequest")
|
||||
assert not hasattr(server_schemas, "RunStartedEvent")
|
||||
|
||||
|
||||
def test_pydantic_ai_event_data_uses_agent_stream_event_model() -> None:
|
||||
event = RUN_EVENT_ADAPTER.validate_python(
|
||||
{
|
||||
"run_id": "run-1",
|
||||
"type": "pydantic_ai_event",
|
||||
"data": {"event_kind": "final_result", "tool_name": None, "tool_call_id": None},
|
||||
}
|
||||
)
|
||||
|
||||
assert isinstance(event, PydanticAIStreamRunEvent)
|
||||
assert isinstance(event.data, FinalResultEvent)
|
||||
def test_server_schemas_keep_server_only_run_helpers() -> None:
|
||||
assert isinstance(server_schemas.new_run_id(), str)
|
||||
assert hasattr(server_schemas, "RunRecord")
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from dify_agent.server.schemas import RunStartedEvent
|
||||
from dify_agent.protocol.schemas import RunStartedEvent
|
||||
from dify_agent.server.sse import format_sse_event
|
||||
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@ import asyncio
|
||||
from collections.abc import Mapping
|
||||
|
||||
from agenton.compositor import CompositorConfig, LayerNodeConfig
|
||||
from dify_agent.server.schemas import CreateRunRequest, RunStartedEvent
|
||||
from dify_agent.protocol.schemas import CreateRunRequest, RunStartedEvent
|
||||
from dify_agent.storage.redis_run_store import DEFAULT_RUN_RETENTION_SECONDS, RedisRunStore
|
||||
|
||||
|
||||
|
||||
@ -8,5 +8,6 @@ def test_dify_agent_examples_are_importable() -> None:
|
||||
"dify_agent_examples.run_pydantic_ai_agent",
|
||||
"dify_agent_examples.run_server_consumer",
|
||||
"dify_agent_examples.run_server_sse_consumer",
|
||||
"dify_agent_examples.run_server_sync_client",
|
||||
]:
|
||||
importlib.import_module(module_name)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user