feat: add structured log and keep the trace id injecting to all request

This commit is contained in:
Byron Wang 2025-12-24 21:04:37 +08:00
parent f439e081b5
commit 183b38ecb0
No known key found for this signature in database
GPG Key ID: 335E934E215AD579
17 changed files with 1075 additions and 81 deletions

View File

@ -28,25 +28,33 @@ def create_flask_app_with_configs() -> DifyApp:
# add an unique identifier to each request
RecyclableContextVar.increment_thread_recycles()
# add after request hook for injecting X-Trace-Id header from OpenTelemetry span context
# add after request hook for injecting trace headers from OpenTelemetry span context
# Only adds headers when OTEL is enabled and has valid context
@dify_app.after_request
def add_trace_id_header(response):
def add_trace_headers(response):
try:
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
span = get_current_span()
ctx = span.get_span_context() if span else None
if ctx and ctx.is_valid:
trace_id_hex = format(ctx.trace_id, "032x")
# Avoid duplicates if some middleware added it
if "X-Trace-Id" not in response.headers:
response.headers["X-Trace-Id"] = trace_id_hex
if not ctx or not ctx.is_valid:
return response
# Inject trace headers from OTEL context
if ctx.trace_id != INVALID_TRACE_ID and "X-Trace-Id" not in response.headers:
response.headers["X-Trace-Id"] = format(ctx.trace_id, "032x")
if ctx.span_id != INVALID_SPAN_ID and "X-Span-Id" not in response.headers:
response.headers["X-Span-Id"] = format(ctx.span_id, "016x")
except Exception:
# Never break the response due to tracing header injection
logger.warning("Failed to add trace ID to response header", exc_info=True)
logger.warning("Failed to add trace headers to response", exc_info=True)
return response
# Capture the decorator's return value to avoid pyright reportUnusedFunction
_ = before_request
_ = add_trace_id_header
_ = add_trace_headers
return dify_app

View File

