mirror of
https://github.com/langgenius/dify.git
synced 2026-05-13 08:57:28 +08:00
refactor the design of tmp leave control
This commit is contained in:
parent
5a7eb7fdb6
commit
5dfd318907
@ -9,7 +9,7 @@ from inspect import signature
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from agenton.compositor import Compositor, CompositorLayerConfig
|
from agenton.compositor import Compositor, CompositorLayerConfig
|
||||||
from agenton.layers import LayerContextSignal, LayerDeps, NoLayerDeps, PlainLayer
|
from agenton.layers import LayerControl, LayerDeps, NoLayerDeps, PlainLayer
|
||||||
from agenton_collections.plain import DynamicToolsLayer, ObjectLayer, ToolsLayer, with_object
|
from agenton_collections.plain import DynamicToolsLayer, ObjectLayer, ToolsLayer, with_object
|
||||||
|
|
||||||
|
|
||||||
@ -41,19 +41,19 @@ class TraceLayer(PlainLayer[NoLayerDeps]):
|
|||||||
events: list[str] = field(default_factory=list)
|
events: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def on_context_create(self, signal: LayerContextSignal) -> None:
|
async def on_context_create(self, control: LayerControl) -> None:
|
||||||
self.events.append("create")
|
self.events.append("create")
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def on_context_temporarily_leave(self, signal: LayerContextSignal) -> None:
|
async def on_context_tmp_leave(self, control: LayerControl) -> None:
|
||||||
self.events.append("temporary_leave")
|
self.events.append("tmp_leave")
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def on_context_reenter(self, signal: LayerContextSignal) -> None:
|
async def on_context_reenter(self, control: LayerControl) -> None:
|
||||||
self.events.append("reenter")
|
self.events.append("reenter")
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def on_context_delete(self, signal: LayerContextSignal) -> None:
|
async def on_context_delete(self, control: LayerControl) -> None:
|
||||||
self.events.append("delete")
|
self.events.append("delete")
|
||||||
|
|
||||||
|
|
||||||
@ -128,9 +128,9 @@ async def main() -> None:
|
|||||||
print(f"- {tool.__name__}{signature(tool)}")
|
print(f"- {tool.__name__}{signature(tool)}")
|
||||||
print([tool("layer composition") for tool in compositor.tools])
|
print([tool("layer composition") for tool in compositor.tools])
|
||||||
|
|
||||||
async with compositor.context() as context:
|
async with compositor.enter() as lifecycle_control:
|
||||||
context.temporary_leave = True
|
lifecycle_control.tmp_leave = True
|
||||||
async with compositor.context():
|
async with compositor.enter(lifecycle_control):
|
||||||
pass
|
pass
|
||||||
print("\nLifecycle:", trace.events)
|
print("\nLifecycle:", trace.events)
|
||||||
|
|
||||||
|
|||||||
@ -12,14 +12,16 @@ layer names as values. Prompt aggregation depends on insertion order: prefix
|
|||||||
prompts are collected from first to last layer, while suffix prompts are
|
prompts are collected from first to last layer, while suffix prompts are
|
||||||
collected in reverse.
|
collected in reverse.
|
||||||
|
|
||||||
``Compositor.context`` enters layer contexts in compositor order and exits them
|
``Compositor.enter`` enters layers in compositor order and exits them in
|
||||||
in reverse order through ``AsyncExitStack``. It yields per-layer lifecycle
|
reverse order through ``AsyncExitStack``. It accepts an optional
|
||||||
signals so callers can mark individual layers, or all layers, as temporarily
|
``CompositorControl`` whose keys must match the compositor layer names. When
|
||||||
leaving.
|
omitted, one is created from the compositor's layer names. Reuse the same
|
||||||
|
``CompositorControl`` after setting ``tmp_leave`` to reenter those layer
|
||||||
|
contexts.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator, Iterable
|
||||||
from contextlib import AsyncExitStack, asynccontextmanager
|
from contextlib import AsyncExitStack, asynccontextmanager
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
@ -28,7 +30,7 @@ from typing import TYPE_CHECKING, Annotated, Any, Mapping, cast
|
|||||||
from pydantic import AfterValidator, BaseModel, ConfigDict, Field, JsonValue
|
from pydantic import AfterValidator, BaseModel, ConfigDict, Field, JsonValue
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from agenton.layers.base import Layer, LayerContextSignal
|
from agenton.layers.base import Layer, LayerControl
|
||||||
|
|
||||||
|
|
||||||
class ImportedLayerConfig(BaseModel):
|
class ImportedLayerConfig(BaseModel):
|
||||||
@ -127,21 +129,27 @@ def _validate_compositor_config_input(value: CompositorConfigValue) -> Composito
|
|||||||
return _validate_config_model_input(CompositorConfig, value)
|
return _validate_config_model_input(CompositorConfig, value)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
class CompositorControl:
|
||||||
class CompositorContext:
|
"""External controls for layer entry contexts entered by a compositor."""
|
||||||
"""Signal slots for layer contexts entered by a compositor."""
|
|
||||||
|
|
||||||
signals: OrderedDict[str, LayerContextSignal]
|
__slots__ = ("layer_controls",)
|
||||||
|
|
||||||
|
layer_controls: OrderedDict[str, LayerControl]
|
||||||
|
|
||||||
|
def __init__(self, layer_names: Iterable[str]) -> None:
|
||||||
|
self.layer_controls = OrderedDict(
|
||||||
|
(layer_name, LayerControl()) for layer_name in layer_names
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def temporary_leave(self) -> bool:
|
def tmp_leave(self) -> bool:
|
||||||
"""Whether any entered layer is currently marked for temporary leave."""
|
"""Whether any entered layer control is marked for temporary leave."""
|
||||||
return any(signal.temporary_leave for signal in self.signals.values())
|
return any(control.tmp_leave for control in self.layer_controls.values())
|
||||||
|
|
||||||
@temporary_leave.setter
|
@tmp_leave.setter
|
||||||
def temporary_leave(self, value: bool) -> None:
|
def tmp_leave(self, value: bool) -> None:
|
||||||
for signal in self.signals.values():
|
for control in self.layer_controls.values():
|
||||||
signal.temporary_leave = value
|
control.tmp_leave = value
|
||||||
|
|
||||||
|
|
||||||
@dataclass(kw_only=True)
|
@dataclass(kw_only=True)
|
||||||
@ -192,15 +200,33 @@ class Compositor[PromptT, ToolT]:
|
|||||||
self._deps_bound = True
|
self._deps_bound = True
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def context(self) -> AsyncIterator[CompositorContext]:
|
async def enter(
|
||||||
"""Enter each layer context in order and yield their signal slots."""
|
self,
|
||||||
|
control: CompositorControl | None = None,
|
||||||
|
) -> AsyncIterator[CompositorControl]:
|
||||||
|
"""Enter each layer context in order and yield compositor control."""
|
||||||
if not self._deps_bound:
|
if not self._deps_bound:
|
||||||
raise RuntimeError("Compositor deps must be bound before entering context.")
|
raise RuntimeError("Compositor deps must be bound before entering context.")
|
||||||
signals: OrderedDict[str, LayerContextSignal] = OrderedDict()
|
|
||||||
|
if control is None:
|
||||||
|
control = CompositorControl(self.layers)
|
||||||
|
self._validate_control(control)
|
||||||
|
|
||||||
async with AsyncExitStack() as stack:
|
async with AsyncExitStack() as stack:
|
||||||
for layer_name, layer in self.layers.items():
|
for layer_name, layer in self.layers.items():
|
||||||
signals[layer_name] = await stack.enter_async_context(layer.context())
|
await stack.enter_async_context(layer.enter(control.layer_controls[layer_name]))
|
||||||
yield CompositorContext(signals=signals)
|
yield control
|
||||||
|
|
||||||
|
def _validate_control(self, control: CompositorControl) -> None:
|
||||||
|
expected_layer_names = tuple(self.layers)
|
||||||
|
actual_layer_names = tuple(control.layer_controls)
|
||||||
|
if actual_layer_names != expected_layer_names:
|
||||||
|
expected = ", ".join(expected_layer_names)
|
||||||
|
actual = ", ".join(actual_layer_names)
|
||||||
|
raise ValueError(
|
||||||
|
"CompositorControl layer names must match compositor layers in order. "
|
||||||
|
f"Expected [{expected}], got [{actual}]."
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def prompts(self) -> list[PromptT]:
|
def prompts(self) -> list[PromptT]:
|
||||||
@ -224,7 +250,7 @@ __all__ = [
|
|||||||
"CompositorConfig",
|
"CompositorConfig",
|
||||||
"CompositorConfigValue",
|
"CompositorConfigValue",
|
||||||
"CompositorLayerConfigInput",
|
"CompositorLayerConfigInput",
|
||||||
"CompositorContext",
|
"CompositorControl",
|
||||||
"CompositorLayerConfig",
|
"CompositorLayerConfig",
|
||||||
"CompositorLayerConfigValue",
|
"CompositorLayerConfigValue",
|
||||||
"ImportedLayerConfig",
|
"ImportedLayerConfig",
|
||||||
|
|||||||
@ -5,7 +5,7 @@
|
|||||||
families while keeping concrete reusable layers in ``agenton_collections``.
|
families while keeping concrete reusable layers in ``agenton_collections``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from agenton.layers.base import Layer, LayerContextSignal, LayerDeps, NoLayerDeps
|
from agenton.layers.base import Layer, LayerControl, LayerDeps, NoLayerDeps
|
||||||
from agenton.layers.types import (
|
from agenton.layers.types import (
|
||||||
PlainLayer,
|
PlainLayer,
|
||||||
PlainPrompt,
|
PlainPrompt,
|
||||||
@ -17,8 +17,8 @@ from agenton.layers.types import (
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Layer",
|
"Layer",
|
||||||
"LayerContextSignal",
|
|
||||||
"LayerDeps",
|
"LayerDeps",
|
||||||
|
"LayerControl",
|
||||||
"NoLayerDeps",
|
"NoLayerDeps",
|
||||||
"PlainLayer",
|
"PlainLayer",
|
||||||
"PlainPrompt",
|
"PlainPrompt",
|
||||||
|
|||||||
@ -11,10 +11,10 @@ inheritance patterns.
|
|||||||
implementations should treat ``self.deps`` as unavailable until a compositor or
|
implementations should treat ``self.deps`` as unavailable until a compositor or
|
||||||
caller has resolved and bound dependencies.
|
caller has resolved and bound dependencies.
|
||||||
|
|
||||||
Layer async contexts use a bool signal to distinguish permanent exits from
|
Layer async entry uses a caller-provided bool control to distinguish permanent
|
||||||
temporary exits. A normal first entry runs create logic and a normal exit runs
|
exits from temporary exits. The control is also the external lifecycle state:
|
||||||
delete logic; when the signal is set, exit runs temporary-leave logic and the
|
reuse a ``tmp_leave`` control to reenter, or pass a fresh control to start from
|
||||||
next entry runs reenter logic.
|
create logic.
|
||||||
|
|
||||||
``Layer`` is framework-neutral over prompt and tool item types. Typed families
|
``Layer`` is framework-neutral over prompt and tool item types. Typed families
|
||||||
such as ``agenton.layers.types.PlainLayer`` bind those generic slots to a
|
such as ``agenton.layers.types.PlainLayer`` bind those generic slots to a
|
||||||
@ -72,14 +72,17 @@ class NoLayerDeps(LayerDeps):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
class LayerContextSignal:
|
class LayerControl:
|
||||||
"""Signal slot exposed inside a layer context.
|
"""Control slot passed into a layer entry context.
|
||||||
|
|
||||||
Set ``temporary_leave`` before leaving the context to run temporary-leave
|
``Layer.enter`` requires the caller to provide this object. Set
|
||||||
logic instead of delete logic. A later entry will then run reenter logic.
|
``tmp_leave`` before leaving the context to run temporary-leave logic
|
||||||
|
instead of delete logic. Reusing that same control on a later entry will
|
||||||
|
consume ``tmp_leave`` and run reenter logic; using a fresh control starts
|
||||||
|
from create logic.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
temporary_leave: bool = False
|
tmp_leave: bool = False
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, slots=True)
|
@dataclass(frozen=True, slots=True)
|
||||||
@ -97,13 +100,12 @@ class Layer[DepsT: LayerDeps, PromptT, ToolT](ABC):
|
|||||||
properties. They declare required dependencies in the ``DepsT`` container
|
properties. They declare required dependencies in the ``DepsT`` container
|
||||||
rather than by accepting dependencies in ``__init__``. The default async
|
rather than by accepting dependencies in ``__init__``. The default async
|
||||||
context manager handles create, delete, temporary-leave, and reenter
|
context manager handles create, delete, temporary-leave, and reenter
|
||||||
transitions; layers can override ``context`` when they need to wrap extra
|
transitions; layers can override ``enter`` when they need to wrap extra
|
||||||
runtime resources.
|
runtime resources.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
deps_type: type[DepsT]
|
deps_type: type[DepsT]
|
||||||
deps: DepsT
|
deps: DepsT
|
||||||
_temporarily_left: bool
|
|
||||||
|
|
||||||
def __init_subclass__(cls) -> None:
|
def __init_subclass__(cls) -> None:
|
||||||
super().__init_subclass__()
|
super().__init_subclass__()
|
||||||
@ -149,46 +151,43 @@ class Layer[DepsT: LayerDeps, PromptT, ToolT](ABC):
|
|||||||
resolved_deps[name] = deps[name]
|
resolved_deps[name] = deps[name]
|
||||||
self.deps = self.deps_type(**resolved_deps)
|
self.deps = self.deps_type(**resolved_deps)
|
||||||
|
|
||||||
def context(self) -> AbstractAsyncContextManager[LayerContextSignal]:
|
def enter(self, control: LayerControl) -> AbstractAsyncContextManager[None]:
|
||||||
"""Return the layer's async context manager.
|
"""Return the layer's async entry context manager.
|
||||||
|
|
||||||
The yielded ``LayerContextSignal`` is the signal slot available to code
|
``control`` is the lifecycle control slot for this entry. Subclasses can
|
||||||
inside the context. Subclasses can override this to wrap extra async
|
override this to wrap extra async resources around
|
||||||
resources around ``self.lifecycle_context()``.
|
``self.lifecycle_enter(control)``.
|
||||||
"""
|
"""
|
||||||
return self.lifecycle_context()
|
return self.lifecycle_enter(control)
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifecycle_context(self) -> AsyncIterator[LayerContextSignal]:
|
async def lifecycle_enter(self, control: LayerControl) -> AsyncIterator[None]:
|
||||||
"""Run the default create/reenter and delete/temporary-leave lifecycle."""
|
"""Run the default create/reenter and delete/temporary-leave lifecycle."""
|
||||||
signal = LayerContextSignal()
|
was_tmp_left = control.tmp_leave
|
||||||
was_temporarily_left = getattr(self, "_temporarily_left", False)
|
control.tmp_leave = False
|
||||||
self._temporarily_left = False
|
if was_tmp_left:
|
||||||
if was_temporarily_left:
|
await self.on_context_reenter(control)
|
||||||
await self.on_context_reenter(signal)
|
|
||||||
else:
|
else:
|
||||||
await self.on_context_create(signal)
|
await self.on_context_create(control)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield signal
|
yield
|
||||||
finally:
|
finally:
|
||||||
if signal.temporary_leave:
|
if control.tmp_leave:
|
||||||
await self.on_context_temporarily_leave(signal)
|
await self.on_context_tmp_leave(control)
|
||||||
self._temporarily_left = True
|
|
||||||
else:
|
else:
|
||||||
await self.on_context_delete(signal)
|
await self.on_context_delete(control)
|
||||||
self._temporarily_left = False
|
|
||||||
|
|
||||||
async def on_context_create(self, signal: LayerContextSignal) -> None:
|
async def on_context_create(self, control: LayerControl) -> None:
|
||||||
"""Run when the layer context is entered from a non-temporary state."""
|
"""Run when the layer context is entered from a non-temporary state."""
|
||||||
|
|
||||||
async def on_context_delete(self, signal: LayerContextSignal) -> None:
|
async def on_context_delete(self, control: LayerControl) -> None:
|
||||||
"""Run when the layer context exits without a temporary-leave signal."""
|
"""Run when the layer context exits without ``tmp_leave`` set."""
|
||||||
|
|
||||||
async def on_context_temporarily_leave(self, signal: LayerContextSignal) -> None:
|
async def on_context_tmp_leave(self, control: LayerControl) -> None:
|
||||||
"""Run when the layer context exits with ``temporary_leave`` set."""
|
"""Run when the layer context exits with ``tmp_leave`` set."""
|
||||||
|
|
||||||
async def on_context_reenter(self, signal: LayerContextSignal) -> None:
|
async def on_context_reenter(self, control: LayerControl) -> None:
|
||||||
"""Run when the layer context enters after a temporary leave."""
|
"""Run when the layer context enters after a temporary leave."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
1
dify-agent/tests/unit/agenton/compositor/__init__.py
Normal file
1
dify-agent/tests/unit/agenton/compositor/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
|
||||||
99
dify-agent/tests/unit/agenton/compositor/test_enter.py
Normal file
99
dify-agent/tests/unit/agenton/compositor/test_enter.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
import asyncio
|
||||||
|
from collections import OrderedDict
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from agenton.compositor import Compositor, CompositorControl
|
||||||
|
from agenton.layers import LayerControl, NoLayerDeps, PlainLayer
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class TraceLayer(PlainLayer[NoLayerDeps]):
|
||||||
|
"""Layer that records lifecycle events observable to tests."""
|
||||||
|
|
||||||
|
events: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def on_context_create(self, control: LayerControl) -> None:
|
||||||
|
self.events.append("create")
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def on_context_tmp_leave(self, control: LayerControl) -> None:
|
||||||
|
self.events.append("tmp_leave")
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def on_context_reenter(self, control: LayerControl) -> None:
|
||||||
|
self.events.append("reenter")
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def on_context_delete(self, control: LayerControl) -> None:
|
||||||
|
self.events.append("delete")
|
||||||
|
|
||||||
|
|
||||||
|
def test_compositor_enter_creates_control_and_applies_tmp_leave_to_all_layers() -> None:
|
||||||
|
first_layer = TraceLayer()
|
||||||
|
second_layer = TraceLayer()
|
||||||
|
compositor: Compositor[str, object] = Compositor(
|
||||||
|
layers=OrderedDict(
|
||||||
|
[
|
||||||
|
("first", first_layer),
|
||||||
|
("second", second_layer),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
compositor_control = CompositorControl(compositor.layers)
|
||||||
|
|
||||||
|
async def run() -> None:
|
||||||
|
async with compositor.enter(compositor_control) as control:
|
||||||
|
assert control is compositor_control
|
||||||
|
assert list(control.layer_controls) == ["first", "second"]
|
||||||
|
control.tmp_leave = True
|
||||||
|
|
||||||
|
async with compositor.enter(compositor_control):
|
||||||
|
pass
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
|
|
||||||
|
assert first_layer.events == ["create", "tmp_leave", "reenter", "delete"]
|
||||||
|
assert second_layer.events == ["create", "tmp_leave", "reenter", "delete"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_compositor_enter_does_not_store_tmp_leave_on_layer() -> None:
|
||||||
|
layer = TraceLayer()
|
||||||
|
compositor: Compositor[str, object] = Compositor(
|
||||||
|
layers=OrderedDict([("trace", layer)])
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run() -> None:
|
||||||
|
async with compositor.enter() as control:
|
||||||
|
control.tmp_leave = True
|
||||||
|
|
||||||
|
async with compositor.enter():
|
||||||
|
pass
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
|
|
||||||
|
assert layer.events == ["create", "tmp_leave", "create", "delete"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_compositor_enter_rejects_control_with_mismatched_layer_names() -> None:
|
||||||
|
layer = TraceLayer()
|
||||||
|
compositor: Compositor[str, object] = Compositor(
|
||||||
|
layers=OrderedDict([("trace", layer)])
|
||||||
|
)
|
||||||
|
compositor_control = CompositorControl(["other"])
|
||||||
|
|
||||||
|
async def run() -> None:
|
||||||
|
async with compositor.enter(compositor_control):
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
asyncio.run(run())
|
||||||
|
except ValueError as e:
|
||||||
|
assert str(e) == (
|
||||||
|
"CompositorControl layer names must match compositor layers in order. "
|
||||||
|
"Expected [trace], got [other]."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise AssertionError("Expected ValueError.")
|
||||||
Loading…
Reference in New Issue
Block a user