mirror of https://github.com/langgenius/dify.git
feat: Enhances OpenTelemetry node parsers (#30706)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
febc9b930d
commit
eca26a9b9b
|
|
@ -8,7 +8,7 @@ intercept and respond to GraphEngine events.
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
|
||||
from core.workflow.graph_events import GraphEngineEvent
|
||||
from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.runtime import ReadOnlyGraphRuntimeState
|
||||
|
||||
|
|
@ -98,7 +98,7 @@ class GraphEngineLayer(ABC):
|
|||
"""
|
||||
pass
|
||||
|
||||
def on_node_run_start(self, node: Node) -> None: # noqa: B027
|
||||
def on_node_run_start(self, node: Node) -> None:
|
||||
"""
|
||||
Called immediately before a node begins execution.
|
||||
|
||||
|
|
@ -109,9 +109,11 @@ class GraphEngineLayer(ABC):
|
|||
Args:
|
||||
node: The node instance about to be executed
|
||||
"""
|
||||
pass
|
||||
return
|
||||
|
||||
def on_node_run_end(self, node: Node, error: Exception | None) -> None: # noqa: B027
|
||||
def on_node_run_end(
|
||||
self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Called after a node finishes execution.
|
||||
|
||||
|
|
@ -121,5 +123,6 @@ class GraphEngineLayer(ABC):
|
|||
Args:
|
||||
node: The node instance that just finished execution
|
||||
error: Exception instance if the node failed, otherwise None
|
||||
result_event: The final result event from node execution (succeeded/failed/paused), if any
|
||||
"""
|
||||
pass
|
||||
return
|
||||
|
|
|
|||
|
|
@ -1,61 +0,0 @@
|
|||
"""
|
||||
Node-level OpenTelemetry parser interfaces and defaults.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Protocol
|
||||
|
||||
from opentelemetry.trace import Span
|
||||
from opentelemetry.trace.status import Status, StatusCode
|
||||
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
|
||||
|
||||
class NodeOTelParser(Protocol):
|
||||
"""Parser interface for node-specific OpenTelemetry enrichment."""
|
||||
|
||||
def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None: ...
|
||||
|
||||
|
||||
class DefaultNodeOTelParser:
|
||||
"""Fallback parser used when no node-specific parser is registered."""
|
||||
|
||||
def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None:
|
||||
span.set_attribute("node.id", node.id)
|
||||
if node.execution_id:
|
||||
span.set_attribute("node.execution_id", node.execution_id)
|
||||
if hasattr(node, "node_type") and node.node_type:
|
||||
span.set_attribute("node.type", node.node_type.value)
|
||||
|
||||
if error:
|
||||
span.record_exception(error)
|
||||
span.set_status(Status(StatusCode.ERROR, str(error)))
|
||||
else:
|
||||
span.set_status(Status(StatusCode.OK))
|
||||
|
||||
|
||||
class ToolNodeOTelParser:
|
||||
"""Parser for tool nodes that captures tool-specific metadata."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._delegate = DefaultNodeOTelParser()
|
||||
|
||||
def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None:
|
||||
self._delegate.parse(node=node, span=span, error=error)
|
||||
|
||||
tool_data = getattr(node, "_node_data", None)
|
||||
if not isinstance(tool_data, ToolNodeData):
|
||||
return
|
||||
|
||||
span.set_attribute("tool.provider.id", tool_data.provider_id)
|
||||
span.set_attribute("tool.provider.type", tool_data.provider_type.value)
|
||||
span.set_attribute("tool.provider.name", tool_data.provider_name)
|
||||
span.set_attribute("tool.name", tool_data.tool_name)
|
||||
span.set_attribute("tool.label", tool_data.tool_label)
|
||||
if tool_data.plugin_unique_identifier:
|
||||
span.set_attribute("tool.plugin.id", tool_data.plugin_unique_identifier)
|
||||
if tool_data.credential_id:
|
||||
span.set_attribute("tool.credential.id", tool_data.credential_id)
|
||||
if tool_data.tool_configurations:
|
||||
span.set_attribute("tool.config", json.dumps(tool_data.tool_configurations, ensure_ascii=False))
|
||||
|
|
@ -18,12 +18,15 @@ from typing_extensions import override
|
|||
from configs import dify_config
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_engine.layers.node_parsers import (
|
||||
from core.workflow.graph_events import GraphNodeEventBase
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from extensions.otel.parser import (
|
||||
DefaultNodeOTelParser,
|
||||
LLMNodeOTelParser,
|
||||
NodeOTelParser,
|
||||
RetrievalNodeOTelParser,
|
||||
ToolNodeOTelParser,
|
||||
)
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from extensions.otel.runtime import is_instrument_flag_enabled
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -72,6 +75,8 @@ class ObservabilityLayer(GraphEngineLayer):
|
|||
"""Initialize parser registry for node types."""
|
||||
self._parsers = {
|
||||
NodeType.TOOL: ToolNodeOTelParser(),
|
||||
NodeType.LLM: LLMNodeOTelParser(),
|
||||
NodeType.KNOWLEDGE_RETRIEVAL: RetrievalNodeOTelParser(),
|
||||
}
|
||||
|
||||
def _get_parser(self, node: Node) -> NodeOTelParser:
|
||||
|
|
@ -119,7 +124,9 @@ class ObservabilityLayer(GraphEngineLayer):
|
|||
logger.warning("Failed to create OpenTelemetry span for node %s: %s", node.id, e)
|
||||
|
||||
@override
|
||||
def on_node_run_end(self, node: Node, error: Exception | None) -> None:
|
||||
def on_node_run_end(
|
||||
self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Called when a node finishes execution.
|
||||
|
||||
|
|
@ -139,7 +146,7 @@ class ObservabilityLayer(GraphEngineLayer):
|
|||
span = node_context.span
|
||||
parser = self._get_parser(node)
|
||||
try:
|
||||
parser.parse(node=node, span=span, error=error)
|
||||
parser.parse(node=node, span=span, error=error, result_event=result_event)
|
||||
span.end()
|
||||
finally:
|
||||
token = node_context.token
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from typing_extensions import override
|
|||
from core.workflow.context import IExecutionContext
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
|
||||
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, is_node_result_event
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
from .ready_queue import ReadyQueue
|
||||
|
|
@ -131,6 +131,7 @@ class Worker(threading.Thread):
|
|||
node.ensure_execution_id()
|
||||
|
||||
error: Exception | None = None
|
||||
result_event: GraphNodeEventBase | None = None
|
||||
|
||||
# Execute the node with preserved context if execution context is provided
|
||||
if self._execution_context is not None:
|
||||
|
|
@ -140,22 +141,26 @@ class Worker(threading.Thread):
|
|||
node_events = node.run()
|
||||
for event in node_events:
|
||||
self._event_queue.put(event)
|
||||
if is_node_result_event(event):
|
||||
result_event = event
|
||||
except Exception as exc:
|
||||
error = exc
|
||||
raise
|
||||
finally:
|
||||
self._invoke_node_run_end_hooks(node, error)
|
||||
self._invoke_node_run_end_hooks(node, error, result_event)
|
||||
else:
|
||||
self._invoke_node_run_start_hooks(node)
|
||||
try:
|
||||
node_events = node.run()
|
||||
for event in node_events:
|
||||
self._event_queue.put(event)
|
||||
if is_node_result_event(event):
|
||||
result_event = event
|
||||
except Exception as exc:
|
||||
error = exc
|
||||
raise
|
||||
finally:
|
||||
self._invoke_node_run_end_hooks(node, error)
|
||||
self._invoke_node_run_end_hooks(node, error, result_event)
|
||||
|
||||
def _invoke_node_run_start_hooks(self, node: Node) -> None:
|
||||
"""Invoke on_node_run_start hooks for all layers."""
|
||||
|
|
@ -166,11 +171,13 @@ class Worker(threading.Thread):
|
|||
# Silently ignore layer errors to prevent disrupting node execution
|
||||
continue
|
||||
|
||||
def _invoke_node_run_end_hooks(self, node: Node, error: Exception | None) -> None:
|
||||
def _invoke_node_run_end_hooks(
|
||||
self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None
|
||||
) -> None:
|
||||
"""Invoke on_node_run_end hooks for all layers."""
|
||||
for layer in self._layers:
|
||||
try:
|
||||
layer.on_node_run_end(node, error)
|
||||
layer.on_node_run_end(node, error, result_event)
|
||||
except Exception:
|
||||
# Silently ignore layer errors to prevent disrupting node execution
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ from .node import (
|
|||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
is_node_result_event,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -73,4 +74,5 @@ __all__ = [
|
|||
"NodeRunStartedEvent",
|
||||
"NodeRunStreamChunkEvent",
|
||||
"NodeRunSucceededEvent",
|
||||
"is_node_result_event",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -56,3 +56,26 @@ class NodeRunRetryEvent(NodeRunStartedEvent):
|
|||
|
||||
class NodeRunPauseRequestedEvent(GraphNodeEventBase):
|
||||
reason: PauseReason = Field(..., description="pause reason")
|
||||
|
||||
|
||||
def is_node_result_event(event: GraphNodeEventBase) -> bool:
|
||||
"""
|
||||
Check if an event is a final result event from node execution.
|
||||
|
||||
A result event indicates the completion of a node execution and contains
|
||||
runtime information such as inputs, outputs, or error details.
|
||||
|
||||
Args:
|
||||
event: The event to check
|
||||
|
||||
Returns:
|
||||
True if the event is a node result event (succeeded/failed/paused), False otherwise
|
||||
"""
|
||||
return isinstance(
|
||||
event,
|
||||
(
|
||||
NodeRunSucceededEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunPauseRequestedEvent,
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,20 @@
|
|||
"""
|
||||
OpenTelemetry node parsers for workflow nodes.
|
||||
|
||||
This module provides parsers that extract node-specific metadata and set
|
||||
OpenTelemetry span attributes according to semantic conventions.
|
||||
"""
|
||||
|
||||
from extensions.otel.parser.base import DefaultNodeOTelParser, NodeOTelParser, safe_json_dumps
|
||||
from extensions.otel.parser.llm import LLMNodeOTelParser
|
||||
from extensions.otel.parser.retrieval import RetrievalNodeOTelParser
|
||||
from extensions.otel.parser.tool import ToolNodeOTelParser
|
||||
|
||||
__all__ = [
|
||||
"DefaultNodeOTelParser",
|
||||
"LLMNodeOTelParser",
|
||||
"NodeOTelParser",
|
||||
"RetrievalNodeOTelParser",
|
||||
"ToolNodeOTelParser",
|
||||
"safe_json_dumps",
|
||||
]
|
||||
|
|
@ -0,0 +1,117 @@
|
|||
"""
|
||||
Base parser interface and utilities for OpenTelemetry node parsers.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, Protocol
|
||||
|
||||
from opentelemetry.trace import Span
|
||||
from opentelemetry.trace.status import Status, StatusCode
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.file.models import File
|
||||
from core.variables import Segment
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph_events import GraphNodeEventBase
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from extensions.otel.semconv.gen_ai import ChainAttributes, GenAIAttributes
|
||||
|
||||
|
||||
def safe_json_dumps(obj: Any, ensure_ascii: bool = False) -> str:
|
||||
"""
|
||||
Safely serialize objects to JSON, handling non-serializable types.
|
||||
|
||||
Handles:
|
||||
- Segment types (ArrayFileSegment, FileSegment, etc.) - converts to their value
|
||||
- File objects - converts to dict using to_dict()
|
||||
- BaseModel objects - converts using model_dump()
|
||||
- Other types - falls back to str() representation
|
||||
|
||||
Args:
|
||||
obj: Object to serialize
|
||||
ensure_ascii: Whether to ensure ASCII encoding
|
||||
|
||||
Returns:
|
||||
JSON string representation of the object
|
||||
"""
|
||||
|
||||
def _convert_value(value: Any) -> Any:
|
||||
"""Recursively convert non-serializable values."""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, (bool, int, float, str)):
|
||||
return value
|
||||
if isinstance(value, Segment):
|
||||
# Convert Segment to its underlying value
|
||||
return _convert_value(value.value)
|
||||
if isinstance(value, File):
|
||||
# Convert File to dict
|
||||
return value.to_dict()
|
||||
if isinstance(value, BaseModel):
|
||||
# Convert Pydantic model to dict
|
||||
return _convert_value(value.model_dump(mode="json"))
|
||||
if isinstance(value, dict):
|
||||
return {k: _convert_value(v) for k, v in value.items()}
|
||||
if isinstance(value, (list, tuple)):
|
||||
return [_convert_value(item) for item in value]
|
||||
# Fallback to string representation for unknown types
|
||||
return str(value)
|
||||
|
||||
try:
|
||||
converted = _convert_value(obj)
|
||||
return json.dumps(converted, ensure_ascii=ensure_ascii)
|
||||
except (TypeError, ValueError) as e:
|
||||
# If conversion still fails, return error message as string
|
||||
return json.dumps(
|
||||
{"error": f"Failed to serialize: {type(obj).__name__}", "message": str(e)}, ensure_ascii=ensure_ascii
|
||||
)
|
||||
|
||||
|
||||
class NodeOTelParser(Protocol):
|
||||
"""Parser interface for node-specific OpenTelemetry enrichment."""
|
||||
|
||||
def parse(
|
||||
self, *, node: Node, span: "Span", error: Exception | None, result_event: GraphNodeEventBase | None = None
|
||||
) -> None: ...
|
||||
|
||||
|
||||
class DefaultNodeOTelParser:
|
||||
"""Fallback parser used when no node-specific parser is registered."""
|
||||
|
||||
def parse(
|
||||
self, *, node: Node, span: "Span", error: Exception | None, result_event: GraphNodeEventBase | None = None
|
||||
) -> None:
|
||||
span.set_attribute("node.id", node.id)
|
||||
if node.execution_id:
|
||||
span.set_attribute("node.execution_id", node.execution_id)
|
||||
if hasattr(node, "node_type") and node.node_type:
|
||||
span.set_attribute("node.type", node.node_type.value)
|
||||
|
||||
span.set_attribute(GenAIAttributes.FRAMEWORK, "dify")
|
||||
|
||||
node_type = getattr(node, "node_type", None)
|
||||
if isinstance(node_type, NodeType):
|
||||
if node_type == NodeType.LLM:
|
||||
span.set_attribute(GenAIAttributes.SPAN_KIND, "LLM")
|
||||
elif node_type == NodeType.KNOWLEDGE_RETRIEVAL:
|
||||
span.set_attribute(GenAIAttributes.SPAN_KIND, "RETRIEVER")
|
||||
elif node_type == NodeType.TOOL:
|
||||
span.set_attribute(GenAIAttributes.SPAN_KIND, "TOOL")
|
||||
else:
|
||||
span.set_attribute(GenAIAttributes.SPAN_KIND, "TASK")
|
||||
else:
|
||||
span.set_attribute(GenAIAttributes.SPAN_KIND, "TASK")
|
||||
|
||||
# Extract inputs and outputs from result_event
|
||||
if result_event and result_event.node_run_result:
|
||||
node_run_result = result_event.node_run_result
|
||||
if node_run_result.inputs:
|
||||
span.set_attribute(ChainAttributes.INPUT_VALUE, safe_json_dumps(node_run_result.inputs))
|
||||
if node_run_result.outputs:
|
||||
span.set_attribute(ChainAttributes.OUTPUT_VALUE, safe_json_dumps(node_run_result.outputs))
|
||||
|
||||
if error:
|
||||
span.record_exception(error)
|
||||
span.set_status(Status(StatusCode.ERROR, str(error)))
|
||||
else:
|
||||
span.set_status(Status(StatusCode.OK))
|
||||
|
|
@ -0,0 +1,155 @@
|
|||
"""
|
||||
Parser for LLM nodes that captures LLM-specific metadata.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from opentelemetry.trace import Span
|
||||
|
||||
from core.workflow.graph_events import GraphNodeEventBase
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps
|
||||
from extensions.otel.semconv.gen_ai import LLMAttributes
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _format_input_messages(process_data: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
Format input messages from process_data for LLM spans.
|
||||
|
||||
Args:
|
||||
process_data: Process data containing prompts
|
||||
|
||||
Returns:
|
||||
JSON string of formatted input messages
|
||||
"""
|
||||
try:
|
||||
if not isinstance(process_data, dict):
|
||||
return safe_json_dumps([])
|
||||
|
||||
prompts = process_data.get("prompts", [])
|
||||
if not prompts:
|
||||
return safe_json_dumps([])
|
||||
|
||||
valid_roles = {"system", "user", "assistant", "tool"}
|
||||
input_messages = []
|
||||
for prompt in prompts:
|
||||
if not isinstance(prompt, dict):
|
||||
continue
|
||||
|
||||
role = prompt.get("role", "")
|
||||
text = prompt.get("text", "")
|
||||
|
||||
if not role or role not in valid_roles:
|
||||
continue
|
||||
|
||||
if text:
|
||||
message = {"role": role, "parts": [{"type": "text", "content": text}]}
|
||||
input_messages.append(message)
|
||||
|
||||
return safe_json_dumps(input_messages)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to format input messages: %s", e, exc_info=True)
|
||||
return safe_json_dumps([])
|
||||
|
||||
|
||||
def _format_output_messages(outputs: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
Format output messages from outputs for LLM spans.
|
||||
|
||||
Args:
|
||||
outputs: Output data containing text and finish_reason
|
||||
|
||||
Returns:
|
||||
JSON string of formatted output messages
|
||||
"""
|
||||
try:
|
||||
if not isinstance(outputs, dict):
|
||||
return safe_json_dumps([])
|
||||
|
||||
text = outputs.get("text", "")
|
||||
finish_reason = outputs.get("finish_reason", "")
|
||||
|
||||
if not text:
|
||||
return safe_json_dumps([])
|
||||
|
||||
valid_finish_reasons = {"stop", "length", "content_filter", "tool_call", "error"}
|
||||
if finish_reason not in valid_finish_reasons:
|
||||
finish_reason = "stop"
|
||||
|
||||
output_message = {
|
||||
"role": "assistant",
|
||||
"parts": [{"type": "text", "content": text}],
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
|
||||
return safe_json_dumps([output_message])
|
||||
except Exception as e:
|
||||
logger.warning("Failed to format output messages: %s", e, exc_info=True)
|
||||
return safe_json_dumps([])
|
||||
|
||||
|
||||
class LLMNodeOTelParser:
|
||||
"""Parser for LLM nodes that captures LLM-specific metadata."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._delegate = DefaultNodeOTelParser()
|
||||
|
||||
def parse(
|
||||
self, *, node: Node, span: "Span", error: Exception | None, result_event: GraphNodeEventBase | None = None
|
||||
) -> None:
|
||||
self._delegate.parse(node=node, span=span, error=error, result_event=result_event)
|
||||
|
||||
if not result_event or not result_event.node_run_result:
|
||||
return
|
||||
|
||||
node_run_result = result_event.node_run_result
|
||||
process_data = node_run_result.process_data or {}
|
||||
outputs = node_run_result.outputs or {}
|
||||
|
||||
# Extract usage data (from process_data or outputs)
|
||||
usage_data = process_data.get("usage") or outputs.get("usage") or {}
|
||||
|
||||
# Model and provider information
|
||||
model_name = process_data.get("model_name") or ""
|
||||
model_provider = process_data.get("model_provider") or ""
|
||||
|
||||
if model_name:
|
||||
span.set_attribute(LLMAttributes.REQUEST_MODEL, model_name)
|
||||
if model_provider:
|
||||
span.set_attribute(LLMAttributes.PROVIDER_NAME, model_provider)
|
||||
|
||||
# Token usage
|
||||
if usage_data:
|
||||
prompt_tokens = usage_data.get("prompt_tokens", 0)
|
||||
completion_tokens = usage_data.get("completion_tokens", 0)
|
||||
total_tokens = usage_data.get("total_tokens", 0)
|
||||
|
||||
span.set_attribute(LLMAttributes.USAGE_INPUT_TOKENS, prompt_tokens)
|
||||
span.set_attribute(LLMAttributes.USAGE_OUTPUT_TOKENS, completion_tokens)
|
||||
span.set_attribute(LLMAttributes.USAGE_TOTAL_TOKENS, total_tokens)
|
||||
|
||||
# Prompts and completion
|
||||
prompts = process_data.get("prompts", [])
|
||||
if prompts:
|
||||
prompts_json = safe_json_dumps(prompts)
|
||||
span.set_attribute(LLMAttributes.PROMPT, prompts_json)
|
||||
|
||||
text_output = str(outputs.get("text", ""))
|
||||
if text_output:
|
||||
span.set_attribute(LLMAttributes.COMPLETION, text_output)
|
||||
|
||||
# Finish reason
|
||||
finish_reason = outputs.get("finish_reason") or ""
|
||||
if finish_reason:
|
||||
span.set_attribute(LLMAttributes.RESPONSE_FINISH_REASON, finish_reason)
|
||||
|
||||
# Structured input/output messages
|
||||
gen_ai_input_message = _format_input_messages(process_data)
|
||||
gen_ai_output_message = _format_output_messages(outputs)
|
||||
|
||||
span.set_attribute(LLMAttributes.INPUT_MESSAGE, gen_ai_input_message)
|
||||
span.set_attribute(LLMAttributes.OUTPUT_MESSAGE, gen_ai_output_message)
|
||||
|
|
@ -0,0 +1,105 @@
|
|||
"""
|
||||
Parser for knowledge retrieval nodes that captures retrieval-specific metadata.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from opentelemetry.trace import Span
|
||||
|
||||
from core.variables import Segment
|
||||
from core.workflow.graph_events import GraphNodeEventBase
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps
|
||||
from extensions.otel.semconv.gen_ai import RetrieverAttributes
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _format_retrieval_documents(retrieval_documents: list[Any]) -> list:
|
||||
"""
|
||||
Format retrieval documents for semantic conventions.
|
||||
|
||||
Args:
|
||||
retrieval_documents: List of retrieval document dictionaries
|
||||
|
||||
Returns:
|
||||
List of formatted semantic documents
|
||||
"""
|
||||
try:
|
||||
if not isinstance(retrieval_documents, list):
|
||||
return []
|
||||
|
||||
semantic_documents = []
|
||||
for doc in retrieval_documents:
|
||||
if not isinstance(doc, dict):
|
||||
continue
|
||||
|
||||
metadata = doc.get("metadata", {})
|
||||
content = doc.get("content", "")
|
||||
title = doc.get("title", "")
|
||||
score = metadata.get("score", 0.0)
|
||||
document_id = metadata.get("document_id", "")
|
||||
|
||||
semantic_metadata = {}
|
||||
if title:
|
||||
semantic_metadata["title"] = title
|
||||
if metadata.get("source"):
|
||||
semantic_metadata["source"] = metadata["source"]
|
||||
elif metadata.get("_source"):
|
||||
semantic_metadata["source"] = metadata["_source"]
|
||||
if metadata.get("doc_metadata"):
|
||||
doc_metadata = metadata["doc_metadata"]
|
||||
if isinstance(doc_metadata, dict):
|
||||
semantic_metadata.update(doc_metadata)
|
||||
|
||||
semantic_doc = {
|
||||
"document": {"content": content, "metadata": semantic_metadata, "score": score, "id": document_id}
|
||||
}
|
||||
semantic_documents.append(semantic_doc)
|
||||
|
||||
return semantic_documents
|
||||
except Exception as e:
|
||||
logger.warning("Failed to format retrieval documents: %s", e, exc_info=True)
|
||||
return []
|
||||
|
||||
|
||||
class RetrievalNodeOTelParser:
|
||||
"""Parser for knowledge retrieval nodes that captures retrieval-specific metadata."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._delegate = DefaultNodeOTelParser()
|
||||
|
||||
def parse(
|
||||
self, *, node: Node, span: "Span", error: Exception | None, result_event: GraphNodeEventBase | None = None
|
||||
) -> None:
|
||||
self._delegate.parse(node=node, span=span, error=error, result_event=result_event)
|
||||
|
||||
if not result_event or not result_event.node_run_result:
|
||||
return
|
||||
|
||||
node_run_result = result_event.node_run_result
|
||||
inputs = node_run_result.inputs or {}
|
||||
outputs = node_run_result.outputs or {}
|
||||
|
||||
# Extract query from inputs
|
||||
query = str(inputs.get("query", "")) if inputs else ""
|
||||
if query:
|
||||
span.set_attribute(RetrieverAttributes.QUERY, query)
|
||||
|
||||
# Extract and format retrieval documents from outputs
|
||||
result_value = outputs.get("result") if outputs else None
|
||||
retrieval_documents: list[Any] = []
|
||||
if result_value:
|
||||
value_to_check = result_value
|
||||
if isinstance(result_value, Segment):
|
||||
value_to_check = result_value.value
|
||||
|
||||
if isinstance(value_to_check, (list, Sequence)):
|
||||
retrieval_documents = list(value_to_check)
|
||||
|
||||
if retrieval_documents:
|
||||
semantic_retrieval_documents = _format_retrieval_documents(retrieval_documents)
|
||||
semantic_retrieval_documents_json = safe_json_dumps(semantic_retrieval_documents)
|
||||
span.set_attribute(RetrieverAttributes.DOCUMENT, semantic_retrieval_documents_json)
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
"""
|
||||
Parser for tool nodes that captures tool-specific metadata.
|
||||
"""
|
||||
|
||||
from opentelemetry.trace import Span
|
||||
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.graph_events import GraphNodeEventBase
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps
|
||||
from extensions.otel.semconv.gen_ai import ToolAttributes
|
||||
|
||||
|
||||
class ToolNodeOTelParser:
|
||||
"""Parser for tool nodes that captures tool-specific metadata."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._delegate = DefaultNodeOTelParser()
|
||||
|
||||
def parse(
|
||||
self, *, node: Node, span: "Span", error: Exception | None, result_event: GraphNodeEventBase | None = None
|
||||
) -> None:
|
||||
self._delegate.parse(node=node, span=span, error=error, result_event=result_event)
|
||||
|
||||
tool_data = getattr(node, "_node_data", None)
|
||||
if not isinstance(tool_data, ToolNodeData):
|
||||
return
|
||||
|
||||
span.set_attribute(ToolAttributes.TOOL_NAME, node.title)
|
||||
span.set_attribute(ToolAttributes.TOOL_TYPE, tool_data.provider_type.value)
|
||||
|
||||
# Extract tool info from metadata (consistent with aliyun_trace)
|
||||
tool_info = {}
|
||||
if result_event and result_event.node_run_result:
|
||||
node_run_result = result_event.node_run_result
|
||||
if node_run_result.metadata:
|
||||
tool_info = node_run_result.metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO, {})
|
||||
|
||||
if tool_info:
|
||||
span.set_attribute(ToolAttributes.TOOL_DESCRIPTION, safe_json_dumps(tool_info))
|
||||
|
||||
if result_event and result_event.node_run_result and result_event.node_run_result.inputs:
|
||||
span.set_attribute(ToolAttributes.TOOL_CALL_ARGUMENTS, safe_json_dumps(result_event.node_run_result.inputs))
|
||||
|
||||
if result_event and result_event.node_run_result and result_event.node_run_result.outputs:
|
||||
span.set_attribute(ToolAttributes.TOOL_CALL_RESULT, safe_json_dumps(result_event.node_run_result.outputs))
|
||||
|
|
@ -1,6 +1,13 @@
|
|||
"""Semantic convention shortcuts for Dify-specific spans."""
|
||||
|
||||
from .dify import DifySpanAttributes
|
||||
from .gen_ai import GenAIAttributes
|
||||
from .gen_ai import ChainAttributes, GenAIAttributes, LLMAttributes, RetrieverAttributes, ToolAttributes
|
||||
|
||||
__all__ = ["DifySpanAttributes", "GenAIAttributes"]
|
||||
__all__ = [
|
||||
"ChainAttributes",
|
||||
"DifySpanAttributes",
|
||||
"GenAIAttributes",
|
||||
"LLMAttributes",
|
||||
"RetrieverAttributes",
|
||||
"ToolAttributes",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -62,3 +62,37 @@ class ToolAttributes:
|
|||
|
||||
TOOL_CALL_RESULT = "gen_ai.tool.call.result"
|
||||
"""Tool invocation result."""
|
||||
|
||||
|
||||
class LLMAttributes:
|
||||
"""LLM operation attribute keys."""
|
||||
|
||||
REQUEST_MODEL = "gen_ai.request.model"
|
||||
"""Model identifier."""
|
||||
|
||||
PROVIDER_NAME = "gen_ai.provider.name"
|
||||
"""Provider name."""
|
||||
|
||||
USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens"
|
||||
"""Number of input tokens."""
|
||||
|
||||
USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens"
|
||||
"""Number of output tokens."""
|
||||
|
||||
USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens"
|
||||
"""Total number of tokens."""
|
||||
|
||||
PROMPT = "gen_ai.prompt"
|
||||
"""Prompt text."""
|
||||
|
||||
COMPLETION = "gen_ai.completion"
|
||||
"""Completion text."""
|
||||
|
||||
RESPONSE_FINISH_REASON = "gen_ai.response.finish_reason"
|
||||
"""Finish reason for the response."""
|
||||
|
||||
INPUT_MESSAGE = "gen_ai.input.messages"
|
||||
"""Input messages in structured format."""
|
||||
|
||||
OUTPUT_MESSAGE = "gen_ai.output.messages"
|
||||
"""Output messages in structured format."""
|
||||
|
|
|
|||
|
|
@ -99,3 +99,38 @@ def mock_is_instrument_flag_enabled_true():
|
|||
"""Mock is_instrument_flag_enabled to return True."""
|
||||
with patch("core.workflow.graph_engine.layers.observability.is_instrument_flag_enabled", return_value=True):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_retrieval_node():
|
||||
"""Create a mock Knowledge Retrieval Node."""
|
||||
node = MagicMock()
|
||||
node.id = "test-retrieval-node-id"
|
||||
node.title = "Retrieval Node"
|
||||
node.execution_id = "test-retrieval-execution-id"
|
||||
node.node_type = NodeType.KNOWLEDGE_RETRIEVAL
|
||||
return node
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_result_event():
|
||||
"""Create a mock result event with NodeRunResult."""
|
||||
from datetime import datetime
|
||||
|
||||
from core.workflow.graph_events.node import NodeRunSucceededEvent
|
||||
from core.workflow.node_events.base import NodeRunResult
|
||||
|
||||
node_run_result = NodeRunResult(
|
||||
inputs={"query": "test query"},
|
||||
outputs={"result": [{"content": "test content", "metadata": {}}]},
|
||||
process_data={},
|
||||
metadata={},
|
||||
)
|
||||
|
||||
return NodeRunSucceededEvent(
|
||||
id="test-execution-id",
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.LLM,
|
||||
start_at=datetime.now(),
|
||||
node_run_result=node_run_result,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,8 @@ Tests for ObservabilityLayer.
|
|||
Test coverage:
|
||||
- Initialization and enable/disable logic
|
||||
- Node span lifecycle (start, end, error handling)
|
||||
- Parser integration (default and tool-specific)
|
||||
- Parser integration (default, tool, LLM, and retrieval parsers)
|
||||
- Result event parameter extraction (inputs/outputs)
|
||||
- Graph lifecycle management
|
||||
- Disabled mode behavior
|
||||
"""
|
||||
|
|
@ -134,9 +135,101 @@ class TestObservabilityLayerParserIntegration:
|
|||
assert len(spans) == 1
|
||||
attrs = spans[0].attributes
|
||||
assert attrs["node.id"] == mock_tool_node.id
|
||||
assert attrs["tool.provider.id"] == mock_tool_node._node_data.provider_id
|
||||
assert attrs["tool.provider.type"] == mock_tool_node._node_data.provider_type.value
|
||||
assert attrs["tool.name"] == mock_tool_node._node_data.tool_name
|
||||
assert attrs["gen_ai.tool.name"] == mock_tool_node.title
|
||||
assert attrs["gen_ai.tool.type"] == mock_tool_node._node_data.provider_type.value
|
||||
|
||||
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
|
||||
def test_llm_parser_used_for_llm_node(
|
||||
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node, mock_result_event
|
||||
):
|
||||
"""Test that LLM parser is used for LLM nodes and extracts LLM-specific attributes."""
|
||||
from core.workflow.node_events.base import NodeRunResult
|
||||
|
||||
mock_result_event.node_run_result = NodeRunResult(
|
||||
inputs={},
|
||||
outputs={"text": "test completion", "finish_reason": "stop"},
|
||||
process_data={
|
||||
"model_name": "gpt-4",
|
||||
"model_provider": "openai",
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
|
||||
"prompts": [{"role": "user", "text": "test prompt"}],
|
||||
},
|
||||
metadata={},
|
||||
)
|
||||
|
||||
layer = ObservabilityLayer()
|
||||
layer.on_graph_start()
|
||||
|
||||
layer.on_node_run_start(mock_llm_node)
|
||||
layer.on_node_run_end(mock_llm_node, None, mock_result_event)
|
||||
|
||||
spans = memory_span_exporter.get_finished_spans()
|
||||
assert len(spans) == 1
|
||||
attrs = spans[0].attributes
|
||||
assert attrs["node.id"] == mock_llm_node.id
|
||||
assert attrs["gen_ai.request.model"] == "gpt-4"
|
||||
assert attrs["gen_ai.provider.name"] == "openai"
|
||||
assert attrs["gen_ai.usage.input_tokens"] == 10
|
||||
assert attrs["gen_ai.usage.output_tokens"] == 20
|
||||
assert attrs["gen_ai.usage.total_tokens"] == 30
|
||||
assert attrs["gen_ai.completion"] == "test completion"
|
||||
assert attrs["gen_ai.response.finish_reason"] == "stop"
|
||||
|
||||
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
|
||||
def test_retrieval_parser_used_for_retrieval_node(
|
||||
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_retrieval_node, mock_result_event
|
||||
):
|
||||
"""Test that retrieval parser is used for retrieval nodes and extracts retrieval-specific attributes."""
|
||||
from core.workflow.node_events.base import NodeRunResult
|
||||
|
||||
mock_result_event.node_run_result = NodeRunResult(
|
||||
inputs={"query": "test query"},
|
||||
outputs={"result": [{"content": "test content", "metadata": {"score": 0.9, "document_id": "doc1"}}]},
|
||||
process_data={},
|
||||
metadata={},
|
||||
)
|
||||
|
||||
layer = ObservabilityLayer()
|
||||
layer.on_graph_start()
|
||||
|
||||
layer.on_node_run_start(mock_retrieval_node)
|
||||
layer.on_node_run_end(mock_retrieval_node, None, mock_result_event)
|
||||
|
||||
spans = memory_span_exporter.get_finished_spans()
|
||||
assert len(spans) == 1
|
||||
attrs = spans[0].attributes
|
||||
assert attrs["node.id"] == mock_retrieval_node.id
|
||||
assert attrs["retrieval.query"] == "test query"
|
||||
assert "retrieval.document" in attrs
|
||||
|
||||
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
|
||||
def test_result_event_extracts_inputs_and_outputs(
|
||||
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_start_node, mock_result_event
|
||||
):
|
||||
"""Test that result_event parameter allows parsers to extract inputs and outputs."""
|
||||
from core.workflow.node_events.base import NodeRunResult
|
||||
|
||||
mock_result_event.node_run_result = NodeRunResult(
|
||||
inputs={"input_key": "input_value"},
|
||||
outputs={"output_key": "output_value"},
|
||||
process_data={},
|
||||
metadata={},
|
||||
)
|
||||
|
||||
layer = ObservabilityLayer()
|
||||
layer.on_graph_start()
|
||||
|
||||
layer.on_node_run_start(mock_start_node)
|
||||
layer.on_node_run_end(mock_start_node, None, mock_result_event)
|
||||
|
||||
spans = memory_span_exporter.get_finished_spans()
|
||||
assert len(spans) == 1
|
||||
attrs = spans[0].attributes
|
||||
assert "input.value" in attrs
|
||||
assert "output.value" in attrs
|
||||
|
||||
|
||||
class TestObservabilityLayerGraphLifecycle:
|
||||
|
|
|
|||
Loading…
Reference in New Issue