diff --git a/api/core/app/apps/agent_app/app_runner.py b/api/core/app/apps/agent_app/app_runner.py index 03c6f3e410c..f3f809e4b43 100644 --- a/api/core/app/apps/agent_app/app_runner.py +++ b/api/core/app/apps/agent_app/app_runner.py @@ -72,13 +72,48 @@ def publish_text_answer( both the backend-produced answer and short-circuited answers (moderation / annotation reply) share the exact same persistence + SSE path. """ + publish_text_delta( + queue_manager=queue_manager, + model_name=model_name, + delta=answer, + user_query=user_query, + ) + publish_message_end( + queue_manager=queue_manager, + model_name=model_name, + answer=answer, + user_query=user_query, + ) + + +def publish_text_delta( + *, + queue_manager: AppQueueManager, + model_name: str, + delta: str, + user_query: str | None = None, +) -> None: + """Publish one assistant text delta through the EasyUI chat pipeline.""" + if not delta: + return prompt_messages = _prompt_messages_from_query(user_query) chunk = LLMResultChunk( model=model_name, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=answer)), + delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=delta)), ) queue_manager.publish(QueueLLMChunkEvent(chunk=chunk), PublishFrom.APPLICATION_MANAGER) + + +def publish_message_end( + *, + queue_manager: AppQueueManager, + model_name: str, + answer: str, + user_query: str | None = None, +) -> None: + """Publish the terminal assistant result without emitting another delta.""" + prompt_messages = _prompt_messages_from_query(user_query) queue_manager.publish( QueueMessageEndEvent( llm_result=LLMResult( @@ -151,7 +186,12 @@ class AgentAppRunner: ) create_response = self._agent_backend_client.create_run(runtime.request) - terminal = self._consume_stream(create_response.run_id, queue_manager=queue_manager) + terminal, streamed_answer = self._consume_stream( + create_response.run_id, + queue_manager=queue_manager, + model_name=model_name, + query=query, + ) if isinstance(terminal, AgentBackendDeferredToolCallInternalEvent): # ENG-635: the agent asked a human. End this turn with the question and @@ -175,7 +215,13 @@ class AgentAppRunner: raise AgentBackendError(str(error)) answer = self._extract_answer(terminal.output) - self._publish_answer(queue_manager=queue_manager, model_name=model_name, answer=answer, query=query) + self._publish_terminal_answer( + queue_manager=queue_manager, + model_name=model_name, + answer=answer, + query=query, + streamed_answer=streamed_answer, + ) self._save_session( scope=scope, backend_run_id=terminal.run_id, @@ -272,8 +318,16 @@ class AgentAppRunner: parts.append(args.markdown) return "\n\n".join(parts) - def _consume_stream(self, run_id: str, *, queue_manager: AppQueueManager): + def _consume_stream( + self, + run_id: str, + *, + queue_manager: AppQueueManager, + model_name: str, + query: str | None, + ): terminal = None + streamed_answer_parts: list[str] = [] for public_event in self._agent_backend_client.stream_events(run_id): if queue_manager.is_stopped(): self._cancel_run(run_id) @@ -286,16 +340,23 @@ class AgentAppRunner: AgentBackendInternalEventType.RUN_STARTED, AgentBackendInternalEventType.STREAM_EVENT, ): - # Stream deltas are accumulated by the backend into the - # terminal output; token-level forwarding is an S3 refinement. if isinstance(internal_event, AgentBackendStreamInternalEvent): + text_delta = self._extract_stream_text_delta(internal_event) + if text_delta: + streamed_answer_parts.append(text_delta) + publish_text_delta( + queue_manager=queue_manager, + model_name=model_name, + delta=text_delta, + user_query=query, + ) continue continue terminal = internal_event break if terminal is not None: break - return terminal + return terminal, "".join(streamed_answer_parts) def _cancel_run(self, run_id: str) -> None: try: @@ -310,6 +371,35 @@ class AgentAppRunner: # task pipeline streams the chunk over SSE and persists the message. publish_text_answer(queue_manager=queue_manager, model_name=model_name, answer=answer, user_query=query) + def _publish_terminal_answer( + self, + *, + queue_manager: AppQueueManager, + model_name: str, + answer: str, + query: str | None, + streamed_answer: str, + ) -> None: + """Finish a successful streamed turn without duplicating the final text.""" + if not streamed_answer: + self._publish_answer(queue_manager=queue_manager, model_name=model_name, answer=answer, query=query) + return + + if answer.startswith(streamed_answer): + publish_text_delta( + queue_manager=queue_manager, + model_name=model_name, + delta=answer[len(streamed_answer) :], + user_query=query, + ) + elif answer != streamed_answer: + logger.warning( + "Agent App streamed answer does not match terminal output; " + "using terminal output for message persistence." + ) + + publish_message_end(queue_manager=queue_manager, model_name=model_name, answer=answer, user_query=query) + def _save_session( self, *, @@ -357,5 +447,27 @@ class AgentAppRunner: return json.dumps(output, ensure_ascii=False) return json.dumps(output, ensure_ascii=False) + @staticmethod + def _extract_stream_text_delta(event: AgentBackendStreamInternalEvent) -> str | None: + data = event.data + if not isinstance(data, dict): + return None -__all__ = ["AgentAppRunner", "publish_text_answer"] + if data.get("event_kind") == "part_delta": + delta = data.get("delta") + if isinstance(delta, dict) and delta.get("part_delta_kind") == "text": + content_delta = delta.get("content_delta") + if isinstance(content_delta, str): + return content_delta + + if data.get("event_kind") == "part_start": + part = data.get("part") + if isinstance(part, dict) and part.get("part_kind") == "text": + content = part.get("content") + if isinstance(content, str): + return content + + return None + + +__all__ = ["AgentAppRunner", "publish_message_end", "publish_text_answer", "publish_text_delta"] diff --git a/api/tests/unit_tests/core/app/apps/agent_app/test_app_runner.py b/api/tests/unit_tests/core/app/apps/agent_app/test_app_runner.py index e696d4aaa0b..13f11d7a1c3 100644 --- a/api/tests/unit_tests/core/app/apps/agent_app/test_app_runner.py +++ b/api/tests/unit_tests/core/app/apps/agent_app/test_app_runner.py @@ -4,6 +4,8 @@ saved, using the deterministic fake backend client (no live stack).""" from __future__ import annotations +from collections.abc import Iterator +from datetime import UTC, datetime from types import SimpleNamespace from typing import Any, override from unittest.mock import MagicMock @@ -11,7 +13,17 @@ from unittest.mock import MagicMock import pytest from agenton.compositor import CompositorSessionSnapshot from dify_agent.layers.ask_human import AskHumanToolResult -from dify_agent.protocol import CancelRunRequest, CancelRunResponse, RuntimeLayerSpec +from dify_agent.protocol import ( + CancelRunRequest, + CancelRunResponse, + PydanticAIStreamRunEvent, + RunEvent, + RunStartedEvent, + RunSucceededEvent, + RunSucceededEventData, + RuntimeLayerSpec, +) +from pydantic_ai.messages import PartDeltaEvent, PartStartEvent, TextPart, TextPartDelta from clients.agent_backend import ( AgentBackendError, @@ -67,6 +79,58 @@ class _RecordingFakeAgentBackendRunClient(FakeAgentBackendRunClient): return super().cancel_run(run_id, request=request) +class _StreamingFakeAgentBackendRunClient(FakeAgentBackendRunClient): + @override + def stream_events(self, run_id: str, *, after: str | None = None) -> Iterator[RunEvent]: + del after + created_at = datetime(2026, 1, 1, tzinfo=UTC) + yield RunStartedEvent(id="1-0", run_id=run_id, created_at=created_at) + yield PydanticAIStreamRunEvent( + id="2-0", + run_id=run_id, + created_at=created_at, + data=PartDeltaEvent(index=0, delta=TextPartDelta(content_delta="hello ")), + ) + yield PydanticAIStreamRunEvent( + id="3-0", + run_id=run_id, + created_at=created_at, + data=PartDeltaEvent(index=0, delta=TextPartDelta(content_delta="agent")), + ) + yield RunSucceededEvent( + id="4-0", + run_id=run_id, + created_at=created_at, + data=RunSucceededEventData( + output={"text": "hello agent"}, + session_snapshot=CompositorSessionSnapshot(layers=[]), + ), + ) + + +class _StreamingPartStartFakeAgentBackendRunClient(FakeAgentBackendRunClient): + @override + def stream_events(self, run_id: str, *, after: str | None = None) -> Iterator[RunEvent]: + del after + created_at = datetime(2026, 1, 1, tzinfo=UTC) + yield RunStartedEvent(id="1-0", run_id=run_id, created_at=created_at) + yield PydanticAIStreamRunEvent( + id="2-0", + run_id=run_id, + created_at=created_at, + data=PartStartEvent(index=0, part=TextPart(content="hello")), + ) + yield RunSucceededEvent( + id="3-0", + run_id=run_id, + created_at=created_at, + data=RunSucceededEventData( + output={"text": "hello agent"}, + session_snapshot=CompositorSessionSnapshot(layers=[]), + ), + ) + + class _FakeSessionStore: def __init__( self, @@ -165,9 +229,13 @@ def _message_end(qm: _FakeQueueManager) -> QueueMessageEndEvent: def _saved_user_query(qm: _FakeQueueManager) -> str: - prompt_messages = _message_end(qm).llm_result.prompt_messages + llm_result = _message_end(qm).llm_result + assert llm_result is not None + prompt_messages = llm_result.prompt_messages assert len(prompt_messages) == 1 - return prompt_messages[0].content + content = prompt_messages[0].content + assert isinstance(content, str) + return content def test_successful_turn_publishes_chunk_and_message_end_and_saves_session(): @@ -204,6 +272,35 @@ def test_successful_turn_publishes_chunk_and_message_end_and_saves_session(): ] +def test_successful_turn_forwards_agent_backend_stream_text_deltas_without_duplicate_terminal_chunk(): + client = _StreamingFakeAgentBackendRunClient() + store = _FakeSessionStore() + qm = _FakeQueueManager() + + _run(_runner(client, store), qm) + + chunk_events = [e for e in qm.events if isinstance(e, QueueLLMChunkEvent)] + end_events = [e for e in qm.events if isinstance(e, QueueMessageEndEvent)] + assert [event.chunk.delta.message.content for event in chunk_events] == ["hello ", "agent"] + assert len(end_events) == 1 + assert end_events[0].llm_result.message.content == "hello agent" + assert store.saved + + +def test_successful_turn_forwards_part_start_text_and_publishes_missing_terminal_suffix(): + client = _StreamingPartStartFakeAgentBackendRunClient() + store = _FakeSessionStore() + qm = _FakeQueueManager() + + _run(_runner(client, store), qm) + + chunk_events = [e for e in qm.events if isinstance(e, QueueLLMChunkEvent)] + end_events = [e for e in qm.events if isinstance(e, QueueMessageEndEvent)] + assert [event.chunk.delta.message.content for event in chunk_events] == ["hello", " agent"] + assert len(end_events) == 1 + assert end_events[0].llm_result.message.content == "hello agent" + + def test_prior_session_snapshot_is_threaded_into_request(): prior = CompositorSessionSnapshot(layers=[]) client = FakeAgentBackendRunClient()