This commit is contained in:
Byron.wang 2025-12-29 15:43:28 +08:00 committed by GitHub
commit 4941b2981e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 1184 additions and 113 deletions

View File

@ -2,9 +2,11 @@ import logging
import time
from opentelemetry.trace import get_current_span
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
from configs import dify_config
from contexts.wrapper import RecyclableContextVar
from core.logging.context import init_request_context
from dify_app import DifyApp
logger = logging.getLogger(__name__)
@ -25,28 +27,35 @@ def create_flask_app_with_configs() -> DifyApp:
# add before request hook
@dify_app.before_request
def before_request():
# add an unique identifier to each request
# Initialize logging context for this request
init_request_context()
RecyclableContextVar.increment_thread_recycles()
# add after request hook for injecting X-Trace-Id header from OpenTelemetry span context
# add after request hook for injecting trace headers from OpenTelemetry span context
# Only adds headers when OTEL is enabled and has valid context
@dify_app.after_request
def add_trace_id_header(response):
def add_trace_headers(response):
try:
span = get_current_span()
ctx = span.get_span_context() if span else None
if ctx and ctx.is_valid:
trace_id_hex = format(ctx.trace_id, "032x")
# Avoid duplicates if some middleware added it
if "X-Trace-Id" not in response.headers:
response.headers["X-Trace-Id"] = trace_id_hex
if not ctx or not ctx.is_valid:
return response
# Inject trace headers from OTEL context
if ctx.trace_id != INVALID_TRACE_ID and "X-Trace-Id" not in response.headers:
response.headers["X-Trace-Id"] = format(ctx.trace_id, "032x")
if ctx.span_id != INVALID_SPAN_ID and "X-Span-Id" not in response.headers:
response.headers["X-Span-Id"] = format(ctx.span_id, "016x")
except Exception:
# Never break the response due to tracing header injection
logger.warning("Failed to add trace ID to response header", exc_info=True)
logger.warning("Failed to add trace headers to response", exc_info=True)
return response
# Capture the decorator's return value to avoid pyright reportUnusedFunction
_ = before_request
_ = add_trace_id_header
_ = add_trace_headers
return dify_app

View File

@ -587,6 +587,11 @@ class LoggingConfig(BaseSettings):
default="INFO",
)
LOG_OUTPUT_FORMAT: Literal["text", "json"] = Field(
description="Log output format: 'text' for human-readable, 'json' for structured JSON logs.",
default="text",
)
LOG_FILE: str | None = Field(
description="File path for log output.",
default=None,

View File

@ -88,7 +88,41 @@ def _get_user_provided_host_header(headers: dict | None) -> str | None:
return None
def _inject_trace_headers(headers: dict | None) -> dict:
"""
Inject W3C traceparent header for distributed tracing.
When OTEL is enabled, HTTPXClientInstrumentor handles trace propagation automatically.
When OTEL is disabled, we manually inject the traceparent header.
"""
if headers is None:
headers = {}
# Skip if already present (case-insensitive check)
for key in headers:
if key.lower() == "traceparent":
return headers
# Skip if OTEL is enabled - HTTPXClientInstrumentor handles this automatically
if dify_config.ENABLE_OTEL:
return headers
# Generate and inject traceparent for non-OTEL scenarios
try:
from core.helper.trace_id_helper import generate_traceparent_header
traceparent = generate_traceparent_header()
if traceparent:
headers["traceparent"] = traceparent
except Exception:
# Silently ignore errors to avoid breaking requests
logger.debug("Failed to generate traceparent header", exc_info=True)
return headers
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
# Convert requests-style allow_redirects to httpx-style follow_redirects
if "allow_redirects" in kwargs:
allow_redirects = kwargs.pop("allow_redirects")
if "follow_redirects" not in kwargs:
@ -106,18 +140,21 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY)
client = _get_ssrf_client(verify_option)
# Inject traceparent header for distributed tracing (when OTEL is not enabled)
headers = kwargs.get("headers") or {}
headers = _inject_trace_headers(headers)
kwargs["headers"] = headers
# Preserve user-provided Host header
# When using a forward proxy, httpx may override the Host header based on the URL.
# We extract and preserve any explicitly set Host header to support virtual hosting.
headers = kwargs.get("headers", {})
user_provided_host = _get_user_provided_host_header(headers)
retries = 0
while retries <= max_retries:
try:
# Build the request manually to preserve the Host header
# httpx may override the Host header when using a proxy, so we use
# the request API to explicitly set headers before sending
# Preserve the user-provided Host header
# httpx may override the Host header when using a proxy
headers = {k: v for k, v in headers.items() if k.lower() != "host"}
if user_provided_host is not None:
headers["host"] = user_provided_host

View File

