Fix(workflow): Prevent token overcount caused by loop/iteration (#28406)

This commit is contained in:
Jax 2025-11-25 09:56:59 +08:00 committed by GitHub
parent 6bd114285c
commit eed38c8b2a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 2 additions and 12 deletions

View File

@ -237,8 +237,7 @@ class IterationNode(LLMUsageTrackingMixin, Node):
)
)
# Update the total tokens from this iteration
self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens
# Accumulate usage from this iteration
usage_accumulator[0] = self._merge_usage(
usage_accumulator[0], graph_engine.graph_runtime_state.llm_usage
)
@ -265,7 +264,6 @@ class IterationNode(LLMUsageTrackingMixin, Node):
datetime,
list[GraphNodeEventBase],
object | None,
int,
dict[str, VariableUnion],
LLMUsage,
]
@ -292,7 +290,6 @@ class IterationNode(LLMUsageTrackingMixin, Node):
iter_start_at,
events,
output_value,
tokens_used,
conversation_snapshot,
iteration_usage,
) = result
@ -304,7 +301,6 @@ class IterationNode(LLMUsageTrackingMixin, Node):
yield from events
# Update tokens and timing
self.graph_runtime_state.total_tokens += tokens_used
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
usage_accumulator[0] = self._merge_usage(usage_accumulator[0], iteration_usage)
@ -336,7 +332,7 @@ class IterationNode(LLMUsageTrackingMixin, Node):
item: object,
flask_app: Flask,
context_vars: contextvars.Context,
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion], LLMUsage]:
) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, VariableUnion], LLMUsage]:
"""Execute a single iteration in parallel mode and return results."""
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
@ -363,7 +359,6 @@ class IterationNode(LLMUsageTrackingMixin, Node):
iter_start_at,
events,
output_value,
graph_engine.graph_runtime_state.total_tokens,
conversation_snapshot,
graph_engine.graph_runtime_state.llm_usage,
)

View File

@ -140,7 +140,6 @@ class LoopNode(LLMUsageTrackingMixin, Node):
if reach_break_condition:
loop_count = 0
cost_tokens = 0
for i in range(loop_count):
graph_engine = self._create_graph_engine(start_at=start_at, root_node_id=root_node_id)
@ -163,9 +162,6 @@ class LoopNode(LLMUsageTrackingMixin, Node):
# For other outputs, just update
self.graph_runtime_state.set_output(key, value)
# Update the total tokens from this iteration
cost_tokens += graph_engine.graph_runtime_state.total_tokens
# Accumulate usage from the sub-graph execution
loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage)
@ -194,7 +190,6 @@ class LoopNode(LLMUsageTrackingMixin, Node):
pre_loop_output=self._node_data.outputs,
)
self.graph_runtime_state.total_tokens += cost_tokens
self._accumulate_usage(loop_usage)
# Loop completed successfully
yield LoopSucceededEvent(