mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 15:57:06 +08:00
Fix: surface workflow container LLM usage (#27021)
This commit is contained in:
parent
2bcf96565a
commit
4a6398fc1f
@ -72,6 +72,19 @@ default_retrieval_model: dict[str, Any] = {
|
|||||||
class DatasetRetrieval:
|
class DatasetRetrieval:
|
||||||
def __init__(self, application_generate_entity=None):
|
def __init__(self, application_generate_entity=None):
|
||||||
self.application_generate_entity = application_generate_entity
|
self.application_generate_entity = application_generate_entity
|
||||||
|
self._llm_usage = LLMUsage.empty_usage()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def llm_usage(self) -> LLMUsage:
|
||||||
|
return self._llm_usage.model_copy()
|
||||||
|
|
||||||
|
def _record_usage(self, usage: LLMUsage | None) -> None:
|
||||||
|
if usage is None or usage.total_tokens <= 0:
|
||||||
|
return
|
||||||
|
if self._llm_usage.total_tokens == 0:
|
||||||
|
self._llm_usage = usage
|
||||||
|
else:
|
||||||
|
self._llm_usage = self._llm_usage.plus(usage)
|
||||||
|
|
||||||
def retrieve(
|
def retrieve(
|
||||||
self,
|
self,
|
||||||
@ -312,15 +325,18 @@ class DatasetRetrieval:
|
|||||||
)
|
)
|
||||||
tools.append(message_tool)
|
tools.append(message_tool)
|
||||||
dataset_id = None
|
dataset_id = None
|
||||||
|
router_usage = LLMUsage.empty_usage()
|
||||||
if planning_strategy == PlanningStrategy.REACT_ROUTER:
|
if planning_strategy == PlanningStrategy.REACT_ROUTER:
|
||||||
react_multi_dataset_router = ReactMultiDatasetRouter()
|
react_multi_dataset_router = ReactMultiDatasetRouter()
|
||||||
dataset_id = react_multi_dataset_router.invoke(
|
dataset_id, router_usage = react_multi_dataset_router.invoke(
|
||||||
query, tools, model_config, model_instance, user_id, tenant_id
|
query, tools, model_config, model_instance, user_id, tenant_id
|
||||||
)
|
)
|
||||||
|
|
||||||
elif planning_strategy == PlanningStrategy.ROUTER:
|
elif planning_strategy == PlanningStrategy.ROUTER:
|
||||||
function_call_router = FunctionCallMultiDatasetRouter()
|
function_call_router = FunctionCallMultiDatasetRouter()
|
||||||
dataset_id = function_call_router.invoke(query, tools, model_config, model_instance)
|
dataset_id, router_usage = function_call_router.invoke(query, tools, model_config, model_instance)
|
||||||
|
|
||||||
|
self._record_usage(router_usage)
|
||||||
|
|
||||||
if dataset_id:
|
if dataset_id:
|
||||||
# get retrieval model config
|
# get retrieval model config
|
||||||
@ -983,7 +999,8 @@ class DatasetRetrieval:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# handle invoke result
|
# handle invoke result
|
||||||
result_text, _ = self._handle_invoke_result(invoke_result=invoke_result)
|
result_text, usage = self._handle_invoke_result(invoke_result=invoke_result)
|
||||||
|
self._record_usage(usage)
|
||||||
|
|
||||||
result_text_json = parse_and_check_json_markdown(result_text, [])
|
result_text_json = parse_and_check_json_markdown(result_text, [])
|
||||||
automatic_metadata_filters = []
|
automatic_metadata_filters = []
|
||||||
|
|||||||
@ -2,7 +2,7 @@ from typing import Union
|
|||||||
|
|
||||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||||
from core.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage
|
from core.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage
|
||||||
|
|
||||||
|
|
||||||
@ -13,15 +13,15 @@ class FunctionCallMultiDatasetRouter:
|
|||||||
dataset_tools: list[PromptMessageTool],
|
dataset_tools: list[PromptMessageTool],
|
||||||
model_config: ModelConfigWithCredentialsEntity,
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
model_instance: ModelInstance,
|
model_instance: ModelInstance,
|
||||||
) -> Union[str, None]:
|
) -> tuple[Union[str, None], LLMUsage]:
|
||||||
"""Given input, decided what to do.
|
"""Given input, decided what to do.
|
||||||
Returns:
|
Returns:
|
||||||
Action specifying what tool to use.
|
Action specifying what tool to use.
|
||||||
"""
|
"""
|
||||||
if len(dataset_tools) == 0:
|
if len(dataset_tools) == 0:
|
||||||
return None
|
return None, LLMUsage.empty_usage()
|
||||||
elif len(dataset_tools) == 1:
|
elif len(dataset_tools) == 1:
|
||||||
return dataset_tools[0].name
|
return dataset_tools[0].name, LLMUsage.empty_usage()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
prompt_messages = [
|
prompt_messages = [
|
||||||
@ -34,9 +34,10 @@ class FunctionCallMultiDatasetRouter:
|
|||||||
stream=False,
|
stream=False,
|
||||||
model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
|
model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
|
||||||
)
|
)
|
||||||
|
usage = result.usage or LLMUsage.empty_usage()
|
||||||
if result.message.tool_calls:
|
if result.message.tool_calls:
|
||||||
# get retrieval model config
|
# get retrieval model config
|
||||||
return result.message.tool_calls[0].function.name
|
return result.message.tool_calls[0].function.name, usage
|
||||||
return None
|
return None, usage
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None, LLMUsage.empty_usage()
|
||||||
|
|||||||
@ -58,15 +58,15 @@ class ReactMultiDatasetRouter:
|
|||||||
model_instance: ModelInstance,
|
model_instance: ModelInstance,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
) -> Union[str, None]:
|
) -> tuple[Union[str, None], LLMUsage]:
|
||||||
"""Given input, decided what to do.
|
"""Given input, decided what to do.
|
||||||
Returns:
|
Returns:
|
||||||
Action specifying what tool to use.
|
Action specifying what tool to use.
|
||||||
"""
|
"""
|
||||||
if len(dataset_tools) == 0:
|
if len(dataset_tools) == 0:
|
||||||
return None
|
return None, LLMUsage.empty_usage()
|
||||||
elif len(dataset_tools) == 1:
|
elif len(dataset_tools) == 1:
|
||||||
return dataset_tools[0].name
|
return dataset_tools[0].name, LLMUsage.empty_usage()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return self._react_invoke(
|
return self._react_invoke(
|
||||||
@ -78,7 +78,7 @@ class ReactMultiDatasetRouter:
|
|||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None, LLMUsage.empty_usage()
|
||||||
|
|
||||||
def _react_invoke(
|
def _react_invoke(
|
||||||
self,
|
self,
|
||||||
@ -91,7 +91,7 @@ class ReactMultiDatasetRouter:
|
|||||||
prefix: str = PREFIX,
|
prefix: str = PREFIX,
|
||||||
suffix: str = SUFFIX,
|
suffix: str = SUFFIX,
|
||||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||||
) -> Union[str, None]:
|
) -> tuple[Union[str, None], LLMUsage]:
|
||||||
prompt: Union[list[ChatModelMessage], CompletionModelPromptTemplate]
|
prompt: Union[list[ChatModelMessage], CompletionModelPromptTemplate]
|
||||||
if model_config.mode == "chat":
|
if model_config.mode == "chat":
|
||||||
prompt = self.create_chat_prompt(
|
prompt = self.create_chat_prompt(
|
||||||
@ -120,7 +120,7 @@ class ReactMultiDatasetRouter:
|
|||||||
memory=None,
|
memory=None,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
)
|
)
|
||||||
result_text, _ = self._invoke_llm(
|
result_text, usage = self._invoke_llm(
|
||||||
completion_param=model_config.parameters,
|
completion_param=model_config.parameters,
|
||||||
model_instance=model_instance,
|
model_instance=model_instance,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
@ -131,8 +131,8 @@ class ReactMultiDatasetRouter:
|
|||||||
output_parser = StructuredChatOutputParser()
|
output_parser = StructuredChatOutputParser()
|
||||||
react_decision = output_parser.parse(result_text)
|
react_decision = output_parser.parse(result_text)
|
||||||
if isinstance(react_decision, ReactAction):
|
if isinstance(react_decision, ReactAction):
|
||||||
return react_decision.tool
|
return react_decision.tool, usage
|
||||||
return None
|
return None, usage
|
||||||
|
|
||||||
def _invoke_llm(
|
def _invoke_llm(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -1,13 +1,14 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
from typing import Any
|
from typing import Any, cast
|
||||||
|
|
||||||
from flask import has_request_context
|
from flask import has_request_context
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
|
||||||
from core.tools.__base.tool import Tool
|
from core.tools.__base.tool import Tool
|
||||||
from core.tools.__base.tool_runtime import ToolRuntime
|
from core.tools.__base.tool_runtime import ToolRuntime
|
||||||
from core.tools.entities.tool_entities import (
|
from core.tools.entities.tool_entities import (
|
||||||
@ -49,6 +50,7 @@ class WorkflowTool(Tool):
|
|||||||
self.workflow_entities = workflow_entities
|
self.workflow_entities = workflow_entities
|
||||||
self.workflow_call_depth = workflow_call_depth
|
self.workflow_call_depth = workflow_call_depth
|
||||||
self.label = label
|
self.label = label
|
||||||
|
self._latest_usage = LLMUsage.empty_usage()
|
||||||
|
|
||||||
super().__init__(entity=entity, runtime=runtime)
|
super().__init__(entity=entity, runtime=runtime)
|
||||||
|
|
||||||
@ -84,10 +86,11 @@ class WorkflowTool(Tool):
|
|||||||
assert self.runtime.invoke_from is not None
|
assert self.runtime.invoke_from is not None
|
||||||
|
|
||||||
user = self._resolve_user(user_id=user_id)
|
user = self._resolve_user(user_id=user_id)
|
||||||
|
|
||||||
if user is None:
|
if user is None:
|
||||||
raise ToolInvokeError("User not found")
|
raise ToolInvokeError("User not found")
|
||||||
|
|
||||||
|
self._latest_usage = LLMUsage.empty_usage()
|
||||||
|
|
||||||
result = generator.generate(
|
result = generator.generate(
|
||||||
app_model=app,
|
app_model=app,
|
||||||
workflow=workflow,
|
workflow=workflow,
|
||||||
@ -111,9 +114,68 @@ class WorkflowTool(Tool):
|
|||||||
for file in files:
|
for file in files:
|
||||||
yield self.create_file_message(file) # type: ignore
|
yield self.create_file_message(file) # type: ignore
|
||||||
|
|
||||||
|
self._latest_usage = self._derive_usage_from_result(data)
|
||||||
|
|
||||||
yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
|
yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
|
||||||
yield self.create_json_message(outputs)
|
yield self.create_json_message(outputs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def latest_usage(self) -> LLMUsage:
|
||||||
|
return self._latest_usage
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _derive_usage_from_result(cls, data: Mapping[str, Any]) -> LLMUsage:
|
||||||
|
usage_dict = cls._extract_usage_dict(data)
|
||||||
|
if usage_dict is not None:
|
||||||
|
return LLMUsage.from_metadata(cast(LLMUsageMetadata, dict(usage_dict)))
|
||||||
|
|
||||||
|
total_tokens = data.get("total_tokens")
|
||||||
|
total_price = data.get("total_price")
|
||||||
|
if total_tokens is None and total_price is None:
|
||||||
|
return LLMUsage.empty_usage()
|
||||||
|
|
||||||
|
usage_metadata: dict[str, Any] = {}
|
||||||
|
if total_tokens is not None:
|
||||||
|
try:
|
||||||
|
usage_metadata["total_tokens"] = int(str(total_tokens))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
pass
|
||||||
|
if total_price is not None:
|
||||||
|
usage_metadata["total_price"] = str(total_price)
|
||||||
|
currency = data.get("currency")
|
||||||
|
if currency is not None:
|
||||||
|
usage_metadata["currency"] = currency
|
||||||
|
|
||||||
|
if not usage_metadata:
|
||||||
|
return LLMUsage.empty_usage()
|
||||||
|
|
||||||
|
return LLMUsage.from_metadata(cast(LLMUsageMetadata, usage_metadata))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _extract_usage_dict(cls, payload: Mapping[str, Any]) -> Mapping[str, Any] | None:
|
||||||
|
usage_candidate = payload.get("usage")
|
||||||
|
if isinstance(usage_candidate, Mapping):
|
||||||
|
return usage_candidate
|
||||||
|
|
||||||
|
metadata_candidate = payload.get("metadata")
|
||||||
|
if isinstance(metadata_candidate, Mapping):
|
||||||
|
usage_candidate = metadata_candidate.get("usage")
|
||||||
|
if isinstance(usage_candidate, Mapping):
|
||||||
|
return usage_candidate
|
||||||
|
|
||||||
|
for value in payload.values():
|
||||||
|
if isinstance(value, Mapping):
|
||||||
|
found = cls._extract_usage_dict(value)
|
||||||
|
if found is not None:
|
||||||
|
return found
|
||||||
|
elif isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
|
||||||
|
for item in value:
|
||||||
|
if isinstance(item, Mapping):
|
||||||
|
found = cls._extract_usage_dict(item)
|
||||||
|
if found is not None:
|
||||||
|
return found
|
||||||
|
return None
|
||||||
|
|
||||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool":
|
def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool":
|
||||||
"""
|
"""
|
||||||
fork a new tool with metadata
|
fork a new tool with metadata
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState, BaseNodeData
|
from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState, BaseNodeData
|
||||||
|
from .usage_tracking_mixin import LLMUsageTrackingMixin
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseIterationNodeData",
|
"BaseIterationNodeData",
|
||||||
@ -6,4 +7,5 @@ __all__ = [
|
|||||||
"BaseLoopNodeData",
|
"BaseLoopNodeData",
|
||||||
"BaseLoopState",
|
"BaseLoopState",
|
||||||
"BaseNodeData",
|
"BaseNodeData",
|
||||||
|
"LLMUsageTrackingMixin",
|
||||||
]
|
]
|
||||||
|
|||||||
28
api/core/workflow/nodes/base/usage_tracking_mixin.py
Normal file
28
api/core/workflow/nodes/base/usage_tracking_mixin.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
|
from core.workflow.runtime import GraphRuntimeState
|
||||||
|
|
||||||
|
|
||||||
|
class LLMUsageTrackingMixin:
|
||||||
|
"""Provides shared helpers for merging and recording LLM usage within workflow nodes."""
|
||||||
|
|
||||||
|
graph_runtime_state: GraphRuntimeState
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _merge_usage(current: LLMUsage, new_usage: LLMUsage | None) -> LLMUsage:
|
||||||
|
"""Return a combined usage snapshot, preserving zero-value inputs."""
|
||||||
|
if new_usage is None or new_usage.total_tokens <= 0:
|
||||||
|
return current
|
||||||
|
if current.total_tokens == 0:
|
||||||
|
return new_usage
|
||||||
|
return current.plus(new_usage)
|
||||||
|
|
||||||
|
def _accumulate_usage(self, usage: LLMUsage) -> None:
|
||||||
|
"""Push usage into the graph runtime accumulator for downstream reporting."""
|
||||||
|
if usage.total_tokens <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
current_usage = self.graph_runtime_state.llm_usage
|
||||||
|
if current_usage.total_tokens == 0:
|
||||||
|
self.graph_runtime_state.llm_usage = usage.model_copy()
|
||||||
|
else:
|
||||||
|
self.graph_runtime_state.llm_usage = current_usage.plus(usage)
|
||||||
@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, NewType, cast
|
|||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
from typing_extensions import TypeIs
|
from typing_extensions import TypeIs
|
||||||
|
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
from core.variables import IntegerVariable, NoneSegment
|
from core.variables import IntegerVariable, NoneSegment
|
||||||
from core.variables.segments import ArrayAnySegment, ArraySegment
|
from core.variables.segments import ArrayAnySegment, ArraySegment
|
||||||
from core.variables.variables import VariableUnion
|
from core.variables.variables import VariableUnion
|
||||||
@ -34,6 +35,7 @@ from core.workflow.node_events import (
|
|||||||
NodeRunResult,
|
NodeRunResult,
|
||||||
StreamCompletedEvent,
|
StreamCompletedEvent,
|
||||||
)
|
)
|
||||||
|
from core.workflow.nodes.base import LLMUsageTrackingMixin
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
||||||
@ -58,7 +60,7 @@ logger = logging.getLogger(__name__)
|
|||||||
EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment)
|
EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment)
|
||||||
|
|
||||||
|
|
||||||
class IterationNode(Node):
|
class IterationNode(LLMUsageTrackingMixin, Node):
|
||||||
"""
|
"""
|
||||||
Iteration Node.
|
Iteration Node.
|
||||||
"""
|
"""
|
||||||
@ -118,6 +120,7 @@ class IterationNode(Node):
|
|||||||
started_at = naive_utc_now()
|
started_at = naive_utc_now()
|
||||||
iter_run_map: dict[str, float] = {}
|
iter_run_map: dict[str, float] = {}
|
||||||
outputs: list[object] = []
|
outputs: list[object] = []
|
||||||
|
usage_accumulator = [LLMUsage.empty_usage()]
|
||||||
|
|
||||||
yield IterationStartedEvent(
|
yield IterationStartedEvent(
|
||||||
start_at=started_at,
|
start_at=started_at,
|
||||||
@ -130,22 +133,27 @@ class IterationNode(Node):
|
|||||||
iterator_list_value=iterator_list_value,
|
iterator_list_value=iterator_list_value,
|
||||||
outputs=outputs,
|
outputs=outputs,
|
||||||
iter_run_map=iter_run_map,
|
iter_run_map=iter_run_map,
|
||||||
|
usage_accumulator=usage_accumulator,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._accumulate_usage(usage_accumulator[0])
|
||||||
yield from self._handle_iteration_success(
|
yield from self._handle_iteration_success(
|
||||||
started_at=started_at,
|
started_at=started_at,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
outputs=outputs,
|
outputs=outputs,
|
||||||
iterator_list_value=iterator_list_value,
|
iterator_list_value=iterator_list_value,
|
||||||
iter_run_map=iter_run_map,
|
iter_run_map=iter_run_map,
|
||||||
|
usage=usage_accumulator[0],
|
||||||
)
|
)
|
||||||
except IterationNodeError as e:
|
except IterationNodeError as e:
|
||||||
|
self._accumulate_usage(usage_accumulator[0])
|
||||||
yield from self._handle_iteration_failure(
|
yield from self._handle_iteration_failure(
|
||||||
started_at=started_at,
|
started_at=started_at,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
outputs=outputs,
|
outputs=outputs,
|
||||||
iterator_list_value=iterator_list_value,
|
iterator_list_value=iterator_list_value,
|
||||||
iter_run_map=iter_run_map,
|
iter_run_map=iter_run_map,
|
||||||
|
usage=usage_accumulator[0],
|
||||||
error=e,
|
error=e,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -196,6 +204,7 @@ class IterationNode(Node):
|
|||||||
iterator_list_value: Sequence[object],
|
iterator_list_value: Sequence[object],
|
||||||
outputs: list[object],
|
outputs: list[object],
|
||||||
iter_run_map: dict[str, float],
|
iter_run_map: dict[str, float],
|
||||||
|
usage_accumulator: list[LLMUsage],
|
||||||
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
|
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
|
||||||
if self._node_data.is_parallel:
|
if self._node_data.is_parallel:
|
||||||
# Parallel mode execution
|
# Parallel mode execution
|
||||||
@ -203,6 +212,7 @@ class IterationNode(Node):
|
|||||||
iterator_list_value=iterator_list_value,
|
iterator_list_value=iterator_list_value,
|
||||||
outputs=outputs,
|
outputs=outputs,
|
||||||
iter_run_map=iter_run_map,
|
iter_run_map=iter_run_map,
|
||||||
|
usage_accumulator=usage_accumulator,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Sequential mode execution
|
# Sequential mode execution
|
||||||
@ -228,6 +238,9 @@ class IterationNode(Node):
|
|||||||
|
|
||||||
# Update the total tokens from this iteration
|
# Update the total tokens from this iteration
|
||||||
self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens
|
self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens
|
||||||
|
usage_accumulator[0] = self._merge_usage(
|
||||||
|
usage_accumulator[0], graph_engine.graph_runtime_state.llm_usage
|
||||||
|
)
|
||||||
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||||
|
|
||||||
def _execute_parallel_iterations(
|
def _execute_parallel_iterations(
|
||||||
@ -235,6 +248,7 @@ class IterationNode(Node):
|
|||||||
iterator_list_value: Sequence[object],
|
iterator_list_value: Sequence[object],
|
||||||
outputs: list[object],
|
outputs: list[object],
|
||||||
iter_run_map: dict[str, float],
|
iter_run_map: dict[str, float],
|
||||||
|
usage_accumulator: list[LLMUsage],
|
||||||
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
|
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
|
||||||
# Initialize outputs list with None values to maintain order
|
# Initialize outputs list with None values to maintain order
|
||||||
outputs.extend([None] * len(iterator_list_value))
|
outputs.extend([None] * len(iterator_list_value))
|
||||||
@ -245,7 +259,16 @@ class IterationNode(Node):
|
|||||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
# Submit all iteration tasks
|
# Submit all iteration tasks
|
||||||
future_to_index: dict[
|
future_to_index: dict[
|
||||||
Future[tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion]]],
|
Future[
|
||||||
|
tuple[
|
||||||
|
datetime,
|
||||||
|
list[GraphNodeEventBase],
|
||||||
|
object | None,
|
||||||
|
int,
|
||||||
|
dict[str, VariableUnion],
|
||||||
|
LLMUsage,
|
||||||
|
]
|
||||||
|
],
|
||||||
int,
|
int,
|
||||||
] = {}
|
] = {}
|
||||||
for index, item in enumerate(iterator_list_value):
|
for index, item in enumerate(iterator_list_value):
|
||||||
@ -264,7 +287,14 @@ class IterationNode(Node):
|
|||||||
index = future_to_index[future]
|
index = future_to_index[future]
|
||||||
try:
|
try:
|
||||||
result = future.result()
|
result = future.result()
|
||||||
iter_start_at, events, output_value, tokens_used, conversation_snapshot = result
|
(
|
||||||
|
iter_start_at,
|
||||||
|
events,
|
||||||
|
output_value,
|
||||||
|
tokens_used,
|
||||||
|
conversation_snapshot,
|
||||||
|
iteration_usage,
|
||||||
|
) = result
|
||||||
|
|
||||||
# Update outputs at the correct index
|
# Update outputs at the correct index
|
||||||
outputs[index] = output_value
|
outputs[index] = output_value
|
||||||
@ -276,6 +306,8 @@ class IterationNode(Node):
|
|||||||
self.graph_runtime_state.total_tokens += tokens_used
|
self.graph_runtime_state.total_tokens += tokens_used
|
||||||
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||||
|
|
||||||
|
usage_accumulator[0] = self._merge_usage(usage_accumulator[0], iteration_usage)
|
||||||
|
|
||||||
# Sync conversation variables after iteration completion
|
# Sync conversation variables after iteration completion
|
||||||
self._sync_conversation_variables_from_snapshot(conversation_snapshot)
|
self._sync_conversation_variables_from_snapshot(conversation_snapshot)
|
||||||
|
|
||||||
@ -303,7 +335,7 @@ class IterationNode(Node):
|
|||||||
item: object,
|
item: object,
|
||||||
flask_app: Flask,
|
flask_app: Flask,
|
||||||
context_vars: contextvars.Context,
|
context_vars: contextvars.Context,
|
||||||
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion]]:
|
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion], LLMUsage]:
|
||||||
"""Execute a single iteration in parallel mode and return results."""
|
"""Execute a single iteration in parallel mode and return results."""
|
||||||
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
|
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
|
||||||
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||||
@ -332,6 +364,7 @@ class IterationNode(Node):
|
|||||||
output_value,
|
output_value,
|
||||||
graph_engine.graph_runtime_state.total_tokens,
|
graph_engine.graph_runtime_state.total_tokens,
|
||||||
conversation_snapshot,
|
conversation_snapshot,
|
||||||
|
graph_engine.graph_runtime_state.llm_usage,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_iteration_success(
|
def _handle_iteration_success(
|
||||||
@ -341,6 +374,8 @@ class IterationNode(Node):
|
|||||||
outputs: list[object],
|
outputs: list[object],
|
||||||
iterator_list_value: Sequence[object],
|
iterator_list_value: Sequence[object],
|
||||||
iter_run_map: dict[str, float],
|
iter_run_map: dict[str, float],
|
||||||
|
*,
|
||||||
|
usage: LLMUsage,
|
||||||
) -> Generator[NodeEventBase, None, None]:
|
) -> Generator[NodeEventBase, None, None]:
|
||||||
# Flatten the list of lists if all outputs are lists
|
# Flatten the list of lists if all outputs are lists
|
||||||
flattened_outputs = self._flatten_outputs_if_needed(outputs)
|
flattened_outputs = self._flatten_outputs_if_needed(outputs)
|
||||||
@ -351,7 +386,9 @@ class IterationNode(Node):
|
|||||||
outputs={"output": flattened_outputs},
|
outputs={"output": flattened_outputs},
|
||||||
steps=len(iterator_list_value),
|
steps=len(iterator_list_value),
|
||||||
metadata={
|
metadata={
|
||||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||||
|
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||||
|
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||||
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
|
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@ -362,8 +399,11 @@ class IterationNode(Node):
|
|||||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
outputs={"output": flattened_outputs},
|
outputs={"output": flattened_outputs},
|
||||||
metadata={
|
metadata={
|
||||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||||
|
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||||
|
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||||
},
|
},
|
||||||
|
llm_usage=usage,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -400,6 +440,8 @@ class IterationNode(Node):
|
|||||||
outputs: list[object],
|
outputs: list[object],
|
||||||
iterator_list_value: Sequence[object],
|
iterator_list_value: Sequence[object],
|
||||||
iter_run_map: dict[str, float],
|
iter_run_map: dict[str, float],
|
||||||
|
*,
|
||||||
|
usage: LLMUsage,
|
||||||
error: IterationNodeError,
|
error: IterationNodeError,
|
||||||
) -> Generator[NodeEventBase, None, None]:
|
) -> Generator[NodeEventBase, None, None]:
|
||||||
# Flatten the list of lists if all outputs are lists (even in failure case)
|
# Flatten the list of lists if all outputs are lists (even in failure case)
|
||||||
@ -411,7 +453,9 @@ class IterationNode(Node):
|
|||||||
outputs={"output": flattened_outputs},
|
outputs={"output": flattened_outputs},
|
||||||
steps=len(iterator_list_value),
|
steps=len(iterator_list_value),
|
||||||
metadata={
|
metadata={
|
||||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||||
|
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||||
|
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||||
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
|
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
|
||||||
},
|
},
|
||||||
error=str(error),
|
error=str(error),
|
||||||
@ -420,6 +464,12 @@ class IterationNode(Node):
|
|||||||
node_run_result=NodeRunResult(
|
node_run_result=NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.FAILED,
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
error=str(error),
|
error=str(error),
|
||||||
|
metadata={
|
||||||
|
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||||
|
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||||
|
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||||
|
},
|
||||||
|
llm_usage=usage,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -15,14 +15,11 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
|
|||||||
from core.entities.agent_entities import PlanningStrategy
|
from core.entities.agent_entities import PlanningStrategy
|
||||||
from core.entities.model_entities import ModelStatus
|
from core.entities.model_entities import ModelStatus
|
||||||
from core.model_manager import ModelInstance, ModelManager
|
from core.model_manager import ModelInstance, ModelManager
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
PromptMessageRole,
|
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||||
)
|
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||||
from core.model_runtime.entities.model_entities import (
|
|
||||||
ModelFeature,
|
|
||||||
ModelType,
|
|
||||||
)
|
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.prompt.simple_prompt_transform import ModelMode
|
from core.prompt.simple_prompt_transform import ModelMode
|
||||||
from core.rag.datasource.retrieval_service import RetrievalService
|
from core.rag.datasource.retrieval_service import RetrievalService
|
||||||
from core.rag.entities.metadata_entities import Condition, MetadataCondition
|
from core.rag.entities.metadata_entities import Condition, MetadataCondition
|
||||||
@ -33,8 +30,14 @@ from core.variables import (
|
|||||||
)
|
)
|
||||||
from core.variables.segments import ArrayObjectSegment
|
from core.variables.segments import ArrayObjectSegment
|
||||||
from core.workflow.entities import GraphInitParams
|
from core.workflow.entities import GraphInitParams
|
||||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
from core.workflow.enums import (
|
||||||
|
ErrorStrategy,
|
||||||
|
NodeType,
|
||||||
|
WorkflowNodeExecutionMetadataKey,
|
||||||
|
WorkflowNodeExecutionStatus,
|
||||||
|
)
|
||||||
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult
|
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult
|
||||||
|
from core.workflow.nodes.base import LLMUsageTrackingMixin
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
|
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
|
||||||
@ -80,7 +83,7 @@ default_retrieval_model = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeRetrievalNode(Node):
|
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node):
|
||||||
node_type = NodeType.KNOWLEDGE_RETRIEVAL
|
node_type = NodeType.KNOWLEDGE_RETRIEVAL
|
||||||
|
|
||||||
_node_data: KnowledgeRetrievalNodeData
|
_node_data: KnowledgeRetrievalNodeData
|
||||||
@ -182,14 +185,21 @@ class KnowledgeRetrievalNode(Node):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# retrieve knowledge
|
# retrieve knowledge
|
||||||
|
usage = LLMUsage.empty_usage()
|
||||||
try:
|
try:
|
||||||
results = self._fetch_dataset_retriever(node_data=self._node_data, query=query)
|
results, usage = self._fetch_dataset_retriever(node_data=self._node_data, query=query)
|
||||||
outputs = {"result": ArrayObjectSegment(value=results)}
|
outputs = {"result": ArrayObjectSegment(value=results)}
|
||||||
return NodeRunResult(
|
return NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
inputs=variables,
|
inputs=variables,
|
||||||
process_data={},
|
process_data={"usage": jsonable_encoder(usage)},
|
||||||
outputs=outputs, # type: ignore
|
outputs=outputs, # type: ignore
|
||||||
|
metadata={
|
||||||
|
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||||
|
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||||
|
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||||
|
},
|
||||||
|
llm_usage=usage,
|
||||||
)
|
)
|
||||||
|
|
||||||
except KnowledgeRetrievalNodeError as e:
|
except KnowledgeRetrievalNodeError as e:
|
||||||
@ -199,6 +209,7 @@ class KnowledgeRetrievalNode(Node):
|
|||||||
inputs=variables,
|
inputs=variables,
|
||||||
error=str(e),
|
error=str(e),
|
||||||
error_type=type(e).__name__,
|
error_type=type(e).__name__,
|
||||||
|
llm_usage=usage,
|
||||||
)
|
)
|
||||||
# Temporary handle all exceptions from DatasetRetrieval class here.
|
# Temporary handle all exceptions from DatasetRetrieval class here.
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -207,11 +218,15 @@ class KnowledgeRetrievalNode(Node):
|
|||||||
inputs=variables,
|
inputs=variables,
|
||||||
error=str(e),
|
error=str(e),
|
||||||
error_type=type(e).__name__,
|
error_type=type(e).__name__,
|
||||||
|
llm_usage=usage,
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
db.session.close()
|
db.session.close()
|
||||||
|
|
||||||
def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]:
|
def _fetch_dataset_retriever(
|
||||||
|
self, node_data: KnowledgeRetrievalNodeData, query: str
|
||||||
|
) -> tuple[list[dict[str, Any]], LLMUsage]:
|
||||||
|
usage = LLMUsage.empty_usage()
|
||||||
available_datasets = []
|
available_datasets = []
|
||||||
dataset_ids = node_data.dataset_ids
|
dataset_ids = node_data.dataset_ids
|
||||||
|
|
||||||
@ -245,9 +260,10 @@ class KnowledgeRetrievalNode(Node):
|
|||||||
if not dataset:
|
if not dataset:
|
||||||
continue
|
continue
|
||||||
available_datasets.append(dataset)
|
available_datasets.append(dataset)
|
||||||
metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition(
|
metadata_filter_document_ids, metadata_condition, metadata_usage = self._get_metadata_filter_condition(
|
||||||
[dataset.id for dataset in available_datasets], query, node_data
|
[dataset.id for dataset in available_datasets], query, node_data
|
||||||
)
|
)
|
||||||
|
usage = self._merge_usage(usage, metadata_usage)
|
||||||
all_documents = []
|
all_documents = []
|
||||||
dataset_retrieval = DatasetRetrieval()
|
dataset_retrieval = DatasetRetrieval()
|
||||||
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
|
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
|
||||||
@ -330,6 +346,8 @@ class KnowledgeRetrievalNode(Node):
|
|||||||
metadata_filter_document_ids=metadata_filter_document_ids,
|
metadata_filter_document_ids=metadata_filter_document_ids,
|
||||||
metadata_condition=metadata_condition,
|
metadata_condition=metadata_condition,
|
||||||
)
|
)
|
||||||
|
usage = self._merge_usage(usage, dataset_retrieval.llm_usage)
|
||||||
|
|
||||||
dify_documents = [item for item in all_documents if item.provider == "dify"]
|
dify_documents = [item for item in all_documents if item.provider == "dify"]
|
||||||
external_documents = [item for item in all_documents if item.provider == "external"]
|
external_documents = [item for item in all_documents if item.provider == "external"]
|
||||||
retrieval_resource_list = []
|
retrieval_resource_list = []
|
||||||
@ -406,11 +424,12 @@ class KnowledgeRetrievalNode(Node):
|
|||||||
)
|
)
|
||||||
for position, item in enumerate(retrieval_resource_list, start=1):
|
for position, item in enumerate(retrieval_resource_list, start=1):
|
||||||
item["metadata"]["position"] = position
|
item["metadata"]["position"] = position
|
||||||
return retrieval_resource_list
|
return retrieval_resource_list, usage
|
||||||
|
|
||||||
def _get_metadata_filter_condition(
|
def _get_metadata_filter_condition(
|
||||||
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
|
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
|
||||||
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None]:
|
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None, LLMUsage]:
|
||||||
|
usage = LLMUsage.empty_usage()
|
||||||
document_query = db.session.query(Document).where(
|
document_query = db.session.query(Document).where(
|
||||||
Document.dataset_id.in_(dataset_ids),
|
Document.dataset_id.in_(dataset_ids),
|
||||||
Document.indexing_status == "completed",
|
Document.indexing_status == "completed",
|
||||||
@ -420,9 +439,12 @@ class KnowledgeRetrievalNode(Node):
|
|||||||
filters: list[Any] = []
|
filters: list[Any] = []
|
||||||
metadata_condition = None
|
metadata_condition = None
|
||||||
if node_data.metadata_filtering_mode == "disabled":
|
if node_data.metadata_filtering_mode == "disabled":
|
||||||
return None, None
|
return None, None, usage
|
||||||
elif node_data.metadata_filtering_mode == "automatic":
|
elif node_data.metadata_filtering_mode == "automatic":
|
||||||
automatic_metadata_filters = self._automatic_metadata_filter_func(dataset_ids, query, node_data)
|
automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func(
|
||||||
|
dataset_ids, query, node_data
|
||||||
|
)
|
||||||
|
usage = self._merge_usage(usage, automatic_usage)
|
||||||
if automatic_metadata_filters:
|
if automatic_metadata_filters:
|
||||||
conditions = []
|
conditions = []
|
||||||
for sequence, filter in enumerate(automatic_metadata_filters):
|
for sequence, filter in enumerate(automatic_metadata_filters):
|
||||||
@ -496,11 +518,12 @@ class KnowledgeRetrievalNode(Node):
|
|||||||
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
|
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
|
||||||
for document in documents:
|
for document in documents:
|
||||||
metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore
|
metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore
|
||||||
return metadata_filter_document_ids, metadata_condition
|
return metadata_filter_document_ids, metadata_condition, usage
|
||||||
|
|
||||||
def _automatic_metadata_filter_func(
|
def _automatic_metadata_filter_func(
|
||||||
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
|
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
|
||||||
) -> list[dict[str, Any]]:
|
) -> tuple[list[dict[str, Any]], LLMUsage]:
|
||||||
|
usage = LLMUsage.empty_usage()
|
||||||
# get all metadata field
|
# get all metadata field
|
||||||
stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
|
stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
|
||||||
metadata_fields = db.session.scalars(stmt).all()
|
metadata_fields = db.session.scalars(stmt).all()
|
||||||
@ -548,6 +571,7 @@ class KnowledgeRetrievalNode(Node):
|
|||||||
for event in generator:
|
for event in generator:
|
||||||
if isinstance(event, ModelInvokeCompletedEvent):
|
if isinstance(event, ModelInvokeCompletedEvent):
|
||||||
result_text = event.text
|
result_text = event.text
|
||||||
|
usage = self._merge_usage(usage, event.usage)
|
||||||
break
|
break
|
||||||
|
|
||||||
result_text_json = parse_and_check_json_markdown(result_text, [])
|
result_text_json = parse_and_check_json_markdown(result_text, [])
|
||||||
@ -564,8 +588,8 @@ class KnowledgeRetrievalNode(Node):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
return []
|
return [], usage
|
||||||
return automatic_metadata_filters
|
return automatic_metadata_filters, usage
|
||||||
|
|
||||||
def _process_metadata_filter_func(
|
def _process_metadata_filter_func(
|
||||||
self, sequence: int, condition: str, metadata_name: str, value: Any, filters: list[Any]
|
self, sequence: int, condition: str, metadata_name: str, value: Any, filters: list[Any]
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from collections.abc import Callable, Generator, Mapping, Sequence
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||||
|
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
from core.variables import Segment, SegmentType
|
from core.variables import Segment, SegmentType
|
||||||
from core.workflow.enums import (
|
from core.workflow.enums import (
|
||||||
ErrorStrategy,
|
ErrorStrategy,
|
||||||
@ -27,6 +28,7 @@ from core.workflow.node_events import (
|
|||||||
NodeRunResult,
|
NodeRunResult,
|
||||||
StreamCompletedEvent,
|
StreamCompletedEvent,
|
||||||
)
|
)
|
||||||
|
from core.workflow.nodes.base import LLMUsageTrackingMixin
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.nodes.loop.entities import LoopNodeData, LoopVariableData
|
from core.workflow.nodes.loop.entities import LoopNodeData, LoopVariableData
|
||||||
@ -40,7 +42,7 @@ if TYPE_CHECKING:
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LoopNode(Node):
|
class LoopNode(LLMUsageTrackingMixin, Node):
|
||||||
"""
|
"""
|
||||||
Loop Node.
|
Loop Node.
|
||||||
"""
|
"""
|
||||||
@ -117,6 +119,7 @@ class LoopNode(Node):
|
|||||||
|
|
||||||
loop_duration_map: dict[str, float] = {}
|
loop_duration_map: dict[str, float] = {}
|
||||||
single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output
|
single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output
|
||||||
|
loop_usage = LLMUsage.empty_usage()
|
||||||
|
|
||||||
# Start Loop event
|
# Start Loop event
|
||||||
yield LoopStartedEvent(
|
yield LoopStartedEvent(
|
||||||
@ -163,6 +166,9 @@ class LoopNode(Node):
|
|||||||
# Update the total tokens from this iteration
|
# Update the total tokens from this iteration
|
||||||
cost_tokens += graph_engine.graph_runtime_state.total_tokens
|
cost_tokens += graph_engine.graph_runtime_state.total_tokens
|
||||||
|
|
||||||
|
# Accumulate usage from the sub-graph execution
|
||||||
|
loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage)
|
||||||
|
|
||||||
# Collect loop variable values after iteration
|
# Collect loop variable values after iteration
|
||||||
single_loop_variable = {}
|
single_loop_variable = {}
|
||||||
for key, selector in loop_variable_selectors.items():
|
for key, selector in loop_variable_selectors.items():
|
||||||
@ -189,6 +195,7 @@ class LoopNode(Node):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.graph_runtime_state.total_tokens += cost_tokens
|
self.graph_runtime_state.total_tokens += cost_tokens
|
||||||
|
self._accumulate_usage(loop_usage)
|
||||||
# Loop completed successfully
|
# Loop completed successfully
|
||||||
yield LoopSucceededEvent(
|
yield LoopSucceededEvent(
|
||||||
start_at=start_at,
|
start_at=start_at,
|
||||||
@ -196,7 +203,9 @@ class LoopNode(Node):
|
|||||||
outputs=self._node_data.outputs,
|
outputs=self._node_data.outputs,
|
||||||
steps=loop_count,
|
steps=loop_count,
|
||||||
metadata={
|
metadata={
|
||||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: cost_tokens,
|
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
|
||||||
|
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
|
||||||
|
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
|
||||||
"completed_reason": "loop_break" if reach_break_condition else "loop_completed",
|
"completed_reason": "loop_break" if reach_break_condition else "loop_completed",
|
||||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||||
@ -207,22 +216,28 @@ class LoopNode(Node):
|
|||||||
node_run_result=NodeRunResult(
|
node_run_result=NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
metadata={
|
metadata={
|
||||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
|
||||||
|
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
|
||||||
|
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
|
||||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||||
},
|
},
|
||||||
outputs=self._node_data.outputs,
|
outputs=self._node_data.outputs,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
|
llm_usage=loop_usage,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
self._accumulate_usage(loop_usage)
|
||||||
yield LoopFailedEvent(
|
yield LoopFailedEvent(
|
||||||
start_at=start_at,
|
start_at=start_at,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
steps=loop_count,
|
steps=loop_count,
|
||||||
metadata={
|
metadata={
|
||||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
|
||||||
|
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
|
||||||
|
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
|
||||||
"completed_reason": "error",
|
"completed_reason": "error",
|
||||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||||
@ -235,10 +250,13 @@ class LoopNode(Node):
|
|||||||
status=WorkflowNodeExecutionStatus.FAILED,
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
error=str(e),
|
error=str(e),
|
||||||
metadata={
|
metadata={
|
||||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
|
||||||
|
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
|
||||||
|
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
|
||||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||||
},
|
},
|
||||||
|
llm_usage=loop_usage,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -6,10 +6,13 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||||
from core.file import File, FileTransferMethod
|
from core.file import File, FileTransferMethod
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
|
from core.tools.__base.tool import Tool
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||||
from core.tools.errors import ToolInvokeError
|
from core.tools.errors import ToolInvokeError
|
||||||
from core.tools.tool_engine import ToolEngine
|
from core.tools.tool_engine import ToolEngine
|
||||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||||
|
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||||
from core.variables.segments import ArrayAnySegment, ArrayFileSegment
|
from core.variables.segments import ArrayAnySegment, ArrayFileSegment
|
||||||
from core.variables.variables import ArrayAnyVariable
|
from core.variables.variables import ArrayAnyVariable
|
||||||
from core.workflow.enums import (
|
from core.workflow.enums import (
|
||||||
@ -136,13 +139,14 @@ class ToolNode(Node):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# convert tool messages
|
# convert tool messages
|
||||||
yield from self._transform_message(
|
_ = yield from self._transform_message(
|
||||||
messages=message_stream,
|
messages=message_stream,
|
||||||
tool_info=tool_info,
|
tool_info=tool_info,
|
||||||
parameters_for_log=parameters_for_log,
|
parameters_for_log=parameters_for_log,
|
||||||
user_id=self.user_id,
|
user_id=self.user_id,
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
node_id=self._node_id,
|
node_id=self._node_id,
|
||||||
|
tool_runtime=tool_runtime,
|
||||||
)
|
)
|
||||||
except ToolInvokeError as e:
|
except ToolInvokeError as e:
|
||||||
yield StreamCompletedEvent(
|
yield StreamCompletedEvent(
|
||||||
@ -236,7 +240,8 @@ class ToolNode(Node):
|
|||||||
user_id: str,
|
user_id: str,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
node_id: str,
|
node_id: str,
|
||||||
) -> Generator:
|
tool_runtime: Tool,
|
||||||
|
) -> Generator[NodeEventBase, None, LLMUsage]:
|
||||||
"""
|
"""
|
||||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||||
"""
|
"""
|
||||||
@ -424,17 +429,34 @@ class ToolNode(Node):
|
|||||||
is_final=True,
|
is_final=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
usage = self._extract_tool_usage(tool_runtime)
|
||||||
|
|
||||||
|
metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
|
||||||
|
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
|
||||||
|
}
|
||||||
|
if usage.total_tokens > 0:
|
||||||
|
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens
|
||||||
|
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price
|
||||||
|
metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency
|
||||||
|
|
||||||
yield StreamCompletedEvent(
|
yield StreamCompletedEvent(
|
||||||
node_run_result=NodeRunResult(
|
node_run_result=NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
|
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
|
||||||
metadata={
|
metadata=metadata,
|
||||||
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
|
|
||||||
},
|
|
||||||
inputs=parameters_for_log,
|
inputs=parameters_for_log,
|
||||||
|
llm_usage=usage,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return usage
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_tool_usage(tool_runtime: Tool) -> LLMUsage:
|
||||||
|
if isinstance(tool_runtime, WorkflowTool):
|
||||||
|
return tool_runtime.latest_usage
|
||||||
|
return LLMUsage.empty_usage()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _extract_variable_selector_to_variable_mapping(
|
def _extract_variable_selector_to_variable_mapping(
|
||||||
cls,
|
cls,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user