From 9c6ae362b37c7ce91d140fdf1cebf5ff730efb4f Mon Sep 17 00:00:00 2001 From: Blackoutta Date: Tue, 21 Apr 2026 16:53:33 +0800 Subject: [PATCH] fix: address arize phoenix trace review feedback --- .../arize_phoenix_trace.py | 151 +++++++++++++++--- .../test_arize_phoenix_trace.py | 23 +++ 2 files changed, 151 insertions(+), 23 deletions(-) diff --git a/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py index 7fb23482bf..f4e68d273e 100644 --- a/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py +++ b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py @@ -2,6 +2,7 @@ import hashlib import json import logging import os +import traceback from datetime import datetime, timedelta from typing import Any, Union, cast from urllib.parse import urlparse @@ -20,7 +21,10 @@ from opentelemetry.sdk import trace as trace_sdk from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.id_generator import RandomIdGenerator -from opentelemetry.trace import SpanContext, TraceFlags, TraceState +from opentelemetry.semconv.attributes import exception_attributes +from opentelemetry.trace import Span, SpanContext, Status, StatusCode, TraceFlags, TraceState, use_span +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator +from opentelemetry.util.types import AttributeValue from sqlalchemy import select, text from core.ops.base_trace_instance import BaseTraceInstance @@ -75,22 +79,11 @@ def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[tra "api_key": arize_phoenix_config.api_key or "", "authorization": f"Bearer {arize_phoenix_config.api_key or ''}", } - # Test connectivity first - try: - # Create a test exporter with short timeout - test_exporter = HttpOTLPSpanExporter( - endpoint=phoenix_endpoint, - headers=phoenix_headers, - timeout=5, - ) - # Try to export an empty batch to test connectivity - test_exporter.export([]) - logger.info("[Arize/Phoenix] Connectivity test successful") - test_exporter.timeout = 30 - exporter = test_exporter - except Exception as connectivity_error: - logger.warning("[Arize/Phoenix] Connectivity test failed: %s, using shorter timeout", str(connectivity_error)) - raise + exporter = HttpOTLPSpanExporter( + endpoint=phoenix_endpoint, + headers=phoenix_headers, + timeout=30, + ) attributes = { "openinference.project.name": arize_phoenix_config.project or "", @@ -119,6 +112,54 @@ def datetime_to_nanos(dt: datetime | None) -> int: return int(dt.timestamp() * 1_000_000_000) +def error_to_string(error: Exception | str | None) -> str: + """Convert an error to a string with traceback information for Arize/Phoenix.""" + error_message = "Empty Stack Trace" + if error: + if isinstance(error, Exception): + string_stacktrace = "".join(traceback.format_exception(error)) + error_message = f"{error.__class__.__name__}: {error}\n\n{string_stacktrace}" + else: + error_message = str(error) + return error_message + + +def set_span_status(current_span: Span, error: Exception | str | None = None): + """Set the status of the current span based on the presence of an error for Arize/Phoenix.""" + if error: + error_string = error_to_string(error) + current_span.set_status(Status(StatusCode.ERROR, error_string)) + + if isinstance(error, Exception): + current_span.record_exception(error) + else: + exception_type = error.__class__.__name__ + exception_message = str(error) + if not exception_message: + exception_message = repr(error) + attributes: dict[str, AttributeValue] = { + exception_attributes.EXCEPTION_TYPE: exception_type, + exception_attributes.EXCEPTION_MESSAGE: exception_message, + exception_attributes.EXCEPTION_ESCAPED: False, + exception_attributes.EXCEPTION_STACKTRACE: error_string, + } + current_span.add_event(name="exception", attributes=attributes) + else: + current_span.set_status(Status(StatusCode.OK)) + + +def safe_json_dumps(obj: Any) -> str: + """A convenience wrapper to ensure that any object can be safely encoded for Arize/Phoenix.""" + return json.dumps(obj, default=str, ensure_ascii=False) + + +def wrap_span_metadata(metadata, **kwargs): + """Add common metatada to all trace entity types for Arize/Phoenix.""" + metadata["created_from"] = "Dify" + metadata.update(kwargs) + return metadata + + def string_to_trace_id128(string: str | None) -> int: """ Convert any input string into a stable 128-bit integer trace ID. @@ -155,20 +196,31 @@ def string_to_span_id64(string: str | None) -> int: return int.from_bytes(digest, byteorder="big") +_NODE_TYPE_TO_SPAN_KIND: dict[str, OpenInferenceSpanKindValues] = { + "llm": OpenInferenceSpanKindValues.LLM, + "knowledge-retrieval": OpenInferenceSpanKindValues.RETRIEVER, + "tool": OpenInferenceSpanKindValues.TOOL, + "agent": OpenInferenceSpanKindValues.AGENT, +} + + +def _get_node_span_kind(node_type: str) -> OpenInferenceSpanKindValues: + """Return the OpenInference span kind for a given workflow node type.""" + return _NODE_TYPE_TO_SPAN_KIND.get(node_type, OpenInferenceSpanKindValues.CHAIN) + + class ArizePhoenixDataTrace(BaseTraceInstance): def __init__( self, arize_phoenix_config: ArizeConfig | PhoenixConfig, ): super().__init__(arize_phoenix_config) - import logging - - logging.basicConfig() - logging.getLogger().setLevel(logging.DEBUG) self.arize_phoenix_config = arize_phoenix_config self.tracer, self.processor = setup_tracer(arize_phoenix_config) self.project = arize_phoenix_config.project self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") + self.propagator = TraceContextTextMapPropagator() + self.dify_trace_ids: set[str] = set() def trace(self, trace_info: BaseTraceInfo): logger.info("[Arize/Phoenix] Trace: %s", trace_info) @@ -673,6 +725,10 @@ class ArizePhoenixDataTrace(BaseTraceInstance): llm_attributes.update(self._construct_llm_attributes(process_data.get("prompts", []))) node_span.set_attributes(llm_attributes) finally: + if node_execution.status == WorkflowNodeExecutionStatus.FAILED: + set_span_status(node_span, node_execution.error) + else: + set_span_status(node_span) node_span.end(end_time=datetime_to_nanos(finished_at)) except AttributeError as e: @@ -691,6 +747,10 @@ class ArizePhoenixDataTrace(BaseTraceInstance): logger.error(f"[Arize/Phoenix] Workflow tracing failed: {e}", exc_info=True) raise ValueError(f"[Arize/Phoenix] Workflow trace failed: {str(e)}") finally: + if trace_info.error: + set_span_status(workflow_span, trace_info.error) + else: + set_span_status(workflow_span) workflow_span.end(end_time=datetime_to_nanos(trace_info.end_time)) def message_trace(self, trace_info: MessageTraceInfo): @@ -816,8 +876,16 @@ class ArizePhoenixDataTrace(BaseTraceInstance): }, ) finally: + if trace_info.error: + set_span_status(llm_span, trace_info.error) + else: + set_span_status(llm_span) llm_span.end(end_time=datetime_to_nanos(trace_info.end_time)) finally: + if trace_info.error: + set_span_status(message_span, trace_info.error) + else: + set_span_status(message_span) message_span.end(end_time=datetime_to_nanos(trace_info.end_time)) def moderation_trace(self, trace_info: ModerationTraceInfo): @@ -874,6 +942,10 @@ class ArizePhoenixDataTrace(BaseTraceInstance): }, ) finally: + if trace_info.message_data.error: + set_span_status(span, trace_info.message_data.error) + else: + set_span_status(span) span.end(end_time=datetime_to_nanos(trace_info.end_time)) def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo): @@ -928,6 +1000,10 @@ class ArizePhoenixDataTrace(BaseTraceInstance): }, ) finally: + if trace_info.error: + set_span_status(span, trace_info.error) + else: + set_span_status(span) span.end(end_time=datetime_to_nanos(end_time)) def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo): @@ -983,6 +1059,10 @@ class ArizePhoenixDataTrace(BaseTraceInstance): }, ) finally: + if trace_info.message_data.error: + set_span_status(span, trace_info.message_data.error) + else: + set_span_status(span) span.end(end_time=datetime_to_nanos(end_time)) def tool_trace(self, trace_info: ToolTraceInfo): @@ -1041,6 +1121,10 @@ class ArizePhoenixDataTrace(BaseTraceInstance): }, ) finally: + if trace_info.error: + set_span_status(span, trace_info.error) + else: + set_span_status(span) span.end(end_time=datetime_to_nanos(trace_info.end_time)) def generate_name_trace(self, trace_info: GenerateNameTraceInfo): @@ -1092,6 +1176,10 @@ class ArizePhoenixDataTrace(BaseTraceInstance): }, ) finally: + if trace_info.message_data.error: + set_span_status(span, trace_info.message_data.error) + else: + set_span_status(span) span.end(end_time=datetime_to_nanos(trace_info.end_time)) def api_check(self): @@ -1103,6 +1191,23 @@ class ArizePhoenixDataTrace(BaseTraceInstance): logger.info("[Arize/Phoenix] API check failed: %s", str(e), exc_info=True) raise ValueError(f"[Arize/Phoenix] API check failed: {str(e)}") + def ensure_root_span(self, dify_trace_id: str | None): + """Ensure a unique root span exists for the given Dify trace ID.""" + if str(dify_trace_id) not in self.dify_trace_ids: + self.carrier: dict[str, str] = {} + + root_span = self.tracer.start_span(name="Dify") + root_span.set_attribute(SpanAttributes.OPENINFERENCE_SPAN_KIND, OpenInferenceSpanKindValues.CHAIN.value) + root_span.set_attribute("dify_project_name", str(self.project)) + root_span.set_attribute("dify_trace_id", str(dify_trace_id)) + + with use_span(root_span, end_on_exit=False): + self.propagator.inject(carrier=self.carrier) + + set_span_status(root_span) + root_span.end() + self.dify_trace_ids.add(str(dify_trace_id)) + def get_project_url(self): try: if self.arize_phoenix_config.endpoint == "https://otlp.arize.com": @@ -1209,8 +1314,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance): exec_id = exec_node.id node_mapping[exec_id] = node_id - # Store execution context - execution_context[exec_id] = { + # Store execution context keyed by the same node identifier used in node_spans. + execution_context[node_id] = { 'node_type': exec_node.node_type, 'status': exec_node.status, 'index': getattr(exec_node, 'index', 0), diff --git a/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py index b0691a87ea..26bb58e3cb 100644 --- a/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py +++ b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py @@ -396,3 +396,26 @@ def test_api_check_success(trace_instance): def test_ensure_root_span_basic(trace_instance): trace_instance.ensure_root_span("tid") assert "tid" in trace_instance.dify_trace_ids + + +def test_find_logical_parent_span_uses_matching_node_context_keys(trace_instance): + parent_span = MagicMock() + child_execution = MagicMock() + child_execution.id = "exec-child" + child_execution.node_id = "child-node" + child_execution.index = 3 + + logical_parent = trace_instance._find_logical_parent_span( + child_execution, + node_spans={"parent-node": parent_span}, + execution_context={ + "parent-node": { + "index": 1, + "node_type": "tool", + "status": "succeeded", + "created_at": _dt(), + } + }, + ) + + assert logical_parent is parent_span