mirror of
https://github.com/langgenius/dify.git
synced 2026-05-09 21:28:25 +08:00
add schema-backed agenton sessions
This commit is contained in:
parent
31a1de4828
commit
f316d19be6
@ -0,0 +1,3 @@
|
||||
# Dify Agent
|
||||
|
||||
Agenton documentation lives in [`docs/agenton/`](docs/agenton/).
|
||||
65
dify-agent/docs/agenton/README.md
Normal file
65
dify-agent/docs/agenton/README.md
Normal file
@ -0,0 +1,65 @@
|
||||
# Agenton configuration and sessions
|
||||
|
||||
Agenton composes shared `Layer` instances into a named graph. Treat layer
|
||||
instances as reusable capability definitions: config and dependency declarations
|
||||
belong on the layer class or instance, while per-session runtime values belong
|
||||
on the `LayerControl` created for that layer in a `CompositorSession`.
|
||||
|
||||
## Config, runtime state, and runtime handles
|
||||
|
||||
- **Config** is serializable graph input. Config-constructible layers declare a
|
||||
`type_id` and a Pydantic `config_type`; builders validate node config before
|
||||
calling `Layer.from_config(validated_config)`.
|
||||
- **Runtime state** is serializable per-layer/per-session state. Layers declare a
|
||||
Pydantic `runtime_state_type`; session snapshots persist this model with
|
||||
`model_dump(mode="json")`.
|
||||
- **Runtime handles** are live Python objects such as clients, open files, or
|
||||
process handles. Layers declare a Pydantic `runtime_handles_type` with
|
||||
`arbitrary_types_allowed=True`. Handles are never serialized; resume hooks
|
||||
should rehydrate them from runtime state.
|
||||
|
||||
`Layer.__init_subclass__` infers `deps_type`, `config_type`,
|
||||
`runtime_state_type`, and `runtime_handles_type` from generic base arguments
|
||||
when possible. For example, `PlainLayer[NoLayerDeps, MyConfig, MyState,
|
||||
MyHandles]` automatically installs those Pydantic schemas. Omitted schema slots
|
||||
default to `EmptyLayerConfig`, `EmptyRuntimeState`, and `EmptyRuntimeHandles`.
|
||||
Lifecycle hooks can annotate controls as `LayerControl[MyState, MyHandles]` to
|
||||
get static checking and IDE completion for runtime state and handles.
|
||||
|
||||
## Registry and builder
|
||||
|
||||
Register config-constructible layers manually:
|
||||
|
||||
```python
|
||||
registry = LayerRegistry()
|
||||
registry.register_layer(PromptLayer) # uses PromptLayer.type_id == "plain.prompt"
|
||||
```
|
||||
|
||||
Use `CompositorBuilder` to mix serializable config nodes with live instances:
|
||||
|
||||
```python
|
||||
compositor = (
|
||||
CompositorBuilder(registry)
|
||||
.add_config({"layers": [{"name": "prompt", "type": "plain.prompt", "config": {"prefix": "Hi"}}]})
|
||||
.add_instance(name="profile", layer=ObjectLayer(profile))
|
||||
.build()
|
||||
)
|
||||
```
|
||||
|
||||
Use `.add_instance()` for layers that require Python objects or callables, such
|
||||
as `ObjectLayer`, `ToolsLayer`, and dynamic tool layers.
|
||||
|
||||
## Session snapshot and restore
|
||||
|
||||
`Compositor.snapshot_session(session)` serializes non-active sessions, including
|
||||
layer lifecycle state and runtime state. It rejects active sessions because live
|
||||
handles cannot be snapshotted safely. Restore with
|
||||
`Compositor.session_from_snapshot(snapshot)`; restored controls validate runtime
|
||||
state with each layer schema and initialize empty runtime handles. Suspended
|
||||
sessions resume through `on_context_resume`, where handles should be hydrated
|
||||
from the restored runtime state.
|
||||
|
||||
Create sessions with `Compositor.new_session()` or
|
||||
`Compositor.session_from_snapshot()`. `Compositor.enter()` validates that every
|
||||
session control uses the target layer's runtime state and handle schemas before
|
||||
any lifecycle hook runs.
|
||||
@ -8,10 +8,9 @@ from inspect import signature
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from agenton.compositor import Compositor, CompositorLayerConfig
|
||||
from agenton.compositor import CompositorBuilder, LayerRegistry
|
||||
from agenton.layers import LayerControl, LayerDeps, NoLayerDeps, PlainLayer
|
||||
from agenton.layers.types import PlainPromptType, PlainToolType
|
||||
from agenton_collections.layers.plain import DynamicToolsLayer, ObjectLayer, ToolsLayer, with_object
|
||||
from agenton_collections.layers.plain import DynamicToolsLayer, ObjectLayer, PromptLayer, ToolsLayer, with_object
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
@ -75,51 +74,41 @@ async def main() -> None:
|
||||
)
|
||||
trace = TraceLayer()
|
||||
|
||||
compositor = Compositor[PlainPromptType, PlainToolType].from_config(
|
||||
{
|
||||
"layers": [
|
||||
{
|
||||
"name": "base_prompt",
|
||||
"layer": {
|
||||
"import_path": "agenton_collections.layers.plain:PromptLayer",
|
||||
registry = LayerRegistry()
|
||||
registry.register_layer(PromptLayer)
|
||||
compositor = (
|
||||
CompositorBuilder(registry)
|
||||
.add_config(
|
||||
{
|
||||
"layers": [
|
||||
{
|
||||
"name": "base_prompt",
|
||||
"type": "plain.prompt",
|
||||
"config": {
|
||||
"prefix": "Use config dicts for serializable layers.",
|
||||
"suffix": "Before finalizing, make the result easy to scan.",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "extra_prompt",
|
||||
"layer": {
|
||||
"import_path": "agenton_collections.layers.plain:PromptLayer",
|
||||
{
|
||||
"name": "extra_prompt",
|
||||
"type": "plain.prompt",
|
||||
"config": {
|
||||
"prefix": "Use constructed instances for objects, local code, and callables.",
|
||||
},
|
||||
},
|
||||
},
|
||||
CompositorLayerConfig(
|
||||
name="profile",
|
||||
layer=ObjectLayer[AgentProfile](profile),
|
||||
),
|
||||
CompositorLayerConfig(
|
||||
name="profile_prompt",
|
||||
# deps maps dependency field names to layer names only when
|
||||
# they differ.
|
||||
# deps={"profile": "profile"},
|
||||
layer=ProfilePromptLayer(),
|
||||
),
|
||||
CompositorLayerConfig(
|
||||
name="tools",
|
||||
layer=ToolsLayer(tool_entries=(count_words,)),
|
||||
),
|
||||
CompositorLayerConfig(
|
||||
name="dynamic_tools",
|
||||
deps={"object_layer": "profile"},
|
||||
layer=DynamicToolsLayer[AgentProfile](tool_entries=(write_tagline,)),
|
||||
),
|
||||
CompositorLayerConfig(name="trace", layer=trace),
|
||||
]
|
||||
},
|
||||
]
|
||||
}
|
||||
)
|
||||
.add_instance(name="profile", layer=ObjectLayer[AgentProfile](profile))
|
||||
.add_instance(name="profile_prompt", layer=ProfilePromptLayer())
|
||||
.add_instance(name="tools", layer=ToolsLayer(tool_entries=(count_words,)))
|
||||
.add_instance(
|
||||
name="dynamic_tools",
|
||||
deps={"object_layer": "profile"},
|
||||
layer=DynamicToolsLayer[AgentProfile](tool_entries=(write_tagline,)),
|
||||
)
|
||||
.add_instance(name="trace", layer=trace)
|
||||
.build()
|
||||
)
|
||||
|
||||
print("Prompts:")
|
||||
|
||||
@ -12,9 +12,8 @@ from pydantic_ai.messages import BuiltinToolCallPart, ModelMessage, ToolCallPart
|
||||
from pydantic_ai.models.openai import OpenAIChatModel # pyright: ignore[reportDeprecated]
|
||||
from pydantic_ai.models.test import TestModel
|
||||
|
||||
from agenton.compositor import Compositor, CompositorLayerConfig
|
||||
from agenton.layers.types import AllPromptTypes, AllToolTypes, PydanticAIPrompt, PydanticAITool
|
||||
from agenton_collections.layers.plain import ObjectLayer, ToolsLayer
|
||||
from agenton.compositor import CompositorBuilder, LayerRegistry
|
||||
from agenton_collections.layers.plain import ObjectLayer, PromptLayer, ToolsLayer
|
||||
from agenton_collections.layers.pydantic_ai import PydanticAIBridgeLayer
|
||||
from agenton_collections.transformers import PYDANTIC_AI_TRANSFORMERS
|
||||
|
||||
@ -55,40 +54,32 @@ async def main() -> None:
|
||||
tool_entries=(write_tagline,),
|
||||
)
|
||||
|
||||
compositor = Compositor[
|
||||
PydanticAIPrompt[object],
|
||||
PydanticAITool[object],
|
||||
AllPromptTypes,
|
||||
AllToolTypes,
|
||||
].from_config(
|
||||
{
|
||||
"layers": [
|
||||
{
|
||||
"name": "base_prompt",
|
||||
"layer": {
|
||||
"import_path": "agenton_collections.layers.plain:PromptLayer",
|
||||
registry = LayerRegistry()
|
||||
registry.register_layer(PromptLayer)
|
||||
compositor = (
|
||||
CompositorBuilder(registry)
|
||||
.add_config(
|
||||
{
|
||||
"layers": [
|
||||
{
|
||||
"name": "base_prompt",
|
||||
"type": "plain.prompt",
|
||||
"config": {
|
||||
"prefix": "Use the available tools before answering.",
|
||||
"suffix": "Return concise, inspectable output.",
|
||||
},
|
||||
},
|
||||
},
|
||||
CompositorLayerConfig(
|
||||
name="profile",
|
||||
layer=ObjectLayer[AgentProfile](profile),
|
||||
),
|
||||
CompositorLayerConfig(
|
||||
name="plain_tools",
|
||||
layer=ToolsLayer(tool_entries=(count_words,)),
|
||||
),
|
||||
CompositorLayerConfig(
|
||||
name="pydantic_ai_bridge",
|
||||
deps={"object_layer": "profile"},
|
||||
layer=pydantic_ai_bridge,
|
||||
),
|
||||
]
|
||||
},
|
||||
**PYDANTIC_AI_TRANSFORMERS,
|
||||
]
|
||||
}
|
||||
)
|
||||
.add_instance(name="profile", layer=ObjectLayer[AgentProfile](profile))
|
||||
.add_instance(name="plain_tools", layer=ToolsLayer(tool_entries=(count_words,)))
|
||||
.add_instance(
|
||||
name="pydantic_ai_bridge",
|
||||
deps={"object_layer": "profile"},
|
||||
layer=pydantic_ai_bridge,
|
||||
)
|
||||
.build(**PYDANTIC_AI_TRANSFORMERS)
|
||||
)
|
||||
|
||||
async with compositor.enter():
|
||||
|
||||
72
dify-agent/examples/agenton/session_snapshot.py
Normal file
72
dify-agent/examples/agenton/session_snapshot.py
Normal file
@ -0,0 +1,72 @@
|
||||
"""Run with: uv run --project dify-agent python examples/agenton/session_snapshot.py."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from typing_extensions import override
|
||||
|
||||
from agenton.compositor import Compositor
|
||||
from agenton.layers import LayerControl, NoLayerDeps, PlainLayer, PlainPromptType, PlainToolType
|
||||
|
||||
|
||||
class ConnectionState(BaseModel):
|
||||
connection_id: str = "demo-connection"
|
||||
|
||||
model_config = ConfigDict(extra="forbid", validate_assignment=True)
|
||||
|
||||
|
||||
class ConnectionHandle:
|
||||
def __init__(self, connection_id: str) -> None:
|
||||
self.connection_id = connection_id
|
||||
|
||||
|
||||
class ConnectionHandles(BaseModel):
|
||||
connection: ConnectionHandle | None = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid", validate_assignment=True, arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ConnectionLayer(PlainLayer[NoLayerDeps]):
|
||||
runtime_state_type: ClassVar[type[BaseModel]] = ConnectionState
|
||||
runtime_handles_type: ClassVar[type[BaseModel]] = ConnectionHandles
|
||||
|
||||
@override
|
||||
async def on_context_create(self, control: LayerControl) -> None:
|
||||
assert isinstance(control.runtime_state, ConnectionState)
|
||||
assert isinstance(control.runtime_handles, ConnectionHandles)
|
||||
control.runtime_handles.connection = ConnectionHandle(control.runtime_state.connection_id)
|
||||
|
||||
@override
|
||||
async def on_context_resume(self, control: LayerControl) -> None:
|
||||
assert isinstance(control.runtime_state, ConnectionState)
|
||||
assert isinstance(control.runtime_handles, ConnectionHandles)
|
||||
control.runtime_handles.connection = ConnectionHandle(f"restored:{control.runtime_state.connection_id}")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
compositor: Compositor[PlainPromptType, PlainToolType] = Compositor(
|
||||
layers=OrderedDict([("connection", ConnectionLayer())])
|
||||
)
|
||||
session = compositor.new_session()
|
||||
async with compositor.enter(session) as active_session:
|
||||
active_session.suspend_on_exit()
|
||||
|
||||
snapshot = compositor.snapshot_session(session)
|
||||
print("Snapshot:", snapshot.model_dump(mode="json"))
|
||||
|
||||
restored = compositor.session_from_snapshot(snapshot)
|
||||
async with compositor.enter(restored):
|
||||
handles = restored.layer("connection").runtime_handles
|
||||
assert isinstance(handles, ConnectionHandles)
|
||||
assert handles.connection is not None
|
||||
print("Rehydrated handle:", handles.connection.connection_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@ -16,6 +16,10 @@ layer names as values. Prompt aggregation depends on insertion order: prefix
|
||||
prompts are collected from first to last layer, while suffix prompts are
|
||||
collected in reverse.
|
||||
|
||||
Serializable graph config uses registry type ids rather than import paths.
|
||||
``CompositorBuilder`` resolves config nodes through ``LayerRegistry`` and can
|
||||
mix those nodes with live layer instances for Python objects and callables.
|
||||
|
||||
``Compositor.enter`` enters layers in compositor order and exits them in reverse
|
||||
order through ``AsyncExitStack``. It accepts an optional ``CompositorSession``
|
||||
whose layer controls must match the compositor layer names and order. When
|
||||
@ -30,13 +34,12 @@ returns those wrapped items unchanged.
|
||||
"""
|
||||
|
||||
from collections import OrderedDict
|
||||
from collections.abc import AsyncIterator, Callable, Iterable, Sequence
|
||||
from collections.abc import AsyncIterator, Callable, Iterable, Mapping as MappingABC, Sequence
|
||||
from contextlib import AsyncExitStack, asynccontextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from importlib import import_module
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Generic, Mapping, TypedDict, cast
|
||||
from typing import Any, Generic, Mapping, TypedDict, cast
|
||||
|
||||
from pydantic import AfterValidator, BaseModel, ConfigDict, Field, JsonValue
|
||||
from pydantic import BaseModel, ConfigDict, Field, JsonValue
|
||||
from typing_extensions import Self, TypeVar
|
||||
|
||||
from agenton.layers.base import Layer, LayerControl, LifecycleState
|
||||
@ -48,31 +51,6 @@ LayerPromptT = TypeVar("LayerPromptT", default=AllPromptTypes)
|
||||
LayerToolT = TypeVar("LayerToolT", default=AllToolTypes)
|
||||
|
||||
|
||||
class ImportedLayerConfig(BaseModel):
|
||||
"""Config for constructing one layer from an import path."""
|
||||
|
||||
import_path: str
|
||||
config: Any = None
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def create_layer(self) -> Layer[Any, Any, Any]:
|
||||
"""Import the target layer class and create it from config."""
|
||||
try:
|
||||
import_module_name, import_target = self.import_path.rsplit(":", 1)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"Invalid import string '{self.import_path}'. "
|
||||
"It should be in the format 'module:ClassName'."
|
||||
) from e
|
||||
|
||||
layer_t = getattr(import_module(import_module_name), import_target)
|
||||
if not isinstance(layer_t, type) or not issubclass(layer_t, Layer):
|
||||
raise TypeError(f"Imported target '{self.import_path}' must be a Layer subclass.")
|
||||
return layer_t.from_config(config=self.config)
|
||||
|
||||
|
||||
LayerSpec = Layer[Any, Any, Any] | ImportedLayerConfig
|
||||
type CompositorTransformer[InputT, OutputT] = Callable[[Sequence[InputT]], Sequence[OutputT]]
|
||||
|
||||
|
||||
@ -98,53 +76,30 @@ def _validate_config_model_input[ModelT: BaseModel](
|
||||
return model_type.model_validate(value)
|
||||
|
||||
|
||||
class CompositorLayerConfig(BaseModel):
|
||||
"""Config entry for one named layer in a compositor.
|
||||
|
||||
``layer`` may be either an already constructed layer instance or an
|
||||
``ImportedLayerConfig``. Direct instances are already initialized, so config
|
||||
for imported layers lives inside ``ImportedLayerConfig`` instead of beside
|
||||
the graph node fields.
|
||||
"""
|
||||
class LayerNodeConfig(BaseModel):
|
||||
"""Serializable config for one registry-backed layer node."""
|
||||
|
||||
name: str
|
||||
type: str
|
||||
config: JsonValue = Field(default_factory=dict)
|
||||
deps: Mapping[str, str] = Field(default_factory=dict)
|
||||
layer: LayerSpec
|
||||
metadata: Mapping[str, JsonValue] = Field(default_factory=dict)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def create_layer(self) -> Layer[Any, Any, Any]:
|
||||
"""Create or return the configured layer instance."""
|
||||
if isinstance(self.layer, Layer):
|
||||
return self.layer
|
||||
return self.layer.create_layer()
|
||||
|
||||
|
||||
type CompositorLayerConfigValue = _ConfigModelValue[CompositorLayerConfig]
|
||||
|
||||
|
||||
def _validate_layer_config_input(value: CompositorLayerConfigValue) -> CompositorLayerConfig:
|
||||
return _validate_config_model_input(CompositorLayerConfig, value)
|
||||
|
||||
|
||||
type CompositorLayerConfigInput = Annotated[
|
||||
CompositorLayerConfigValue,
|
||||
AfterValidator(_validate_layer_config_input),
|
||||
]
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class CompositorConfig(BaseModel):
|
||||
"""Serializable config for constructing a compositor graph.
|
||||
|
||||
``layers`` accepts ready-made ``CompositorLayerConfig`` instances, raw JSON
|
||||
values, or JSON-encoded strings/bytes. After validation, callers always see
|
||||
normalized ``CompositorLayerConfig`` objects.
|
||||
The graph references layer implementations by registry type id. Live Python
|
||||
objects and callables are intentionally excluded; compose those with
|
||||
``CompositorBuilder.add_instance``.
|
||||
"""
|
||||
|
||||
if TYPE_CHECKING:
|
||||
layers: list[CompositorLayerConfig]
|
||||
else:
|
||||
layers: list[CompositorLayerConfigInput]
|
||||
schema_version: int = 1
|
||||
layers: list[LayerNodeConfig]
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
type CompositorConfigValue = _ConfigModelValue[CompositorConfig] | Mapping[str, object]
|
||||
@ -154,24 +109,88 @@ def _validate_compositor_config_input(value: CompositorConfigValue) -> Composito
|
||||
return _validate_config_model_input(CompositorConfig, value)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class LayerDescriptor:
|
||||
"""Registry descriptor inferred from a layer class."""
|
||||
|
||||
type_id: str
|
||||
layer_type: type[Layer[Any, Any, Any, Any, Any, Any]]
|
||||
config_type: type[BaseModel]
|
||||
runtime_state_type: type[BaseModel]
|
||||
runtime_handles_type: type[BaseModel]
|
||||
|
||||
|
||||
class LayerRegistry:
|
||||
"""Manual registry for config-constructible layer classes.
|
||||
|
||||
Registration infers config and runtime schemas from layer class attributes.
|
||||
A registered layer must have a type id, either declared as ``type_id`` on the
|
||||
class or supplied to ``register_layer``.
|
||||
"""
|
||||
|
||||
__slots__ = ("_descriptors",)
|
||||
|
||||
_descriptors: dict[str, LayerDescriptor]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._descriptors = {}
|
||||
|
||||
def register_layer(
|
||||
self,
|
||||
layer_type: type[Layer[Any, Any, Any, Any, Any, Any]],
|
||||
*,
|
||||
type_id: str | None = None,
|
||||
) -> None:
|
||||
"""Register ``layer_type`` under its inferred or explicit type id."""
|
||||
resolved_type_id = type_id or layer_type.type_id
|
||||
if resolved_type_id is not None and not isinstance(resolved_type_id, str):
|
||||
raise TypeError(f"Layer type id for '{layer_type.__qualname__}' must be a string.")
|
||||
if resolved_type_id is None or not resolved_type_id:
|
||||
raise ValueError(f"Layer '{layer_type.__qualname__}' must declare a type_id or be registered with one.")
|
||||
if resolved_type_id in self._descriptors:
|
||||
raise ValueError(f"Layer type id '{resolved_type_id}' is already registered.")
|
||||
self._descriptors[resolved_type_id] = LayerDescriptor(
|
||||
type_id=resolved_type_id,
|
||||
layer_type=layer_type,
|
||||
config_type=layer_type.config_type,
|
||||
runtime_state_type=layer_type.runtime_state_type,
|
||||
runtime_handles_type=layer_type.runtime_handles_type,
|
||||
)
|
||||
|
||||
def resolve(self, type_id: str) -> LayerDescriptor:
|
||||
"""Return the descriptor for ``type_id`` or raise ``KeyError``."""
|
||||
try:
|
||||
return self._descriptors[type_id]
|
||||
except KeyError as e:
|
||||
raise KeyError(f"Layer type id '{type_id}' is not registered.") from e
|
||||
|
||||
def descriptors(self) -> Mapping[str, LayerDescriptor]:
|
||||
"""Return registered descriptors keyed by type id."""
|
||||
return dict(self._descriptors)
|
||||
|
||||
|
||||
class CompositorSession:
|
||||
"""External lifecycle session for layer contexts entered by a compositor.
|
||||
|
||||
A session owns one ``LayerControl`` per compositor layer name, preserving
|
||||
compositor order. Broadcast methods are convenience APIs for setting every
|
||||
layer's per-entry exit intent; ``layer`` allows explicit per-layer control
|
||||
when callers need partial suspend/delete behavior. A mixed session with any
|
||||
closed layer cannot be entered again because compositor entry is all-or-none.
|
||||
compositor order. Controls must be created from the matching layer schemas;
|
||||
prefer ``Compositor.new_session`` or ``Compositor.session_from_snapshot`` for
|
||||
public session construction. Broadcast methods are convenience APIs for
|
||||
setting every layer's per-entry exit intent; ``layer`` allows explicit
|
||||
per-layer control when callers need partial suspend/delete behavior. A mixed
|
||||
session with any closed layer cannot be entered again because compositor
|
||||
entry is all-or-none.
|
||||
"""
|
||||
|
||||
__slots__ = ("layer_controls",)
|
||||
|
||||
layer_controls: OrderedDict[str, LayerControl]
|
||||
|
||||
def __init__(self, layer_names: Iterable[str]) -> None:
|
||||
self.layer_controls = OrderedDict(
|
||||
(layer_name, LayerControl()) for layer_name in layer_names
|
||||
)
|
||||
def __init__(self, layer_names: Iterable[str] | Mapping[str, LayerControl]) -> None:
|
||||
if isinstance(layer_names, MappingABC):
|
||||
self.layer_controls = OrderedDict(layer_names.items())
|
||||
return
|
||||
self.layer_controls = OrderedDict((layer_name, LayerControl()) for layer_name in layer_names)
|
||||
|
||||
def suspend_on_exit(self) -> None:
|
||||
"""Request suspend behavior for every layer when this entry exits."""
|
||||
@ -188,6 +207,124 @@ class CompositorSession:
|
||||
return self.layer_controls[name]
|
||||
|
||||
|
||||
class LayerSessionSnapshot(BaseModel):
|
||||
"""Serializable snapshot for one layer control."""
|
||||
|
||||
name: str
|
||||
state: LifecycleState
|
||||
runtime_state: dict[str, JsonValue]
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class CompositorSessionSnapshot(BaseModel):
|
||||
"""Serializable compositor session snapshot.
|
||||
|
||||
Snapshots include runtime state only. Live runtime handles are intentionally
|
||||
excluded and must be rehydrated by resume hooks using runtime state.
|
||||
"""
|
||||
|
||||
schema_version: int = 1
|
||||
layers: list[LayerSessionSnapshot]
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class _LayerBuildEntry:
|
||||
name: str
|
||||
layer: Layer[Any, Any, Any, Any, Any, Any]
|
||||
deps: Mapping[str, str]
|
||||
|
||||
|
||||
class CompositorBuilder:
|
||||
"""Build compositors from registry config nodes and live instances."""
|
||||
|
||||
__slots__ = ("_registry", "_entries")
|
||||
|
||||
_registry: LayerRegistry
|
||||
_entries: list[_LayerBuildEntry]
|
||||
|
||||
def __init__(self, registry: LayerRegistry) -> None:
|
||||
self._registry = registry
|
||||
self._entries = []
|
||||
|
||||
def add_config(self, config: CompositorConfigValue) -> Self:
|
||||
"""Add all layers from a serializable compositor config."""
|
||||
conf = _validate_compositor_config_input(config)
|
||||
if conf.schema_version != 1:
|
||||
raise ValueError(f"Unsupported compositor config schema_version: {conf.schema_version}.")
|
||||
for layer_conf in conf.layers:
|
||||
self.add_config_layer(
|
||||
name=layer_conf.name,
|
||||
type=layer_conf.type,
|
||||
config=layer_conf.config,
|
||||
deps=layer_conf.deps,
|
||||
)
|
||||
return self
|
||||
|
||||
def add_config_layer(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
type: str,
|
||||
config: object | None = None,
|
||||
deps: Mapping[str, str] | None = None,
|
||||
) -> Self:
|
||||
"""Resolve, validate, and add one registry-backed layer config node."""
|
||||
descriptor = self._registry.resolve(type)
|
||||
raw_config = {} if config is None else config
|
||||
validated_config = descriptor.config_type.model_validate(raw_config)
|
||||
layer = descriptor.layer_type.from_config(cast(Any, validated_config))
|
||||
self.add_instance(name=name, layer=layer, deps=deps)
|
||||
return self
|
||||
|
||||
def add_instance(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
layer: Layer[Any, Any, Any, Any, Any, Any],
|
||||
deps: Mapping[str, str] | None = None,
|
||||
) -> Self:
|
||||
"""Add a live layer instance, useful for Python objects and callables."""
|
||||
self._entries.append(_LayerBuildEntry(name=name, layer=layer, deps=dict(deps or {})))
|
||||
return self
|
||||
|
||||
def build[PromptT, ToolT, LayerPromptT, LayerToolT](
|
||||
self,
|
||||
*,
|
||||
prompt_transformer: CompositorTransformer[LayerPromptT, PromptT] | None = None,
|
||||
tool_transformer: CompositorTransformer[LayerToolT, ToolT] | None = None,
|
||||
) -> "Compositor[PromptT, ToolT, LayerPromptT, LayerToolT]":
|
||||
"""Validate names/dependencies, bind deps, and return a compositor."""
|
||||
layers: OrderedDict[str, Layer[Any, Any, Any, Any, Any, Any]] = OrderedDict()
|
||||
deps_name_mapping: dict[str, Mapping[str, str]] = {}
|
||||
for entry in self._entries:
|
||||
if entry.name in layers:
|
||||
raise ValueError(f"Duplicate layer name '{entry.name}'.")
|
||||
layers[entry.name] = entry.layer
|
||||
deps_name_mapping[entry.name] = entry.deps
|
||||
|
||||
layer_names = set(layers)
|
||||
for layer_name, deps in deps_name_mapping.items():
|
||||
declared_deps = layers[layer_name].dependency_names()
|
||||
unknown_dep_keys = set(deps) - declared_deps
|
||||
if unknown_dep_keys:
|
||||
names = ", ".join(sorted(unknown_dep_keys))
|
||||
raise ValueError(f"Layer '{layer_name}' declares unknown dependency keys: {names}.")
|
||||
missing_targets = set(deps.values()) - layer_names
|
||||
if missing_targets:
|
||||
names = ", ".join(sorted(missing_targets))
|
||||
raise ValueError(f"Layer '{layer_name}' depends on undefined layer names: {names}.")
|
||||
|
||||
return Compositor(
|
||||
layers=layers,
|
||||
deps_name_mapping=deps_name_mapping,
|
||||
prompt_transformer=prompt_transformer,
|
||||
tool_transformer=tool_transformer,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class Compositor(Generic[PromptT, ToolT, LayerPromptT, LayerToolT]):
|
||||
"""Framework-neutral ordered layer graph with lifecycle and aggregation.
|
||||
@ -199,7 +336,7 @@ class Compositor(Generic[PromptT, ToolT, LayerPromptT, LayerToolT]):
|
||||
from exposed item types.
|
||||
"""
|
||||
|
||||
layers: OrderedDict[str, Layer[Any, Any, Any]]
|
||||
layers: OrderedDict[str, Layer[Any, Any, Any, Any, Any, Any]]
|
||||
deps_name_mapping: Mapping[str, Mapping[str, str]] = field(default_factory=dict)
|
||||
prompt_transformer: CompositorTransformer[LayerPromptT, PromptT] | None = None
|
||||
tool_transformer: CompositorTransformer[LayerToolT, ToolT] | None = None
|
||||
@ -213,19 +350,12 @@ class Compositor(Generic[PromptT, ToolT, LayerPromptT, LayerToolT]):
|
||||
cls,
|
||||
conf: CompositorConfigValue,
|
||||
*,
|
||||
registry: LayerRegistry,
|
||||
prompt_transformer: CompositorTransformer[LayerPromptT, PromptT] | None = None,
|
||||
tool_transformer: CompositorTransformer[LayerToolT, ToolT] | None = None,
|
||||
) -> Self:
|
||||
"""Create layers from config-like input and bind named dependencies."""
|
||||
conf = _validate_compositor_config_input(conf)
|
||||
layers: OrderedDict[str, Layer[Any, Any, Any]] = OrderedDict()
|
||||
for layer_conf in conf.layers:
|
||||
layers[layer_conf.name] = layer_conf.create_layer()
|
||||
|
||||
deps_name_mapping = {layer_conf.name: layer_conf.deps for layer_conf in conf.layers}
|
||||
return cls(
|
||||
layers=layers,
|
||||
deps_name_mapping=deps_name_mapping,
|
||||
) -> "Compositor[PromptT, ToolT, LayerPromptT, LayerToolT]":
|
||||
"""Create a compositor from registry-backed serializable config."""
|
||||
return CompositorBuilder(registry).add_config(conf).build(
|
||||
prompt_transformer=prompt_transformer,
|
||||
tool_transformer=tool_transformer,
|
||||
)
|
||||
@ -257,7 +387,60 @@ class Compositor(Generic[PromptT, ToolT, LayerPromptT, LayerToolT]):
|
||||
|
||||
def new_session(self) -> CompositorSession:
|
||||
"""Create a fresh lifecycle session matching this compositor's layer order."""
|
||||
return CompositorSession(self.layers)
|
||||
return CompositorSession(
|
||||
OrderedDict((layer_name, layer.new_control()) for layer_name, layer in self.layers.items())
|
||||
)
|
||||
|
||||
def snapshot_session(self, session: CompositorSession) -> CompositorSessionSnapshot:
|
||||
"""Serialize non-active session lifecycle state and runtime state.
|
||||
|
||||
Runtime handles are live Python objects and are intentionally excluded.
|
||||
"""
|
||||
self._validate_session(session)
|
||||
active_layers = [name for name, control in session.layer_controls.items() if control.state is LifecycleState.ACTIVE]
|
||||
if active_layers:
|
||||
names = ", ".join(active_layers)
|
||||
raise RuntimeError(f"Cannot snapshot active compositor session layers: {names}.")
|
||||
return CompositorSessionSnapshot(
|
||||
layers=[
|
||||
LayerSessionSnapshot(
|
||||
name=name,
|
||||
state=control.state,
|
||||
runtime_state=cast(dict[str, JsonValue], control.runtime_state.model_dump(mode="json")),
|
||||
)
|
||||
for name, control in session.layer_controls.items()
|
||||
]
|
||||
)
|
||||
|
||||
def session_from_snapshot(self, snapshot: CompositorSessionSnapshot | JsonValue | str | bytes) -> CompositorSession:
|
||||
"""Restore a session from a snapshot and reinitialize empty handles."""
|
||||
snapshot = _validate_config_model_input(CompositorSessionSnapshot, snapshot)
|
||||
if snapshot.schema_version != 1:
|
||||
raise ValueError(f"Unsupported compositor session snapshot schema_version: {snapshot.schema_version}.")
|
||||
snapshot_layer_names = tuple(layer.name for layer in snapshot.layers)
|
||||
expected_layer_names = tuple(self.layers)
|
||||
if snapshot_layer_names != expected_layer_names:
|
||||
expected = ", ".join(expected_layer_names)
|
||||
actual = ", ".join(snapshot_layer_names)
|
||||
raise ValueError(
|
||||
"CompositorSessionSnapshot layer names must match compositor layers in order. "
|
||||
f"Expected [{expected}], got [{actual}]."
|
||||
)
|
||||
active_layers = [layer.name for layer in snapshot.layers if layer.state is LifecycleState.ACTIVE]
|
||||
if active_layers:
|
||||
names = ", ".join(active_layers)
|
||||
raise ValueError(f"Cannot restore active compositor session layers from snapshot: {names}.")
|
||||
controls = OrderedDict(
|
||||
(
|
||||
layer_snapshot.name,
|
||||
self.layers[layer_snapshot.name].new_control(
|
||||
state=layer_snapshot.state,
|
||||
runtime_state=layer_snapshot.runtime_state,
|
||||
),
|
||||
)
|
||||
for layer_snapshot in snapshot.layers
|
||||
)
|
||||
return CompositorSession(controls)
|
||||
|
||||
@asynccontextmanager
|
||||
async def enter(
|
||||
@ -288,6 +471,18 @@ class Compositor(Generic[PromptT, ToolT, LayerPromptT, LayerToolT]):
|
||||
"CompositorSession layer names must match compositor layers in order. "
|
||||
f"Expected [{expected}], got [{actual}]."
|
||||
)
|
||||
for layer_name, layer in self.layers.items():
|
||||
control = session.layer_controls[layer_name]
|
||||
if not isinstance(control.runtime_state, layer.runtime_state_type):
|
||||
raise TypeError(
|
||||
f"CompositorSession layer '{layer_name}' runtime_state must be "
|
||||
f"{layer.runtime_state_type.__name__}, got {type(control.runtime_state).__name__}."
|
||||
)
|
||||
if not isinstance(control.runtime_handles, layer.runtime_handles_type):
|
||||
raise TypeError(
|
||||
f"CompositorSession layer '{layer_name}' runtime_handles must be "
|
||||
f"{layer.runtime_handles_type.__name__}, got {type(control.runtime_handles).__name__}."
|
||||
)
|
||||
|
||||
def _ensure_session_can_enter(self, session: CompositorSession) -> None:
|
||||
"""Reject active or closed layer controls before any layer side effects."""
|
||||
@ -330,14 +525,15 @@ class Compositor(Generic[PromptT, ToolT, LayerPromptT, LayerToolT]):
|
||||
|
||||
__all__ = [
|
||||
"Compositor",
|
||||
"CompositorBuilder",
|
||||
"CompositorConfig",
|
||||
"CompositorConfigValue",
|
||||
"CompositorLayerConfigInput",
|
||||
"CompositorSessionSnapshot",
|
||||
"CompositorSession",
|
||||
"CompositorTransformer",
|
||||
"CompositorTransformerKwargs",
|
||||
"CompositorLayerConfig",
|
||||
"CompositorLayerConfigValue",
|
||||
"ImportedLayerConfig",
|
||||
"LayerSpec",
|
||||
"LayerDescriptor",
|
||||
"LayerNodeConfig",
|
||||
"LayerRegistry",
|
||||
"LayerSessionSnapshot",
|
||||
]
|
||||
|
||||
@ -5,7 +5,17 @@
|
||||
families while keeping concrete reusable layers in ``agenton_collections``.
|
||||
"""
|
||||
|
||||
from agenton.layers.base import ExitIntent, Layer, LayerControl, LayerDeps, LifecycleState, NoLayerDeps
|
||||
from agenton.layers.base import (
|
||||
EmptyLayerConfig,
|
||||
EmptyRuntimeHandles,
|
||||
EmptyRuntimeState,
|
||||
ExitIntent,
|
||||
Layer,
|
||||
LayerControl,
|
||||
LayerDeps,
|
||||
LifecycleState,
|
||||
NoLayerDeps,
|
||||
)
|
||||
from agenton.layers.types import (
|
||||
AllPromptTypes,
|
||||
AllToolTypes,
|
||||
@ -29,6 +39,9 @@ __all__ = [
|
||||
"LayerControl",
|
||||
"LifecycleState",
|
||||
"ExitIntent",
|
||||
"EmptyLayerConfig",
|
||||
"EmptyRuntimeState",
|
||||
"EmptyRuntimeHandles",
|
||||
"NoLayerDeps",
|
||||
"PlainLayer",
|
||||
"PlainPrompt",
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
"""Core layer abstractions and typed dependency binding.
|
||||
|
||||
Layers declare their dependency shape with ``Layer[DepsT, PromptT, ToolT]``.
|
||||
Layers declare their dependency shape with ``Layer[DepsT, PromptT, ToolT, ...]``.
|
||||
``DepsT`` must be a ``LayerDeps`` subclass whose annotated members are concrete
|
||||
``Layer`` subclasses or modern optional dependencies such as ``SomeLayer |
|
||||
None``. The base class infers ``deps_type`` from the generic base when possible,
|
||||
while still allowing subclasses to set ``deps_type`` explicitly for unusual
|
||||
inheritance patterns.
|
||||
None``. The optional trailing generic slots declare Pydantic schemas for config,
|
||||
serializable runtime state, and live runtime handles. The base class infers
|
||||
``deps_type`` and schema class attributes from the generic base when possible,
|
||||
while still allowing subclasses to set them explicitly for unusual inheritance
|
||||
patterns.
|
||||
|
||||
``Layer.bind_deps`` is the mutation point for dependency state. Layer
|
||||
implementations should treat ``self.deps`` as unavailable until a compositor or
|
||||
@ -17,9 +19,10 @@ machine and per-session runtime owner. A fresh control starts in
|
||||
while active or closed controls are rejected to prevent ambiguous nested or
|
||||
post-delete reuse. Exit behavior is selected per entry with ``ExitIntent`` and
|
||||
resets to delete on every successful enter. Layer instances are shared graph and
|
||||
capability definitions, so session-local ids, handles, clients, and other
|
||||
runtime values generated by lifecycle hooks belong in
|
||||
``LayerControl.runtime_state`` rather than on ``self``.
|
||||
capability definitions, so session-local serializable ids, checkpoints, and
|
||||
other snapshot data belong in ``LayerControl.runtime_state``; live clients,
|
||||
connections, and process handles belong in ``LayerControl.runtime_handles``.
|
||||
Neither category should be stored on ``self`` when it is session-local.
|
||||
|
||||
``Layer`` is framework-neutral over prompt and tool item types. The native
|
||||
``prefix_prompts``, ``suffix_prompts``, and ``tools`` properties are the layer
|
||||
@ -34,9 +37,18 @@ from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from enum import StrEnum
|
||||
from types import UnionType
|
||||
from typing import Any, Mapping, Sequence, Union, cast, get_args, get_origin, get_type_hints
|
||||
from typing import Any, ClassVar, Generic, Mapping, Sequence, Union, cast, get_args, get_origin, get_type_hints
|
||||
|
||||
from typing_extensions import Self
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from typing_extensions import Self, TypeVar
|
||||
|
||||
|
||||
_DepsT = TypeVar("_DepsT", bound="LayerDeps")
|
||||
_PromptT = TypeVar("_PromptT")
|
||||
_ToolT = TypeVar("_ToolT")
|
||||
_ConfigT = TypeVar("_ConfigT", bound=BaseModel, default="EmptyLayerConfig")
|
||||
_RuntimeStateT = TypeVar("_RuntimeStateT", bound=BaseModel, default="EmptyRuntimeState")
|
||||
_RuntimeHandlesT = TypeVar("_RuntimeHandlesT", bound=BaseModel, default="EmptyRuntimeHandles")
|
||||
|
||||
|
||||
class LayerDeps:
|
||||
@ -47,7 +59,7 @@ class LayerDeps:
|
||||
are always assigned as attributes; missing optional values become ``None``.
|
||||
"""
|
||||
|
||||
def __init__(self, **deps: "Layer[Any, Any, Any] | None") -> None:
|
||||
def __init__(self, **deps: "Layer[Any, Any, Any, Any, Any, Any] | None") -> None:
|
||||
dep_specs = _get_dep_specs(type(self))
|
||||
missing_names = {name for name, spec in dep_specs.items() if not spec.optional} - deps.keys()
|
||||
if missing_names:
|
||||
@ -79,6 +91,28 @@ class NoLayerDeps(LayerDeps):
|
||||
"""Dependency container for layers that do not require other layers."""
|
||||
|
||||
|
||||
class EmptyLayerConfig(BaseModel):
|
||||
"""Default serializable config schema for layers without config."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class EmptyRuntimeState(BaseModel):
|
||||
"""Default serializable per-session runtime state schema."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", validate_assignment=True)
|
||||
|
||||
|
||||
class EmptyRuntimeHandles(BaseModel):
|
||||
"""Default live per-session runtime handle schema.
|
||||
|
||||
Handles may contain arbitrary Python objects and are intentionally excluded
|
||||
from session snapshots.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", validate_assignment=True, arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class LifecycleState(StrEnum):
|
||||
"""Externally observable lifecycle state for a layer control."""
|
||||
|
||||
@ -96,26 +130,31 @@ class ExitIntent(StrEnum):
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class LayerControl:
|
||||
class LayerControl(Generic[_RuntimeStateT, _RuntimeHandlesT]):
|
||||
"""Stateful control slot passed into a layer entry context.
|
||||
|
||||
``Layer.enter`` requires the caller to provide this object. The control owns
|
||||
the layer lifecycle state, the current entry's exit intent, and arbitrary
|
||||
per-session runtime state. Call ``suspend_on_exit`` before leaving the
|
||||
per-session runtime state and live handles. Call ``suspend_on_exit`` before leaving the
|
||||
context to make a later entry resume; call ``delete_on_exit`` or do nothing
|
||||
for the default delete behavior. Store session-local ids and resource
|
||||
handles in ``runtime_state`` so concurrent or later sessions do not share
|
||||
mutable runtime data through the layer instance.
|
||||
for the default delete behavior. Store session-local serializable ids,
|
||||
checkpoints, and other snapshot data in ``runtime_state``. Store live
|
||||
clients, connections, process handles, and other non-serializable objects in
|
||||
``runtime_handles``. Do not put either kind of session-local data on the
|
||||
shared layer instance.
|
||||
|
||||
``runtime_state`` intentionally persists after suspend and delete. Suspend,
|
||||
resume, and delete hooks can inspect the same values created on entry, and
|
||||
callers may inspect closed-session diagnostics after exit. Reuse is still
|
||||
governed by ``state``: a closed control cannot be entered again.
|
||||
governed by ``state``: a closed control cannot be entered again. Runtime
|
||||
handles are not serialized in snapshots and should be rehydrated from
|
||||
runtime state in resume hooks.
|
||||
"""
|
||||
|
||||
state: LifecycleState = LifecycleState.NEW
|
||||
exit_intent: ExitIntent = ExitIntent.DELETE
|
||||
runtime_state: dict[str, object] = field(default_factory=dict)
|
||||
runtime_state: _RuntimeStateT = field(default_factory=lambda: cast(_RuntimeStateT, EmptyRuntimeState()))
|
||||
runtime_handles: _RuntimeHandlesT = field(default_factory=lambda: cast(_RuntimeHandlesT, EmptyRuntimeHandles()))
|
||||
|
||||
def suspend_on_exit(self) -> None:
|
||||
"""Request suspend behavior when the current layer entry exits."""
|
||||
@ -130,11 +169,14 @@ class LayerControl:
|
||||
class LayerDepSpec:
|
||||
"""Runtime dependency specification derived from a deps annotation."""
|
||||
|
||||
layer_type: type["Layer[Any, Any, Any]"]
|
||||
layer_type: type["Layer[Any, Any, Any, Any, Any, Any]"]
|
||||
optional: bool = False
|
||||
|
||||
|
||||
class Layer[DepsT: LayerDeps, PromptT, ToolT](ABC):
|
||||
class Layer(
|
||||
ABC,
|
||||
Generic[_DepsT, _PromptT, _ToolT, _ConfigT, _RuntimeStateT, _RuntimeHandlesT],
|
||||
):
|
||||
"""Framework-neutral base class for prompt/tool layers.
|
||||
|
||||
Subclasses expose optional prompt fragments and tools through typed
|
||||
@ -147,15 +189,20 @@ class Layer[DepsT: LayerDeps, PromptT, ToolT](ABC):
|
||||
extra runtime resources.
|
||||
"""
|
||||
|
||||
deps_type: type[DepsT]
|
||||
deps: DepsT
|
||||
deps_type: type[_DepsT]
|
||||
deps: _DepsT
|
||||
type_id: ClassVar[str | None] = None
|
||||
config_type: ClassVar[type[BaseModel]] = EmptyLayerConfig
|
||||
runtime_state_type: ClassVar[type[BaseModel]] = EmptyRuntimeState
|
||||
runtime_handles_type: ClassVar[type[BaseModel]] = EmptyRuntimeHandles
|
||||
|
||||
def __init_subclass__(cls) -> None:
|
||||
super().__init_subclass__()
|
||||
is_generic_template = _is_generic_layer_template(cls)
|
||||
deps_type = cls.__dict__.get("deps_type")
|
||||
if deps_type is None:
|
||||
deps_type = _infer_deps_type(cls) or getattr(cls, "deps_type", None)
|
||||
if deps_type is None and _is_generic_layer_template(cls):
|
||||
if deps_type is None and is_generic_template:
|
||||
return
|
||||
if deps_type is not None:
|
||||
cls.deps_type = deps_type # pyright: ignore[reportAttributeAccessIssue]
|
||||
@ -164,25 +211,63 @@ class Layer[DepsT: LayerDeps, PromptT, ToolT](ABC):
|
||||
if not isinstance(deps_type, type) or not issubclass(deps_type, LayerDeps):
|
||||
raise TypeError(f"{cls.__name__}.deps_type must be a LayerDeps subclass.")
|
||||
_get_dep_specs(deps_type)
|
||||
_init_schema_type(cls, "config_type", _infer_schema_type(cls, 3, "config_type"), EmptyLayerConfig)
|
||||
_init_schema_type(
|
||||
cls,
|
||||
"runtime_state_type",
|
||||
_infer_schema_type(cls, 4, "runtime_state_type"),
|
||||
EmptyRuntimeState,
|
||||
)
|
||||
_init_schema_type(
|
||||
cls,
|
||||
"runtime_handles_type",
|
||||
_infer_schema_type(cls, 5, "runtime_handles_type"),
|
||||
EmptyRuntimeHandles,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls: type[Self], config: Any) -> Self:
|
||||
"""Create a layer from serialized config.
|
||||
def from_config(cls: type[Self], config: _ConfigT) -> Self:
|
||||
"""Create a layer from schema-validated serialized config.
|
||||
|
||||
Layers are not config-constructible by default. Subclasses that accept
|
||||
config should override this method and validate dynamic input before
|
||||
constructing the layer.
|
||||
Registries/builders validate raw config with ``config_type`` before
|
||||
calling this method. Layers are not config-constructible by default.
|
||||
Subclasses that accept config should override this method and consume
|
||||
the typed Pydantic model for their schema.
|
||||
"""
|
||||
raise TypeError(f"{cls.__name__} cannot be created from config.")
|
||||
|
||||
def bind_deps(self, deps: Mapping[str, "Layer[Any, Any, Any] | None"]) -> None:
|
||||
@classmethod
|
||||
def dependency_names(cls) -> frozenset[str]:
|
||||
"""Return dependency field names declared by this layer's deps schema."""
|
||||
return frozenset(_get_dep_specs(cls.deps_type))
|
||||
|
||||
def new_control(
|
||||
self,
|
||||
*,
|
||||
state: LifecycleState = LifecycleState.NEW,
|
||||
runtime_state: object | None = None,
|
||||
) -> LayerControl[_RuntimeStateT, _RuntimeHandlesT]:
|
||||
"""Create a schema-validated per-session control for this layer.
|
||||
|
||||
``runtime_state`` is validated through ``runtime_state_type`` and live
|
||||
handles are always initialized empty through ``runtime_handles_type``.
|
||||
"""
|
||||
raw_runtime_state = {} if runtime_state is None else runtime_state
|
||||
return LayerControl(
|
||||
state=state,
|
||||
exit_intent=ExitIntent.DELETE,
|
||||
runtime_state=cast(_RuntimeStateT, self.runtime_state_type.model_validate(raw_runtime_state)),
|
||||
runtime_handles=cast(_RuntimeHandlesT, self.runtime_handles_type.model_validate({})),
|
||||
)
|
||||
|
||||
def bind_deps(self, deps: Mapping[str, "Layer[Any, Any, Any, Any, Any, Any] | None"]) -> None:
|
||||
"""Bind this layer's declared dependencies from a name-to-layer mapping.
|
||||
|
||||
The mapping may include more layers than the declared dependency fields.
|
||||
Only names declared by ``deps_type`` are selected and validated. Missing
|
||||
optional deps are bound as ``None``.
|
||||
"""
|
||||
resolved_deps: dict[str, Layer[Any, Any, Any] | None] = {}
|
||||
resolved_deps: dict[str, Layer[Any, Any, Any, Any, Any, Any] | None] = {}
|
||||
for name, spec in _get_dep_specs(self.deps_type).items():
|
||||
if name not in deps:
|
||||
if spec.optional:
|
||||
@ -194,7 +279,7 @@ class Layer[DepsT: LayerDeps, PromptT, ToolT](ABC):
|
||||
resolved_deps[name] = deps[name]
|
||||
self.deps = self.deps_type(**resolved_deps)
|
||||
|
||||
def enter(self, control: LayerControl) -> AbstractAsyncContextManager[None]:
|
||||
def enter(self, control: LayerControl[_RuntimeStateT, _RuntimeHandlesT]) -> AbstractAsyncContextManager[None]:
|
||||
"""Return the layer's async entry context manager.
|
||||
|
||||
``control`` is the lifecycle control slot for this entry. Subclasses can
|
||||
@ -204,7 +289,7 @@ class Layer[DepsT: LayerDeps, PromptT, ToolT](ABC):
|
||||
return self.lifecycle_enter(control)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifecycle_enter(self, control: LayerControl) -> AsyncIterator[None]:
|
||||
async def lifecycle_enter(self, control: LayerControl[_RuntimeStateT, _RuntimeHandlesT]) -> AsyncIterator[None]:
|
||||
"""Run the default explicit lifecycle state machine for one entry."""
|
||||
if control.state is LifecycleState.NEW:
|
||||
control.exit_intent = ExitIntent.DELETE
|
||||
@ -233,37 +318,37 @@ class Layer[DepsT: LayerDeps, PromptT, ToolT](ABC):
|
||||
await self.on_context_delete(control)
|
||||
control.state = LifecycleState.CLOSED
|
||||
|
||||
async def on_context_create(self, control: LayerControl) -> None:
|
||||
async def on_context_create(self, control: LayerControl[_RuntimeStateT, _RuntimeHandlesT]) -> None:
|
||||
"""Run when the layer context is entered from ``LifecycleState.NEW``."""
|
||||
|
||||
async def on_context_delete(self, control: LayerControl) -> None:
|
||||
async def on_context_delete(self, control: LayerControl[_RuntimeStateT, _RuntimeHandlesT]) -> None:
|
||||
"""Run when the layer context exits with ``ExitIntent.DELETE``."""
|
||||
|
||||
async def on_context_suspend(self, control: LayerControl) -> None:
|
||||
async def on_context_suspend(self, control: LayerControl[_RuntimeStateT, _RuntimeHandlesT]) -> None:
|
||||
"""Run when the layer context exits with ``ExitIntent.SUSPEND``."""
|
||||
|
||||
async def on_context_resume(self, control: LayerControl) -> None:
|
||||
async def on_context_resume(self, control: LayerControl[_RuntimeStateT, _RuntimeHandlesT]) -> None:
|
||||
"""Run when the layer context enters from ``LifecycleState.SUSPENDED``."""
|
||||
|
||||
@property
|
||||
def prefix_prompts(self) -> Sequence[PromptT]:
|
||||
def prefix_prompts(self) -> Sequence[_PromptT]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def suffix_prompts(self) -> Sequence[PromptT]:
|
||||
def suffix_prompts(self) -> Sequence[_PromptT]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def tools(self) -> Sequence[ToolT]:
|
||||
def tools(self) -> Sequence[_ToolT]:
|
||||
return []
|
||||
|
||||
@abstractmethod
|
||||
def wrap_prompt(self, prompt: PromptT) -> object:
|
||||
def wrap_prompt(self, prompt: _PromptT) -> object:
|
||||
"""Wrap a native prompt item for compositor aggregation."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def wrap_tool(self, tool: ToolT) -> object:
|
||||
def wrap_tool(self, tool: _ToolT) -> object:
|
||||
"""Wrap a native tool item for compositor aggregation."""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -297,45 +382,107 @@ def _as_dep_spec(annotation: object) -> LayerDepSpec | None:
|
||||
return LayerDepSpec(layer_type=layer_type)
|
||||
|
||||
|
||||
def _as_layer_type(annotation: object) -> type[Layer[Any, Any, Any]] | None:
|
||||
def _as_layer_type(annotation: object) -> type[Layer[Any, Any, Any, Any, Any, Any]] | None:
|
||||
runtime_type = get_origin(annotation) or annotation
|
||||
if isinstance(runtime_type, type) and issubclass(runtime_type, Layer):
|
||||
return cast(type[Layer[Any, Any, Any]], runtime_type)
|
||||
return cast(type[Layer[Any, Any, Any, Any, Any, Any]], runtime_type)
|
||||
return None
|
||||
|
||||
|
||||
def _infer_deps_type(layer_type: type[Layer[Any, Any, Any]]) -> type[LayerDeps] | None:
|
||||
return _infer_deps_type_from_bases(layer_type, {})
|
||||
def _infer_deps_type(layer_type: type[Layer[Any, Any, Any, Any, Any, Any]]) -> type[LayerDeps] | None:
|
||||
inferred = _infer_layer_generic_arg(layer_type, 0, {})
|
||||
if inferred is None:
|
||||
return None
|
||||
return _as_deps_type(inferred)
|
||||
|
||||
|
||||
def _infer_deps_type_from_bases(
|
||||
layer_type: type[Layer[Any, Any, Any]],
|
||||
def _infer_schema_type(
|
||||
layer_type: type[Layer[Any, Any, Any, Any, Any, Any]],
|
||||
index: int,
|
||||
attr_name: str,
|
||||
) -> type[BaseModel] | None:
|
||||
inferred = _infer_schema_generic_arg(layer_type, attr_name, {}) or _infer_layer_generic_arg(layer_type, index, {})
|
||||
if inferred is None:
|
||||
return None
|
||||
schema_type = _as_model_type(inferred)
|
||||
if schema_type is None:
|
||||
raise TypeError(f"{layer_type.__name__}.{attr_name} must be a Pydantic BaseModel subclass.")
|
||||
return schema_type
|
||||
|
||||
|
||||
def _infer_schema_generic_arg(
|
||||
layer_type: type[Layer[Any, Any, Any, Any, Any, Any]],
|
||||
attr_name: str,
|
||||
substitutions: Mapping[object, object],
|
||||
) -> type[LayerDeps] | None:
|
||||
"""Infer the concrete deps container through generic Layer inheritance.
|
||||
) -> object | None:
|
||||
"""Infer schema type arguments exposed by typed layer family bases."""
|
||||
expected_names = {
|
||||
"config_type": {"ConfigT", "_ConfigT"},
|
||||
"runtime_state_type": {"RuntimeStateT", "_RuntimeStateT"},
|
||||
"runtime_handles_type": {"RuntimeHandlesT", "_RuntimeHandlesT"},
|
||||
}[attr_name]
|
||||
for base in getattr(layer_type, "__orig_bases__", ()):
|
||||
origin = get_origin(base) or base
|
||||
args = tuple(_substitute_type(arg, substitutions) for arg in get_args(base))
|
||||
if not isinstance(origin, type) or not issubclass(origin, Layer):
|
||||
continue
|
||||
|
||||
params = _generic_params(origin)
|
||||
for param, arg in zip(params, args):
|
||||
if getattr(param, "__name__", None) in expected_names:
|
||||
return arg
|
||||
|
||||
next_substitutions = dict(substitutions)
|
||||
next_substitutions.update(_generic_arg_substitutions(origin, args))
|
||||
inferred = _infer_schema_generic_arg(origin, attr_name, next_substitutions)
|
||||
if inferred is not None:
|
||||
return inferred
|
||||
return None
|
||||
|
||||
|
||||
def _infer_layer_generic_arg(
|
||||
layer_type: type[Layer[Any, Any, Any, Any, Any, Any]],
|
||||
index: int,
|
||||
substitutions: Mapping[object, object],
|
||||
) -> object | None:
|
||||
"""Infer one concrete ``Layer`` generic argument through inheritance.
|
||||
|
||||
This walks through intermediate generic base classes so subclasses can omit
|
||||
an explicit ``deps_type`` in common cases such as ``class X(Base[YDeps])``.
|
||||
explicit class attributes in common cases such as ``class X(Base[YDeps])``.
|
||||
"""
|
||||
for base in getattr(layer_type, "__orig_bases__", ()):
|
||||
origin = get_origin(base) or base
|
||||
args = tuple(_substitute_type(arg, substitutions) for arg in get_args(base))
|
||||
if origin is Layer:
|
||||
if not args:
|
||||
if len(args) <= index:
|
||||
continue
|
||||
return _as_deps_type(args[0])
|
||||
return args[index]
|
||||
|
||||
if not isinstance(origin, type) or not issubclass(origin, Layer):
|
||||
continue
|
||||
|
||||
next_substitutions = dict(substitutions)
|
||||
next_substitutions.update(_generic_arg_substitutions(origin, args))
|
||||
inferred = _infer_deps_type_from_bases(origin, next_substitutions)
|
||||
inferred = _infer_layer_generic_arg(origin, index, next_substitutions)
|
||||
if inferred is not None:
|
||||
return inferred
|
||||
return None
|
||||
|
||||
|
||||
def _init_schema_type(
|
||||
layer_type: type[Layer[Any, Any, Any, Any, Any, Any]],
|
||||
attr_name: str,
|
||||
inferred_schema_type: type[BaseModel] | None,
|
||||
default_schema_type: type[BaseModel],
|
||||
) -> None:
|
||||
schema_type = layer_type.__dict__.get(attr_name)
|
||||
if schema_type is None:
|
||||
schema_type = inferred_schema_type or getattr(layer_type, attr_name, default_schema_type)
|
||||
setattr(layer_type, attr_name, schema_type)
|
||||
if not isinstance(schema_type, type) or not issubclass(schema_type, BaseModel):
|
||||
raise TypeError(f"{layer_type.__name__}.{attr_name} must be a Pydantic BaseModel subclass.")
|
||||
|
||||
|
||||
def _substitute_type(value: object, substitutions: Mapping[object, object]) -> object:
|
||||
if value in substitutions:
|
||||
return substitutions[value]
|
||||
@ -359,10 +506,15 @@ def _substitute_type(value: object, substitutions: Mapping[object, object]) -> o
|
||||
|
||||
|
||||
def _generic_arg_substitutions(origin: type[Any], args: Sequence[object]) -> dict[object, object]:
|
||||
params = _generic_params(origin)
|
||||
return dict(zip(params, args))
|
||||
|
||||
|
||||
def _generic_params(origin: type[Any]) -> Sequence[object]:
|
||||
params = getattr(origin, "__type_params__", ())
|
||||
if not params:
|
||||
params = getattr(origin, "__parameters__", ())
|
||||
return dict(zip(params, args))
|
||||
return params
|
||||
|
||||
|
||||
def _as_deps_type(value: object) -> type[LayerDeps] | None:
|
||||
@ -372,7 +524,14 @@ def _as_deps_type(value: object) -> type[LayerDeps] | None:
|
||||
return None
|
||||
|
||||
|
||||
def _is_generic_layer_template(layer_type: type[Layer[Any, Any, Any]]) -> bool:
|
||||
def _as_model_type(value: object) -> type[BaseModel] | None:
|
||||
runtime_type = get_origin(value) or value
|
||||
if isinstance(runtime_type, type) and issubclass(runtime_type, BaseModel):
|
||||
return runtime_type
|
||||
return None
|
||||
|
||||
|
||||
def _is_generic_layer_template(layer_type: type[Layer[Any, Any, Any, Any, Any, Any]]) -> bool:
|
||||
return bool(getattr(layer_type, "__type_params__", ())) or bool(
|
||||
getattr(layer_type, "__parameters__", ())
|
||||
)
|
||||
|
||||
@ -2,7 +2,10 @@
|
||||
|
||||
``Layer`` itself is framework-neutral. This module defines typed layer families
|
||||
that bind its prompt/tool generic slots to concrete contracts, such as ordinary
|
||||
string prompts with plain callable tools or pydantic-ai prompt/tool shapes.
|
||||
string prompts with plain callable tools or pydantic-ai prompt/tool shapes. The
|
||||
families keep the trailing schema generic slots open so concrete layers can have
|
||||
``config_type``, ``runtime_state_type``, and ``runtime_handles_type`` inferred
|
||||
from type arguments instead of repeated class attributes.
|
||||
Tagged aggregate aliases cover code paths that can accept any supported
|
||||
prompt/tool family without changing the plain and pydantic-ai layer contracts.
|
||||
Pydantic-ai names are imported for static analysis only, so ``agenton`` can be
|
||||
@ -14,15 +17,17 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
from typing import TYPE_CHECKING, Any, Generic, Literal
|
||||
|
||||
from typing_extensions import final, override
|
||||
from typing_extensions import TypeVar, final, override
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_ai import Tool
|
||||
from pydantic_ai.tools import SystemPromptFunc
|
||||
|
||||
from agenton.layers.base import Layer, LayerDeps
|
||||
from pydantic import BaseModel
|
||||
|
||||
from agenton.layers.base import EmptyLayerConfig, EmptyRuntimeHandles, EmptyRuntimeState, Layer, LayerDeps
|
||||
|
||||
type PlainPrompt = str
|
||||
type PlainTool = Callable[..., Any]
|
||||
@ -68,7 +73,17 @@ type AllPromptTypes = PlainPromptType | PydanticAIPromptType[Any]
|
||||
type AllToolTypes = PlainToolType | PydanticAIToolType[Any]
|
||||
|
||||
|
||||
class PlainLayer[DepsT: LayerDeps](Layer[DepsT, PlainPrompt, PlainTool]):
|
||||
_DepsT = TypeVar("_DepsT", bound=LayerDeps)
|
||||
_ConfigT = TypeVar("_ConfigT", bound=BaseModel, default=EmptyLayerConfig)
|
||||
_RuntimeStateT = TypeVar("_RuntimeStateT", bound=BaseModel, default=EmptyRuntimeState)
|
||||
_RuntimeHandlesT = TypeVar("_RuntimeHandlesT", bound=BaseModel, default=EmptyRuntimeHandles)
|
||||
_AgentDepsT = TypeVar("_AgentDepsT")
|
||||
|
||||
|
||||
class PlainLayer(
|
||||
Generic[_DepsT, _ConfigT, _RuntimeStateT, _RuntimeHandlesT],
|
||||
Layer[_DepsT, PlainPrompt, PlainTool, _ConfigT, _RuntimeStateT, _RuntimeHandlesT],
|
||||
):
|
||||
"""Layer base for ordinary string prompts and plain-callable tools."""
|
||||
|
||||
@final
|
||||
@ -82,8 +97,16 @@ class PlainLayer[DepsT: LayerDeps](Layer[DepsT, PlainPrompt, PlainTool]):
|
||||
return PlainToolType(tool)
|
||||
|
||||
|
||||
class PydanticAILayer[DepsT: LayerDeps, AgentDepsT](
|
||||
Layer[DepsT, PydanticAIPrompt[AgentDepsT], PydanticAITool[AgentDepsT]]
|
||||
class PydanticAILayer(
|
||||
Generic[_DepsT, _AgentDepsT, _ConfigT, _RuntimeStateT, _RuntimeHandlesT],
|
||||
Layer[
|
||||
_DepsT,
|
||||
PydanticAIPrompt[_AgentDepsT],
|
||||
PydanticAITool[_AgentDepsT],
|
||||
_ConfigT,
|
||||
_RuntimeStateT,
|
||||
_RuntimeHandlesT,
|
||||
],
|
||||
):
|
||||
"""Layer base for pydantic-ai prompt and tool adapters."""
|
||||
|
||||
@ -91,13 +114,13 @@ class PydanticAILayer[DepsT: LayerDeps, AgentDepsT](
|
||||
@override
|
||||
def wrap_prompt(
|
||||
self,
|
||||
prompt: PydanticAIPrompt[AgentDepsT],
|
||||
) -> PydanticAIPromptType[AgentDepsT]:
|
||||
prompt: PydanticAIPrompt[_AgentDepsT],
|
||||
) -> PydanticAIPromptType[_AgentDepsT]:
|
||||
return PydanticAIPromptType(prompt)
|
||||
|
||||
@final
|
||||
@override
|
||||
def wrap_tool(self, tool: PydanticAITool[AgentDepsT]) -> PydanticAIToolType[AgentDepsT]:
|
||||
def wrap_tool(self, tool: PydanticAITool[_AgentDepsT]) -> PydanticAIToolType[_AgentDepsT]:
|
||||
return PydanticAIToolType(tool)
|
||||
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
"""Reusable collection layers for the plain layer family."""
|
||||
|
||||
from agenton_collections.layers.plain.basic import ObjectLayer, PromptLayer, ToolsLayer
|
||||
from agenton_collections.layers.plain.basic import ObjectLayer, PromptLayer, PromptLayerConfig, ToolsLayer
|
||||
from agenton_collections.layers.plain.dynamic_tools import (
|
||||
DynamicToolsLayer,
|
||||
DynamicToolsLayerDeps,
|
||||
@ -12,6 +12,7 @@ __all__ = [
|
||||
"DynamicToolsLayerDeps",
|
||||
"ObjectLayer",
|
||||
"PromptLayer",
|
||||
"PromptLayerConfig",
|
||||
"ToolsLayer",
|
||||
"with_object",
|
||||
]
|
||||
|
||||
@ -10,30 +10,46 @@ from collections.abc import Callable, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from agenton.layers.base import NoLayerDeps
|
||||
from agenton.layers.types import PlainLayer
|
||||
|
||||
|
||||
class PromptLayerConfig(BaseModel):
|
||||
"""Serializable config schema for ``PromptLayer``."""
|
||||
|
||||
prefix: list[str] | str = Field(default_factory=list)
|
||||
suffix: list[str] | str = Field(default_factory=list)
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ObjectLayer[ObjectT](PlainLayer[NoLayerDeps]):
|
||||
"""Layer that stores one typed object for downstream dependencies."""
|
||||
"""Layer that stores one typed object for downstream dependencies.
|
||||
|
||||
Object layers are instance-only because arbitrary Python objects are not
|
||||
serializable graph config. Add them with ``CompositorBuilder.add_instance``.
|
||||
"""
|
||||
|
||||
value: ObjectT
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptLayer(PlainLayer[NoLayerDeps]):
|
||||
class PromptLayer(PlainLayer[NoLayerDeps, PromptLayerConfig]):
|
||||
"""Layer that contributes configured prefix and suffix prompt fragments."""
|
||||
|
||||
type_id = "plain.prompt"
|
||||
|
||||
prefix: list[str] | str = field(default_factory=list)
|
||||
suffix: list[str] | str = field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Any):
|
||||
"""Validate prompt config against this dataclass."""
|
||||
return _PROMPT_LAYER_ADAPTER.validate_python(config)
|
||||
def from_config(cls, config: BaseModel):
|
||||
"""Create a prompt layer from validated prompt config."""
|
||||
validated_config = PromptLayerConfig.model_validate(config)
|
||||
return cls(prefix=validated_config.prefix, suffix=validated_config.suffix)
|
||||
|
||||
@property
|
||||
def prefix_prompts(self) -> list[str]:
|
||||
@ -50,7 +66,11 @@ class PromptLayer(PlainLayer[NoLayerDeps]):
|
||||
|
||||
@dataclass
|
||||
class ToolsLayer(PlainLayer[NoLayerDeps]):
|
||||
"""Layer that contributes configured plain-callable tools."""
|
||||
"""Layer that contributes configured plain-callable tools.
|
||||
|
||||
Tool layers are instance-only because Python callables are live objects. Add
|
||||
them with ``CompositorBuilder.add_instance``.
|
||||
"""
|
||||
|
||||
tool_entries: Sequence[Callable[..., Any]] = ()
|
||||
|
||||
@ -59,10 +79,9 @@ class ToolsLayer(PlainLayer[NoLayerDeps]):
|
||||
return list(self.tool_entries)
|
||||
|
||||
|
||||
_PROMPT_LAYER_ADAPTER = TypeAdapter(PromptLayer)
|
||||
|
||||
__all__ = [
|
||||
"ObjectLayer",
|
||||
"PromptLayerConfig",
|
||||
"PromptLayer",
|
||||
"ToolsLayer",
|
||||
]
|
||||
|
||||
@ -0,0 +1,257 @@
|
||||
import asyncio
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, ValidationError
|
||||
from typing_extensions import override
|
||||
|
||||
from agenton.compositor import Compositor, CompositorBuilder, CompositorSession, LayerRegistry
|
||||
from agenton.layers import EmptyLayerConfig, LayerControl, LayerDeps, NoLayerDeps, PlainLayer, PlainPromptType, PlainToolType
|
||||
from agenton_collections.layers.plain import ObjectLayer, PromptLayer
|
||||
|
||||
|
||||
def test_registry_infers_descriptor_and_rejects_duplicate_or_missing_type_id() -> None:
|
||||
registry = LayerRegistry()
|
||||
registry.register_layer(PromptLayer)
|
||||
|
||||
descriptor = registry.resolve("plain.prompt")
|
||||
assert descriptor.layer_type is PromptLayer
|
||||
assert descriptor.config_type is PromptLayer.config_type
|
||||
|
||||
try:
|
||||
registry.register_layer(PromptLayer)
|
||||
except ValueError as e:
|
||||
assert str(e) == "Layer type id 'plain.prompt' is already registered."
|
||||
else:
|
||||
raise AssertionError("Expected ValueError.")
|
||||
|
||||
try:
|
||||
registry.register_layer(InstanceOnlyLayer)
|
||||
except ValueError as e:
|
||||
assert "must declare a type_id" in str(e)
|
||||
else:
|
||||
raise AssertionError("Expected ValueError.")
|
||||
|
||||
try:
|
||||
registry.register_layer(InstanceOnlyLayer, type_id=123) # pyright: ignore[reportArgumentType]
|
||||
except TypeError as e:
|
||||
assert str(e) == "Layer type id for 'InstanceOnlyLayer' must be a string."
|
||||
else:
|
||||
raise AssertionError("Expected TypeError.")
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class InstanceOnlyLayer(PlainLayer[NoLayerDeps]):
|
||||
pass
|
||||
|
||||
|
||||
def test_builder_creates_config_layers_with_typed_validation() -> None:
|
||||
registry = LayerRegistry()
|
||||
registry.register_layer(PromptLayer)
|
||||
|
||||
compositor = (
|
||||
CompositorBuilder(registry)
|
||||
.add_config_layer(
|
||||
name="prompt",
|
||||
type="plain.prompt",
|
||||
config={"prefix": "hello", "suffix": ["bye"]},
|
||||
)
|
||||
.build()
|
||||
)
|
||||
|
||||
assert [prompt.value for prompt in compositor.prompts] == ["hello", "bye"]
|
||||
|
||||
try:
|
||||
CompositorBuilder(registry).add_config_layer(
|
||||
name="bad",
|
||||
type="plain.prompt",
|
||||
config={"unknown": "field"},
|
||||
)
|
||||
except ValidationError:
|
||||
pass
|
||||
else:
|
||||
raise AssertionError("Expected ValidationError.")
|
||||
|
||||
|
||||
class ObjectConsumerDeps(LayerDeps):
|
||||
obj: ObjectLayer[str] # pyright: ignore[reportUninitializedInstanceVariable]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ObjectConsumerLayer(PlainLayer[ObjectConsumerDeps]):
|
||||
@property
|
||||
@override
|
||||
def prefix_prompts(self) -> list[str]:
|
||||
return [self.deps.obj.value]
|
||||
|
||||
|
||||
def test_builder_mixes_config_and_instances_and_rejects_invalid_deps() -> None:
|
||||
registry = LayerRegistry()
|
||||
registry.register_layer(PromptLayer)
|
||||
|
||||
compositor = (
|
||||
CompositorBuilder(registry)
|
||||
.add_config({"layers": [{"name": "prompt", "type": "plain.prompt", "config": {"prefix": "cfg"}}]})
|
||||
.add_instance(name="obj", layer=ObjectLayer("instance"))
|
||||
.add_instance(name="consumer", layer=ObjectConsumerLayer(), deps={"obj": "obj"})
|
||||
.build()
|
||||
)
|
||||
|
||||
assert [prompt.value for prompt in compositor.prompts] == ["cfg", "instance"]
|
||||
|
||||
try:
|
||||
CompositorBuilder(registry).add_instance(
|
||||
name="consumer",
|
||||
layer=ObjectConsumerLayer(),
|
||||
deps={"missing_dep_key": "obj"},
|
||||
).build()
|
||||
except ValueError as e:
|
||||
assert str(e) == "Layer 'consumer' declares unknown dependency keys: missing_dep_key."
|
||||
else:
|
||||
raise AssertionError("Expected ValueError.")
|
||||
|
||||
try:
|
||||
CompositorBuilder(registry).add_instance(
|
||||
name="consumer",
|
||||
layer=ObjectConsumerLayer(),
|
||||
deps={"obj": "missing_target"},
|
||||
).build()
|
||||
except ValueError as e:
|
||||
assert str(e) == "Layer 'consumer' depends on undefined layer names: missing_target."
|
||||
else:
|
||||
raise AssertionError("Expected ValueError.")
|
||||
|
||||
|
||||
class HandleState(BaseModel):
|
||||
resource_id: str = ""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", validate_assignment=True)
|
||||
|
||||
|
||||
class HandleBox:
|
||||
def __init__(self, value: str) -> None:
|
||||
self.value = value
|
||||
|
||||
|
||||
class HandleModels(BaseModel):
|
||||
handle: HandleBox | None = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid", validate_assignment=True, arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class HandleLayer(PlainLayer[NoLayerDeps, EmptyLayerConfig, HandleState, HandleModels]):
|
||||
created: int = 0
|
||||
resumed: int = 0
|
||||
|
||||
@override
|
||||
async def on_context_create(self, control: LayerControl[HandleState, HandleModels]) -> None:
|
||||
self.created += 1
|
||||
control.runtime_handles.handle = HandleBox(control.runtime_state.resource_id)
|
||||
|
||||
@override
|
||||
async def on_context_resume(self, control: LayerControl[HandleState, HandleModels]) -> None:
|
||||
self.resumed += 1
|
||||
control.runtime_handles.handle = HandleBox(f"resumed:{control.runtime_state.resource_id}")
|
||||
|
||||
|
||||
def test_new_session_uses_layer_runtime_schemas() -> None:
|
||||
compositor: Compositor[PlainPromptType, PlainToolType] = Compositor(
|
||||
layers=OrderedDict([("handle", HandleLayer())])
|
||||
)
|
||||
session = compositor.new_session()
|
||||
|
||||
assert isinstance(session.layer("handle").runtime_state, HandleState)
|
||||
assert isinstance(session.layer("handle").runtime_handles, HandleModels)
|
||||
|
||||
|
||||
def test_enter_rejects_bad_session_runtime_schemas_before_layer_hooks() -> None:
|
||||
layer = HandleLayer()
|
||||
compositor: Compositor[PlainPromptType, PlainToolType] = Compositor(layers=OrderedDict([("handle", layer)]))
|
||||
bad_session = CompositorSession(OrderedDict([("handle", LayerControl())]))
|
||||
|
||||
async def run() -> None:
|
||||
async with compositor.enter(bad_session):
|
||||
pass
|
||||
|
||||
try:
|
||||
asyncio.run(run())
|
||||
except TypeError as e:
|
||||
assert str(e) == (
|
||||
"CompositorSession layer 'handle' runtime_state must be HandleState, "
|
||||
"got EmptyRuntimeState."
|
||||
)
|
||||
else:
|
||||
raise AssertionError("Expected TypeError.")
|
||||
|
||||
assert layer.created == 0
|
||||
|
||||
|
||||
def test_snapshot_rejects_active_sessions_and_excludes_handles() -> None:
|
||||
compositor: Compositor[PlainPromptType, PlainToolType] = Compositor(
|
||||
layers=OrderedDict([("handle", HandleLayer())])
|
||||
)
|
||||
session = compositor.session_from_snapshot(
|
||||
{"layers": [{"name": "handle", "state": "new", "runtime_state": {"resource_id": "abc"}}]}
|
||||
)
|
||||
|
||||
async def run() -> None:
|
||||
async with compositor.enter(session):
|
||||
try:
|
||||
compositor.snapshot_session(session)
|
||||
except RuntimeError as e:
|
||||
assert str(e) == "Cannot snapshot active compositor session layers: handle."
|
||||
else:
|
||||
raise AssertionError("Expected RuntimeError.")
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
snapshot = compositor.snapshot_session(session)
|
||||
assert snapshot.model_dump(mode="json") == {
|
||||
"schema_version": 1,
|
||||
"layers": [{"name": "handle", "state": "closed", "runtime_state": {"resource_id": "abc"}}],
|
||||
}
|
||||
|
||||
|
||||
def test_restore_validates_runtime_state_and_resume_rehydrates_handles() -> None:
|
||||
layer = HandleLayer()
|
||||
compositor: Compositor[PlainPromptType, PlainToolType] = Compositor(layers=OrderedDict([("handle", layer)]))
|
||||
|
||||
try:
|
||||
compositor.session_from_snapshot(
|
||||
{"layers": [{"name": "handle", "state": "suspended", "runtime_state": {"wrong": "field"}}]}
|
||||
)
|
||||
except ValidationError:
|
||||
pass
|
||||
else:
|
||||
raise AssertionError("Expected ValidationError.")
|
||||
|
||||
restored = compositor.session_from_snapshot(
|
||||
{"layers": [{"name": "handle", "state": "suspended", "runtime_state": {"resource_id": "abc"}}]}
|
||||
)
|
||||
|
||||
async def run() -> None:
|
||||
async with compositor.enter(restored):
|
||||
control = restored.layer("handle")
|
||||
assert isinstance(control.runtime_handles, HandleModels)
|
||||
assert control.runtime_handles.handle is not None
|
||||
assert control.runtime_handles.handle.value == "resumed:abc"
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
assert layer.resumed == 1
|
||||
|
||||
|
||||
def test_session_from_snapshot_rejects_active_layer_state() -> None:
|
||||
compositor: Compositor[PlainPromptType, PlainToolType] = Compositor(
|
||||
layers=OrderedDict([("handle", HandleLayer())])
|
||||
)
|
||||
|
||||
try:
|
||||
compositor.session_from_snapshot(
|
||||
{"layers": [{"name": "handle", "state": "active", "runtime_state": {"resource_id": "abc"}}]}
|
||||
)
|
||||
except ValueError as e:
|
||||
assert str(e) == "Cannot restore active compositor session layers from snapshot: handle."
|
||||
else:
|
||||
raise AssertionError("Expected ValueError.")
|
||||
@ -4,11 +4,14 @@ from collections.abc import Iterator
|
||||
from dataclasses import dataclass, field
|
||||
from itertools import count
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from typing_extensions import override
|
||||
|
||||
from agenton.compositor import Compositor, CompositorSession
|
||||
from agenton.layers import (
|
||||
ExitIntent,
|
||||
EmptyLayerConfig,
|
||||
EmptyRuntimeHandles,
|
||||
LayerControl,
|
||||
LifecycleState,
|
||||
NoLayerDeps,
|
||||
@ -239,22 +242,30 @@ def test_failed_resume_keeps_control_reusable_as_suspended() -> None:
|
||||
assert session.layer("trace").state is LifecycleState.CLOSED
|
||||
|
||||
|
||||
class RuntimeState(BaseModel):
|
||||
runtime_id: int | None = None
|
||||
resumed_runtime_id: int | None = None
|
||||
deleted_runtime_id: int | None = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid", validate_assignment=True)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class RuntimeStateLayer(PlainLayer[NoLayerDeps]):
|
||||
class RuntimeStateLayer(PlainLayer[NoLayerDeps, EmptyLayerConfig, RuntimeState]):
|
||||
next_id: Iterator[int] = field(default_factory=lambda: count(1))
|
||||
|
||||
@override
|
||||
async def on_context_create(self, control: LayerControl) -> None:
|
||||
async def on_context_create(self, control: LayerControl[RuntimeState, EmptyRuntimeHandles]) -> None:
|
||||
runtime_id = next(self.next_id)
|
||||
control.runtime_state["runtime_id"] = runtime_id
|
||||
control.runtime_state.runtime_id = runtime_id
|
||||
|
||||
@override
|
||||
async def on_context_resume(self, control: LayerControl) -> None:
|
||||
control.runtime_state["resumed_runtime_id"] = control.runtime_state["runtime_id"]
|
||||
async def on_context_resume(self, control: LayerControl[RuntimeState, EmptyRuntimeHandles]) -> None:
|
||||
control.runtime_state.resumed_runtime_id = control.runtime_state.runtime_id
|
||||
|
||||
@override
|
||||
async def on_context_delete(self, control: LayerControl) -> None:
|
||||
control.runtime_state["deleted_runtime_id"] = control.runtime_state["runtime_id"]
|
||||
async def on_context_delete(self, control: LayerControl[RuntimeState, EmptyRuntimeHandles]) -> None:
|
||||
control.runtime_state.deleted_runtime_id = control.runtime_state.runtime_id
|
||||
|
||||
|
||||
def test_runtime_state_is_per_session_and_survives_suspend_resume_delete() -> None:
|
||||
@ -275,12 +286,12 @@ def test_runtime_state_is_per_session_and_survives_suspend_resume_delete() -> No
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
assert first_session.layer("trace").runtime_state == {
|
||||
assert first_session.layer("trace").runtime_state.model_dump(exclude_none=True) == {
|
||||
"runtime_id": 1,
|
||||
"resumed_runtime_id": 1,
|
||||
"deleted_runtime_id": 1,
|
||||
}
|
||||
assert second_session.layer("trace").runtime_state == {
|
||||
assert second_session.layer("trace").runtime_state.model_dump(exclude_none=True) == {
|
||||
"runtime_id": 2,
|
||||
"deleted_runtime_id": 2,
|
||||
}
|
||||
|
||||
@ -0,0 +1,94 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from agenton.compositor import LayerRegistry
|
||||
from agenton.layers import EmptyLayerConfig, EmptyRuntimeHandles, EmptyRuntimeState, LayerControl, NoLayerDeps, PlainLayer
|
||||
|
||||
|
||||
class InferredConfig(BaseModel):
|
||||
value: str = "configured"
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class InferredState(BaseModel):
|
||||
count: int = 0
|
||||
|
||||
model_config = ConfigDict(extra="forbid", validate_assignment=True)
|
||||
|
||||
|
||||
class InferredHandles(BaseModel):
|
||||
token: object | None = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid", validate_assignment=True, arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class GenericSchemaLayer(PlainLayer[NoLayerDeps, InferredConfig, InferredState, InferredHandles]):
|
||||
type_id = "test.generic-schema"
|
||||
|
||||
async def on_context_create(self, control: LayerControl[InferredState, InferredHandles]) -> None:
|
||||
control.runtime_state.count += 1
|
||||
control.runtime_handles.token = object()
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class DefaultSchemaLayer(PlainLayer[NoLayerDeps]):
|
||||
type_id = "test.default-schema"
|
||||
|
||||
|
||||
def test_layer_infers_config_runtime_state_and_handles_from_generics() -> None:
|
||||
layer = GenericSchemaLayer()
|
||||
control = layer.new_control(runtime_state={"count": 3})
|
||||
|
||||
assert GenericSchemaLayer.config_type is InferredConfig
|
||||
assert GenericSchemaLayer.runtime_state_type is InferredState
|
||||
assert GenericSchemaLayer.runtime_handles_type is InferredHandles
|
||||
assert isinstance(control.runtime_state, InferredState)
|
||||
assert control.runtime_state.count == 3
|
||||
assert isinstance(control.runtime_handles, InferredHandles)
|
||||
|
||||
|
||||
def test_layer_uses_empty_schema_defaults_when_omitted() -> None:
|
||||
layer = DefaultSchemaLayer()
|
||||
control = layer.new_control()
|
||||
|
||||
assert DefaultSchemaLayer.config_type is EmptyLayerConfig
|
||||
assert DefaultSchemaLayer.runtime_state_type is EmptyRuntimeState
|
||||
assert DefaultSchemaLayer.runtime_handles_type is EmptyRuntimeHandles
|
||||
assert isinstance(control.runtime_state, EmptyRuntimeState)
|
||||
assert isinstance(control.runtime_handles, EmptyRuntimeHandles)
|
||||
|
||||
|
||||
def test_invalid_declared_schema_type_is_rejected_clearly() -> None:
|
||||
try:
|
||||
|
||||
class InvalidSchemaLayer(PlainLayer[NoLayerDeps]):
|
||||
config_type = dict # pyright: ignore[reportAssignmentType]
|
||||
|
||||
except TypeError as e:
|
||||
assert str(e) == "InvalidSchemaLayer.config_type must be a Pydantic BaseModel subclass."
|
||||
else:
|
||||
raise AssertionError("Expected TypeError.")
|
||||
|
||||
try:
|
||||
|
||||
class InvalidGenericSchemaLayer(PlainLayer[NoLayerDeps, dict[str, object]]): # pyright: ignore[reportInvalidTypeArguments]
|
||||
pass
|
||||
|
||||
except TypeError as e:
|
||||
assert str(e) == "InvalidGenericSchemaLayer.config_type must be a Pydantic BaseModel subclass."
|
||||
else:
|
||||
raise AssertionError("Expected TypeError.")
|
||||
|
||||
|
||||
def test_registry_descriptor_uses_inferred_schema_types() -> None:
|
||||
registry = LayerRegistry()
|
||||
registry.register_layer(GenericSchemaLayer)
|
||||
|
||||
descriptor = registry.resolve("test.generic-schema")
|
||||
|
||||
assert descriptor.config_type is InferredConfig
|
||||
assert descriptor.runtime_state_type is InferredState
|
||||
assert descriptor.runtime_handles_type is InferredHandles
|
||||
@ -38,3 +38,11 @@ def test_agenton_pydantic_ai_example_smoke() -> None:
|
||||
assert "ToolCallPart: count_words(" in result.stdout
|
||||
assert "ToolCallPart: write_tagline(" in result.stdout
|
||||
assert "TextPart:" in result.stdout
|
||||
|
||||
|
||||
def test_agenton_session_snapshot_example_smoke() -> None:
|
||||
result = _run_example("examples/agenton/session_snapshot.py")
|
||||
|
||||
assert result.returncode == 0, result.stderr
|
||||
assert "Snapshot:" in result.stdout
|
||||
assert "Rehydrated handle: restored:demo-connection" in result.stdout
|
||||
|
||||
Loading…
Reference in New Issue
Block a user