Fix: surface workflow container LLM usage (#27021)

This commit is contained in:
-LAN- 2025-10-21 16:05:26 +08:00 committed by GitHub
parent 2bcf96565a
commit 4a6398fc1f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 283 additions and 59 deletions

View File

@ -72,6 +72,19 @@ default_retrieval_model: dict[str, Any] = {
class DatasetRetrieval: class DatasetRetrieval:
def __init__(self, application_generate_entity=None): def __init__(self, application_generate_entity=None):
self.application_generate_entity = application_generate_entity self.application_generate_entity = application_generate_entity
self._llm_usage = LLMUsage.empty_usage()
@property
def llm_usage(self) -> LLMUsage:
return self._llm_usage.model_copy()
def _record_usage(self, usage: LLMUsage | None) -> None:
if usage is None or usage.total_tokens <= 0:
return
if self._llm_usage.total_tokens == 0:
self._llm_usage = usage
else:
self._llm_usage = self._llm_usage.plus(usage)
def retrieve( def retrieve(
self, self,
@ -312,15 +325,18 @@ class DatasetRetrieval:
) )
tools.append(message_tool) tools.append(message_tool)
dataset_id = None dataset_id = None
router_usage = LLMUsage.empty_usage()
if planning_strategy == PlanningStrategy.REACT_ROUTER: if planning_strategy == PlanningStrategy.REACT_ROUTER:
react_multi_dataset_router = ReactMultiDatasetRouter() react_multi_dataset_router = ReactMultiDatasetRouter()
dataset_id = react_multi_dataset_router.invoke( dataset_id, router_usage = react_multi_dataset_router.invoke(
query, tools, model_config, model_instance, user_id, tenant_id query, tools, model_config, model_instance, user_id, tenant_id
) )
elif planning_strategy == PlanningStrategy.ROUTER: elif planning_strategy == PlanningStrategy.ROUTER:
function_call_router = FunctionCallMultiDatasetRouter() function_call_router = FunctionCallMultiDatasetRouter()
dataset_id = function_call_router.invoke(query, tools, model_config, model_instance) dataset_id, router_usage = function_call_router.invoke(query, tools, model_config, model_instance)
self._record_usage(router_usage)
if dataset_id: if dataset_id:
# get retrieval model config # get retrieval model config
@ -983,7 +999,8 @@ class DatasetRetrieval:
) )
# handle invoke result # handle invoke result
result_text, _ = self._handle_invoke_result(invoke_result=invoke_result) result_text, usage = self._handle_invoke_result(invoke_result=invoke_result)
self._record_usage(usage)
result_text_json = parse_and_check_json_markdown(result_text, []) result_text_json = parse_and_check_json_markdown(result_text, [])
automatic_metadata_filters = [] automatic_metadata_filters = []

View File

@ -2,7 +2,7 @@ from typing import Union
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage from core.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage
@ -13,15 +13,15 @@ class FunctionCallMultiDatasetRouter:
dataset_tools: list[PromptMessageTool], dataset_tools: list[PromptMessageTool],
model_config: ModelConfigWithCredentialsEntity, model_config: ModelConfigWithCredentialsEntity,
model_instance: ModelInstance, model_instance: ModelInstance,
) -> Union[str, None]: ) -> tuple[Union[str, None], LLMUsage]:
"""Given input, decided what to do. """Given input, decided what to do.
Returns: Returns:
Action specifying what tool to use. Action specifying what tool to use.
""" """
if len(dataset_tools) == 0: if len(dataset_tools) == 0:
return None return None, LLMUsage.empty_usage()
elif len(dataset_tools) == 1: elif len(dataset_tools) == 1:
return dataset_tools[0].name return dataset_tools[0].name, LLMUsage.empty_usage()
try: try:
prompt_messages = [ prompt_messages = [
@ -34,9 +34,10 @@ class FunctionCallMultiDatasetRouter:
stream=False, stream=False,
model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500}, model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
) )
usage = result.usage or LLMUsage.empty_usage()
if result.message.tool_calls: if result.message.tool_calls:
# get retrieval model config # get retrieval model config
return result.message.tool_calls[0].function.name return result.message.tool_calls[0].function.name, usage
return None return None, usage
except Exception: except Exception:
return None return None, LLMUsage.empty_usage()

View File

