use the contextvar instead of flask.g

This commit is contained in:
Byron Wang 2025-12-25 15:57:28 +08:00
parent c37d4b765f
commit 26d12b34ce
No known key found for this signature in database
GPG Key ID: 335E934E215AD579
8 changed files with 184 additions and 54 deletions

View File

@ -25,7 +25,10 @@ def create_flask_app_with_configs() -> DifyApp:
# add before request hook
@dify_app.before_request
def before_request():
# add an unique identifier to each request
# Initialize logging context for this request
from core.logging.context import init_request_context
init_request_context()
RecyclableContextVar.increment_thread_recycles()
# add after request hook for injecting trace headers from OpenTelemetry span context

View File

@ -133,8 +133,8 @@ def generate_traceparent_header() -> str | None:
"""
Generate a W3C traceparent header from the current context.
Uses OpenTelemetry context if available, otherwise generates new IDs
based on the Flask request context.
Uses OpenTelemetry context if available, otherwise uses the
ContextVar-based trace_id from the logging context.
Format: {version}-{trace_id}-{span_id}-{flags}
Example: 00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01
@ -151,17 +151,10 @@ def generate_traceparent_header() -> str | None:
if trace_id and span_id:
return f"00-{trace_id}-{span_id}-01"
# Fallback: generate new trace context
try:
import flask
# Fallback: use ContextVar-based trace_id or generate new one
from core.logging.context import get_trace_id as get_logging_trace_id
if flask.has_request_context() and hasattr(flask.g, "request_id"):
# Derive trace_id from request_id for consistency
trace_id = uuid.uuid5(uuid.NAMESPACE_DNS, flask.g.request_id).hex
else:
trace_id = uuid.uuid4().hex
except Exception:
trace_id = uuid.uuid4().hex
trace_id = get_logging_trace_id() or uuid.uuid4().hex
# Generate a new span_id (16 hex chars)
span_id = uuid.uuid4().hex[:16]

View File

@ -1,5 +1,11 @@
"""Structured logging components for Dify."""
from core.logging.context import (
clear_request_context,
get_request_id,
get_trace_id,
init_request_context,
)
from core.logging.filters import IdentityContextFilter, TraceContextFilter
from core.logging.structured_formatter import StructuredJSONFormatter
@ -7,4 +13,8 @@ __all__ = [
"IdentityContextFilter",
"StructuredJSONFormatter",
"TraceContextFilter",
"clear_request_context",
"get_request_id",
"get_trace_id",
"init_request_context",
]

View File

@ -0,0 +1,35 @@
"""Request context for logging - framework agnostic.
This module provides request-scoped context variables for logging,
using Python's contextvars for thread-safe and async-safe storage.
"""
import uuid
from contextvars import ContextVar
_request_id: ContextVar[str] = ContextVar("log_request_id", default="")
_trace_id: ContextVar[str] = ContextVar("log_trace_id", default="")
def get_request_id() -> str:
"""Get current request ID (10 hex chars)."""
return _request_id.get()
def get_trace_id() -> str:
"""Get fallback trace ID when OTEL is unavailable (32 hex chars)."""
return _trace_id.get()
def init_request_context() -> None:
"""Initialize request context. Call at start of each request."""
req_id = uuid.uuid4().hex[:10]
trace_id = uuid.uuid5(uuid.NAMESPACE_DNS, req_id).hex
_request_id.set(req_id)
_trace_id.set(trace_id)
def clear_request_context() -> None:
"""Clear request context. Call at end of request (optional)."""
_request_id.set("")
_trace_id.set("")

View File

@ -2,31 +2,32 @@
import contextlib
import logging
import uuid
import flask
from core.logging.context import get_request_id, get_trace_id
class TraceContextFilter(logging.Filter):
"""
Filter that adds trace_id and span_id to log records.
Integrates with OpenTelemetry when available, falls back to request_id.
Integrates with OpenTelemetry when available, falls back to ContextVar-based trace_id.
"""
def filter(self, record: logging.LogRecord) -> bool:
# Get trace context from OpenTelemetry
trace_id, span_id = self._get_otel_context()
# Set trace_id (fallback to request_id if no OTEL context)
# Set trace_id (fallback to ContextVar if no OTEL context)
if trace_id:
record.trace_id = trace_id
else:
record.trace_id = self._get_or_create_request_trace_id()
record.trace_id = get_trace_id()
record.span_id = span_id or ""
# For backward compatibility, also set req_id
record.req_id = self._get_request_id()
record.req_id = get_request_id()
return True
@ -45,28 +46,6 @@ class TraceContextFilter(logging.Filter):
return trace_id, span_id
return "", ""
def _get_request_id(self) -> str:
"""Get request ID from Flask context."""
if flask.has_request_context():
if hasattr(flask.g, "request_id"):
return flask.g.request_id
flask.g.request_id = uuid.uuid4().hex[:10]
return flask.g.request_id
return ""
def _get_or_create_request_trace_id(self) -> str:
"""Get or create a trace_id derived from request context."""
if flask.has_request_context():
if hasattr(flask.g, "_trace_id"):
return flask.g._trace_id
# Derive trace_id from request_id for consistency
request_id = self._get_request_id()
if request_id:
# Generate a 32-char hex trace_id from request_id
flask.g._trace_id = uuid.uuid5(uuid.NAMESPACE_DNS, request_id).hex
return flask.g._trace_id
return ""
class IdentityContextFilter(logging.Filter):
"""