@ -587,6 +587,11 @@ class LoggingConfig(BaseSettings):
default="INFO",
)
LOG_OUTPUT_FORMAT: Literal["text", "json"] = Field(
description="Log output format: 'text' for human-readable, 'json' for structured JSON logs.",
default="text",
)
LOG_FILE: str | None = Field(
description="File path for log output.",
default=None,

View File

@ -88,12 +88,48 @@ def _get_user_provided_host_header(headers: dict | None) -> str | None:
return None
def _inject_trace_headers(headers: dict | None) -> dict:
"""
Inject W3C traceparent header for distributed tracing.
When OTEL is enabled, HTTPXClientInstrumentor handles trace propagation automatically.
When OTEL is disabled, we manually inject the traceparent header.
"""
if headers is None:
headers = {}
# Skip if already present (case-insensitive check)
for key in headers:
if key.lower() == "traceparent":
return headers
# Skip if OTEL is enabled - HTTPXClientInstrumentor handles this automatically
if dify_config.ENABLE_OTEL:
return headers
# Generate and inject traceparent for non-OTEL scenarios
try:
from core.helper.trace_id_helper import generate_traceparent_header
traceparent = generate_traceparent_header()
if traceparent:
headers["traceparent"] = traceparent
except Exception:
# Silently ignore errors to avoid breaking requests
logger.debug("Failed to generate traceparent header", exc_info=True)
return headers
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
if "allow_redirects" in kwargs:
allow_redirects = kwargs.pop("allow_redirects")
if "follow_redirects" not in kwargs:
kwargs["follow_redirects"] = allow_redirects
# Extract follow_redirects - it's a send() parameter, not build_request() parameter
follow_redirects = kwargs.pop("follow_redirects", False)
if "timeout" not in kwargs:
kwargs["timeout"] = httpx.Timeout(
timeout=dify_config.SSRF_DEFAULT_TIME_OUT,
@ -106,10 +142,14 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY)
client = _get_ssrf_client(verify_option)
# Inject traceparent header for distributed tracing (when OTEL is not enabled)
headers = kwargs.get("headers") or {}
headers = _inject_trace_headers(headers)
kwargs["headers"] = headers
# Preserve user-provided Host header
# When using a forward proxy, httpx may override the Host header based on the URL.
# We extract and preserve any explicitly set Host header to support virtual hosting.
headers = kwargs.get("headers", {})
user_provided_host = _get_user_provided_host_header(headers)
retries = 0
@ -124,7 +164,7 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
if user_provided_host is not None:
request.headers["Host"] = user_provided_host
response = client.send(request)
response = client.send(request, follow_redirects=follow_redirects)
# Check for SSRF protection by Squid proxy
if response.status_code in (401, 403):

View File

@ -103,3 +103,67 @@ def parse_traceparent_header(traceparent: str) -> str | None:
if len(parts) == 4 and len(parts[1]) == 32:
return parts[1]
return None
def get_span_id_from_otel_context() -> str | None:
"""
Retrieve the current span ID from the active OpenTelemetry trace context.
Returns:
A 16-character hex string representing the span ID, or None if not available.
"""
try:
from opentelemetry.trace import get_current_span
from opentelemetry.trace.span import INVALID_SPAN_ID
span = get_current_span()
if not span:
return None
span_context = span.get_span_context()
if not span_context or span_context.span_id == INVALID_SPAN_ID:
return None
return f"{span_context.span_id:016x}"
except Exception:
return None
def generate_traceparent_header() -> str | None:
"""
Generate a W3C traceparent header from the current context.
Uses OpenTelemetry context if available, otherwise generates new IDs
based on the Flask request context.
Format: {version}-{trace_id}-{span_id}-{flags}
Example: 00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01
Returns:
A valid traceparent header string, or None if generation fails.
"""
import uuid
# Try OTEL context first
trace_id = get_trace_id_from_otel_context()
span_id = get_span_id_from_otel_context()
if trace_id and span_id:
return f"00-{trace_id}-{span_id}-01"
# Fallback: generate new trace context
try:
import flask
if flask.has_request_context() and hasattr(flask.g, "request_id"):
# Derive trace_id from request_id for consistency
trace_id = uuid.uuid5(uuid.NAMESPACE_DNS, flask.g.request_id).hex
else:
trace_id = uuid.uuid4().hex
except Exception:
trace_id = uuid.uuid4().hex
# Generate a new span_id (16 hex chars)
span_id = uuid.uuid4().hex[:16]
return f"00-{trace_id}-{span_id}-01"

View File

@ -0,0 +1,10 @@
"""Structured logging components for Dify."""
from core.logging.filters import IdentityContextFilter, TraceContextFilter
from core.logging.structured_formatter import StructuredJSONFormatter
__all__ = [
"IdentityContextFilter",
"StructuredJSONFormatter",
"TraceContextFilter",
]

115
api/core/logging/filters.py Normal file
View File

@ -0,0 +1,115 @@
"""Logging filters for structured logging."""
import contextlib
import logging
import uuid
import flask
class TraceContextFilter(logging.Filter):
"""
Filter that adds trace_id and span_id to log records.
Integrates with OpenTelemetry when available, falls back to request_id.
"""
def filter(self, record: logging.LogRecord) -> bool:
# Get trace context from OpenTelemetry
trace_id, span_id = self._get_otel_context()
# Set trace_id (fallback to request_id if no OTEL context)
if trace_id:
record.trace_id = trace_id
else:
record.trace_id = self._get_or_create_request_trace_id()
record.span_id = span_id or ""
# For backward compatibility, also set req_id
record.req_id = self._get_request_id()
return True
def _get_otel_context(self) -> tuple[str, str]:
"""Extract trace_id and span_id from OpenTelemetry context."""
with contextlib.suppress(Exception):
from opentelemetry.trace import get_current_span
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
span = get_current_span()
if span and span.get_span_context():
ctx = span.get_span_context()
if ctx.is_valid and ctx.trace_id != INVALID_TRACE_ID:
trace_id = f"{ctx.trace_id:032x}"
span_id = f"{ctx.span_id:016x}" if ctx.span_id != INVALID_SPAN_ID else ""
return trace_id, span_id
return "", ""
def _get_request_id(self) -> str:
"""Get request ID from Flask context."""
if flask.has_request_context():
if hasattr(flask.g, "request_id"):
return flask.g.request_id
flask.g.request_id = uuid.uuid4().hex[:10]
return flask.g.request_id
return ""
def _get_or_create_request_trace_id(self) -> str:
"""Get or create a trace_id derived from request context."""
if flask.has_request_context():
if hasattr(flask.g, "_trace_id"):
return flask.g._trace_id
# Derive trace_id from request_id for consistency
request_id = self._get_request_id()
if request_id:
# Generate a 32-char hex trace_id from request_id
flask.g._trace_id = uuid.uuid5(uuid.NAMESPACE_DNS, request_id).hex
return flask.g._trace_id
return ""
class IdentityContextFilter(logging.Filter):
"""
Filter that adds user identity context to log records.
Extracts tenant_id, user_id, and user_type from Flask-Login current_user.
"""
def filter(self, record: logging.LogRecord) -> bool:
identity = self._extract_identity()
record.tenant_id = identity.get("tenant_id", "")
record.user_id = identity.get("user_id", "")
record.user_type = identity.get("user_type", "")
return True
def _extract_identity(self) -> dict[str, str]:
"""Extract identity from current_user if in request context."""
try:
if not flask.has_request_context():
return {}
from flask_login import current_user
# Check if user is authenticated using the proxy
if not current_user.is_authenticated:
return {}
# Access the underlying user object
user = current_user
from models import Account
from models.model import EndUser
identity: dict[str, str] = {}
if isinstance(user, Account):
if user.current_tenant_id:
identity["tenant_id"] = user.current_tenant_id
identity["user_id"] = user.id
identity["user_type"] = "account"
elif isinstance(user, EndUser):
identity["tenant_id"] = user.tenant_id
identity["user_id"] = user.id
identity["user_type"] = user.type or "end_user"
return identity
except Exception:
return {}

View File

@ -0,0 +1,101 @@
"""Structured JSON log formatter for Dify."""
import logging
import traceback
from datetime import UTC, datetime
from typing import Any
import orjson
from configs import dify_config
class StructuredJSONFormatter(logging.Formatter):
"""
JSON log formatter following the specified schema:
{
"ts": "ISO 8601 UTC",
"severity": "INFO|ERROR|WARN|DEBUG",
"service": "service name",
"caller": "file:line",
"trace_id": "hex 32",
"span_id": "hex 16",
"identity": { "tenant_id", "user_id", "user_type" },
"message": "log message",
"attributes": { ... },
"stack_trace": "..."
}
"""
SEVERITY_MAP: dict[int, str] = {
logging.DEBUG: "DEBUG",
logging.INFO: "INFO",
logging.WARNING: "WARN",
logging.ERROR: "ERROR",
logging.CRITICAL: "ERROR",
}
def __init__(self, service_name: str | None = None):
super().__init__()
self._service_name = service_name or dify_config.APPLICATION_NAME
def format(self, record: logging.LogRecord) -> str:
log_dict = self._build_log_dict(record)
return orjson.dumps(log_dict).decode("utf-8")
def _build_log_dict(self, record: logging.LogRecord) -> dict[str, Any]:
# Core fields
log_dict: dict[str, Any] = {
"ts": datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z"),
"severity": self.SEVERITY_MAP.get(record.levelno, "INFO"),
"service": self._service_name,
"caller": f"{record.filename}:{record.lineno}",
"message": record.getMessage(),
}
# Trace context (from TraceContextFilter)
trace_id = getattr(record, "trace_id", "") or ""
span_id = getattr(record, "span_id", "") or ""
if trace_id:
log_dict["trace_id"] = trace_id
if span_id:
log_dict["span_id"] = span_id
# Identity context (from IdentityContextFilter)
identity = self._extract_identity(record)
if identity:
log_dict["identity"] = identity
# Dynamic attributes
attributes = getattr(record, "attributes", None)
if attributes:
log_dict["attributes"] = attributes
# Stack trace for errors with exceptions
if record.exc_info and record.levelno >= logging.ERROR:
log_dict["stack_trace"] = self._format_exception(record.exc_info)
return log_dict
def _extract_identity(self, record: logging.LogRecord) -> dict[str, str] | None:
tenant_id = getattr(record, "tenant_id", None)
user_id = getattr(record, "user_id", None)
user_type = getattr(record, "user_type", None)
if not any([tenant_id, user_id, user_type]):
return None
identity: dict[str, str] = {}
if tenant_id:
identity["tenant_id"] = tenant_id
if user_id:
identity["user_id"] = user_id
if user_type:
identity["user_type"] = user_type
return identity
def _format_exception(self, exc_info: tuple[Any, ...]) -> str:
if exc_info and exc_info[0] is not None:
return "".join(traceback.format_exception(*exc_info))
return ""

View File

@ -103,6 +103,9 @@ class BasePluginClient:
prepared_headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY
prepared_headers.setdefault("Accept-Encoding", "gzip, deflate, br")
# Inject traceparent header for distributed tracing
self._inject_trace_headers(prepared_headers)
prepared_data: bytes | dict[str, Any] | str | None = (
data if isinstance(data, (bytes, str, dict)) or data is None else None
)
@ -114,6 +117,28 @@ class BasePluginClient:
return str(url), prepared_headers, prepared_data, params, files
def _inject_trace_headers(self, headers: dict[str, str]) -> None:
"""
Inject W3C traceparent header for distributed tracing.
This ensures trace context is propagated to plugin daemon even if
HTTPXClientInstrumentor doesn't cover module-level httpx functions.
"""
import contextlib
# Skip if already present (case-insensitive check)
for key in headers:
if key.lower() == "traceparent":
return
# Inject traceparent - works as fallback when OTEL instrumentation doesn't cover this call
with contextlib.suppress(Exception):
from core.helper.trace_id_helper import generate_traceparent_header
traceparent = generate_traceparent_header()
if traceparent:
headers["traceparent"] = traceparent
def _stream_request(
self,
method: str,

View File

@ -1,3 +1,5 @@
"""Logging extension for Dify Flask application."""
import logging
import os
import sys
@ -7,12 +9,14 @@ from logging.handlers import RotatingFileHandler
import flask
from configs import dify_config
from core.helper.trace_id_helper import get_trace_id_from_otel_context
from dify_app import DifyApp
def init_app(app: DifyApp):
"""Initialize logging with support for text or JSON format."""
log_handlers: list[logging.Handler] = []
# File handler
log_file = dify_config.LOG_FILE
if log_file:
log_dir = os.path.dirname(log_file)
@ -25,27 +29,53 @@ def init_app(app: DifyApp):
)
)
# Always add StreamHandler to log to console
# Console handler
sh = logging.StreamHandler(sys.stdout)
log_handlers.append(sh)
# Apply RequestIdFilter to all handlers
for handler in log_handlers:
handler.addFilter(RequestIdFilter())
# Apply filters to all handlers
from core.logging.filters import IdentityContextFilter, TraceContextFilter
for handler in log_handlers:
handler.addFilter(TraceContextFilter())
handler.addFilter(IdentityContextFilter())
# Configure formatter based on format type
formatter = _create_formatter()
for handler in log_handlers:
handler.setFormatter(formatter)
# Configure root logger
logging.basicConfig(
level=dify_config.LOG_LEVEL,
format=dify_config.LOG_FORMAT,
datefmt=dify_config.LOG_DATEFORMAT,
handlers=log_handlers,
force=True,
)
# Apply RequestIdFormatter to all handlers
apply_request_id_formatter()
# Disable propagation for noisy loggers to avoid duplicate logs
logging.getLogger("sqlalchemy.engine").propagate = False
# Apply timezone if specified (only for text format)
if dify_config.LOG_OUTPUT_FORMAT == "text":
_apply_timezone(log_handlers)
def _create_formatter() -> logging.Formatter:
"""Create appropriate formatter based on configuration."""
if dify_config.LOG_OUTPUT_FORMAT == "json":
from core.logging.structured_formatter import StructuredJSONFormatter
return StructuredJSONFormatter()
else:
# Text format - use existing pattern with backward compatible formatter
return _TextFormatter(
fmt=dify_config.LOG_FORMAT,
datefmt=dify_config.LOG_DATEFORMAT,
)
def _apply_timezone(handlers: list[logging.Handler]):
"""Apply timezone conversion to text formatters."""
log_tz = dify_config.LOG_TZ
if log_tz:
from datetime import datetime
@ -57,26 +87,41 @@ def init_app(app: DifyApp):
def time_converter(seconds):
return datetime.fromtimestamp(seconds, tz=timezone).timetuple()
for handler in logging.root.handlers:
for handler in handlers:
if handler.formatter:
handler.formatter.converter = time_converter
handler.formatter.converter = time_converter # type: ignore[attr-defined]
def get_request_id():
if getattr(flask.g, "request_id", None):
class _TextFormatter(logging.Formatter):
"""Text formatter that ensures trace_id and req_id are always present."""
def format(self, record: logging.LogRecord) -> str:
if not hasattr(record, "req_id"):
record.req_id = ""
if not hasattr(record, "trace_id"):
record.trace_id = ""
if not hasattr(record, "span_id"):
record.span_id = ""
return super().format(record)
def get_request_id() -> str:
"""Get or create request ID for current request context."""
if flask.has_request_context():
if getattr(flask.g, "request_id", None):
return flask.g.request_id
flask.g.request_id = uuid.uuid4().hex[:10]
return flask.g.request_id
new_uuid = uuid.uuid4().hex[:10]
flask.g.request_id = new_uuid
return new_uuid
return ""
# Backward compatibility aliases
class RequestIdFilter(logging.Filter):
# This is a logging filter that makes the request ID available for use in
# the logging format. Note that we're checking if we're in a request
# context, as we may want to log things before Flask is fully loaded.
def filter(self, record):
"""Deprecated: Use TraceContextFilter from core.logging.filters instead."""
def filter(self, record: logging.LogRecord) -> bool:
from core.helper.trace_id_helper import get_trace_id_from_otel_context
trace_id = get_trace_id_from_otel_context() or ""
record.req_id = get_request_id() if flask.has_request_context() else ""
record.trace_id = trace_id
@ -84,7 +129,9 @@ class RequestIdFilter(logging.Filter):
class RequestIdFormatter(logging.Formatter):
def format(self, record):
"""Deprecated: Use _TextFormatter instead."""
def format(self, record: logging.LogRecord) -> str:
if not hasattr(record, "req_id"):
record.req_id = ""
if not hasattr(record, "trace_id"):
@ -93,6 +140,7 @@ class RequestIdFormatter(logging.Formatter):
def apply_request_id_formatter():
"""Deprecated: Formatter is now applied in init_app."""
for handler in logging.root.handlers:
if handler.formatter:
handler.formatter = RequestIdFormatter(dify_config.LOG_FORMAT, dify_config.LOG_DATEFORMAT)

