refactor the design of tmp leave control

This commit is contained in:
盐粒 Yanli 2026-04-29 04:09:32 +08:00
parent 5a7eb7fdb6
commit 5dfd318907
6 changed files with 195 additions and 70 deletions

View File

@ -9,7 +9,7 @@ from inspect import signature
from typing_extensions import override from typing_extensions import override
from agenton.compositor import Compositor, CompositorLayerConfig from agenton.compositor import Compositor, CompositorLayerConfig
from agenton.layers import LayerContextSignal, LayerDeps, NoLayerDeps, PlainLayer from agenton.layers import LayerControl, LayerDeps, NoLayerDeps, PlainLayer
from agenton_collections.plain import DynamicToolsLayer, ObjectLayer, ToolsLayer, with_object from agenton_collections.plain import DynamicToolsLayer, ObjectLayer, ToolsLayer, with_object
@ -41,19 +41,19 @@ class TraceLayer(PlainLayer[NoLayerDeps]):
events: list[str] = field(default_factory=list) events: list[str] = field(default_factory=list)
@override @override
async def on_context_create(self, signal: LayerContextSignal) -> None: async def on_context_create(self, control: LayerControl) -> None:
self.events.append("create") self.events.append("create")
@override @override
async def on_context_temporarily_leave(self, signal: LayerContextSignal) -> None: async def on_context_tmp_leave(self, control: LayerControl) -> None:
self.events.append("temporary_leave") self.events.append("tmp_leave")
@override @override
async def on_context_reenter(self, signal: LayerContextSignal) -> None: async def on_context_reenter(self, control: LayerControl) -> None:
self.events.append("reenter") self.events.append("reenter")
@override @override
async def on_context_delete(self, signal: LayerContextSignal) -> None: async def on_context_delete(self, control: LayerControl) -> None:
self.events.append("delete") self.events.append("delete")
@ -128,9 +128,9 @@ async def main() -> None:
print(f"- {tool.__name__}{signature(tool)}") print(f"- {tool.__name__}{signature(tool)}")
print([tool("layer composition") for tool in compositor.tools]) print([tool("layer composition") for tool in compositor.tools])
async with compositor.context() as context: async with compositor.enter() as lifecycle_control:
context.temporary_leave = True lifecycle_control.tmp_leave = True
async with compositor.context(): async with compositor.enter(lifecycle_control):
pass pass
print("\nLifecycle:", trace.events) print("\nLifecycle:", trace.events)

View File

