add schema-backed agenton sessions

This commit is contained in:
盐粒 Yanli 2026-05-07 22:42:07 +08:00
parent 31a1de4828
commit f316d19be6
15 changed files with 1147 additions and 246 deletions

View File

@ -0,0 +1,3 @@
# Dify Agent
Agenton documentation lives in [`docs/agenton/`](docs/agenton/).

View File

@ -0,0 +1,65 @@
# Agenton configuration and sessions
Agenton composes shared `Layer` instances into a named graph. Treat layer
instances as reusable capability definitions: config and dependency declarations
belong on the layer class or instance, while per-session runtime values belong
on the `LayerControl` created for that layer in a `CompositorSession`.
## Config, runtime state, and runtime handles
- **Config** is serializable graph input. Config-constructible layers declare a
`type_id` and a Pydantic `config_type`; builders validate node config before
calling `Layer.from_config(validated_config)`.
- **Runtime state** is serializable per-layer/per-session state. Layers declare a
Pydantic `runtime_state_type`; session snapshots persist this model with
`model_dump(mode="json")`.
- **Runtime handles** are live Python objects such as clients, open files, or
process handles. Layers declare a Pydantic `runtime_handles_type` with
`arbitrary_types_allowed=True`. Handles are never serialized; resume hooks
should rehydrate them from runtime state.
`Layer.__init_subclass__` infers `deps_type`, `config_type`,
`runtime_state_type`, and `runtime_handles_type` from generic base arguments
when possible. For example, `PlainLayer[NoLayerDeps, MyConfig, MyState,
MyHandles]` automatically installs those Pydantic schemas. Omitted schema slots
default to `EmptyLayerConfig`, `EmptyRuntimeState`, and `EmptyRuntimeHandles`.
Lifecycle hooks can annotate controls as `LayerControl[MyState, MyHandles]` to
get static checking and IDE completion for runtime state and handles.
## Registry and builder
Register config-constructible layers manually:
```python
registry = LayerRegistry()
registry.register_layer(PromptLayer) # uses PromptLayer.type_id == "plain.prompt"
```
Use `CompositorBuilder` to mix serializable config nodes with live instances:
```python
compositor = (
CompositorBuilder(registry)
.add_config({"layers": [{"name": "prompt", "type": "plain.prompt", "config": {"prefix": "Hi"}}]})
.add_instance(name="profile", layer=ObjectLayer(profile))
.build()
)
```
Use `.add_instance()` for layers that require Python objects or callables, such
as `ObjectLayer`, `ToolsLayer`, and dynamic tool layers.
## Session snapshot and restore
`Compositor.snapshot_session(session)` serializes non-active sessions, including
layer lifecycle state and runtime state. It rejects active sessions because live
handles cannot be snapshotted safely. Restore with
`Compositor.session_from_snapshot(snapshot)`; restored controls validate runtime
state with each layer schema and initialize empty runtime handles. Suspended
sessions resume through `on_context_resume`, where handles should be hydrated
from the restored runtime state.
Create sessions with `Compositor.new_session()` or
`Compositor.session_from_snapshot()`. `Compositor.enter()` validates that every
session control uses the target layer's runtime state and handle schemas before
any lifecycle hook runs.

View File