View File

@ -19,26 +19,43 @@ logger = logging.getLogger(__name__)
class ExceptionLoggingHandler(logging.Handler):
"""
Handler that records exceptions to the current OpenTelemetry span.
Unlike creating a new span, this records exceptions on the existing span
to maintain trace context consistency throughout the request lifecycle.
"""
def emit(self, record: logging.LogRecord):
with contextlib.suppress(Exception):
if record.exc_info:
tracer = get_tracer_provider().get_tracer("dify.exception.logging")
with tracer.start_as_current_span(
"log.exception",
attributes={
"log.level": record.levelname,
"log.message": record.getMessage(),
"log.logger": record.name,
"log.file.path": record.pathname,
"log.file.line": record.lineno,
},
) as span:
span.set_status(StatusCode.ERROR)
if record.exc_info[1]:
span.record_exception(record.exc_info[1])
span.set_attribute("exception.message", str(record.exc_info[1]))
if record.exc_info[0]:
span.set_attribute("exception.type", record.exc_info[0].__name__)
if not record.exc_info:
return
from opentelemetry.trace import get_current_span
span = get_current_span()
if not span or not span.is_recording():
return
# Record exception on the current span instead of creating a new one
span.set_status(StatusCode.ERROR, record.getMessage())
# Add log context as span events/attributes
span.add_event(
"log.exception",
attributes={
"log.level": record.levelname,
"log.message": record.getMessage(),
"log.logger": record.name,
"log.file.path": record.pathname,
"log.file.line": record.lineno,
},
)
if record.exc_info[1]:
span.record_exception(record.exc_info[1])
if record.exc_info[0]:
span.set_attribute("exception.type", record.exc_info[0].__name__)
def instrument_exception_logging() -> None:

