reduce the dupicate with in tests

This commit is contained in:
Byron Wang 2025-12-26 10:30:35 +08:00
parent ed583f7e4b
commit 58283221cc
No known key found for this signature in database
GPG Key ID: 335E934E215AD579
2 changed files with 52 additions and 108 deletions

View File

@ -2,9 +2,11 @@ import logging
import time
from opentelemetry.trace import get_current_span
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
from configs import dify_config
from contexts.wrapper import RecyclableContextVar
from core.logging.context import init_request_context
from dify_app import DifyApp
logger = logging.getLogger(__name__)
@ -26,8 +28,6 @@ def create_flask_app_with_configs() -> DifyApp:
@dify_app.before_request
def before_request():
# Initialize logging context for this request
from core.logging.context import init_request_context
init_request_context()
RecyclableContextVar.increment_thread_recycles()
@ -36,8 +36,6 @@ def create_flask_app_with_configs() -> DifyApp:
@dify_app.after_request
def add_trace_headers(response):
try:
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
span = get_current_span()
ctx = span.get_span_context() if span else None

View File

@ -3,9 +3,24 @@
import logging
from unittest import mock
import pytest
@pytest.fixture
def log_record():
return logging.LogRecord(
name="test",
level=logging.INFO,
pathname="",
lineno=0,
msg="test",
args=(),
exc_info=None,
)
class TestTraceContextFilter:
def test_sets_empty_trace_id_without_context(self):
def test_sets_empty_trace_id_without_context(self, log_record):
from core.logging.context import clear_request_context
from core.logging.filters import TraceContextFilter
@ -13,27 +28,17 @@ class TestTraceContextFilter:
clear_request_context()
filter = TraceContextFilter()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="",
lineno=0,
msg="test",
args=(),
exc_info=None,
)
result = filter.filter(record)
result = filter.filter(log_record)
assert result is True
assert hasattr(record, "trace_id")
assert hasattr(record, "span_id")
assert hasattr(record, "req_id")
assert hasattr(log_record, "trace_id")
assert hasattr(log_record, "span_id")
assert hasattr(log_record, "req_id")
# Without context, IDs should be empty
assert record.trace_id == ""
assert record.req_id == ""
assert log_record.trace_id == ""
assert log_record.req_id == ""
def test_sets_trace_id_from_context(self):
def test_sets_trace_id_from_context(self, log_record):
"""Test that trace_id and req_id are set from ContextVar when initialized."""
from core.logging.context import init_request_context
from core.logging.filters import TraceContextFilter
@ -42,42 +47,22 @@ class TestTraceContextFilter:
init_request_context()
filter = TraceContextFilter()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="",
lineno=0,
msg="test",
args=(),
exc_info=None,
)
filter.filter(record)
filter.filter(log_record)
# With context initialized, IDs should be set
assert record.trace_id != ""
assert len(record.trace_id) == 32
assert record.req_id != ""
assert len(record.req_id) == 10
assert log_record.trace_id != ""
assert len(log_record.trace_id) == 32
assert log_record.req_id != ""
assert len(log_record.req_id) == 10
def test_filter_always_returns_true(self):
def test_filter_always_returns_true(self, log_record):
from core.logging.filters import TraceContextFilter
filter = TraceContextFilter()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="",
lineno=0,
msg="test",
args=(),
exc_info=None,
)
result = filter.filter(record)
result = filter.filter(log_record)
assert result is True
def test_sets_trace_id_from_otel_when_available(self):
def test_sets_trace_id_from_otel_when_available(self, log_record):
from core.logging.filters import TraceContextFilter
mock_span = mock.MagicMock()
@ -86,83 +71,44 @@ class TestTraceContextFilter:
mock_context.span_id = 0x051581BF3BB55C45
mock_span.get_span_context.return_value = mock_context
with mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span):
with (
mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0),
mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0),
):
filter = TraceContextFilter()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="",
lineno=0,
msg="test",
args=(),
exc_info=None,
)
with (
mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span),
mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0),
mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0),
):
filter = TraceContextFilter()
filter.filter(log_record)
filter.filter(record)
assert record.trace_id == "5b8aa5a2d2c872e8321cf37308d69df2"
assert record.span_id == "051581bf3bb55c45"
assert log_record.trace_id == "5b8aa5a2d2c872e8321cf37308d69df2"
assert log_record.span_id == "051581bf3bb55c45"
class TestIdentityContextFilter:
def test_sets_empty_identity_without_request_context(self):
def test_sets_empty_identity_without_request_context(self, log_record):
from core.logging.filters import IdentityContextFilter
filter = IdentityContextFilter()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="",
lineno=0,
msg="test",
args=(),
exc_info=None,
)
result = filter.filter(record)
result = filter.filter(log_record)
assert result is True
assert record.tenant_id == ""
assert record.user_id == ""
assert record.user_type == ""
assert log_record.tenant_id == ""
assert log_record.user_id == ""
assert log_record.user_type == ""
def test_filter_always_returns_true(self):
def test_filter_always_returns_true(self, log_record):
from core.logging.filters import IdentityContextFilter
filter = IdentityContextFilter()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="",
lineno=0,
msg="test",
args=(),
exc_info=None,
)
result = filter.filter(record)
result = filter.filter(log_record)
assert result is True
def test_handles_exception_gracefully(self):
def test_handles_exception_gracefully(self, log_record):
from core.logging.filters import IdentityContextFilter
filter = IdentityContextFilter()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="",
lineno=0,
msg="test",
args=(),
exc_info=None,
)
# Should not raise even if something goes wrong
with mock.patch("core.logging.filters.flask.has_request_context", side_effect=Exception("Test error")):
result = filter.filter(record)
result = filter.filter(log_record)
assert result is True
assert record.tenant_id == ""
assert log_record.tenant_id == ""