mirror of https://github.com/langgenius/dify.git
use the contextvar instead of flask.g
This commit is contained in:
parent
c37d4b765f
commit
26d12b34ce
|
|
@ -25,7 +25,10 @@ def create_flask_app_with_configs() -> DifyApp:
|
|||
# add before request hook
|
||||
@dify_app.before_request
|
||||
def before_request():
|
||||
# add an unique identifier to each request
|
||||
# Initialize logging context for this request
|
||||
from core.logging.context import init_request_context
|
||||
|
||||
init_request_context()
|
||||
RecyclableContextVar.increment_thread_recycles()
|
||||
|
||||
# add after request hook for injecting trace headers from OpenTelemetry span context
|
||||
|
|
|
|||
|
|
@ -133,8 +133,8 @@ def generate_traceparent_header() -> str | None:
|
|||
"""
|
||||
Generate a W3C traceparent header from the current context.
|
||||
|
||||
Uses OpenTelemetry context if available, otherwise generates new IDs
|
||||
based on the Flask request context.
|
||||
Uses OpenTelemetry context if available, otherwise uses the
|
||||
ContextVar-based trace_id from the logging context.
|
||||
|
||||
Format: {version}-{trace_id}-{span_id}-{flags}
|
||||
Example: 00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01
|
||||
|
|
@ -151,17 +151,10 @@ def generate_traceparent_header() -> str | None:
|
|||
if trace_id and span_id:
|
||||
return f"00-{trace_id}-{span_id}-01"
|
||||
|
||||
# Fallback: generate new trace context
|
||||
try:
|
||||
import flask
|
||||
# Fallback: use ContextVar-based trace_id or generate new one
|
||||
from core.logging.context import get_trace_id as get_logging_trace_id
|
||||
|
||||
if flask.has_request_context() and hasattr(flask.g, "request_id"):
|
||||
# Derive trace_id from request_id for consistency
|
||||
trace_id = uuid.uuid5(uuid.NAMESPACE_DNS, flask.g.request_id).hex
|
||||
else:
|
||||
trace_id = uuid.uuid4().hex
|
||||
except Exception:
|
||||
trace_id = uuid.uuid4().hex
|
||||
trace_id = get_logging_trace_id() or uuid.uuid4().hex
|
||||
|
||||
# Generate a new span_id (16 hex chars)
|
||||
span_id = uuid.uuid4().hex[:16]
|
||||
|
|
|
|||
|
|
@ -1,5 +1,11 @@
|
|||
"""Structured logging components for Dify."""
|
||||
|
||||
from core.logging.context import (
|
||||
clear_request_context,
|
||||
get_request_id,
|
||||
get_trace_id,
|
||||
init_request_context,
|
||||
)
|
||||
from core.logging.filters import IdentityContextFilter, TraceContextFilter
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
|
|
@ -7,4 +13,8 @@ __all__ = [
|
|||
"IdentityContextFilter",
|
||||
"StructuredJSONFormatter",
|
||||
"TraceContextFilter",
|
||||
"clear_request_context",
|
||||
"get_request_id",
|
||||
"get_trace_id",
|
||||
"init_request_context",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,35 @@
|
|||
"""Request context for logging - framework agnostic.
|
||||
|
||||
This module provides request-scoped context variables for logging,
|
||||
using Python's contextvars for thread-safe and async-safe storage.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from contextvars import ContextVar
|
||||
|
||||
_request_id: ContextVar[str] = ContextVar("log_request_id", default="")
|
||||
_trace_id: ContextVar[str] = ContextVar("log_trace_id", default="")
|
||||
|
||||
|
||||
def get_request_id() -> str:
|
||||
"""Get current request ID (10 hex chars)."""
|
||||
return _request_id.get()
|
||||
|
||||
|
||||
def get_trace_id() -> str:
|
||||
"""Get fallback trace ID when OTEL is unavailable (32 hex chars)."""
|
||||
return _trace_id.get()
|
||||
|
||||
|
||||
def init_request_context() -> None:
|
||||
"""Initialize request context. Call at start of each request."""
|
||||
req_id = uuid.uuid4().hex[:10]
|
||||
trace_id = uuid.uuid5(uuid.NAMESPACE_DNS, req_id).hex
|
||||
_request_id.set(req_id)
|
||||
_trace_id.set(trace_id)
|
||||
|
||||
|
||||
def clear_request_context() -> None:
|
||||
"""Clear request context. Call at end of request (optional)."""
|
||||
_request_id.set("")
|
||||
_trace_id.set("")
|
||||
|
|
@ -2,31 +2,32 @@
|
|||
|
||||
import contextlib
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
import flask
|
||||
|
||||
from core.logging.context import get_request_id, get_trace_id
|
||||
|
||||
|
||||
class TraceContextFilter(logging.Filter):
|
||||
"""
|
||||
Filter that adds trace_id and span_id to log records.
|
||||
Integrates with OpenTelemetry when available, falls back to request_id.
|
||||
Integrates with OpenTelemetry when available, falls back to ContextVar-based trace_id.
|
||||
"""
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
# Get trace context from OpenTelemetry
|
||||
trace_id, span_id = self._get_otel_context()
|
||||
|
||||
# Set trace_id (fallback to request_id if no OTEL context)
|
||||
# Set trace_id (fallback to ContextVar if no OTEL context)
|
||||
if trace_id:
|
||||
record.trace_id = trace_id
|
||||
else:
|
||||
record.trace_id = self._get_or_create_request_trace_id()
|
||||
record.trace_id = get_trace_id()
|
||||
|
||||
record.span_id = span_id or ""
|
||||
|
||||
# For backward compatibility, also set req_id
|
||||
record.req_id = self._get_request_id()
|
||||
record.req_id = get_request_id()
|
||||
|
||||
return True
|
||||
|
||||
|
|
@ -45,28 +46,6 @@ class TraceContextFilter(logging.Filter):
|
|||
return trace_id, span_id
|
||||
return "", ""
|
||||
|
||||
def _get_request_id(self) -> str:
|
||||
"""Get request ID from Flask context."""
|
||||
if flask.has_request_context():
|
||||
if hasattr(flask.g, "request_id"):
|
||||
return flask.g.request_id
|
||||
flask.g.request_id = uuid.uuid4().hex[:10]
|
||||
return flask.g.request_id
|
||||
return ""
|
||||
|
||||
def _get_or_create_request_trace_id(self) -> str:
|
||||
"""Get or create a trace_id derived from request context."""
|
||||
if flask.has_request_context():
|
||||
if hasattr(flask.g, "_trace_id"):
|
||||
return flask.g._trace_id
|
||||
# Derive trace_id from request_id for consistency
|
||||
request_id = self._get_request_id()
|
||||
if request_id:
|
||||
# Generate a 32-char hex trace_id from request_id
|
||||
flask.g._trace_id = uuid.uuid5(uuid.NAMESPACE_DNS, request_id).hex
|
||||
return flask.g._trace_id
|
||||
return ""
|
||||
|
||||
|
||||
class IdentityContextFilter(logging.Filter):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -3,11 +3,8 @@
|
|||
import logging
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
import flask
|
||||
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
|
||||
|
|
@ -106,13 +103,13 @@ class _TextFormatter(logging.Formatter):
|
|||
|
||||
|
||||
def get_request_id() -> str:
|
||||
"""Get or create request ID for current request context."""
|
||||
if flask.has_request_context():
|
||||
if getattr(flask.g, "request_id", None):
|
||||
return flask.g.request_id
|
||||
flask.g.request_id = uuid.uuid4().hex[:10]
|
||||
return flask.g.request_id
|
||||
return ""
|
||||
"""Get request ID for current request context.
|
||||
|
||||
Deprecated: Use core.logging.context.get_request_id() directly.
|
||||
"""
|
||||
from core.logging.context import get_request_id as _get_request_id
|
||||
|
||||
return _get_request_id()
|
||||
|
||||
|
||||
# Backward compatibility aliases
|
||||
|
|
@ -120,11 +117,11 @@ class RequestIdFilter(logging.Filter):
|
|||
"""Deprecated: Use TraceContextFilter from core.logging.filters instead."""
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
from core.helper.trace_id_helper import get_trace_id_from_otel_context
|
||||
from core.logging.context import get_request_id as _get_request_id
|
||||
from core.logging.context import get_trace_id as _get_trace_id
|
||||
|
||||
trace_id = get_trace_id_from_otel_context() or ""
|
||||
record.req_id = get_request_id() if flask.has_request_context() else ""
|
||||
record.trace_id = trace_id
|
||||
record.req_id = _get_request_id()
|
||||
record.trace_id = _get_trace_id()
|
||||
return True
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,79 @@
|
|||
"""Tests for logging context module."""
|
||||
|
||||
import uuid
|
||||
|
||||
from core.logging.context import (
|
||||
clear_request_context,
|
||||
get_request_id,
|
||||
get_trace_id,
|
||||
init_request_context,
|
||||
)
|
||||
|
||||
|
||||
class TestLoggingContext:
|
||||
"""Tests for the logging context functions."""
|
||||
|
||||
def test_init_creates_request_id(self):
|
||||
"""init_request_context should create a 10-char request ID."""
|
||||
init_request_context()
|
||||
request_id = get_request_id()
|
||||
assert len(request_id) == 10
|
||||
assert all(c in "0123456789abcdef" for c in request_id)
|
||||
|
||||
def test_init_creates_trace_id(self):
|
||||
"""init_request_context should create a 32-char trace ID."""
|
||||
init_request_context()
|
||||
trace_id = get_trace_id()
|
||||
assert len(trace_id) == 32
|
||||
assert all(c in "0123456789abcdef" for c in trace_id)
|
||||
|
||||
def test_trace_id_derived_from_request_id(self):
|
||||
"""trace_id should be deterministically derived from request_id."""
|
||||
init_request_context()
|
||||
request_id = get_request_id()
|
||||
trace_id = get_trace_id()
|
||||
|
||||
# Verify trace_id is derived using uuid5
|
||||
expected_trace = uuid.uuid5(uuid.NAMESPACE_DNS, request_id).hex
|
||||
assert trace_id == expected_trace
|
||||
|
||||
def test_clear_resets_context(self):
|
||||
"""clear_request_context should reset both IDs to empty strings."""
|
||||
init_request_context()
|
||||
assert get_request_id() != ""
|
||||
assert get_trace_id() != ""
|
||||
|
||||
clear_request_context()
|
||||
assert get_request_id() == ""
|
||||
assert get_trace_id() == ""
|
||||
|
||||
def test_default_values_are_empty(self):
|
||||
"""Default values should be empty strings before init."""
|
||||
clear_request_context()
|
||||
assert get_request_id() == ""
|
||||
assert get_trace_id() == ""
|
||||
|
||||
def test_multiple_inits_create_different_ids(self):
|
||||
"""Each init should create new unique IDs."""
|
||||
init_request_context()
|
||||
first_request_id = get_request_id()
|
||||
first_trace_id = get_trace_id()
|
||||
|
||||
init_request_context()
|
||||
second_request_id = get_request_id()
|
||||
second_trace_id = get_trace_id()
|
||||
|
||||
assert first_request_id != second_request_id
|
||||
assert first_trace_id != second_trace_id
|
||||
|
||||
def test_context_isolation(self):
|
||||
"""Context should be isolated per-call (no thread leakage in same thread)."""
|
||||
init_request_context()
|
||||
id1 = get_request_id()
|
||||
|
||||
# Simulate another request
|
||||
init_request_context()
|
||||
id2 = get_request_id()
|
||||
|
||||
# IDs should be different
|
||||
assert id1 != id2
|
||||
|
|
@ -6,8 +6,12 @@ from unittest import mock
|
|||
|
||||
class TestTraceContextFilter:
|
||||
def test_sets_empty_trace_id_without_context(self):
|
||||
from core.logging.context import clear_request_context
|
||||
from core.logging.filters import TraceContextFilter
|
||||
|
||||
# Ensure no context is set
|
||||
clear_request_context()
|
||||
|
||||
filter = TraceContextFilter()
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
|
|
@ -25,6 +29,36 @@ class TestTraceContextFilter:
|
|||
assert hasattr(record, "trace_id")
|
||||
assert hasattr(record, "span_id")
|
||||
assert hasattr(record, "req_id")
|
||||
# Without context, IDs should be empty
|
||||
assert record.trace_id == ""
|
||||
assert record.req_id == ""
|
||||
|
||||
def test_sets_trace_id_from_context(self):
|
||||
"""Test that trace_id and req_id are set from ContextVar when initialized."""
|
||||
from core.logging.context import init_request_context
|
||||
from core.logging.filters import TraceContextFilter
|
||||
|
||||
# Initialize context (no Flask needed!)
|
||||
init_request_context()
|
||||
|
||||
filter = TraceContextFilter()
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="",
|
||||
lineno=0,
|
||||
msg="test",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
|
||||
filter.filter(record)
|
||||
|
||||
# With context initialized, IDs should be set
|
||||
assert record.trace_id != ""
|
||||
assert len(record.trace_id) == 32
|
||||
assert record.req_id != ""
|
||||
assert len(record.req_id) == 10
|
||||
|
||||
def test_filter_always_returns_true(self):
|
||||
from core.logging.filters import TraceContextFilter
|
||||
|
|
|
|||
Loading…
Reference in New Issue