View File

@ -1,5 +1,4 @@
import re
import sys
from collections.abc import Mapping
from typing import Any
@ -109,11 +108,8 @@ def register_external_error_handlers(api: Api):
data.setdefault("code", "unknown")
data.setdefault("status", status_code)
# Log stack
exc_info: Any = sys.exc_info()
if exc_info[1] is None:
exc_info = (None, None, None)
current_app.log_exception(exc_info)
# Note: Exception logging is handled by Flask/Flask-RESTX framework automatically
# Explicit log_exception call removed to avoid duplicate log entries
return data, status_code

View File

@ -143,7 +143,7 @@ def test_host_header_preservation_with_user_header(mock_get_client):
mock_client.build_request.assert_called_once()
# Verify the Host header was set on the request object
assert mock_request.headers.get("Host") == custom_host
mock_client.send.assert_called_once_with(mock_request)
mock_client.send.assert_called_once_with(mock_request, follow_redirects=False)
@patch("core.helper.ssrf_proxy._get_ssrf_client")
@ -160,3 +160,101 @@ def test_host_header_preservation_case_insensitive(mock_get_client, host_key):
mock_get_client.return_value = mock_client
response = make_request("GET", "http://example.com", headers={host_key: "api.example.com"})
assert mock_request.headers.get("Host") == "api.example.com"
class TestFollowRedirectsParameter:
"""Tests for follow_redirects parameter handling.
These tests verify that follow_redirects is passed to send(), not build_request().
This is critical because httpx.Client.build_request() does not accept follow_redirects.
"""
@patch("core.helper.ssrf_proxy._get_ssrf_client")
def test_follow_redirects_not_passed_to_build_request(self, mock_get_client):
"""Verify follow_redirects is NOT passed to build_request()."""
mock_client = MagicMock()
mock_request = MagicMock()
mock_request.headers = {}
mock_response = MagicMock()
mock_response.status_code = 200
mock_client.send.return_value = mock_response
mock_client.build_request.return_value = mock_request
mock_get_client.return_value = mock_client
make_request("GET", "http://example.com", follow_redirects=True)
# Verify follow_redirects was NOT passed to build_request
call_kwargs = mock_client.build_request.call_args.kwargs
assert "follow_redirects" not in call_kwargs, "follow_redirects should not be passed to build_request()"
@patch("core.helper.ssrf_proxy._get_ssrf_client")
def test_follow_redirects_passed_to_send(self, mock_get_client):
"""Verify follow_redirects IS passed to send()."""
mock_client = MagicMock()
mock_request = MagicMock()
mock_request.headers = {}
mock_response = MagicMock()
mock_response.status_code = 200
mock_client.send.return_value = mock_response
mock_client.build_request.return_value = mock_request
mock_get_client.return_value = mock_client
make_request("GET", "http://example.com", follow_redirects=True)
# Verify follow_redirects WAS passed to send
mock_client.send.assert_called_once_with(mock_request, follow_redirects=True)
@patch("core.helper.ssrf_proxy._get_ssrf_client")
def test_allow_redirects_converted_to_follow_redirects(self, mock_get_client):
"""Verify allow_redirects (requests-style) is converted to follow_redirects (httpx-style)."""
mock_client = MagicMock()
mock_request = MagicMock()
mock_request.headers = {}
mock_response = MagicMock()
mock_response.status_code = 200
mock_client.send.return_value = mock_response
mock_client.build_request.return_value = mock_request
mock_get_client.return_value = mock_client
# Use allow_redirects (requests-style parameter)
make_request("GET", "http://example.com", allow_redirects=True)
# Verify it was converted to follow_redirects for send()
mock_client.send.assert_called_once_with(mock_request, follow_redirects=True)
# Verify allow_redirects was NOT passed to build_request
call_kwargs = mock_client.build_request.call_args.kwargs
assert "allow_redirects" not in call_kwargs
@patch("core.helper.ssrf_proxy._get_ssrf_client")
def test_follow_redirects_default_is_false(self, mock_get_client):
"""Verify follow_redirects defaults to False when not specified."""
mock_client = MagicMock()
mock_request = MagicMock()
mock_request.headers = {}
mock_response = MagicMock()
mock_response.status_code = 200
mock_client.send.return_value = mock_response
mock_client.build_request.return_value = mock_request
mock_get_client.return_value = mock_client
make_request("GET", "http://example.com")
# Verify default is False
mock_client.send.assert_called_once_with(mock_request, follow_redirects=False)
@patch("core.helper.ssrf_proxy._get_ssrf_client")
def test_follow_redirects_takes_precedence_over_allow_redirects(self, mock_get_client):
"""Verify follow_redirects takes precedence when both are specified."""
mock_client = MagicMock()
mock_request = MagicMock()
mock_request.headers = {}
mock_response = MagicMock()
mock_response.status_code = 200
mock_client.send.return_value = mock_response
mock_client.build_request.return_value = mock_request
mock_get_client.return_value = mock_client
# Both specified - follow_redirects should take precedence
make_request("GET", "http://example.com", allow_redirects=False, follow_redirects=True)
mock_client.send.assert_called_once_with(mock_request, follow_redirects=True)

