From 4a6398fc1fb18d9ffc37f7634e1028a210562a7b Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 21 Oct 2025 16:05:26 +0800 Subject: [PATCH] Fix: surface workflow container LLM usage (#27021) --- api/core/rag/retrieval/dataset_retrieval.py | 23 ++++++- .../multi_dataset_function_call_router.py | 15 ++-- .../router/multi_dataset_react_route.py | 16 ++--- api/core/tools/workflow_as_tool/tool.py | 68 ++++++++++++++++++- api/core/workflow/nodes/base/__init__.py | 2 + .../nodes/base/usage_tracking_mixin.py | 28 ++++++++ .../nodes/iteration/iteration_node.py | 64 +++++++++++++++-- .../knowledge_retrieval_node.py | 66 ++++++++++++------ api/core/workflow/nodes/loop/loop_node.py | 28 ++++++-- api/core/workflow/nodes/tool/tool_node.py | 32 +++++++-- 10 files changed, 283 insertions(+), 59 deletions(-) create mode 100644 api/core/workflow/nodes/base/usage_tracking_mixin.py diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 99bbe615fb..45b19f25a0 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -72,6 +72,19 @@ default_retrieval_model: dict[str, Any] = { class DatasetRetrieval: def __init__(self, application_generate_entity=None): self.application_generate_entity = application_generate_entity + self._llm_usage = LLMUsage.empty_usage() + + @property + def llm_usage(self) -> LLMUsage: + return self._llm_usage.model_copy() + + def _record_usage(self, usage: LLMUsage | None) -> None: + if usage is None or usage.total_tokens <= 0: + return + if self._llm_usage.total_tokens == 0: + self._llm_usage = usage + else: + self._llm_usage = self._llm_usage.plus(usage) def retrieve( self, @@ -312,15 +325,18 @@ class DatasetRetrieval: ) tools.append(message_tool) dataset_id = None + router_usage = LLMUsage.empty_usage() if planning_strategy == PlanningStrategy.REACT_ROUTER: react_multi_dataset_router = ReactMultiDatasetRouter() - dataset_id = react_multi_dataset_router.invoke( + dataset_id, router_usage = react_multi_dataset_router.invoke( query, tools, model_config, model_instance, user_id, tenant_id ) elif planning_strategy == PlanningStrategy.ROUTER: function_call_router = FunctionCallMultiDatasetRouter() - dataset_id = function_call_router.invoke(query, tools, model_config, model_instance) + dataset_id, router_usage = function_call_router.invoke(query, tools, model_config, model_instance) + + self._record_usage(router_usage) if dataset_id: # get retrieval model config @@ -983,7 +999,8 @@ class DatasetRetrieval: ) # handle invoke result - result_text, _ = self._handle_invoke_result(invoke_result=invoke_result) + result_text, usage = self._handle_invoke_result(invoke_result=invoke_result) + self._record_usage(usage) result_text_json = parse_and_check_json_markdown(result_text, []) automatic_metadata_filters = [] diff --git a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py index de59c6380e..5f3e1a8cae 100644 --- a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py +++ b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py @@ -2,7 +2,7 @@ from typing import Union from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.model_manager import ModelInstance -from core.model_runtime.entities.llm_entities import LLMResult +from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage @@ -13,15 +13,15 @@ class FunctionCallMultiDatasetRouter: dataset_tools: list[PromptMessageTool], model_config: ModelConfigWithCredentialsEntity, model_instance: ModelInstance, - ) -> Union[str, None]: + ) -> tuple[Union[str, None], LLMUsage]: """Given input, decided what to do. Returns: Action specifying what tool to use. """ if len(dataset_tools) == 0: - return None + return None, LLMUsage.empty_usage() elif len(dataset_tools) == 1: - return dataset_tools[0].name + return dataset_tools[0].name, LLMUsage.empty_usage() try: prompt_messages = [ @@ -34,9 +34,10 @@ class FunctionCallMultiDatasetRouter: stream=False, model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500}, ) + usage = result.usage or LLMUsage.empty_usage() if result.message.tool_calls: # get retrieval model config - return result.message.tool_calls[0].function.name - return None + return result.message.tool_calls[0].function.name, usage + return None, usage except Exception: - return None + return None, LLMUsage.empty_usage() diff --git a/api/core/rag/retrieval/router/multi_dataset_react_route.py b/api/core/rag/retrieval/router/multi_dataset_react_route.py index 59d36229b3..8f3bec2704 100644 --- a/api/core/rag/retrieval/router/multi_dataset_react_route.py +++ b/api/core/rag/retrieval/router/multi_dataset_react_route.py @@ -58,15 +58,15 @@ class ReactMultiDatasetRouter: model_instance: ModelInstance, user_id: str, tenant_id: str, - ) -> Union[str, None]: + ) -> tuple[Union[str, None], LLMUsage]: """Given input, decided what to do. Returns: Action specifying what tool to use. """ if len(dataset_tools) == 0: - return None + return None, LLMUsage.empty_usage() elif len(dataset_tools) == 1: - return dataset_tools[0].name + return dataset_tools[0].name, LLMUsage.empty_usage() try: return self._react_invoke( @@ -78,7 +78,7 @@ class ReactMultiDatasetRouter: tenant_id=tenant_id, ) except Exception: - return None + return None, LLMUsage.empty_usage() def _react_invoke( self, @@ -91,7 +91,7 @@ class ReactMultiDatasetRouter: prefix: str = PREFIX, suffix: str = SUFFIX, format_instructions: str = FORMAT_INSTRUCTIONS, - ) -> Union[str, None]: + ) -> tuple[Union[str, None], LLMUsage]: prompt: Union[list[ChatModelMessage], CompletionModelPromptTemplate] if model_config.mode == "chat": prompt = self.create_chat_prompt( @@ -120,7 +120,7 @@ class ReactMultiDatasetRouter: memory=None, model_config=model_config, ) - result_text, _ = self._invoke_llm( + result_text, usage = self._invoke_llm( completion_param=model_config.parameters, model_instance=model_instance, prompt_messages=prompt_messages, @@ -131,8 +131,8 @@ class ReactMultiDatasetRouter: output_parser = StructuredChatOutputParser() react_decision = output_parser.parse(result_text) if isinstance(react_decision, ReactAction): - return react_decision.tool - return None + return react_decision.tool, usage + return None, usage def _invoke_llm( self, diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index bc5b12c102..2cd46647a0 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -1,13 +1,14 @@ import json import logging -from collections.abc import Generator -from typing import Any +from collections.abc import Generator, Mapping, Sequence +from typing import Any, cast from flask import has_request_context from sqlalchemy import select from sqlalchemy.orm import Session from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod +from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ( @@ -49,6 +50,7 @@ class WorkflowTool(Tool): self.workflow_entities = workflow_entities self.workflow_call_depth = workflow_call_depth self.label = label + self._latest_usage = LLMUsage.empty_usage() super().__init__(entity=entity, runtime=runtime) @@ -84,10 +86,11 @@ class WorkflowTool(Tool): assert self.runtime.invoke_from is not None user = self._resolve_user(user_id=user_id) - if user is None: raise ToolInvokeError("User not found") + self._latest_usage = LLMUsage.empty_usage() + result = generator.generate( app_model=app, workflow=workflow, @@ -111,9 +114,68 @@ class WorkflowTool(Tool): for file in files: yield self.create_file_message(file) # type: ignore + self._latest_usage = self._derive_usage_from_result(data) + yield self.create_text_message(json.dumps(outputs, ensure_ascii=False)) yield self.create_json_message(outputs) + @property + def latest_usage(self) -> LLMUsage: + return self._latest_usage + + @classmethod + def _derive_usage_from_result(cls, data: Mapping[str, Any]) -> LLMUsage: + usage_dict = cls._extract_usage_dict(data) + if usage_dict is not None: + return LLMUsage.from_metadata(cast(LLMUsageMetadata, dict(usage_dict))) + + total_tokens = data.get("total_tokens") + total_price = data.get("total_price") + if total_tokens is None and total_price is None: + return LLMUsage.empty_usage() + + usage_metadata: dict[str, Any] = {} + if total_tokens is not None: + try: + usage_metadata["total_tokens"] = int(str(total_tokens)) + except (TypeError, ValueError): + pass + if total_price is not None: + usage_metadata["total_price"] = str(total_price) + currency = data.get("currency") + if currency is not None: + usage_metadata["currency"] = currency + + if not usage_metadata: + return LLMUsage.empty_usage() + + return LLMUsage.from_metadata(cast(LLMUsageMetadata, usage_metadata)) + + @classmethod + def _extract_usage_dict(cls, payload: Mapping[str, Any]) -> Mapping[str, Any] | None: + usage_candidate = payload.get("usage") + if isinstance(usage_candidate, Mapping): + return usage_candidate + + metadata_candidate = payload.get("metadata") + if isinstance(metadata_candidate, Mapping): + usage_candidate = metadata_candidate.get("usage") + if isinstance(usage_candidate, Mapping): + return usage_candidate + + for value in payload.values(): + if isinstance(value, Mapping): + found = cls._extract_usage_dict(value) + if found is not None: + return found + elif isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)): + for item in value: + if isinstance(item, Mapping): + found = cls._extract_usage_dict(item) + if found is not None: + return found + return None + def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool": """ fork a new tool with metadata diff --git a/api/core/workflow/nodes/base/__init__.py b/api/core/workflow/nodes/base/__init__.py index 8cf31dc342..f83df0e323 100644 --- a/api/core/workflow/nodes/base/__init__.py +++ b/api/core/workflow/nodes/base/__init__.py @@ -1,4 +1,5 @@ from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState, BaseNodeData +from .usage_tracking_mixin import LLMUsageTrackingMixin __all__ = [ "BaseIterationNodeData", @@ -6,4 +7,5 @@ __all__ = [ "BaseLoopNodeData", "BaseLoopState", "BaseNodeData", + "LLMUsageTrackingMixin", ] diff --git a/api/core/workflow/nodes/base/usage_tracking_mixin.py b/api/core/workflow/nodes/base/usage_tracking_mixin.py new file mode 100644 index 0000000000..d9a0ef8972 --- /dev/null +++ b/api/core/workflow/nodes/base/usage_tracking_mixin.py @@ -0,0 +1,28 @@ +from core.model_runtime.entities.llm_entities import LLMUsage +from core.workflow.runtime import GraphRuntimeState + + +class LLMUsageTrackingMixin: + """Provides shared helpers for merging and recording LLM usage within workflow nodes.""" + + graph_runtime_state: GraphRuntimeState + + @staticmethod + def _merge_usage(current: LLMUsage, new_usage: LLMUsage | None) -> LLMUsage: + """Return a combined usage snapshot, preserving zero-value inputs.""" + if new_usage is None or new_usage.total_tokens <= 0: + return current + if current.total_tokens == 0: + return new_usage + return current.plus(new_usage) + + def _accumulate_usage(self, usage: LLMUsage) -> None: + """Push usage into the graph runtime accumulator for downstream reporting.""" + if usage.total_tokens <= 0: + return + + current_usage = self.graph_runtime_state.llm_usage + if current_usage.total_tokens == 0: + self.graph_runtime_state.llm_usage = usage.model_copy() + else: + self.graph_runtime_state.llm_usage = current_usage.plus(usage) diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 41060bd569..3a3a2290be 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, NewType, cast from flask import Flask, current_app from typing_extensions import TypeIs +from core.model_runtime.entities.llm_entities import LLMUsage from core.variables import IntegerVariable, NoneSegment from core.variables.segments import ArrayAnySegment, ArraySegment from core.variables.variables import VariableUnion @@ -34,6 +35,7 @@ from core.workflow.node_events import ( NodeRunResult, StreamCompletedEvent, ) +from core.workflow.nodes.base import LLMUsageTrackingMixin from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData @@ -58,7 +60,7 @@ logger = logging.getLogger(__name__) EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment) -class IterationNode(Node): +class IterationNode(LLMUsageTrackingMixin, Node): """ Iteration Node. """ @@ -118,6 +120,7 @@ class IterationNode(Node): started_at = naive_utc_now() iter_run_map: dict[str, float] = {} outputs: list[object] = [] + usage_accumulator = [LLMUsage.empty_usage()] yield IterationStartedEvent( start_at=started_at, @@ -130,22 +133,27 @@ class IterationNode(Node): iterator_list_value=iterator_list_value, outputs=outputs, iter_run_map=iter_run_map, + usage_accumulator=usage_accumulator, ) + self._accumulate_usage(usage_accumulator[0]) yield from self._handle_iteration_success( started_at=started_at, inputs=inputs, outputs=outputs, iterator_list_value=iterator_list_value, iter_run_map=iter_run_map, + usage=usage_accumulator[0], ) except IterationNodeError as e: + self._accumulate_usage(usage_accumulator[0]) yield from self._handle_iteration_failure( started_at=started_at, inputs=inputs, outputs=outputs, iterator_list_value=iterator_list_value, iter_run_map=iter_run_map, + usage=usage_accumulator[0], error=e, ) @@ -196,6 +204,7 @@ class IterationNode(Node): iterator_list_value: Sequence[object], outputs: list[object], iter_run_map: dict[str, float], + usage_accumulator: list[LLMUsage], ) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: if self._node_data.is_parallel: # Parallel mode execution @@ -203,6 +212,7 @@ class IterationNode(Node): iterator_list_value=iterator_list_value, outputs=outputs, iter_run_map=iter_run_map, + usage_accumulator=usage_accumulator, ) else: # Sequential mode execution @@ -228,6 +238,9 @@ class IterationNode(Node): # Update the total tokens from this iteration self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens + usage_accumulator[0] = self._merge_usage( + usage_accumulator[0], graph_engine.graph_runtime_state.llm_usage + ) iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() def _execute_parallel_iterations( @@ -235,6 +248,7 @@ class IterationNode(Node): iterator_list_value: Sequence[object], outputs: list[object], iter_run_map: dict[str, float], + usage_accumulator: list[LLMUsage], ) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: # Initialize outputs list with None values to maintain order outputs.extend([None] * len(iterator_list_value)) @@ -245,7 +259,16 @@ class IterationNode(Node): with ThreadPoolExecutor(max_workers=max_workers) as executor: # Submit all iteration tasks future_to_index: dict[ - Future[tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion]]], + Future[ + tuple[ + datetime, + list[GraphNodeEventBase], + object | None, + int, + dict[str, VariableUnion], + LLMUsage, + ] + ], int, ] = {} for index, item in enumerate(iterator_list_value): @@ -264,7 +287,14 @@ class IterationNode(Node): index = future_to_index[future] try: result = future.result() - iter_start_at, events, output_value, tokens_used, conversation_snapshot = result + ( + iter_start_at, + events, + output_value, + tokens_used, + conversation_snapshot, + iteration_usage, + ) = result # Update outputs at the correct index outputs[index] = output_value @@ -276,6 +306,8 @@ class IterationNode(Node): self.graph_runtime_state.total_tokens += tokens_used iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() + usage_accumulator[0] = self._merge_usage(usage_accumulator[0], iteration_usage) + # Sync conversation variables after iteration completion self._sync_conversation_variables_from_snapshot(conversation_snapshot) @@ -303,7 +335,7 @@ class IterationNode(Node): item: object, flask_app: Flask, context_vars: contextvars.Context, - ) -> tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion]]: + ) -> tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion], LLMUsage]: """Execute a single iteration in parallel mode and return results.""" with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars): iter_start_at = datetime.now(UTC).replace(tzinfo=None) @@ -332,6 +364,7 @@ class IterationNode(Node): output_value, graph_engine.graph_runtime_state.total_tokens, conversation_snapshot, + graph_engine.graph_runtime_state.llm_usage, ) def _handle_iteration_success( @@ -341,6 +374,8 @@ class IterationNode(Node): outputs: list[object], iterator_list_value: Sequence[object], iter_run_map: dict[str, float], + *, + usage: LLMUsage, ) -> Generator[NodeEventBase, None, None]: # Flatten the list of lists if all outputs are lists flattened_outputs = self._flatten_outputs_if_needed(outputs) @@ -351,7 +386,9 @@ class IterationNode(Node): outputs={"output": flattened_outputs}, steps=len(iterator_list_value), metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, + WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map, }, ) @@ -362,8 +399,11 @@ class IterationNode(Node): status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"output": flattened_outputs}, metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, + WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, }, + llm_usage=usage, ) ) @@ -400,6 +440,8 @@ class IterationNode(Node): outputs: list[object], iterator_list_value: Sequence[object], iter_run_map: dict[str, float], + *, + usage: LLMUsage, error: IterationNodeError, ) -> Generator[NodeEventBase, None, None]: # Flatten the list of lists if all outputs are lists (even in failure case) @@ -411,7 +453,9 @@ class IterationNode(Node): outputs={"output": flattened_outputs}, steps=len(iterator_list_value), metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, + WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map, }, error=str(error), @@ -420,6 +464,12 @@ class IterationNode(Node): node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, error=str(error), + metadata={ + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, + WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, + }, + llm_usage=usage, ) ) diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index ba5134f9e6..4a63900527 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -15,14 +15,11 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti from core.entities.agent_entities import PlanningStrategy from core.entities.model_entities import ModelStatus from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.message_entities import ( - PromptMessageRole, -) -from core.model_runtime.entities.model_entities import ( - ModelFeature, - ModelType, -) +from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.entities.message_entities import PromptMessageRole +from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.simple_prompt_transform import ModelMode from core.rag.datasource.retrieval_service import RetrievalService from core.rag.entities.metadata_entities import Condition, MetadataCondition @@ -33,8 +30,14 @@ from core.variables import ( ) from core.variables.segments import ArrayObjectSegment from core.workflow.entities import GraphInitParams -from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus +from core.workflow.enums import ( + ErrorStrategy, + NodeType, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult +from core.workflow.nodes.base import LLMUsageTrackingMixin from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from core.workflow.nodes.knowledge_retrieval.template_prompts import ( @@ -80,7 +83,7 @@ default_retrieval_model = { } -class KnowledgeRetrievalNode(Node): +class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node): node_type = NodeType.KNOWLEDGE_RETRIEVAL _node_data: KnowledgeRetrievalNodeData @@ -182,14 +185,21 @@ class KnowledgeRetrievalNode(Node): ) # retrieve knowledge + usage = LLMUsage.empty_usage() try: - results = self._fetch_dataset_retriever(node_data=self._node_data, query=query) + results, usage = self._fetch_dataset_retriever(node_data=self._node_data, query=query) outputs = {"result": ArrayObjectSegment(value=results)} return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, - process_data={}, + process_data={"usage": jsonable_encoder(usage)}, outputs=outputs, # type: ignore + metadata={ + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, + WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, + }, + llm_usage=usage, ) except KnowledgeRetrievalNodeError as e: @@ -199,6 +209,7 @@ class KnowledgeRetrievalNode(Node): inputs=variables, error=str(e), error_type=type(e).__name__, + llm_usage=usage, ) # Temporary handle all exceptions from DatasetRetrieval class here. except Exception as e: @@ -207,11 +218,15 @@ class KnowledgeRetrievalNode(Node): inputs=variables, error=str(e), error_type=type(e).__name__, + llm_usage=usage, ) finally: db.session.close() - def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]: + def _fetch_dataset_retriever( + self, node_data: KnowledgeRetrievalNodeData, query: str + ) -> tuple[list[dict[str, Any]], LLMUsage]: + usage = LLMUsage.empty_usage() available_datasets = [] dataset_ids = node_data.dataset_ids @@ -245,9 +260,10 @@ class KnowledgeRetrievalNode(Node): if not dataset: continue available_datasets.append(dataset) - metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition( + metadata_filter_document_ids, metadata_condition, metadata_usage = self._get_metadata_filter_condition( [dataset.id for dataset in available_datasets], query, node_data ) + usage = self._merge_usage(usage, metadata_usage) all_documents = [] dataset_retrieval = DatasetRetrieval() if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: @@ -330,6 +346,8 @@ class KnowledgeRetrievalNode(Node): metadata_filter_document_ids=metadata_filter_document_ids, metadata_condition=metadata_condition, ) + usage = self._merge_usage(usage, dataset_retrieval.llm_usage) + dify_documents = [item for item in all_documents if item.provider == "dify"] external_documents = [item for item in all_documents if item.provider == "external"] retrieval_resource_list = [] @@ -406,11 +424,12 @@ class KnowledgeRetrievalNode(Node): ) for position, item in enumerate(retrieval_resource_list, start=1): item["metadata"]["position"] = position - return retrieval_resource_list + return retrieval_resource_list, usage def _get_metadata_filter_condition( self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData - ) -> tuple[dict[str, list[str]] | None, MetadataCondition | None]: + ) -> tuple[dict[str, list[str]] | None, MetadataCondition | None, LLMUsage]: + usage = LLMUsage.empty_usage() document_query = db.session.query(Document).where( Document.dataset_id.in_(dataset_ids), Document.indexing_status == "completed", @@ -420,9 +439,12 @@ class KnowledgeRetrievalNode(Node): filters: list[Any] = [] metadata_condition = None if node_data.metadata_filtering_mode == "disabled": - return None, None + return None, None, usage elif node_data.metadata_filtering_mode == "automatic": - automatic_metadata_filters = self._automatic_metadata_filter_func(dataset_ids, query, node_data) + automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func( + dataset_ids, query, node_data + ) + usage = self._merge_usage(usage, automatic_usage) if automatic_metadata_filters: conditions = [] for sequence, filter in enumerate(automatic_metadata_filters): @@ -496,11 +518,12 @@ class KnowledgeRetrievalNode(Node): metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore for document in documents: metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore - return metadata_filter_document_ids, metadata_condition + return metadata_filter_document_ids, metadata_condition, usage def _automatic_metadata_filter_func( self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData - ) -> list[dict[str, Any]]: + ) -> tuple[list[dict[str, Any]], LLMUsage]: + usage = LLMUsage.empty_usage() # get all metadata field stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)) metadata_fields = db.session.scalars(stmt).all() @@ -548,6 +571,7 @@ class KnowledgeRetrievalNode(Node): for event in generator: if isinstance(event, ModelInvokeCompletedEvent): result_text = event.text + usage = self._merge_usage(usage, event.usage) break result_text_json = parse_and_check_json_markdown(result_text, []) @@ -564,8 +588,8 @@ class KnowledgeRetrievalNode(Node): } ) except Exception: - return [] - return automatic_metadata_filters + return [], usage + return automatic_metadata_filters, usage def _process_metadata_filter_func( self, sequence: int, condition: str, metadata_name: str, value: Any, filters: list[Any] diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index c35195e931..ca39e5aa23 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -5,6 +5,7 @@ from collections.abc import Callable, Generator, Mapping, Sequence from datetime import datetime from typing import TYPE_CHECKING, Any, Literal, cast +from core.model_runtime.entities.llm_entities import LLMUsage from core.variables import Segment, SegmentType from core.workflow.enums import ( ErrorStrategy, @@ -27,6 +28,7 @@ from core.workflow.node_events import ( NodeRunResult, StreamCompletedEvent, ) +from core.workflow.nodes.base import LLMUsageTrackingMixin from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from core.workflow.nodes.loop.entities import LoopNodeData, LoopVariableData @@ -40,7 +42,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class LoopNode(Node): +class LoopNode(LLMUsageTrackingMixin, Node): """ Loop Node. """ @@ -117,6 +119,7 @@ class LoopNode(Node): loop_duration_map: dict[str, float] = {} single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output + loop_usage = LLMUsage.empty_usage() # Start Loop event yield LoopStartedEvent( @@ -163,6 +166,9 @@ class LoopNode(Node): # Update the total tokens from this iteration cost_tokens += graph_engine.graph_runtime_state.total_tokens + # Accumulate usage from the sub-graph execution + loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage) + # Collect loop variable values after iteration single_loop_variable = {} for key, selector in loop_variable_selectors.items(): @@ -189,6 +195,7 @@ class LoopNode(Node): ) self.graph_runtime_state.total_tokens += cost_tokens + self._accumulate_usage(loop_usage) # Loop completed successfully yield LoopSucceededEvent( start_at=start_at, @@ -196,7 +203,9 @@ class LoopNode(Node): outputs=self._node_data.outputs, steps=loop_count, metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: cost_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price, + WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency, "completed_reason": "loop_break" if reach_break_condition else "loop_completed", WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, @@ -207,22 +216,28 @@ class LoopNode(Node): node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price, + WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency, WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, }, outputs=self._node_data.outputs, inputs=inputs, + llm_usage=loop_usage, ) ) except Exception as e: + self._accumulate_usage(loop_usage) yield LoopFailedEvent( start_at=start_at, inputs=inputs, steps=loop_count, metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price, + WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency, "completed_reason": "error", WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, @@ -235,10 +250,13 @@ class LoopNode(Node): status=WorkflowNodeExecutionStatus.FAILED, error=str(e), metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price, + WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency, WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, }, + llm_usage=loop_usage, ) ) diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 2e2c32ac93..69ab6f0718 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -6,10 +6,13 @@ from sqlalchemy.orm import Session from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.file import File, FileTransferMethod +from core.model_runtime.entities.llm_entities import LLMUsage +from core.tools.__base.tool import Tool from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.errors import ToolInvokeError from core.tools.tool_engine import ToolEngine from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.tools.workflow_as_tool.tool import WorkflowTool from core.variables.segments import ArrayAnySegment, ArrayFileSegment from core.variables.variables import ArrayAnyVariable from core.workflow.enums import ( @@ -136,13 +139,14 @@ class ToolNode(Node): try: # convert tool messages - yield from self._transform_message( + _ = yield from self._transform_message( messages=message_stream, tool_info=tool_info, parameters_for_log=parameters_for_log, user_id=self.user_id, tenant_id=self.tenant_id, node_id=self._node_id, + tool_runtime=tool_runtime, ) except ToolInvokeError as e: yield StreamCompletedEvent( @@ -236,7 +240,8 @@ class ToolNode(Node): user_id: str, tenant_id: str, node_id: str, - ) -> Generator: + tool_runtime: Tool, + ) -> Generator[NodeEventBase, None, LLMUsage]: """ Convert ToolInvokeMessages into tuple[plain_text, files] """ @@ -424,17 +429,34 @@ class ToolNode(Node): is_final=True, ) + usage = self._extract_tool_usage(tool_runtime) + + metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = { + WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, + } + if usage.total_tokens > 0: + metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens + metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price + metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency + yield StreamCompletedEvent( node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables}, - metadata={ - WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, - }, + metadata=metadata, inputs=parameters_for_log, + llm_usage=usage, ) ) + return usage + + @staticmethod + def _extract_tool_usage(tool_runtime: Tool) -> LLMUsage: + if isinstance(tool_runtime, WorkflowTool): + return tool_runtime.latest_usage + return LLMUsage.empty_usage() + @classmethod def _extract_variable_selector_to_variable_mapping( cls,