From 58283221cc01cdd1ec93807f159a3f5c78f54757 Mon Sep 17 00:00:00 2001 From: Byron Wang Date: Fri, 26 Dec 2025 10:30:35 +0800 Subject: [PATCH] reduce the dupicate with in tests --- api/app_factory.py | 6 +- .../unit_tests/core/logging/test_filters.py | 154 ++++++------------ 2 files changed, 52 insertions(+), 108 deletions(-) diff --git a/api/app_factory.py b/api/app_factory.py index 7b45736d95..f827842d68 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -2,9 +2,11 @@ import logging import time from opentelemetry.trace import get_current_span +from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID from configs import dify_config from contexts.wrapper import RecyclableContextVar +from core.logging.context import init_request_context from dify_app import DifyApp logger = logging.getLogger(__name__) @@ -26,8 +28,6 @@ def create_flask_app_with_configs() -> DifyApp: @dify_app.before_request def before_request(): # Initialize logging context for this request - from core.logging.context import init_request_context - init_request_context() RecyclableContextVar.increment_thread_recycles() @@ -36,8 +36,6 @@ def create_flask_app_with_configs() -> DifyApp: @dify_app.after_request def add_trace_headers(response): try: - from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID - span = get_current_span() ctx = span.get_span_context() if span else None diff --git a/api/tests/unit_tests/core/logging/test_filters.py b/api/tests/unit_tests/core/logging/test_filters.py index 1ebab3308e..b66ad111d5 100644 --- a/api/tests/unit_tests/core/logging/test_filters.py +++ b/api/tests/unit_tests/core/logging/test_filters.py @@ -3,9 +3,24 @@ import logging from unittest import mock +import pytest + + +@pytest.fixture +def log_record(): + return logging.LogRecord( + name="test", + level=logging.INFO, + pathname="", + lineno=0, + msg="test", + args=(), + exc_info=None, + ) + class TestTraceContextFilter: - def test_sets_empty_trace_id_without_context(self): + def test_sets_empty_trace_id_without_context(self, log_record): from core.logging.context import clear_request_context from core.logging.filters import TraceContextFilter @@ -13,27 +28,17 @@ class TestTraceContextFilter: clear_request_context() filter = TraceContextFilter() - record = logging.LogRecord( - name="test", - level=logging.INFO, - pathname="", - lineno=0, - msg="test", - args=(), - exc_info=None, - ) - - result = filter.filter(record) + result = filter.filter(log_record) assert result is True - assert hasattr(record, "trace_id") - assert hasattr(record, "span_id") - assert hasattr(record, "req_id") + assert hasattr(log_record, "trace_id") + assert hasattr(log_record, "span_id") + assert hasattr(log_record, "req_id") # Without context, IDs should be empty - assert record.trace_id == "" - assert record.req_id == "" + assert log_record.trace_id == "" + assert log_record.req_id == "" - def test_sets_trace_id_from_context(self): + def test_sets_trace_id_from_context(self, log_record): """Test that trace_id and req_id are set from ContextVar when initialized.""" from core.logging.context import init_request_context from core.logging.filters import TraceContextFilter @@ -42,42 +47,22 @@ class TestTraceContextFilter: init_request_context() filter = TraceContextFilter() - record = logging.LogRecord( - name="test", - level=logging.INFO, - pathname="", - lineno=0, - msg="test", - args=(), - exc_info=None, - ) - - filter.filter(record) + filter.filter(log_record) # With context initialized, IDs should be set - assert record.trace_id != "" - assert len(record.trace_id) == 32 - assert record.req_id != "" - assert len(record.req_id) == 10 + assert log_record.trace_id != "" + assert len(log_record.trace_id) == 32 + assert log_record.req_id != "" + assert len(log_record.req_id) == 10 - def test_filter_always_returns_true(self): + def test_filter_always_returns_true(self, log_record): from core.logging.filters import TraceContextFilter filter = TraceContextFilter() - record = logging.LogRecord( - name="test", - level=logging.INFO, - pathname="", - lineno=0, - msg="test", - args=(), - exc_info=None, - ) - - result = filter.filter(record) + result = filter.filter(log_record) assert result is True - def test_sets_trace_id_from_otel_when_available(self): + def test_sets_trace_id_from_otel_when_available(self, log_record): from core.logging.filters import TraceContextFilter mock_span = mock.MagicMock() @@ -86,83 +71,44 @@ class TestTraceContextFilter: mock_context.span_id = 0x051581BF3BB55C45 mock_span.get_span_context.return_value = mock_context - with mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span): - with ( - mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0), - mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0), - ): - filter = TraceContextFilter() - record = logging.LogRecord( - name="test", - level=logging.INFO, - pathname="", - lineno=0, - msg="test", - args=(), - exc_info=None, - ) + with ( + mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span), + mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0), + mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0), + ): + filter = TraceContextFilter() + filter.filter(log_record) - filter.filter(record) - - assert record.trace_id == "5b8aa5a2d2c872e8321cf37308d69df2" - assert record.span_id == "051581bf3bb55c45" + assert log_record.trace_id == "5b8aa5a2d2c872e8321cf37308d69df2" + assert log_record.span_id == "051581bf3bb55c45" class TestIdentityContextFilter: - def test_sets_empty_identity_without_request_context(self): + def test_sets_empty_identity_without_request_context(self, log_record): from core.logging.filters import IdentityContextFilter filter = IdentityContextFilter() - record = logging.LogRecord( - name="test", - level=logging.INFO, - pathname="", - lineno=0, - msg="test", - args=(), - exc_info=None, - ) - - result = filter.filter(record) + result = filter.filter(log_record) assert result is True - assert record.tenant_id == "" - assert record.user_id == "" - assert record.user_type == "" + assert log_record.tenant_id == "" + assert log_record.user_id == "" + assert log_record.user_type == "" - def test_filter_always_returns_true(self): + def test_filter_always_returns_true(self, log_record): from core.logging.filters import IdentityContextFilter filter = IdentityContextFilter() - record = logging.LogRecord( - name="test", - level=logging.INFO, - pathname="", - lineno=0, - msg="test", - args=(), - exc_info=None, - ) - - result = filter.filter(record) + result = filter.filter(log_record) assert result is True - def test_handles_exception_gracefully(self): + def test_handles_exception_gracefully(self, log_record): from core.logging.filters import IdentityContextFilter filter = IdentityContextFilter() - record = logging.LogRecord( - name="test", - level=logging.INFO, - pathname="", - lineno=0, - msg="test", - args=(), - exc_info=None, - ) # Should not raise even if something goes wrong with mock.patch("core.logging.filters.flask.has_request_context", side_effect=Exception("Test error")): - result = filter.filter(record) + result = filter.filter(log_record) assert result is True - assert record.tenant_id == "" + assert log_record.tenant_id == ""