View File

@ -0,0 +1,134 @@
"""Tests for logging filters."""
import logging
from unittest import mock
class TestTraceContextFilter:
def test_sets_empty_trace_id_without_context(self):
from core.logging.filters import TraceContextFilter
filter = TraceContextFilter()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="",
lineno=0,
msg="test",
args=(),
exc_info=None,
)
result = filter.filter(record)
assert result is True
assert hasattr(record, "trace_id")
assert hasattr(record, "span_id")
assert hasattr(record, "req_id")
def test_filter_always_returns_true(self):
from core.logging.filters import TraceContextFilter
filter = TraceContextFilter()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="",
lineno=0,
msg="test",
args=(),
exc_info=None,
)
result = filter.filter(record)
assert result is True
def test_sets_trace_id_from_otel_when_available(self):
from core.logging.filters import TraceContextFilter
mock_span = mock.MagicMock()
mock_context = mock.MagicMock()
mock_context.trace_id = 0x5B8AA5A2D2C872E8321CF37308D69DF2
mock_context.span_id = 0x051581BF3BB55C45
mock_span.get_span_context.return_value = mock_context
with mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span):
with (
mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0),
mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0),
):
filter = TraceContextFilter()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="",
lineno=0,
msg="test",
args=(),
exc_info=None,
)
filter.filter(record)
assert record.trace_id == "5b8aa5a2d2c872e8321cf37308d69df2"
assert record.span_id == "051581bf3bb55c45"
class TestIdentityContextFilter:
def test_sets_empty_identity_without_request_context(self):
from core.logging.filters import IdentityContextFilter
filter = IdentityContextFilter()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="",
lineno=0,
msg="test",
args=(),
exc_info=None,
)
result = filter.filter(record)
assert result is True
assert record.tenant_id == ""
assert record.user_id == ""
assert record.user_type == ""
def test_filter_always_returns_true(self):
from core.logging.filters import IdentityContextFilter
filter = IdentityContextFilter()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="",
lineno=0,
msg="test",
args=(),
exc_info=None,
)
result = filter.filter(record)
assert result is True
def test_handles_exception_gracefully(self):
from core.logging.filters import IdentityContextFilter
filter = IdentityContextFilter()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="",
lineno=0,
msg="test",
args=(),
exc_info=None,
)
# Should not raise even if something goes wrong
with mock.patch("core.logging.filters.flask.has_request_context", side_effect=Exception("Test error")):
result = filter.filter(record)
assert result is True
assert record.tenant_id == ""