@ -103,3 +103,60 @@ def parse_traceparent_header(traceparent: str) -> str | None:
if len(parts) == 4 and len(parts[1]) == 32:
return parts[1]
return None
def get_span_id_from_otel_context() -> str | None:
"""
Retrieve the current span ID from the active OpenTelemetry trace context.
Returns:
A 16-character hex string representing the span ID, or None if not available.
"""
try:
from opentelemetry.trace import get_current_span
from opentelemetry.trace.span import INVALID_SPAN_ID
span = get_current_span()
if not span:
return None
span_context = span.get_span_context()
if not span_context or span_context.span_id == INVALID_SPAN_ID:
return None
return f"{span_context.span_id:016x}"
except Exception:
return None
def generate_traceparent_header() -> str | None:
"""
Generate a W3C traceparent header from the current context.
Uses OpenTelemetry context if available, otherwise uses the
ContextVar-based trace_id from the logging context.
Format: {version}-{trace_id}-{span_id}-{flags}
Example: 00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01
Returns:
A valid traceparent header string, or None if generation fails.
"""
import uuid
# Try OTEL context first
trace_id = get_trace_id_from_otel_context()
span_id = get_span_id_from_otel_context()
if trace_id and span_id:
return f"00-{trace_id}-{span_id}-01"
# Fallback: use ContextVar-based trace_id or generate new one
from core.logging.context import get_trace_id as get_logging_trace_id
trace_id = get_logging_trace_id() or uuid.uuid4().hex
# Generate a new span_id (16 hex chars)
span_id = uuid.uuid4().hex[:16]
return f"00-{trace_id}-{span_id}-01"

View File

@ -0,0 +1,20 @@
"""Structured logging components for Dify."""
from core.logging.context import (
clear_request_context,
get_request_id,
get_trace_id,
init_request_context,
)
from core.logging.filters import IdentityContextFilter, TraceContextFilter
from core.logging.structured_formatter import StructuredJSONFormatter
__all__ = [
"IdentityContextFilter",
"StructuredJSONFormatter",
"TraceContextFilter",
"clear_request_context",
"get_request_id",
"get_trace_id",
"init_request_context",
]

View File

@ -0,0 +1,35 @@
"""Request context for logging - framework agnostic.
This module provides request-scoped context variables for logging,
using Python's contextvars for thread-safe and async-safe storage.
"""
import uuid
from contextvars import ContextVar
_request_id: ContextVar[str] = ContextVar("log_request_id", default="")
_trace_id: ContextVar[str] = ContextVar("log_trace_id", default="")
def get_request_id() -> str:
"""Get current request ID (10 hex chars)."""
return _request_id.get()
def get_trace_id() -> str:
"""Get fallback trace ID when OTEL is unavailable (32 hex chars)."""
return _trace_id.get()
def init_request_context() -> None:
"""Initialize request context. Call at start of each request."""
req_id = uuid.uuid4().hex[:10]
trace_id = uuid.uuid5(uuid.NAMESPACE_DNS, req_id).hex
_request_id.set(req_id)
_trace_id.set(trace_id)
def clear_request_context() -> None:
"""Clear request context. Call at end of request (optional)."""
_request_id.set("")
_trace_id.set("")

View File

@ -0,0 +1,94 @@
"""Logging filters for structured logging."""
import contextlib
import logging
import flask
from core.logging.context import get_request_id, get_trace_id
class TraceContextFilter(logging.Filter):
"""
Filter that adds trace_id and span_id to log records.
Integrates with OpenTelemetry when available, falls back to ContextVar-based trace_id.
"""
def filter(self, record: logging.LogRecord) -> bool:
# Get trace context from OpenTelemetry
trace_id, span_id = self._get_otel_context()
# Set trace_id (fallback to ContextVar if no OTEL context)
if trace_id:
record.trace_id = trace_id
else:
record.trace_id = get_trace_id()
record.span_id = span_id or ""
# For backward compatibility, also set req_id
record.req_id = get_request_id()
return True
def _get_otel_context(self) -> tuple[str, str]:
"""Extract trace_id and span_id from OpenTelemetry context."""
with contextlib.suppress(Exception):
from opentelemetry.trace import get_current_span
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
span = get_current_span()
if span and span.get_span_context():
ctx = span.get_span_context()
if ctx.is_valid and ctx.trace_id != INVALID_TRACE_ID:
trace_id = f"{ctx.trace_id:032x}"
span_id = f"{ctx.span_id:016x}" if ctx.span_id != INVALID_SPAN_ID else ""
return trace_id, span_id
return "", ""
class IdentityContextFilter(logging.Filter):
"""
Filter that adds user identity context to log records.
Extracts tenant_id, user_id, and user_type from Flask-Login current_user.
"""
def filter(self, record: logging.LogRecord) -> bool:
identity = self._extract_identity()
record.tenant_id = identity.get("tenant_id", "")
record.user_id = identity.get("user_id", "")
record.user_type = identity.get("user_type", "")
return True
def _extract_identity(self) -> dict[str, str]:
"""Extract identity from current_user if in request context."""
try:
if not flask.has_request_context():
return {}
from flask_login import current_user
# Check if user is authenticated using the proxy
if not current_user.is_authenticated:
return {}
# Access the underlying user object
user = current_user
from models import Account
from models.model import EndUser
identity: dict[str, str] = {}
if isinstance(user, Account):
if user.current_tenant_id:
identity["tenant_id"] = user.current_tenant_id
identity["user_id"] = user.id
identity["user_type"] = "account"
elif isinstance(user, EndUser):
identity["tenant_id"] = user.tenant_id
identity["user_id"] = user.id
identity["user_type"] = user.type or "end_user"
return identity
except Exception:
return {}

View File

