import asyncio from collections.abc import Callable import secrets import time from dataclasses import dataclass import pytest from agenton.compositor import Compositor, LayerNode, LayerProvider from agenton.layers import LifecycleState from dify_agent.layers.shell import DIFY_SHELL_LAYER_TYPE_ID, DifyShellLayerConfig from dify_agent.layers.shell.layer import DifyShellLayer, DifyShellRuntimeState, ShellctlClientFactory from shell_session_manager.shellctl.shared import DeleteJobResponse, JobResult, JobStatusName, JobStatusView def _job_result( job_id: str, *, status: JobStatusName = JobStatusName.RUNNING, done: bool = False, exit_code: int | None = None, output: str = "", offset: int = 0, truncated: bool = False, output_path: str = "/tmp/output.log", ) -> JobResult: return JobResult( job_id=job_id, status=status, done=done, exit_code=exit_code, output=output, offset=offset, truncated=truncated, output_path=output_path, ) def _job_status( job_id: str, *, status: JobStatusName = JobStatusName.RUNNING, done: bool = False, exit_code: int | None = None, offset: int = 0, ) -> JobStatusView: return JobStatusView( job_id=job_id, status=status, done=done, exit_code=exit_code, created_at="2026-05-28T12:00:00Z", started_at="2026-05-28T12:00:01Z", ended_at="2026-05-28T12:00:02Z" if done else None, offset=offset, ) def _assert_error_observation(result: object, *, job_id: str | None = None, includes: str | None = None) -> None: assert isinstance(result, dict) assert isinstance(result.get("error"), str) assert result["error"] if job_id is None: assert "job_id" not in result else: assert result.get("job_id") == job_id if includes is not None: assert includes in result["error"] @dataclass(slots=True) class RunCall: script: str cwd: str | None timeout: float @dataclass(slots=True) class WaitCall: job_id: str offset: int timeout: float @dataclass(slots=True) class InputCall: job_id: str text: str offset: int timeout: float @dataclass(slots=True) class TerminateCall: job_id: str grace_seconds: float @dataclass(slots=True) class DeleteCall: job_id: str force: bool grace_seconds: float | None class FakeShellctlClient: run_calls: list[RunCall] wait_calls: list[WaitCall] input_calls: list[InputCall] terminate_calls: list[TerminateCall] delete_calls: list[DeleteCall] events: list[tuple[str, str]] closed: bool def __init__( self, *, run_handler: Callable[[str, str | None, float], JobResult] | None = None, wait_handler: Callable[[str, int, float], JobResult] | None = None, input_handler: Callable[[str, str, int, float], JobResult] | None = None, terminate_handler: Callable[[str, float], JobStatusView] | None = None, delete_handler: Callable[[str, bool, float | None], DeleteJobResponse] | None = None, ) -> None: self._run_handler = run_handler self._wait_handler = wait_handler self._input_handler = input_handler self._terminate_handler = terminate_handler self._delete_handler = delete_handler self.run_calls = [] self.wait_calls = [] self.input_calls = [] self.terminate_calls = [] self.delete_calls = [] self.events = [] self.closed = False async def run(self, script: str, *, cwd: str | None = None, timeout: float = 10.0) -> JobResult: self.run_calls.append(RunCall(script=script, cwd=cwd, timeout=timeout)) self.events.append(("run", script)) if self._run_handler is None: raise AssertionError("Unexpected run() call") return self._run_handler(script, cwd, timeout) async def wait(self, job_id: str, *, offset: int, timeout: float = 10.0) -> JobResult: self.wait_calls.append(WaitCall(job_id=job_id, offset=offset, timeout=timeout)) self.events.append(("wait", job_id)) if self._wait_handler is None: raise AssertionError("Unexpected wait() call") return self._wait_handler(job_id, offset, timeout) async def input(self, job_id: str, text: str, *, offset: int, timeout: float = 10.0) -> JobResult: self.input_calls.append(InputCall(job_id=job_id, text=text, offset=offset, timeout=timeout)) self.events.append(("input", job_id)) if self._input_handler is None: raise AssertionError("Unexpected input() call") return self._input_handler(job_id, text, offset, timeout) async def terminate(self, job_id: str, grace_seconds: float = 2.0) -> JobStatusView: self.terminate_calls.append(TerminateCall(job_id=job_id, grace_seconds=grace_seconds)) self.events.append(("terminate", job_id)) if self._terminate_handler is None: raise AssertionError("Unexpected terminate() call") return self._terminate_handler(job_id, grace_seconds) async def delete( self, job_id: str, *, force: bool = False, grace_seconds: float | None = None, ) -> DeleteJobResponse: self.delete_calls.append(DeleteCall(job_id=job_id, force=force, grace_seconds=grace_seconds)) self.events.append(("delete", job_id)) if self._delete_handler is None: return DeleteJobResponse(job_id=job_id) return self._delete_handler(job_id, force, grace_seconds) async def close(self) -> None: self.closed = True self.events.append(("close", "client")) def _shell_layer(*, client_factory: ShellctlClientFactory) -> DifyShellLayer: return DifyShellLayer.from_config_with_settings( DifyShellLayerConfig(), shellctl_entrypoint="http://shellctl", shellctl_client_factory=client_factory, ) def _shell_provider(*, client_factory: ShellctlClientFactory) -> LayerProvider[DifyShellLayer]: return LayerProvider.from_factory( layer_type=DifyShellLayer, create=lambda config: DifyShellLayer.from_config_with_settings( DifyShellLayerConfig.model_validate(config), shellctl_entrypoint="http://shellctl", shellctl_client_factory=client_factory, ), ) def test_shell_type_id_constant_matches_implementation_class() -> None: assert DIFY_SHELL_LAYER_TYPE_ID == DifyShellLayer.type_id def test_shell_layer_create_generates_5_plus_2_hex_session_id_and_retries_workspace_collision( monkeypatch: pytest.MonkeyPatch, ) -> None: random_suffixes = iter(["aa", "bb"]) monkeypatch.setattr(time, "time", lambda: 0x12345F) monkeypatch.setattr(secrets, "token_hex", lambda nbytes: next(random_suffixes)) def run_handler(script: str, cwd: str | None, timeout: float) -> JobResult: assert cwd is None assert timeout == 30.0 if "2345faa" in script: return _job_result("mkdir-collision", status=JobStatusName.EXITED, done=True, exit_code=17) if "2345fbb" in script: return _job_result("mkdir-success", status=JobStatusName.RUNNING, done=False, offset=4) raise AssertionError(f"Unexpected script: {script}") def wait_handler(job_id: str, offset: int, timeout: float) -> JobResult: assert job_id == "mkdir-success" assert offset == 4 assert timeout == 30.0 return _job_result("mkdir-success", status=JobStatusName.EXITED, done=True, exit_code=0, offset=8) client = FakeShellctlClient(run_handler=run_handler, wait_handler=wait_handler) layer = _shell_layer(client_factory=lambda _entrypoint: client) async def scenario() -> None: async with layer.resource_context(): await layer.on_context_create() assert client.closed is False asyncio.run(scenario()) assert layer.runtime_state.session_id == "2345fbb" assert layer.runtime_state.workspace_cwd == "~/workspace/2345fbb" assert layer.runtime_state.job_ids == ["mkdir-collision", "mkdir-success"] assert layer.runtime_state.job_offsets == {"mkdir-collision": 0, "mkdir-success": 8} assert 'mkdir "$HOME/workspace/2345fbb"' in client.run_calls[1].script assert 'mkdir -p "$HOME/workspace/2345fbb"' not in client.run_calls[1].script assert client.closed is True def test_shell_layer_suspend_leaves_client_open_until_resource_context_exits() -> None: client = FakeShellctlClient() layer = _shell_layer(client_factory=lambda _entrypoint: client) layer.runtime_state = DifyShellRuntimeState(session_id="abc12ff", workspace_cwd="~/workspace/abc12ff") async def scenario() -> None: async with layer.resource_context(): await layer.on_context_suspend() assert client.closed is False asyncio.run(scenario()) assert client.closed is True def test_shell_layer_suspend_and_resume_reuse_state_with_fresh_clients() -> None: first_client = FakeShellctlClient( run_handler=lambda _script, _cwd, _timeout: _job_result( "mkdir-job", status=JobStatusName.EXITED, done=True, exit_code=0, ) ) second_client = FakeShellctlClient() created_entrypoints: list[str] = [] clients = iter([first_client, second_client]) def factory(entrypoint: str) -> FakeShellctlClient: created_entrypoints.append(entrypoint) return next(clients) compositor = Compositor([LayerNode("shell", _shell_provider(client_factory=factory))]) async def scenario() -> None: async with compositor.enter(configs={"shell": DifyShellLayerConfig()}) as run: shell_layer = run.get_layer("shell", DifyShellLayer) initial_session_id = shell_layer.runtime_state.session_id assert initial_session_id is not None assert shell_layer.runtime_state.workspace_cwd == f"~/workspace/{initial_session_id}" shell_layer.runtime_state.job_ids = [*shell_layer.runtime_state.job_ids, "user-job"] shell_layer.runtime_state.job_offsets = { **shell_layer.runtime_state.job_offsets, "user-job": 42, } assert first_client.closed is False run.suspend_layer_on_exit("shell") assert run.session_snapshot is not None assert first_client.closed is True assert run.session_snapshot.layers[0].lifecycle_state is LifecycleState.SUSPENDED async with compositor.enter( configs={"shell": DifyShellLayerConfig()}, session_snapshot=run.session_snapshot, ) as resumed_run: resumed_shell = resumed_run.get_layer("shell", DifyShellLayer) assert second_client.closed is False assert resumed_shell.runtime_state.session_id == initial_session_id assert resumed_shell.runtime_state.workspace_cwd == f"~/workspace/{initial_session_id}" assert set(resumed_shell.runtime_state.job_ids) == {"mkdir-job", "user-job"} assert resumed_shell.runtime_state.job_offsets == {"mkdir-job": 0, "user-job": 42} resumed_run.suspend_layer_on_exit("shell") assert second_client.closed is True asyncio.run(scenario()) assert created_entrypoints == ["http://shellctl", "http://shellctl"] def test_shell_layer_delete_removes_workspace_then_force_deletes_tracked_jobs_and_closes_client() -> None: def run_handler(script: str, cwd: str | None, timeout: float) -> JobResult: assert script == 'rm -rf -- "$HOME/workspace/abc12ff"' assert cwd is None assert timeout == 30.0 return _job_result("cleanup-job", status=JobStatusName.RUNNING, done=False, offset=3) def wait_handler(job_id: str, offset: int, timeout: float) -> JobResult: assert job_id == "cleanup-job" assert offset == 3 assert timeout == 30.0 return _job_result("cleanup-job", status=JobStatusName.EXITED, done=True, exit_code=0, offset=5) client = FakeShellctlClient(run_handler=run_handler, wait_handler=wait_handler) layer = _shell_layer(client_factory=lambda _entrypoint: client) async def scenario() -> None: async with layer.resource_context(): layer.runtime_state = DifyShellRuntimeState(session_id="abc12ff", workspace_cwd="~/workspace/abc12ff") layer.runtime_state.job_ids = ["user-job", "mkdir-job"] layer.runtime_state.job_offsets = {"user-job": 9, "mkdir-job": 1} await layer.on_context_delete() assert client.closed is False asyncio.run(scenario()) assert client.events[:2] == [("run", 'rm -rf -- "$HOME/workspace/abc12ff"'), ("wait", "cleanup-job")] assert {call.job_id for call in client.delete_calls} == {"user-job", "mkdir-job", "cleanup-job"} assert all(client.events.index(("delete", call.job_id)) > client.events.index(("wait", "cleanup-job")) for call in client.delete_calls) assert all(call.force is True for call in client.delete_calls) assert layer.runtime_state.job_ids == [] assert layer.runtime_state.job_offsets == {} assert client.closed is True def test_shell_layer_create_failure_force_deletes_internal_jobs_before_reraising() -> None: client = FakeShellctlClient( run_handler=lambda _script, _cwd, _timeout: _job_result( "mkdir-failed", status=JobStatusName.EXITED, done=True, exit_code=1, ) ) layer = _shell_layer(client_factory=lambda _entrypoint: client) async def scenario() -> None: with pytest.raises(RuntimeError, match="Failed to create shell workspace"): async with layer.resource_context(): await layer.on_context_create() asyncio.run(scenario()) assert [call.job_id for call in client.delete_calls] == ["mkdir-failed"] assert all(call.force is True for call in client.delete_calls) assert layer.runtime_state.job_ids == [] assert layer.runtime_state.job_offsets == {} assert client.closed is True def test_shell_layer_tools_map_inputs_to_shellctl_calls_and_maintain_offsets() -> None: def run_handler(script: str, cwd: str | None, timeout: float) -> JobResult: assert script == "pwd" assert cwd == "~/workspace/abc12ff" assert timeout == 2.5 return _job_result( "user-job", status=JobStatusName.RUNNING, done=False, offset=10, output="/home/test\n", ) def wait_handler(job_id: str, offset: int, timeout: float) -> JobResult: assert job_id == "user-job" assert offset == 10 assert timeout == 4.0 return _job_result( "user-job", status=JobStatusName.RUNNING, done=False, offset=18, output="more\n", ) def input_handler(job_id: str, text: str, offset: int, timeout: float) -> JobResult: assert job_id == "user-job" assert text == "ls\n" assert offset == 18 assert timeout == 5.0 return _job_result( "user-job", status=JobStatusName.EXITED, done=True, exit_code=0, offset=22, output="file.txt\n", ) def terminate_handler(job_id: str, grace_seconds: float) -> JobStatusView: assert job_id == "user-job" assert grace_seconds == 1.5 return _job_status( "user-job", status=JobStatusName.TERMINATED, done=True, exit_code=130, offset=22, ) client = FakeShellctlClient( run_handler=run_handler, wait_handler=wait_handler, input_handler=input_handler, terminate_handler=terminate_handler, ) layer = _shell_layer(client_factory=lambda _entrypoint: client) tools = {tool.name: tool for tool in layer.tools} async def scenario() -> None: async with layer.resource_context(): layer.runtime_state = DifyShellRuntimeState(session_id="abc12ff", workspace_cwd="~/workspace/abc12ff") run_tool_def = await tools["shell.run"].prepare_tool_def(None) # pyright: ignore[reportArgumentType] wait_tool_def = await tools["shell.wait"].prepare_tool_def(None) # pyright: ignore[reportArgumentType] input_tool_def = await tools["shell.input"].prepare_tool_def(None) # pyright: ignore[reportArgumentType] interrupt_tool_def = await tools["shell.interrupt"].prepare_tool_def(None) # pyright: ignore[reportArgumentType] run_result = await tools["shell.run"].function_schema.call( {"script": "pwd", "timeout": 2.5}, None, # pyright: ignore[reportArgumentType] ) wait_result = await tools["shell.wait"].function_schema.call( {"job_id": "user-job", "timeout": 4.0}, None, # pyright: ignore[reportArgumentType] ) input_result = await tools["shell.input"].function_schema.call( {"job_id": "user-job", "text": "ls\n", "timeout": 5.0}, None, # pyright: ignore[reportArgumentType] ) interrupt_result = await tools["shell.interrupt"].function_schema.call( {"job_id": "user-job", "grace_seconds": 1.5}, None, # pyright: ignore[reportArgumentType] ) assert run_tool_def is not None assert wait_tool_def is not None assert input_tool_def is not None assert interrupt_tool_def is not None assert "offset" not in run_tool_def.parameters_json_schema.get("properties", {}) assert "offset" not in wait_tool_def.parameters_json_schema.get("properties", {}) assert "offset" not in input_tool_def.parameters_json_schema.get("properties", {}) assert "offset" not in interrupt_tool_def.parameters_json_schema.get("properties", {}) assert set(tools) == {"shell.run", "shell.wait", "shell.input", "shell.interrupt"} assert run_result["job_id"] == "user-job" assert run_result["offset"] == 10 assert wait_result["offset"] == 18 assert input_result["offset"] == 22 assert interrupt_result == { "job_id": "user-job", "status": "terminated", "done": True, "exit_code": 130, "offset": 22, } assert client.closed is False asyncio.run(scenario()) assert layer.runtime_state.job_ids == ["user-job"] assert layer.runtime_state.job_offsets == {"user-job": 22} assert client.closed is True def test_shell_layer_tools_reject_untracked_job_ids_without_shellctl_calls() -> None: client = FakeShellctlClient() layer = _shell_layer(client_factory=lambda _entrypoint: client) tools = {tool.name: tool for tool in layer.tools} async def scenario() -> None: async with layer.resource_context(): layer.runtime_state = DifyShellRuntimeState(session_id="abc12ff", workspace_cwd="~/workspace/abc12ff") wait_result = await tools["shell.wait"].function_schema.call( {"job_id": "missing-job"}, None, # pyright: ignore[reportArgumentType] ) input_result = await tools["shell.input"].function_schema.call( {"job_id": "missing-job", "text": "hello"}, None, # pyright: ignore[reportArgumentType] ) interrupt_result = await tools["shell.interrupt"].function_schema.call( {"job_id": "missing-job"}, None, # pyright: ignore[reportArgumentType] ) _assert_error_observation(wait_result, job_id="missing-job") _assert_error_observation(input_result, job_id="missing-job") _assert_error_observation(interrupt_result, job_id="missing-job") asyncio.run(scenario()) assert client.wait_calls == [] assert client.input_calls == [] assert client.terminate_calls == [] def test_shell_layer_hooks_and_tools_fail_clearly_outside_active_resource_context() -> None: client = FakeShellctlClient() layer = _shell_layer(client_factory=lambda _entrypoint: client) layer.runtime_state = DifyShellRuntimeState(session_id="abc12ff", workspace_cwd="~/workspace/abc12ff") tools = {tool.name: tool for tool in layer.tools} async def scenario() -> None: with pytest.raises(RuntimeError, match="resource_context"): await layer.on_context_suspend() run_result = await tools["shell.run"].function_schema.call( {"script": "pwd"}, None, # pyright: ignore[reportArgumentType] ) _assert_error_observation(run_result, includes="resource_context") asyncio.run(scenario()) assert client.run_calls == [] def test_shell_runtime_state_rejects_unsafe_resumed_workspace_identity() -> None: with pytest.raises(ValueError, match="session_id must match"): _ = DifyShellRuntimeState.model_validate( { "session_id": "../../tmp", "workspace_cwd": "~/workspace/../../tmp", "job_ids": [], "job_offsets": {}, } ) with pytest.raises(ValueError, match="workspace_cwd must equal"): _ = DifyShellRuntimeState.model_validate( { "session_id": "abc12ff", "workspace_cwd": "~/workspace/def34aa", "job_ids": [], "job_offsets": {}, } ) def test_shell_runtime_state_treats_job_ids_as_opaque_strings_and_rejects_unknown_offset_keys() -> None: state = DifyShellRuntimeState.model_validate( { "session_id": "abc12ff", "workspace_cwd": "~/workspace/abc12ff", "job_ids": ['job"bad with spaces'], "job_offsets": {'job"bad with spaces': 0}, } ) assert state.job_ids == ['job"bad with spaces'] assert state.job_offsets == {'job"bad with spaces': 0} with pytest.raises(ValueError, match="unknown job ids"): _ = DifyShellRuntimeState.model_validate( { "session_id": "abc12ff", "workspace_cwd": "~/workspace/abc12ff", "job_ids": ["job-1"], "job_offsets": {"job-2": 3}, } )