@ -58,15 +58,15 @@ class ReactMultiDatasetRouter:
model_instance: ModelInstance, model_instance: ModelInstance,
user_id: str, user_id: str,
tenant_id: str, tenant_id: str,
) -> Union[str, None]: ) -> tuple[Union[str, None], LLMUsage]:
"""Given input, decided what to do. """Given input, decided what to do.
Returns: Returns:
Action specifying what tool to use. Action specifying what tool to use.
""" """
if len(dataset_tools) == 0: if len(dataset_tools) == 0:
return None return None, LLMUsage.empty_usage()
elif len(dataset_tools) == 1: elif len(dataset_tools) == 1:
return dataset_tools[0].name return dataset_tools[0].name, LLMUsage.empty_usage()
try: try:
return self._react_invoke( return self._react_invoke(
@ -78,7 +78,7 @@ class ReactMultiDatasetRouter:
tenant_id=tenant_id, tenant_id=tenant_id,
) )
except Exception: except Exception:
return None return None, LLMUsage.empty_usage()
def _react_invoke( def _react_invoke(
self, self,
@ -91,7 +91,7 @@ class ReactMultiDatasetRouter:
prefix: str = PREFIX, prefix: str = PREFIX,
suffix: str = SUFFIX, suffix: str = SUFFIX,
format_instructions: str = FORMAT_INSTRUCTIONS, format_instructions: str = FORMAT_INSTRUCTIONS,
) -> Union[str, None]: ) -> tuple[Union[str, None], LLMUsage]:
prompt: Union[list[ChatModelMessage], CompletionModelPromptTemplate] prompt: Union[list[ChatModelMessage], CompletionModelPromptTemplate]
if model_config.mode == "chat": if model_config.mode == "chat":
prompt = self.create_chat_prompt( prompt = self.create_chat_prompt(
@ -120,7 +120,7 @@ class ReactMultiDatasetRouter:
memory=None, memory=None,
model_config=model_config, model_config=model_config,
) )
result_text, _ = self._invoke_llm( result_text, usage = self._invoke_llm(
completion_param=model_config.parameters, completion_param=model_config.parameters,
model_instance=model_instance, model_instance=model_instance,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
@ -131,8 +131,8 @@ class ReactMultiDatasetRouter:
output_parser = StructuredChatOutputParser() output_parser = StructuredChatOutputParser()
react_decision = output_parser.parse(result_text) react_decision = output_parser.parse(result_text)
if isinstance(react_decision, ReactAction): if isinstance(react_decision, ReactAction):
return react_decision.tool return react_decision.tool, usage
return None return None, usage
def _invoke_llm( def _invoke_llm(
self, self,

View File

@ -1,13 +1,14 @@
import json import json
import logging import logging
from collections.abc import Generator from collections.abc import Generator, Mapping, Sequence
from typing import Any from typing import Any, cast
from flask import has_request_context from flask import has_request_context
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from core.tools.__base.tool import Tool from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import ( from core.tools.entities.tool_entities import (
@ -49,6 +50,7 @@ class WorkflowTool(Tool):
self.workflow_entities = workflow_entities self.workflow_entities = workflow_entities
self.workflow_call_depth = workflow_call_depth self.workflow_call_depth = workflow_call_depth
self.label = label self.label = label
self._latest_usage = LLMUsage.empty_usage()
super().__init__(entity=entity, runtime=runtime) super().__init__(entity=entity, runtime=runtime)
@ -84,10 +86,11 @@ class WorkflowTool(Tool):
assert self.runtime.invoke_from is not None assert self.runtime.invoke_from is not None
user = self._resolve_user(user_id=user_id) user = self._resolve_user(user_id=user_id)
if user is None: if user is None:
raise ToolInvokeError("User not found") raise ToolInvokeError("User not found")
self._latest_usage = LLMUsage.empty_usage()
result = generator.generate( result = generator.generate(
app_model=app, app_model=app,
workflow=workflow, workflow=workflow,
@ -111,9 +114,68 @@ class WorkflowTool(Tool):
for file in files: for file in files:
yield self.create_file_message(file) # type: ignore yield self.create_file_message(file) # type: ignore
self._latest_usage = self._derive_usage_from_result(data)
yield self.create_text_message(json.dumps(outputs, ensure_ascii=False)) yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
yield self.create_json_message(outputs) yield self.create_json_message(outputs)
@property
def latest_usage(self) -> LLMUsage:
return self._latest_usage
@classmethod
def _derive_usage_from_result(cls, data: Mapping[str, Any]) -> LLMUsage:
usage_dict = cls._extract_usage_dict(data)
if usage_dict is not None:
return LLMUsage.from_metadata(cast(LLMUsageMetadata, dict(usage_dict)))
total_tokens = data.get("total_tokens")
total_price = data.get("total_price")
if total_tokens is None and total_price is None:
return LLMUsage.empty_usage()
usage_metadata: dict[str, Any] = {}
if total_tokens is not None:
try:
usage_metadata["total_tokens"] = int(str(total_tokens))
except (TypeError, ValueError):
pass
if total_price is not None:
usage_metadata["total_price"] = str(total_price)
currency = data.get("currency")
if currency is not None:
usage_metadata["currency"] = currency
if not usage_metadata:
return LLMUsage.empty_usage()
return LLMUsage.from_metadata(cast(LLMUsageMetadata, usage_metadata))
@classmethod
def _extract_usage_dict(cls, payload: Mapping[str, Any]) -> Mapping[str, Any] | None:
usage_candidate = payload.get("usage")
if isinstance(usage_candidate, Mapping):
return usage_candidate
metadata_candidate = payload.get("metadata")
if isinstance(metadata_candidate, Mapping):
usage_candidate = metadata_candidate.get("usage")
if isinstance(usage_candidate, Mapping):
return usage_candidate
for value in payload.values():
if isinstance(value, Mapping):
found = cls._extract_usage_dict(value)
if found is not None:
return found
elif isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
for item in value:
if isinstance(item, Mapping):
found = cls._extract_usage_dict(item)
if found is not None:
return found
return None
def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool": def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool":
""" """
fork a new tool with metadata fork a new tool with metadata

View File

@ -1,4 +1,5 @@
from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState, BaseNodeData from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState, BaseNodeData
from .usage_tracking_mixin import LLMUsageTrackingMixin
__all__ = [ __all__ = [
"BaseIterationNodeData", "BaseIterationNodeData",
@ -6,4 +7,5 @@ __all__ = [
"BaseLoopNodeData", "BaseLoopNodeData",
"BaseLoopState", "BaseLoopState",
"BaseNodeData", "BaseNodeData",
"LLMUsageTrackingMixin",
] ]

View File

@ -0,0 +1,28 @@
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.runtime import GraphRuntimeState
class LLMUsageTrackingMixin:
"""Provides shared helpers for merging and recording LLM usage within workflow nodes."""
graph_runtime_state: GraphRuntimeState
@staticmethod
def _merge_usage(current: LLMUsage, new_usage: LLMUsage | None) -> LLMUsage:
"""Return a combined usage snapshot, preserving zero-value inputs."""
if new_usage is None or new_usage.total_tokens <= 0:
return current
if current.total_tokens == 0:
return new_usage
return current.plus(new_usage)
def _accumulate_usage(self, usage: LLMUsage) -> None:
"""Push usage into the graph runtime accumulator for downstream reporting."""
if usage.total_tokens <= 0:
return
current_usage = self.graph_runtime_state.llm_usage
if current_usage.total_tokens == 0:
self.graph_runtime_state.llm_usage = usage.model_copy()
else:
self.graph_runtime_state.llm_usage = current_usage.plus(usage)

View File

@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, NewType, cast
from flask import Flask, current_app from flask import Flask, current_app
from typing_extensions import TypeIs from typing_extensions import TypeIs
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables import IntegerVariable, NoneSegment from core.variables import IntegerVariable, NoneSegment
from core.variables.segments import ArrayAnySegment, ArraySegment from core.variables.segments import ArrayAnySegment, ArraySegment
from core.variables.variables import VariableUnion from core.variables.variables import VariableUnion
@ -34,6 +35,7 @@ from core.workflow.node_events import (
NodeRunResult, NodeRunResult,
StreamCompletedEvent, StreamCompletedEvent,
) )
from core.workflow.nodes.base import LLMUsageTrackingMixin
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.node import Node
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
@ -58,7 +60,7 @@ logger = logging.getLogger(__name__)
EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment) EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment)
class IterationNode(Node): class IterationNode(LLMUsageTrackingMixin, Node):
""" """
Iteration Node. Iteration Node.
""" """
@ -118,6 +120,7 @@ class IterationNode(Node):
started_at = naive_utc_now() started_at = naive_utc_now()
iter_run_map: dict[str, float] = {} iter_run_map: dict[str, float] = {}
outputs: list[object] = [] outputs: list[object] = []
usage_accumulator = [LLMUsage.empty_usage()]
yield IterationStartedEvent( yield IterationStartedEvent(
start_at=started_at, start_at=started_at,
@ -130,22 +133,27 @@ class IterationNode(Node):
iterator_list_value=iterator_list_value, iterator_list_value=iterator_list_value,
outputs=outputs, outputs=outputs,
iter_run_map=iter_run_map, iter_run_map=iter_run_map,
usage_accumulator=usage_accumulator,
) )
self._accumulate_usage(usage_accumulator[0])
yield from self._handle_iteration_success( yield from self._handle_iteration_success(
started_at=started_at, started_at=started_at,
inputs=inputs, inputs=inputs,
outputs=outputs, outputs=outputs,
iterator_list_value=iterator_list_value, iterator_list_value=iterator_list_value,
iter_run_map=iter_run_map, iter_run_map=iter_run_map,
usage=usage_accumulator[0],
) )
except IterationNodeError as e: except IterationNodeError as e:
self._accumulate_usage(usage_accumulator[0])
yield from self._handle_iteration_failure( yield from self._handle_iteration_failure(
started_at=started_at, started_at=started_at,
inputs=inputs, inputs=inputs,
outputs=outputs, outputs=outputs,
iterator_list_value=iterator_list_value, iterator_list_value=iterator_list_value,
iter_run_map=iter_run_map, iter_run_map=iter_run_map,
usage=usage_accumulator[0],
error=e, error=e,
) )
@ -196,6 +204,7 @@ class IterationNode(Node):
iterator_list_value: Sequence[object], iterator_list_value: Sequence[object],
outputs: list[object], outputs: list[object],
iter_run_map: dict[str, float], iter_run_map: dict[str, float],
usage_accumulator: list[LLMUsage],
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: ) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
if self._node_data.is_parallel: if self._node_data.is_parallel:
# Parallel mode execution # Parallel mode execution
@ -203,6 +212,7 @@ class IterationNode(Node):
iterator_list_value=iterator_list_value, iterator_list_value=iterator_list_value,
outputs=outputs, outputs=outputs,
iter_run_map=iter_run_map, iter_run_map=iter_run_map,
usage_accumulator=usage_accumulator,
) )
else: else:
# Sequential mode execution # Sequential mode execution
@ -228,6 +238,9 @@ class IterationNode(Node):
# Update the total tokens from this iteration # Update the total tokens from this iteration
self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens
usage_accumulator[0] = self._merge_usage(
usage_accumulator[0], graph_engine.graph_runtime_state.llm_usage
)
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
def _execute_parallel_iterations( def _execute_parallel_iterations(
@ -235,6 +248,7 @@ class IterationNode(Node):
iterator_list_value: Sequence[object], iterator_list_value: Sequence[object],
outputs: list[object], outputs: list[object],
iter_run_map: dict[str, float], iter_run_map: dict[str, float],
usage_accumulator: list[LLMUsage],
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: ) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
# Initialize outputs list with None values to maintain order # Initialize outputs list with None values to maintain order
outputs.extend([None] * len(iterator_list_value)) outputs.extend([None] * len(iterator_list_value))
@ -245,7 +259,16 @@ class IterationNode(Node):
with ThreadPoolExecutor(max_workers=max_workers) as executor: with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all iteration tasks # Submit all iteration tasks
future_to_index: dict[ future_to_index: dict[
Future[tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion]]], Future[
tuple[
datetime,
list[GraphNodeEventBase],
object | None,
int,
dict[str, VariableUnion],
LLMUsage,
]
],
int, int,
] = {} ] = {}
for index, item in enumerate(iterator_list_value): for index, item in enumerate(iterator_list_value):
@ -264,7 +287,14 @@ class IterationNode(Node):
index = future_to_index[future] index = future_to_index[future]
try: try:
result = future.result() result = future.result()
iter_start_at, events, output_value, tokens_used, conversation_snapshot = result (
iter_start_at,
events,
output_value,
tokens_used,
conversation_snapshot,
iteration_usage,
) = result
# Update outputs at the correct index # Update outputs at the correct index
outputs[index] = output_value outputs[index] = output_value
@ -276,6 +306,8 @@ class IterationNode(Node):
self.graph_runtime_state.total_tokens += tokens_used self.graph_runtime_state.total_tokens += tokens_used
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
usage_accumulator[0] = self._merge_usage(usage_accumulator[0], iteration_usage)
# Sync conversation variables after iteration completion # Sync conversation variables after iteration completion
self._sync_conversation_variables_from_snapshot(conversation_snapshot) self._sync_conversation_variables_from_snapshot(conversation_snapshot)
@ -303,7 +335,7 @@ class IterationNode(Node):
item: object, item: object,
flask_app: Flask, flask_app: Flask,
context_vars: contextvars.Context, context_vars: contextvars.Context,
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion]]: ) -> tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion], LLMUsage]:
"""Execute a single iteration in parallel mode and return results.""" """Execute a single iteration in parallel mode and return results."""
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars): with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
iter_start_at = datetime.now(UTC).replace(tzinfo=None) iter_start_at = datetime.now(UTC).replace(tzinfo=None)
@ -332,6 +364,7 @@ class IterationNode(Node):
output_value, output_value,
graph_engine.graph_runtime_state.total_tokens, graph_engine.graph_runtime_state.total_tokens,
conversation_snapshot, conversation_snapshot,
graph_engine.graph_runtime_state.llm_usage,
) )
def _handle_iteration_success( def _handle_iteration_success(
@ -341,6 +374,8 @@ class IterationNode(Node):
outputs: list[object], outputs: list[object],
iterator_list_value: Sequence[object], iterator_list_value: Sequence[object],
iter_run_map: dict[str, float], iter_run_map: dict[str, float],
*,
usage: LLMUsage,
) -> Generator[NodeEventBase, None, None]: ) -> Generator[NodeEventBase, None, None]:
# Flatten the list of lists if all outputs are lists # Flatten the list of lists if all outputs are lists
flattened_outputs = self._flatten_outputs_if_needed(outputs) flattened_outputs = self._flatten_outputs_if_needed(outputs)
@ -351,7 +386,9 @@ class IterationNode(Node):
outputs={"output": flattened_outputs}, outputs={"output": flattened_outputs},
steps=len(iterator_list_value), steps=len(iterator_list_value),
metadata={ metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map, WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
}, },
) )
@ -362,8 +399,11 @@ class IterationNode(Node):
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"output": flattened_outputs}, outputs={"output": flattened_outputs},
metadata={ metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
}, },
llm_usage=usage,
) )
) )
@ -400,6 +440,8 @@ class IterationNode(Node):
outputs: list[object], outputs: list[object],
iterator_list_value: Sequence[object], iterator_list_value: Sequence[object],
iter_run_map: dict[str, float], iter_run_map: dict[str, float],
*,
usage: LLMUsage,
error: IterationNodeError, error: IterationNodeError,
) -> Generator[NodeEventBase, None, None]: ) -> Generator[NodeEventBase, None, None]:
# Flatten the list of lists if all outputs are lists (even in failure case) # Flatten the list of lists if all outputs are lists (even in failure case)
@ -411,7 +453,9 @@ class IterationNode(Node):
outputs={"output": flattened_outputs}, outputs={"output": flattened_outputs},
steps=len(iterator_list_value), steps=len(iterator_list_value),
metadata={ metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map, WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
}, },
error=str(error), error=str(error),
@ -420,6 +464,12 @@ class IterationNode(Node):
node_run_result=NodeRunResult( node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
error=str(error), error=str(error),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
},
llm_usage=usage,
) )
) )

