refactor the design of tmp leave control

This commit is contained in:
盐粒 Yanli 2026-04-29 04:09:32 +08:00
parent 5a7eb7fdb6
commit 5dfd318907
6 changed files with 195 additions and 70 deletions

View File

@ -9,7 +9,7 @@ from inspect import signature
from typing_extensions import override
from agenton.compositor import Compositor, CompositorLayerConfig
from agenton.layers import LayerContextSignal, LayerDeps, NoLayerDeps, PlainLayer
from agenton.layers import LayerControl, LayerDeps, NoLayerDeps, PlainLayer
from agenton_collections.plain import DynamicToolsLayer, ObjectLayer, ToolsLayer, with_object
@ -41,19 +41,19 @@ class TraceLayer(PlainLayer[NoLayerDeps]):
events: list[str] = field(default_factory=list)
@override
async def on_context_create(self, signal: LayerContextSignal) -> None:
async def on_context_create(self, control: LayerControl) -> None:
self.events.append("create")
@override
async def on_context_temporarily_leave(self, signal: LayerContextSignal) -> None:
self.events.append("temporary_leave")
async def on_context_tmp_leave(self, control: LayerControl) -> None:
self.events.append("tmp_leave")
@override
async def on_context_reenter(self, signal: LayerContextSignal) -> None:
async def on_context_reenter(self, control: LayerControl) -> None:
self.events.append("reenter")
@override
async def on_context_delete(self, signal: LayerContextSignal) -> None:
async def on_context_delete(self, control: LayerControl) -> None:
self.events.append("delete")
@ -128,9 +128,9 @@ async def main() -> None:
print(f"- {tool.__name__}{signature(tool)}")
print([tool("layer composition") for tool in compositor.tools])
async with compositor.context() as context:
context.temporary_leave = True
async with compositor.context():
async with compositor.enter() as lifecycle_control:
lifecycle_control.tmp_leave = True
async with compositor.enter(lifecycle_control):
pass
print("\nLifecycle:", trace.events)

View File

