From f316d19be60e95c39cb013f376ab19b823bdbe94 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9B=90=E7=B2=92=20Yanli?= Date: Thu, 7 May 2026 22:42:07 +0800 Subject: [PATCH] add schema-backed agenton sessions --- dify-agent/README.md | 3 + dify-agent/docs/agenton/README.md | 65 +++ dify-agent/examples/agenton/basics.py | 67 ++-- .../examples/agenton/pydantic_ai_bridge.py | 55 ++- .../examples/agenton/session_snapshot.py | 72 ++++ dify-agent/src/agenton/compositor/__init__.py | 378 +++++++++++++----- dify-agent/src/agenton/layers/__init__.py | 15 +- dify-agent/src/agenton/layers/base.py | 267 ++++++++++--- dify-agent/src/agenton/layers/types.py | 43 +- .../layers/plain/__init__.py | 3 +- .../agenton_collections/layers/plain/basic.py | 37 +- .../compositor/test_builder_snapshot.py | 257 ++++++++++++ .../local/agenton/compositor/test_enter.py | 29 +- .../agenton/layers/test_schema_inference.py | 94 +++++ .../local/examples/test_agenton_examples.py | 8 + 15 files changed, 1147 insertions(+), 246 deletions(-) create mode 100644 dify-agent/docs/agenton/README.md create mode 100644 dify-agent/examples/agenton/session_snapshot.py create mode 100644 dify-agent/tests/local/agenton/compositor/test_builder_snapshot.py create mode 100644 dify-agent/tests/local/agenton/layers/test_schema_inference.py diff --git a/dify-agent/README.md b/dify-agent/README.md index e69de29bb2..08ba651da4 100644 --- a/dify-agent/README.md +++ b/dify-agent/README.md @@ -0,0 +1,3 @@ +# Dify Agent + +Agenton documentation lives in [`docs/agenton/`](docs/agenton/). diff --git a/dify-agent/docs/agenton/README.md b/dify-agent/docs/agenton/README.md new file mode 100644 index 0000000000..874377aa89 --- /dev/null +++ b/dify-agent/docs/agenton/README.md @@ -0,0 +1,65 @@ +# Agenton configuration and sessions + +Agenton composes shared `Layer` instances into a named graph. Treat layer +instances as reusable capability definitions: config and dependency declarations +belong on the layer class or instance, while per-session runtime values belong +on the `LayerControl` created for that layer in a `CompositorSession`. + +## Config, runtime state, and runtime handles + +- **Config** is serializable graph input. Config-constructible layers declare a + `type_id` and a Pydantic `config_type`; builders validate node config before + calling `Layer.from_config(validated_config)`. +- **Runtime state** is serializable per-layer/per-session state. Layers declare a + Pydantic `runtime_state_type`; session snapshots persist this model with + `model_dump(mode="json")`. +- **Runtime handles** are live Python objects such as clients, open files, or + process handles. Layers declare a Pydantic `runtime_handles_type` with + `arbitrary_types_allowed=True`. Handles are never serialized; resume hooks + should rehydrate them from runtime state. + +`Layer.__init_subclass__` infers `deps_type`, `config_type`, +`runtime_state_type`, and `runtime_handles_type` from generic base arguments +when possible. For example, `PlainLayer[NoLayerDeps, MyConfig, MyState, +MyHandles]` automatically installs those Pydantic schemas. Omitted schema slots +default to `EmptyLayerConfig`, `EmptyRuntimeState`, and `EmptyRuntimeHandles`. +Lifecycle hooks can annotate controls as `LayerControl[MyState, MyHandles]` to +get static checking and IDE completion for runtime state and handles. + +## Registry and builder + +Register config-constructible layers manually: + +```python +registry = LayerRegistry() +registry.register_layer(PromptLayer) # uses PromptLayer.type_id == "plain.prompt" +``` + +Use `CompositorBuilder` to mix serializable config nodes with live instances: + +```python +compositor = ( + CompositorBuilder(registry) + .add_config({"layers": [{"name": "prompt", "type": "plain.prompt", "config": {"prefix": "Hi"}}]}) + .add_instance(name="profile", layer=ObjectLayer(profile)) + .build() +) +``` + +Use `.add_instance()` for layers that require Python objects or callables, such +as `ObjectLayer`, `ToolsLayer`, and dynamic tool layers. + +## Session snapshot and restore + +`Compositor.snapshot_session(session)` serializes non-active sessions, including +layer lifecycle state and runtime state. It rejects active sessions because live +handles cannot be snapshotted safely. Restore with +`Compositor.session_from_snapshot(snapshot)`; restored controls validate runtime +state with each layer schema and initialize empty runtime handles. Suspended +sessions resume through `on_context_resume`, where handles should be hydrated +from the restored runtime state. + +Create sessions with `Compositor.new_session()` or +`Compositor.session_from_snapshot()`. `Compositor.enter()` validates that every +session control uses the target layer's runtime state and handle schemas before +any lifecycle hook runs. diff --git a/dify-agent/examples/agenton/basics.py b/dify-agent/examples/agenton/basics.py index 219469b13e..22d501c793 100644 --- a/dify-agent/examples/agenton/basics.py +++ b/dify-agent/examples/agenton/basics.py @@ -8,10 +8,9 @@ from inspect import signature from typing_extensions import override -from agenton.compositor import Compositor, CompositorLayerConfig +from agenton.compositor import CompositorBuilder, LayerRegistry from agenton.layers import LayerControl, LayerDeps, NoLayerDeps, PlainLayer -from agenton.layers.types import PlainPromptType, PlainToolType -from agenton_collections.layers.plain import DynamicToolsLayer, ObjectLayer, ToolsLayer, with_object +from agenton_collections.layers.plain import DynamicToolsLayer, ObjectLayer, PromptLayer, ToolsLayer, with_object @dataclass(frozen=True, slots=True) @@ -75,51 +74,41 @@ async def main() -> None: ) trace = TraceLayer() - compositor = Compositor[PlainPromptType, PlainToolType].from_config( - { - "layers": [ - { - "name": "base_prompt", - "layer": { - "import_path": "agenton_collections.layers.plain:PromptLayer", + registry = LayerRegistry() + registry.register_layer(PromptLayer) + compositor = ( + CompositorBuilder(registry) + .add_config( + { + "layers": [ + { + "name": "base_prompt", + "type": "plain.prompt", "config": { "prefix": "Use config dicts for serializable layers.", "suffix": "Before finalizing, make the result easy to scan.", }, }, - }, - { - "name": "extra_prompt", - "layer": { - "import_path": "agenton_collections.layers.plain:PromptLayer", + { + "name": "extra_prompt", + "type": "plain.prompt", "config": { "prefix": "Use constructed instances for objects, local code, and callables.", }, }, - }, - CompositorLayerConfig( - name="profile", - layer=ObjectLayer[AgentProfile](profile), - ), - CompositorLayerConfig( - name="profile_prompt", - # deps maps dependency field names to layer names only when - # they differ. - # deps={"profile": "profile"}, - layer=ProfilePromptLayer(), - ), - CompositorLayerConfig( - name="tools", - layer=ToolsLayer(tool_entries=(count_words,)), - ), - CompositorLayerConfig( - name="dynamic_tools", - deps={"object_layer": "profile"}, - layer=DynamicToolsLayer[AgentProfile](tool_entries=(write_tagline,)), - ), - CompositorLayerConfig(name="trace", layer=trace), - ] - }, + ] + } + ) + .add_instance(name="profile", layer=ObjectLayer[AgentProfile](profile)) + .add_instance(name="profile_prompt", layer=ProfilePromptLayer()) + .add_instance(name="tools", layer=ToolsLayer(tool_entries=(count_words,))) + .add_instance( + name="dynamic_tools", + deps={"object_layer": "profile"}, + layer=DynamicToolsLayer[AgentProfile](tool_entries=(write_tagline,)), + ) + .add_instance(name="trace", layer=trace) + .build() ) print("Prompts:") diff --git a/dify-agent/examples/agenton/pydantic_ai_bridge.py b/dify-agent/examples/agenton/pydantic_ai_bridge.py index be76bb2e4a..730fceac76 100644 --- a/dify-agent/examples/agenton/pydantic_ai_bridge.py +++ b/dify-agent/examples/agenton/pydantic_ai_bridge.py @@ -12,9 +12,8 @@ from pydantic_ai.messages import BuiltinToolCallPart, ModelMessage, ToolCallPart from pydantic_ai.models.openai import OpenAIChatModel # pyright: ignore[reportDeprecated] from pydantic_ai.models.test import TestModel -from agenton.compositor import Compositor, CompositorLayerConfig -from agenton.layers.types import AllPromptTypes, AllToolTypes, PydanticAIPrompt, PydanticAITool -from agenton_collections.layers.plain import ObjectLayer, ToolsLayer +from agenton.compositor import CompositorBuilder, LayerRegistry +from agenton_collections.layers.plain import ObjectLayer, PromptLayer, ToolsLayer from agenton_collections.layers.pydantic_ai import PydanticAIBridgeLayer from agenton_collections.transformers import PYDANTIC_AI_TRANSFORMERS @@ -55,40 +54,32 @@ async def main() -> None: tool_entries=(write_tagline,), ) - compositor = Compositor[ - PydanticAIPrompt[object], - PydanticAITool[object], - AllPromptTypes, - AllToolTypes, - ].from_config( - { - "layers": [ - { - "name": "base_prompt", - "layer": { - "import_path": "agenton_collections.layers.plain:PromptLayer", + registry = LayerRegistry() + registry.register_layer(PromptLayer) + compositor = ( + CompositorBuilder(registry) + .add_config( + { + "layers": [ + { + "name": "base_prompt", + "type": "plain.prompt", "config": { "prefix": "Use the available tools before answering.", "suffix": "Return concise, inspectable output.", }, }, - }, - CompositorLayerConfig( - name="profile", - layer=ObjectLayer[AgentProfile](profile), - ), - CompositorLayerConfig( - name="plain_tools", - layer=ToolsLayer(tool_entries=(count_words,)), - ), - CompositorLayerConfig( - name="pydantic_ai_bridge", - deps={"object_layer": "profile"}, - layer=pydantic_ai_bridge, - ), - ] - }, - **PYDANTIC_AI_TRANSFORMERS, + ] + } + ) + .add_instance(name="profile", layer=ObjectLayer[AgentProfile](profile)) + .add_instance(name="plain_tools", layer=ToolsLayer(tool_entries=(count_words,))) + .add_instance( + name="pydantic_ai_bridge", + deps={"object_layer": "profile"}, + layer=pydantic_ai_bridge, + ) + .build(**PYDANTIC_AI_TRANSFORMERS) ) async with compositor.enter(): diff --git a/dify-agent/examples/agenton/session_snapshot.py b/dify-agent/examples/agenton/session_snapshot.py new file mode 100644 index 0000000000..d282c02597 --- /dev/null +++ b/dify-agent/examples/agenton/session_snapshot.py @@ -0,0 +1,72 @@ +"""Run with: uv run --project dify-agent python examples/agenton/session_snapshot.py.""" + +from __future__ import annotations + +import asyncio +from collections import OrderedDict +from dataclasses import dataclass +from typing import ClassVar + +from pydantic import BaseModel, ConfigDict +from typing_extensions import override + +from agenton.compositor import Compositor +from agenton.layers import LayerControl, NoLayerDeps, PlainLayer, PlainPromptType, PlainToolType + + +class ConnectionState(BaseModel): + connection_id: str = "demo-connection" + + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + +class ConnectionHandle: + def __init__(self, connection_id: str) -> None: + self.connection_id = connection_id + + +class ConnectionHandles(BaseModel): + connection: ConnectionHandle | None = None + + model_config = ConfigDict(extra="forbid", validate_assignment=True, arbitrary_types_allowed=True) + + +@dataclass(slots=True) +class ConnectionLayer(PlainLayer[NoLayerDeps]): + runtime_state_type: ClassVar[type[BaseModel]] = ConnectionState + runtime_handles_type: ClassVar[type[BaseModel]] = ConnectionHandles + + @override + async def on_context_create(self, control: LayerControl) -> None: + assert isinstance(control.runtime_state, ConnectionState) + assert isinstance(control.runtime_handles, ConnectionHandles) + control.runtime_handles.connection = ConnectionHandle(control.runtime_state.connection_id) + + @override + async def on_context_resume(self, control: LayerControl) -> None: + assert isinstance(control.runtime_state, ConnectionState) + assert isinstance(control.runtime_handles, ConnectionHandles) + control.runtime_handles.connection = ConnectionHandle(f"restored:{control.runtime_state.connection_id}") + + +async def main() -> None: + compositor: Compositor[PlainPromptType, PlainToolType] = Compositor( + layers=OrderedDict([("connection", ConnectionLayer())]) + ) + session = compositor.new_session() + async with compositor.enter(session) as active_session: + active_session.suspend_on_exit() + + snapshot = compositor.snapshot_session(session) + print("Snapshot:", snapshot.model_dump(mode="json")) + + restored = compositor.session_from_snapshot(snapshot) + async with compositor.enter(restored): + handles = restored.layer("connection").runtime_handles + assert isinstance(handles, ConnectionHandles) + assert handles.connection is not None + print("Rehydrated handle:", handles.connection.connection_id) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/dify-agent/src/agenton/compositor/__init__.py b/dify-agent/src/agenton/compositor/__init__.py index 3560f29e9b..cb23e69cc8 100644 --- a/dify-agent/src/agenton/compositor/__init__.py +++ b/dify-agent/src/agenton/compositor/__init__.py @@ -16,6 +16,10 @@ layer names as values. Prompt aggregation depends on insertion order: prefix prompts are collected from first to last layer, while suffix prompts are collected in reverse. +Serializable graph config uses registry type ids rather than import paths. +``CompositorBuilder`` resolves config nodes through ``LayerRegistry`` and can +mix those nodes with live layer instances for Python objects and callables. + ``Compositor.enter`` enters layers in compositor order and exits them in reverse order through ``AsyncExitStack``. It accepts an optional ``CompositorSession`` whose layer controls must match the compositor layer names and order. When @@ -30,13 +34,12 @@ returns those wrapped items unchanged. """ from collections import OrderedDict -from collections.abc import AsyncIterator, Callable, Iterable, Sequence +from collections.abc import AsyncIterator, Callable, Iterable, Mapping as MappingABC, Sequence from contextlib import AsyncExitStack, asynccontextmanager from dataclasses import dataclass, field -from importlib import import_module -from typing import TYPE_CHECKING, Annotated, Any, Generic, Mapping, TypedDict, cast +from typing import Any, Generic, Mapping, TypedDict, cast -from pydantic import AfterValidator, BaseModel, ConfigDict, Field, JsonValue +from pydantic import BaseModel, ConfigDict, Field, JsonValue from typing_extensions import Self, TypeVar from agenton.layers.base import Layer, LayerControl, LifecycleState @@ -48,31 +51,6 @@ LayerPromptT = TypeVar("LayerPromptT", default=AllPromptTypes) LayerToolT = TypeVar("LayerToolT", default=AllToolTypes) -class ImportedLayerConfig(BaseModel): - """Config for constructing one layer from an import path.""" - - import_path: str - config: Any = None - - model_config = ConfigDict(arbitrary_types_allowed=True) - - def create_layer(self) -> Layer[Any, Any, Any]: - """Import the target layer class and create it from config.""" - try: - import_module_name, import_target = self.import_path.rsplit(":", 1) - except ValueError as e: - raise ValueError( - f"Invalid import string '{self.import_path}'. " - "It should be in the format 'module:ClassName'." - ) from e - - layer_t = getattr(import_module(import_module_name), import_target) - if not isinstance(layer_t, type) or not issubclass(layer_t, Layer): - raise TypeError(f"Imported target '{self.import_path}' must be a Layer subclass.") - return layer_t.from_config(config=self.config) - - -LayerSpec = Layer[Any, Any, Any] | ImportedLayerConfig type CompositorTransformer[InputT, OutputT] = Callable[[Sequence[InputT]], Sequence[OutputT]] @@ -98,53 +76,30 @@ def _validate_config_model_input[ModelT: BaseModel]( return model_type.model_validate(value) -class CompositorLayerConfig(BaseModel): - """Config entry for one named layer in a compositor. - - ``layer`` may be either an already constructed layer instance or an - ``ImportedLayerConfig``. Direct instances are already initialized, so config - for imported layers lives inside ``ImportedLayerConfig`` instead of beside - the graph node fields. - """ +class LayerNodeConfig(BaseModel): + """Serializable config for one registry-backed layer node.""" name: str + type: str + config: JsonValue = Field(default_factory=dict) deps: Mapping[str, str] = Field(default_factory=dict) - layer: LayerSpec + metadata: Mapping[str, JsonValue] = Field(default_factory=dict) - model_config = ConfigDict(arbitrary_types_allowed=True) - - def create_layer(self) -> Layer[Any, Any, Any]: - """Create or return the configured layer instance.""" - if isinstance(self.layer, Layer): - return self.layer - return self.layer.create_layer() - - -type CompositorLayerConfigValue = _ConfigModelValue[CompositorLayerConfig] - - -def _validate_layer_config_input(value: CompositorLayerConfigValue) -> CompositorLayerConfig: - return _validate_config_model_input(CompositorLayerConfig, value) - - -type CompositorLayerConfigInput = Annotated[ - CompositorLayerConfigValue, - AfterValidator(_validate_layer_config_input), -] + model_config = ConfigDict(extra="forbid") class CompositorConfig(BaseModel): """Serializable config for constructing a compositor graph. - ``layers`` accepts ready-made ``CompositorLayerConfig`` instances, raw JSON - values, or JSON-encoded strings/bytes. After validation, callers always see - normalized ``CompositorLayerConfig`` objects. + The graph references layer implementations by registry type id. Live Python + objects and callables are intentionally excluded; compose those with + ``CompositorBuilder.add_instance``. """ - if TYPE_CHECKING: - layers: list[CompositorLayerConfig] - else: - layers: list[CompositorLayerConfigInput] + schema_version: int = 1 + layers: list[LayerNodeConfig] + + model_config = ConfigDict(extra="forbid") type CompositorConfigValue = _ConfigModelValue[CompositorConfig] | Mapping[str, object] @@ -154,24 +109,88 @@ def _validate_compositor_config_input(value: CompositorConfigValue) -> Composito return _validate_config_model_input(CompositorConfig, value) +@dataclass(frozen=True, slots=True) +class LayerDescriptor: + """Registry descriptor inferred from a layer class.""" + + type_id: str + layer_type: type[Layer[Any, Any, Any, Any, Any, Any]] + config_type: type[BaseModel] + runtime_state_type: type[BaseModel] + runtime_handles_type: type[BaseModel] + + +class LayerRegistry: + """Manual registry for config-constructible layer classes. + + Registration infers config and runtime schemas from layer class attributes. + A registered layer must have a type id, either declared as ``type_id`` on the + class or supplied to ``register_layer``. + """ + + __slots__ = ("_descriptors",) + + _descriptors: dict[str, LayerDescriptor] + + def __init__(self) -> None: + self._descriptors = {} + + def register_layer( + self, + layer_type: type[Layer[Any, Any, Any, Any, Any, Any]], + *, + type_id: str | None = None, + ) -> None: + """Register ``layer_type`` under its inferred or explicit type id.""" + resolved_type_id = type_id or layer_type.type_id + if resolved_type_id is not None and not isinstance(resolved_type_id, str): + raise TypeError(f"Layer type id for '{layer_type.__qualname__}' must be a string.") + if resolved_type_id is None or not resolved_type_id: + raise ValueError(f"Layer '{layer_type.__qualname__}' must declare a type_id or be registered with one.") + if resolved_type_id in self._descriptors: + raise ValueError(f"Layer type id '{resolved_type_id}' is already registered.") + self._descriptors[resolved_type_id] = LayerDescriptor( + type_id=resolved_type_id, + layer_type=layer_type, + config_type=layer_type.config_type, + runtime_state_type=layer_type.runtime_state_type, + runtime_handles_type=layer_type.runtime_handles_type, + ) + + def resolve(self, type_id: str) -> LayerDescriptor: + """Return the descriptor for ``type_id`` or raise ``KeyError``.""" + try: + return self._descriptors[type_id] + except KeyError as e: + raise KeyError(f"Layer type id '{type_id}' is not registered.") from e + + def descriptors(self) -> Mapping[str, LayerDescriptor]: + """Return registered descriptors keyed by type id.""" + return dict(self._descriptors) + + class CompositorSession: """External lifecycle session for layer contexts entered by a compositor. A session owns one ``LayerControl`` per compositor layer name, preserving - compositor order. Broadcast methods are convenience APIs for setting every - layer's per-entry exit intent; ``layer`` allows explicit per-layer control - when callers need partial suspend/delete behavior. A mixed session with any - closed layer cannot be entered again because compositor entry is all-or-none. + compositor order. Controls must be created from the matching layer schemas; + prefer ``Compositor.new_session`` or ``Compositor.session_from_snapshot`` for + public session construction. Broadcast methods are convenience APIs for + setting every layer's per-entry exit intent; ``layer`` allows explicit + per-layer control when callers need partial suspend/delete behavior. A mixed + session with any closed layer cannot be entered again because compositor + entry is all-or-none. """ __slots__ = ("layer_controls",) layer_controls: OrderedDict[str, LayerControl] - def __init__(self, layer_names: Iterable[str]) -> None: - self.layer_controls = OrderedDict( - (layer_name, LayerControl()) for layer_name in layer_names - ) + def __init__(self, layer_names: Iterable[str] | Mapping[str, LayerControl]) -> None: + if isinstance(layer_names, MappingABC): + self.layer_controls = OrderedDict(layer_names.items()) + return + self.layer_controls = OrderedDict((layer_name, LayerControl()) for layer_name in layer_names) def suspend_on_exit(self) -> None: """Request suspend behavior for every layer when this entry exits.""" @@ -188,6 +207,124 @@ class CompositorSession: return self.layer_controls[name] +class LayerSessionSnapshot(BaseModel): + """Serializable snapshot for one layer control.""" + + name: str + state: LifecycleState + runtime_state: dict[str, JsonValue] + + model_config = ConfigDict(extra="forbid") + + +class CompositorSessionSnapshot(BaseModel): + """Serializable compositor session snapshot. + + Snapshots include runtime state only. Live runtime handles are intentionally + excluded and must be rehydrated by resume hooks using runtime state. + """ + + schema_version: int = 1 + layers: list[LayerSessionSnapshot] + + model_config = ConfigDict(extra="forbid") + + +@dataclass(frozen=True, slots=True) +class _LayerBuildEntry: + name: str + layer: Layer[Any, Any, Any, Any, Any, Any] + deps: Mapping[str, str] + + +class CompositorBuilder: + """Build compositors from registry config nodes and live instances.""" + + __slots__ = ("_registry", "_entries") + + _registry: LayerRegistry + _entries: list[_LayerBuildEntry] + + def __init__(self, registry: LayerRegistry) -> None: + self._registry = registry + self._entries = [] + + def add_config(self, config: CompositorConfigValue) -> Self: + """Add all layers from a serializable compositor config.""" + conf = _validate_compositor_config_input(config) + if conf.schema_version != 1: + raise ValueError(f"Unsupported compositor config schema_version: {conf.schema_version}.") + for layer_conf in conf.layers: + self.add_config_layer( + name=layer_conf.name, + type=layer_conf.type, + config=layer_conf.config, + deps=layer_conf.deps, + ) + return self + + def add_config_layer( + self, + *, + name: str, + type: str, + config: object | None = None, + deps: Mapping[str, str] | None = None, + ) -> Self: + """Resolve, validate, and add one registry-backed layer config node.""" + descriptor = self._registry.resolve(type) + raw_config = {} if config is None else config + validated_config = descriptor.config_type.model_validate(raw_config) + layer = descriptor.layer_type.from_config(cast(Any, validated_config)) + self.add_instance(name=name, layer=layer, deps=deps) + return self + + def add_instance( + self, + *, + name: str, + layer: Layer[Any, Any, Any, Any, Any, Any], + deps: Mapping[str, str] | None = None, + ) -> Self: + """Add a live layer instance, useful for Python objects and callables.""" + self._entries.append(_LayerBuildEntry(name=name, layer=layer, deps=dict(deps or {}))) + return self + + def build[PromptT, ToolT, LayerPromptT, LayerToolT]( + self, + *, + prompt_transformer: CompositorTransformer[LayerPromptT, PromptT] | None = None, + tool_transformer: CompositorTransformer[LayerToolT, ToolT] | None = None, + ) -> "Compositor[PromptT, ToolT, LayerPromptT, LayerToolT]": + """Validate names/dependencies, bind deps, and return a compositor.""" + layers: OrderedDict[str, Layer[Any, Any, Any, Any, Any, Any]] = OrderedDict() + deps_name_mapping: dict[str, Mapping[str, str]] = {} + for entry in self._entries: + if entry.name in layers: + raise ValueError(f"Duplicate layer name '{entry.name}'.") + layers[entry.name] = entry.layer + deps_name_mapping[entry.name] = entry.deps + + layer_names = set(layers) + for layer_name, deps in deps_name_mapping.items(): + declared_deps = layers[layer_name].dependency_names() + unknown_dep_keys = set(deps) - declared_deps + if unknown_dep_keys: + names = ", ".join(sorted(unknown_dep_keys)) + raise ValueError(f"Layer '{layer_name}' declares unknown dependency keys: {names}.") + missing_targets = set(deps.values()) - layer_names + if missing_targets: + names = ", ".join(sorted(missing_targets)) + raise ValueError(f"Layer '{layer_name}' depends on undefined layer names: {names}.") + + return Compositor( + layers=layers, + deps_name_mapping=deps_name_mapping, + prompt_transformer=prompt_transformer, + tool_transformer=tool_transformer, + ) + + @dataclass(kw_only=True) class Compositor(Generic[PromptT, ToolT, LayerPromptT, LayerToolT]): """Framework-neutral ordered layer graph with lifecycle and aggregation. @@ -199,7 +336,7 @@ class Compositor(Generic[PromptT, ToolT, LayerPromptT, LayerToolT]): from exposed item types. """ - layers: OrderedDict[str, Layer[Any, Any, Any]] + layers: OrderedDict[str, Layer[Any, Any, Any, Any, Any, Any]] deps_name_mapping: Mapping[str, Mapping[str, str]] = field(default_factory=dict) prompt_transformer: CompositorTransformer[LayerPromptT, PromptT] | None = None tool_transformer: CompositorTransformer[LayerToolT, ToolT] | None = None @@ -213,19 +350,12 @@ class Compositor(Generic[PromptT, ToolT, LayerPromptT, LayerToolT]): cls, conf: CompositorConfigValue, *, + registry: LayerRegistry, prompt_transformer: CompositorTransformer[LayerPromptT, PromptT] | None = None, tool_transformer: CompositorTransformer[LayerToolT, ToolT] | None = None, - ) -> Self: - """Create layers from config-like input and bind named dependencies.""" - conf = _validate_compositor_config_input(conf) - layers: OrderedDict[str, Layer[Any, Any, Any]] = OrderedDict() - for layer_conf in conf.layers: - layers[layer_conf.name] = layer_conf.create_layer() - - deps_name_mapping = {layer_conf.name: layer_conf.deps for layer_conf in conf.layers} - return cls( - layers=layers, - deps_name_mapping=deps_name_mapping, + ) -> "Compositor[PromptT, ToolT, LayerPromptT, LayerToolT]": + """Create a compositor from registry-backed serializable config.""" + return CompositorBuilder(registry).add_config(conf).build( prompt_transformer=prompt_transformer, tool_transformer=tool_transformer, ) @@ -257,7 +387,60 @@ class Compositor(Generic[PromptT, ToolT, LayerPromptT, LayerToolT]): def new_session(self) -> CompositorSession: """Create a fresh lifecycle session matching this compositor's layer order.""" - return CompositorSession(self.layers) + return CompositorSession( + OrderedDict((layer_name, layer.new_control()) for layer_name, layer in self.layers.items()) + ) + + def snapshot_session(self, session: CompositorSession) -> CompositorSessionSnapshot: + """Serialize non-active session lifecycle state and runtime state. + + Runtime handles are live Python objects and are intentionally excluded. + """ + self._validate_session(session) + active_layers = [name for name, control in session.layer_controls.items() if control.state is LifecycleState.ACTIVE] + if active_layers: + names = ", ".join(active_layers) + raise RuntimeError(f"Cannot snapshot active compositor session layers: {names}.") + return CompositorSessionSnapshot( + layers=[ + LayerSessionSnapshot( + name=name, + state=control.state, + runtime_state=cast(dict[str, JsonValue], control.runtime_state.model_dump(mode="json")), + ) + for name, control in session.layer_controls.items() + ] + ) + + def session_from_snapshot(self, snapshot: CompositorSessionSnapshot | JsonValue | str | bytes) -> CompositorSession: + """Restore a session from a snapshot and reinitialize empty handles.""" + snapshot = _validate_config_model_input(CompositorSessionSnapshot, snapshot) + if snapshot.schema_version != 1: + raise ValueError(f"Unsupported compositor session snapshot schema_version: {snapshot.schema_version}.") + snapshot_layer_names = tuple(layer.name for layer in snapshot.layers) + expected_layer_names = tuple(self.layers) + if snapshot_layer_names != expected_layer_names: + expected = ", ".join(expected_layer_names) + actual = ", ".join(snapshot_layer_names) + raise ValueError( + "CompositorSessionSnapshot layer names must match compositor layers in order. " + f"Expected [{expected}], got [{actual}]." + ) + active_layers = [layer.name for layer in snapshot.layers if layer.state is LifecycleState.ACTIVE] + if active_layers: + names = ", ".join(active_layers) + raise ValueError(f"Cannot restore active compositor session layers from snapshot: {names}.") + controls = OrderedDict( + ( + layer_snapshot.name, + self.layers[layer_snapshot.name].new_control( + state=layer_snapshot.state, + runtime_state=layer_snapshot.runtime_state, + ), + ) + for layer_snapshot in snapshot.layers + ) + return CompositorSession(controls) @asynccontextmanager async def enter( @@ -288,6 +471,18 @@ class Compositor(Generic[PromptT, ToolT, LayerPromptT, LayerToolT]): "CompositorSession layer names must match compositor layers in order. " f"Expected [{expected}], got [{actual}]." ) + for layer_name, layer in self.layers.items(): + control = session.layer_controls[layer_name] + if not isinstance(control.runtime_state, layer.runtime_state_type): + raise TypeError( + f"CompositorSession layer '{layer_name}' runtime_state must be " + f"{layer.runtime_state_type.__name__}, got {type(control.runtime_state).__name__}." + ) + if not isinstance(control.runtime_handles, layer.runtime_handles_type): + raise TypeError( + f"CompositorSession layer '{layer_name}' runtime_handles must be " + f"{layer.runtime_handles_type.__name__}, got {type(control.runtime_handles).__name__}." + ) def _ensure_session_can_enter(self, session: CompositorSession) -> None: """Reject active or closed layer controls before any layer side effects.""" @@ -330,14 +525,15 @@ class Compositor(Generic[PromptT, ToolT, LayerPromptT, LayerToolT]): __all__ = [ "Compositor", + "CompositorBuilder", "CompositorConfig", "CompositorConfigValue", - "CompositorLayerConfigInput", + "CompositorSessionSnapshot", "CompositorSession", "CompositorTransformer", "CompositorTransformerKwargs", - "CompositorLayerConfig", - "CompositorLayerConfigValue", - "ImportedLayerConfig", - "LayerSpec", + "LayerDescriptor", + "LayerNodeConfig", + "LayerRegistry", + "LayerSessionSnapshot", ] diff --git a/dify-agent/src/agenton/layers/__init__.py b/dify-agent/src/agenton/layers/__init__.py index 19cb7cb7b8..b8b62561c2 100644 --- a/dify-agent/src/agenton/layers/__init__.py +++ b/dify-agent/src/agenton/layers/__init__.py @@ -5,7 +5,17 @@ families while keeping concrete reusable layers in ``agenton_collections``. """ -from agenton.layers.base import ExitIntent, Layer, LayerControl, LayerDeps, LifecycleState, NoLayerDeps +from agenton.layers.base import ( + EmptyLayerConfig, + EmptyRuntimeHandles, + EmptyRuntimeState, + ExitIntent, + Layer, + LayerControl, + LayerDeps, + LifecycleState, + NoLayerDeps, +) from agenton.layers.types import ( AllPromptTypes, AllToolTypes, @@ -29,6 +39,9 @@ __all__ = [ "LayerControl", "LifecycleState", "ExitIntent", + "EmptyLayerConfig", + "EmptyRuntimeState", + "EmptyRuntimeHandles", "NoLayerDeps", "PlainLayer", "PlainPrompt", diff --git a/dify-agent/src/agenton/layers/base.py b/dify-agent/src/agenton/layers/base.py index fba690a6d2..0cbaa7e14f 100644 --- a/dify-agent/src/agenton/layers/base.py +++ b/dify-agent/src/agenton/layers/base.py @@ -1,11 +1,13 @@ """Core layer abstractions and typed dependency binding. -Layers declare their dependency shape with ``Layer[DepsT, PromptT, ToolT]``. +Layers declare their dependency shape with ``Layer[DepsT, PromptT, ToolT, ...]``. ``DepsT`` must be a ``LayerDeps`` subclass whose annotated members are concrete ``Layer`` subclasses or modern optional dependencies such as ``SomeLayer | -None``. The base class infers ``deps_type`` from the generic base when possible, -while still allowing subclasses to set ``deps_type`` explicitly for unusual -inheritance patterns. +None``. The optional trailing generic slots declare Pydantic schemas for config, +serializable runtime state, and live runtime handles. The base class infers +``deps_type`` and schema class attributes from the generic base when possible, +while still allowing subclasses to set them explicitly for unusual inheritance +patterns. ``Layer.bind_deps`` is the mutation point for dependency state. Layer implementations should treat ``self.deps`` as unavailable until a compositor or @@ -17,9 +19,10 @@ machine and per-session runtime owner. A fresh control starts in while active or closed controls are rejected to prevent ambiguous nested or post-delete reuse. Exit behavior is selected per entry with ``ExitIntent`` and resets to delete on every successful enter. Layer instances are shared graph and -capability definitions, so session-local ids, handles, clients, and other -runtime values generated by lifecycle hooks belong in -``LayerControl.runtime_state`` rather than on ``self``. +capability definitions, so session-local serializable ids, checkpoints, and +other snapshot data belong in ``LayerControl.runtime_state``; live clients, +connections, and process handles belong in ``LayerControl.runtime_handles``. +Neither category should be stored on ``self`` when it is session-local. ``Layer`` is framework-neutral over prompt and tool item types. The native ``prefix_prompts``, ``suffix_prompts``, and ``tools`` properties are the layer @@ -34,9 +37,18 @@ from contextlib import AbstractAsyncContextManager, asynccontextmanager from dataclasses import dataclass, field from enum import StrEnum from types import UnionType -from typing import Any, Mapping, Sequence, Union, cast, get_args, get_origin, get_type_hints +from typing import Any, ClassVar, Generic, Mapping, Sequence, Union, cast, get_args, get_origin, get_type_hints -from typing_extensions import Self +from pydantic import BaseModel, ConfigDict +from typing_extensions import Self, TypeVar + + +_DepsT = TypeVar("_DepsT", bound="LayerDeps") +_PromptT = TypeVar("_PromptT") +_ToolT = TypeVar("_ToolT") +_ConfigT = TypeVar("_ConfigT", bound=BaseModel, default="EmptyLayerConfig") +_RuntimeStateT = TypeVar("_RuntimeStateT", bound=BaseModel, default="EmptyRuntimeState") +_RuntimeHandlesT = TypeVar("_RuntimeHandlesT", bound=BaseModel, default="EmptyRuntimeHandles") class LayerDeps: @@ -47,7 +59,7 @@ class LayerDeps: are always assigned as attributes; missing optional values become ``None``. """ - def __init__(self, **deps: "Layer[Any, Any, Any] | None") -> None: + def __init__(self, **deps: "Layer[Any, Any, Any, Any, Any, Any] | None") -> None: dep_specs = _get_dep_specs(type(self)) missing_names = {name for name, spec in dep_specs.items() if not spec.optional} - deps.keys() if missing_names: @@ -79,6 +91,28 @@ class NoLayerDeps(LayerDeps): """Dependency container for layers that do not require other layers.""" +class EmptyLayerConfig(BaseModel): + """Default serializable config schema for layers without config.""" + + model_config = ConfigDict(extra="forbid") + + +class EmptyRuntimeState(BaseModel): + """Default serializable per-session runtime state schema.""" + + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + +class EmptyRuntimeHandles(BaseModel): + """Default live per-session runtime handle schema. + + Handles may contain arbitrary Python objects and are intentionally excluded + from session snapshots. + """ + + model_config = ConfigDict(extra="forbid", validate_assignment=True, arbitrary_types_allowed=True) + + class LifecycleState(StrEnum): """Externally observable lifecycle state for a layer control.""" @@ -96,26 +130,31 @@ class ExitIntent(StrEnum): @dataclass(slots=True) -class LayerControl: +class LayerControl(Generic[_RuntimeStateT, _RuntimeHandlesT]): """Stateful control slot passed into a layer entry context. ``Layer.enter`` requires the caller to provide this object. The control owns the layer lifecycle state, the current entry's exit intent, and arbitrary - per-session runtime state. Call ``suspend_on_exit`` before leaving the + per-session runtime state and live handles. Call ``suspend_on_exit`` before leaving the context to make a later entry resume; call ``delete_on_exit`` or do nothing - for the default delete behavior. Store session-local ids and resource - handles in ``runtime_state`` so concurrent or later sessions do not share - mutable runtime data through the layer instance. + for the default delete behavior. Store session-local serializable ids, + checkpoints, and other snapshot data in ``runtime_state``. Store live + clients, connections, process handles, and other non-serializable objects in + ``runtime_handles``. Do not put either kind of session-local data on the + shared layer instance. ``runtime_state`` intentionally persists after suspend and delete. Suspend, resume, and delete hooks can inspect the same values created on entry, and callers may inspect closed-session diagnostics after exit. Reuse is still - governed by ``state``: a closed control cannot be entered again. + governed by ``state``: a closed control cannot be entered again. Runtime + handles are not serialized in snapshots and should be rehydrated from + runtime state in resume hooks. """ state: LifecycleState = LifecycleState.NEW exit_intent: ExitIntent = ExitIntent.DELETE - runtime_state: dict[str, object] = field(default_factory=dict) + runtime_state: _RuntimeStateT = field(default_factory=lambda: cast(_RuntimeStateT, EmptyRuntimeState())) + runtime_handles: _RuntimeHandlesT = field(default_factory=lambda: cast(_RuntimeHandlesT, EmptyRuntimeHandles())) def suspend_on_exit(self) -> None: """Request suspend behavior when the current layer entry exits.""" @@ -130,11 +169,14 @@ class LayerControl: class LayerDepSpec: """Runtime dependency specification derived from a deps annotation.""" - layer_type: type["Layer[Any, Any, Any]"] + layer_type: type["Layer[Any, Any, Any, Any, Any, Any]"] optional: bool = False -class Layer[DepsT: LayerDeps, PromptT, ToolT](ABC): +class Layer( + ABC, + Generic[_DepsT, _PromptT, _ToolT, _ConfigT, _RuntimeStateT, _RuntimeHandlesT], +): """Framework-neutral base class for prompt/tool layers. Subclasses expose optional prompt fragments and tools through typed @@ -147,15 +189,20 @@ class Layer[DepsT: LayerDeps, PromptT, ToolT](ABC): extra runtime resources. """ - deps_type: type[DepsT] - deps: DepsT + deps_type: type[_DepsT] + deps: _DepsT + type_id: ClassVar[str | None] = None + config_type: ClassVar[type[BaseModel]] = EmptyLayerConfig + runtime_state_type: ClassVar[type[BaseModel]] = EmptyRuntimeState + runtime_handles_type: ClassVar[type[BaseModel]] = EmptyRuntimeHandles def __init_subclass__(cls) -> None: super().__init_subclass__() + is_generic_template = _is_generic_layer_template(cls) deps_type = cls.__dict__.get("deps_type") if deps_type is None: deps_type = _infer_deps_type(cls) or getattr(cls, "deps_type", None) - if deps_type is None and _is_generic_layer_template(cls): + if deps_type is None and is_generic_template: return if deps_type is not None: cls.deps_type = deps_type # pyright: ignore[reportAttributeAccessIssue] @@ -164,25 +211,63 @@ class Layer[DepsT: LayerDeps, PromptT, ToolT](ABC): if not isinstance(deps_type, type) or not issubclass(deps_type, LayerDeps): raise TypeError(f"{cls.__name__}.deps_type must be a LayerDeps subclass.") _get_dep_specs(deps_type) + _init_schema_type(cls, "config_type", _infer_schema_type(cls, 3, "config_type"), EmptyLayerConfig) + _init_schema_type( + cls, + "runtime_state_type", + _infer_schema_type(cls, 4, "runtime_state_type"), + EmptyRuntimeState, + ) + _init_schema_type( + cls, + "runtime_handles_type", + _infer_schema_type(cls, 5, "runtime_handles_type"), + EmptyRuntimeHandles, + ) @classmethod - def from_config(cls: type[Self], config: Any) -> Self: - """Create a layer from serialized config. + def from_config(cls: type[Self], config: _ConfigT) -> Self: + """Create a layer from schema-validated serialized config. - Layers are not config-constructible by default. Subclasses that accept - config should override this method and validate dynamic input before - constructing the layer. + Registries/builders validate raw config with ``config_type`` before + calling this method. Layers are not config-constructible by default. + Subclasses that accept config should override this method and consume + the typed Pydantic model for their schema. """ raise TypeError(f"{cls.__name__} cannot be created from config.") - def bind_deps(self, deps: Mapping[str, "Layer[Any, Any, Any] | None"]) -> None: + @classmethod + def dependency_names(cls) -> frozenset[str]: + """Return dependency field names declared by this layer's deps schema.""" + return frozenset(_get_dep_specs(cls.deps_type)) + + def new_control( + self, + *, + state: LifecycleState = LifecycleState.NEW, + runtime_state: object | None = None, + ) -> LayerControl[_RuntimeStateT, _RuntimeHandlesT]: + """Create a schema-validated per-session control for this layer. + + ``runtime_state`` is validated through ``runtime_state_type`` and live + handles are always initialized empty through ``runtime_handles_type``. + """ + raw_runtime_state = {} if runtime_state is None else runtime_state + return LayerControl( + state=state, + exit_intent=ExitIntent.DELETE, + runtime_state=cast(_RuntimeStateT, self.runtime_state_type.model_validate(raw_runtime_state)), + runtime_handles=cast(_RuntimeHandlesT, self.runtime_handles_type.model_validate({})), + ) + + def bind_deps(self, deps: Mapping[str, "Layer[Any, Any, Any, Any, Any, Any] | None"]) -> None: """Bind this layer's declared dependencies from a name-to-layer mapping. The mapping may include more layers than the declared dependency fields. Only names declared by ``deps_type`` are selected and validated. Missing optional deps are bound as ``None``. """ - resolved_deps: dict[str, Layer[Any, Any, Any] | None] = {} + resolved_deps: dict[str, Layer[Any, Any, Any, Any, Any, Any] | None] = {} for name, spec in _get_dep_specs(self.deps_type).items(): if name not in deps: if spec.optional: @@ -194,7 +279,7 @@ class Layer[DepsT: LayerDeps, PromptT, ToolT](ABC): resolved_deps[name] = deps[name] self.deps = self.deps_type(**resolved_deps) - def enter(self, control: LayerControl) -> AbstractAsyncContextManager[None]: + def enter(self, control: LayerControl[_RuntimeStateT, _RuntimeHandlesT]) -> AbstractAsyncContextManager[None]: """Return the layer's async entry context manager. ``control`` is the lifecycle control slot for this entry. Subclasses can @@ -204,7 +289,7 @@ class Layer[DepsT: LayerDeps, PromptT, ToolT](ABC): return self.lifecycle_enter(control) @asynccontextmanager - async def lifecycle_enter(self, control: LayerControl) -> AsyncIterator[None]: + async def lifecycle_enter(self, control: LayerControl[_RuntimeStateT, _RuntimeHandlesT]) -> AsyncIterator[None]: """Run the default explicit lifecycle state machine for one entry.""" if control.state is LifecycleState.NEW: control.exit_intent = ExitIntent.DELETE @@ -233,37 +318,37 @@ class Layer[DepsT: LayerDeps, PromptT, ToolT](ABC): await self.on_context_delete(control) control.state = LifecycleState.CLOSED - async def on_context_create(self, control: LayerControl) -> None: + async def on_context_create(self, control: LayerControl[_RuntimeStateT, _RuntimeHandlesT]) -> None: """Run when the layer context is entered from ``LifecycleState.NEW``.""" - async def on_context_delete(self, control: LayerControl) -> None: + async def on_context_delete(self, control: LayerControl[_RuntimeStateT, _RuntimeHandlesT]) -> None: """Run when the layer context exits with ``ExitIntent.DELETE``.""" - async def on_context_suspend(self, control: LayerControl) -> None: + async def on_context_suspend(self, control: LayerControl[_RuntimeStateT, _RuntimeHandlesT]) -> None: """Run when the layer context exits with ``ExitIntent.SUSPEND``.""" - async def on_context_resume(self, control: LayerControl) -> None: + async def on_context_resume(self, control: LayerControl[_RuntimeStateT, _RuntimeHandlesT]) -> None: """Run when the layer context enters from ``LifecycleState.SUSPENDED``.""" @property - def prefix_prompts(self) -> Sequence[PromptT]: + def prefix_prompts(self) -> Sequence[_PromptT]: return [] @property - def suffix_prompts(self) -> Sequence[PromptT]: + def suffix_prompts(self) -> Sequence[_PromptT]: return [] @property - def tools(self) -> Sequence[ToolT]: + def tools(self) -> Sequence[_ToolT]: return [] @abstractmethod - def wrap_prompt(self, prompt: PromptT) -> object: + def wrap_prompt(self, prompt: _PromptT) -> object: """Wrap a native prompt item for compositor aggregation.""" raise NotImplementedError @abstractmethod - def wrap_tool(self, tool: ToolT) -> object: + def wrap_tool(self, tool: _ToolT) -> object: """Wrap a native tool item for compositor aggregation.""" raise NotImplementedError @@ -297,45 +382,107 @@ def _as_dep_spec(annotation: object) -> LayerDepSpec | None: return LayerDepSpec(layer_type=layer_type) -def _as_layer_type(annotation: object) -> type[Layer[Any, Any, Any]] | None: +def _as_layer_type(annotation: object) -> type[Layer[Any, Any, Any, Any, Any, Any]] | None: runtime_type = get_origin(annotation) or annotation if isinstance(runtime_type, type) and issubclass(runtime_type, Layer): - return cast(type[Layer[Any, Any, Any]], runtime_type) + return cast(type[Layer[Any, Any, Any, Any, Any, Any]], runtime_type) return None -def _infer_deps_type(layer_type: type[Layer[Any, Any, Any]]) -> type[LayerDeps] | None: - return _infer_deps_type_from_bases(layer_type, {}) +def _infer_deps_type(layer_type: type[Layer[Any, Any, Any, Any, Any, Any]]) -> type[LayerDeps] | None: + inferred = _infer_layer_generic_arg(layer_type, 0, {}) + if inferred is None: + return None + return _as_deps_type(inferred) -def _infer_deps_type_from_bases( - layer_type: type[Layer[Any, Any, Any]], +def _infer_schema_type( + layer_type: type[Layer[Any, Any, Any, Any, Any, Any]], + index: int, + attr_name: str, +) -> type[BaseModel] | None: + inferred = _infer_schema_generic_arg(layer_type, attr_name, {}) or _infer_layer_generic_arg(layer_type, index, {}) + if inferred is None: + return None + schema_type = _as_model_type(inferred) + if schema_type is None: + raise TypeError(f"{layer_type.__name__}.{attr_name} must be a Pydantic BaseModel subclass.") + return schema_type + + +def _infer_schema_generic_arg( + layer_type: type[Layer[Any, Any, Any, Any, Any, Any]], + attr_name: str, substitutions: Mapping[object, object], -) -> type[LayerDeps] | None: - """Infer the concrete deps container through generic Layer inheritance. +) -> object | None: + """Infer schema type arguments exposed by typed layer family bases.""" + expected_names = { + "config_type": {"ConfigT", "_ConfigT"}, + "runtime_state_type": {"RuntimeStateT", "_RuntimeStateT"}, + "runtime_handles_type": {"RuntimeHandlesT", "_RuntimeHandlesT"}, + }[attr_name] + for base in getattr(layer_type, "__orig_bases__", ()): + origin = get_origin(base) or base + args = tuple(_substitute_type(arg, substitutions) for arg in get_args(base)) + if not isinstance(origin, type) or not issubclass(origin, Layer): + continue + + params = _generic_params(origin) + for param, arg in zip(params, args): + if getattr(param, "__name__", None) in expected_names: + return arg + + next_substitutions = dict(substitutions) + next_substitutions.update(_generic_arg_substitutions(origin, args)) + inferred = _infer_schema_generic_arg(origin, attr_name, next_substitutions) + if inferred is not None: + return inferred + return None + + +def _infer_layer_generic_arg( + layer_type: type[Layer[Any, Any, Any, Any, Any, Any]], + index: int, + substitutions: Mapping[object, object], +) -> object | None: + """Infer one concrete ``Layer`` generic argument through inheritance. This walks through intermediate generic base classes so subclasses can omit - an explicit ``deps_type`` in common cases such as ``class X(Base[YDeps])``. + explicit class attributes in common cases such as ``class X(Base[YDeps])``. """ for base in getattr(layer_type, "__orig_bases__", ()): origin = get_origin(base) or base args = tuple(_substitute_type(arg, substitutions) for arg in get_args(base)) if origin is Layer: - if not args: + if len(args) <= index: continue - return _as_deps_type(args[0]) + return args[index] if not isinstance(origin, type) or not issubclass(origin, Layer): continue next_substitutions = dict(substitutions) next_substitutions.update(_generic_arg_substitutions(origin, args)) - inferred = _infer_deps_type_from_bases(origin, next_substitutions) + inferred = _infer_layer_generic_arg(origin, index, next_substitutions) if inferred is not None: return inferred return None +def _init_schema_type( + layer_type: type[Layer[Any, Any, Any, Any, Any, Any]], + attr_name: str, + inferred_schema_type: type[BaseModel] | None, + default_schema_type: type[BaseModel], +) -> None: + schema_type = layer_type.__dict__.get(attr_name) + if schema_type is None: + schema_type = inferred_schema_type or getattr(layer_type, attr_name, default_schema_type) + setattr(layer_type, attr_name, schema_type) + if not isinstance(schema_type, type) or not issubclass(schema_type, BaseModel): + raise TypeError(f"{layer_type.__name__}.{attr_name} must be a Pydantic BaseModel subclass.") + + def _substitute_type(value: object, substitutions: Mapping[object, object]) -> object: if value in substitutions: return substitutions[value] @@ -359,10 +506,15 @@ def _substitute_type(value: object, substitutions: Mapping[object, object]) -> o def _generic_arg_substitutions(origin: type[Any], args: Sequence[object]) -> dict[object, object]: + params = _generic_params(origin) + return dict(zip(params, args)) + + +def _generic_params(origin: type[Any]) -> Sequence[object]: params = getattr(origin, "__type_params__", ()) if not params: params = getattr(origin, "__parameters__", ()) - return dict(zip(params, args)) + return params def _as_deps_type(value: object) -> type[LayerDeps] | None: @@ -372,7 +524,14 @@ def _as_deps_type(value: object) -> type[LayerDeps] | None: return None -def _is_generic_layer_template(layer_type: type[Layer[Any, Any, Any]]) -> bool: +def _as_model_type(value: object) -> type[BaseModel] | None: + runtime_type = get_origin(value) or value + if isinstance(runtime_type, type) and issubclass(runtime_type, BaseModel): + return runtime_type + return None + + +def _is_generic_layer_template(layer_type: type[Layer[Any, Any, Any, Any, Any, Any]]) -> bool: return bool(getattr(layer_type, "__type_params__", ())) or bool( getattr(layer_type, "__parameters__", ()) ) diff --git a/dify-agent/src/agenton/layers/types.py b/dify-agent/src/agenton/layers/types.py index 3a705964a6..c1de45b684 100644 --- a/dify-agent/src/agenton/layers/types.py +++ b/dify-agent/src/agenton/layers/types.py @@ -2,7 +2,10 @@ ``Layer`` itself is framework-neutral. This module defines typed layer families that bind its prompt/tool generic slots to concrete contracts, such as ordinary -string prompts with plain callable tools or pydantic-ai prompt/tool shapes. +string prompts with plain callable tools or pydantic-ai prompt/tool shapes. The +families keep the trailing schema generic slots open so concrete layers can have +``config_type``, ``runtime_state_type``, and ``runtime_handles_type`` inferred +from type arguments instead of repeated class attributes. Tagged aggregate aliases cover code paths that can accept any supported prompt/tool family without changing the plain and pydantic-ai layer contracts. Pydantic-ai names are imported for static analysis only, so ``agenton`` can be @@ -14,15 +17,17 @@ from __future__ import annotations from collections.abc import Callable from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Generic, Literal -from typing_extensions import final, override +from typing_extensions import TypeVar, final, override if TYPE_CHECKING: from pydantic_ai import Tool from pydantic_ai.tools import SystemPromptFunc -from agenton.layers.base import Layer, LayerDeps +from pydantic import BaseModel + +from agenton.layers.base import EmptyLayerConfig, EmptyRuntimeHandles, EmptyRuntimeState, Layer, LayerDeps type PlainPrompt = str type PlainTool = Callable[..., Any] @@ -68,7 +73,17 @@ type AllPromptTypes = PlainPromptType | PydanticAIPromptType[Any] type AllToolTypes = PlainToolType | PydanticAIToolType[Any] -class PlainLayer[DepsT: LayerDeps](Layer[DepsT, PlainPrompt, PlainTool]): +_DepsT = TypeVar("_DepsT", bound=LayerDeps) +_ConfigT = TypeVar("_ConfigT", bound=BaseModel, default=EmptyLayerConfig) +_RuntimeStateT = TypeVar("_RuntimeStateT", bound=BaseModel, default=EmptyRuntimeState) +_RuntimeHandlesT = TypeVar("_RuntimeHandlesT", bound=BaseModel, default=EmptyRuntimeHandles) +_AgentDepsT = TypeVar("_AgentDepsT") + + +class PlainLayer( + Generic[_DepsT, _ConfigT, _RuntimeStateT, _RuntimeHandlesT], + Layer[_DepsT, PlainPrompt, PlainTool, _ConfigT, _RuntimeStateT, _RuntimeHandlesT], +): """Layer base for ordinary string prompts and plain-callable tools.""" @final @@ -82,8 +97,16 @@ class PlainLayer[DepsT: LayerDeps](Layer[DepsT, PlainPrompt, PlainTool]): return PlainToolType(tool) -class PydanticAILayer[DepsT: LayerDeps, AgentDepsT]( - Layer[DepsT, PydanticAIPrompt[AgentDepsT], PydanticAITool[AgentDepsT]] +class PydanticAILayer( + Generic[_DepsT, _AgentDepsT, _ConfigT, _RuntimeStateT, _RuntimeHandlesT], + Layer[ + _DepsT, + PydanticAIPrompt[_AgentDepsT], + PydanticAITool[_AgentDepsT], + _ConfigT, + _RuntimeStateT, + _RuntimeHandlesT, + ], ): """Layer base for pydantic-ai prompt and tool adapters.""" @@ -91,13 +114,13 @@ class PydanticAILayer[DepsT: LayerDeps, AgentDepsT]( @override def wrap_prompt( self, - prompt: PydanticAIPrompt[AgentDepsT], - ) -> PydanticAIPromptType[AgentDepsT]: + prompt: PydanticAIPrompt[_AgentDepsT], + ) -> PydanticAIPromptType[_AgentDepsT]: return PydanticAIPromptType(prompt) @final @override - def wrap_tool(self, tool: PydanticAITool[AgentDepsT]) -> PydanticAIToolType[AgentDepsT]: + def wrap_tool(self, tool: PydanticAITool[_AgentDepsT]) -> PydanticAIToolType[_AgentDepsT]: return PydanticAIToolType(tool) diff --git a/dify-agent/src/agenton_collections/layers/plain/__init__.py b/dify-agent/src/agenton_collections/layers/plain/__init__.py index 76c8da5339..a5b0b75384 100644 --- a/dify-agent/src/agenton_collections/layers/plain/__init__.py +++ b/dify-agent/src/agenton_collections/layers/plain/__init__.py @@ -1,6 +1,6 @@ """Reusable collection layers for the plain layer family.""" -from agenton_collections.layers.plain.basic import ObjectLayer, PromptLayer, ToolsLayer +from agenton_collections.layers.plain.basic import ObjectLayer, PromptLayer, PromptLayerConfig, ToolsLayer from agenton_collections.layers.plain.dynamic_tools import ( DynamicToolsLayer, DynamicToolsLayerDeps, @@ -12,6 +12,7 @@ __all__ = [ "DynamicToolsLayerDeps", "ObjectLayer", "PromptLayer", + "PromptLayerConfig", "ToolsLayer", "with_object", ] diff --git a/dify-agent/src/agenton_collections/layers/plain/basic.py b/dify-agent/src/agenton_collections/layers/plain/basic.py index 1fcb6b51ae..b3523052fa 100644 --- a/dify-agent/src/agenton_collections/layers/plain/basic.py +++ b/dify-agent/src/agenton_collections/layers/plain/basic.py @@ -10,30 +10,46 @@ from collections.abc import Callable, Sequence from dataclasses import dataclass, field from typing import Any -from pydantic import TypeAdapter +from pydantic import BaseModel, ConfigDict, Field from agenton.layers.base import NoLayerDeps from agenton.layers.types import PlainLayer +class PromptLayerConfig(BaseModel): + """Serializable config schema for ``PromptLayer``.""" + + prefix: list[str] | str = Field(default_factory=list) + suffix: list[str] | str = Field(default_factory=list) + + model_config = ConfigDict(extra="forbid") + + @dataclass class ObjectLayer[ObjectT](PlainLayer[NoLayerDeps]): - """Layer that stores one typed object for downstream dependencies.""" + """Layer that stores one typed object for downstream dependencies. + + Object layers are instance-only because arbitrary Python objects are not + serializable graph config. Add them with ``CompositorBuilder.add_instance``. + """ value: ObjectT @dataclass -class PromptLayer(PlainLayer[NoLayerDeps]): +class PromptLayer(PlainLayer[NoLayerDeps, PromptLayerConfig]): """Layer that contributes configured prefix and suffix prompt fragments.""" + type_id = "plain.prompt" + prefix: list[str] | str = field(default_factory=list) suffix: list[str] | str = field(default_factory=list) @classmethod - def from_config(cls, config: Any): - """Validate prompt config against this dataclass.""" - return _PROMPT_LAYER_ADAPTER.validate_python(config) + def from_config(cls, config: BaseModel): + """Create a prompt layer from validated prompt config.""" + validated_config = PromptLayerConfig.model_validate(config) + return cls(prefix=validated_config.prefix, suffix=validated_config.suffix) @property def prefix_prompts(self) -> list[str]: @@ -50,7 +66,11 @@ class PromptLayer(PlainLayer[NoLayerDeps]): @dataclass class ToolsLayer(PlainLayer[NoLayerDeps]): - """Layer that contributes configured plain-callable tools.""" + """Layer that contributes configured plain-callable tools. + + Tool layers are instance-only because Python callables are live objects. Add + them with ``CompositorBuilder.add_instance``. + """ tool_entries: Sequence[Callable[..., Any]] = () @@ -59,10 +79,9 @@ class ToolsLayer(PlainLayer[NoLayerDeps]): return list(self.tool_entries) -_PROMPT_LAYER_ADAPTER = TypeAdapter(PromptLayer) - __all__ = [ "ObjectLayer", + "PromptLayerConfig", "PromptLayer", "ToolsLayer", ] diff --git a/dify-agent/tests/local/agenton/compositor/test_builder_snapshot.py b/dify-agent/tests/local/agenton/compositor/test_builder_snapshot.py new file mode 100644 index 0000000000..7c5c0e2205 --- /dev/null +++ b/dify-agent/tests/local/agenton/compositor/test_builder_snapshot.py @@ -0,0 +1,257 @@ +import asyncio +from collections import OrderedDict +from dataclasses import dataclass + +from pydantic import BaseModel, ConfigDict, ValidationError +from typing_extensions import override + +from agenton.compositor import Compositor, CompositorBuilder, CompositorSession, LayerRegistry +from agenton.layers import EmptyLayerConfig, LayerControl, LayerDeps, NoLayerDeps, PlainLayer, PlainPromptType, PlainToolType +from agenton_collections.layers.plain import ObjectLayer, PromptLayer + + +def test_registry_infers_descriptor_and_rejects_duplicate_or_missing_type_id() -> None: + registry = LayerRegistry() + registry.register_layer(PromptLayer) + + descriptor = registry.resolve("plain.prompt") + assert descriptor.layer_type is PromptLayer + assert descriptor.config_type is PromptLayer.config_type + + try: + registry.register_layer(PromptLayer) + except ValueError as e: + assert str(e) == "Layer type id 'plain.prompt' is already registered." + else: + raise AssertionError("Expected ValueError.") + + try: + registry.register_layer(InstanceOnlyLayer) + except ValueError as e: + assert "must declare a type_id" in str(e) + else: + raise AssertionError("Expected ValueError.") + + try: + registry.register_layer(InstanceOnlyLayer, type_id=123) # pyright: ignore[reportArgumentType] + except TypeError as e: + assert str(e) == "Layer type id for 'InstanceOnlyLayer' must be a string." + else: + raise AssertionError("Expected TypeError.") + + +@dataclass(slots=True) +class InstanceOnlyLayer(PlainLayer[NoLayerDeps]): + pass + + +def test_builder_creates_config_layers_with_typed_validation() -> None: + registry = LayerRegistry() + registry.register_layer(PromptLayer) + + compositor = ( + CompositorBuilder(registry) + .add_config_layer( + name="prompt", + type="plain.prompt", + config={"prefix": "hello", "suffix": ["bye"]}, + ) + .build() + ) + + assert [prompt.value for prompt in compositor.prompts] == ["hello", "bye"] + + try: + CompositorBuilder(registry).add_config_layer( + name="bad", + type="plain.prompt", + config={"unknown": "field"}, + ) + except ValidationError: + pass + else: + raise AssertionError("Expected ValidationError.") + + +class ObjectConsumerDeps(LayerDeps): + obj: ObjectLayer[str] # pyright: ignore[reportUninitializedInstanceVariable] + + +@dataclass(slots=True) +class ObjectConsumerLayer(PlainLayer[ObjectConsumerDeps]): + @property + @override + def prefix_prompts(self) -> list[str]: + return [self.deps.obj.value] + + +def test_builder_mixes_config_and_instances_and_rejects_invalid_deps() -> None: + registry = LayerRegistry() + registry.register_layer(PromptLayer) + + compositor = ( + CompositorBuilder(registry) + .add_config({"layers": [{"name": "prompt", "type": "plain.prompt", "config": {"prefix": "cfg"}}]}) + .add_instance(name="obj", layer=ObjectLayer("instance")) + .add_instance(name="consumer", layer=ObjectConsumerLayer(), deps={"obj": "obj"}) + .build() + ) + + assert [prompt.value for prompt in compositor.prompts] == ["cfg", "instance"] + + try: + CompositorBuilder(registry).add_instance( + name="consumer", + layer=ObjectConsumerLayer(), + deps={"missing_dep_key": "obj"}, + ).build() + except ValueError as e: + assert str(e) == "Layer 'consumer' declares unknown dependency keys: missing_dep_key." + else: + raise AssertionError("Expected ValueError.") + + try: + CompositorBuilder(registry).add_instance( + name="consumer", + layer=ObjectConsumerLayer(), + deps={"obj": "missing_target"}, + ).build() + except ValueError as e: + assert str(e) == "Layer 'consumer' depends on undefined layer names: missing_target." + else: + raise AssertionError("Expected ValueError.") + + +class HandleState(BaseModel): + resource_id: str = "" + + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + +class HandleBox: + def __init__(self, value: str) -> None: + self.value = value + + +class HandleModels(BaseModel): + handle: HandleBox | None = None + + model_config = ConfigDict(extra="forbid", validate_assignment=True, arbitrary_types_allowed=True) + + +@dataclass(slots=True) +class HandleLayer(PlainLayer[NoLayerDeps, EmptyLayerConfig, HandleState, HandleModels]): + created: int = 0 + resumed: int = 0 + + @override + async def on_context_create(self, control: LayerControl[HandleState, HandleModels]) -> None: + self.created += 1 + control.runtime_handles.handle = HandleBox(control.runtime_state.resource_id) + + @override + async def on_context_resume(self, control: LayerControl[HandleState, HandleModels]) -> None: + self.resumed += 1 + control.runtime_handles.handle = HandleBox(f"resumed:{control.runtime_state.resource_id}") + + +def test_new_session_uses_layer_runtime_schemas() -> None: + compositor: Compositor[PlainPromptType, PlainToolType] = Compositor( + layers=OrderedDict([("handle", HandleLayer())]) + ) + session = compositor.new_session() + + assert isinstance(session.layer("handle").runtime_state, HandleState) + assert isinstance(session.layer("handle").runtime_handles, HandleModels) + + +def test_enter_rejects_bad_session_runtime_schemas_before_layer_hooks() -> None: + layer = HandleLayer() + compositor: Compositor[PlainPromptType, PlainToolType] = Compositor(layers=OrderedDict([("handle", layer)])) + bad_session = CompositorSession(OrderedDict([("handle", LayerControl())])) + + async def run() -> None: + async with compositor.enter(bad_session): + pass + + try: + asyncio.run(run()) + except TypeError as e: + assert str(e) == ( + "CompositorSession layer 'handle' runtime_state must be HandleState, " + "got EmptyRuntimeState." + ) + else: + raise AssertionError("Expected TypeError.") + + assert layer.created == 0 + + +def test_snapshot_rejects_active_sessions_and_excludes_handles() -> None: + compositor: Compositor[PlainPromptType, PlainToolType] = Compositor( + layers=OrderedDict([("handle", HandleLayer())]) + ) + session = compositor.session_from_snapshot( + {"layers": [{"name": "handle", "state": "new", "runtime_state": {"resource_id": "abc"}}]} + ) + + async def run() -> None: + async with compositor.enter(session): + try: + compositor.snapshot_session(session) + except RuntimeError as e: + assert str(e) == "Cannot snapshot active compositor session layers: handle." + else: + raise AssertionError("Expected RuntimeError.") + + asyncio.run(run()) + + snapshot = compositor.snapshot_session(session) + assert snapshot.model_dump(mode="json") == { + "schema_version": 1, + "layers": [{"name": "handle", "state": "closed", "runtime_state": {"resource_id": "abc"}}], + } + + +def test_restore_validates_runtime_state_and_resume_rehydrates_handles() -> None: + layer = HandleLayer() + compositor: Compositor[PlainPromptType, PlainToolType] = Compositor(layers=OrderedDict([("handle", layer)])) + + try: + compositor.session_from_snapshot( + {"layers": [{"name": "handle", "state": "suspended", "runtime_state": {"wrong": "field"}}]} + ) + except ValidationError: + pass + else: + raise AssertionError("Expected ValidationError.") + + restored = compositor.session_from_snapshot( + {"layers": [{"name": "handle", "state": "suspended", "runtime_state": {"resource_id": "abc"}}]} + ) + + async def run() -> None: + async with compositor.enter(restored): + control = restored.layer("handle") + assert isinstance(control.runtime_handles, HandleModels) + assert control.runtime_handles.handle is not None + assert control.runtime_handles.handle.value == "resumed:abc" + + asyncio.run(run()) + + assert layer.resumed == 1 + + +def test_session_from_snapshot_rejects_active_layer_state() -> None: + compositor: Compositor[PlainPromptType, PlainToolType] = Compositor( + layers=OrderedDict([("handle", HandleLayer())]) + ) + + try: + compositor.session_from_snapshot( + {"layers": [{"name": "handle", "state": "active", "runtime_state": {"resource_id": "abc"}}]} + ) + except ValueError as e: + assert str(e) == "Cannot restore active compositor session layers from snapshot: handle." + else: + raise AssertionError("Expected ValueError.") diff --git a/dify-agent/tests/local/agenton/compositor/test_enter.py b/dify-agent/tests/local/agenton/compositor/test_enter.py index 9ac9cc8f3b..f6c9a2d67f 100644 --- a/dify-agent/tests/local/agenton/compositor/test_enter.py +++ b/dify-agent/tests/local/agenton/compositor/test_enter.py @@ -4,11 +4,14 @@ from collections.abc import Iterator from dataclasses import dataclass, field from itertools import count +from pydantic import BaseModel, ConfigDict from typing_extensions import override from agenton.compositor import Compositor, CompositorSession from agenton.layers import ( ExitIntent, + EmptyLayerConfig, + EmptyRuntimeHandles, LayerControl, LifecycleState, NoLayerDeps, @@ -239,22 +242,30 @@ def test_failed_resume_keeps_control_reusable_as_suspended() -> None: assert session.layer("trace").state is LifecycleState.CLOSED +class RuntimeState(BaseModel): + runtime_id: int | None = None + resumed_runtime_id: int | None = None + deleted_runtime_id: int | None = None + + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + @dataclass(slots=True) -class RuntimeStateLayer(PlainLayer[NoLayerDeps]): +class RuntimeStateLayer(PlainLayer[NoLayerDeps, EmptyLayerConfig, RuntimeState]): next_id: Iterator[int] = field(default_factory=lambda: count(1)) @override - async def on_context_create(self, control: LayerControl) -> None: + async def on_context_create(self, control: LayerControl[RuntimeState, EmptyRuntimeHandles]) -> None: runtime_id = next(self.next_id) - control.runtime_state["runtime_id"] = runtime_id + control.runtime_state.runtime_id = runtime_id @override - async def on_context_resume(self, control: LayerControl) -> None: - control.runtime_state["resumed_runtime_id"] = control.runtime_state["runtime_id"] + async def on_context_resume(self, control: LayerControl[RuntimeState, EmptyRuntimeHandles]) -> None: + control.runtime_state.resumed_runtime_id = control.runtime_state.runtime_id @override - async def on_context_delete(self, control: LayerControl) -> None: - control.runtime_state["deleted_runtime_id"] = control.runtime_state["runtime_id"] + async def on_context_delete(self, control: LayerControl[RuntimeState, EmptyRuntimeHandles]) -> None: + control.runtime_state.deleted_runtime_id = control.runtime_state.runtime_id def test_runtime_state_is_per_session_and_survives_suspend_resume_delete() -> None: @@ -275,12 +286,12 @@ def test_runtime_state_is_per_session_and_survives_suspend_resume_delete() -> No asyncio.run(run()) - assert first_session.layer("trace").runtime_state == { + assert first_session.layer("trace").runtime_state.model_dump(exclude_none=True) == { "runtime_id": 1, "resumed_runtime_id": 1, "deleted_runtime_id": 1, } - assert second_session.layer("trace").runtime_state == { + assert second_session.layer("trace").runtime_state.model_dump(exclude_none=True) == { "runtime_id": 2, "deleted_runtime_id": 2, } diff --git a/dify-agent/tests/local/agenton/layers/test_schema_inference.py b/dify-agent/tests/local/agenton/layers/test_schema_inference.py new file mode 100644 index 0000000000..45eedf3bc5 --- /dev/null +++ b/dify-agent/tests/local/agenton/layers/test_schema_inference.py @@ -0,0 +1,94 @@ +from dataclasses import dataclass + +from pydantic import BaseModel, ConfigDict + +from agenton.compositor import LayerRegistry +from agenton.layers import EmptyLayerConfig, EmptyRuntimeHandles, EmptyRuntimeState, LayerControl, NoLayerDeps, PlainLayer + + +class InferredConfig(BaseModel): + value: str = "configured" + + model_config = ConfigDict(extra="forbid") + + +class InferredState(BaseModel): + count: int = 0 + + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + +class InferredHandles(BaseModel): + token: object | None = None + + model_config = ConfigDict(extra="forbid", validate_assignment=True, arbitrary_types_allowed=True) + + +@dataclass(slots=True) +class GenericSchemaLayer(PlainLayer[NoLayerDeps, InferredConfig, InferredState, InferredHandles]): + type_id = "test.generic-schema" + + async def on_context_create(self, control: LayerControl[InferredState, InferredHandles]) -> None: + control.runtime_state.count += 1 + control.runtime_handles.token = object() + + +@dataclass(slots=True) +class DefaultSchemaLayer(PlainLayer[NoLayerDeps]): + type_id = "test.default-schema" + + +def test_layer_infers_config_runtime_state_and_handles_from_generics() -> None: + layer = GenericSchemaLayer() + control = layer.new_control(runtime_state={"count": 3}) + + assert GenericSchemaLayer.config_type is InferredConfig + assert GenericSchemaLayer.runtime_state_type is InferredState + assert GenericSchemaLayer.runtime_handles_type is InferredHandles + assert isinstance(control.runtime_state, InferredState) + assert control.runtime_state.count == 3 + assert isinstance(control.runtime_handles, InferredHandles) + + +def test_layer_uses_empty_schema_defaults_when_omitted() -> None: + layer = DefaultSchemaLayer() + control = layer.new_control() + + assert DefaultSchemaLayer.config_type is EmptyLayerConfig + assert DefaultSchemaLayer.runtime_state_type is EmptyRuntimeState + assert DefaultSchemaLayer.runtime_handles_type is EmptyRuntimeHandles + assert isinstance(control.runtime_state, EmptyRuntimeState) + assert isinstance(control.runtime_handles, EmptyRuntimeHandles) + + +def test_invalid_declared_schema_type_is_rejected_clearly() -> None: + try: + + class InvalidSchemaLayer(PlainLayer[NoLayerDeps]): + config_type = dict # pyright: ignore[reportAssignmentType] + + except TypeError as e: + assert str(e) == "InvalidSchemaLayer.config_type must be a Pydantic BaseModel subclass." + else: + raise AssertionError("Expected TypeError.") + + try: + + class InvalidGenericSchemaLayer(PlainLayer[NoLayerDeps, dict[str, object]]): # pyright: ignore[reportInvalidTypeArguments] + pass + + except TypeError as e: + assert str(e) == "InvalidGenericSchemaLayer.config_type must be a Pydantic BaseModel subclass." + else: + raise AssertionError("Expected TypeError.") + + +def test_registry_descriptor_uses_inferred_schema_types() -> None: + registry = LayerRegistry() + registry.register_layer(GenericSchemaLayer) + + descriptor = registry.resolve("test.generic-schema") + + assert descriptor.config_type is InferredConfig + assert descriptor.runtime_state_type is InferredState + assert descriptor.runtime_handles_type is InferredHandles diff --git a/dify-agent/tests/local/examples/test_agenton_examples.py b/dify-agent/tests/local/examples/test_agenton_examples.py index bf190f9226..30cfcc538f 100644 --- a/dify-agent/tests/local/examples/test_agenton_examples.py +++ b/dify-agent/tests/local/examples/test_agenton_examples.py @@ -38,3 +38,11 @@ def test_agenton_pydantic_ai_example_smoke() -> None: assert "ToolCallPart: count_words(" in result.stdout assert "ToolCallPart: write_tagline(" in result.stdout assert "TextPart:" in result.stdout + + +def test_agenton_session_snapshot_example_smoke() -> None: + result = _run_example("examples/agenton/session_snapshot.py") + + assert result.returncode == 0, result.stderr + assert "Snapshot:" in result.stdout + assert "Rehydrated handle: restored:demo-connection" in result.stdout