mirror of https://github.com/langgenius/dify.git
reduce the dupicate with in tests
This commit is contained in:
parent
ed583f7e4b
commit
58283221cc
|
|
@ -2,9 +2,11 @@ import logging
|
|||
import time
|
||||
|
||||
from opentelemetry.trace import get_current_span
|
||||
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
|
||||
|
||||
from configs import dify_config
|
||||
from contexts.wrapper import RecyclableContextVar
|
||||
from core.logging.context import init_request_context
|
||||
from dify_app import DifyApp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -26,8 +28,6 @@ def create_flask_app_with_configs() -> DifyApp:
|
|||
@dify_app.before_request
|
||||
def before_request():
|
||||
# Initialize logging context for this request
|
||||
from core.logging.context import init_request_context
|
||||
|
||||
init_request_context()
|
||||
RecyclableContextVar.increment_thread_recycles()
|
||||
|
||||
|
|
@ -36,8 +36,6 @@ def create_flask_app_with_configs() -> DifyApp:
|
|||
@dify_app.after_request
|
||||
def add_trace_headers(response):
|
||||
try:
|
||||
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
|
||||
|
||||
span = get_current_span()
|
||||
ctx = span.get_span_context() if span else None
|
||||
|
||||
|
|
|
|||
|
|
@ -3,9 +3,24 @@
|
|||
import logging
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def log_record():
|
||||
return logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="",
|
||||
lineno=0,
|
||||
msg="test",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
|
||||
|
||||
class TestTraceContextFilter:
|
||||
def test_sets_empty_trace_id_without_context(self):
|
||||
def test_sets_empty_trace_id_without_context(self, log_record):
|
||||
from core.logging.context import clear_request_context
|
||||
from core.logging.filters import TraceContextFilter
|
||||
|
||||
|
|
@ -13,27 +28,17 @@ class TestTraceContextFilter:
|
|||
clear_request_context()
|
||||
|
||||
filter = TraceContextFilter()
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="",
|
||||
lineno=0,
|
||||
msg="test",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
|
||||
result = filter.filter(record)
|
||||
result = filter.filter(log_record)
|
||||
|
||||
assert result is True
|
||||
assert hasattr(record, "trace_id")
|
||||
assert hasattr(record, "span_id")
|
||||
assert hasattr(record, "req_id")
|
||||
assert hasattr(log_record, "trace_id")
|
||||
assert hasattr(log_record, "span_id")
|
||||
assert hasattr(log_record, "req_id")
|
||||
# Without context, IDs should be empty
|
||||
assert record.trace_id == ""
|
||||
assert record.req_id == ""
|
||||
assert log_record.trace_id == ""
|
||||
assert log_record.req_id == ""
|
||||
|
||||
def test_sets_trace_id_from_context(self):
|
||||
def test_sets_trace_id_from_context(self, log_record):
|
||||
"""Test that trace_id and req_id are set from ContextVar when initialized."""
|
||||
from core.logging.context import init_request_context
|
||||
from core.logging.filters import TraceContextFilter
|
||||
|
|
@ -42,42 +47,22 @@ class TestTraceContextFilter:
|
|||
init_request_context()
|
||||
|
||||
filter = TraceContextFilter()
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="",
|
||||
lineno=0,
|
||||
msg="test",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
|
||||
filter.filter(record)
|
||||
filter.filter(log_record)
|
||||
|
||||
# With context initialized, IDs should be set
|
||||
assert record.trace_id != ""
|
||||
assert len(record.trace_id) == 32
|
||||
assert record.req_id != ""
|
||||
assert len(record.req_id) == 10
|
||||
assert log_record.trace_id != ""
|
||||
assert len(log_record.trace_id) == 32
|
||||
assert log_record.req_id != ""
|
||||
assert len(log_record.req_id) == 10
|
||||
|
||||
def test_filter_always_returns_true(self):
|
||||
def test_filter_always_returns_true(self, log_record):
|
||||
from core.logging.filters import TraceContextFilter
|
||||
|
||||
filter = TraceContextFilter()
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="",
|
||||
lineno=0,
|
||||
msg="test",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
|
||||
result = filter.filter(record)
|
||||
result = filter.filter(log_record)
|
||||
assert result is True
|
||||
|
||||
def test_sets_trace_id_from_otel_when_available(self):
|
||||
def test_sets_trace_id_from_otel_when_available(self, log_record):
|
||||
from core.logging.filters import TraceContextFilter
|
||||
|
||||
mock_span = mock.MagicMock()
|
||||
|
|
@ -86,83 +71,44 @@ class TestTraceContextFilter:
|
|||
mock_context.span_id = 0x051581BF3BB55C45
|
||||
mock_span.get_span_context.return_value = mock_context
|
||||
|
||||
with mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span):
|
||||
with (
|
||||
mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0),
|
||||
mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0),
|
||||
):
|
||||
filter = TraceContextFilter()
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="",
|
||||
lineno=0,
|
||||
msg="test",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
with (
|
||||
mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span),
|
||||
mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0),
|
||||
mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0),
|
||||
):
|
||||
filter = TraceContextFilter()
|
||||
filter.filter(log_record)
|
||||
|
||||
filter.filter(record)
|
||||
|
||||
assert record.trace_id == "5b8aa5a2d2c872e8321cf37308d69df2"
|
||||
assert record.span_id == "051581bf3bb55c45"
|
||||
assert log_record.trace_id == "5b8aa5a2d2c872e8321cf37308d69df2"
|
||||
assert log_record.span_id == "051581bf3bb55c45"
|
||||
|
||||
|
||||
class TestIdentityContextFilter:
|
||||
def test_sets_empty_identity_without_request_context(self):
|
||||
def test_sets_empty_identity_without_request_context(self, log_record):
|
||||
from core.logging.filters import IdentityContextFilter
|
||||
|
||||
filter = IdentityContextFilter()
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="",
|
||||
lineno=0,
|
||||
msg="test",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
|
||||
result = filter.filter(record)
|
||||
result = filter.filter(log_record)
|
||||
|
||||
assert result is True
|
||||
assert record.tenant_id == ""
|
||||
assert record.user_id == ""
|
||||
assert record.user_type == ""
|
||||
assert log_record.tenant_id == ""
|
||||
assert log_record.user_id == ""
|
||||
assert log_record.user_type == ""
|
||||
|
||||
def test_filter_always_returns_true(self):
|
||||
def test_filter_always_returns_true(self, log_record):
|
||||
from core.logging.filters import IdentityContextFilter
|
||||
|
||||
filter = IdentityContextFilter()
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="",
|
||||
lineno=0,
|
||||
msg="test",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
|
||||
result = filter.filter(record)
|
||||
result = filter.filter(log_record)
|
||||
assert result is True
|
||||
|
||||
def test_handles_exception_gracefully(self):
|
||||
def test_handles_exception_gracefully(self, log_record):
|
||||
from core.logging.filters import IdentityContextFilter
|
||||
|
||||
filter = IdentityContextFilter()
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="",
|
||||
lineno=0,
|
||||
msg="test",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
|
||||
# Should not raise even if something goes wrong
|
||||
with mock.patch("core.logging.filters.flask.has_request_context", side_effect=Exception("Test error")):
|
||||
result = filter.filter(record)
|
||||
result = filter.filter(log_record)
|
||||
assert result is True
|
||||
assert record.tenant_id == ""
|
||||
assert log_record.tenant_id == ""
|
||||
|
|
|
|||
Loading…
Reference in New Issue