@ -0,0 +1,107 @@
"""Structured JSON log formatter for Dify."""
import logging
import traceback
from datetime import UTC, datetime
from typing import Any
import orjson
from configs import dify_config
class StructuredJSONFormatter(logging.Formatter):
"""
JSON log formatter following the specified schema:
{
"ts": "ISO 8601 UTC",
"severity": "INFO|ERROR|WARN|DEBUG",
"service": "service name",
"caller": "file:line",
"trace_id": "hex 32",
"span_id": "hex 16",
"identity": { "tenant_id", "user_id", "user_type" },
"message": "log message",
"attributes": { ... },
"stack_trace": "..."
}
"""
SEVERITY_MAP: dict[int, str] = {
logging.DEBUG: "DEBUG",
logging.INFO: "INFO",
logging.WARNING: "WARN",
logging.ERROR: "ERROR",
logging.CRITICAL: "ERROR",
}
def __init__(self, service_name: str | None = None):
super().__init__()
self._service_name = service_name or dify_config.APPLICATION_NAME
def format(self, record: logging.LogRecord) -> str:
log_dict = self._build_log_dict(record)
try:
return orjson.dumps(log_dict).decode("utf-8")
except TypeError:
# Fallback: convert non-serializable objects to string
import json
return json.dumps(log_dict, default=str, ensure_ascii=False)
def _build_log_dict(self, record: logging.LogRecord) -> dict[str, Any]:
# Core fields
log_dict: dict[str, Any] = {
"ts": datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z"),
"severity": self.SEVERITY_MAP.get(record.levelno, "INFO"),
"service": self._service_name,
"caller": f"{record.filename}:{record.lineno}",
"message": record.getMessage(),
}
# Trace context (from TraceContextFilter)
trace_id = getattr(record, "trace_id", "")
span_id = getattr(record, "span_id", "")
if trace_id:
log_dict["trace_id"] = trace_id
if span_id:
log_dict["span_id"] = span_id
# Identity context (from IdentityContextFilter)
identity = self._extract_identity(record)
if identity:
log_dict["identity"] = identity
# Dynamic attributes
attributes = getattr(record, "attributes", None)
if attributes:
log_dict["attributes"] = attributes
# Stack trace for errors with exceptions
if record.exc_info and record.levelno >= logging.ERROR:
log_dict["stack_trace"] = self._format_exception(record.exc_info)
return log_dict
def _extract_identity(self, record: logging.LogRecord) -> dict[str, str] | None:
tenant_id = getattr(record, "tenant_id", None)
user_id = getattr(record, "user_id", None)
user_type = getattr(record, "user_type", None)
if not any([tenant_id, user_id, user_type]):
return None
identity: dict[str, str] = {}
if tenant_id:
identity["tenant_id"] = tenant_id
if user_id:
identity["user_id"] = user_id
if user_type:
identity["user_type"] = user_type
return identity
def _format_exception(self, exc_info: tuple[Any, ...]) -> str:
if exc_info and exc_info[0] is not None:
return "".join(traceback.format_exception(*exc_info))
return ""

View File