View File

@ -0,0 +1,240 @@
"""Tests for structured JSON formatter."""
import json
import logging
import sys
class TestStructuredJSONFormatter:
def test_basic_log_format(self):
from core.logging.structured_formatter import StructuredJSONFormatter
formatter = StructuredJSONFormatter(service_name="test-service")
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="test.py",
lineno=42,
msg="Test message",
args=(),
exc_info=None,
)
output = formatter.format(record)
log_dict = json.loads(output)
assert log_dict["severity"] == "INFO"
assert log_dict["service"] == "test-service"
assert log_dict["caller"] == "test.py:42"
assert log_dict["message"] == "Test message"
assert "ts" in log_dict
assert log_dict["ts"].endswith("Z")
def test_severity_mapping(self):
from core.logging.structured_formatter import StructuredJSONFormatter
formatter = StructuredJSONFormatter()
test_cases = [
(logging.DEBUG, "DEBUG"),
(logging.INFO, "INFO"),
(logging.WARNING, "WARN"),
(logging.ERROR, "ERROR"),
(logging.CRITICAL, "ERROR"),
]
for level, expected_severity in test_cases:
record = logging.LogRecord(
name="test",
level=level,
pathname="test.py",
lineno=1,
msg="Test",
args=(),
exc_info=None,
)
output = formatter.format(record)
log_dict = json.loads(output)
assert log_dict["severity"] == expected_severity, f"Level {level} should map to {expected_severity}"
def test_error_with_stack_trace(self):
from core.logging.structured_formatter import StructuredJSONFormatter
formatter = StructuredJSONFormatter()
try:
raise ValueError("Test error")
except ValueError:
exc_info = sys.exc_info()
record = logging.LogRecord(
name="test",
level=logging.ERROR,
pathname="test.py",
lineno=10,
msg="Error occurred",
args=(),
exc_info=exc_info,
)
output = formatter.format(record)
log_dict = json.loads(output)
assert log_dict["severity"] == "ERROR"
assert "stack_trace" in log_dict
assert "ValueError: Test error" in log_dict["stack_trace"]
def test_no_stack_trace_for_info(self):
from core.logging.structured_formatter import StructuredJSONFormatter
formatter = StructuredJSONFormatter()
try:
raise ValueError("Test error")
except ValueError:
exc_info = sys.exc_info()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="test.py",
lineno=10,
msg="Info message",
args=(),
exc_info=exc_info,
)
output = formatter.format(record)
log_dict = json.loads(output)
assert "stack_trace" not in log_dict
def test_trace_context_included(self):
from core.logging.structured_formatter import StructuredJSONFormatter
formatter = StructuredJSONFormatter()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="test.py",
lineno=1,
msg="Test",
args=(),
exc_info=None,
)
record.trace_id = "5b8aa5a2d2c872e8321cf37308d69df2"
record.span_id = "051581bf3bb55c45"
output = formatter.format(record)
log_dict = json.loads(output)
assert log_dict["trace_id"] == "5b8aa5a2d2c872e8321cf37308d69df2"
assert log_dict["span_id"] == "051581bf3bb55c45"
def test_identity_context_included(self):
from core.logging.structured_formatter import StructuredJSONFormatter
formatter = StructuredJSONFormatter()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="test.py",
lineno=1,
msg="Test",
args=(),
exc_info=None,
)
record.tenant_id = "t-global-corp"
record.user_id = "u-admin-007"
record.user_type = "admin"
output = formatter.format(record)
log_dict = json.loads(output)
assert "identity" in log_dict
assert log_dict["identity"]["tenant_id"] == "t-global-corp"
assert log_dict["identity"]["user_id"] == "u-admin-007"
assert log_dict["identity"]["user_type"] == "admin"
def test_no_identity_when_empty(self):
from core.logging.structured_formatter import StructuredJSONFormatter
formatter = StructuredJSONFormatter()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="test.py",
lineno=1,
msg="Test",
args=(),
exc_info=None,
)
output = formatter.format(record)
log_dict = json.loads(output)
assert "identity" not in log_dict
def test_attributes_included(self):
from core.logging.structured_formatter import StructuredJSONFormatter
formatter = StructuredJSONFormatter()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="test.py",
lineno=1,
msg="Test",
args=(),
exc_info=None,
)
record.attributes = {"order_id": "ord-123", "amount": 99.99}
output = formatter.format(record)
log_dict = json.loads(output)
assert log_dict["attributes"]["order_id"] == "ord-123"
assert log_dict["attributes"]["amount"] == 99.99
def test_message_with_args(self):
from core.logging.structured_formatter import StructuredJSONFormatter
formatter = StructuredJSONFormatter()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="test.py",
lineno=1,
msg="User %s logged in from %s",
args=("john", "192.168.1.1"),
exc_info=None,
)
output = formatter.format(record)
log_dict = json.loads(output)
assert log_dict["message"] == "User john logged in from 192.168.1.1"
def test_timestamp_format(self):
from core.logging.structured_formatter import StructuredJSONFormatter
formatter = StructuredJSONFormatter()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="test.py",
lineno=1,
msg="Test",
args=(),
exc_info=None,
)
output = formatter.format(record)
log_dict = json.loads(output)
# Verify ISO 8601 format with Z suffix
ts = log_dict["ts"]
assert ts.endswith("Z")
assert "T" in ts
# Should have milliseconds
assert "." in ts