View File

@ -15,14 +15,11 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
from core.entities.agent_entities import PlanningStrategy from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus from core.entities.model_entities import ModelStatus
from core.model_manager import ModelInstance, ModelManager from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.llm_entities import LLMUsage
PromptMessageRole, from core.model_runtime.entities.message_entities import PromptMessageRole
) from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.entities.model_entities import (
ModelFeature,
ModelType,
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.simple_prompt_transform import ModelMode from core.prompt.simple_prompt_transform import ModelMode
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.metadata_entities import Condition, MetadataCondition from core.rag.entities.metadata_entities import Condition, MetadataCondition
@ -33,8 +30,14 @@ from core.variables import (
) )
from core.variables.segments import ArrayObjectSegment from core.variables.segments import ArrayObjectSegment
from core.workflow.entities import GraphInitParams from core.workflow.entities import GraphInitParams
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus from core.workflow.enums import (
ErrorStrategy,
NodeType,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult
from core.workflow.nodes.base import LLMUsageTrackingMixin
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.node import Node
from core.workflow.nodes.knowledge_retrieval.template_prompts import ( from core.workflow.nodes.knowledge_retrieval.template_prompts import (
@ -80,7 +83,7 @@ default_retrieval_model = {
} }
class KnowledgeRetrievalNode(Node): class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node):
node_type = NodeType.KNOWLEDGE_RETRIEVAL node_type = NodeType.KNOWLEDGE_RETRIEVAL
_node_data: KnowledgeRetrievalNodeData _node_data: KnowledgeRetrievalNodeData
@ -182,14 +185,21 @@ class KnowledgeRetrievalNode(Node):
) )
# retrieve knowledge # retrieve knowledge
usage = LLMUsage.empty_usage()
try: try:
results = self._fetch_dataset_retriever(node_data=self._node_data, query=query) results, usage = self._fetch_dataset_retriever(node_data=self._node_data, query=query)
outputs = {"result": ArrayObjectSegment(value=results)} outputs = {"result": ArrayObjectSegment(value=results)}
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables, inputs=variables,
process_data={}, process_data={"usage": jsonable_encoder(usage)},
outputs=outputs, # type: ignore outputs=outputs, # type: ignore
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
},
llm_usage=usage,
) )
except KnowledgeRetrievalNodeError as e: except KnowledgeRetrievalNodeError as e:
@ -199,6 +209,7 @@ class KnowledgeRetrievalNode(Node):
inputs=variables, inputs=variables,
error=str(e), error=str(e),
error_type=type(e).__name__, error_type=type(e).__name__,
llm_usage=usage,
) )
# Temporary handle all exceptions from DatasetRetrieval class here. # Temporary handle all exceptions from DatasetRetrieval class here.
except Exception as e: except Exception as e:
@ -207,11 +218,15 @@ class KnowledgeRetrievalNode(Node):
inputs=variables, inputs=variables,
error=str(e), error=str(e),
error_type=type(e).__name__, error_type=type(e).__name__,
llm_usage=usage,
) )
finally: finally:
db.session.close() db.session.close()
def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]: def _fetch_dataset_retriever(
self, node_data: KnowledgeRetrievalNodeData, query: str
) -> tuple[list[dict[str, Any]], LLMUsage]:
usage = LLMUsage.empty_usage()
available_datasets = [] available_datasets = []
dataset_ids = node_data.dataset_ids dataset_ids = node_data.dataset_ids
@ -245,9 +260,10 @@ class KnowledgeRetrievalNode(Node):
if not dataset: if not dataset:
continue continue
available_datasets.append(dataset) available_datasets.append(dataset)
metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition( metadata_filter_document_ids, metadata_condition, metadata_usage = self._get_metadata_filter_condition(
[dataset.id for dataset in available_datasets], query, node_data [dataset.id for dataset in available_datasets], query, node_data
) )
usage = self._merge_usage(usage, metadata_usage)
all_documents = [] all_documents = []
dataset_retrieval = DatasetRetrieval() dataset_retrieval = DatasetRetrieval()
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
@ -330,6 +346,8 @@ class KnowledgeRetrievalNode(Node):
metadata_filter_document_ids=metadata_filter_document_ids, metadata_filter_document_ids=metadata_filter_document_ids,
metadata_condition=metadata_condition, metadata_condition=metadata_condition,
) )
usage = self._merge_usage(usage, dataset_retrieval.llm_usage)
dify_documents = [item for item in all_documents if item.provider == "dify"] dify_documents = [item for item in all_documents if item.provider == "dify"]
external_documents = [item for item in all_documents if item.provider == "external"] external_documents = [item for item in all_documents if item.provider == "external"]
retrieval_resource_list = [] retrieval_resource_list = []
@ -406,11 +424,12 @@ class KnowledgeRetrievalNode(Node):
) )
for position, item in enumerate(retrieval_resource_list, start=1): for position, item in enumerate(retrieval_resource_list, start=1):
item["metadata"]["position"] = position item["metadata"]["position"] = position
return retrieval_resource_list return retrieval_resource_list, usage
def _get_metadata_filter_condition( def _get_metadata_filter_condition(
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None]: ) -> tuple[dict[str, list[str]] | None, MetadataCondition | None, LLMUsage]:
usage = LLMUsage.empty_usage()
document_query = db.session.query(Document).where( document_query = db.session.query(Document).where(
Document.dataset_id.in_(dataset_ids), Document.dataset_id.in_(dataset_ids),
Document.indexing_status == "completed", Document.indexing_status == "completed",
@ -420,9 +439,12 @@ class KnowledgeRetrievalNode(Node):
filters: list[Any] = [] filters: list[Any] = []
metadata_condition = None metadata_condition = None
if node_data.metadata_filtering_mode == "disabled": if node_data.metadata_filtering_mode == "disabled":
return None, None return None, None, usage
elif node_data.metadata_filtering_mode == "automatic": elif node_data.metadata_filtering_mode == "automatic":
automatic_metadata_filters = self._automatic_metadata_filter_func(dataset_ids, query, node_data) automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func(
dataset_ids, query, node_data
)
usage = self._merge_usage(usage, automatic_usage)
if automatic_metadata_filters: if automatic_metadata_filters:
conditions = [] conditions = []
for sequence, filter in enumerate(automatic_metadata_filters): for sequence, filter in enumerate(automatic_metadata_filters):
@ -496,11 +518,12 @@ class KnowledgeRetrievalNode(Node):
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
for document in documents: for document in documents:
metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore
return metadata_filter_document_ids, metadata_condition return metadata_filter_document_ids, metadata_condition, usage
def _automatic_metadata_filter_func( def _automatic_metadata_filter_func(
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> list[dict[str, Any]]: ) -> tuple[list[dict[str, Any]], LLMUsage]:
usage = LLMUsage.empty_usage()
# get all metadata field # get all metadata field
stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)) stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
metadata_fields = db.session.scalars(stmt).all() metadata_fields = db.session.scalars(stmt).all()
@ -548,6 +571,7 @@ class KnowledgeRetrievalNode(Node):
for event in generator: for event in generator:
if isinstance(event, ModelInvokeCompletedEvent): if isinstance(event, ModelInvokeCompletedEvent):
result_text = event.text result_text = event.text
usage = self._merge_usage(usage, event.usage)
break break
result_text_json = parse_and_check_json_markdown(result_text, []) result_text_json = parse_and_check_json_markdown(result_text, [])
@ -564,8 +588,8 @@ class KnowledgeRetrievalNode(Node):
} }
) )
except Exception: except Exception:
return [] return [], usage
return automatic_metadata_filters return automatic_metadata_filters, usage
def _process_metadata_filter_func( def _process_metadata_filter_func(
self, sequence: int, condition: str, metadata_name: str, value: Any, filters: list[Any] self, sequence: int, condition: str, metadata_name: str, value: Any, filters: list[Any]

View File

@ -5,6 +5,7 @@ from collections.abc import Callable, Generator, Mapping, Sequence
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING, Any, Literal, cast from typing import TYPE_CHECKING, Any, Literal, cast
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables import Segment, SegmentType from core.variables import Segment, SegmentType
from core.workflow.enums import ( from core.workflow.enums import (
ErrorStrategy, ErrorStrategy,
@ -27,6 +28,7 @@ from core.workflow.node_events import (
NodeRunResult, NodeRunResult,
StreamCompletedEvent, StreamCompletedEvent,
) )
from core.workflow.nodes.base import LLMUsageTrackingMixin
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.node import Node
from core.workflow.nodes.loop.entities import LoopNodeData, LoopVariableData from core.workflow.nodes.loop.entities import LoopNodeData, LoopVariableData
@ -40,7 +42,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class LoopNode(Node): class LoopNode(LLMUsageTrackingMixin, Node):
""" """
Loop Node. Loop Node.
""" """
@ -117,6 +119,7 @@ class LoopNode(Node):
loop_duration_map: dict[str, float] = {} loop_duration_map: dict[str, float] = {}
single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output
loop_usage = LLMUsage.empty_usage()
# Start Loop event # Start Loop event
yield LoopStartedEvent( yield LoopStartedEvent(
@ -163,6 +166,9 @@ class LoopNode(Node):
# Update the total tokens from this iteration # Update the total tokens from this iteration
cost_tokens += graph_engine.graph_runtime_state.total_tokens cost_tokens += graph_engine.graph_runtime_state.total_tokens
# Accumulate usage from the sub-graph execution
loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage)
# Collect loop variable values after iteration # Collect loop variable values after iteration
single_loop_variable = {} single_loop_variable = {}
for key, selector in loop_variable_selectors.items(): for key, selector in loop_variable_selectors.items():
@ -189,6 +195,7 @@ class LoopNode(Node):
) )
self.graph_runtime_state.total_tokens += cost_tokens self.graph_runtime_state.total_tokens += cost_tokens
self._accumulate_usage(loop_usage)
# Loop completed successfully # Loop completed successfully
yield LoopSucceededEvent( yield LoopSucceededEvent(
start_at=start_at, start_at=start_at,
@ -196,7 +203,9 @@ class LoopNode(Node):
outputs=self._node_data.outputs, outputs=self._node_data.outputs,
steps=loop_count, steps=loop_count,
metadata={ metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: cost_tokens, WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
"completed_reason": "loop_break" if reach_break_condition else "loop_completed", "completed_reason": "loop_break" if reach_break_condition else "loop_completed",
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
@ -207,22 +216,28 @@ class LoopNode(Node):
node_run_result=NodeRunResult( node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
metadata={ metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
}, },
outputs=self._node_data.outputs, outputs=self._node_data.outputs,
inputs=inputs, inputs=inputs,
llm_usage=loop_usage,
) )
) )
except Exception as e: except Exception as e:
self._accumulate_usage(loop_usage)
yield LoopFailedEvent( yield LoopFailedEvent(
start_at=start_at, start_at=start_at,
inputs=inputs, inputs=inputs,
steps=loop_count, steps=loop_count,
metadata={ metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
"completed_reason": "error", "completed_reason": "error",
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
@ -235,10 +250,13 @@ class LoopNode(Node):
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
error=str(e), error=str(e),
metadata={ metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
}, },
llm_usage=loop_usage,
) )
) )

