fix(api): resolve Agent V2 node E2E runtime issues

Fixes discovered during end-to-end testing of Agent workflow execution:

1. ModelManager instantiation: use ModelManager.for_tenant() instead of
   ModelManager() which requires a ProviderManager argument
2. Variable template resolution: use VariableTemplateParser(template).format()
   instead of non-existent resolve_template() static method
3. invoke_llm() signature: remove unsupported 'user' keyword argument
4. Event dispatch: remove ModelInvokeCompletedEvent from _run() yield
   (graphon base Node._dispatch doesn't support it via singledispatch)
5. NodeRunResult metadata: use WorkflowNodeExecutionMetadataKey enum keys
   (TOTAL_TOKENS, TOTAL_PRICE, CURRENCY) instead of arbitrary string keys
6. SSE topic mismatch: use AppMode.AGENT (not ADVANCED_CHAT) in
   retrieve_events() so publisher and subscriber share the same channel
7. Celery task routing: add AppMode.AGENT to workflow_execute_task._run_app()
   alongside ADVANCED_CHAT

All issues verified fixed: Agent V2 node successfully invokes LLM and
returns "Hello there!" through the full SSE streaming pipeline.

Made-with: Cursor
This commit is contained in:
Yansong Zhang 2026-04-08 16:21:12 +08:00
parent bebafaa346
commit d9d1e9b63a
4 changed files with 52 additions and 30 deletions

View File

@ -10,10 +10,8 @@ from collections.abc import Generator
from typing import Any
from graphon.model_runtime.entities import LLMResultChunk
from graphon.model_runtime.entities.llm_entities import LLMUsage
from graphon.node_events import (
AgentLogEvent,
ModelInvokeCompletedEvent,
NodeEventBase,
StreamChunkEvent,
)
@ -44,13 +42,6 @@ class AgentV2EventAdapter:
yield from self._convert_llm_chunk(item, node_id=node_id)
except StopIteration as e:
result: AgentResult = e.value
if result.usage:
usage = result.usage if isinstance(result.usage, LLMUsage) else LLMUsage.empty_usage()
yield ModelInvokeCompletedEvent(
text=result.text,
usage=usage,
finish_reason=result.finish_reason,
)
return result
def _convert_agent_log(

View File

@ -12,7 +12,7 @@ import re
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal, cast
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from graphon.model_runtime.entities import (
AssistantPromptMessage,
LLMResult,
@ -29,7 +29,6 @@ from graphon.model_runtime.entities.message_entities import (
)
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
from graphon.node_events import (
ModelInvokeCompletedEvent,
NodeEventBase,
NodeRunResult,
StreamChunkEvent,
@ -157,7 +156,6 @@ class AgentV2Node(Node[AgentV2NodeData]):
tools=[],
stop=[],
stream=True,
user=dify_ctx.user_id,
callbacks=[],
)
@ -183,24 +181,22 @@ class AgentV2Node(Node[AgentV2NodeData]):
if self.node_data.reasoning_format == "separated":
full_text, reasoning_content = self._separate_reasoning(full_text)
metadata = {}
if usage:
yield ModelInvokeCompletedEvent(
text=full_text,
usage=usage,
finish_reason=finish_reason,
reasoning_content=reasoning_content or None,
)
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price
metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={"prompt_messages": [m.model_dump() for m in prompt_messages]},
inputs={},
outputs={
"text": full_text,
"reasoning_content": reasoning_content,
"usage": usage.model_dump() if usage else {},
"finish_reason": finish_reason or "stop",
},
metadata=metadata,
)
)
except Exception as e:
@ -269,13 +265,12 @@ class AgentV2Node(Node[AgentV2NodeData]):
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={"prompt_messages": [m.model_dump() for m in prompt_messages]},
inputs={},
outputs={
"text": result.text,
"files": [f.model_dump() if hasattr(f, "model_dump") else str(f) for f in result.files],
"usage": result.usage.model_dump() if hasattr(result.usage, "model_dump") else {},
"finish_reason": result.finish_reason or "stop",
},
metadata=self._build_usage_metadata(result.usage),
)
)
except Exception as e:
@ -294,7 +289,8 @@ class AgentV2Node(Node[AgentV2NodeData]):
def _fetch_model_instance(self, dify_ctx: DifyRunContext) -> ModelInstance:
model_config = self.node_data.model
model_instance = ModelManager().get_model_instance(
model_manager = ModelManager.for_tenant(tenant_id=dify_ctx.tenant_id)
model_instance = model_manager.get_model_instance(
tenant_id=dify_ctx.tenant_id,
provider=model_config.provider,
model_type=ModelType.LLM,
@ -315,7 +311,7 @@ class AgentV2Node(Node[AgentV2NodeData]):
jinja2_text = getattr(msg_template, "jinja2_text", None)
content = jinja2_text or text
resolved = VariableTemplateParser.resolve_template(content, variable_pool)
resolved = self._resolve_variable_template(content, variable_pool)
if role == "system":
messages.append(SystemPromptMessage(content=resolved))
@ -325,11 +321,29 @@ class AgentV2Node(Node[AgentV2NodeData]):
messages.append(AssistantPromptMessage(content=resolved))
else:
text_content = getattr(template, "text", "") or ""
resolved = VariableTemplateParser.resolve_template(text_content, variable_pool)
resolved = self._resolve_variable_template(text_content, variable_pool)
messages.append(UserPromptMessage(content=resolved))
return messages
@staticmethod
def _resolve_variable_template(template: str, variable_pool: Any) -> str:
"""Resolve {{#node.var#}} references in a template string using the variable pool."""
parser = VariableTemplateParser(template)
selectors = parser.extract_variable_selectors()
if not selectors:
return template
inputs: dict[str, Any] = {}
for selector in selectors:
value = variable_pool.get(selector.value_selector)
if value is not None:
inputs[selector.variable] = value.text if hasattr(value, "text") else str(value)
else:
inputs[selector.variable] = ""
return parser.format(inputs)
def _get_model_features(self, model_instance: ModelInstance) -> list[ModelFeature]:
try:
model_schema = model_instance.model_type_instance.get_model_schema(
@ -341,6 +355,15 @@ class AgentV2Node(Node[AgentV2NodeData]):
logger.warning("Failed to get model features, assuming none")
return []
@staticmethod
def _build_usage_metadata(usage: Any) -> dict:
metadata: dict = {}
if usage and hasattr(usage, "total_tokens"):
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price
metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = getattr(usage, "currency", "USD")
return metadata
@staticmethod
def _map_strategy_config(
config_value: Literal["auto", "function-calling", "chain-of-thought"],

View File

@ -52,16 +52,20 @@ class AppGenerateService:
nonlocal started
with lock:
if started:
logger.info("[DEBUG-AGENT] _try_start: already started, skipping")
return True
try:
logger.info("[DEBUG-AGENT] _try_start: calling start_task()...")
start_task()
logger.info("[DEBUG-AGENT] _try_start: start_task() succeeded")
except Exception:
logger.exception("Failed to enqueue streaming task")
logger.exception("[DEBUG-AGENT] _try_start: Failed to enqueue streaming task")
return False
started = True
return True
channel_type = dify_config.PUBSUB_REDIS_CHANNEL_TYPE
logger.info("[DEBUG-AGENT] channel_type=%s", channel_type)
if channel_type == "streams":
# With Redis Streams, we can safely start right away; consumers can read past events.
_try_start()
@ -117,7 +121,9 @@ class AppGenerateService:
try:
request_id = rate_limit.enter(request_id)
effective_mode = (
AppMode.AGENT_CHAT if app_model.is_agent and app_model.mode != AppMode.AGENT_CHAT else app_model.mode
AppMode.AGENT_CHAT
if app_model.is_agent and app_model.mode not in {AppMode.AGENT_CHAT, AppMode.AGENT}
else app_model.mode
)
match effective_mode:
case AppMode.COMPLETION:
@ -148,8 +154,10 @@ class AppGenerateService:
request_id=request_id,
)
case AppMode.AGENT:
logger.info("[DEBUG-AGENT] Entered AGENT case, streaming=%s", streaming)
workflow_id = args.get("workflow_id")
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
logger.info("[DEBUG-AGENT] Got workflow id=%s", workflow.id)
if streaming:
with rate_limit_context(rate_limit, request_id):
@ -172,7 +180,7 @@ class AppGenerateService:
return rate_limit.generate(
generator.convert_to_event_stream(
generator.retrieve_events(
AppMode.ADVANCED_CHAT,
AppMode.AGENT,
payload.workflow_run_id,
on_subscribe=on_subscribe,
),

View File

@ -183,7 +183,7 @@ class _AppRunner:
pause_state_config: PauseStateLayerConfig,
):
exec_params = self._exec_params
if exec_params.app_mode == AppMode.ADVANCED_CHAT:
if exec_params.app_mode in {AppMode.ADVANCED_CHAT, AppMode.AGENT}:
return AdvancedChatAppGenerator().generate(
app_model=app,
workflow=workflow,