@ -12,14 +12,16 @@ layer names as values. Prompt aggregation depends on insertion order: prefix
prompts are collected from first to last layer, while suffix prompts are
collected in reverse.
``Compositor.context`` enters layer contexts in compositor order and exits them
in reverse order through ``AsyncExitStack``. It yields per-layer lifecycle
signals so callers can mark individual layers, or all layers, as temporarily
leaving.
``Compositor.enter`` enters layers in compositor order and exits them in
reverse order through ``AsyncExitStack``. It accepts an optional
``CompositorControl`` whose keys must match the compositor layer names. When
omitted, one is created from the compositor's layer names. Reuse the same
``CompositorControl`` after setting ``tmp_leave`` to reenter those layer
contexts.
"""
from collections import OrderedDict
from collections.abc import AsyncIterator
from collections.abc import AsyncIterator, Iterable
from contextlib import AsyncExitStack, asynccontextmanager
from dataclasses import dataclass, field
from importlib import import_module
@ -28,7 +30,7 @@ from typing import TYPE_CHECKING, Annotated, Any, Mapping, cast
from pydantic import AfterValidator, BaseModel, ConfigDict, Field, JsonValue
from typing_extensions import Self
from agenton.layers.base import Layer, LayerContextSignal
from agenton.layers.base import Layer, LayerControl
class ImportedLayerConfig(BaseModel):
@ -127,21 +129,27 @@ def _validate_compositor_config_input(value: CompositorConfigValue) -> Composito
return _validate_config_model_input(CompositorConfig, value)
@dataclass(slots=True)
class CompositorContext:
"""Signal slots for layer contexts entered by a compositor."""
class CompositorControl:
"""External controls for layer entry contexts entered by a compositor."""
signals: OrderedDict[str, LayerContextSignal]
__slots__ = ("layer_controls",)
layer_controls: OrderedDict[str, LayerControl]
def __init__(self, layer_names: Iterable[str]) -> None:
self.layer_controls = OrderedDict(
(layer_name, LayerControl()) for layer_name in layer_names
)
@property
def temporary_leave(self) -> bool:
"""Whether any entered layer is currently marked for temporary leave."""
return any(signal.temporary_leave for signal in self.signals.values())
def tmp_leave(self) -> bool:
"""Whether any entered layer control is marked for temporary leave."""
return any(control.tmp_leave for control in self.layer_controls.values())
@temporary_leave.setter
def temporary_leave(self, value: bool) -> None:
for signal in self.signals.values():
signal.temporary_leave = value
@tmp_leave.setter
def tmp_leave(self, value: bool) -> None:
for control in self.layer_controls.values():
control.tmp_leave = value
@dataclass(kw_only=True)
@ -192,15 +200,33 @@ class Compositor[PromptT, ToolT]:
self._deps_bound = True
@asynccontextmanager
async def context(self) -> AsyncIterator[CompositorContext]:
"""Enter each layer context in order and yield their signal slots."""
async def enter(
self,
control: CompositorControl | None = None,
) -> AsyncIterator[CompositorControl]:
"""Enter each layer context in order and yield compositor control."""
if not self._deps_bound:
raise RuntimeError("Compositor deps must be bound before entering context.")
signals: OrderedDict[str, LayerContextSignal] = OrderedDict()
if control is None:
control = CompositorControl(self.layers)
self._validate_control(control)
async with AsyncExitStack() as stack:
for layer_name, layer in self.layers.items():
signals[layer_name] = await stack.enter_async_context(layer.context())
yield CompositorContext(signals=signals)
await stack.enter_async_context(layer.enter(control.layer_controls[layer_name]))
yield control
def _validate_control(self, control: CompositorControl) -> None:
expected_layer_names = tuple(self.layers)
actual_layer_names = tuple(control.layer_controls)
if actual_layer_names != expected_layer_names:
expected = ", ".join(expected_layer_names)
actual = ", ".join(actual_layer_names)
raise ValueError(
"CompositorControl layer names must match compositor layers in order. "
f"Expected [{expected}], got [{actual}]."
)
@property
def prompts(self) -> list[PromptT]:
@ -224,7 +250,7 @@ __all__ = [
"CompositorConfig",
"CompositorConfigValue",
"CompositorLayerConfigInput",
"CompositorContext",
"CompositorControl",
"CompositorLayerConfig",
"CompositorLayerConfigValue",
"ImportedLayerConfig",

View File

@ -5,7 +5,7 @@
families while keeping concrete reusable layers in ``agenton_collections``.
"""
from agenton.layers.base import Layer, LayerContextSignal, LayerDeps, NoLayerDeps
from agenton.layers.base import Layer, LayerControl, LayerDeps, NoLayerDeps
from agenton.layers.types import (
PlainLayer,
PlainPrompt,
@ -17,8 +17,8 @@ from agenton.layers.types import (
__all__ = [
"Layer",
"LayerContextSignal",
"LayerDeps",
"LayerControl",
"NoLayerDeps",
"PlainLayer",
"PlainPrompt",

View File

@ -11,10 +11,10 @@ inheritance patterns.
implementations should treat ``self.deps`` as unavailable until a compositor or
caller has resolved and bound dependencies.
Layer async contexts use a bool signal to distinguish permanent exits from
temporary exits. A normal first entry runs create logic and a normal exit runs
delete logic; when the signal is set, exit runs temporary-leave logic and the
next entry runs reenter logic.
Layer async entry uses a caller-provided bool control to distinguish permanent
exits from temporary exits. The control is also the external lifecycle state:
reuse a ``tmp_leave`` control to reenter, or pass a fresh control to start from
create logic.
``Layer`` is framework-neutral over prompt and tool item types. Typed families
such as ``agenton.layers.types.PlainLayer`` bind those generic slots to a
@ -72,14 +72,17 @@ class NoLayerDeps(LayerDeps):
@dataclass(slots=True)
class LayerContextSignal:
"""Signal slot exposed inside a layer context.
class LayerControl:
"""Control slot passed into a layer entry context.
Set ``temporary_leave`` before leaving the context to run temporary-leave
logic instead of delete logic. A later entry will then run reenter logic.
``Layer.enter`` requires the caller to provide this object. Set
``tmp_leave`` before leaving the context to run temporary-leave logic
instead of delete logic. Reusing that same control on a later entry will
consume ``tmp_leave`` and run reenter logic; using a fresh control starts
from create logic.
"""
temporary_leave: bool = False
tmp_leave: bool = False
@dataclass(frozen=True, slots=True)
@ -97,13 +100,12 @@ class Layer[DepsT: LayerDeps, PromptT, ToolT](ABC):
properties. They declare required dependencies in the ``DepsT`` container
rather than by accepting dependencies in ``__init__``. The default async
context manager handles create, delete, temporary-leave, and reenter
transitions; layers can override ``context`` when they need to wrap extra
transitions; layers can override ``enter`` when they need to wrap extra
runtime resources.
"""
deps_type: type[DepsT]
deps: DepsT
_temporarily_left: bool
def __init_subclass__(cls) -> None:
super().__init_subclass__()
@ -149,46 +151,43 @@ class Layer[DepsT: LayerDeps, PromptT, ToolT](ABC):
resolved_deps[name] = deps[name]
self.deps = self.deps_type(**resolved_deps)
def context(self) -> AbstractAsyncContextManager[LayerContextSignal]:
"""Return the layer's async context manager.
def enter(self, control: LayerControl) -> AbstractAsyncContextManager[None]:
"""Return the layer's async entry context manager.
The yielded ``LayerContextSignal`` is the signal slot available to code
inside the context. Subclasses can override this to wrap extra async
resources around ``self.lifecycle_context()``.
``control`` is the lifecycle control slot for this entry. Subclasses can
override this to wrap extra async resources around
``self.lifecycle_enter(control)``.
"""
return self.lifecycle_context()
return self.lifecycle_enter(control)
@asynccontextmanager
async def lifecycle_context(self) -> AsyncIterator[LayerContextSignal]:
async def lifecycle_enter(self, control: LayerControl) -> AsyncIterator[None]:
"""Run the default create/reenter and delete/temporary-leave lifecycle."""
signal = LayerContextSignal()
was_temporarily_left = getattr(self, "_temporarily_left", False)
self._temporarily_left = False
if was_temporarily_left:
await self.on_context_reenter(signal)
was_tmp_left = control.tmp_leave
control.tmp_leave = False
if was_tmp_left:
await self.on_context_reenter(control)
else:
await self.on_context_create(signal)
await self.on_context_create(control)
try:
yield signal
yield
finally:
if signal.temporary_leave:
await self.on_context_temporarily_leave(signal)
self._temporarily_left = True
if control.tmp_leave:
await self.on_context_tmp_leave(control)
else:
await self.on_context_delete(signal)
self._temporarily_left = False
await self.on_context_delete(control)
async def on_context_create(self, signal: LayerContextSignal) -> None:
async def on_context_create(self, control: LayerControl) -> None:
"""Run when the layer context is entered from a non-temporary state."""
async def on_context_delete(self, signal: LayerContextSignal) -> None:
"""Run when the layer context exits without a temporary-leave signal."""
async def on_context_delete(self, control: LayerControl) -> None:
"""Run when the layer context exits without ``tmp_leave`` set."""
async def on_context_temporarily_leave(self, signal: LayerContextSignal) -> None:
"""Run when the layer context exits with ``temporary_leave`` set."""
async def on_context_tmp_leave(self, control: LayerControl) -> None:
"""Run when the layer context exits with ``tmp_leave`` set."""
async def on_context_reenter(self, signal: LayerContextSignal) -> None:
async def on_context_reenter(self, control: LayerControl) -> None:
"""Run when the layer context enters after a temporary leave."""
@property

View File

@ -0,0 +1 @@

View File

@ -0,0 +1,99 @@
import asyncio
from collections import OrderedDict
from dataclasses import dataclass, field
from typing_extensions import override
from agenton.compositor import Compositor, CompositorControl
from agenton.layers import LayerControl, NoLayerDeps, PlainLayer
@dataclass(slots=True)
class TraceLayer(PlainLayer[NoLayerDeps]):
"""Layer that records lifecycle events observable to tests."""
events: list[str] = field(default_factory=list)
@override
async def on_context_create(self, control: LayerControl) -> None:
self.events.append("create")
@override
async def on_context_tmp_leave(self, control: LayerControl) -> None:
self.events.append("tmp_leave")
@override
async def on_context_reenter(self, control: LayerControl) -> None:
self.events.append("reenter")
@override
async def on_context_delete(self, control: LayerControl) -> None:
self.events.append("delete")
def test_compositor_enter_creates_control_and_applies_tmp_leave_to_all_layers() -> None:
first_layer = TraceLayer()
second_layer = TraceLayer()
compositor: Compositor[str, object] = Compositor(
layers=OrderedDict(
[
("first", first_layer),
("second", second_layer),
]
)
)
compositor_control = CompositorControl(compositor.layers)
async def run() -> None:
async with compositor.enter(compositor_control) as control:
assert control is compositor_control
assert list(control.layer_controls) == ["first", "second"]
control.tmp_leave = True
async with compositor.enter(compositor_control):
pass
asyncio.run(run())
assert first_layer.events == ["create", "tmp_leave", "reenter", "delete"]
assert second_layer.events == ["create", "tmp_leave", "reenter", "delete"]
def test_compositor_enter_does_not_store_tmp_leave_on_layer() -> None:
layer = TraceLayer()
compositor: Compositor[str, object] = Compositor(
layers=OrderedDict([("trace", layer)])
)
async def run() -> None:
async with compositor.enter() as control:
control.tmp_leave = True
async with compositor.enter():
pass
asyncio.run(run())
assert layer.events == ["create", "tmp_leave", "create", "delete"]
def test_compositor_enter_rejects_control_with_mismatched_layer_names() -> None:
layer = TraceLayer()
compositor: Compositor[str, object] = Compositor(
layers=OrderedDict([("trace", layer)])
)
compositor_control = CompositorControl(["other"])
async def run() -> None:
async with compositor.enter(compositor_control):
pass
try:
asyncio.run(run())
except ValueError as e:
assert str(e) == (
"CompositorControl layer names must match compositor layers in order. "
"Expected [trace], got [other]."
)
else:
raise AssertionError("Expected ValueError.")