@ -103,6 +103,9 @@ class BasePluginClient:
prepared_headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY
prepared_headers.setdefault("Accept-Encoding", "gzip, deflate, br")
# Inject traceparent header for distributed tracing
self._inject_trace_headers(prepared_headers)
prepared_data: bytes | dict[str, Any] | str | None = (
data if isinstance(data, (bytes, str, dict)) or data is None else None
)
@ -114,6 +117,28 @@ class BasePluginClient:
return str(url), prepared_headers, prepared_data, params, files
def _inject_trace_headers(self, headers: dict[str, str]) -> None:
"""
Inject W3C traceparent header for distributed tracing.
This ensures trace context is propagated to plugin daemon even if
HTTPXClientInstrumentor doesn't cover module-level httpx functions.
"""
import contextlib
# Skip if already present (case-insensitive check)
for key in headers:
if key.lower() == "traceparent":
return
# Inject traceparent - works as fallback when OTEL instrumentation doesn't cover this call
with contextlib.suppress(Exception):
from core.helper.trace_id_helper import generate_traceparent_header
traceparent = generate_traceparent_header()
if traceparent:
headers["traceparent"] = traceparent
def _stream_request(
self,
method: str,

View File

@ -47,7 +47,11 @@ def _get_celery_ssl_options() -> dict[str, Any] | None:
def init_app(app: DifyApp) -> Celery:
class FlaskTask(Task):
def __call__(self, *args: object, **kwargs: object) -> object:
from core.logging.context import init_request_context
with app.app_context():
# Initialize logging context for this task (similar to before_request in Flask)
init_request_context()
return self.run(*args, **kwargs)
broker_transport_options = {}

View File

@ -1,18 +1,19 @@
"""Logging extension for Dify Flask application."""
import logging
import os
import sys
import uuid
from logging.handlers import RotatingFileHandler
import flask
from configs import dify_config
from core.helper.trace_id_helper import get_trace_id_from_otel_context
from dify_app import DifyApp
def init_app(app: DifyApp):
"""Initialize logging with support for text or JSON format."""
log_handlers: list[logging.Handler] = []
# File handler
log_file = dify_config.LOG_FILE
if log_file:
log_dir = os.path.dirname(log_file)
@ -25,27 +26,53 @@ def init_app(app: DifyApp):
)
)
# Always add StreamHandler to log to console
# Console handler
sh = logging.StreamHandler(sys.stdout)
log_handlers.append(sh)
# Apply RequestIdFilter to all handlers
for handler in log_handlers:
handler.addFilter(RequestIdFilter())
# Apply filters to all handlers
from core.logging.filters import IdentityContextFilter, TraceContextFilter
for handler in log_handlers:
handler.addFilter(TraceContextFilter())
handler.addFilter(IdentityContextFilter())
# Configure formatter based on format type
formatter = _create_formatter()
for handler in log_handlers:
handler.setFormatter(formatter)
# Configure root logger
logging.basicConfig(
level=dify_config.LOG_LEVEL,
format=dify_config.LOG_FORMAT,
datefmt=dify_config.LOG_DATEFORMAT,
handlers=log_handlers,
force=True,
)
# Apply RequestIdFormatter to all handlers
apply_request_id_formatter()
# Disable propagation for noisy loggers to avoid duplicate logs
logging.getLogger("sqlalchemy.engine").propagate = False
# Apply timezone if specified (only for text format)
if dify_config.LOG_OUTPUT_FORMAT == "text":
_apply_timezone(log_handlers)
def _create_formatter() -> logging.Formatter:
"""Create appropriate formatter based on configuration."""
if dify_config.LOG_OUTPUT_FORMAT == "json":
from core.logging.structured_formatter import StructuredJSONFormatter
return StructuredJSONFormatter()
else:
# Text format - use existing pattern with backward compatible formatter
return _TextFormatter(
fmt=dify_config.LOG_FORMAT,
datefmt=dify_config.LOG_DATEFORMAT,
)
def _apply_timezone(handlers: list[logging.Handler]):
"""Apply timezone conversion to text formatters."""
log_tz = dify_config.LOG_TZ
if log_tz:
from datetime import datetime
@ -57,34 +84,51 @@ def init_app(app: DifyApp):
def time_converter(seconds):
return datetime.fromtimestamp(seconds, tz=timezone).timetuple()
for handler in logging.root.handlers:
for handler in handlers:
if handler.formatter:
handler.formatter.converter = time_converter
handler.formatter.converter = time_converter # type: ignore[attr-defined]
def get_request_id():
if getattr(flask.g, "request_id", None):
return flask.g.request_id
class _TextFormatter(logging.Formatter):
"""Text formatter that ensures trace_id and req_id are always present."""
new_uuid = uuid.uuid4().hex[:10]
flask.g.request_id = new_uuid
return new_uuid
def format(self, record: logging.LogRecord) -> str:
if not hasattr(record, "req_id"):
record.req_id = ""
if not hasattr(record, "trace_id"):
record.trace_id = ""
if not hasattr(record, "span_id"):
record.span_id = ""
return super().format(record)
def get_request_id() -> str:
"""Get request ID for current request context.
Deprecated: Use core.logging.context.get_request_id() directly.
"""
from core.logging.context import get_request_id as _get_request_id
return _get_request_id()
# Backward compatibility aliases
class RequestIdFilter(logging.Filter):
# This is a logging filter that makes the request ID available for use in
# the logging format. Note that we're checking if we're in a request
# context, as we may want to log things before Flask is fully loaded.
def filter(self, record):
trace_id = get_trace_id_from_otel_context() or ""
record.req_id = get_request_id() if flask.has_request_context() else ""
record.trace_id = trace_id
"""Deprecated: Use TraceContextFilter from core.logging.filters instead."""
def filter(self, record: logging.LogRecord) -> bool:
from core.logging.context import get_request_id as _get_request_id
from core.logging.context import get_trace_id as _get_trace_id
record.req_id = _get_request_id()
record.trace_id = _get_trace_id()
return True
class RequestIdFormatter(logging.Formatter):
def format(self, record):
"""Deprecated: Use _TextFormatter instead."""
def format(self, record: logging.LogRecord) -> str:
if not hasattr(record, "req_id"):
record.req_id = ""
if not hasattr(record, "trace_id"):
@ -93,6 +137,7 @@ class RequestIdFormatter(logging.Formatter):
def apply_request_id_formatter():
"""Deprecated: Formatter is now applied in init_app."""
for handler in logging.root.handlers:
if handler.formatter:
handler.formatter = RequestIdFormatter(dify_config.LOG_FORMAT, dify_config.LOG_DATEFORMAT)

View File

@ -19,26 +19,43 @@ logger = logging.getLogger(__name__)
class ExceptionLoggingHandler(logging.Handler):
"""
Handler that records exceptions to the current OpenTelemetry span.
Unlike creating a new span, this records exceptions on the existing span
to maintain trace context consistency throughout the request lifecycle.
"""
def emit(self, record: logging.LogRecord):
with contextlib.suppress(Exception):
if record.exc_info:
tracer = get_tracer_provider().get_tracer("dify.exception.logging")
with tracer.start_as_current_span(
"log.exception",
attributes={
"log.level": record.levelname,
"log.message": record.getMessage(),
"log.logger": record.name,
"log.file.path": record.pathname,
"log.file.line": record.lineno,
},
) as span:
span.set_status(StatusCode.ERROR)
if record.exc_info[1]:
span.record_exception(record.exc_info[1])
span.set_attribute("exception.message", str(record.exc_info[1]))
if record.exc_info[0]:
span.set_attribute("exception.type", record.exc_info[0].__name__)
if not record.exc_info:
return
from opentelemetry.trace import get_current_span
span = get_current_span()
if not span or not span.is_recording():
return
# Record exception on the current span instead of creating a new one
span.set_status(StatusCode.ERROR, record.getMessage())
# Add log context as span events/attributes
span.add_event(
"log.exception",
attributes={
"log.level": record.levelname,
"log.message": record.getMessage(),
"log.logger": record.name,
"log.file.path": record.pathname,
"log.file.line": record.lineno,
},
)
if record.exc_info[1]:
span.record_exception(record.exc_info[1])
if record.exc_info[0]:
span.set_attribute("exception.type", record.exc_info[0].__name__)
def instrument_exception_logging() -> None:

View File

@ -1,5 +1,4 @@
import re
import sys
from collections.abc import Mapping
from typing import Any
@ -109,11 +108,8 @@ def register_external_error_handlers(api: Api):
data.setdefault("code", "unknown")
data.setdefault("status", status_code)
# Log stack
exc_info: Any = sys.exc_info()
if exc_info[1] is None:
exc_info = (None, None, None)
current_app.log_exception(exc_info)
# Note: Exception logging is handled by Flask/Flask-RESTX framework automatically
# Explicit log_exception call removed to avoid duplicate log entries
return data, status_code

View File

@ -14,12 +14,12 @@ def test_successful_request(mock_get_client):
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_client.send.return_value = mock_response
mock_client.request.return_value = mock_response
mock_get_client.return_value = mock_client
response = make_request("GET", "http://example.com")
assert response.status_code == 200
mock_client.request.assert_called_once()
@patch("core.helper.ssrf_proxy._get_ssrf_client")
@ -27,7 +27,6 @@ def test_retry_exceed_max_retries(mock_get_client):
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 500
mock_client.send.return_value = mock_response
mock_client.request.return_value = mock_response
mock_get_client.return_value = mock_client
@ -72,34 +71,12 @@ class TestGetUserProvidedHostHeader:
assert result in ("first.com", "second.com")
@patch("core.helper.ssrf_proxy._get_ssrf_client")
def test_host_header_preservation_without_user_header(mock_get_client):
"""Test that when no Host header is provided, the default behavior is maintained."""
mock_client = MagicMock()
mock_request = MagicMock()
mock_request.headers = {}
mock_response = MagicMock()
mock_response.status_code = 200
mock_client.send.return_value = mock_response
mock_client.request.return_value = mock_response
mock_get_client.return_value = mock_client
response = make_request("GET", "http://example.com")
assert response.status_code == 200
# Host should not be set if not provided by user
assert "Host" not in mock_request.headers or mock_request.headers.get("Host") is None
@patch("core.helper.ssrf_proxy._get_ssrf_client")
def test_host_header_preservation_with_user_header(mock_get_client):
"""Test that user-provided Host header is preserved in the request."""
mock_client = MagicMock()
mock_request = MagicMock()
mock_request.headers = {}
mock_response = MagicMock()
mock_response.status_code = 200
mock_client.send.return_value = mock_response
mock_client.request.return_value = mock_response
mock_get_client.return_value = mock_client
@ -107,3 +84,93 @@ def test_host_header_preservation_with_user_header(mock_get_client):
response = make_request("GET", "http://example.com", headers={"Host": custom_host})
assert response.status_code == 200
# Verify client.request was called with the host header preserved (lowercase)
call_kwargs = mock_client.request.call_args.kwargs
assert call_kwargs["headers"]["host"] == custom_host
@patch("core.helper.ssrf_proxy._get_ssrf_client")
@pytest.mark.parametrize("host_key", ["host", "HOST", "Host"])
def test_host_header_preservation_case_insensitive(mock_get_client, host_key):
"""Test that Host header is preserved regardless of case."""
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_client.request.return_value = mock_response
mock_get_client.return_value = mock_client
response = make_request("GET", "http://example.com", headers={host_key: "api.example.com"})
assert response.status_code == 200
# Host header should be normalized to lowercase "host"
call_kwargs = mock_client.request.call_args.kwargs
assert call_kwargs["headers"]["host"] == "api.example.com"
class TestFollowRedirectsParameter:
"""Tests for follow_redirects parameter handling.
These tests verify that follow_redirects is correctly passed to client.request().
"""
@patch("core.helper.ssrf_proxy._get_ssrf_client")
def test_follow_redirects_passed_to_request(self, mock_get_client):
"""Verify follow_redirects IS passed to client.request()."""
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_client.request.return_value = mock_response
mock_get_client.return_value = mock_client
make_request("GET", "http://example.com", follow_redirects=True)
# Verify follow_redirects was passed to request
call_kwargs = mock_client.request.call_args.kwargs
assert call_kwargs.get("follow_redirects") is True
@patch("core.helper.ssrf_proxy._get_ssrf_client")
def test_allow_redirects_converted_to_follow_redirects(self, mock_get_client):
"""Verify allow_redirects (requests-style) is converted to follow_redirects (httpx-style)."""
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_client.request.return_value = mock_response
mock_get_client.return_value = mock_client
# Use allow_redirects (requests-style parameter)
make_request("GET", "http://example.com", allow_redirects=True)
# Verify it was converted to follow_redirects
call_kwargs = mock_client.request.call_args.kwargs
assert call_kwargs.get("follow_redirects") is True
assert "allow_redirects" not in call_kwargs
@patch("core.helper.ssrf_proxy._get_ssrf_client")
def test_follow_redirects_not_set_when_not_specified(self, mock_get_client):
"""Verify follow_redirects is not in kwargs when not specified (httpx default behavior)."""
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_client.request.return_value = mock_response
mock_get_client.return_value = mock_client
make_request("GET", "http://example.com")
# follow_redirects should not be in kwargs, letting httpx use its default
call_kwargs = mock_client.request.call_args.kwargs
assert "follow_redirects" not in call_kwargs
@patch("core.helper.ssrf_proxy._get_ssrf_client")
def test_follow_redirects_takes_precedence_over_allow_redirects(self, mock_get_client):
"""Verify follow_redirects takes precedence when both are specified."""
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_client.request.return_value = mock_response
mock_get_client.return_value = mock_client
# Both specified - follow_redirects should take precedence
make_request("GET", "http://example.com", allow_redirects=False, follow_redirects=True)
call_kwargs = mock_client.request.call_args.kwargs
assert call_kwargs.get("follow_redirects") is True