@ -8,10 +8,9 @@ from inspect import signature
from typing_extensions import override
from agenton.compositor import Compositor, CompositorLayerConfig
from agenton.compositor import CompositorBuilder, LayerRegistry
from agenton.layers import LayerControl, LayerDeps, NoLayerDeps, PlainLayer
from agenton.layers.types import PlainPromptType, PlainToolType
from agenton_collections.layers.plain import DynamicToolsLayer, ObjectLayer, ToolsLayer, with_object
from agenton_collections.layers.plain import DynamicToolsLayer, ObjectLayer, PromptLayer, ToolsLayer, with_object
@dataclass(frozen=True, slots=True)
@ -75,51 +74,41 @@ async def main() -> None:
)
trace = TraceLayer()
compositor = Compositor[PlainPromptType, PlainToolType].from_config(
{
"layers": [
{
"name": "base_prompt",
"layer": {
"import_path": "agenton_collections.layers.plain:PromptLayer",
registry = LayerRegistry()
registry.register_layer(PromptLayer)
compositor = (
CompositorBuilder(registry)
.add_config(
{
"layers": [
{
"name": "base_prompt",
"type": "plain.prompt",
"config": {
"prefix": "Use config dicts for serializable layers.",
"suffix": "Before finalizing, make the result easy to scan.",
},
},
},
{
"name": "extra_prompt",
"layer": {
"import_path": "agenton_collections.layers.plain:PromptLayer",
{
"name": "extra_prompt",
"type": "plain.prompt",
"config": {
"prefix": "Use constructed instances for objects, local code, and callables.",
},
},
},
CompositorLayerConfig(
name="profile",
layer=ObjectLayer[AgentProfile](profile),
),
CompositorLayerConfig(
name="profile_prompt",
# deps maps dependency field names to layer names only when
# they differ.
# deps={"profile": "profile"},
layer=ProfilePromptLayer(),
),
CompositorLayerConfig(
name="tools",
layer=ToolsLayer(tool_entries=(count_words,)),
),
CompositorLayerConfig(
name="dynamic_tools",
deps={"object_layer": "profile"},
layer=DynamicToolsLayer[AgentProfile](tool_entries=(write_tagline,)),
),
CompositorLayerConfig(name="trace", layer=trace),
]
},
]
}
)
.add_instance(name="profile", layer=ObjectLayer[AgentProfile](profile))
.add_instance(name="profile_prompt", layer=ProfilePromptLayer())
.add_instance(name="tools", layer=ToolsLayer(tool_entries=(count_words,)))
.add_instance(
name="dynamic_tools",
deps={"object_layer": "profile"},
layer=DynamicToolsLayer[AgentProfile](tool_entries=(write_tagline,)),
)
.add_instance(name="trace", layer=trace)
.build()
)
print("Prompts:")

View File

@ -12,9 +12,8 @@ from pydantic_ai.messages import BuiltinToolCallPart, ModelMessage, ToolCallPart
from pydantic_ai.models.openai import OpenAIChatModel # pyright: ignore[reportDeprecated]
from pydantic_ai.models.test import TestModel
from agenton.compositor import Compositor, CompositorLayerConfig
from agenton.layers.types import AllPromptTypes, AllToolTypes, PydanticAIPrompt, PydanticAITool
from agenton_collections.layers.plain import ObjectLayer, ToolsLayer
from agenton.compositor import CompositorBuilder, LayerRegistry
from agenton_collections.layers.plain import ObjectLayer, PromptLayer, ToolsLayer
from agenton_collections.layers.pydantic_ai import PydanticAIBridgeLayer
from agenton_collections.transformers import PYDANTIC_AI_TRANSFORMERS
@ -55,40 +54,32 @@ async def main() -> None:
tool_entries=(write_tagline,),
)
compositor = Compositor[
PydanticAIPrompt[object],
PydanticAITool[object],
AllPromptTypes,
AllToolTypes,
].from_config(
{
"layers": [
{
"name": "base_prompt",
"layer": {
"import_path": "agenton_collections.layers.plain:PromptLayer",
registry = LayerRegistry()
registry.register_layer(PromptLayer)
compositor = (
CompositorBuilder(registry)
.add_config(
{
"layers": [
{
"name": "base_prompt",
"type": "plain.prompt",
"config": {
"prefix": "Use the available tools before answering.",
"suffix": "Return concise, inspectable output.",
},
},
},
CompositorLayerConfig(
name="profile",
layer=ObjectLayer[AgentProfile](profile),
),
CompositorLayerConfig(
name="plain_tools",
layer=ToolsLayer(tool_entries=(count_words,)),
),
CompositorLayerConfig(
name="pydantic_ai_bridge",
deps={"object_layer": "profile"},
layer=pydantic_ai_bridge,
),
]
},
**PYDANTIC_AI_TRANSFORMERS,
]
}
)
.add_instance(name="profile", layer=ObjectLayer[AgentProfile](profile))
.add_instance(name="plain_tools", layer=ToolsLayer(tool_entries=(count_words,)))
.add_instance(
name="pydantic_ai_bridge",
deps={"object_layer": "profile"},
layer=pydantic_ai_bridge,
)
.build(**PYDANTIC_AI_TRANSFORMERS)
)
async with compositor.enter():

View File

@ -0,0 +1,72 @@
"""Run with: uv run --project dify-agent python examples/agenton/session_snapshot.py."""
from __future__ import annotations
import asyncio
from collections import OrderedDict
from dataclasses import dataclass
from typing import ClassVar
from pydantic import BaseModel, ConfigDict
from typing_extensions import override
from agenton.compositor import Compositor
from agenton.layers import LayerControl, NoLayerDeps, PlainLayer, PlainPromptType, PlainToolType
class ConnectionState(BaseModel):
connection_id: str = "demo-connection"
model_config = ConfigDict(extra="forbid", validate_assignment=True)
class ConnectionHandle:
def __init__(self, connection_id: str) -> None:
self.connection_id = connection_id
class ConnectionHandles(BaseModel):
connection: ConnectionHandle | None = None
model_config = ConfigDict(extra="forbid", validate_assignment=True, arbitrary_types_allowed=True)
@dataclass(slots=True)
class ConnectionLayer(PlainLayer[NoLayerDeps]):
runtime_state_type: ClassVar[type[BaseModel]] = ConnectionState
runtime_handles_type: ClassVar[type[BaseModel]] = ConnectionHandles
@override
async def on_context_create(self, control: LayerControl) -> None:
assert isinstance(control.runtime_state, ConnectionState)
assert isinstance(control.runtime_handles, ConnectionHandles)
control.runtime_handles.connection = ConnectionHandle(control.runtime_state.connection_id)
@override
async def on_context_resume(self, control: LayerControl) -> None:
assert isinstance(control.runtime_state, ConnectionState)
assert isinstance(control.runtime_handles, ConnectionHandles)
control.runtime_handles.connection = ConnectionHandle(f"restored:{control.runtime_state.connection_id}")
async def main() -> None:
compositor: Compositor[PlainPromptType, PlainToolType] = Compositor(
layers=OrderedDict([("connection", ConnectionLayer())])
)
session = compositor.new_session()
async with compositor.enter(session) as active_session:
active_session.suspend_on_exit()
snapshot = compositor.snapshot_session(session)
print("Snapshot:", snapshot.model_dump(mode="json"))
restored = compositor.session_from_snapshot(snapshot)
async with compositor.enter(restored):
handles = restored.layer("connection").runtime_handles
assert isinstance(handles, ConnectionHandles)
assert handles.connection is not None
print("Rehydrated handle:", handles.connection.connection_id)
if __name__ == "__main__":
asyncio.run(main())

View File

@ -16,6 +16,10 @@ layer names as values. Prompt aggregation depends on insertion order: prefix
prompts are collected from first to last layer, while suffix prompts are
collected in reverse.
Serializable graph config uses registry type ids rather than import paths.
``CompositorBuilder`` resolves config nodes through ``LayerRegistry`` and can
mix those nodes with live layer instances for Python objects and callables.
``Compositor.enter`` enters layers in compositor order and exits them in reverse
order through ``AsyncExitStack``. It accepts an optional ``CompositorSession``
whose layer controls must match the compositor layer names and order. When
@ -30,13 +34,12 @@ returns those wrapped items unchanged.
"""
from collections import OrderedDict
from collections.abc import AsyncIterator, Callable, Iterable, Sequence
from collections.abc import AsyncIterator, Callable, Iterable, Mapping as MappingABC, Sequence
from contextlib import AsyncExitStack, asynccontextmanager
from dataclasses import dataclass, field
from importlib import import_module
from typing import TYPE_CHECKING, Annotated, Any, Generic, Mapping, TypedDict, cast
from typing import Any, Generic, Mapping, TypedDict, cast
from pydantic import AfterValidator, BaseModel, ConfigDict, Field, JsonValue
from pydantic import BaseModel, ConfigDict, Field, JsonValue
from typing_extensions import Self, TypeVar
from agenton.layers.base import Layer, LayerControl, LifecycleState
@ -48,31 +51,6 @@ LayerPromptT = TypeVar("LayerPromptT", default=AllPromptTypes)
LayerToolT = TypeVar("LayerToolT", default=AllToolTypes)
class ImportedLayerConfig(BaseModel):
"""Config for constructing one layer from an import path."""
import_path: str
config: Any = None
model_config = ConfigDict(arbitrary_types_allowed=True)
def create_layer(self) -> Layer[Any, Any, Any]:
"""Import the target layer class and create it from config."""
try:
import_module_name, import_target = self.import_path.rsplit(":", 1)
except ValueError as e:
raise ValueError(
f"Invalid import string '{self.import_path}'. "
"It should be in the format 'module:ClassName'."
) from e
layer_t = getattr(import_module(import_module_name), import_target)
if not isinstance(layer_t, type) or not issubclass(layer_t, Layer):
raise TypeError(f"Imported target '{self.import_path}' must be a Layer subclass.")
return layer_t.from_config(config=self.config)
LayerSpec = Layer[Any, Any, Any] | ImportedLayerConfig
type CompositorTransformer[InputT, OutputT] = Callable[[Sequence[InputT]], Sequence[OutputT]]
@ -98,53 +76,30 @@ def _validate_config_model_input[ModelT: BaseModel](
return model_type.model_validate(value)
class CompositorLayerConfig(BaseModel):
"""Config entry for one named layer in a compositor.
``layer`` may be either an already constructed layer instance or an
``ImportedLayerConfig``. Direct instances are already initialized, so config
for imported layers lives inside ``ImportedLayerConfig`` instead of beside
the graph node fields.
"""
class LayerNodeConfig(BaseModel):
"""Serializable config for one registry-backed layer node."""
name: str
type: str
config: JsonValue = Field(default_factory=dict)
deps: Mapping[str, str] = Field(default_factory=dict)
layer: LayerSpec
metadata: Mapping[str, JsonValue] = Field(default_factory=dict)
model_config = ConfigDict(arbitrary_types_allowed=True)
def create_layer(self) -> Layer[Any, Any, Any]:
"""Create or return the configured layer instance."""
if isinstance(self.layer, Layer):
return self.layer
return self.layer.create_layer()
type CompositorLayerConfigValue = _ConfigModelValue[CompositorLayerConfig]
def _validate_layer_config_input(value: CompositorLayerConfigValue) -> CompositorLayerConfig:
return _validate_config_model_input(CompositorLayerConfig, value)
type CompositorLayerConfigInput = Annotated[
CompositorLayerConfigValue,
AfterValidator(_validate_layer_config_input),
]
model_config = ConfigDict(extra="forbid")
class CompositorConfig(BaseModel):
"""Serializable config for constructing a compositor graph.
``layers`` accepts ready-made ``CompositorLayerConfig`` instances, raw JSON
values, or JSON-encoded strings/bytes. After validation, callers always see
normalized ``CompositorLayerConfig`` objects.
The graph references layer implementations by registry type id. Live Python
objects and callables are intentionally excluded; compose those with
``CompositorBuilder.add_instance``.
"""
if TYPE_CHECKING:
layers: list[CompositorLayerConfig]
else:
layers: list[CompositorLayerConfigInput]
schema_version: int = 1
layers: list[LayerNodeConfig]
model_config = ConfigDict(extra="forbid")
type CompositorConfigValue = _ConfigModelValue[CompositorConfig] | Mapping[str, object]
@ -154,24 +109,88 @@ def _validate_compositor_config_input(value: CompositorConfigValue) -> Composito
return _validate_config_model_input(CompositorConfig, value)
@dataclass(frozen=True, slots=True)
class LayerDescriptor:
"""Registry descriptor inferred from a layer class."""
type_id: str
layer_type: type[Layer[Any, Any, Any, Any, Any, Any]]
config_type: type[BaseModel]
runtime_state_type: type[BaseModel]
runtime_handles_type: type[BaseModel]
class LayerRegistry:
"""Manual registry for config-constructible layer classes.
Registration infers config and runtime schemas from layer class attributes.
A registered layer must have a type id, either declared as ``type_id`` on the
class or supplied to ``register_layer``.
"""
__slots__ = ("_descriptors",)
_descriptors: dict[str, LayerDescriptor]
def __init__(self) -> None:
self._descriptors = {}
def register_layer(
self,
layer_type: type[Layer[Any, Any, Any, Any, Any, Any]],
*,
type_id: str | None = None,
) -> None:
"""Register ``layer_type`` under its inferred or explicit type id."""
resolved_type_id = type_id or layer_type.type_id
if resolved_type_id is not None and not isinstance(resolved_type_id, str):
raise TypeError(f"Layer type id for '{layer_type.__qualname__}' must be a string.")
if resolved_type_id is None or not resolved_type_id:
raise ValueError(f"Layer '{layer_type.__qualname__}' must declare a type_id or be registered with one.")
if resolved_type_id in self._descriptors:
raise ValueError(f"Layer type id '{resolved_type_id}' is already registered.")
self._descriptors[resolved_type_id] = LayerDescriptor(
type_id=resolved_type_id,
layer_type=layer_type,
config_type=layer_type.config_type,
runtime_state_type=layer_type.runtime_state_type,
runtime_handles_type=layer_type.runtime_handles_type,
)
def resolve(self, type_id: str) -> LayerDescriptor:
"""Return the descriptor for ``type_id`` or raise ``KeyError``."""
try:
return self._descriptors[type_id]
except KeyError as e:
raise KeyError(f"Layer type id '{type_id}' is not registered.") from e
def descriptors(self) -> Mapping[str, LayerDescriptor]:
"""Return registered descriptors keyed by type id."""
return dict(self._descriptors)
class CompositorSession:
"""External lifecycle session for layer contexts entered by a compositor.
A session owns one ``LayerControl`` per compositor layer name, preserving
compositor order. Broadcast methods are convenience APIs for setting every
layer's per-entry exit intent; ``layer`` allows explicit per-layer control
when callers need partial suspend/delete behavior. A mixed session with any
closed layer cannot be entered again because compositor entry is all-or-none.
compositor order. Controls must be created from the matching layer schemas;
prefer ``Compositor.new_session`` or ``Compositor.session_from_snapshot`` for
public session construction. Broadcast methods are convenience APIs for
setting every layer's per-entry exit intent; ``layer`` allows explicit
per-layer control when callers need partial suspend/delete behavior. A mixed
session with any closed layer cannot be entered again because compositor
entry is all-or-none.
"""
__slots__ = ("layer_controls",)
layer_controls: OrderedDict[str, LayerControl]
def __init__(self, layer_names: Iterable[str]) -> None:
self.layer_controls = OrderedDict(
(layer_name, LayerControl()) for layer_name in layer_names
)
def __init__(self, layer_names: Iterable[str] | Mapping[str, LayerControl]) -> None:
if isinstance(layer_names, MappingABC):
self.layer_controls = OrderedDict(layer_names.items())
return
self.layer_controls = OrderedDict((layer_name, LayerControl()) for layer_name in layer_names)
def suspend_on_exit(self) -> None:
"""Request suspend behavior for every layer when this entry exits."""
@ -188,6 +207,124 @@ class CompositorSession:
return self.layer_controls[name]
class LayerSessionSnapshot(BaseModel):
"""Serializable snapshot for one layer control."""
name: str
state: LifecycleState
runtime_state: dict[str, JsonValue]
model_config = ConfigDict(extra="forbid")
class CompositorSessionSnapshot(BaseModel):
"""Serializable compositor session snapshot.
Snapshots include runtime state only. Live runtime handles are intentionally
excluded and must be rehydrated by resume hooks using runtime state.
"""
schema_version: int = 1
layers: list[LayerSessionSnapshot]
model_config = ConfigDict(extra="forbid")
@dataclass(frozen=True, slots=True)
class _LayerBuildEntry:
name: str
layer: Layer[Any, Any, Any, Any, Any, Any]
deps: Mapping[str, str]
class CompositorBuilder:
"""Build compositors from registry config nodes and live instances."""
__slots__ = ("_registry", "_entries")
_registry: LayerRegistry
_entries: list[_LayerBuildEntry]
def __init__(self, registry: LayerRegistry) -> None:
self._registry = registry
self._entries = []
def add_config(self, config: CompositorConfigValue) -> Self:
"""Add all layers from a serializable compositor config."""
conf = _validate_compositor_config_input(config)
if conf.schema_version != 1:
raise ValueError(f"Unsupported compositor config schema_version: {conf.schema_version}.")
for layer_conf in conf.layers:
self.add_config_layer(
name=layer_conf.name,
type=layer_conf.type,
config=layer_conf.config,
deps=layer_conf.deps,
)
return self
def add_config_layer(
self,
*,
name: str,
type: str,
config: object | None = None,
deps: Mapping[str, str] | None = None,
) -> Self:
"""Resolve, validate, and add one registry-backed layer config node."""
descriptor = self._registry.resolve(type)
raw_config = {} if config is None else config
validated_config = descriptor.config_type.model_validate(raw_config)
layer = descriptor.layer_type.from_config(cast(Any, validated_config))
self.add_instance(name=name, layer=layer, deps=deps)
return self
def add_instance(
self,
*,
name: str,
layer: Layer[Any, Any, Any, Any, Any, Any],
deps: Mapping[str, str] | None = None,
) -> Self:
"""Add a live layer instance, useful for Python objects and callables."""
self._entries.append(_LayerBuildEntry(name=name, layer=layer, deps=dict(deps or {})))
return self
def build[PromptT, ToolT, LayerPromptT, LayerToolT](
self,
*,
prompt_transformer: CompositorTransformer[LayerPromptT, PromptT] | None = None,
tool_transformer: CompositorTransformer[LayerToolT, ToolT] | None = None,
) -> "Compositor[PromptT, ToolT, LayerPromptT, LayerToolT]":
"""Validate names/dependencies, bind deps, and return a compositor."""
layers: OrderedDict[str, Layer[Any, Any, Any, Any, Any, Any]] = OrderedDict()
deps_name_mapping: dict[str, Mapping[str, str]] = {}
for entry in self._entries:
if entry.name in layers:
raise ValueError(f"Duplicate layer name '{entry.name}'.")
layers[entry.name] = entry.layer
deps_name_mapping[entry.name] = entry.deps
layer_names = set(layers)
for layer_name, deps in deps_name_mapping.items():
declared_deps = layers[layer_name].dependency_names()
unknown_dep_keys = set(deps) - declared_deps
if unknown_dep_keys:
names = ", ".join(sorted(unknown_dep_keys))
raise ValueError(f"Layer '{layer_name}' declares unknown dependency keys: {names}.")
missing_targets = set(deps.values()) - layer_names
if missing_targets:
names = ", ".join(sorted(missing_targets))
raise ValueError(f"Layer '{layer_name}' depends on undefined layer names: {names}.")
return Compositor(
layers=layers,
deps_name_mapping=deps_name_mapping,
prompt_transformer=prompt_transformer,
tool_transformer=tool_transformer,
)
@dataclass(kw_only=True)
class Compositor(Generic[PromptT, ToolT, LayerPromptT, LayerToolT]):
"""Framework-neutral ordered layer graph with lifecycle and aggregation.
@ -199,7 +336,7 @@ class Compositor(Generic[PromptT, ToolT, LayerPromptT, LayerToolT]):
from exposed item types.
"""
layers: OrderedDict[str, Layer[Any, Any, Any]]
layers: OrderedDict[str, Layer[Any, Any, Any, Any, Any, Any]]
deps_name_mapping: Mapping[str, Mapping[str, str]] = field(default_factory=dict)
prompt_transformer: CompositorTransformer[LayerPromptT, PromptT] | None = None
tool_transformer: CompositorTransformer[LayerToolT, ToolT] | None = None
@ -213,19 +350,12 @@ class Compositor(Generic[PromptT, ToolT, LayerPromptT, LayerToolT]):
cls,
conf: CompositorConfigValue,
*,
registry: LayerRegistry,
prompt_transformer: CompositorTransformer[LayerPromptT, PromptT] | None = None,
tool_transformer: CompositorTransformer[LayerToolT, ToolT] | None = None,
) -> Self:
"""Create layers from config-like input and bind named dependencies."""
conf = _validate_compositor_config_input(conf)
layers: OrderedDict[str, Layer[Any, Any, Any]] = OrderedDict()
for layer_conf in conf.layers:
layers[layer_conf.name] = layer_conf.create_layer()
deps_name_mapping = {layer_conf.name: layer_conf.deps for layer_conf in conf.layers}
return cls(
layers=layers,
deps_name_mapping=deps_name_mapping,
) -> "Compositor[PromptT, ToolT, LayerPromptT, LayerToolT]":
"""Create a compositor from registry-backed serializable config."""
return CompositorBuilder(registry).add_config(conf).build(
prompt_transformer=prompt_transformer,
tool_transformer=tool_transformer,
)
@ -257,7 +387,60 @@ class Compositor(Generic[PromptT, ToolT, LayerPromptT, LayerToolT]):
def new_session(self) -> CompositorSession:
"""Create a fresh lifecycle session matching this compositor's layer order."""
return CompositorSession(self.layers)
return CompositorSession(
OrderedDict((layer_name, layer.new_control()) for layer_name, layer in self.layers.items())
)
def snapshot_session(self, session: CompositorSession) -> CompositorSessionSnapshot:
"""Serialize non-active session lifecycle state and runtime state.
Runtime handles are live Python objects and are intentionally excluded.
"""
self._validate_session(session)
active_layers = [name for name, control in session.layer_controls.items() if control.state is LifecycleState.ACTIVE]
if active_layers:
names = ", ".join(active_layers)
raise RuntimeError(f"Cannot snapshot active compositor session layers: {names}.")
return CompositorSessionSnapshot(
layers=[
LayerSessionSnapshot(
name=name,
state=control.state,
runtime_state=cast(dict[str, JsonValue], control.runtime_state.model_dump(mode="json")),
)
for name, control in session.layer_controls.items()
]
)
def session_from_snapshot(self, snapshot: CompositorSessionSnapshot | JsonValue | str | bytes) -> CompositorSession:
"""Restore a session from a snapshot and reinitialize empty handles."""
snapshot = _validate_config_model_input(CompositorSessionSnapshot, snapshot)
if snapshot.schema_version != 1:
raise ValueError(f"Unsupported compositor session snapshot schema_version: {snapshot.schema_version}.")
snapshot_layer_names = tuple(layer.name for layer in snapshot.layers)
expected_layer_names = tuple(self.layers)
if snapshot_layer_names != expected_layer_names:
expected = ", ".join(expected_layer_names)
actual = ", ".join(snapshot_layer_names)
raise ValueError(
"CompositorSessionSnapshot layer names must match compositor layers in order. "
f"Expected [{expected}], got [{actual}]."
)
active_layers = [layer.name for layer in snapshot.layers if layer.state is LifecycleState.ACTIVE]
if active_layers:
names = ", ".join(active_layers)
raise ValueError(f"Cannot restore active compositor session layers from snapshot: {names}.")
controls = OrderedDict(
(
layer_snapshot.name,
self.layers[layer_snapshot.name].new_control(
state=layer_snapshot.state,
runtime_state=layer_snapshot.runtime_state,
),
)
for layer_snapshot in snapshot.layers
)
return CompositorSession(controls)
@asynccontextmanager
async def enter(
@ -288,6 +471,18 @@ class Compositor(Generic[PromptT, ToolT, LayerPromptT, LayerToolT]):
"CompositorSession layer names must match compositor layers in order. "
f"Expected [{expected}], got [{actual}]."
)
for layer_name, layer in self.layers.items():
control = session.layer_controls[layer_name]
if not isinstance(control.runtime_state, layer.runtime_state_type):
raise TypeError(
f"CompositorSession layer '{layer_name}' runtime_state must be "
f"{layer.runtime_state_type.__name__}, got {type(control.runtime_state).__name__}."
)
if not isinstance(control.runtime_handles, layer.runtime_handles_type):
raise TypeError(
f"CompositorSession layer '{layer_name}' runtime_handles must be "
f"{layer.runtime_handles_type.__name__}, got {type(control.runtime_handles).__name__}."
)
def _ensure_session_can_enter(self, session: CompositorSession) -> None:
"""Reject active or closed layer controls before any layer side effects."""
@ -330,14 +525,15 @@ class Compositor(Generic[PromptT, ToolT, LayerPromptT, LayerToolT]):
__all__ = [
"Compositor",
"CompositorBuilder",
"CompositorConfig",
"CompositorConfigValue",
"CompositorLayerConfigInput",
"CompositorSessionSnapshot",
"CompositorSession",
"CompositorTransformer",
"CompositorTransformerKwargs",
"CompositorLayerConfig",
"CompositorLayerConfigValue",
"ImportedLayerConfig",
"LayerSpec",
"LayerDescriptor",
"LayerNodeConfig",
"LayerRegistry",
"LayerSessionSnapshot",
]

View File

@ -5,7 +5,17 @@
families while keeping concrete reusable layers in ``agenton_collections``.
"""
from agenton.layers.base import ExitIntent, Layer, LayerControl, LayerDeps, LifecycleState, NoLayerDeps
from agenton.layers.base import (
EmptyLayerConfig,
EmptyRuntimeHandles,
EmptyRuntimeState,
ExitIntent,
Layer,
LayerControl,
LayerDeps,
LifecycleState,
NoLayerDeps,
)
from agenton.layers.types import (
AllPromptTypes,
AllToolTypes,
@ -29,6 +39,9 @@ __all__ = [
"LayerControl",
"LifecycleState",
"ExitIntent",
"EmptyLayerConfig",
"EmptyRuntimeState",
"EmptyRuntimeHandles",
"NoLayerDeps",
"PlainLayer",
"PlainPrompt",

View File

@ -1,11 +1,13 @@
"""Core layer abstractions and typed dependency binding.
Layers declare their dependency shape with ``Layer[DepsT, PromptT, ToolT]``.
Layers declare their dependency shape with ``Layer[DepsT, PromptT, ToolT, ...]``.
``DepsT`` must be a ``LayerDeps`` subclass whose annotated members are concrete
``Layer`` subclasses or modern optional dependencies such as ``SomeLayer |
None``. The base class infers ``deps_type`` from the generic base when possible,
while still allowing subclasses to set ``deps_type`` explicitly for unusual
inheritance patterns.
None``. The optional trailing generic slots declare Pydantic schemas for config,
serializable runtime state, and live runtime handles. The base class infers
``deps_type`` and schema class attributes from the generic base when possible,
while still allowing subclasses to set them explicitly for unusual inheritance
patterns.
``Layer.bind_deps`` is the mutation point for dependency state. Layer
implementations should treat ``self.deps`` as unavailable until a compositor or
@ -17,9 +19,10 @@ machine and per-session runtime owner. A fresh control starts in
while active or closed controls are rejected to prevent ambiguous nested or
post-delete reuse. Exit behavior is selected per entry with ``ExitIntent`` and
resets to delete on every successful enter. Layer instances are shared graph and
capability definitions, so session-local ids, handles, clients, and other
runtime values generated by lifecycle hooks belong in
``LayerControl.runtime_state`` rather than on ``self``.
capability definitions, so session-local serializable ids, checkpoints, and
other snapshot data belong in ``LayerControl.runtime_state``; live clients,
connections, and process handles belong in ``LayerControl.runtime_handles``.
Neither category should be stored on ``self`` when it is session-local.
``Layer`` is framework-neutral over prompt and tool item types. The native
``prefix_prompts``, ``suffix_prompts``, and ``tools`` properties are the layer
@ -34,9 +37,18 @@ from contextlib import AbstractAsyncContextManager, asynccontextmanager
from dataclasses import dataclass, field
from enum import StrEnum
from types import UnionType
from typing import Any, Mapping, Sequence, Union, cast, get_args, get_origin, get_type_hints
from typing import Any, ClassVar, Generic, Mapping, Sequence, Union, cast, get_args, get_origin, get_type_hints
from typing_extensions import Self
from pydantic import BaseModel, ConfigDict
from typing_extensions import Self, TypeVar
_DepsT = TypeVar("_DepsT", bound="LayerDeps")
_PromptT = TypeVar("_PromptT")
_ToolT = TypeVar("_ToolT")
_ConfigT = TypeVar("_ConfigT", bound=BaseModel, default="EmptyLayerConfig")
_RuntimeStateT = TypeVar("_RuntimeStateT", bound=BaseModel, default="EmptyRuntimeState")
_RuntimeHandlesT = TypeVar("_RuntimeHandlesT", bound=BaseModel, default="EmptyRuntimeHandles")
class LayerDeps:
@ -47,7 +59,7 @@ class LayerDeps:
are always assigned as attributes; missing optional values become ``None``.
"""
def __init__(self, **deps: "Layer[Any, Any, Any] | None") -> None:
def __init__(self, **deps: "Layer[Any, Any, Any, Any, Any, Any] | None") -> None:
dep_specs = _get_dep_specs(type(self))
missing_names = {name for name, spec in dep_specs.items() if not spec.optional} - deps.keys()
if missing_names:
@ -79,6 +91,28 @@ class NoLayerDeps(LayerDeps):
"""Dependency container for layers that do not require other layers."""
class EmptyLayerConfig(BaseModel):
"""Default serializable config schema for layers without config."""
model_config = ConfigDict(extra="forbid")
class EmptyRuntimeState(BaseModel):
"""Default serializable per-session runtime state schema."""
model_config = ConfigDict(extra="forbid", validate_assignment=True)
class EmptyRuntimeHandles(BaseModel):
"""Default live per-session runtime handle schema.
Handles may contain arbitrary Python objects and are intentionally excluded
from session snapshots.
"""
model_config = ConfigDict(extra="forbid", validate_assignment=True, arbitrary_types_allowed=True)
class LifecycleState(StrEnum):
"""Externally observable lifecycle state for a layer control."""
@ -96,26 +130,31 @@ class ExitIntent(StrEnum):
@dataclass(slots=True)
class LayerControl:
class LayerControl(Generic[_RuntimeStateT, _RuntimeHandlesT]):
"""Stateful control slot passed into a layer entry context.
``Layer.enter`` requires the caller to provide this object. The control owns
the layer lifecycle state, the current entry's exit intent, and arbitrary
per-session runtime state. Call ``suspend_on_exit`` before leaving the
per-session runtime state and live handles. Call ``suspend_on_exit`` before leaving the
context to make a later entry resume; call ``delete_on_exit`` or do nothing
for the default delete behavior. Store session-local ids and resource
handles in ``runtime_state`` so concurrent or later sessions do not share
mutable runtime data through the layer instance.
for the default delete behavior. Store session-local serializable ids,
checkpoints, and other snapshot data in ``runtime_state``. Store live
clients, connections, process handles, and other non-serializable objects in
``runtime_handles``. Do not put either kind of session-local data on the
shared layer instance.
``runtime_state`` intentionally persists after suspend and delete. Suspend,
resume, and delete hooks can inspect the same values created on entry, and
callers may inspect closed-session diagnostics after exit. Reuse is still
governed by ``state``: a closed control cannot be entered again.
governed by ``state``: a closed control cannot be entered again. Runtime
handles are not serialized in snapshots and should be rehydrated from
runtime state in resume hooks.
"""
state: LifecycleState = LifecycleState.NEW
exit_intent: ExitIntent = ExitIntent.DELETE
runtime_state: dict[str, object] = field(default_factory=dict)
runtime_state: _RuntimeStateT = field(default_factory=lambda: cast(_RuntimeStateT, EmptyRuntimeState()))
runtime_handles: _RuntimeHandlesT = field(default_factory=lambda: cast(_RuntimeHandlesT, EmptyRuntimeHandles()))
def suspend_on_exit(self) -> None:
"""Request suspend behavior when the current layer entry exits."""
@ -130,11 +169,14 @@ class LayerControl:
class LayerDepSpec:
"""Runtime dependency specification derived from a deps annotation."""
layer_type: type["Layer[Any, Any, Any]"]
layer_type: type["Layer[Any, Any, Any, Any, Any, Any]"]
optional: bool = False
class Layer[DepsT: LayerDeps, PromptT, ToolT](ABC):
class Layer(
ABC,
Generic[_DepsT, _PromptT, _ToolT, _ConfigT, _RuntimeStateT, _RuntimeHandlesT],
):
"""Framework-neutral base class for prompt/tool layers.
Subclasses expose optional prompt fragments and tools through typed
@ -147,15 +189,20 @@ class Layer[DepsT: LayerDeps, PromptT, ToolT](ABC):
extra runtime resources.
"""
deps_type: type[DepsT]
deps: DepsT
deps_type: type[_DepsT]
deps: _DepsT
type_id: ClassVar[str | None] = None
config_type: ClassVar[type[BaseModel]] = EmptyLayerConfig
runtime_state_type: ClassVar[type[BaseModel]] = EmptyRuntimeState
runtime_handles_type: ClassVar[type[BaseModel]] = EmptyRuntimeHandles
def __init_subclass__(cls) -> None:
super().__init_subclass__()
is_generic_template = _is_generic_layer_template(cls)
deps_type = cls.__dict__.get("deps_type")
if deps_type is None:
deps_type = _infer_deps_type(cls) or getattr(cls, "deps_type", None)
if deps_type is None and _is_generic_layer_template(cls):
if deps_type is None and is_generic_template:
return
if deps_type is not None:
cls.deps_type = deps_type # pyright: ignore[reportAttributeAccessIssue]
@ -164,25 +211,63 @@ class Layer[DepsT: LayerDeps, PromptT, ToolT](ABC):
if not isinstance(deps_type, type) or not issubclass(deps_type, LayerDeps):
raise TypeError(f"{cls.__name__}.deps_type must be a LayerDeps subclass.")
_get_dep_specs(deps_type)
_init_schema_type(cls, "config_type", _infer_schema_type(cls, 3, "config_type"), EmptyLayerConfig)
_init_schema_type(
cls,
"runtime_state_type",
_infer_schema_type(cls, 4, "runtime_state_type"),
EmptyRuntimeState,
)
_init_schema_type(
cls,
"runtime_handles_type",
_infer_schema_type(cls, 5, "runtime_handles_type"),
EmptyRuntimeHandles,
)
@classmethod
def from_config(cls: type[Self], config: Any) -> Self:
"""Create a layer from serialized config.
def from_config(cls: type[Self], config: _ConfigT) -> Self:
"""Create a layer from schema-validated serialized config.
Layers are not config-constructible by default. Subclasses that accept
config should override this method and validate dynamic input before
constructing the layer.
Registries/builders validate raw config with ``config_type`` before
calling this method. Layers are not config-constructible by default.
Subclasses that accept config should override this method and consume
the typed Pydantic model for their schema.
"""
raise TypeError(f"{cls.__name__} cannot be created from config.")
def bind_deps(self, deps: Mapping[str, "Layer[Any, Any, Any] | None"]) -> None:
@classmethod
def dependency_names(cls) -> frozenset[str]:
"""Return dependency field names declared by this layer's deps schema."""
return frozenset(_get_dep_specs(cls.deps_type))
def new_control(
self,
*,
state: LifecycleState = LifecycleState.NEW,
runtime_state: object | None = None,
) -> LayerControl[_RuntimeStateT, _RuntimeHandlesT]:
"""Create a schema-validated per-session control for this layer.
``runtime_state`` is validated through ``runtime_state_type`` and live
handles are always initialized empty through ``runtime_handles_type``.
"""
raw_runtime_state = {} if runtime_state is None else runtime_state
return LayerControl(
state=state,
exit_intent=ExitIntent.DELETE,
runtime_state=cast(_RuntimeStateT, self.runtime_state_type.model_validate(raw_runtime_state)),
runtime_handles=cast(_RuntimeHandlesT, self.runtime_handles_type.model_validate({})),
)
def bind_deps(self, deps: Mapping[str, "Layer[Any, Any, Any, Any, Any, Any] | None"]) -> None:
"""Bind this layer's declared dependencies from a name-to-layer mapping.
The mapping may include more layers than the declared dependency fields.
Only names declared by ``deps_type`` are selected and validated. Missing
optional deps are bound as ``None``.
"""
resolved_deps: dict[str, Layer[Any, Any, Any] | None] = {}
resolved_deps: dict[str, Layer[Any, Any, Any, Any, Any, Any] | None] = {}
for name, spec in _get_dep_specs(self.deps_type).items():
if name not in deps:
if spec.optional:
@ -194,7 +279,7 @@ class Layer[DepsT: LayerDeps, PromptT, ToolT](ABC):
resolved_deps[name] = deps[name]
self.deps = self.deps_type(**resolved_deps)
def enter(self, control: LayerControl) -> AbstractAsyncContextManager[None]:
def enter(self, control: LayerControl[_RuntimeStateT, _RuntimeHandlesT]) -> AbstractAsyncContextManager[None]:
"""Return the layer's async entry context manager.
``control`` is the lifecycle control slot for this entry. Subclasses can
@ -204,7 +289,7 @@ class Layer[DepsT: LayerDeps, PromptT, ToolT](ABC):
return self.lifecycle_enter(control)
@asynccontextmanager
async def lifecycle_enter(self, control: LayerControl) -> AsyncIterator[None]:
async def lifecycle_enter(self, control: LayerControl[_RuntimeStateT, _RuntimeHandlesT]) -> AsyncIterator[None]:
"""Run the default explicit lifecycle state machine for one entry."""
if control.state is LifecycleState.NEW:
control.exit_intent = ExitIntent.DELETE
@ -233,37 +318,37 @@ class Layer[DepsT: LayerDeps, PromptT, ToolT](ABC):
await self.on_context_delete(control)
control.state = LifecycleState.CLOSED
async def on_context_create(self, control: LayerControl) -> None:
async def on_context_create(self, control: LayerControl[_RuntimeStateT, _RuntimeHandlesT]) -> None:
"""Run when the layer context is entered from ``LifecycleState.NEW``."""
async def on_context_delete(self, control: LayerControl) -> None:
async def on_context_delete(self, control: LayerControl[_RuntimeStateT, _RuntimeHandlesT]) -> None:
"""Run when the layer context exits with ``ExitIntent.DELETE``."""
async def on_context_suspend(self, control: LayerControl) -> None:
async def on_context_suspend(self, control: LayerControl[_RuntimeStateT, _RuntimeHandlesT]) -> None:
"""Run when the layer context exits with ``ExitIntent.SUSPEND``."""
async def on_context_resume(self, control: LayerControl) -> None:
async def on_context_resume(self, control: LayerControl[_RuntimeStateT, _RuntimeHandlesT]) -> None:
"""Run when the layer context enters from ``LifecycleState.SUSPENDED``."""
@property
def prefix_prompts(self) -> Sequence[PromptT]:
def prefix_prompts(self) -> Sequence[_PromptT]:
return []
@property
def suffix_prompts(self) -> Sequence[PromptT]:
def suffix_prompts(self) -> Sequence[_PromptT]:
return []
@property
def tools(self) -> Sequence[ToolT]:
def tools(self) -> Sequence[_ToolT]:
return []
@abstractmethod
def wrap_prompt(self, prompt: PromptT) -> object:
def wrap_prompt(self, prompt: _PromptT) -> object:
"""Wrap a native prompt item for compositor aggregation."""
raise NotImplementedError
@abstractmethod
def wrap_tool(self, tool: ToolT) -> object:
def wrap_tool(self, tool: _ToolT) -> object:
"""Wrap a native tool item for compositor aggregation."""
raise NotImplementedError
@ -297,45 +382,107 @@ def _as_dep_spec(annotation: object) -> LayerDepSpec | None:
return LayerDepSpec(layer_type=layer_type)
def _as_layer_type(annotation: object) -> type[Layer[Any, Any, Any]] | None:
def _as_layer_type(annotation: object) -> type[Layer[Any, Any, Any, Any, Any, Any]] | None:
runtime_type = get_origin(annotation) or annotation
if isinstance(runtime_type, type) and issubclass(runtime_type, Layer):
return cast(type[Layer[Any, Any, Any]], runtime_type)
return cast(type[Layer[Any, Any, Any, Any, Any, Any]], runtime_type)
return None
def _infer_deps_type(layer_type: type[Layer[Any, Any, Any]]) -> type[LayerDeps] | None:
return _infer_deps_type_from_bases(layer_type, {})
def _infer_deps_type(layer_type: type[Layer[Any, Any, Any, Any, Any, Any]]) -> type[LayerDeps] | None:
inferred = _infer_layer_generic_arg(layer_type, 0, {})
if inferred is None:
return None
return _as_deps_type(inferred)
def _infer_deps_type_from_bases(
layer_type: type[Layer[Any, Any, Any]],
def _infer_schema_type(
layer_type: type[Layer[Any, Any, Any, Any, Any, Any]],
index: int,
attr_name: str,
) -> type[BaseModel] | None:
inferred = _infer_schema_generic_arg(layer_type, attr_name, {}) or _infer_layer_generic_arg(layer_type, index, {})
if inferred is None:
return None
schema_type = _as_model_type(inferred)
if schema_type is None:
raise TypeError(f"{layer_type.__name__}.{attr_name} must be a Pydantic BaseModel subclass.")
return schema_type
def _infer_schema_generic_arg(
layer_type: type[Layer[Any, Any, Any, Any, Any, Any]],
attr_name: str,
substitutions: Mapping[object, object],
) -> type[LayerDeps] | None:
"""Infer the concrete deps container through generic Layer inheritance.
) -> object | None:
"""Infer schema type arguments exposed by typed layer family bases."""
expected_names = {
"config_type": {"ConfigT", "_ConfigT"},
"runtime_state_type": {"RuntimeStateT", "_RuntimeStateT"},
"runtime_handles_type": {"RuntimeHandlesT", "_RuntimeHandlesT"},
}[attr_name]
for base in getattr(layer_type, "__orig_bases__", ()):
origin = get_origin(base) or base
args = tuple(_substitute_type(arg, substitutions) for arg in get_args(base))
if not isinstance(origin, type) or not issubclass(origin, Layer):
continue
params = _generic_params(origin)
for param, arg in zip(params, args):
if getattr(param, "__name__", None) in expected_names:
return arg
next_substitutions = dict(substitutions)
next_substitutions.update(_generic_arg_substitutions(origin, args))
inferred = _infer_schema_generic_arg(origin, attr_name, next_substitutions)
if inferred is not None:
return inferred
return None
def _infer_layer_generic_arg(
layer_type: type[Layer[Any, Any, Any, Any, Any, Any]],
index: int,
substitutions: Mapping[object, object],
) -> object | None:
"""Infer one concrete ``Layer`` generic argument through inheritance.
This walks through intermediate generic base classes so subclasses can omit
an explicit ``deps_type`` in common cases such as ``class X(Base[YDeps])``.
explicit class attributes in common cases such as ``class X(Base[YDeps])``.
"""
for base in getattr(layer_type, "__orig_bases__", ()):
origin = get_origin(base) or base
args = tuple(_substitute_type(arg, substitutions) for arg in get_args(base))
if origin is Layer:
if not args:
if len(args) <= index:
continue
return _as_deps_type(args[0])
return args[index]
if not isinstance(origin, type) or not issubclass(origin, Layer):
continue
next_substitutions = dict(substitutions)
next_substitutions.update(_generic_arg_substitutions(origin, args))
inferred = _infer_deps_type_from_bases(origin, next_substitutions)
inferred = _infer_layer_generic_arg(origin, index, next_substitutions)
if inferred is not None:
return inferred
return None
def _init_schema_type(
layer_type: type[Layer[Any, Any, Any, Any, Any, Any]],
attr_name: str,
inferred_schema_type: type[BaseModel] | None,
default_schema_type: type[BaseModel],
) -> None:
schema_type = layer_type.__dict__.get(attr_name)
if schema_type is None:
schema_type = inferred_schema_type or getattr(layer_type, attr_name, default_schema_type)
setattr(layer_type, attr_name, schema_type)
if not isinstance(schema_type, type) or not issubclass(schema_type, BaseModel):
raise TypeError(f"{layer_type.__name__}.{attr_name} must be a Pydantic BaseModel subclass.")
def _substitute_type(value: object, substitutions: Mapping[object, object]) -> object:
if value in substitutions:
return substitutions[value]
@ -359,10 +506,15 @@ def _substitute_type(value: object, substitutions: Mapping[object, object]) -> o
def _generic_arg_substitutions(origin: type[Any], args: Sequence[object]) -> dict[object, object]:
params = _generic_params(origin)
return dict(zip(params, args))
def _generic_params(origin: type[Any]) -> Sequence[object]:
params = getattr(origin, "__type_params__", ())
if not params:
params = getattr(origin, "__parameters__", ())
return dict(zip(params, args))
return params
def _as_deps_type(value: object) -> type[LayerDeps] | None:
@ -372,7 +524,14 @@ def _as_deps_type(value: object) -> type[LayerDeps] | None:
return None
def _is_generic_layer_template(layer_type: type[Layer[Any, Any, Any]]) -> bool:
def _as_model_type(value: object) -> type[BaseModel] | None:
runtime_type = get_origin(value) or value
if isinstance(runtime_type, type) and issubclass(runtime_type, BaseModel):
return runtime_type
return None
def _is_generic_layer_template(layer_type: type[Layer[Any, Any, Any, Any, Any, Any]]) -> bool:
return bool(getattr(layer_type, "__type_params__", ())) or bool(
getattr(layer_type, "__parameters__", ())
)

View File

@ -2,7 +2,10 @@
``Layer`` itself is framework-neutral. This module defines typed layer families
that bind its prompt/tool generic slots to concrete contracts, such as ordinary
string prompts with plain callable tools or pydantic-ai prompt/tool shapes.
string prompts with plain callable tools or pydantic-ai prompt/tool shapes. The
families keep the trailing schema generic slots open so concrete layers can have
``config_type``, ``runtime_state_type``, and ``runtime_handles_type`` inferred
from type arguments instead of repeated class attributes.
Tagged aggregate aliases cover code paths that can accept any supported
prompt/tool family without changing the plain and pydantic-ai layer contracts.
Pydantic-ai names are imported for static analysis only, so ``agenton`` can be
@ -14,15 +17,17 @@ from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any, Generic, Literal
from typing_extensions import final, override
from typing_extensions import TypeVar, final, override
if TYPE_CHECKING:
from pydantic_ai import Tool
from pydantic_ai.tools import SystemPromptFunc
from agenton.layers.base import Layer, LayerDeps
from pydantic import BaseModel
from agenton.layers.base import EmptyLayerConfig, EmptyRuntimeHandles, EmptyRuntimeState, Layer, LayerDeps
type PlainPrompt = str
type PlainTool = Callable[..., Any]
@ -68,7 +73,17 @@ type AllPromptTypes = PlainPromptType | PydanticAIPromptType[Any]
type AllToolTypes = PlainToolType | PydanticAIToolType[Any]
class PlainLayer[DepsT: LayerDeps](Layer[DepsT, PlainPrompt, PlainTool]):
_DepsT = TypeVar("_DepsT", bound=LayerDeps)
_ConfigT = TypeVar("_ConfigT", bound=BaseModel, default=EmptyLayerConfig)
_RuntimeStateT = TypeVar("_RuntimeStateT", bound=BaseModel, default=EmptyRuntimeState)
_RuntimeHandlesT = TypeVar("_RuntimeHandlesT", bound=BaseModel, default=EmptyRuntimeHandles)
_AgentDepsT = TypeVar("_AgentDepsT")
class PlainLayer(
Generic[_DepsT, _ConfigT, _RuntimeStateT, _RuntimeHandlesT],
Layer[_DepsT, PlainPrompt, PlainTool, _ConfigT, _RuntimeStateT, _RuntimeHandlesT],
):
"""Layer base for ordinary string prompts and plain-callable tools."""
@final
@ -82,8 +97,16 @@ class PlainLayer[DepsT: LayerDeps](Layer[DepsT, PlainPrompt, PlainTool]):
return PlainToolType(tool)
class PydanticAILayer[DepsT: LayerDeps, AgentDepsT](
Layer[DepsT, PydanticAIPrompt[AgentDepsT], PydanticAITool[AgentDepsT]]
class PydanticAILayer(
Generic[_DepsT, _AgentDepsT, _ConfigT, _RuntimeStateT, _RuntimeHandlesT],
Layer[
_DepsT,
PydanticAIPrompt[_AgentDepsT],
PydanticAITool[_AgentDepsT],
_ConfigT,
_RuntimeStateT,
_RuntimeHandlesT,
],
):
"""Layer base for pydantic-ai prompt and tool adapters."""
@ -91,13 +114,13 @@ class PydanticAILayer[DepsT: LayerDeps, AgentDepsT](
@override
def wrap_prompt(
self,
prompt: PydanticAIPrompt[AgentDepsT],
) -> PydanticAIPromptType[AgentDepsT]:
prompt: PydanticAIPrompt[_AgentDepsT],
) -> PydanticAIPromptType[_AgentDepsT]:
return PydanticAIPromptType(prompt)
@final
@override
def wrap_tool(self, tool: PydanticAITool[AgentDepsT]) -> PydanticAIToolType[AgentDepsT]:
def wrap_tool(self, tool: PydanticAITool[_AgentDepsT]) -> PydanticAIToolType[_AgentDepsT]:
return PydanticAIToolType(tool)

View File

@ -1,6 +1,6 @@
"""Reusable collection layers for the plain layer family."""
from agenton_collections.layers.plain.basic import ObjectLayer, PromptLayer, ToolsLayer
from agenton_collections.layers.plain.basic import ObjectLayer, PromptLayer, PromptLayerConfig, ToolsLayer
from agenton_collections.layers.plain.dynamic_tools import (
DynamicToolsLayer,
DynamicToolsLayerDeps,
@ -12,6 +12,7 @@ __all__ = [
"DynamicToolsLayerDeps",
"ObjectLayer",
"PromptLayer",
"PromptLayerConfig",
"ToolsLayer",
"with_object",
]

View File

@ -10,30 +10,46 @@ from collections.abc import Callable, Sequence
from dataclasses import dataclass, field
from typing import Any
from pydantic import TypeAdapter
from pydantic import BaseModel, ConfigDict, Field
from agenton.layers.base import NoLayerDeps
from agenton.layers.types import PlainLayer
class PromptLayerConfig(BaseModel):
"""Serializable config schema for ``PromptLayer``."""
prefix: list[str] | str = Field(default_factory=list)
suffix: list[str] | str = Field(default_factory=list)
model_config = ConfigDict(extra="forbid")
@dataclass
class ObjectLayer[ObjectT](PlainLayer[NoLayerDeps]):
"""Layer that stores one typed object for downstream dependencies."""
"""Layer that stores one typed object for downstream dependencies.
Object layers are instance-only because arbitrary Python objects are not
serializable graph config. Add them with ``CompositorBuilder.add_instance``.
"""
value: ObjectT
@dataclass
class PromptLayer(PlainLayer[NoLayerDeps]):
class PromptLayer(PlainLayer[NoLayerDeps, PromptLayerConfig]):
"""Layer that contributes configured prefix and suffix prompt fragments."""
type_id = "plain.prompt"
prefix: list[str] | str = field(default_factory=list)
suffix: list[str] | str = field(default_factory=list)
@classmethod
def from_config(cls, config: Any):
"""Validate prompt config against this dataclass."""
return _PROMPT_LAYER_ADAPTER.validate_python(config)
def from_config(cls, config: BaseModel):
"""Create a prompt layer from validated prompt config."""
validated_config = PromptLayerConfig.model_validate(config)
return cls(prefix=validated_config.prefix, suffix=validated_config.suffix)
@property
def prefix_prompts(self) -> list[str]:
@ -50,7 +66,11 @@ class PromptLayer(PlainLayer[NoLayerDeps]):
@dataclass
class ToolsLayer(PlainLayer[NoLayerDeps]):
"""Layer that contributes configured plain-callable tools."""
"""Layer that contributes configured plain-callable tools.
Tool layers are instance-only because Python callables are live objects. Add
them with ``CompositorBuilder.add_instance``.
"""
tool_entries: Sequence[Callable[..., Any]] = ()
@ -59,10 +79,9 @@ class ToolsLayer(PlainLayer[NoLayerDeps]):
return list(self.tool_entries)
_PROMPT_LAYER_ADAPTER = TypeAdapter(PromptLayer)
__all__ = [
"ObjectLayer",
"PromptLayerConfig",
"PromptLayer",
"ToolsLayer",
]

View File

@ -0,0 +1,257 @@
import asyncio
from collections import OrderedDict
from dataclasses import dataclass
from pydantic import BaseModel, ConfigDict, ValidationError
from typing_extensions import override
from agenton.compositor import Compositor, CompositorBuilder, CompositorSession, LayerRegistry
from agenton.layers import EmptyLayerConfig, LayerControl, LayerDeps, NoLayerDeps, PlainLayer, PlainPromptType, PlainToolType
from agenton_collections.layers.plain import ObjectLayer, PromptLayer
def test_registry_infers_descriptor_and_rejects_duplicate_or_missing_type_id() -> None:
registry = LayerRegistry()
registry.register_layer(PromptLayer)
descriptor = registry.resolve("plain.prompt")
assert descriptor.layer_type is PromptLayer
assert descriptor.config_type is PromptLayer.config_type
try:
registry.register_layer(PromptLayer)
except ValueError as e:
assert str(e) == "Layer type id 'plain.prompt' is already registered."
else:
raise AssertionError("Expected ValueError.")
try:
registry.register_layer(InstanceOnlyLayer)
except ValueError as e:
assert "must declare a type_id" in str(e)
else:
raise AssertionError("Expected ValueError.")
try:
registry.register_layer(InstanceOnlyLayer, type_id=123) # pyright: ignore[reportArgumentType]
except TypeError as e:
assert str(e) == "Layer type id for 'InstanceOnlyLayer' must be a string."
else:
raise AssertionError("Expected TypeError.")
@dataclass(slots=True)
class InstanceOnlyLayer(PlainLayer[NoLayerDeps]):
pass
def test_builder_creates_config_layers_with_typed_validation() -> None:
registry = LayerRegistry()
registry.register_layer(PromptLayer)
compositor = (
CompositorBuilder(registry)
.add_config_layer(
name="prompt",
type="plain.prompt",
config={"prefix": "hello", "suffix": ["bye"]},
)
.build()
)
assert [prompt.value for prompt in compositor.prompts] == ["hello", "bye"]
try:
CompositorBuilder(registry).add_config_layer(
name="bad",
type="plain.prompt",
config={"unknown": "field"},
)
except ValidationError:
pass
else:
raise AssertionError("Expected ValidationError.")
class ObjectConsumerDeps(LayerDeps):
obj: ObjectLayer[str] # pyright: ignore[reportUninitializedInstanceVariable]
@dataclass(slots=True)
class ObjectConsumerLayer(PlainLayer[ObjectConsumerDeps]):
@property
@override
def prefix_prompts(self) -> list[str]:
return [self.deps.obj.value]
def test_builder_mixes_config_and_instances_and_rejects_invalid_deps() -> None:
registry = LayerRegistry()
registry.register_layer(PromptLayer)
compositor = (
CompositorBuilder(registry)
.add_config({"layers": [{"name": "prompt", "type": "plain.prompt", "config": {"prefix": "cfg"}}]})
.add_instance(name="obj", layer=ObjectLayer("instance"))
.add_instance(name="consumer", layer=ObjectConsumerLayer(), deps={"obj": "obj"})
.build()
)
assert [prompt.value for prompt in compositor.prompts] == ["cfg", "instance"]
try:
CompositorBuilder(registry).add_instance(
name="consumer",
layer=ObjectConsumerLayer(),
deps={"missing_dep_key": "obj"},
).build()
except ValueError as e:
assert str(e) == "Layer 'consumer' declares unknown dependency keys: missing_dep_key."
else:
raise AssertionError("Expected ValueError.")
try:
CompositorBuilder(registry).add_instance(
name="consumer",
layer=ObjectConsumerLayer(),
deps={"obj": "missing_target"},
).build()
except ValueError as e:
assert str(e) == "Layer 'consumer' depends on undefined layer names: missing_target."
else:
raise AssertionError("Expected ValueError.")
class HandleState(BaseModel):
resource_id: str = ""
model_config = ConfigDict(extra="forbid", validate_assignment=True)
class HandleBox:
def __init__(self, value: str) -> None:
self.value = value
class HandleModels(BaseModel):
handle: HandleBox | None = None
model_config = ConfigDict(extra="forbid", validate_assignment=True, arbitrary_types_allowed=True)
@dataclass(slots=True)
class HandleLayer(PlainLayer[NoLayerDeps, EmptyLayerConfig, HandleState, HandleModels]):
created: int = 0
resumed: int = 0
@override
async def on_context_create(self, control: LayerControl[HandleState, HandleModels]) -> None:
self.created += 1
control.runtime_handles.handle = HandleBox(control.runtime_state.resource_id)
@override
async def on_context_resume(self, control: LayerControl[HandleState, HandleModels]) -> None:
self.resumed += 1
control.runtime_handles.handle = HandleBox(f"resumed:{control.runtime_state.resource_id}")
def test_new_session_uses_layer_runtime_schemas() -> None:
compositor: Compositor[PlainPromptType, PlainToolType] = Compositor(
layers=OrderedDict([("handle", HandleLayer())])
)
session = compositor.new_session()
assert isinstance(session.layer("handle").runtime_state, HandleState)
assert isinstance(session.layer("handle").runtime_handles, HandleModels)
def test_enter_rejects_bad_session_runtime_schemas_before_layer_hooks() -> None:
layer = HandleLayer()
compositor: Compositor[PlainPromptType, PlainToolType] = Compositor(layers=OrderedDict([("handle", layer)]))
bad_session = CompositorSession(OrderedDict([("handle", LayerControl())]))
async def run() -> None:
async with compositor.enter(bad_session):
pass
try:
asyncio.run(run())
except TypeError as e:
assert str(e) == (
"CompositorSession layer 'handle' runtime_state must be HandleState, "
"got EmptyRuntimeState."
)
else:
raise AssertionError("Expected TypeError.")
assert layer.created == 0
def test_snapshot_rejects_active_sessions_and_excludes_handles() -> None:
compositor: Compositor[PlainPromptType, PlainToolType] = Compositor(
layers=OrderedDict([("handle", HandleLayer())])
)
session = compositor.session_from_snapshot(
{"layers": [{"name": "handle", "state": "new", "runtime_state": {"resource_id": "abc"}}]}
)
async def run() -> None:
async with compositor.enter(session):
try:
compositor.snapshot_session(session)
except RuntimeError as e:
assert str(e) == "Cannot snapshot active compositor session layers: handle."
else:
raise AssertionError("Expected RuntimeError.")
asyncio.run(run())
snapshot = compositor.snapshot_session(session)
assert snapshot.model_dump(mode="json") == {
"schema_version": 1,
"layers": [{"name": "handle", "state": "closed", "runtime_state": {"resource_id": "abc"}}],
}
def test_restore_validates_runtime_state_and_resume_rehydrates_handles() -> None:
layer = HandleLayer()
compositor: Compositor[PlainPromptType, PlainToolType] = Compositor(layers=OrderedDict([("handle", layer)]))
try:
compositor.session_from_snapshot(
{"layers": [{"name": "handle", "state": "suspended", "runtime_state": {"wrong": "field"}}]}
)
except ValidationError:
pass
else:
raise AssertionError("Expected ValidationError.")
restored = compositor.session_from_snapshot(
{"layers": [{"name": "handle", "state": "suspended", "runtime_state": {"resource_id": "abc"}}]}
)
async def run() -> None:
async with compositor.enter(restored):
control = restored.layer("handle")
assert isinstance(control.runtime_handles, HandleModels)
assert control.runtime_handles.handle is not None
assert control.runtime_handles.handle.value == "resumed:abc"
asyncio.run(run())
assert layer.resumed == 1
def test_session_from_snapshot_rejects_active_layer_state() -> None:
compositor: Compositor[PlainPromptType, PlainToolType] = Compositor(
layers=OrderedDict([("handle", HandleLayer())])
)
try:
compositor.session_from_snapshot(
{"layers": [{"name": "handle", "state": "active", "runtime_state": {"resource_id": "abc"}}]}
)
except ValueError as e:
assert str(e) == "Cannot restore active compositor session layers from snapshot: handle."
else:
raise AssertionError("Expected ValueError.")

View File

@ -4,11 +4,14 @@ from collections.abc import Iterator
from dataclasses import dataclass, field
from itertools import count
from pydantic import BaseModel, ConfigDict
from typing_extensions import override
from agenton.compositor import Compositor, CompositorSession
from agenton.layers import (
ExitIntent,
EmptyLayerConfig,
EmptyRuntimeHandles,
LayerControl,
LifecycleState,
NoLayerDeps,
@ -239,22 +242,30 @@ def test_failed_resume_keeps_control_reusable_as_suspended() -> None:
assert session.layer("trace").state is LifecycleState.CLOSED
class RuntimeState(BaseModel):
runtime_id: int | None = None
resumed_runtime_id: int | None = None
deleted_runtime_id: int | None = None
model_config = ConfigDict(extra="forbid", validate_assignment=True)
@dataclass(slots=True)
class RuntimeStateLayer(PlainLayer[NoLayerDeps]):
class RuntimeStateLayer(PlainLayer[NoLayerDeps, EmptyLayerConfig, RuntimeState]):
next_id: Iterator[int] = field(default_factory=lambda: count(1))
@override
async def on_context_create(self, control: LayerControl) -> None:
async def on_context_create(self, control: LayerControl[RuntimeState, EmptyRuntimeHandles]) -> None:
runtime_id = next(self.next_id)
control.runtime_state["runtime_id"] = runtime_id
control.runtime_state.runtime_id = runtime_id
@override
async def on_context_resume(self, control: LayerControl) -> None:
control.runtime_state["resumed_runtime_id"] = control.runtime_state["runtime_id"]
async def on_context_resume(self, control: LayerControl[RuntimeState, EmptyRuntimeHandles]) -> None:
control.runtime_state.resumed_runtime_id = control.runtime_state.runtime_id
@override
async def on_context_delete(self, control: LayerControl) -> None:
control.runtime_state["deleted_runtime_id"] = control.runtime_state["runtime_id"]
async def on_context_delete(self, control: LayerControl[RuntimeState, EmptyRuntimeHandles]) -> None:
control.runtime_state.deleted_runtime_id = control.runtime_state.runtime_id
def test_runtime_state_is_per_session_and_survives_suspend_resume_delete() -> None:
@ -275,12 +286,12 @@ def test_runtime_state_is_per_session_and_survives_suspend_resume_delete() -> No
asyncio.run(run())
assert first_session.layer("trace").runtime_state == {
assert first_session.layer("trace").runtime_state.model_dump(exclude_none=True) == {
"runtime_id": 1,
"resumed_runtime_id": 1,
"deleted_runtime_id": 1,
}
assert second_session.layer("trace").runtime_state == {
assert second_session.layer("trace").runtime_state.model_dump(exclude_none=True) == {
"runtime_id": 2,
"deleted_runtime_id": 2,
}

View File

@ -0,0 +1,94 @@
from dataclasses import dataclass
from pydantic import BaseModel, ConfigDict
from agenton.compositor import LayerRegistry
from agenton.layers import EmptyLayerConfig, EmptyRuntimeHandles, EmptyRuntimeState, LayerControl, NoLayerDeps, PlainLayer
class InferredConfig(BaseModel):
value: str = "configured"
model_config = ConfigDict(extra="forbid")
class InferredState(BaseModel):
count: int = 0
model_config = ConfigDict(extra="forbid", validate_assignment=True)
class InferredHandles(BaseModel):
token: object | None = None
model_config = ConfigDict(extra="forbid", validate_assignment=True, arbitrary_types_allowed=True)
@dataclass(slots=True)
class GenericSchemaLayer(PlainLayer[NoLayerDeps, InferredConfig, InferredState, InferredHandles]):
type_id = "test.generic-schema"
async def on_context_create(self, control: LayerControl[InferredState, InferredHandles]) -> None:
control.runtime_state.count += 1
control.runtime_handles.token = object()
@dataclass(slots=True)
class DefaultSchemaLayer(PlainLayer[NoLayerDeps]):
type_id = "test.default-schema"
def test_layer_infers_config_runtime_state_and_handles_from_generics() -> None:
layer = GenericSchemaLayer()
control = layer.new_control(runtime_state={"count": 3})
assert GenericSchemaLayer.config_type is InferredConfig
assert GenericSchemaLayer.runtime_state_type is InferredState
assert GenericSchemaLayer.runtime_handles_type is InferredHandles
assert isinstance(control.runtime_state, InferredState)
assert control.runtime_state.count == 3
assert isinstance(control.runtime_handles, InferredHandles)
def test_layer_uses_empty_schema_defaults_when_omitted() -> None:
layer = DefaultSchemaLayer()
control = layer.new_control()
assert DefaultSchemaLayer.config_type is EmptyLayerConfig
assert DefaultSchemaLayer.runtime_state_type is EmptyRuntimeState
assert DefaultSchemaLayer.runtime_handles_type is EmptyRuntimeHandles
assert isinstance(control.runtime_state, EmptyRuntimeState)
assert isinstance(control.runtime_handles, EmptyRuntimeHandles)
def test_invalid_declared_schema_type_is_rejected_clearly() -> None:
try:
class InvalidSchemaLayer(PlainLayer[NoLayerDeps]):
config_type = dict # pyright: ignore[reportAssignmentType]
except TypeError as e:
assert str(e) == "InvalidSchemaLayer.config_type must be a Pydantic BaseModel subclass."
else:
raise AssertionError("Expected TypeError.")
try:
class InvalidGenericSchemaLayer(PlainLayer[NoLayerDeps, dict[str, object]]): # pyright: ignore[reportInvalidTypeArguments]
pass
except TypeError as e:
assert str(e) == "InvalidGenericSchemaLayer.config_type must be a Pydantic BaseModel subclass."
else:
raise AssertionError("Expected TypeError.")
def test_registry_descriptor_uses_inferred_schema_types() -> None:
registry = LayerRegistry()
registry.register_layer(GenericSchemaLayer)
descriptor = registry.resolve("test.generic-schema")
assert descriptor.config_type is InferredConfig
assert descriptor.runtime_state_type is InferredState
assert descriptor.runtime_handles_type is InferredHandles

View File

@ -38,3 +38,11 @@ def test_agenton_pydantic_ai_example_smoke() -> None:
assert "ToolCallPart: count_words(" in result.stdout
assert "ToolCallPart: write_tagline(" in result.stdout
assert "TextPart:" in result.stdout
def test_agenton_session_snapshot_example_smoke() -> None:
result = _run_example("examples/agenton/session_snapshot.py")
assert result.returncode == 0, result.stderr
assert "Snapshot:" in result.stdout
assert "Rehydrated handle: restored:demo-connection" in result.stdout