diff --git a/api/app_factory.py b/api/app_factory.py index d93f51d4a3..7b45736d95 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -25,7 +25,10 @@ def create_flask_app_with_configs() -> DifyApp: # add before request hook @dify_app.before_request def before_request(): - # add an unique identifier to each request + # Initialize logging context for this request + from core.logging.context import init_request_context + + init_request_context() RecyclableContextVar.increment_thread_recycles() # add after request hook for injecting trace headers from OpenTelemetry span context diff --git a/api/core/helper/trace_id_helper.py b/api/core/helper/trace_id_helper.py index 26fef55d52..e827859109 100644 --- a/api/core/helper/trace_id_helper.py +++ b/api/core/helper/trace_id_helper.py @@ -133,8 +133,8 @@ def generate_traceparent_header() -> str | None: """ Generate a W3C traceparent header from the current context. - Uses OpenTelemetry context if available, otherwise generates new IDs - based on the Flask request context. + Uses OpenTelemetry context if available, otherwise uses the + ContextVar-based trace_id from the logging context. Format: {version}-{trace_id}-{span_id}-{flags} Example: 00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01 @@ -151,17 +151,10 @@ def generate_traceparent_header() -> str | None: if trace_id and span_id: return f"00-{trace_id}-{span_id}-01" - # Fallback: generate new trace context - try: - import flask + # Fallback: use ContextVar-based trace_id or generate new one + from core.logging.context import get_trace_id as get_logging_trace_id - if flask.has_request_context() and hasattr(flask.g, "request_id"): - # Derive trace_id from request_id for consistency - trace_id = uuid.uuid5(uuid.NAMESPACE_DNS, flask.g.request_id).hex - else: - trace_id = uuid.uuid4().hex - except Exception: - trace_id = uuid.uuid4().hex + trace_id = get_logging_trace_id() or uuid.uuid4().hex # Generate a new span_id (16 hex chars) span_id = uuid.uuid4().hex[:16] diff --git a/api/core/logging/__init__.py b/api/core/logging/__init__.py index 66d5643251..db046cc9fa 100644 --- a/api/core/logging/__init__.py +++ b/api/core/logging/__init__.py @@ -1,5 +1,11 @@ """Structured logging components for Dify.""" +from core.logging.context import ( + clear_request_context, + get_request_id, + get_trace_id, + init_request_context, +) from core.logging.filters import IdentityContextFilter, TraceContextFilter from core.logging.structured_formatter import StructuredJSONFormatter @@ -7,4 +13,8 @@ __all__ = [ "IdentityContextFilter", "StructuredJSONFormatter", "TraceContextFilter", + "clear_request_context", + "get_request_id", + "get_trace_id", + "init_request_context", ] diff --git a/api/core/logging/context.py b/api/core/logging/context.py new file mode 100644 index 0000000000..18633a0b05 --- /dev/null +++ b/api/core/logging/context.py @@ -0,0 +1,35 @@ +"""Request context for logging - framework agnostic. + +This module provides request-scoped context variables for logging, +using Python's contextvars for thread-safe and async-safe storage. +""" + +import uuid +from contextvars import ContextVar + +_request_id: ContextVar[str] = ContextVar("log_request_id", default="") +_trace_id: ContextVar[str] = ContextVar("log_trace_id", default="") + + +def get_request_id() -> str: + """Get current request ID (10 hex chars).""" + return _request_id.get() + + +def get_trace_id() -> str: + """Get fallback trace ID when OTEL is unavailable (32 hex chars).""" + return _trace_id.get() + + +def init_request_context() -> None: + """Initialize request context. Call at start of each request.""" + req_id = uuid.uuid4().hex[:10] + trace_id = uuid.uuid5(uuid.NAMESPACE_DNS, req_id).hex + _request_id.set(req_id) + _trace_id.set(trace_id) + + +def clear_request_context() -> None: + """Clear request context. Call at end of request (optional).""" + _request_id.set("") + _trace_id.set("") diff --git a/api/core/logging/filters.py b/api/core/logging/filters.py index 9330d32068..1e8aa8d566 100644 --- a/api/core/logging/filters.py +++ b/api/core/logging/filters.py @@ -2,31 +2,32 @@ import contextlib import logging -import uuid import flask +from core.logging.context import get_request_id, get_trace_id + class TraceContextFilter(logging.Filter): """ Filter that adds trace_id and span_id to log records. - Integrates with OpenTelemetry when available, falls back to request_id. + Integrates with OpenTelemetry when available, falls back to ContextVar-based trace_id. """ def filter(self, record: logging.LogRecord) -> bool: # Get trace context from OpenTelemetry trace_id, span_id = self._get_otel_context() - # Set trace_id (fallback to request_id if no OTEL context) + # Set trace_id (fallback to ContextVar if no OTEL context) if trace_id: record.trace_id = trace_id else: - record.trace_id = self._get_or_create_request_trace_id() + record.trace_id = get_trace_id() record.span_id = span_id or "" # For backward compatibility, also set req_id - record.req_id = self._get_request_id() + record.req_id = get_request_id() return True @@ -45,28 +46,6 @@ class TraceContextFilter(logging.Filter): return trace_id, span_id return "", "" - def _get_request_id(self) -> str: - """Get request ID from Flask context.""" - if flask.has_request_context(): - if hasattr(flask.g, "request_id"): - return flask.g.request_id - flask.g.request_id = uuid.uuid4().hex[:10] - return flask.g.request_id - return "" - - def _get_or_create_request_trace_id(self) -> str: - """Get or create a trace_id derived from request context.""" - if flask.has_request_context(): - if hasattr(flask.g, "_trace_id"): - return flask.g._trace_id - # Derive trace_id from request_id for consistency - request_id = self._get_request_id() - if request_id: - # Generate a 32-char hex trace_id from request_id - flask.g._trace_id = uuid.uuid5(uuid.NAMESPACE_DNS, request_id).hex - return flask.g._trace_id - return "" - class IdentityContextFilter(logging.Filter): """ diff --git a/api/extensions/ext_logging.py b/api/extensions/ext_logging.py index 3bf01a3bdc..978a40c503 100644 --- a/api/extensions/ext_logging.py +++ b/api/extensions/ext_logging.py @@ -3,11 +3,8 @@ import logging import os import sys -import uuid from logging.handlers import RotatingFileHandler -import flask - from configs import dify_config from dify_app import DifyApp @@ -106,13 +103,13 @@ class _TextFormatter(logging.Formatter): def get_request_id() -> str: - """Get or create request ID for current request context.""" - if flask.has_request_context(): - if getattr(flask.g, "request_id", None): - return flask.g.request_id - flask.g.request_id = uuid.uuid4().hex[:10] - return flask.g.request_id - return "" + """Get request ID for current request context. + + Deprecated: Use core.logging.context.get_request_id() directly. + """ + from core.logging.context import get_request_id as _get_request_id + + return _get_request_id() # Backward compatibility aliases @@ -120,11 +117,11 @@ class RequestIdFilter(logging.Filter): """Deprecated: Use TraceContextFilter from core.logging.filters instead.""" def filter(self, record: logging.LogRecord) -> bool: - from core.helper.trace_id_helper import get_trace_id_from_otel_context + from core.logging.context import get_request_id as _get_request_id + from core.logging.context import get_trace_id as _get_trace_id - trace_id = get_trace_id_from_otel_context() or "" - record.req_id = get_request_id() if flask.has_request_context() else "" - record.trace_id = trace_id + record.req_id = _get_request_id() + record.trace_id = _get_trace_id() return True diff --git a/api/tests/unit_tests/core/logging/test_context.py b/api/tests/unit_tests/core/logging/test_context.py new file mode 100644 index 0000000000..f388a3a0b9 --- /dev/null +++ b/api/tests/unit_tests/core/logging/test_context.py @@ -0,0 +1,79 @@ +"""Tests for logging context module.""" + +import uuid + +from core.logging.context import ( + clear_request_context, + get_request_id, + get_trace_id, + init_request_context, +) + + +class TestLoggingContext: + """Tests for the logging context functions.""" + + def test_init_creates_request_id(self): + """init_request_context should create a 10-char request ID.""" + init_request_context() + request_id = get_request_id() + assert len(request_id) == 10 + assert all(c in "0123456789abcdef" for c in request_id) + + def test_init_creates_trace_id(self): + """init_request_context should create a 32-char trace ID.""" + init_request_context() + trace_id = get_trace_id() + assert len(trace_id) == 32 + assert all(c in "0123456789abcdef" for c in trace_id) + + def test_trace_id_derived_from_request_id(self): + """trace_id should be deterministically derived from request_id.""" + init_request_context() + request_id = get_request_id() + trace_id = get_trace_id() + + # Verify trace_id is derived using uuid5 + expected_trace = uuid.uuid5(uuid.NAMESPACE_DNS, request_id).hex + assert trace_id == expected_trace + + def test_clear_resets_context(self): + """clear_request_context should reset both IDs to empty strings.""" + init_request_context() + assert get_request_id() != "" + assert get_trace_id() != "" + + clear_request_context() + assert get_request_id() == "" + assert get_trace_id() == "" + + def test_default_values_are_empty(self): + """Default values should be empty strings before init.""" + clear_request_context() + assert get_request_id() == "" + assert get_trace_id() == "" + + def test_multiple_inits_create_different_ids(self): + """Each init should create new unique IDs.""" + init_request_context() + first_request_id = get_request_id() + first_trace_id = get_trace_id() + + init_request_context() + second_request_id = get_request_id() + second_trace_id = get_trace_id() + + assert first_request_id != second_request_id + assert first_trace_id != second_trace_id + + def test_context_isolation(self): + """Context should be isolated per-call (no thread leakage in same thread).""" + init_request_context() + id1 = get_request_id() + + # Simulate another request + init_request_context() + id2 = get_request_id() + + # IDs should be different + assert id1 != id2 diff --git a/api/tests/unit_tests/core/logging/test_filters.py b/api/tests/unit_tests/core/logging/test_filters.py index 733c552454..1ebab3308e 100644 --- a/api/tests/unit_tests/core/logging/test_filters.py +++ b/api/tests/unit_tests/core/logging/test_filters.py @@ -6,8 +6,12 @@ from unittest import mock class TestTraceContextFilter: def test_sets_empty_trace_id_without_context(self): + from core.logging.context import clear_request_context from core.logging.filters import TraceContextFilter + # Ensure no context is set + clear_request_context() + filter = TraceContextFilter() record = logging.LogRecord( name="test", @@ -25,6 +29,36 @@ class TestTraceContextFilter: assert hasattr(record, "trace_id") assert hasattr(record, "span_id") assert hasattr(record, "req_id") + # Without context, IDs should be empty + assert record.trace_id == "" + assert record.req_id == "" + + def test_sets_trace_id_from_context(self): + """Test that trace_id and req_id are set from ContextVar when initialized.""" + from core.logging.context import init_request_context + from core.logging.filters import TraceContextFilter + + # Initialize context (no Flask needed!) + init_request_context() + + filter = TraceContextFilter() + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="", + lineno=0, + msg="test", + args=(), + exc_info=None, + ) + + filter.filter(record) + + # With context initialized, IDs should be set + assert record.trace_id != "" + assert len(record.trace_id) == 32 + assert record.req_id != "" + assert len(record.req_id) == 10 def test_filter_always_returns_true(self): from core.logging.filters import TraceContextFilter