View File

@ -0,0 +1,102 @@
"""Tests for trace helper functions."""
import re
from unittest import mock
class TestGetSpanIdFromOtelContext:
def test_returns_none_without_span(self):
from core.helper.trace_id_helper import get_span_id_from_otel_context
with mock.patch("opentelemetry.trace.get_current_span", return_value=None):
result = get_span_id_from_otel_context()
assert result is None
def test_returns_span_id_when_available(self):
from core.helper.trace_id_helper import get_span_id_from_otel_context
mock_span = mock.MagicMock()
mock_context = mock.MagicMock()
mock_context.span_id = 0x051581BF3BB55C45
mock_span.get_span_context.return_value = mock_context
with mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span):
with mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0):
result = get_span_id_from_otel_context()
assert result == "051581bf3bb55c45"
def test_returns_none_on_exception(self):
from core.helper.trace_id_helper import get_span_id_from_otel_context
with mock.patch("opentelemetry.trace.get_current_span", side_effect=Exception("Test error")):
result = get_span_id_from_otel_context()
assert result is None
class TestGenerateTraceparentHeader:
def test_generates_valid_format(self):
from core.helper.trace_id_helper import generate_traceparent_header
with mock.patch("opentelemetry.trace.get_current_span", return_value=None):
result = generate_traceparent_header()
assert result is not None
# Format: 00-{trace_id}-{span_id}-01
parts = result.split("-")
assert len(parts) == 4
assert parts[0] == "00" # version
assert len(parts[1]) == 32 # trace_id (32 hex chars)
assert len(parts[2]) == 16 # span_id (16 hex chars)
assert parts[3] == "01" # flags
def test_uses_otel_context_when_available(self):
from core.helper.trace_id_helper import generate_traceparent_header
mock_span = mock.MagicMock()
mock_context = mock.MagicMock()
mock_context.trace_id = 0x5B8AA5A2D2C872E8321CF37308D69DF2
mock_context.span_id = 0x051581BF3BB55C45
mock_span.get_span_context.return_value = mock_context
with mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span):
with (
mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0),
mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0),
):
result = generate_traceparent_header()
assert result == "00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01"
def test_generates_hex_only_values(self):
from core.helper.trace_id_helper import generate_traceparent_header
with mock.patch("opentelemetry.trace.get_current_span", return_value=None):
result = generate_traceparent_header()
parts = result.split("-")
# All parts should be valid hex
assert re.match(r"^[0-9a-f]+$", parts[1])
assert re.match(r"^[0-9a-f]+$", parts[2])
class TestParseTraceparentHeader:
def test_parses_valid_traceparent(self):
from core.helper.trace_id_helper import parse_traceparent_header
traceparent = "00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01"
result = parse_traceparent_header(traceparent)
assert result == "5b8aa5a2d2c872e8321cf37308d69df2"
def test_returns_none_for_invalid_format(self):
from core.helper.trace_id_helper import parse_traceparent_header
# Wrong number of parts
assert parse_traceparent_header("00-abc-def") is None
# Wrong trace_id length
assert parse_traceparent_header("00-abc-def-01") is None
def test_returns_none_for_empty_string(self):
from core.helper.trace_id_helper import parse_traceparent_header
assert parse_traceparent_header("") is None

