mirror of
https://github.com/langgenius/dify.git
synced 2026-05-09 21:28:25 +08:00
refactor the design of tmp leave control
This commit is contained in:
parent
5a7eb7fdb6
commit
5dfd318907
@ -9,7 +9,7 @@ from inspect import signature
|
||||
from typing_extensions import override
|
||||
|
||||
from agenton.compositor import Compositor, CompositorLayerConfig
|
||||
from agenton.layers import LayerContextSignal, LayerDeps, NoLayerDeps, PlainLayer
|
||||
from agenton.layers import LayerControl, LayerDeps, NoLayerDeps, PlainLayer
|
||||
from agenton_collections.plain import DynamicToolsLayer, ObjectLayer, ToolsLayer, with_object
|
||||
|
||||
|
||||
@ -41,19 +41,19 @@ class TraceLayer(PlainLayer[NoLayerDeps]):
|
||||
events: list[str] = field(default_factory=list)
|
||||
|
||||
@override
|
||||
async def on_context_create(self, signal: LayerContextSignal) -> None:
|
||||
async def on_context_create(self, control: LayerControl) -> None:
|
||||
self.events.append("create")
|
||||
|
||||
@override
|
||||
async def on_context_temporarily_leave(self, signal: LayerContextSignal) -> None:
|
||||
self.events.append("temporary_leave")
|
||||
async def on_context_tmp_leave(self, control: LayerControl) -> None:
|
||||
self.events.append("tmp_leave")
|
||||
|
||||
@override
|
||||
async def on_context_reenter(self, signal: LayerContextSignal) -> None:
|
||||
async def on_context_reenter(self, control: LayerControl) -> None:
|
||||
self.events.append("reenter")
|
||||
|
||||
@override
|
||||
async def on_context_delete(self, signal: LayerContextSignal) -> None:
|
||||
async def on_context_delete(self, control: LayerControl) -> None:
|
||||
self.events.append("delete")
|
||||
|
||||
|
||||
@ -128,9 +128,9 @@ async def main() -> None:
|
||||
print(f"- {tool.__name__}{signature(tool)}")
|
||||
print([tool("layer composition") for tool in compositor.tools])
|
||||
|
||||
async with compositor.context() as context:
|
||||
context.temporary_leave = True
|
||||
async with compositor.context():
|
||||
async with compositor.enter() as lifecycle_control:
|
||||
lifecycle_control.tmp_leave = True
|
||||
async with compositor.enter(lifecycle_control):
|
||||
pass
|
||||
print("\nLifecycle:", trace.events)
|
||||
|
||||
|
||||
@ -12,14 +12,16 @@ layer names as values. Prompt aggregation depends on insertion order: prefix
|
||||
prompts are collected from first to last layer, while suffix prompts are
|
||||
collected in reverse.
|
||||
|
||||
``Compositor.context`` enters layer contexts in compositor order and exits them
|
||||
in reverse order through ``AsyncExitStack``. It yields per-layer lifecycle
|
||||
signals so callers can mark individual layers, or all layers, as temporarily
|
||||
leaving.
|
||||
``Compositor.enter`` enters layers in compositor order and exits them in
|
||||
reverse order through ``AsyncExitStack``. It accepts an optional
|
||||
``CompositorControl`` whose keys must match the compositor layer names. When
|
||||
omitted, one is created from the compositor's layer names. Reuse the same
|
||||
``CompositorControl`` after setting ``tmp_leave`` to reenter those layer
|
||||
contexts.
|
||||
"""
|
||||
|
||||
from collections import OrderedDict
|
||||
from collections.abc import AsyncIterator
|
||||
from collections.abc import AsyncIterator, Iterable
|
||||
from contextlib import AsyncExitStack, asynccontextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from importlib import import_module
|
||||
@ -28,7 +30,7 @@ from typing import TYPE_CHECKING, Annotated, Any, Mapping, cast
|
||||
from pydantic import AfterValidator, BaseModel, ConfigDict, Field, JsonValue
|
||||
from typing_extensions import Self
|
||||
|
||||
from agenton.layers.base import Layer, LayerContextSignal
|
||||
from agenton.layers.base import Layer, LayerControl
|
||||
|
||||
|
||||
class ImportedLayerConfig(BaseModel):
|
||||
@ -127,21 +129,27 @@ def _validate_compositor_config_input(value: CompositorConfigValue) -> Composito
|
||||
return _validate_config_model_input(CompositorConfig, value)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class CompositorContext:
|
||||
"""Signal slots for layer contexts entered by a compositor."""
|
||||
class CompositorControl:
|
||||
"""External controls for layer entry contexts entered by a compositor."""
|
||||
|
||||
signals: OrderedDict[str, LayerContextSignal]
|
||||
__slots__ = ("layer_controls",)
|
||||
|
||||
layer_controls: OrderedDict[str, LayerControl]
|
||||
|
||||
def __init__(self, layer_names: Iterable[str]) -> None:
|
||||
self.layer_controls = OrderedDict(
|
||||
(layer_name, LayerControl()) for layer_name in layer_names
|
||||
)
|
||||
|
||||
@property
|
||||
def temporary_leave(self) -> bool:
|
||||
"""Whether any entered layer is currently marked for temporary leave."""
|
||||
return any(signal.temporary_leave for signal in self.signals.values())
|
||||
def tmp_leave(self) -> bool:
|
||||
"""Whether any entered layer control is marked for temporary leave."""
|
||||
return any(control.tmp_leave for control in self.layer_controls.values())
|
||||
|
||||
@temporary_leave.setter
|
||||
def temporary_leave(self, value: bool) -> None:
|
||||
for signal in self.signals.values():
|
||||
signal.temporary_leave = value
|
||||
@tmp_leave.setter
|
||||
def tmp_leave(self, value: bool) -> None:
|
||||
for control in self.layer_controls.values():
|
||||
control.tmp_leave = value
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
@ -192,15 +200,33 @@ class Compositor[PromptT, ToolT]:
|
||||
self._deps_bound = True
|
||||
|
||||
@asynccontextmanager
|
||||
async def context(self) -> AsyncIterator[CompositorContext]:
|
||||
"""Enter each layer context in order and yield their signal slots."""
|
||||
async def enter(
|
||||
self,
|
||||
control: CompositorControl | None = None,
|
||||
) -> AsyncIterator[CompositorControl]:
|
||||
"""Enter each layer context in order and yield compositor control."""
|
||||
if not self._deps_bound:
|
||||
raise RuntimeError("Compositor deps must be bound before entering context.")
|
||||
signals: OrderedDict[str, LayerContextSignal] = OrderedDict()
|
||||
|
||||
if control is None:
|
||||
control = CompositorControl(self.layers)
|
||||
self._validate_control(control)
|
||||
|
||||
async with AsyncExitStack() as stack:
|
||||
for layer_name, layer in self.layers.items():
|
||||
signals[layer_name] = await stack.enter_async_context(layer.context())
|
||||
yield CompositorContext(signals=signals)
|
||||
await stack.enter_async_context(layer.enter(control.layer_controls[layer_name]))
|
||||
yield control
|
||||
|
||||
def _validate_control(self, control: CompositorControl) -> None:
|
||||
expected_layer_names = tuple(self.layers)
|
||||
actual_layer_names = tuple(control.layer_controls)
|
||||
if actual_layer_names != expected_layer_names:
|
||||
expected = ", ".join(expected_layer_names)
|
||||
actual = ", ".join(actual_layer_names)
|
||||
raise ValueError(
|
||||
"CompositorControl layer names must match compositor layers in order. "
|
||||
f"Expected [{expected}], got [{actual}]."
|
||||
)
|
||||
|
||||
@property
|
||||
def prompts(self) -> list[PromptT]:
|
||||
@ -224,7 +250,7 @@ __all__ = [
|
||||
"CompositorConfig",
|
||||
"CompositorConfigValue",
|
||||
"CompositorLayerConfigInput",
|
||||
"CompositorContext",
|
||||
"CompositorControl",
|
||||
"CompositorLayerConfig",
|
||||
"CompositorLayerConfigValue",
|
||||
"ImportedLayerConfig",
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
families while keeping concrete reusable layers in ``agenton_collections``.
|
||||
"""
|
||||
|
||||
from agenton.layers.base import Layer, LayerContextSignal, LayerDeps, NoLayerDeps
|
||||
from agenton.layers.base import Layer, LayerControl, LayerDeps, NoLayerDeps
|
||||
from agenton.layers.types import (
|
||||
PlainLayer,
|
||||
PlainPrompt,
|
||||
@ -17,8 +17,8 @@ from agenton.layers.types import (
|
||||
|
||||
__all__ = [
|
||||
"Layer",
|
||||
"LayerContextSignal",
|
||||
"LayerDeps",
|
||||
"LayerControl",
|
||||
"NoLayerDeps",
|
||||
"PlainLayer",
|
||||
"PlainPrompt",
|
||||
|
||||
@ -11,10 +11,10 @@ inheritance patterns.
|
||||
implementations should treat ``self.deps`` as unavailable until a compositor or
|
||||
caller has resolved and bound dependencies.
|
||||
|
||||
Layer async contexts use a bool signal to distinguish permanent exits from
|
||||
temporary exits. A normal first entry runs create logic and a normal exit runs
|
||||
delete logic; when the signal is set, exit runs temporary-leave logic and the
|
||||
next entry runs reenter logic.
|
||||
Layer async entry uses a caller-provided bool control to distinguish permanent
|
||||
exits from temporary exits. The control is also the external lifecycle state:
|
||||
reuse a ``tmp_leave`` control to reenter, or pass a fresh control to start from
|
||||
create logic.
|
||||
|
||||
``Layer`` is framework-neutral over prompt and tool item types. Typed families
|
||||
such as ``agenton.layers.types.PlainLayer`` bind those generic slots to a
|
||||
@ -72,14 +72,17 @@ class NoLayerDeps(LayerDeps):
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class LayerContextSignal:
|
||||
"""Signal slot exposed inside a layer context.
|
||||
class LayerControl:
|
||||
"""Control slot passed into a layer entry context.
|
||||
|
||||
Set ``temporary_leave`` before leaving the context to run temporary-leave
|
||||
logic instead of delete logic. A later entry will then run reenter logic.
|
||||
``Layer.enter`` requires the caller to provide this object. Set
|
||||
``tmp_leave`` before leaving the context to run temporary-leave logic
|
||||
instead of delete logic. Reusing that same control on a later entry will
|
||||
consume ``tmp_leave`` and run reenter logic; using a fresh control starts
|
||||
from create logic.
|
||||
"""
|
||||
|
||||
temporary_leave: bool = False
|
||||
tmp_leave: bool = False
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
@ -97,13 +100,12 @@ class Layer[DepsT: LayerDeps, PromptT, ToolT](ABC):
|
||||
properties. They declare required dependencies in the ``DepsT`` container
|
||||
rather than by accepting dependencies in ``__init__``. The default async
|
||||
context manager handles create, delete, temporary-leave, and reenter
|
||||
transitions; layers can override ``context`` when they need to wrap extra
|
||||
transitions; layers can override ``enter`` when they need to wrap extra
|
||||
runtime resources.
|
||||
"""
|
||||
|
||||
deps_type: type[DepsT]
|
||||
deps: DepsT
|
||||
_temporarily_left: bool
|
||||
|
||||
def __init_subclass__(cls) -> None:
|
||||
super().__init_subclass__()
|
||||
@ -149,46 +151,43 @@ class Layer[DepsT: LayerDeps, PromptT, ToolT](ABC):
|
||||
resolved_deps[name] = deps[name]
|
||||
self.deps = self.deps_type(**resolved_deps)
|
||||
|
||||
def context(self) -> AbstractAsyncContextManager[LayerContextSignal]:
|
||||
"""Return the layer's async context manager.
|
||||
def enter(self, control: LayerControl) -> AbstractAsyncContextManager[None]:
|
||||
"""Return the layer's async entry context manager.
|
||||
|
||||
The yielded ``LayerContextSignal`` is the signal slot available to code
|
||||
inside the context. Subclasses can override this to wrap extra async
|
||||
resources around ``self.lifecycle_context()``.
|
||||
``control`` is the lifecycle control slot for this entry. Subclasses can
|
||||
override this to wrap extra async resources around
|
||||
``self.lifecycle_enter(control)``.
|
||||
"""
|
||||
return self.lifecycle_context()
|
||||
return self.lifecycle_enter(control)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifecycle_context(self) -> AsyncIterator[LayerContextSignal]:
|
||||
async def lifecycle_enter(self, control: LayerControl) -> AsyncIterator[None]:
|
||||
"""Run the default create/reenter and delete/temporary-leave lifecycle."""
|
||||
signal = LayerContextSignal()
|
||||
was_temporarily_left = getattr(self, "_temporarily_left", False)
|
||||
self._temporarily_left = False
|
||||
if was_temporarily_left:
|
||||
await self.on_context_reenter(signal)
|
||||
was_tmp_left = control.tmp_leave
|
||||
control.tmp_leave = False
|
||||
if was_tmp_left:
|
||||
await self.on_context_reenter(control)
|
||||
else:
|
||||
await self.on_context_create(signal)
|
||||
await self.on_context_create(control)
|
||||
|
||||
try:
|
||||
yield signal
|
||||
yield
|
||||
finally:
|
||||
if signal.temporary_leave:
|
||||
await self.on_context_temporarily_leave(signal)
|
||||
self._temporarily_left = True
|
||||
if control.tmp_leave:
|
||||
await self.on_context_tmp_leave(control)
|
||||
else:
|
||||
await self.on_context_delete(signal)
|
||||
self._temporarily_left = False
|
||||
await self.on_context_delete(control)
|
||||
|
||||
async def on_context_create(self, signal: LayerContextSignal) -> None:
|
||||
async def on_context_create(self, control: LayerControl) -> None:
|
||||
"""Run when the layer context is entered from a non-temporary state."""
|
||||
|
||||
async def on_context_delete(self, signal: LayerContextSignal) -> None:
|
||||
"""Run when the layer context exits without a temporary-leave signal."""
|
||||
async def on_context_delete(self, control: LayerControl) -> None:
|
||||
"""Run when the layer context exits without ``tmp_leave`` set."""
|
||||
|
||||
async def on_context_temporarily_leave(self, signal: LayerContextSignal) -> None:
|
||||
"""Run when the layer context exits with ``temporary_leave`` set."""
|
||||
async def on_context_tmp_leave(self, control: LayerControl) -> None:
|
||||
"""Run when the layer context exits with ``tmp_leave`` set."""
|
||||
|
||||
async def on_context_reenter(self, signal: LayerContextSignal) -> None:
|
||||
async def on_context_reenter(self, control: LayerControl) -> None:
|
||||
"""Run when the layer context enters after a temporary leave."""
|
||||
|
||||
@property
|
||||
|
||||
1
dify-agent/tests/unit/agenton/compositor/__init__.py
Normal file
1
dify-agent/tests/unit/agenton/compositor/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
|
||||
99
dify-agent/tests/unit/agenton/compositor/test_enter.py
Normal file
99
dify-agent/tests/unit/agenton/compositor/test_enter.py
Normal file
@ -0,0 +1,99 @@
|
||||
import asyncio
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from agenton.compositor import Compositor, CompositorControl
|
||||
from agenton.layers import LayerControl, NoLayerDeps, PlainLayer
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class TraceLayer(PlainLayer[NoLayerDeps]):
|
||||
"""Layer that records lifecycle events observable to tests."""
|
||||
|
||||
events: list[str] = field(default_factory=list)
|
||||
|
||||
@override
|
||||
async def on_context_create(self, control: LayerControl) -> None:
|
||||
self.events.append("create")
|
||||
|
||||
@override
|
||||
async def on_context_tmp_leave(self, control: LayerControl) -> None:
|
||||
self.events.append("tmp_leave")
|
||||
|
||||
@override
|
||||
async def on_context_reenter(self, control: LayerControl) -> None:
|
||||
self.events.append("reenter")
|
||||
|
||||
@override
|
||||
async def on_context_delete(self, control: LayerControl) -> None:
|
||||
self.events.append("delete")
|
||||
|
||||
|
||||
def test_compositor_enter_creates_control_and_applies_tmp_leave_to_all_layers() -> None:
|
||||
first_layer = TraceLayer()
|
||||
second_layer = TraceLayer()
|
||||
compositor: Compositor[str, object] = Compositor(
|
||||
layers=OrderedDict(
|
||||
[
|
||||
("first", first_layer),
|
||||
("second", second_layer),
|
||||
]
|
||||
)
|
||||
)
|
||||
compositor_control = CompositorControl(compositor.layers)
|
||||
|
||||
async def run() -> None:
|
||||
async with compositor.enter(compositor_control) as control:
|
||||
assert control is compositor_control
|
||||
assert list(control.layer_controls) == ["first", "second"]
|
||||
control.tmp_leave = True
|
||||
|
||||
async with compositor.enter(compositor_control):
|
||||
pass
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
assert first_layer.events == ["create", "tmp_leave", "reenter", "delete"]
|
||||
assert second_layer.events == ["create", "tmp_leave", "reenter", "delete"]
|
||||
|
||||
|
||||
def test_compositor_enter_does_not_store_tmp_leave_on_layer() -> None:
|
||||
layer = TraceLayer()
|
||||
compositor: Compositor[str, object] = Compositor(
|
||||
layers=OrderedDict([("trace", layer)])
|
||||
)
|
||||
|
||||
async def run() -> None:
|
||||
async with compositor.enter() as control:
|
||||
control.tmp_leave = True
|
||||
|
||||
async with compositor.enter():
|
||||
pass
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
assert layer.events == ["create", "tmp_leave", "create", "delete"]
|
||||
|
||||
|
||||
def test_compositor_enter_rejects_control_with_mismatched_layer_names() -> None:
|
||||
layer = TraceLayer()
|
||||
compositor: Compositor[str, object] = Compositor(
|
||||
layers=OrderedDict([("trace", layer)])
|
||||
)
|
||||
compositor_control = CompositorControl(["other"])
|
||||
|
||||
async def run() -> None:
|
||||
async with compositor.enter(compositor_control):
|
||||
pass
|
||||
|
||||
try:
|
||||
asyncio.run(run())
|
||||
except ValueError as e:
|
||||
assert str(e) == (
|
||||
"CompositorControl layer names must match compositor layers in order. "
|
||||
"Expected [trace], got [other]."
|
||||
)
|
||||
else:
|
||||
raise AssertionError("Expected ValueError.")
|
||||
Loading…
Reference in New Issue
Block a user