View File

@ -0,0 +1,79 @@
"""Tests for logging context module."""
import uuid
from core.logging.context import (
clear_request_context,
get_request_id,
get_trace_id,
init_request_context,
)
class TestLoggingContext:
"""Tests for the logging context functions."""
def test_init_creates_request_id(self):
"""init_request_context should create a 10-char request ID."""
init_request_context()
request_id = get_request_id()
assert len(request_id) == 10
assert all(c in "0123456789abcdef" for c in request_id)
def test_init_creates_trace_id(self):
"""init_request_context should create a 32-char trace ID."""
init_request_context()
trace_id = get_trace_id()
assert len(trace_id) == 32
assert all(c in "0123456789abcdef" for c in trace_id)
def test_trace_id_derived_from_request_id(self):
"""trace_id should be deterministically derived from request_id."""
init_request_context()
request_id = get_request_id()
trace_id = get_trace_id()
# Verify trace_id is derived using uuid5
expected_trace = uuid.uuid5(uuid.NAMESPACE_DNS, request_id).hex
assert trace_id == expected_trace
def test_clear_resets_context(self):
"""clear_request_context should reset both IDs to empty strings."""
init_request_context()
assert get_request_id() != ""
assert get_trace_id() != ""
clear_request_context()
assert get_request_id() == ""
assert get_trace_id() == ""
def test_default_values_are_empty(self):
"""Default values should be empty strings before init."""
clear_request_context()
assert get_request_id() == ""
assert get_trace_id() == ""
def test_multiple_inits_create_different_ids(self):
"""Each init should create new unique IDs."""
init_request_context()
first_request_id = get_request_id()
first_trace_id = get_trace_id()
init_request_context()
second_request_id = get_request_id()
second_trace_id = get_trace_id()
assert first_request_id != second_request_id
assert first_trace_id != second_trace_id
def test_context_isolation(self):
"""Context should be isolated per-call (no thread leakage in same thread)."""
init_request_context()
id1 = get_request_id()
# Simulate another request
init_request_context()
id2 = get_request_id()
# IDs should be different
assert id1 != id2

View File

@ -0,0 +1,114 @@
"""Tests for logging filters."""
import logging
from unittest import mock
import pytest
@pytest.fixture
def log_record():
return logging.LogRecord(
name="test",
level=logging.INFO,
pathname="",
lineno=0,
msg="test",
args=(),
exc_info=None,
)
class TestTraceContextFilter:
def test_sets_empty_trace_id_without_context(self, log_record):
from core.logging.context import clear_request_context
from core.logging.filters import TraceContextFilter
# Ensure no context is set
clear_request_context()
filter = TraceContextFilter()
result = filter.filter(log_record)
assert result is True
assert hasattr(log_record, "trace_id")
assert hasattr(log_record, "span_id")
assert hasattr(log_record, "req_id")
# Without context, IDs should be empty
assert log_record.trace_id == ""
assert log_record.req_id == ""
def test_sets_trace_id_from_context(self, log_record):
"""Test that trace_id and req_id are set from ContextVar when initialized."""
from core.logging.context import init_request_context
from core.logging.filters import TraceContextFilter
# Initialize context (no Flask needed!)
init_request_context()
filter = TraceContextFilter()
filter.filter(log_record)
# With context initialized, IDs should be set
assert log_record.trace_id != ""
assert len(log_record.trace_id) == 32
assert log_record.req_id != ""
assert len(log_record.req_id) == 10
def test_filter_always_returns_true(self, log_record):
from core.logging.filters import TraceContextFilter
filter = TraceContextFilter()
result = filter.filter(log_record)
assert result is True
def test_sets_trace_id_from_otel_when_available(self, log_record):
from core.logging.filters import TraceContextFilter
mock_span = mock.MagicMock()
mock_context = mock.MagicMock()
mock_context.trace_id = 0x5B8AA5A2D2C872E8321CF37308D69DF2
mock_context.span_id = 0x051581BF3BB55C45
mock_span.get_span_context.return_value = mock_context
with (
mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span),
mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0),
mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0),
):
filter = TraceContextFilter()
filter.filter(log_record)
assert log_record.trace_id == "5b8aa5a2d2c872e8321cf37308d69df2"
assert log_record.span_id == "051581bf3bb55c45"
class TestIdentityContextFilter:
def test_sets_empty_identity_without_request_context(self, log_record):
from core.logging.filters import IdentityContextFilter
filter = IdentityContextFilter()
result = filter.filter(log_record)
assert result is True
assert log_record.tenant_id == ""
assert log_record.user_id == ""
assert log_record.user_type == ""
def test_filter_always_returns_true(self, log_record):
from core.logging.filters import IdentityContextFilter
filter = IdentityContextFilter()
result = filter.filter(log_record)
assert result is True
def test_handles_exception_gracefully(self, log_record):
from core.logging.filters import IdentityContextFilter
filter = IdentityContextFilter()
# Should not raise even if something goes wrong
with mock.patch("core.logging.filters.flask.has_request_context", side_effect=Exception("Test error")):
result = filter.filter(log_record)
assert result is True
assert log_record.tenant_id == ""

