mirror of https://github.com/langgenius/dify.git
Fix(workflow): Prevent token overcount caused by loop/iteration (#28406)
This commit is contained in:
parent
6bd114285c
commit
eed38c8b2a
|
|
@ -237,8 +237,7 @@ class IterationNode(LLMUsageTrackingMixin, Node):
|
|||
)
|
||||
)
|
||||
|
||||
# Update the total tokens from this iteration
|
||||
self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens
|
||||
# Accumulate usage from this iteration
|
||||
usage_accumulator[0] = self._merge_usage(
|
||||
usage_accumulator[0], graph_engine.graph_runtime_state.llm_usage
|
||||
)
|
||||
|
|
@ -265,7 +264,6 @@ class IterationNode(LLMUsageTrackingMixin, Node):
|
|||
datetime,
|
||||
list[GraphNodeEventBase],
|
||||
object | None,
|
||||
int,
|
||||
dict[str, VariableUnion],
|
||||
LLMUsage,
|
||||
]
|
||||
|
|
@ -292,7 +290,6 @@ class IterationNode(LLMUsageTrackingMixin, Node):
|
|||
iter_start_at,
|
||||
events,
|
||||
output_value,
|
||||
tokens_used,
|
||||
conversation_snapshot,
|
||||
iteration_usage,
|
||||
) = result
|
||||
|
|
@ -304,7 +301,6 @@ class IterationNode(LLMUsageTrackingMixin, Node):
|
|||
yield from events
|
||||
|
||||
# Update tokens and timing
|
||||
self.graph_runtime_state.total_tokens += tokens_used
|
||||
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
|
||||
usage_accumulator[0] = self._merge_usage(usage_accumulator[0], iteration_usage)
|
||||
|
|
@ -336,7 +332,7 @@ class IterationNode(LLMUsageTrackingMixin, Node):
|
|||
item: object,
|
||||
flask_app: Flask,
|
||||
context_vars: contextvars.Context,
|
||||
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion], LLMUsage]:
|
||||
) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, VariableUnion], LLMUsage]:
|
||||
"""Execute a single iteration in parallel mode and return results."""
|
||||
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
|
||||
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
|
@ -363,7 +359,6 @@ class IterationNode(LLMUsageTrackingMixin, Node):
|
|||
iter_start_at,
|
||||
events,
|
||||
output_value,
|
||||
graph_engine.graph_runtime_state.total_tokens,
|
||||
conversation_snapshot,
|
||||
graph_engine.graph_runtime_state.llm_usage,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -140,7 +140,6 @@ class LoopNode(LLMUsageTrackingMixin, Node):
|
|||
|
||||
if reach_break_condition:
|
||||
loop_count = 0
|
||||
cost_tokens = 0
|
||||
|
||||
for i in range(loop_count):
|
||||
graph_engine = self._create_graph_engine(start_at=start_at, root_node_id=root_node_id)
|
||||
|
|
@ -163,9 +162,6 @@ class LoopNode(LLMUsageTrackingMixin, Node):
|
|||
# For other outputs, just update
|
||||
self.graph_runtime_state.set_output(key, value)
|
||||
|
||||
# Update the total tokens from this iteration
|
||||
cost_tokens += graph_engine.graph_runtime_state.total_tokens
|
||||
|
||||
# Accumulate usage from the sub-graph execution
|
||||
loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage)
|
||||
|
||||
|
|
@ -194,7 +190,6 @@ class LoopNode(LLMUsageTrackingMixin, Node):
|
|||
pre_loop_output=self._node_data.outputs,
|
||||
)
|
||||
|
||||
self.graph_runtime_state.total_tokens += cost_tokens
|
||||
self._accumulate_usage(loop_usage)
|
||||
# Loop completed successfully
|
||||
yield LoopSucceededEvent(
|
||||
|
|
|
|||
Loading…
Reference in New Issue