@ -12,14 +12,16 @@ layer names as values. Prompt aggregation depends on insertion order: prefix
prompts are collected from first to last layer, while suffix prompts are prompts are collected from first to last layer, while suffix prompts are
collected in reverse. collected in reverse.
``Compositor.context`` enters layer contexts in compositor order and exits them ``Compositor.enter`` enters layers in compositor order and exits them in
in reverse order through ``AsyncExitStack``. It yields per-layer lifecycle reverse order through ``AsyncExitStack``. It accepts an optional
signals so callers can mark individual layers, or all layers, as temporarily ``CompositorControl`` whose keys must match the compositor layer names. When
leaving. omitted, one is created from the compositor's layer names. Reuse the same
``CompositorControl`` after setting ``tmp_leave`` to reenter those layer
contexts.
""" """
from collections import OrderedDict from collections import OrderedDict
from collections.abc import AsyncIterator from collections.abc import AsyncIterator, Iterable
from contextlib import AsyncExitStack, asynccontextmanager from contextlib import AsyncExitStack, asynccontextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from importlib import import_module from importlib import import_module
@ -28,7 +30,7 @@ from typing import TYPE_CHECKING, Annotated, Any, Mapping, cast
from pydantic import AfterValidator, BaseModel, ConfigDict, Field, JsonValue from pydantic import AfterValidator, BaseModel, ConfigDict, Field, JsonValue
from typing_extensions import Self from typing_extensions import Self
from agenton.layers.base import Layer, LayerContextSignal from agenton.layers.base import Layer, LayerControl
class ImportedLayerConfig(BaseModel): class ImportedLayerConfig(BaseModel):
@ -127,21 +129,27 @@ def _validate_compositor_config_input(value: CompositorConfigValue) -> Composito
return _validate_config_model_input(CompositorConfig, value) return _validate_config_model_input(CompositorConfig, value)
@dataclass(slots=True) class CompositorControl:
class CompositorContext: """External controls for layer entry contexts entered by a compositor."""
"""Signal slots for layer contexts entered by a compositor."""
signals: OrderedDict[str, LayerContextSignal] __slots__ = ("layer_controls",)
layer_controls: OrderedDict[str, LayerControl]
def __init__(self, layer_names: Iterable[str]) -> None:
self.layer_controls = OrderedDict(
(layer_name, LayerControl()) for layer_name in layer_names
)
@property @property
def temporary_leave(self) -> bool: def tmp_leave(self) -> bool:
"""Whether any entered layer is currently marked for temporary leave.""" """Whether any entered layer control is marked for temporary leave."""
return any(signal.temporary_leave for signal in self.signals.values()) return any(control.tmp_leave for control in self.layer_controls.values())
@temporary_leave.setter @tmp_leave.setter
def temporary_leave(self, value: bool) -> None: def tmp_leave(self, value: bool) -> None:
for signal in self.signals.values(): for control in self.layer_controls.values():
signal.temporary_leave = value control.tmp_leave = value
@dataclass(kw_only=True) @dataclass(kw_only=True)
@ -192,15 +200,33 @@ class Compositor[PromptT, ToolT]:
self._deps_bound = True self._deps_bound = True
@asynccontextmanager @asynccontextmanager
async def context(self) -> AsyncIterator[CompositorContext]: async def enter(
"""Enter each layer context in order and yield their signal slots.""" self,
control: CompositorControl | None = None,
) -> AsyncIterator[CompositorControl]:
"""Enter each layer context in order and yield compositor control."""
if not self._deps_bound: if not self._deps_bound:
raise RuntimeError("Compositor deps must be bound before entering context.") raise RuntimeError("Compositor deps must be bound before entering context.")
signals: OrderedDict[str, LayerContextSignal] = OrderedDict()
if control is None:
control = CompositorControl(self.layers)
self._validate_control(control)
async with AsyncExitStack() as stack: async with AsyncExitStack() as stack:
for layer_name, layer in self.layers.items(): for layer_name, layer in self.layers.items():
signals[layer_name] = await stack.enter_async_context(layer.context()) await stack.enter_async_context(layer.enter(control.layer_controls[layer_name]))
yield CompositorContext(signals=signals) yield control
def _validate_control(self, control: CompositorControl) -> None:
expected_layer_names = tuple(self.layers)
actual_layer_names = tuple(control.layer_controls)
if actual_layer_names != expected_layer_names:
expected = ", ".join(expected_layer_names)
actual = ", ".join(actual_layer_names)
raise ValueError(
"CompositorControl layer names must match compositor layers in order. "
f"Expected [{expected}], got [{actual}]."
)
@property @property
def prompts(self) -> list[PromptT]: def prompts(self) -> list[PromptT]:
@ -224,7 +250,7 @@ __all__ = [
"CompositorConfig", "CompositorConfig",
"CompositorConfigValue", "CompositorConfigValue",
"CompositorLayerConfigInput", "CompositorLayerConfigInput",
"CompositorContext", "CompositorControl",
"CompositorLayerConfig", "CompositorLayerConfig",
"CompositorLayerConfigValue", "CompositorLayerConfigValue",
"ImportedLayerConfig", "ImportedLayerConfig",

View File

@ -5,7 +5,7 @@
families while keeping concrete reusable layers in ``agenton_collections``. families while keeping concrete reusable layers in ``agenton_collections``.
""" """
from agenton.layers.base import Layer, LayerContextSignal, LayerDeps, NoLayerDeps from agenton.layers.base import Layer, LayerControl, LayerDeps, NoLayerDeps
from agenton.layers.types import ( from agenton.layers.types import (
PlainLayer, PlainLayer,
PlainPrompt, PlainPrompt,
@ -17,8 +17,8 @@ from agenton.layers.types import (
__all__ = [ __all__ = [
"Layer", "Layer",
"LayerContextSignal",
"LayerDeps", "LayerDeps",
"LayerControl",
"NoLayerDeps", "NoLayerDeps",
"PlainLayer", "PlainLayer",
"PlainPrompt", "PlainPrompt",

View File

@ -11,10 +11,10 @@ inheritance patterns.
implementations should treat ``self.deps`` as unavailable until a compositor or implementations should treat ``self.deps`` as unavailable until a compositor or
caller has resolved and bound dependencies. caller has resolved and bound dependencies.
Layer async contexts use a bool signal to distinguish permanent exits from Layer async entry uses a caller-provided bool control to distinguish permanent
temporary exits. A normal first entry runs create logic and a normal exit runs exits from temporary exits. The control is also the external lifecycle state:
delete logic; when the signal is set, exit runs temporary-leave logic and the reuse a ``tmp_leave`` control to reenter, or pass a fresh control to start from
next entry runs reenter logic. create logic.
``Layer`` is framework-neutral over prompt and tool item types. Typed families ``Layer`` is framework-neutral over prompt and tool item types. Typed families
such as ``agenton.layers.types.PlainLayer`` bind those generic slots to a such as ``agenton.layers.types.PlainLayer`` bind those generic slots to a
@ -72,14 +72,17 @@ class NoLayerDeps(LayerDeps):
@dataclass(slots=True) @dataclass(slots=True)
class LayerContextSignal: class LayerControl:
"""Signal slot exposed inside a layer context. """Control slot passed into a layer entry context.
Set ``temporary_leave`` before leaving the context to run temporary-leave ``Layer.enter`` requires the caller to provide this object. Set
logic instead of delete logic. A later entry will then run reenter logic. ``tmp_leave`` before leaving the context to run temporary-leave logic
instead of delete logic. Reusing that same control on a later entry will
consume ``tmp_leave`` and run reenter logic; using a fresh control starts
from create logic.
""" """
temporary_leave: bool = False tmp_leave: bool = False
@dataclass(frozen=True, slots=True) @dataclass(frozen=True, slots=True)
@ -97,13 +100,12 @@ class Layer[DepsT: LayerDeps, PromptT, ToolT](ABC):
properties. They declare required dependencies in the ``DepsT`` container properties. They declare required dependencies in the ``DepsT`` container
rather than by accepting dependencies in ``__init__``. The default async rather than by accepting dependencies in ``__init__``. The default async
context manager handles create, delete, temporary-leave, and reenter context manager handles create, delete, temporary-leave, and reenter
transitions; layers can override ``context`` when they need to wrap extra transitions; layers can override ``enter`` when they need to wrap extra
runtime resources. runtime resources.
""" """
deps_type: type[DepsT] deps_type: type[DepsT]
deps: DepsT deps: DepsT
_temporarily_left: bool
def __init_subclass__(cls) -> None: def __init_subclass__(cls) -> None:
super().__init_subclass__() super().__init_subclass__()
@ -149,46 +151,43 @@ class Layer[DepsT: LayerDeps, PromptT, ToolT](ABC):
resolved_deps[name] = deps[name] resolved_deps[name] = deps[name]
self.deps = self.deps_type(**resolved_deps) self.deps = self.deps_type(**resolved_deps)
def context(self) -> AbstractAsyncContextManager[LayerContextSignal]: def enter(self, control: LayerControl) -> AbstractAsyncContextManager[None]:
"""Return the layer's async context manager. """Return the layer's async entry context manager.
The yielded ``LayerContextSignal`` is the signal slot available to code ``control`` is the lifecycle control slot for this entry. Subclasses can
inside the context. Subclasses can override this to wrap extra async override this to wrap extra async resources around
resources around ``self.lifecycle_context()``. ``self.lifecycle_enter(control)``.
""" """
return self.lifecycle_context() return self.lifecycle_enter(control)
@asynccontextmanager @asynccontextmanager
async def lifecycle_context(self) -> AsyncIterator[LayerContextSignal]: async def lifecycle_enter(self, control: LayerControl) -> AsyncIterator[None]:
"""Run the default create/reenter and delete/temporary-leave lifecycle.""" """Run the default create/reenter and delete/temporary-leave lifecycle."""
signal = LayerContextSignal() was_tmp_left = control.tmp_leave
was_temporarily_left = getattr(self, "_temporarily_left", False) control.tmp_leave = False
self._temporarily_left = False if was_tmp_left:
if was_temporarily_left: await self.on_context_reenter(control)
await self.on_context_reenter(signal)
else: else:
await self.on_context_create(signal) await self.on_context_create(control)
try: try:
yield signal yield
finally: finally:
if signal.temporary_leave: if control.tmp_leave:
await self.on_context_temporarily_leave(signal) await self.on_context_tmp_leave(control)
self._temporarily_left = True
else: else:
await self.on_context_delete(signal) await self.on_context_delete(control)
self._temporarily_left = False
async def on_context_create(self, signal: LayerContextSignal) -> None: async def on_context_create(self, control: LayerControl) -> None:
"""Run when the layer context is entered from a non-temporary state.""" """Run when the layer context is entered from a non-temporary state."""
async def on_context_delete(self, signal: LayerContextSignal) -> None: async def on_context_delete(self, control: LayerControl) -> None:
"""Run when the layer context exits without a temporary-leave signal.""" """Run when the layer context exits without ``tmp_leave`` set."""
async def on_context_temporarily_leave(self, signal: LayerContextSignal) -> None: async def on_context_tmp_leave(self, control: LayerControl) -> None:
"""Run when the layer context exits with ``temporary_leave`` set.""" """Run when the layer context exits with ``tmp_leave`` set."""
async def on_context_reenter(self, signal: LayerContextSignal) -> None: async def on_context_reenter(self, control: LayerControl) -> None:
"""Run when the layer context enters after a temporary leave.""" """Run when the layer context enters after a temporary leave."""
@property @property

View File

@ -0,0 +1 @@

View File

@ -0,0 +1,99 @@
import asyncio
from collections import OrderedDict
from dataclasses import dataclass, field
from typing_extensions import override
from agenton.compositor import Compositor, CompositorControl
from agenton.layers import LayerControl, NoLayerDeps, PlainLayer
@dataclass(slots=True)
class TraceLayer(PlainLayer[NoLayerDeps]):
"""Layer that records lifecycle events observable to tests."""
events: list[str] = field(default_factory=list)
@override
async def on_context_create(self, control: LayerControl) -> None:
self.events.append("create")
@override
async def on_context_tmp_leave(self, control: LayerControl) -> None:
self.events.append("tmp_leave")
@override
async def on_context_reenter(self, control: LayerControl) -> None:
self.events.append("reenter")
@override
async def on_context_delete(self, control: LayerControl) -> None:
self.events.append("delete")
def test_compositor_enter_creates_control_and_applies_tmp_leave_to_all_layers() -> None:
first_layer = TraceLayer()
second_layer = TraceLayer()
compositor: Compositor[str, object] = Compositor(
layers=OrderedDict(
[
("first", first_layer),
("second", second_layer),
]
)
)
compositor_control = CompositorControl(compositor.layers)
async def run() -> None:
async with compositor.enter(compositor_control) as control:
assert control is compositor_control
assert list(control.layer_controls) == ["first", "second"]
control.tmp_leave = True
async with compositor.enter(compositor_control):
pass
asyncio.run(run())
assert first_layer.events == ["create", "tmp_leave", "reenter", "delete"]
assert second_layer.events == ["create", "tmp_leave", "reenter", "delete"]
def test_compositor_enter_does_not_store_tmp_leave_on_layer() -> None:
layer = TraceLayer()
compositor: Compositor[str, object] = Compositor(
layers=OrderedDict([("trace", layer)])
)
async def run() -> None:
async with compositor.enter() as control:
control.tmp_leave = True
async with compositor.enter():
pass
asyncio.run(run())
assert layer.events == ["create", "tmp_leave", "create", "delete"]
def test_compositor_enter_rejects_control_with_mismatched_layer_names() -> None:
layer = TraceLayer()
compositor: Compositor[str, object] = Compositor(
layers=OrderedDict([("trace", layer)])
)
compositor_control = CompositorControl(["other"])
async def run() -> None:
async with compositor.enter(compositor_control):
pass
try:
asyncio.run(run())
except ValueError as e:
assert str(e) == (
"CompositorControl layer names must match compositor layers in order. "
"Expected [trace], got [other]."
)
else:
raise AssertionError("Expected ValueError.")