View File

@ -0,0 +1,267 @@
"""Tests for structured JSON formatter."""
import logging
import sys
import orjson
class TestStructuredJSONFormatter:
def test_basic_log_format(self):
from core.logging.structured_formatter import StructuredJSONFormatter
formatter = StructuredJSONFormatter(service_name="test-service")
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="test.py",
lineno=42,
msg="Test message",
args=(),
exc_info=None,
)
output = formatter.format(record)
log_dict = orjson.loads(output)
assert log_dict["severity"] == "INFO"
assert log_dict["service"] == "test-service"
assert log_dict["caller"] == "test.py:42"
assert log_dict["message"] == "Test message"
assert "ts" in log_dict
assert log_dict["ts"].endswith("Z")
def test_severity_mapping(self):
from core.logging.structured_formatter import StructuredJSONFormatter
formatter = StructuredJSONFormatter()
test_cases = [
(logging.DEBUG, "DEBUG"),
(logging.INFO, "INFO"),
(logging.WARNING, "WARN"),
(logging.ERROR, "ERROR"),
(logging.CRITICAL, "ERROR"),
]
for level, expected_severity in test_cases:
record = logging.LogRecord(
name="test",
level=level,
pathname="test.py",
lineno=1,
msg="Test",
args=(),
exc_info=None,
)
output = formatter.format(record)
log_dict = orjson.loads(output)
assert log_dict["severity"] == expected_severity, f"Level {level} should map to {expected_severity}"
def test_error_with_stack_trace(self):
from core.logging.structured_formatter import StructuredJSONFormatter
formatter = StructuredJSONFormatter()
try:
raise ValueError("Test error")
except ValueError:
exc_info = sys.exc_info()
record = logging.LogRecord(
name="test",
level=logging.ERROR,
pathname="test.py",
lineno=10,
msg="Error occurred",
args=(),
exc_info=exc_info,
)
output = formatter.format(record)
log_dict = orjson.loads(output)
assert log_dict["severity"] == "ERROR"
assert "stack_trace" in log_dict
assert "ValueError: Test error" in log_dict["stack_trace"]
def test_no_stack_trace_for_info(self):
from core.logging.structured_formatter import StructuredJSONFormatter
formatter = StructuredJSONFormatter()
try:
raise ValueError("Test error")
except ValueError:
exc_info = sys.exc_info()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="test.py",
lineno=10,
msg="Info message",
args=(),
exc_info=exc_info,
)
output = formatter.format(record)
log_dict = orjson.loads(output)
assert "stack_trace" not in log_dict
def test_trace_context_included(self):
from core.logging.structured_formatter import StructuredJSONFormatter
formatter = StructuredJSONFormatter()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="test.py",
lineno=1,
msg="Test",
args=(),
exc_info=None,
)
record.trace_id = "5b8aa5a2d2c872e8321cf37308d69df2"
record.span_id = "051581bf3bb55c45"
output = formatter.format(record)
log_dict = orjson.loads(output)
assert log_dict["trace_id"] == "5b8aa5a2d2c872e8321cf37308d69df2"
assert log_dict["span_id"] == "051581bf3bb55c45"
def test_identity_context_included(self):
from core.logging.structured_formatter import StructuredJSONFormatter
formatter = StructuredJSONFormatter()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="test.py",
lineno=1,
msg="Test",
args=(),
exc_info=None,
)
record.tenant_id = "t-global-corp"
record.user_id = "u-admin-007"
record.user_type = "admin"
output = formatter.format(record)
log_dict = orjson.loads(output)
assert "identity" in log_dict
assert log_dict["identity"]["tenant_id"] == "t-global-corp"
assert log_dict["identity"]["user_id"] == "u-admin-007"
assert log_dict["identity"]["user_type"] == "admin"
def test_no_identity_when_empty(self):
from core.logging.structured_formatter import StructuredJSONFormatter
formatter = StructuredJSONFormatter()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="test.py",
lineno=1,
msg="Test",
args=(),
exc_info=None,
)
output = formatter.format(record)
log_dict = orjson.loads(output)
assert "identity" not in log_dict
def test_attributes_included(self):
from core.logging.structured_formatter import StructuredJSONFormatter
formatter = StructuredJSONFormatter()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="test.py",
lineno=1,
msg="Test",
args=(),
exc_info=None,
)
record.attributes = {"order_id": "ord-123", "amount": 99.99}
output = formatter.format(record)
log_dict = orjson.loads(output)
assert log_dict["attributes"]["order_id"] == "ord-123"
assert log_dict["attributes"]["amount"] == 99.99
def test_message_with_args(self):
from core.logging.structured_formatter import StructuredJSONFormatter
formatter = StructuredJSONFormatter()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="test.py",
lineno=1,
msg="User %s logged in from %s",
args=("john", "192.168.1.1"),
exc_info=None,
)
output = formatter.format(record)
log_dict = orjson.loads(output)
assert log_dict["message"] == "User john logged in from 192.168.1.1"
def test_timestamp_format(self):
from core.logging.structured_formatter import StructuredJSONFormatter
formatter = StructuredJSONFormatter()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="test.py",
lineno=1,
msg="Test",
args=(),
exc_info=None,
)
output = formatter.format(record)
log_dict = orjson.loads(output)
# Verify ISO 8601 format with Z suffix
ts = log_dict["ts"]
assert ts.endswith("Z")
assert "T" in ts
# Should have milliseconds
assert "." in ts
def test_fallback_for_non_serializable_attributes(self):
from core.logging.structured_formatter import StructuredJSONFormatter
formatter = StructuredJSONFormatter()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="test.py",
lineno=1,
msg="Test with non-serializable",
args=(),
exc_info=None,
)
# Set is not serializable by orjson
record.attributes = {"items": {1, 2, 3}, "custom": object()}
# Should not raise, fallback to json.dumps with default=str
output = formatter.format(record)
# Verify it's valid JSON (parsed by stdlib json since orjson may fail)
import json
log_dict = json.loads(output)
assert log_dict["message"] == "Test with non-serializable"
assert "attributes" in log_dict