View File

@ -6,10 +6,13 @@ from sqlalchemy.orm import Session
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file import File, FileTransferMethod from core.file import File, FileTransferMethod
from core.model_runtime.entities.llm_entities import LLMUsage
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.errors import ToolInvokeError from core.tools.errors import ToolInvokeError
from core.tools.tool_engine import ToolEngine from core.tools.tool_engine import ToolEngine
from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.tools.workflow_as_tool.tool import WorkflowTool
from core.variables.segments import ArrayAnySegment, ArrayFileSegment from core.variables.segments import ArrayAnySegment, ArrayFileSegment
from core.variables.variables import ArrayAnyVariable from core.variables.variables import ArrayAnyVariable
from core.workflow.enums import ( from core.workflow.enums import (
@ -136,13 +139,14 @@ class ToolNode(Node):
try: try:
# convert tool messages # convert tool messages
yield from self._transform_message( _ = yield from self._transform_message(
messages=message_stream, messages=message_stream,
tool_info=tool_info, tool_info=tool_info,
parameters_for_log=parameters_for_log, parameters_for_log=parameters_for_log,
user_id=self.user_id, user_id=self.user_id,
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
node_id=self._node_id, node_id=self._node_id,
tool_runtime=tool_runtime,
) )
except ToolInvokeError as e: except ToolInvokeError as e:
yield StreamCompletedEvent( yield StreamCompletedEvent(
@ -236,7 +240,8 @@ class ToolNode(Node):
user_id: str, user_id: str,
tenant_id: str, tenant_id: str,
node_id: str, node_id: str,
) -> Generator: tool_runtime: Tool,
) -> Generator[NodeEventBase, None, LLMUsage]:
""" """
Convert ToolInvokeMessages into tuple[plain_text, files] Convert ToolInvokeMessages into tuple[plain_text, files]
""" """
@ -424,17 +429,34 @@ class ToolNode(Node):
is_final=True, is_final=True,
) )
usage = self._extract_tool_usage(tool_runtime)
metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
}
if usage.total_tokens > 0:
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price
metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency
yield StreamCompletedEvent( yield StreamCompletedEvent(
node_run_result=NodeRunResult( node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables}, outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
metadata={ metadata=metadata,
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
},
inputs=parameters_for_log, inputs=parameters_for_log,
llm_usage=usage,
) )
) )
return usage
@staticmethod
def _extract_tool_usage(tool_runtime: Tool) -> LLMUsage:
if isinstance(tool_runtime, WorkflowTool):
return tool_runtime.latest_usage
return LLMUsage.empty_usage()
@classmethod @classmethod
def _extract_variable_selector_to_variable_mapping( def _extract_variable_selector_to_variable_mapping(
cls, cls,