View File

@ -99,29 +99,20 @@ def test_external_api_json_message_and_bad_request_rewrite():
assert res.get_json()["message"] == "Invalid JSON payload received or JSON payload is empty."
def test_external_api_param_mapping_and_quota_and_exc_info_none():
# Force exc_info() to return (None,None,None) only during request
import libs.external_api as ext
def test_external_api_param_mapping_and_quota():
app = _create_api_app()
client = app.test_client()
orig_exc_info = ext.sys.exc_info
try:
ext.sys.exc_info = lambda: (None, None, None)
# Param errors mapping payload path
res = client.get("/api/param-errors")
assert res.status_code == 400
data = res.get_json()
assert data["code"] == "invalid_param"
assert data["params"] == "field"
app = _create_api_app()
client = app.test_client()
# Param errors mapping payload path
res = client.get("/api/param-errors")
assert res.status_code == 400
data = res.get_json()
assert data["code"] == "invalid_param"
assert data["params"] == "field"
# Quota path — depending on Flask-RESTX internals it may be handled
res = client.get("/api/quota")
assert res.status_code in (400, 429)
finally:
ext.sys.exc_info = orig_exc_info # type: ignore[assignment]
# Quota path — depending on Flask-RESTX internals it may be handled
res = client.get("/api/quota")
assert res.status_code in (400, 429)
def test_unauthorized_and_force_logout_clears_cookies():