diff --git a/api/.importlinter b/api/.importlinter index 6e15f06a5c..14a66f2ff9 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -37,7 +37,6 @@ type = layers layers = graph_engine response_coordinator - output_registry containers = core.workflow.graph_engine diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index b627ccc634..8ac27143e3 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -35,7 +35,6 @@ from .event_management import EventCollector, EventEmitter, EventHandlerRegistry from .graph_traversal import BranchHandler, EdgeProcessor, NodeReadinessChecker, SkipPropagator from .layers.base import Layer from .orchestration import Dispatcher, ExecutionCoordinator -from .output_registry import OutputRegistry from .protocols.command_channel import CommandChannel from .response_coordinator import ResponseStreamCoordinator from .state_management import UnifiedStateManager @@ -122,8 +121,9 @@ class GraphEngine: self.state_manager = UnifiedStateManager(self.graph, self.ready_queue) # Response coordination - self.output_registry = OutputRegistry(self.graph_runtime_state.variable_pool) - self.response_coordinator = ResponseStreamCoordinator(registry=self.output_registry, graph=self.graph) + self.response_coordinator = ResponseStreamCoordinator( + variable_pool=self.graph_runtime_state.variable_pool, graph=self.graph + ) # Event management self.event_collector = EventCollector() diff --git a/api/core/workflow/graph_engine/output_registry/__init__.py b/api/core/workflow/graph_engine/output_registry/__init__.py deleted file mode 100644 index a65a62ec53..0000000000 --- a/api/core/workflow/graph_engine/output_registry/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -""" -OutputRegistry - Thread-safe storage for node outputs (streams and scalars) - -This component provides thread-safe storage and retrieval of node outputs, -supporting both scalar values and streaming chunks with proper state management. -""" - -from .registry import OutputRegistry - -__all__ = ["OutputRegistry"] diff --git a/api/core/workflow/graph_engine/output_registry/registry.py b/api/core/workflow/graph_engine/output_registry/registry.py deleted file mode 100644 index 29eefa5abe..0000000000 --- a/api/core/workflow/graph_engine/output_registry/registry.py +++ /dev/null @@ -1,148 +0,0 @@ -""" -Main OutputRegistry implementation. - -This module contains the public OutputRegistry class that provides -thread-safe storage for node outputs. -""" - -from collections.abc import Sequence -from threading import RLock -from typing import TYPE_CHECKING, Any, Union, final - -from core.variables import Segment -from core.workflow.entities.variable_pool import VariablePool - -from .stream import Stream - -if TYPE_CHECKING: - from core.workflow.graph_events import NodeRunStreamChunkEvent - - -@final -class OutputRegistry: - """ - Thread-safe registry for storing and retrieving node outputs. - - Supports both scalar values and streaming chunks with proper state management. - All operations are thread-safe using internal locking. - """ - - def __init__(self, variable_pool: VariablePool) -> None: - """Initialize empty registry with thread-safe storage.""" - self._lock = RLock() - self._scalars = variable_pool - self._streams: dict[tuple[str, ...], Stream] = {} - - def _selector_to_key(self, selector: Sequence[str]) -> tuple[str, ...]: - """Convert selector list to tuple key for internal storage.""" - return tuple(selector) - - def set_scalar( - self, selector: Sequence[str], value: Union[str, int, float, bool, dict[str, Any], list[Any]] - ) -> None: - """ - Set a scalar value for the given selector. - - Args: - selector: List of strings identifying the output location - value: The scalar value to store - """ - with self._lock: - self._scalars.add(selector, value) - - def get_scalar(self, selector: Sequence[str]) -> "Segment | None": - """ - Get a scalar value for the given selector. - - Args: - selector: List of strings identifying the output location - - Returns: - The stored Variable object, or None if not found - """ - with self._lock: - return self._scalars.get(selector) - - def append_chunk(self, selector: Sequence[str], event: "NodeRunStreamChunkEvent") -> None: - """ - Append a NodeRunStreamChunkEvent to the stream for the given selector. - - Args: - selector: List of strings identifying the stream location - event: The NodeRunStreamChunkEvent to append - - Raises: - ValueError: If the stream is already closed - """ - key = self._selector_to_key(selector) - with self._lock: - if key not in self._streams: - self._streams[key] = Stream() - - try: - self._streams[key].append(event) - except ValueError: - raise ValueError(f"Stream {'.'.join(selector)} is already closed") - - def pop_chunk(self, selector: Sequence[str]) -> "NodeRunStreamChunkEvent | None": - """ - Pop the next unread NodeRunStreamChunkEvent from the stream. - - Args: - selector: List of strings identifying the stream location - - Returns: - The next event, or None if no unread events available - """ - key = self._selector_to_key(selector) - with self._lock: - if key not in self._streams: - return None - - return self._streams[key].pop_next() - - def has_unread(self, selector: Sequence[str]) -> bool: - """ - Check if the stream has unread events. - - Args: - selector: List of strings identifying the stream location - - Returns: - True if there are unread events, False otherwise - """ - key = self._selector_to_key(selector) - with self._lock: - if key not in self._streams: - return False - - return self._streams[key].has_unread() - - def close_stream(self, selector: Sequence[str]) -> None: - """ - Mark a stream as closed (no more chunks can be appended). - - Args: - selector: List of strings identifying the stream location - """ - key = self._selector_to_key(selector) - with self._lock: - if key not in self._streams: - self._streams[key] = Stream() - self._streams[key].close() - - def stream_closed(self, selector: Sequence[str]) -> bool: - """ - Check if a stream is closed. - - Args: - selector: List of strings identifying the stream location - - Returns: - True if the stream is closed, False otherwise - """ - key = self._selector_to_key(selector) - with self._lock: - if key not in self._streams: - return False - return self._streams[key].is_closed diff --git a/api/core/workflow/graph_engine/output_registry/stream.py b/api/core/workflow/graph_engine/output_registry/stream.py deleted file mode 100644 index 8a99b56d1f..0000000000 --- a/api/core/workflow/graph_engine/output_registry/stream.py +++ /dev/null @@ -1,70 +0,0 @@ -""" -Internal stream implementation for OutputRegistry. - -This module contains the private Stream class used internally by OutputRegistry -to manage streaming data chunks. -""" - -from typing import TYPE_CHECKING, final - -if TYPE_CHECKING: - from core.workflow.graph_events import NodeRunStreamChunkEvent - - -@final -class Stream: - """ - A stream that holds NodeRunStreamChunkEvent objects and tracks read position. - - This class encapsulates stream-specific data and operations, - including event storage, read position tracking, and closed state. - - Note: This is an internal class not exposed in the public API. - """ - - def __init__(self) -> None: - """Initialize an empty stream.""" - self.events: list[NodeRunStreamChunkEvent] = [] - self.read_position: int = 0 - self.is_closed: bool = False - - def append(self, event: "NodeRunStreamChunkEvent") -> None: - """ - Append a NodeRunStreamChunkEvent to the stream. - - Args: - event: The NodeRunStreamChunkEvent to append - - Raises: - ValueError: If the stream is already closed - """ - if self.is_closed: - raise ValueError("Cannot append to a closed stream") - self.events.append(event) - - def pop_next(self) -> "NodeRunStreamChunkEvent | None": - """ - Pop the next unread NodeRunStreamChunkEvent from the stream. - - Returns: - The next event, or None if no unread events available - """ - if self.read_position >= len(self.events): - return None - - event = self.events[self.read_position] - self.read_position += 1 - return event - - def has_unread(self) -> bool: - """ - Check if the stream has unread events. - - Returns: - True if there are unread events, False otherwise - """ - return self.read_position < len(self.events) - - def close(self) -> None: - """Mark the stream as closed (no more chunks can be appended).""" - self.is_closed = True diff --git a/api/core/workflow/graph_engine/response_coordinator/coordinator.py b/api/core/workflow/graph_engine/response_coordinator/coordinator.py index 1fb58852d2..a7b77bdf4a 100644 --- a/api/core/workflow/graph_engine/response_coordinator/coordinator.py +++ b/api/core/workflow/graph_engine/response_coordinator/coordinator.py @@ -12,12 +12,12 @@ from threading import RLock from typing import TypeAlias, final from uuid import uuid4 +from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import NodeExecutionType, NodeState from core.workflow.graph import Graph from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent from core.workflow.nodes.base.template import TextSegment, VariableSegment -from ..output_registry import OutputRegistry from .path import Path from .session import ResponseSession @@ -36,20 +36,25 @@ class ResponseStreamCoordinator: Ensures ordered streaming of responses based on upstream node outputs and constants. """ - def __init__(self, registry: OutputRegistry, graph: "Graph") -> None: + def __init__(self, variable_pool: "VariablePool", graph: "Graph") -> None: """ - Initialize coordinator with output registry. + Initialize coordinator with variable pool. Args: - registry: OutputRegistry instance for accessing node outputs + variable_pool: VariablePool instance for accessing node variables graph: Graph instance for looking up node information """ - self.registry = registry + self.variable_pool = variable_pool self.graph = graph self.active_session: ResponseSession | None = None self.waiting_sessions: deque[ResponseSession] = deque() self.lock = RLock() + # Internal stream management (replacing OutputRegistry) + self._stream_buffers: dict[tuple[str, ...], list[NodeRunStreamChunkEvent]] = {} + self._stream_positions: dict[tuple[str, ...], int] = {} + self._closed_streams: set[tuple[str, ...]] = set() + # Track response nodes self._response_nodes: set[NodeID] = set() @@ -256,15 +261,15 @@ class ResponseStreamCoordinator: ) -> Sequence[NodeRunStreamChunkEvent]: with self.lock: if isinstance(event, NodeRunStreamChunkEvent): - self.registry.append_chunk(event.selector, event) + self._append_stream_chunk(event.selector, event) if event.is_final: - self.registry.close_stream(event.selector) + self._close_stream(event.selector) return self.try_flush() else: # Skip cause we share the same variable pool. # # for variable_name, variable_value in event.node_run_result.outputs.items(): - # self.registry.set_scalar((event.node_id, variable_name), variable_value) + # self.variable_pool.add((event.node_id, variable_name), variable_value) return self.try_flush() return [] @@ -327,8 +332,8 @@ class ResponseStreamCoordinator: execution_id = self._get_or_create_execution_id(output_node_id) # Stream all available chunks - while self.registry.has_unread(segment.selector): - if event := self.registry.pop_chunk(segment.selector): + while self._has_unread_stream(segment.selector): + if event := self._pop_stream_chunk(segment.selector): # For special selectors, we need to update the event to use # the active response node's information if self.active_session and source_selector_prefix not in self.graph.nodes: @@ -349,12 +354,12 @@ class ResponseStreamCoordinator: events.append(event) # Check if this is the last chunk by looking ahead - stream_closed = self.registry.stream_closed(segment.selector) + stream_closed = self._is_stream_closed(segment.selector) # Check if stream is closed to determine if segment is complete if stream_closed: is_complete = True - elif value := self.registry.get_scalar(segment.selector): + elif value := self.variable_pool.get(segment.selector): # Process scalar value is_last_segment = bool( self.active_session and self.active_session.index == len(self.active_session.template.segments) - 1 @@ -464,3 +469,93 @@ class ResponseStreamCoordinator: events = self.try_flush() return events + + # ============= Internal Stream Management Methods ============= + + def _append_stream_chunk(self, selector: Sequence[str], event: NodeRunStreamChunkEvent) -> None: + """ + Append a stream chunk to the internal buffer. + + Args: + selector: List of strings identifying the stream location + event: The NodeRunStreamChunkEvent to append + + Raises: + ValueError: If the stream is already closed + """ + key = tuple(selector) + + if key in self._closed_streams: + raise ValueError(f"Stream {'.'.join(selector)} is already closed") + + if key not in self._stream_buffers: + self._stream_buffers[key] = [] + self._stream_positions[key] = 0 + + self._stream_buffers[key].append(event) + + def _pop_stream_chunk(self, selector: Sequence[str]) -> NodeRunStreamChunkEvent | None: + """ + Pop the next unread stream chunk from the buffer. + + Args: + selector: List of strings identifying the stream location + + Returns: + The next event, or None if no unread events available + """ + key = tuple(selector) + + if key not in self._stream_buffers: + return None + + position = self._stream_positions.get(key, 0) + buffer = self._stream_buffers[key] + + if position >= len(buffer): + return None + + event = buffer[position] + self._stream_positions[key] = position + 1 + return event + + def _has_unread_stream(self, selector: Sequence[str]) -> bool: + """ + Check if the stream has unread events. + + Args: + selector: List of strings identifying the stream location + + Returns: + True if there are unread events, False otherwise + """ + key = tuple(selector) + + if key not in self._stream_buffers: + return False + + position = self._stream_positions.get(key, 0) + return position < len(self._stream_buffers[key]) + + def _close_stream(self, selector: Sequence[str]) -> None: + """ + Mark a stream as closed (no more chunks can be appended). + + Args: + selector: List of strings identifying the stream location + """ + key = tuple(selector) + self._closed_streams.add(key) + + def _is_stream_closed(self, selector: Sequence[str]) -> bool: + """ + Check if a stream is closed. + + Args: + selector: List of strings identifying the stream location + + Returns: + True if the stream is closed, False otherwise + """ + key = tuple(selector) + return key in self._closed_streams diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_output_registry.py b/api/tests/unit_tests/core/workflow/graph_engine/test_output_registry.py deleted file mode 100644 index d27f610fe6..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_output_registry.py +++ /dev/null @@ -1,135 +0,0 @@ -from uuid import uuid4 - -import pytest - -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import NodeType -from core.workflow.graph_engine.output_registry import OutputRegistry -from core.workflow.graph_events import NodeRunStreamChunkEvent - - -class TestOutputRegistry: - def test_scalar_operations(self): - variable_pool = VariablePool() - registry = OutputRegistry(variable_pool) - - # Test setting and getting scalar - registry.set_scalar(["node1", "output"], "test_value") - - segment = registry.get_scalar(["node1", "output"]) - assert segment - assert segment.text == "test_value" - - # Test getting non-existent scalar - assert registry.get_scalar(["non_existent"]) is None - - def test_stream_operations(self): - variable_pool = VariablePool() - registry = OutputRegistry(variable_pool) - - # Create test events - event1 = NodeRunStreamChunkEvent( - id=str(uuid4()), - node_id="node1", - node_type=NodeType.LLM, - selector=["node1", "stream"], - chunk="chunk1", - is_final=False, - ) - event2 = NodeRunStreamChunkEvent( - id=str(uuid4()), - node_id="node1", - node_type=NodeType.LLM, - selector=["node1", "stream"], - chunk="chunk2", - is_final=True, - ) - - # Test appending events - registry.append_chunk(["node1", "stream"], event1) - registry.append_chunk(["node1", "stream"], event2) - - # Test has_unread - assert registry.has_unread(["node1", "stream"]) is True - - # Test popping events - popped_event1 = registry.pop_chunk(["node1", "stream"]) - assert popped_event1 == event1 - assert popped_event1.chunk == "chunk1" - - popped_event2 = registry.pop_chunk(["node1", "stream"]) - assert popped_event2 == event2 - assert popped_event2.chunk == "chunk2" - - assert registry.pop_chunk(["node1", "stream"]) is None - - # Test has_unread after popping all - assert registry.has_unread(["node1", "stream"]) is False - - def test_stream_closing(self): - variable_pool = VariablePool() - registry = OutputRegistry(variable_pool) - - # Test stream is not closed initially - assert registry.stream_closed(["node1", "stream"]) is False - - # Test closing stream - registry.close_stream(["node1", "stream"]) - assert registry.stream_closed(["node1", "stream"]) is True - - # Test appending to closed stream raises error - event = NodeRunStreamChunkEvent( - id=str(uuid4()), - node_id="node1", - node_type=NodeType.LLM, - selector=["node1", "stream"], - chunk="chunk", - is_final=False, - ) - with pytest.raises(ValueError, match="Stream node1.stream is already closed"): - registry.append_chunk(["node1", "stream"], event) - - def test_thread_safety(self): - import threading - - variable_pool = VariablePool() - registry = OutputRegistry(variable_pool) - results = [] - - def append_chunks(thread_id: int): - for i in range(100): - event = NodeRunStreamChunkEvent( - id=str(uuid4()), - node_id="test_node", - node_type=NodeType.LLM, - selector=["stream"], - chunk=f"thread{thread_id}_chunk{i}", - is_final=False, - ) - registry.append_chunk(["stream"], event) - - # Start multiple threads - threads = [] - for i in range(5): - thread = threading.Thread(target=append_chunks, args=(i,)) - threads.append(thread) - thread.start() - - # Wait for threads - for thread in threads: - thread.join() - - # Verify all events are present - events = [] - while True: - event = registry.pop_chunk(["stream"]) - if event is None: - break - events.append(event) - - assert len(events) == 500 # 5 threads * 100 events each - # Verify the events have the expected chunk content format - chunk_texts = [e.chunk for e in events] - for i in range(5): - for j in range(100): - assert f"thread{i}_chunk{j}" in chunk_texts diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_response_coordinator.py b/api/tests/unit_tests/core/workflow/graph_engine/test_response_coordinator.py deleted file mode 100644 index eadadfb8c8..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_response_coordinator.py +++ /dev/null @@ -1,347 +0,0 @@ -"""Test cases for ResponseStreamCoordinator.""" - -from unittest.mock import Mock - -from core.variables import StringSegment -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import NodeState, NodeType -from core.workflow.graph import Graph -from core.workflow.graph_engine.output_registry import OutputRegistry -from core.workflow.graph_engine.response_coordinator import ResponseStreamCoordinator -from core.workflow.graph_engine.response_coordinator.session import ResponseSession -from core.workflow.nodes.answer.answer_node import AnswerNode -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.base.template import Template, TextSegment, VariableSegment - - -class TestResponseStreamCoordinator: - """Test cases for ResponseStreamCoordinator.""" - - def test_skip_variable_segment_from_skipped_node(self): - """Test that VariableSegments from skipped nodes are properly skipped during try_flush.""" - # Create mock graph - graph = Mock(spec=Graph) - - # Create mock nodes - skipped_node = Mock(spec=Node) - skipped_node.id = "skipped_node" - skipped_node.state = NodeState.SKIPPED - skipped_node.node_type = NodeType.LLM - - active_node = Mock(spec=Node) - active_node.id = "active_node" - active_node.state = NodeState.TAKEN - active_node.node_type = NodeType.LLM - - response_node = Mock(spec=AnswerNode) - response_node.id = "response_node" - response_node.node_type = NodeType.ANSWER - - # Set up graph nodes dictionary - graph.nodes = {"skipped_node": skipped_node, "active_node": active_node, "response_node": response_node} - - # Create output registry with variable pool - variable_pool = VariablePool() - registry = OutputRegistry(variable_pool) - - # Add some test data to registry for the active node - registry.set_scalar(("active_node", "output"), StringSegment(value="Active output")) - - # Create RSC instance - rsc = ResponseStreamCoordinator(registry=registry, graph=graph) - - # Create template with segments from both skipped and active nodes - template = Template( - segments=[ - VariableSegment(selector=["skipped_node", "output"]), - TextSegment(text=" - "), - VariableSegment(selector=["active_node", "output"]), - ] - ) - - # Create and set active session - session = ResponseSession(node_id="response_node", template=template, index=0) - rsc.active_session = session - - # Execute try_flush - events = rsc.try_flush() - - # Verify that: - # 1. The skipped node's variable segment was skipped (index advanced) - # 2. The text segment was processed - # 3. The active node's variable segment was processed - assert len(events) == 2 # TextSegment + VariableSegment from active_node - - # Check that the first event is the text segment - assert events[0].chunk == " - " - - # Check that the second event is from the active node - assert events[1].chunk == "Active output" - assert events[1].selector == ["active_node", "output"] - - # Session should be complete - assert session.is_complete() - - def test_process_variable_segment_from_non_skipped_node(self): - """Test that VariableSegments from non-skipped nodes are processed normally.""" - # Create mock graph - graph = Mock(spec=Graph) - - # Create mock nodes - active_node1 = Mock(spec=Node) - active_node1.id = "node1" - active_node1.state = NodeState.TAKEN - active_node1.node_type = NodeType.LLM - - active_node2 = Mock(spec=Node) - active_node2.id = "node2" - active_node2.state = NodeState.TAKEN - active_node2.node_type = NodeType.LLM - - response_node = Mock(spec=AnswerNode) - response_node.id = "response_node" - response_node.node_type = NodeType.ANSWER - - # Set up graph nodes dictionary - graph.nodes = {"node1": active_node1, "node2": active_node2, "response_node": response_node} - - # Create output registry with variable pool - variable_pool = VariablePool() - registry = OutputRegistry(variable_pool) - - # Add test data to registry - registry.set_scalar(("node1", "output"), StringSegment(value="Output 1")) - registry.set_scalar(("node2", "output"), StringSegment(value="Output 2")) - - # Create RSC instance - rsc = ResponseStreamCoordinator(registry=registry, graph=graph) - - # Create template with segments from active nodes - template = Template( - segments=[ - VariableSegment(selector=["node1", "output"]), - TextSegment(text=" | "), - VariableSegment(selector=["node2", "output"]), - ] - ) - - # Create and set active session - session = ResponseSession(node_id="response_node", template=template, index=0) - rsc.active_session = session - - # Execute try_flush - events = rsc.try_flush() - - # Verify all segments were processed - assert len(events) == 3 - - # Check events in order - assert events[0].chunk == "Output 1" - assert events[0].selector == ["node1", "output"] - - assert events[1].chunk == " | " - - assert events[2].chunk == "Output 2" - assert events[2].selector == ["node2", "output"] - - # Session should be complete - assert session.is_complete() - - def test_mixed_skipped_and_active_nodes(self): - """Test processing with a mix of skipped and active nodes.""" - # Create mock graph - graph = Mock(spec=Graph) - - # Create mock nodes with various states - skipped_node1 = Mock(spec=Node) - skipped_node1.id = "skip1" - skipped_node1.state = NodeState.SKIPPED - skipped_node1.node_type = NodeType.LLM - - active_node = Mock(spec=Node) - active_node.id = "active" - active_node.state = NodeState.TAKEN - active_node.node_type = NodeType.LLM - - skipped_node2 = Mock(spec=Node) - skipped_node2.id = "skip2" - skipped_node2.state = NodeState.SKIPPED - skipped_node2.node_type = NodeType.LLM - - response_node = Mock(spec=AnswerNode) - response_node.id = "response_node" - response_node.node_type = NodeType.ANSWER - - # Set up graph nodes dictionary - graph.nodes = { - "skip1": skipped_node1, - "active": active_node, - "skip2": skipped_node2, - "response_node": response_node, - } - - # Create output registry with variable pool - variable_pool = VariablePool() - registry = OutputRegistry(variable_pool) - - # Add data only for active node - registry.set_scalar(("active", "result"), StringSegment(value="Active Result")) - - # Create RSC instance - rsc = ResponseStreamCoordinator(registry=registry, graph=graph) - - # Create template with mixed segments - template = Template( - segments=[ - TextSegment(text="Start: "), - VariableSegment(selector=["skip1", "output"]), - VariableSegment(selector=["active", "result"]), - VariableSegment(selector=["skip2", "output"]), - TextSegment(text=" :End"), - ] - ) - - # Create and set active session - session = ResponseSession(node_id="response_node", template=template, index=0) - rsc.active_session = session - - # Execute try_flush - events = rsc.try_flush() - - # Should have: "Start: ", "Active Result", " :End" - assert len(events) == 3 - - assert events[0].chunk == "Start: " - assert events[1].chunk == "Active Result" - assert events[1].selector == ["active", "result"] - assert events[2].chunk == " :End" - - # Session should be complete - assert session.is_complete() - - def test_all_variable_segments_skipped(self): - """Test when all VariableSegments are from skipped nodes.""" - # Create mock graph - graph = Mock(spec=Graph) - - # Create all skipped nodes - skipped_node1 = Mock(spec=Node) - skipped_node1.id = "skip1" - skipped_node1.state = NodeState.SKIPPED - skipped_node1.node_type = NodeType.LLM - - skipped_node2 = Mock(spec=Node) - skipped_node2.id = "skip2" - skipped_node2.state = NodeState.SKIPPED - skipped_node2.node_type = NodeType.LLM - - response_node = Mock(spec=AnswerNode) - response_node.id = "response_node" - response_node.node_type = NodeType.ANSWER - - # Set up graph nodes dictionary - graph.nodes = {"skip1": skipped_node1, "skip2": skipped_node2, "response_node": response_node} - - # Create output registry (empty since nodes are skipped) with variable pool - variable_pool = VariablePool() - registry = OutputRegistry(variable_pool) - - # Create RSC instance - rsc = ResponseStreamCoordinator(registry=registry, graph=graph) - - # Create template with only skipped segments - template = Template( - segments=[ - VariableSegment(selector=["skip1", "output"]), - VariableSegment(selector=["skip2", "output"]), - TextSegment(text="Final text"), - ] - ) - - # Create and set active session - session = ResponseSession(node_id="response_node", template=template, index=0) - rsc.active_session = session - - # Execute try_flush - events = rsc.try_flush() - - # Should only have the final text segment - assert len(events) == 1 - assert events[0].chunk == "Final text" - - # Session should be complete - assert session.is_complete() - - def test_special_prefix_selectors(self): - """Test that special prefix selectors (sys, env, conversation) are handled correctly.""" - # Create mock graph - graph = Mock(spec=Graph) - - # Create response node - response_node = Mock(spec=AnswerNode) - response_node.id = "response_node" - response_node.node_type = NodeType.ANSWER - - # Set up graph nodes dictionary (no sys, env, conversation nodes) - graph.nodes = {"response_node": response_node} - - # Create output registry with special selector data and variable pool - variable_pool = VariablePool() - registry = OutputRegistry(variable_pool) - registry.set_scalar(("sys", "user_id"), StringSegment(value="user123")) - registry.set_scalar(("env", "api_key"), StringSegment(value="key456")) - registry.set_scalar(("conversation", "id"), StringSegment(value="conv789")) - - # Create RSC instance - rsc = ResponseStreamCoordinator(registry=registry, graph=graph) - - # Create template with special selectors - template = Template( - segments=[ - TextSegment(text="User: "), - VariableSegment(selector=["sys", "user_id"]), - TextSegment(text=", API: "), - VariableSegment(selector=["env", "api_key"]), - TextSegment(text=", Conv: "), - VariableSegment(selector=["conversation", "id"]), - ] - ) - - # Create and set active session - session = ResponseSession(node_id="response_node", template=template, index=0) - rsc.active_session = session - - # Execute try_flush - events = rsc.try_flush() - - # Should have all segments processed - assert len(events) == 6 - - # Check text segments - assert events[0].chunk == "User: " - assert events[0].node_id == "response_node" - - # Check sys selector - should use response node's info - assert events[1].chunk == "user123" - assert events[1].selector == ["sys", "user_id"] - assert events[1].node_id == "response_node" - assert events[1].node_type == NodeType.ANSWER - - assert events[2].chunk == ", API: " - - # Check env selector - should use response node's info - assert events[3].chunk == "key456" - assert events[3].selector == ["env", "api_key"] - assert events[3].node_id == "response_node" - assert events[3].node_type == NodeType.ANSWER - - assert events[4].chunk == ", Conv: " - - # Check conversation selector - should use response node's info - assert events[5].chunk == "conv789" - assert events[5].selector == ["conversation", "id"] - assert events[5].node_id == "response_node" - assert events[5].node_type == NodeType.ANSWER - - # Session should be complete - assert session.is_complete()