View File

@ -0,0 +1,102 @@
"""Tests for trace helper functions."""
import re
from unittest import mock
class TestGetSpanIdFromOtelContext:
def test_returns_none_without_span(self):
from core.helper.trace_id_helper import get_span_id_from_otel_context
with mock.patch("opentelemetry.trace.get_current_span", return_value=None):
result = get_span_id_from_otel_context()
assert result is None
def test_returns_span_id_when_available(self):
from core.helper.trace_id_helper import get_span_id_from_otel_context
mock_span = mock.MagicMock()
mock_context = mock.MagicMock()
mock_context.span_id = 0x051581BF3BB55C45
mock_span.get_span_context.return_value = mock_context
with mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span):
with mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0):
result = get_span_id_from_otel_context()
assert result == "051581bf3bb55c45"
def test_returns_none_on_exception(self):
from core.helper.trace_id_helper import get_span_id_from_otel_context
with mock.patch("opentelemetry.trace.get_current_span", side_effect=Exception("Test error")):
result = get_span_id_from_otel_context()
assert result is None
class TestGenerateTraceparentHeader:
def test_generates_valid_format(self):
from core.helper.trace_id_helper import generate_traceparent_header
with mock.patch("opentelemetry.trace.get_current_span", return_value=None):
result = generate_traceparent_header()
assert result is not None
# Format: 00-{trace_id}-{span_id}-01
parts = result.split("-")
assert len(parts) == 4
assert parts[0] == "00" # version
assert len(parts[1]) == 32 # trace_id (32 hex chars)
assert len(parts[2]) == 16 # span_id (16 hex chars)
assert parts[3] == "01" # flags
def test_uses_otel_context_when_available(self):
from core.helper.trace_id_helper import generate_traceparent_header
mock_span = mock.MagicMock()
mock_context = mock.MagicMock()
mock_context.trace_id = 0x5B8AA5A2D2C872E8321CF37308D69DF2
mock_context.span_id = 0x051581BF3BB55C45
mock_span.get_span_context.return_value = mock_context
with mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span):
with (
mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0),
mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0),
):
result = generate_traceparent_header()
assert result == "00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01"
def test_generates_hex_only_values(self):
from core.helper.trace_id_helper import generate_traceparent_header
with mock.patch("opentelemetry.trace.get_current_span", return_value=None):
result = generate_traceparent_header()
parts = result.split("-")
# All parts should be valid hex
assert re.match(r"^[0-9a-f]+$", parts[1])
assert re.match(r"^[0-9a-f]+$", parts[2])
class TestParseTraceparentHeader:
def test_parses_valid_traceparent(self):
from core.helper.trace_id_helper import parse_traceparent_header
traceparent = "00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01"
result = parse_traceparent_header(traceparent)
assert result == "5b8aa5a2d2c872e8321cf37308d69df2"
def test_returns_none_for_invalid_format(self):
from core.helper.trace_id_helper import parse_traceparent_header
# Wrong number of parts
assert parse_traceparent_header("00-abc-def") is None
# Wrong trace_id length
assert parse_traceparent_header("00-abc-def-01") is None
def test_returns_none_for_empty_string(self):
from core.helper.trace_id_helper import parse_traceparent_header
assert parse_traceparent_header("") is None

View File

@ -99,29 +99,20 @@ def test_external_api_json_message_and_bad_request_rewrite():
assert res.get_json()["message"] == "Invalid JSON payload received or JSON payload is empty."
def test_external_api_param_mapping_and_quota_and_exc_info_none():
# Force exc_info() to return (None,None,None) only during request
import libs.external_api as ext
def test_external_api_param_mapping_and_quota():
app = _create_api_app()
client = app.test_client()
orig_exc_info = ext.sys.exc_info
try:
ext.sys.exc_info = lambda: (None, None, None)
# Param errors mapping payload path
res = client.get("/api/param-errors")
assert res.status_code == 400
data = res.get_json()
assert data["code"] == "invalid_param"
assert data["params"] == "field"
app = _create_api_app()
client = app.test_client()
# Param errors mapping payload path
res = client.get("/api/param-errors")
assert res.status_code == 400
data = res.get_json()
assert data["code"] == "invalid_param"
assert data["params"] == "field"
# Quota path — depending on Flask-RESTX internals it may be handled
res = client.get("/api/quota")
assert res.status_code in (400, 429)
finally:
ext.sys.exc_info = orig_exc_info # type: ignore[assignment]
# Quota path — depending on Flask-RESTX internals it may be handled
res = client.get("/api/quota")
assert res.status_code in (400, 429)
def test_unauthorized_and_force_logout_clears_cookies():