fix: address arize phoenix trace review feedback

This commit is contained in:
Blackoutta 2026-04-21 16:53:33 +08:00
parent 2126a5a11b
commit 9c6ae362b3
2 changed files with 151 additions and 23 deletions

View File

@ -2,6 +2,7 @@ import hashlib
import json
import logging
import os
import traceback
from datetime import datetime, timedelta
from typing import Any, Union, cast
from urllib.parse import urlparse
@ -20,7 +21,10 @@ from opentelemetry.sdk import trace as trace_sdk
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.id_generator import RandomIdGenerator
from opentelemetry.trace import SpanContext, TraceFlags, TraceState
from opentelemetry.semconv.attributes import exception_attributes
from opentelemetry.trace import Span, SpanContext, Status, StatusCode, TraceFlags, TraceState, use_span
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from opentelemetry.util.types import AttributeValue
from sqlalchemy import select, text
from core.ops.base_trace_instance import BaseTraceInstance
@ -75,22 +79,11 @@ def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[tra
"api_key": arize_phoenix_config.api_key or "",
"authorization": f"Bearer {arize_phoenix_config.api_key or ''}",
}
# Test connectivity first
try:
# Create a test exporter with short timeout
test_exporter = HttpOTLPSpanExporter(
endpoint=phoenix_endpoint,
headers=phoenix_headers,
timeout=5,
)
# Try to export an empty batch to test connectivity
test_exporter.export([])
logger.info("[Arize/Phoenix] Connectivity test successful")
test_exporter.timeout = 30
exporter = test_exporter
except Exception as connectivity_error:
logger.warning("[Arize/Phoenix] Connectivity test failed: %s, using shorter timeout", str(connectivity_error))
raise
exporter = HttpOTLPSpanExporter(
endpoint=phoenix_endpoint,
headers=phoenix_headers,
timeout=30,
)
attributes = {
"openinference.project.name": arize_phoenix_config.project or "",
@ -119,6 +112,54 @@ def datetime_to_nanos(dt: datetime | None) -> int:
return int(dt.timestamp() * 1_000_000_000)
def error_to_string(error: Exception | str | None) -> str:
"""Convert an error to a string with traceback information for Arize/Phoenix."""
error_message = "Empty Stack Trace"
if error:
if isinstance(error, Exception):
string_stacktrace = "".join(traceback.format_exception(error))
error_message = f"{error.__class__.__name__}: {error}\n\n{string_stacktrace}"
else:
error_message = str(error)
return error_message
def set_span_status(current_span: Span, error: Exception | str | None = None):
"""Set the status of the current span based on the presence of an error for Arize/Phoenix."""
if error:
error_string = error_to_string(error)
current_span.set_status(Status(StatusCode.ERROR, error_string))
if isinstance(error, Exception):
current_span.record_exception(error)
else:
exception_type = error.__class__.__name__
exception_message = str(error)
if not exception_message:
exception_message = repr(error)
attributes: dict[str, AttributeValue] = {
exception_attributes.EXCEPTION_TYPE: exception_type,
exception_attributes.EXCEPTION_MESSAGE: exception_message,
exception_attributes.EXCEPTION_ESCAPED: False,
exception_attributes.EXCEPTION_STACKTRACE: error_string,
}
current_span.add_event(name="exception", attributes=attributes)
else:
current_span.set_status(Status(StatusCode.OK))
def safe_json_dumps(obj: Any) -> str:
"""A convenience wrapper to ensure that any object can be safely encoded for Arize/Phoenix."""
return json.dumps(obj, default=str, ensure_ascii=False)
def wrap_span_metadata(metadata, **kwargs):
"""Add common metatada to all trace entity types for Arize/Phoenix."""
metadata["created_from"] = "Dify"
metadata.update(kwargs)
return metadata
def string_to_trace_id128(string: str | None) -> int:
"""
Convert any input string into a stable 128-bit integer trace ID.
@ -155,20 +196,31 @@ def string_to_span_id64(string: str | None) -> int:
return int.from_bytes(digest, byteorder="big")
_NODE_TYPE_TO_SPAN_KIND: dict[str, OpenInferenceSpanKindValues] = {
"llm": OpenInferenceSpanKindValues.LLM,
"knowledge-retrieval": OpenInferenceSpanKindValues.RETRIEVER,
"tool": OpenInferenceSpanKindValues.TOOL,
"agent": OpenInferenceSpanKindValues.AGENT,
}
def _get_node_span_kind(node_type: str) -> OpenInferenceSpanKindValues:
"""Return the OpenInference span kind for a given workflow node type."""
return _NODE_TYPE_TO_SPAN_KIND.get(node_type, OpenInferenceSpanKindValues.CHAIN)
class ArizePhoenixDataTrace(BaseTraceInstance):
def __init__(
self,
arize_phoenix_config: ArizeConfig | PhoenixConfig,
):
super().__init__(arize_phoenix_config)
import logging
logging.basicConfig()
logging.getLogger().setLevel(logging.DEBUG)
self.arize_phoenix_config = arize_phoenix_config
self.tracer, self.processor = setup_tracer(arize_phoenix_config)
self.project = arize_phoenix_config.project
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
self.propagator = TraceContextTextMapPropagator()
self.dify_trace_ids: set[str] = set()
def trace(self, trace_info: BaseTraceInfo):
logger.info("[Arize/Phoenix] Trace: %s", trace_info)
@ -673,6 +725,10 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
llm_attributes.update(self._construct_llm_attributes(process_data.get("prompts", [])))
node_span.set_attributes(llm_attributes)
finally:
if node_execution.status == WorkflowNodeExecutionStatus.FAILED:
set_span_status(node_span, node_execution.error)
else:
set_span_status(node_span)
node_span.end(end_time=datetime_to_nanos(finished_at))
except AttributeError as e:
@ -691,6 +747,10 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
logger.error(f"[Arize/Phoenix] Workflow tracing failed: {e}", exc_info=True)
raise ValueError(f"[Arize/Phoenix] Workflow trace failed: {str(e)}")
finally:
if trace_info.error:
set_span_status(workflow_span, trace_info.error)
else:
set_span_status(workflow_span)
workflow_span.end(end_time=datetime_to_nanos(trace_info.end_time))
def message_trace(self, trace_info: MessageTraceInfo):
@ -816,8 +876,16 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
},
)
finally:
if trace_info.error:
set_span_status(llm_span, trace_info.error)
else:
set_span_status(llm_span)
llm_span.end(end_time=datetime_to_nanos(trace_info.end_time))
finally:
if trace_info.error:
set_span_status(message_span, trace_info.error)
else:
set_span_status(message_span)
message_span.end(end_time=datetime_to_nanos(trace_info.end_time))
def moderation_trace(self, trace_info: ModerationTraceInfo):
@ -874,6 +942,10 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
},
)
finally:
if trace_info.message_data.error:
set_span_status(span, trace_info.message_data.error)
else:
set_span_status(span)
span.end(end_time=datetime_to_nanos(trace_info.end_time))
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
@ -928,6 +1000,10 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
},
)
finally:
if trace_info.error:
set_span_status(span, trace_info.error)
else:
set_span_status(span)
span.end(end_time=datetime_to_nanos(end_time))
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
@ -983,6 +1059,10 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
},
)
finally:
if trace_info.message_data.error:
set_span_status(span, trace_info.message_data.error)
else:
set_span_status(span)
span.end(end_time=datetime_to_nanos(end_time))
def tool_trace(self, trace_info: ToolTraceInfo):
@ -1041,6 +1121,10 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
},
)
finally:
if trace_info.error:
set_span_status(span, trace_info.error)
else:
set_span_status(span)
span.end(end_time=datetime_to_nanos(trace_info.end_time))
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
@ -1092,6 +1176,10 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
},
)
finally:
if trace_info.message_data.error:
set_span_status(span, trace_info.message_data.error)
else:
set_span_status(span)
span.end(end_time=datetime_to_nanos(trace_info.end_time))
def api_check(self):
@ -1103,6 +1191,23 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
logger.info("[Arize/Phoenix] API check failed: %s", str(e), exc_info=True)
raise ValueError(f"[Arize/Phoenix] API check failed: {str(e)}")
def ensure_root_span(self, dify_trace_id: str | None):
"""Ensure a unique root span exists for the given Dify trace ID."""
if str(dify_trace_id) not in self.dify_trace_ids:
self.carrier: dict[str, str] = {}
root_span = self.tracer.start_span(name="Dify")
root_span.set_attribute(SpanAttributes.OPENINFERENCE_SPAN_KIND, OpenInferenceSpanKindValues.CHAIN.value)
root_span.set_attribute("dify_project_name", str(self.project))
root_span.set_attribute("dify_trace_id", str(dify_trace_id))
with use_span(root_span, end_on_exit=False):
self.propagator.inject(carrier=self.carrier)
set_span_status(root_span)
root_span.end()
self.dify_trace_ids.add(str(dify_trace_id))
def get_project_url(self):
try:
if self.arize_phoenix_config.endpoint == "https://otlp.arize.com":
@ -1209,8 +1314,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
exec_id = exec_node.id
node_mapping[exec_id] = node_id
# Store execution context
execution_context[exec_id] = {
# Store execution context keyed by the same node identifier used in node_spans.
execution_context[node_id] = {
'node_type': exec_node.node_type,
'status': exec_node.status,
'index': getattr(exec_node, 'index', 0),

View File

@ -396,3 +396,26 @@ def test_api_check_success(trace_instance):
def test_ensure_root_span_basic(trace_instance):
trace_instance.ensure_root_span("tid")
assert "tid" in trace_instance.dify_trace_ids
def test_find_logical_parent_span_uses_matching_node_context_keys(trace_instance):
parent_span = MagicMock()
child_execution = MagicMock()
child_execution.id = "exec-child"
child_execution.node_id = "child-node"
child_execution.index = 3
logical_parent = trace_instance._find_logical_parent_span(
child_execution,
node_spans={"parent-node": parent_span},
execution_context={
"parent-node": {
"index": 1,
"node_type": "tool",
"status": "succeeded",
"created_at": _dt(),
}
},
)
assert logical_parent is parent_span