diff --git a/dify-agent/examples/agenton_basics.py b/dify-agent/examples/agenton_basics.py index a78d979d92..ea02d6931a 100644 --- a/dify-agent/examples/agenton_basics.py +++ b/dify-agent/examples/agenton_basics.py @@ -9,7 +9,7 @@ from inspect import signature from typing_extensions import override from agenton.compositor import Compositor, CompositorLayerConfig -from agenton.layers import LayerContextSignal, LayerDeps, NoLayerDeps, PlainLayer +from agenton.layers import LayerControl, LayerDeps, NoLayerDeps, PlainLayer from agenton_collections.plain import DynamicToolsLayer, ObjectLayer, ToolsLayer, with_object @@ -41,19 +41,19 @@ class TraceLayer(PlainLayer[NoLayerDeps]): events: list[str] = field(default_factory=list) @override - async def on_context_create(self, signal: LayerContextSignal) -> None: + async def on_context_create(self, control: LayerControl) -> None: self.events.append("create") @override - async def on_context_temporarily_leave(self, signal: LayerContextSignal) -> None: - self.events.append("temporary_leave") + async def on_context_tmp_leave(self, control: LayerControl) -> None: + self.events.append("tmp_leave") @override - async def on_context_reenter(self, signal: LayerContextSignal) -> None: + async def on_context_reenter(self, control: LayerControl) -> None: self.events.append("reenter") @override - async def on_context_delete(self, signal: LayerContextSignal) -> None: + async def on_context_delete(self, control: LayerControl) -> None: self.events.append("delete") @@ -128,9 +128,9 @@ async def main() -> None: print(f"- {tool.__name__}{signature(tool)}") print([tool("layer composition") for tool in compositor.tools]) - async with compositor.context() as context: - context.temporary_leave = True - async with compositor.context(): + async with compositor.enter() as lifecycle_control: + lifecycle_control.tmp_leave = True + async with compositor.enter(lifecycle_control): pass print("\nLifecycle:", trace.events) diff --git a/dify-agent/src/agenton/compositor/__init__.py b/dify-agent/src/agenton/compositor/__init__.py index 4ff76fec4f..5f0ac9201e 100644 --- a/dify-agent/src/agenton/compositor/__init__.py +++ b/dify-agent/src/agenton/compositor/__init__.py @@ -12,14 +12,16 @@ layer names as values. Prompt aggregation depends on insertion order: prefix prompts are collected from first to last layer, while suffix prompts are collected in reverse. -``Compositor.context`` enters layer contexts in compositor order and exits them -in reverse order through ``AsyncExitStack``. It yields per-layer lifecycle -signals so callers can mark individual layers, or all layers, as temporarily -leaving. +``Compositor.enter`` enters layers in compositor order and exits them in +reverse order through ``AsyncExitStack``. It accepts an optional +``CompositorControl`` whose keys must match the compositor layer names. When +omitted, one is created from the compositor's layer names. Reuse the same +``CompositorControl`` after setting ``tmp_leave`` to reenter those layer +contexts. """ from collections import OrderedDict -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Iterable from contextlib import AsyncExitStack, asynccontextmanager from dataclasses import dataclass, field from importlib import import_module @@ -28,7 +30,7 @@ from typing import TYPE_CHECKING, Annotated, Any, Mapping, cast from pydantic import AfterValidator, BaseModel, ConfigDict, Field, JsonValue from typing_extensions import Self -from agenton.layers.base import Layer, LayerContextSignal +from agenton.layers.base import Layer, LayerControl class ImportedLayerConfig(BaseModel): @@ -127,21 +129,27 @@ def _validate_compositor_config_input(value: CompositorConfigValue) -> Composito return _validate_config_model_input(CompositorConfig, value) -@dataclass(slots=True) -class CompositorContext: - """Signal slots for layer contexts entered by a compositor.""" +class CompositorControl: + """External controls for layer entry contexts entered by a compositor.""" - signals: OrderedDict[str, LayerContextSignal] + __slots__ = ("layer_controls",) + + layer_controls: OrderedDict[str, LayerControl] + + def __init__(self, layer_names: Iterable[str]) -> None: + self.layer_controls = OrderedDict( + (layer_name, LayerControl()) for layer_name in layer_names + ) @property - def temporary_leave(self) -> bool: - """Whether any entered layer is currently marked for temporary leave.""" - return any(signal.temporary_leave for signal in self.signals.values()) + def tmp_leave(self) -> bool: + """Whether any entered layer control is marked for temporary leave.""" + return any(control.tmp_leave for control in self.layer_controls.values()) - @temporary_leave.setter - def temporary_leave(self, value: bool) -> None: - for signal in self.signals.values(): - signal.temporary_leave = value + @tmp_leave.setter + def tmp_leave(self, value: bool) -> None: + for control in self.layer_controls.values(): + control.tmp_leave = value @dataclass(kw_only=True) @@ -192,15 +200,33 @@ class Compositor[PromptT, ToolT]: self._deps_bound = True @asynccontextmanager - async def context(self) -> AsyncIterator[CompositorContext]: - """Enter each layer context in order and yield their signal slots.""" + async def enter( + self, + control: CompositorControl | None = None, + ) -> AsyncIterator[CompositorControl]: + """Enter each layer context in order and yield compositor control.""" if not self._deps_bound: raise RuntimeError("Compositor deps must be bound before entering context.") - signals: OrderedDict[str, LayerContextSignal] = OrderedDict() + + if control is None: + control = CompositorControl(self.layers) + self._validate_control(control) + async with AsyncExitStack() as stack: for layer_name, layer in self.layers.items(): - signals[layer_name] = await stack.enter_async_context(layer.context()) - yield CompositorContext(signals=signals) + await stack.enter_async_context(layer.enter(control.layer_controls[layer_name])) + yield control + + def _validate_control(self, control: CompositorControl) -> None: + expected_layer_names = tuple(self.layers) + actual_layer_names = tuple(control.layer_controls) + if actual_layer_names != expected_layer_names: + expected = ", ".join(expected_layer_names) + actual = ", ".join(actual_layer_names) + raise ValueError( + "CompositorControl layer names must match compositor layers in order. " + f"Expected [{expected}], got [{actual}]." + ) @property def prompts(self) -> list[PromptT]: @@ -224,7 +250,7 @@ __all__ = [ "CompositorConfig", "CompositorConfigValue", "CompositorLayerConfigInput", - "CompositorContext", + "CompositorControl", "CompositorLayerConfig", "CompositorLayerConfigValue", "ImportedLayerConfig", diff --git a/dify-agent/src/agenton/layers/__init__.py b/dify-agent/src/agenton/layers/__init__.py index 94481c3d2f..47c92108d3 100644 --- a/dify-agent/src/agenton/layers/__init__.py +++ b/dify-agent/src/agenton/layers/__init__.py @@ -5,7 +5,7 @@ families while keeping concrete reusable layers in ``agenton_collections``. """ -from agenton.layers.base import Layer, LayerContextSignal, LayerDeps, NoLayerDeps +from agenton.layers.base import Layer, LayerControl, LayerDeps, NoLayerDeps from agenton.layers.types import ( PlainLayer, PlainPrompt, @@ -17,8 +17,8 @@ from agenton.layers.types import ( __all__ = [ "Layer", - "LayerContextSignal", "LayerDeps", + "LayerControl", "NoLayerDeps", "PlainLayer", "PlainPrompt", diff --git a/dify-agent/src/agenton/layers/base.py b/dify-agent/src/agenton/layers/base.py index 1263b5a6f3..48023ff337 100644 --- a/dify-agent/src/agenton/layers/base.py +++ b/dify-agent/src/agenton/layers/base.py @@ -11,10 +11,10 @@ inheritance patterns. implementations should treat ``self.deps`` as unavailable until a compositor or caller has resolved and bound dependencies. -Layer async contexts use a bool signal to distinguish permanent exits from -temporary exits. A normal first entry runs create logic and a normal exit runs -delete logic; when the signal is set, exit runs temporary-leave logic and the -next entry runs reenter logic. +Layer async entry uses a caller-provided bool control to distinguish permanent +exits from temporary exits. The control is also the external lifecycle state: +reuse a ``tmp_leave`` control to reenter, or pass a fresh control to start from +create logic. ``Layer`` is framework-neutral over prompt and tool item types. Typed families such as ``agenton.layers.types.PlainLayer`` bind those generic slots to a @@ -72,14 +72,17 @@ class NoLayerDeps(LayerDeps): @dataclass(slots=True) -class LayerContextSignal: - """Signal slot exposed inside a layer context. +class LayerControl: + """Control slot passed into a layer entry context. - Set ``temporary_leave`` before leaving the context to run temporary-leave - logic instead of delete logic. A later entry will then run reenter logic. + ``Layer.enter`` requires the caller to provide this object. Set + ``tmp_leave`` before leaving the context to run temporary-leave logic + instead of delete logic. Reusing that same control on a later entry will + consume ``tmp_leave`` and run reenter logic; using a fresh control starts + from create logic. """ - temporary_leave: bool = False + tmp_leave: bool = False @dataclass(frozen=True, slots=True) @@ -97,13 +100,12 @@ class Layer[DepsT: LayerDeps, PromptT, ToolT](ABC): properties. They declare required dependencies in the ``DepsT`` container rather than by accepting dependencies in ``__init__``. The default async context manager handles create, delete, temporary-leave, and reenter - transitions; layers can override ``context`` when they need to wrap extra + transitions; layers can override ``enter`` when they need to wrap extra runtime resources. """ deps_type: type[DepsT] deps: DepsT - _temporarily_left: bool def __init_subclass__(cls) -> None: super().__init_subclass__() @@ -149,46 +151,43 @@ class Layer[DepsT: LayerDeps, PromptT, ToolT](ABC): resolved_deps[name] = deps[name] self.deps = self.deps_type(**resolved_deps) - def context(self) -> AbstractAsyncContextManager[LayerContextSignal]: - """Return the layer's async context manager. + def enter(self, control: LayerControl) -> AbstractAsyncContextManager[None]: + """Return the layer's async entry context manager. - The yielded ``LayerContextSignal`` is the signal slot available to code - inside the context. Subclasses can override this to wrap extra async - resources around ``self.lifecycle_context()``. + ``control`` is the lifecycle control slot for this entry. Subclasses can + override this to wrap extra async resources around + ``self.lifecycle_enter(control)``. """ - return self.lifecycle_context() + return self.lifecycle_enter(control) @asynccontextmanager - async def lifecycle_context(self) -> AsyncIterator[LayerContextSignal]: + async def lifecycle_enter(self, control: LayerControl) -> AsyncIterator[None]: """Run the default create/reenter and delete/temporary-leave lifecycle.""" - signal = LayerContextSignal() - was_temporarily_left = getattr(self, "_temporarily_left", False) - self._temporarily_left = False - if was_temporarily_left: - await self.on_context_reenter(signal) + was_tmp_left = control.tmp_leave + control.tmp_leave = False + if was_tmp_left: + await self.on_context_reenter(control) else: - await self.on_context_create(signal) + await self.on_context_create(control) try: - yield signal + yield finally: - if signal.temporary_leave: - await self.on_context_temporarily_leave(signal) - self._temporarily_left = True + if control.tmp_leave: + await self.on_context_tmp_leave(control) else: - await self.on_context_delete(signal) - self._temporarily_left = False + await self.on_context_delete(control) - async def on_context_create(self, signal: LayerContextSignal) -> None: + async def on_context_create(self, control: LayerControl) -> None: """Run when the layer context is entered from a non-temporary state.""" - async def on_context_delete(self, signal: LayerContextSignal) -> None: - """Run when the layer context exits without a temporary-leave signal.""" + async def on_context_delete(self, control: LayerControl) -> None: + """Run when the layer context exits without ``tmp_leave`` set.""" - async def on_context_temporarily_leave(self, signal: LayerContextSignal) -> None: - """Run when the layer context exits with ``temporary_leave`` set.""" + async def on_context_tmp_leave(self, control: LayerControl) -> None: + """Run when the layer context exits with ``tmp_leave`` set.""" - async def on_context_reenter(self, signal: LayerContextSignal) -> None: + async def on_context_reenter(self, control: LayerControl) -> None: """Run when the layer context enters after a temporary leave.""" @property diff --git a/dify-agent/tests/unit/agenton/compositor/__init__.py b/dify-agent/tests/unit/agenton/compositor/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/dify-agent/tests/unit/agenton/compositor/__init__.py @@ -0,0 +1 @@ + diff --git a/dify-agent/tests/unit/agenton/compositor/test_enter.py b/dify-agent/tests/unit/agenton/compositor/test_enter.py new file mode 100644 index 0000000000..208c541b0c --- /dev/null +++ b/dify-agent/tests/unit/agenton/compositor/test_enter.py @@ -0,0 +1,99 @@ +import asyncio +from collections import OrderedDict +from dataclasses import dataclass, field + +from typing_extensions import override + +from agenton.compositor import Compositor, CompositorControl +from agenton.layers import LayerControl, NoLayerDeps, PlainLayer + + +@dataclass(slots=True) +class TraceLayer(PlainLayer[NoLayerDeps]): + """Layer that records lifecycle events observable to tests.""" + + events: list[str] = field(default_factory=list) + + @override + async def on_context_create(self, control: LayerControl) -> None: + self.events.append("create") + + @override + async def on_context_tmp_leave(self, control: LayerControl) -> None: + self.events.append("tmp_leave") + + @override + async def on_context_reenter(self, control: LayerControl) -> None: + self.events.append("reenter") + + @override + async def on_context_delete(self, control: LayerControl) -> None: + self.events.append("delete") + + +def test_compositor_enter_creates_control_and_applies_tmp_leave_to_all_layers() -> None: + first_layer = TraceLayer() + second_layer = TraceLayer() + compositor: Compositor[str, object] = Compositor( + layers=OrderedDict( + [ + ("first", first_layer), + ("second", second_layer), + ] + ) + ) + compositor_control = CompositorControl(compositor.layers) + + async def run() -> None: + async with compositor.enter(compositor_control) as control: + assert control is compositor_control + assert list(control.layer_controls) == ["first", "second"] + control.tmp_leave = True + + async with compositor.enter(compositor_control): + pass + + asyncio.run(run()) + + assert first_layer.events == ["create", "tmp_leave", "reenter", "delete"] + assert second_layer.events == ["create", "tmp_leave", "reenter", "delete"] + + +def test_compositor_enter_does_not_store_tmp_leave_on_layer() -> None: + layer = TraceLayer() + compositor: Compositor[str, object] = Compositor( + layers=OrderedDict([("trace", layer)]) + ) + + async def run() -> None: + async with compositor.enter() as control: + control.tmp_leave = True + + async with compositor.enter(): + pass + + asyncio.run(run()) + + assert layer.events == ["create", "tmp_leave", "create", "delete"] + + +def test_compositor_enter_rejects_control_with_mismatched_layer_names() -> None: + layer = TraceLayer() + compositor: Compositor[str, object] = Compositor( + layers=OrderedDict([("trace", layer)]) + ) + compositor_control = CompositorControl(["other"]) + + async def run() -> None: + async with compositor.enter(compositor_control): + pass + + try: + asyncio.run(run()) + except ValueError as e: + assert str(e) == ( + "CompositorControl layer names must match compositor layers in order. " + "Expected [trace], got [other]." + ) + else: + raise AssertionError("Expected ValueError.")