mirror of
https://github.com/langgenius/dify.git
synced 2026-06-07 16:32:01 +08:00
262 lines
8.0 KiB
Python
262 lines
8.0 KiB
Python
import contextlib
|
|
import re
|
|
from collections.abc import Mapping
|
|
from typing import Any
|
|
|
|
from pydantic import BaseModel, ConfigDict, StrictStr, ValidationError
|
|
from werkzeug.exceptions import BadRequest
|
|
|
|
|
|
class ParentTraceContext(BaseModel):
|
|
"""Typed parent trace context propagated from an outer workflow tool node."""
|
|
|
|
parent_workflow_run_id: StrictStr
|
|
parent_node_execution_id: StrictStr | None = None
|
|
|
|
model_config = ConfigDict(extra="forbid")
|
|
|
|
|
|
def is_valid_trace_id(trace_id: str) -> bool:
|
|
"""
|
|
Check if the trace_id is valid.
|
|
|
|
Requirements: 1-128 characters, only letters, numbers, '-', and '_'.
|
|
"""
|
|
return bool(re.match(r"^[a-zA-Z0-9\-_]{1,128}$", trace_id))
|
|
|
|
|
|
def get_external_trace_id(request: Any) -> str | None:
|
|
"""
|
|
Retrieve the trace_id from the request.
|
|
|
|
Priority:
|
|
1. header ('X-Trace-Id')
|
|
2. parameters
|
|
3. JSON body
|
|
4. Current OpenTelemetry context (if enabled)
|
|
5. OpenTelemetry traceparent header (if present and valid)
|
|
|
|
Returns None if no valid trace_id is provided.
|
|
"""
|
|
trace_id = request.headers.get("X-Trace-Id")
|
|
|
|
if not trace_id:
|
|
trace_id = request.args.get("trace_id")
|
|
|
|
if not trace_id and getattr(request, "is_json", False):
|
|
json_data = getattr(request, "json", None)
|
|
if json_data:
|
|
trace_id = json_data.get("trace_id")
|
|
|
|
if not trace_id:
|
|
trace_id = get_trace_id_from_otel_context()
|
|
|
|
if not trace_id:
|
|
traceparent = request.headers.get("traceparent")
|
|
if traceparent:
|
|
trace_id = parse_traceparent_header(traceparent)
|
|
|
|
if isinstance(trace_id, str) and is_valid_trace_id(trace_id):
|
|
return trace_id
|
|
return None
|
|
|
|
|
|
def extract_external_trace_id_from_args(args: Mapping[str, Any]):
|
|
"""
|
|
Extract 'external_trace_id' from args.
|
|
|
|
Returns a dict suitable for use in extras. Returns an empty dict if not found.
|
|
"""
|
|
trace_id = args.get("external_trace_id")
|
|
if trace_id:
|
|
return {"external_trace_id": trace_id}
|
|
return {}
|
|
|
|
|
|
TRACE_SESSION_ID_HEADER = "X-Trace-Session-Id"
|
|
TRACE_SESSION_ID_ARG = "trace_session_id"
|
|
TRACE_SESSION_ID_MAX_LENGTH = 200
|
|
|
|
|
|
def _validate_trace_session_id(value: Any) -> str:
|
|
if not isinstance(value, str):
|
|
raise BadRequest("trace_session_id must be a string.")
|
|
|
|
normalized = value.strip()
|
|
if not normalized:
|
|
raise BadRequest("trace_session_id must be 1 to 200 characters after trimming.")
|
|
if len(normalized) > TRACE_SESSION_ID_MAX_LENGTH:
|
|
raise BadRequest("trace_session_id must be 1 to 200 characters after trimming.")
|
|
return normalized
|
|
|
|
|
|
def get_trace_session_id(request: Any) -> str | None:
|
|
"""
|
|
Resolve the Service API trace session ID from explicit request inputs.
|
|
|
|
Priority is ``X-Trace-Session-Id`` header, then ``trace_session_id`` query
|
|
parameter, then ``trace_session_id`` JSON body field. Only the resolved
|
|
highest-priority input is validated; lower-priority values are ignored.
|
|
"""
|
|
if TRACE_SESSION_ID_HEADER in request.headers:
|
|
return _validate_trace_session_id(request.headers.get(TRACE_SESSION_ID_HEADER))
|
|
|
|
if TRACE_SESSION_ID_ARG in request.args:
|
|
return _validate_trace_session_id(request.args.get(TRACE_SESSION_ID_ARG))
|
|
|
|
if getattr(request, "is_json", False):
|
|
json_data = getattr(request, "json", None)
|
|
if isinstance(json_data, Mapping) and TRACE_SESSION_ID_ARG in json_data:
|
|
return _validate_trace_session_id(json_data.get(TRACE_SESSION_ID_ARG))
|
|
|
|
return None
|
|
|
|
|
|
def extract_trace_session_id_from_args(args: Mapping[str, Any]) -> dict[str, str]:
|
|
"""
|
|
Extract normalized ``trace_session_id`` from generation args for entity extras.
|
|
"""
|
|
trace_session_id = args.get(TRACE_SESSION_ID_ARG)
|
|
if isinstance(trace_session_id, str):
|
|
normalized = trace_session_id.strip()
|
|
if normalized:
|
|
return {TRACE_SESSION_ID_ARG: normalized}
|
|
return {}
|
|
|
|
|
|
def omit_trace_session_id_from_payload(payload: Any) -> Any:
|
|
"""
|
|
Return a payload copy without transport-level ``trace_session_id``.
|
|
|
|
Controllers validate this field through :func:`get_trace_session_id` so lower-priority
|
|
body values cannot fail DTO validation before header/query priority is applied.
|
|
"""
|
|
if isinstance(payload, Mapping) and TRACE_SESSION_ID_ARG in payload:
|
|
return {key: value for key, value in payload.items() if key != TRACE_SESSION_ID_ARG}
|
|
return payload
|
|
|
|
|
|
def extract_parent_trace_context_from_args(args: Mapping[str, Any]) -> dict[str, ParentTraceContext]:
|
|
"""
|
|
Extract 'parent_trace_context' from args.
|
|
|
|
Returns a dict suitable for use in extras when both parent identifiers exist.
|
|
Returns an empty dict if the context is missing or incomplete.
|
|
"""
|
|
parent_trace_context = args.get("parent_trace_context")
|
|
if isinstance(parent_trace_context, ParentTraceContext):
|
|
context = parent_trace_context
|
|
elif isinstance(parent_trace_context, Mapping):
|
|
try:
|
|
context = ParentTraceContext.model_validate(parent_trace_context)
|
|
except ValidationError:
|
|
return {}
|
|
else:
|
|
return {}
|
|
|
|
if context.parent_node_execution_id is None:
|
|
return {}
|
|
|
|
return {"parent_trace_context": context}
|
|
|
|
|
|
def get_trace_id_from_otel_context() -> str | None:
|
|
"""
|
|
Retrieve the current trace ID from the active OpenTelemetry trace context.
|
|
Returns None if:
|
|
1. OpenTelemetry SDK is not installed or enabled.
|
|
2. There is no active span or trace context.
|
|
"""
|
|
try:
|
|
from opentelemetry.trace import SpanContext, get_current_span
|
|
from opentelemetry.trace.span import INVALID_TRACE_ID
|
|
|
|
span = get_current_span()
|
|
if not span:
|
|
return None
|
|
|
|
span_context: SpanContext = span.get_span_context()
|
|
|
|
if not span_context or span_context.trace_id == INVALID_TRACE_ID:
|
|
return None
|
|
|
|
trace_id_hex = f"{span_context.trace_id:032x}"
|
|
return trace_id_hex
|
|
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
def parse_traceparent_header(traceparent: str) -> str | None:
|
|
"""
|
|
Parse the `traceparent` header to extract the trace_id.
|
|
|
|
Expected format:
|
|
'version-trace_id-span_id-flags'
|
|
|
|
Reference:
|
|
W3C Trace Context Specification: https://www.w3.org/TR/trace-context/
|
|
"""
|
|
with contextlib.suppress(Exception):
|
|
parts = traceparent.split("-")
|
|
if len(parts) == 4 and len(parts[1]) == 32:
|
|
return parts[1]
|
|
return None
|
|
|
|
|
|
def get_span_id_from_otel_context() -> str | None:
|
|
"""
|
|
Retrieve the current span ID from the active OpenTelemetry trace context.
|
|
|
|
Returns:
|
|
A 16-character hex string representing the span ID, or None if not available.
|
|
"""
|
|
try:
|
|
from opentelemetry.trace import get_current_span
|
|
from opentelemetry.trace.span import INVALID_SPAN_ID
|
|
|
|
span = get_current_span()
|
|
if not span:
|
|
return None
|
|
|
|
span_context = span.get_span_context()
|
|
if not span_context or span_context.span_id == INVALID_SPAN_ID:
|
|
return None
|
|
|
|
return f"{span_context.span_id:016x}"
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
def generate_traceparent_header() -> str | None:
|
|
"""
|
|
Generate a W3C traceparent header from the current context.
|
|
|
|
Uses OpenTelemetry context if available, otherwise uses the
|
|
ContextVar-based trace_id from the logging context.
|
|
|
|
Format: {version}-{trace_id}-{span_id}-{flags}
|
|
Example: 00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01
|
|
|
|
Returns:
|
|
A valid traceparent header string, or None if generation fails.
|
|
"""
|
|
import uuid
|
|
|
|
# Try OTEL context first
|
|
trace_id = get_trace_id_from_otel_context()
|
|
span_id = get_span_id_from_otel_context()
|
|
|
|
if trace_id and span_id:
|
|
return f"00-{trace_id}-{span_id}-01"
|
|
|
|
# Fallback: use ContextVar-based trace_id or generate new one
|
|
from core.logging.context import get_trace_id as get_logging_trace_id
|
|
|
|
trace_id = get_logging_trace_id() or uuid.uuid4().hex
|
|
|
|
# Generate a new span_id (16 hex chars)
|
|
span_id = uuid.uuid4().hex[:16]
|
|
|
|
return f"00-{trace_id}-{span_id}-01"
|