View File

@ -3,11 +3,8 @@
import logging
import os
import sys
import uuid
from logging.handlers import RotatingFileHandler
import flask
from configs import dify_config
from dify_app import DifyApp
@ -106,13 +103,13 @@ class _TextFormatter(logging.Formatter):
def get_request_id() -> str:
"""Get or create request ID for current request context."""
if flask.has_request_context():
if getattr(flask.g, "request_id", None):
return flask.g.request_id
flask.g.request_id = uuid.uuid4().hex[:10]
return flask.g.request_id
return ""
"""Get request ID for current request context.
Deprecated: Use core.logging.context.get_request_id() directly.
"""
from core.logging.context import get_request_id as _get_request_id
return _get_request_id()
# Backward compatibility aliases
@ -120,11 +117,11 @@ class RequestIdFilter(logging.Filter):
"""Deprecated: Use TraceContextFilter from core.logging.filters instead."""
def filter(self, record: logging.LogRecord) -> bool:
from core.helper.trace_id_helper import get_trace_id_from_otel_context
from core.logging.context import get_request_id as _get_request_id
from core.logging.context import get_trace_id as _get_trace_id
trace_id = get_trace_id_from_otel_context() or ""
record.req_id = get_request_id() if flask.has_request_context() else ""
record.trace_id = trace_id
record.req_id = _get_request_id()
record.trace_id = _get_trace_id()
return True

View File

@ -0,0 +1,79 @@
"""Tests for logging context module."""
import uuid
from core.logging.context import (
clear_request_context,
get_request_id,
get_trace_id,
init_request_context,
)
class TestLoggingContext:
"""Tests for the logging context functions."""
def test_init_creates_request_id(self):
"""init_request_context should create a 10-char request ID."""
init_request_context()
request_id = get_request_id()
assert len(request_id) == 10
assert all(c in "0123456789abcdef" for c in request_id)
def test_init_creates_trace_id(self):
"""init_request_context should create a 32-char trace ID."""
init_request_context()
trace_id = get_trace_id()
assert len(trace_id) == 32
assert all(c in "0123456789abcdef" for c in trace_id)
def test_trace_id_derived_from_request_id(self):
"""trace_id should be deterministically derived from request_id."""
init_request_context()
request_id = get_request_id()
trace_id = get_trace_id()
# Verify trace_id is derived using uuid5
expected_trace = uuid.uuid5(uuid.NAMESPACE_DNS, request_id).hex
assert trace_id == expected_trace
def test_clear_resets_context(self):
"""clear_request_context should reset both IDs to empty strings."""
init_request_context()
assert get_request_id() != ""
assert get_trace_id() != ""
clear_request_context()
assert get_request_id() == ""
assert get_trace_id() == ""
def test_default_values_are_empty(self):
"""Default values should be empty strings before init."""
clear_request_context()
assert get_request_id() == ""
assert get_trace_id() == ""
def test_multiple_inits_create_different_ids(self):
"""Each init should create new unique IDs."""
init_request_context()
first_request_id = get_request_id()
first_trace_id = get_trace_id()
init_request_context()
second_request_id = get_request_id()
second_trace_id = get_trace_id()
assert first_request_id != second_request_id
assert first_trace_id != second_trace_id
def test_context_isolation(self):
"""Context should be isolated per-call (no thread leakage in same thread)."""
init_request_context()
id1 = get_request_id()
# Simulate another request
init_request_context()
id2 = get_request_id()
# IDs should be different
assert id1 != id2

View File

@ -6,8 +6,12 @@ from unittest import mock
class TestTraceContextFilter:
def test_sets_empty_trace_id_without_context(self):
from core.logging.context import clear_request_context
from core.logging.filters import TraceContextFilter
# Ensure no context is set
clear_request_context()
filter = TraceContextFilter()
record = logging.LogRecord(
name="test",
@ -25,6 +29,36 @@ class TestTraceContextFilter:
assert hasattr(record, "trace_id")
assert hasattr(record, "span_id")
assert hasattr(record, "req_id")
# Without context, IDs should be empty
assert record.trace_id == ""
assert record.req_id == ""
def test_sets_trace_id_from_context(self):
"""Test that trace_id and req_id are set from ContextVar when initialized."""
from core.logging.context import init_request_context
from core.logging.filters import TraceContextFilter
# Initialize context (no Flask needed!)
init_request_context()
filter = TraceContextFilter()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="",
lineno=0,
msg="test",
args=(),
exc_info=None,
)
filter.filter(record)
# With context initialized, IDs should be set
assert record.trace_id != ""
assert len(record.trace_id) == 32
assert record.req_id != ""
assert len(record.req_id) == 10
def test_filter_always_returns_true(self):
from core.logging.filters import TraceContextFilter