diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index 03d2d75372..347992fa0d 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -1,21 +1,22 @@ -import hashlib import json import logging import os +import traceback from datetime import datetime, timedelta from typing import Any, Union, cast from urllib.parse import urlparse -from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes -from opentelemetry import trace +from openinference.semconv.trace import OpenInferenceMimeTypeValues, OpenInferenceSpanKindValues, SpanAttributes from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GrpcOTLPSpanExporter from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HttpOTLPSpanExporter from opentelemetry.sdk import trace as trace_sdk from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace.export import SimpleSpanProcessor -from opentelemetry.sdk.trace.id_generator import RandomIdGenerator -from opentelemetry.trace import SpanContext, TraceFlags, TraceState -from sqlalchemy import select +from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes +from opentelemetry.trace import Span, Status, StatusCode, set_span_in_context, use_span +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator +from opentelemetry.util.types import AttributeValue +from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig @@ -30,9 +31,10 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) +from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db from models.model import EndUser, MessageFile -from models.workflow import WorkflowNodeExecutionModel +from models.workflow import WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) @@ -99,22 +101,45 @@ def datetime_to_nanos(dt: datetime | None) -> int: return int(dt.timestamp() * 1_000_000_000) -def string_to_trace_id128(string: str | None) -> int: - """ - Convert any input string into a stable 128-bit integer trace ID. +def error_to_string(error: Exception | str | None) -> str: + """Convert an error to a string with traceback information.""" + error_message = "Empty Stack Trace" + if error: + if isinstance(error, Exception): + string_stacktrace = "".join(traceback.format_exception(error)) + error_message = f"{error.__class__.__name__}: {error}\n\n{string_stacktrace}" + else: + error_message = str(error) + return error_message - This uses SHA-256 hashing and takes the first 16 bytes (128 bits) of the digest. - It's suitable for generating consistent, unique identifiers from strings. - """ - if string is None: - string = "" - hash_object = hashlib.sha256(string.encode()) - # Take the first 16 bytes (128 bits) of the hash digest - digest = hash_object.digest()[:16] +def set_span_status(current_span: Span, error: Exception | str | None = None): + """Set the status of the current span based on the presence of an error.""" + if error: + error_string = error_to_string(error) + current_span.set_status(Status(StatusCode.ERROR, error_string)) - # Convert to a 128-bit integer - return int.from_bytes(digest, byteorder="big") + if isinstance(error, Exception): + current_span.record_exception(error) + else: + exception_type = error.__class__.__name__ + exception_message = str(error) + if not exception_message: + exception_message = repr(error) + attributes: dict[str, AttributeValue] = { + OTELSpanAttributes.EXCEPTION_TYPE: exception_type, + OTELSpanAttributes.EXCEPTION_MESSAGE: exception_message, + OTELSpanAttributes.EXCEPTION_ESCAPED: False, + OTELSpanAttributes.EXCEPTION_STACKTRACE: error_string, + } + current_span.add_event(name="exception", attributes=attributes) + else: + current_span.set_status(Status(StatusCode.OK)) + + +def safe_json_dumps(obj: Any) -> str: + """A convenience wrapper around `json.dumps` that ensures that any object can be safely encoded.""" + return json.dumps(obj, default=str, ensure_ascii=False) class ArizePhoenixDataTrace(BaseTraceInstance): @@ -131,9 +156,12 @@ class ArizePhoenixDataTrace(BaseTraceInstance): self.tracer, self.processor = setup_tracer(arize_phoenix_config) self.project = arize_phoenix_config.project self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") + self.propagator = TraceContextTextMapPropagator() + self.dify_trace_ids: set[str] = set() def trace(self, trace_info: BaseTraceInfo): - logger.info("[Arize/Phoenix] Trace: %s", trace_info) + logger.info("[Arize/Phoenix] Trace Entity Info: %s", trace_info) + logger.info("[Arize/Phoenix] Trace Entity Type: %s", type(trace_info)) try: if isinstance(trace_info, WorkflowTraceInfo): self.workflow_trace(trace_info) @@ -151,7 +179,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): self.generate_name_trace(trace_info) except Exception as e: - logger.error("[Arize/Phoenix] Error in the trace: %s", str(e), exc_info=True) + logger.error("[Arize/Phoenix] Trace Entity Error: %s", str(e), exc_info=True) raise def workflow_trace(self, trace_info: WorkflowTraceInfo): @@ -166,15 +194,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance): } workflow_metadata.update(trace_info.metadata) - trace_id = string_to_trace_id128(trace_info.trace_id or trace_info.workflow_run_id) - span_id = RandomIdGenerator().generate_span_id() - context = SpanContext( - trace_id=trace_id, - span_id=span_id, - is_remote=False, - trace_flags=TraceFlags(TraceFlags.SAMPLED), - trace_state=TraceState(), - ) + dify_trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id + self.ensure_root_span(dify_trace_id) + root_span_context = self.propagator.extract(carrier=self.carrier) workflow_span = self.tracer.start_span( name=TraceTaskName.WORKFLOW_TRACE.value, @@ -186,31 +208,58 @@ class ArizePhoenixDataTrace(BaseTraceInstance): SpanAttributes.SESSION_ID: trace_info.conversation_id or "", }, start_time=datetime_to_nanos(trace_info.start_time), - context=trace.set_span_in_context(trace.NonRecordingSpan(context)), + context=root_span_context, + ) + + # Through workflow_run_id, get all_nodes_execution using repository + session_factory = sessionmaker(bind=db.engine) + + # Find the app's creator account + app_id = trace_info.metadata.get("app_id") + if not app_id: + raise ValueError("No app_id found in trace_info metadata") + + service_account = self.get_service_account_with_tenant(app_id) + + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( + session_factory=session_factory, + user=service_account, + app_id=app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + # Get all executions for this workflow run + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( + workflow_run_id=trace_info.workflow_run_id ) try: - # Process workflow nodes - for node_execution in self._get_workflow_nodes(trace_info.workflow_run_id): + for node_execution in workflow_node_executions: + tenant_id = trace_info.tenant_id # Use from trace_info instead + app_id = trace_info.metadata.get("app_id") # Use from trace_info instead + inputs_value = node_execution.inputs or {} + outputs_value = node_execution.outputs or {} + created_at = node_execution.created_at or datetime.now() elapsed_time = node_execution.elapsed_time finished_at = created_at + timedelta(seconds=elapsed_time) - process_data = json.loads(node_execution.process_data) if node_execution.process_data else {} + process_data = node_execution.process_data or {} + execution_metadata = node_execution.metadata or {} + node_metadata = {str(k): v for k, v in execution_metadata.items()} - node_metadata = { - "node_id": node_execution.id, - "node_type": node_execution.node_type, - "node_status": node_execution.status, - "tenant_id": node_execution.tenant_id, - "app_id": node_execution.app_id, - "app_name": node_execution.title, - "status": node_execution.status, - "level": "ERROR" if node_execution.status != "succeeded" else "DEFAULT", - } - - if node_execution.execution_metadata: - node_metadata.update(json.loads(node_execution.execution_metadata)) + node_metadata.update( + { + "node_id": node_execution.id, + "node_type": node_execution.node_type, + "node_status": node_execution.status, + "tenant_id": tenant_id, + "app_id": app_id, + "app_name": node_execution.title, + "status": node_execution.status, + "level": "ERROR" if node_execution.status == "failed" else "DEFAULT", + } + ) # Determine the correct span kind based on node type span_kind = OpenInferenceSpanKindValues.CHAIN @@ -223,8 +272,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance): if model: node_metadata["ls_model_name"] = model - outputs = json.loads(node_execution.outputs).get("usage", {}) if "outputs" in node_execution else {} - usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {}) + usage_data = ( + process_data.get("usage", {}) if "usage" in process_data else outputs_value.get("usage", {}) + ) if usage_data: node_metadata["total_tokens"] = usage_data.get("total_tokens", 0) node_metadata["prompt_tokens"] = usage_data.get("prompt_tokens", 0) @@ -236,17 +286,20 @@ class ArizePhoenixDataTrace(BaseTraceInstance): else: span_kind = OpenInferenceSpanKindValues.CHAIN + workflow_span_context = set_span_in_context(workflow_span) node_span = self.tracer.start_span( name=node_execution.node_type, attributes={ - SpanAttributes.INPUT_VALUE: node_execution.inputs or "{}", - SpanAttributes.OUTPUT_VALUE: node_execution.outputs or "{}", + SpanAttributes.INPUT_VALUE: safe_json_dumps(inputs_value), + SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + SpanAttributes.OUTPUT_VALUE: safe_json_dumps(outputs_value), + SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind.value, - SpanAttributes.METADATA: json.dumps(node_metadata, ensure_ascii=False), + SpanAttributes.METADATA: safe_json_dumps(node_metadata), SpanAttributes.SESSION_ID: trace_info.conversation_id or "", }, start_time=datetime_to_nanos(created_at), - context=trace.set_span_in_context(trace.NonRecordingSpan(context)), + context=workflow_span_context, ) try: @@ -260,11 +313,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance): llm_attributes[SpanAttributes.LLM_PROVIDER] = provider if model: llm_attributes[SpanAttributes.LLM_MODEL_NAME] = model - outputs = ( - json.loads(node_execution.outputs).get("usage", {}) if "outputs" in node_execution else {} - ) usage_data = ( - process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {}) + process_data.get("usage", {}) if "usage" in process_data else outputs_value.get("usage", {}) ) if usage_data: llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] = usage_data.get("total_tokens", 0) @@ -275,8 +325,16 @@ class ArizePhoenixDataTrace(BaseTraceInstance): llm_attributes.update(self._construct_llm_attributes(process_data.get("prompts", []))) node_span.set_attributes(llm_attributes) finally: + if node_execution.status == "failed": + set_span_status(node_span, node_execution.error) + else: + set_span_status(node_span) node_span.end(end_time=datetime_to_nanos(finished_at)) finally: + if trace_info.error: + set_span_status(workflow_span, trace_info.error) + else: + set_span_status(workflow_span) workflow_span.end(end_time=datetime_to_nanos(trace_info.end_time)) def message_trace(self, trace_info: MessageTraceInfo): @@ -322,34 +380,18 @@ class ArizePhoenixDataTrace(BaseTraceInstance): SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id, } - trace_id = string_to_trace_id128(trace_info.trace_id or trace_info.message_id) - message_span_id = RandomIdGenerator().generate_span_id() - span_context = SpanContext( - trace_id=trace_id, - span_id=message_span_id, - is_remote=False, - trace_flags=TraceFlags(TraceFlags.SAMPLED), - trace_state=TraceState(), - ) + dify_trace_id = trace_info.trace_id or trace_info.message_id + self.ensure_root_span(dify_trace_id) + root_span_context = self.propagator.extract(carrier=self.carrier) message_span = self.tracer.start_span( name=TraceTaskName.MESSAGE_TRACE.value, attributes=attributes, start_time=datetime_to_nanos(trace_info.start_time), - context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)), + context=root_span_context, ) try: - if trace_info.error: - message_span.add_event( - "exception", - attributes={ - "exception.message": trace_info.error, - "exception.type": "Error", - "exception.stacktrace": trace_info.error, - }, - ) - # Convert outputs to string based on type if isinstance(trace_info.outputs, dict | list): outputs_str = json.dumps(trace_info.outputs, ensure_ascii=False) @@ -383,26 +425,26 @@ class ArizePhoenixDataTrace(BaseTraceInstance): if model_params := metadata_dict.get("model_parameters"): llm_attributes[SpanAttributes.LLM_INVOCATION_PARAMETERS] = json.dumps(model_params) + message_span_context = set_span_in_context(message_span) llm_span = self.tracer.start_span( name="llm", attributes=llm_attributes, start_time=datetime_to_nanos(trace_info.start_time), - context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)), + context=message_span_context, ) try: - if trace_info.error: - llm_span.add_event( - "exception", - attributes={ - "exception.message": trace_info.error, - "exception.type": "Error", - "exception.stacktrace": trace_info.error, - }, - ) + if trace_info.message_data.error: + set_span_status(llm_span, trace_info.message_data.error) + else: + set_span_status(llm_span) finally: llm_span.end(end_time=datetime_to_nanos(trace_info.end_time)) finally: + if trace_info.error: + set_span_status(message_span, trace_info.error) + else: + set_span_status(message_span) message_span.end(end_time=datetime_to_nanos(trace_info.end_time)) def moderation_trace(self, trace_info: ModerationTraceInfo): @@ -418,15 +460,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance): } metadata.update(trace_info.metadata) - trace_id = string_to_trace_id128(trace_info.message_id) - span_id = RandomIdGenerator().generate_span_id() - context = SpanContext( - trace_id=trace_id, - span_id=span_id, - is_remote=False, - trace_flags=TraceFlags(TraceFlags.SAMPLED), - trace_state=TraceState(), - ) + dify_trace_id = trace_info.trace_id or trace_info.message_id + self.ensure_root_span(dify_trace_id) + root_span_context = self.propagator.extract(carrier=self.carrier) span = self.tracer.start_span( name=TraceTaskName.MODERATION_TRACE.value, @@ -445,19 +481,14 @@ class ArizePhoenixDataTrace(BaseTraceInstance): SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False), }, start_time=datetime_to_nanos(trace_info.start_time), - context=trace.set_span_in_context(trace.NonRecordingSpan(context)), + context=root_span_context, ) try: if trace_info.message_data.error: - span.add_event( - "exception", - attributes={ - "exception.message": trace_info.message_data.error, - "exception.type": "Error", - "exception.stacktrace": trace_info.message_data.error, - }, - ) + set_span_status(span, trace_info.message_data.error) + else: + set_span_status(span) finally: span.end(end_time=datetime_to_nanos(trace_info.end_time)) @@ -480,15 +511,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance): } metadata.update(trace_info.metadata) - trace_id = string_to_trace_id128(trace_info.message_id) - span_id = RandomIdGenerator().generate_span_id() - context = SpanContext( - trace_id=trace_id, - span_id=span_id, - is_remote=False, - trace_flags=TraceFlags(TraceFlags.SAMPLED), - trace_state=TraceState(), - ) + dify_trace_id = trace_info.trace_id or trace_info.message_id + self.ensure_root_span(dify_trace_id) + root_span_context = self.propagator.extract(carrier=self.carrier) span = self.tracer.start_span( name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value, @@ -499,19 +524,14 @@ class ArizePhoenixDataTrace(BaseTraceInstance): SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False), }, start_time=datetime_to_nanos(start_time), - context=trace.set_span_in_context(trace.NonRecordingSpan(context)), + context=root_span_context, ) try: if trace_info.error: - span.add_event( - "exception", - attributes={ - "exception.message": trace_info.error, - "exception.type": "Error", - "exception.stacktrace": trace_info.error, - }, - ) + set_span_status(span, trace_info.error) + else: + set_span_status(span) finally: span.end(end_time=datetime_to_nanos(end_time)) @@ -533,15 +553,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance): } metadata.update(trace_info.metadata) - trace_id = string_to_trace_id128(trace_info.message_id) - span_id = RandomIdGenerator().generate_span_id() - context = SpanContext( - trace_id=trace_id, - span_id=span_id, - is_remote=False, - trace_flags=TraceFlags(TraceFlags.SAMPLED), - trace_state=TraceState(), - ) + dify_trace_id = trace_info.trace_id or trace_info.message_id + self.ensure_root_span(dify_trace_id) + root_span_context = self.propagator.extract(carrier=self.carrier) span = self.tracer.start_span( name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value, @@ -554,19 +568,14 @@ class ArizePhoenixDataTrace(BaseTraceInstance): "end_time": end_time.isoformat() if end_time else "", }, start_time=datetime_to_nanos(start_time), - context=trace.set_span_in_context(trace.NonRecordingSpan(context)), + context=root_span_context, ) try: if trace_info.message_data.error: - span.add_event( - "exception", - attributes={ - "exception.message": trace_info.message_data.error, - "exception.type": "Error", - "exception.stacktrace": trace_info.message_data.error, - }, - ) + set_span_status(span, trace_info.message_data.error) + else: + set_span_status(span) finally: span.end(end_time=datetime_to_nanos(end_time)) @@ -580,20 +589,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance): "tool_config": json.dumps(trace_info.tool_config, ensure_ascii=False), } - trace_id = string_to_trace_id128(trace_info.message_id) - tool_span_id = RandomIdGenerator().generate_span_id() - logger.info("[Arize/Phoenix] Creating tool trace with trace_id: %s, span_id: %s", trace_id, tool_span_id) - - # Create span context with the same trace_id as the parent - # todo: Create with the appropriate parent span context, so that the tool span is - # a child of the appropriate span (e.g. message span) - span_context = SpanContext( - trace_id=trace_id, - span_id=tool_span_id, - is_remote=False, - trace_flags=TraceFlags(TraceFlags.SAMPLED), - trace_state=TraceState(), - ) + dify_trace_id = trace_info.trace_id or trace_info.message_id + self.ensure_root_span(dify_trace_id) + root_span_context = self.propagator.extract(carrier=self.carrier) tool_params_str = ( json.dumps(trace_info.tool_parameters, ensure_ascii=False) @@ -612,19 +610,14 @@ class ArizePhoenixDataTrace(BaseTraceInstance): SpanAttributes.TOOL_PARAMETERS: tool_params_str, }, start_time=datetime_to_nanos(trace_info.start_time), - context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)), + context=root_span_context, ) try: if trace_info.error: - span.add_event( - "exception", - attributes={ - "exception.message": trace_info.error, - "exception.type": "Error", - "exception.stacktrace": trace_info.error, - }, - ) + set_span_status(span, trace_info.error) + else: + set_span_status(span) finally: span.end(end_time=datetime_to_nanos(trace_info.end_time)) @@ -641,15 +634,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance): } metadata.update(trace_info.metadata) - trace_id = string_to_trace_id128(trace_info.message_id) - span_id = RandomIdGenerator().generate_span_id() - context = SpanContext( - trace_id=trace_id, - span_id=span_id, - is_remote=False, - trace_flags=TraceFlags(TraceFlags.SAMPLED), - trace_state=TraceState(), - ) + dify_trace_id = trace_info.trace_id or trace_info.message_id or trace_info.conversation_id + self.ensure_root_span(dify_trace_id) + root_span_context = self.propagator.extract(carrier=self.carrier) span = self.tracer.start_span( name=TraceTaskName.GENERATE_NAME_TRACE.value, @@ -663,22 +650,34 @@ class ArizePhoenixDataTrace(BaseTraceInstance): "end_time": trace_info.end_time.isoformat() if trace_info.end_time else "", }, start_time=datetime_to_nanos(trace_info.start_time), - context=trace.set_span_in_context(trace.NonRecordingSpan(context)), + context=root_span_context, ) try: if trace_info.message_data.error: - span.add_event( - "exception", - attributes={ - "exception.message": trace_info.message_data.error, - "exception.type": "Error", - "exception.stacktrace": trace_info.message_data.error, - }, - ) + set_span_status(span, trace_info.message_data.error) + else: + set_span_status(span) finally: span.end(end_time=datetime_to_nanos(trace_info.end_time)) + def ensure_root_span(self, dify_trace_id: str | None): + """Ensure a unique root span exists for the given Dify trace ID.""" + if str(dify_trace_id) not in self.dify_trace_ids: + self.carrier: dict[str, str] = {} + + root_span = self.tracer.start_span(name="Dify") + root_span.set_attribute(SpanAttributes.OPENINFERENCE_SPAN_KIND, OpenInferenceSpanKindValues.CHAIN.value) + root_span.set_attribute("dify_project_name", str(self.project)) + root_span.set_attribute("dify_trace_id", str(dify_trace_id)) + + with use_span(root_span, end_on_exit=False): + self.propagator.inject(carrier=self.carrier) + + set_span_status(root_span) + root_span.end() + self.dify_trace_ids.add(str(dify_trace_id)) + def api_check(self): try: with self.tracer.start_span("api_check") as span: @@ -698,26 +697,6 @@ class ArizePhoenixDataTrace(BaseTraceInstance): logger.info("[Arize/Phoenix] Get run url failed: %s", str(e), exc_info=True) raise ValueError(f"[Arize/Phoenix] Get run url failed: {str(e)}") - def _get_workflow_nodes(self, workflow_run_id: str): - """Helper method to get workflow nodes""" - workflow_nodes = db.session.scalars( - select( - WorkflowNodeExecutionModel.id, - WorkflowNodeExecutionModel.tenant_id, - WorkflowNodeExecutionModel.app_id, - WorkflowNodeExecutionModel.title, - WorkflowNodeExecutionModel.node_type, - WorkflowNodeExecutionModel.status, - WorkflowNodeExecutionModel.inputs, - WorkflowNodeExecutionModel.outputs, - WorkflowNodeExecutionModel.created_at, - WorkflowNodeExecutionModel.elapsed_time, - WorkflowNodeExecutionModel.process_data, - WorkflowNodeExecutionModel.execution_metadata, - ).where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id) - ).all() - return workflow_nodes - def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]: """Helper method to construct LLM attributes with passed prompts.""" attributes = {}