diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py index 2133b15663..fa2ec3f21f 100644 --- a/api/core/ops/entities/trace_entity.py +++ b/api/core/ops/entities/trace_entity.py @@ -48,6 +48,8 @@ class WorkflowTraceInfo(BaseTraceInfo): workflow_run_version: str error: str | None = None total_tokens: int + prompt_tokens: int | None = None + completion_tokens: int | None = None file_list: list[str] query: str metadata: dict[str, Any] diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 8aad70ed01..7a876fb54f 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -541,6 +541,35 @@ class TraceTask: return "anonymous" + @classmethod + def _calculate_workflow_token_split(cls, workflow_run_id: str, tenant_id: str) -> tuple[int, int]: + from core.workflow.enums import WorkflowNodeExecutionMetadataKey + from models.workflow import WorkflowNodeExecutionModel + + with Session(db.engine) as session: + node_executions = session.scalars( + select(WorkflowNodeExecutionModel).where( + WorkflowNodeExecutionModel.tenant_id == tenant_id, + WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id, + ) + ).all() + + total_prompt = 0 + total_completion = 0 + + for node_exec in node_executions: + metadata = node_exec.execution_metadata_dict + + prompt = metadata.get(WorkflowNodeExecutionMetadataKey.PROMPT_TOKENS) + if prompt is not None: + total_prompt += prompt + + completion = metadata.get(WorkflowNodeExecutionMetadataKey.COMPLETION_TOKENS) + if completion is not None: + total_completion += completion + + return (total_prompt, total_completion) + def __init__( self, trace_type: Any, @@ -627,6 +656,10 @@ class TraceTask: total_tokens = workflow_run.total_tokens + prompt_tokens, completion_tokens = self._calculate_workflow_token_split( + workflow_run_id=workflow_run_id, tenant_id=tenant_id + ) + file_list = workflow_run_inputs.get("sys.file") or [] query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or "" @@ -684,6 +717,8 @@ class TraceTask: workflow_run_version=workflow_run_version, error=error, total_tokens=total_tokens, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, file_list=file_list, query=query, metadata=metadata, diff --git a/api/core/workflow/enums.py b/api/core/workflow/enums.py index bb3b13e8c6..938a2f5e21 100644 --- a/api/core/workflow/enums.py +++ b/api/core/workflow/enums.py @@ -232,6 +232,8 @@ class WorkflowNodeExecutionMetadataKey(StrEnum): """ TOTAL_TOKENS = "total_tokens" + PROMPT_TOKENS = "prompt_tokens" + COMPLETION_TOKENS = "completion_tokens" TOTAL_PRICE = "total_price" CURRENCY = "currency" TOOL_INFO = "tool_info" diff --git a/api/core/workflow/graph_engine/layers/persistence.py b/api/core/workflow/graph_engine/layers/persistence.py index 46b6f12b38..a57ffbf12f 100644 --- a/api/core/workflow/graph_engine/layers/persistence.py +++ b/api/core/workflow/graph_engine/layers/persistence.py @@ -437,6 +437,8 @@ class WorkflowPersistenceLayer(GraphEngineLayer): "created_at": domain_execution.created_at, "finished_at": domain_execution.finished_at, "total_tokens": meta.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS, 0), + "prompt_tokens": meta.get(WorkflowNodeExecutionMetadataKey.PROMPT_TOKENS), + "completion_tokens": meta.get(WorkflowNodeExecutionMetadataKey.COMPLETION_TOKENS), "total_price": meta.get(WorkflowNodeExecutionMetadataKey.TOTAL_PRICE, 0.0), "currency": meta.get(WorkflowNodeExecutionMetadataKey.CURRENCY), "tool_name": (meta.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO) or {}).get("tool_name") diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index dfb55dcd80..1600c94e2a 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -322,6 +322,8 @@ class LLMNode(Node[LLMNodeData]): outputs=outputs, metadata={ WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, + WorkflowNodeExecutionMetadataKey.PROMPT_TOKENS: usage.prompt_tokens, + WorkflowNodeExecutionMetadataKey.COMPLETION_TOKENS: usage.completion_tokens, WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, }, diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index f80e2f8f87..735855d979 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -446,6 +446,8 @@ class ToolNode(Node[ToolNodeData]): } if isinstance(usage.total_tokens, int) and usage.total_tokens > 0: metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens + metadata[WorkflowNodeExecutionMetadataKey.PROMPT_TOKENS] = usage.prompt_tokens + metadata[WorkflowNodeExecutionMetadataKey.COMPLETION_TOKENS] = usage.completion_tokens metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency diff --git a/api/enterprise/telemetry/enterprise_trace.py b/api/enterprise/telemetry/enterprise_trace.py index faa2ff02b0..a500ec7424 100644 --- a/api/enterprise/telemetry/enterprise_trace.py +++ b/api/enterprise/telemetry/enterprise_trace.py @@ -145,6 +145,8 @@ class EnterpriseOtelTrace: "dify.workspace.name": info.metadata.get("workspace_name"), "gen_ai.user.id": info.metadata.get("user_id"), "gen_ai.usage.total_tokens": info.total_tokens, + "gen_ai.usage.input_tokens": info.prompt_tokens, + "gen_ai.usage.output_tokens": info.completion_tokens, "dify.workflow.version": info.workflow_run_version, } ) diff --git a/api/extensions/otel/semconv/dify.py b/api/extensions/otel/semconv/dify.py index 9ffc58fde9..301ddd11aa 100644 --- a/api/extensions/otel/semconv/dify.py +++ b/api/extensions/otel/semconv/dify.py @@ -24,3 +24,12 @@ class DifySpanAttributes: INVOKED_BY = "dify.invoked_by" """Invoked by, e.g. end_user, account, user.""" + + USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens" + """Number of input tokens (prompt tokens) used.""" + + USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens" + """Number of output tokens (completion tokens) generated.""" + + USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens" + """Total number of tokens used."""