mirror of https://github.com/langgenius/dify.git
Merge 2f67c5aa75 into 2c919efa69
This commit is contained in:
commit
4941b2981e
|
|
@ -2,9 +2,11 @@ import logging
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from opentelemetry.trace import get_current_span
|
from opentelemetry.trace import get_current_span
|
||||||
|
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from contexts.wrapper import RecyclableContextVar
|
from contexts.wrapper import RecyclableContextVar
|
||||||
|
from core.logging.context import init_request_context
|
||||||
from dify_app import DifyApp
|
from dify_app import DifyApp
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -25,28 +27,35 @@ def create_flask_app_with_configs() -> DifyApp:
|
||||||
# add before request hook
|
# add before request hook
|
||||||
@dify_app.before_request
|
@dify_app.before_request
|
||||||
def before_request():
|
def before_request():
|
||||||
# add an unique identifier to each request
|
# Initialize logging context for this request
|
||||||
|
init_request_context()
|
||||||
RecyclableContextVar.increment_thread_recycles()
|
RecyclableContextVar.increment_thread_recycles()
|
||||||
|
|
||||||
# add after request hook for injecting X-Trace-Id header from OpenTelemetry span context
|
# add after request hook for injecting trace headers from OpenTelemetry span context
|
||||||
|
# Only adds headers when OTEL is enabled and has valid context
|
||||||
@dify_app.after_request
|
@dify_app.after_request
|
||||||
def add_trace_id_header(response):
|
def add_trace_headers(response):
|
||||||
try:
|
try:
|
||||||
span = get_current_span()
|
span = get_current_span()
|
||||||
ctx = span.get_span_context() if span else None
|
ctx = span.get_span_context() if span else None
|
||||||
if ctx and ctx.is_valid:
|
|
||||||
trace_id_hex = format(ctx.trace_id, "032x")
|
if not ctx or not ctx.is_valid:
|
||||||
# Avoid duplicates if some middleware added it
|
return response
|
||||||
if "X-Trace-Id" not in response.headers:
|
|
||||||
response.headers["X-Trace-Id"] = trace_id_hex
|
# Inject trace headers from OTEL context
|
||||||
|
if ctx.trace_id != INVALID_TRACE_ID and "X-Trace-Id" not in response.headers:
|
||||||
|
response.headers["X-Trace-Id"] = format(ctx.trace_id, "032x")
|
||||||
|
if ctx.span_id != INVALID_SPAN_ID and "X-Span-Id" not in response.headers:
|
||||||
|
response.headers["X-Span-Id"] = format(ctx.span_id, "016x")
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
# Never break the response due to tracing header injection
|
# Never break the response due to tracing header injection
|
||||||
logger.warning("Failed to add trace ID to response header", exc_info=True)
|
logger.warning("Failed to add trace headers to response", exc_info=True)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
# Capture the decorator's return value to avoid pyright reportUnusedFunction
|
# Capture the decorator's return value to avoid pyright reportUnusedFunction
|
||||||
_ = before_request
|
_ = before_request
|
||||||
_ = add_trace_id_header
|
_ = add_trace_headers
|
||||||
|
|
||||||
return dify_app
|
return dify_app
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -587,6 +587,11 @@ class LoggingConfig(BaseSettings):
|
||||||
default="INFO",
|
default="INFO",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
LOG_OUTPUT_FORMAT: Literal["text", "json"] = Field(
|
||||||
|
description="Log output format: 'text' for human-readable, 'json' for structured JSON logs.",
|
||||||
|
default="text",
|
||||||
|
)
|
||||||
|
|
||||||
LOG_FILE: str | None = Field(
|
LOG_FILE: str | None = Field(
|
||||||
description="File path for log output.",
|
description="File path for log output.",
|
||||||
default=None,
|
default=None,
|
||||||
|
|
|
||||||
|
|
@ -88,7 +88,41 @@ def _get_user_provided_host_header(headers: dict | None) -> str | None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _inject_trace_headers(headers: dict | None) -> dict:
|
||||||
|
"""
|
||||||
|
Inject W3C traceparent header for distributed tracing.
|
||||||
|
|
||||||
|
When OTEL is enabled, HTTPXClientInstrumentor handles trace propagation automatically.
|
||||||
|
When OTEL is disabled, we manually inject the traceparent header.
|
||||||
|
"""
|
||||||
|
if headers is None:
|
||||||
|
headers = {}
|
||||||
|
|
||||||
|
# Skip if already present (case-insensitive check)
|
||||||
|
for key in headers:
|
||||||
|
if key.lower() == "traceparent":
|
||||||
|
return headers
|
||||||
|
|
||||||
|
# Skip if OTEL is enabled - HTTPXClientInstrumentor handles this automatically
|
||||||
|
if dify_config.ENABLE_OTEL:
|
||||||
|
return headers
|
||||||
|
|
||||||
|
# Generate and inject traceparent for non-OTEL scenarios
|
||||||
|
try:
|
||||||
|
from core.helper.trace_id_helper import generate_traceparent_header
|
||||||
|
|
||||||
|
traceparent = generate_traceparent_header()
|
||||||
|
if traceparent:
|
||||||
|
headers["traceparent"] = traceparent
|
||||||
|
except Exception:
|
||||||
|
# Silently ignore errors to avoid breaking requests
|
||||||
|
logger.debug("Failed to generate traceparent header", exc_info=True)
|
||||||
|
|
||||||
|
return headers
|
||||||
|
|
||||||
|
|
||||||
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||||
|
# Convert requests-style allow_redirects to httpx-style follow_redirects
|
||||||
if "allow_redirects" in kwargs:
|
if "allow_redirects" in kwargs:
|
||||||
allow_redirects = kwargs.pop("allow_redirects")
|
allow_redirects = kwargs.pop("allow_redirects")
|
||||||
if "follow_redirects" not in kwargs:
|
if "follow_redirects" not in kwargs:
|
||||||
|
|
@ -106,18 +140,21 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||||
verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY)
|
verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY)
|
||||||
client = _get_ssrf_client(verify_option)
|
client = _get_ssrf_client(verify_option)
|
||||||
|
|
||||||
|
# Inject traceparent header for distributed tracing (when OTEL is not enabled)
|
||||||
|
headers = kwargs.get("headers") or {}
|
||||||
|
headers = _inject_trace_headers(headers)
|
||||||
|
kwargs["headers"] = headers
|
||||||
|
|
||||||
# Preserve user-provided Host header
|
# Preserve user-provided Host header
|
||||||
# When using a forward proxy, httpx may override the Host header based on the URL.
|
# When using a forward proxy, httpx may override the Host header based on the URL.
|
||||||
# We extract and preserve any explicitly set Host header to support virtual hosting.
|
# We extract and preserve any explicitly set Host header to support virtual hosting.
|
||||||
headers = kwargs.get("headers", {})
|
|
||||||
user_provided_host = _get_user_provided_host_header(headers)
|
user_provided_host = _get_user_provided_host_header(headers)
|
||||||
|
|
||||||
retries = 0
|
retries = 0
|
||||||
while retries <= max_retries:
|
while retries <= max_retries:
|
||||||
try:
|
try:
|
||||||
# Build the request manually to preserve the Host header
|
# Preserve the user-provided Host header
|
||||||
# httpx may override the Host header when using a proxy, so we use
|
# httpx may override the Host header when using a proxy
|
||||||
# the request API to explicitly set headers before sending
|
|
||||||
headers = {k: v for k, v in headers.items() if k.lower() != "host"}
|
headers = {k: v for k, v in headers.items() if k.lower() != "host"}
|
||||||
if user_provided_host is not None:
|
if user_provided_host is not None:
|
||||||
headers["host"] = user_provided_host
|
headers["host"] = user_provided_host
|
||||||
|
|
|
||||||
|
|
@ -103,3 +103,60 @@ def parse_traceparent_header(traceparent: str) -> str | None:
|
||||||
if len(parts) == 4 and len(parts[1]) == 32:
|
if len(parts) == 4 and len(parts[1]) == 32:
|
||||||
return parts[1]
|
return parts[1]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_span_id_from_otel_context() -> str | None:
|
||||||
|
"""
|
||||||
|
Retrieve the current span ID from the active OpenTelemetry trace context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A 16-character hex string representing the span ID, or None if not available.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from opentelemetry.trace import get_current_span
|
||||||
|
from opentelemetry.trace.span import INVALID_SPAN_ID
|
||||||
|
|
||||||
|
span = get_current_span()
|
||||||
|
if not span:
|
||||||
|
return None
|
||||||
|
|
||||||
|
span_context = span.get_span_context()
|
||||||
|
if not span_context or span_context.span_id == INVALID_SPAN_ID:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return f"{span_context.span_id:016x}"
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def generate_traceparent_header() -> str | None:
|
||||||
|
"""
|
||||||
|
Generate a W3C traceparent header from the current context.
|
||||||
|
|
||||||
|
Uses OpenTelemetry context if available, otherwise uses the
|
||||||
|
ContextVar-based trace_id from the logging context.
|
||||||
|
|
||||||
|
Format: {version}-{trace_id}-{span_id}-{flags}
|
||||||
|
Example: 00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A valid traceparent header string, or None if generation fails.
|
||||||
|
"""
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
# Try OTEL context first
|
||||||
|
trace_id = get_trace_id_from_otel_context()
|
||||||
|
span_id = get_span_id_from_otel_context()
|
||||||
|
|
||||||
|
if trace_id and span_id:
|
||||||
|
return f"00-{trace_id}-{span_id}-01"
|
||||||
|
|
||||||
|
# Fallback: use ContextVar-based trace_id or generate new one
|
||||||
|
from core.logging.context import get_trace_id as get_logging_trace_id
|
||||||
|
|
||||||
|
trace_id = get_logging_trace_id() or uuid.uuid4().hex
|
||||||
|
|
||||||
|
# Generate a new span_id (16 hex chars)
|
||||||
|
span_id = uuid.uuid4().hex[:16]
|
||||||
|
|
||||||
|
return f"00-{trace_id}-{span_id}-01"
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,20 @@
|
||||||
|
"""Structured logging components for Dify."""
|
||||||
|
|
||||||
|
from core.logging.context import (
|
||||||
|
clear_request_context,
|
||||||
|
get_request_id,
|
||||||
|
get_trace_id,
|
||||||
|
init_request_context,
|
||||||
|
)
|
||||||
|
from core.logging.filters import IdentityContextFilter, TraceContextFilter
|
||||||
|
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"IdentityContextFilter",
|
||||||
|
"StructuredJSONFormatter",
|
||||||
|
"TraceContextFilter",
|
||||||
|
"clear_request_context",
|
||||||
|
"get_request_id",
|
||||||
|
"get_trace_id",
|
||||||
|
"init_request_context",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,35 @@
|
||||||
|
"""Request context for logging - framework agnostic.
|
||||||
|
|
||||||
|
This module provides request-scoped context variables for logging,
|
||||||
|
using Python's contextvars for thread-safe and async-safe storage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from contextvars import ContextVar
|
||||||
|
|
||||||
|
_request_id: ContextVar[str] = ContextVar("log_request_id", default="")
|
||||||
|
_trace_id: ContextVar[str] = ContextVar("log_trace_id", default="")
|
||||||
|
|
||||||
|
|
||||||
|
def get_request_id() -> str:
|
||||||
|
"""Get current request ID (10 hex chars)."""
|
||||||
|
return _request_id.get()
|
||||||
|
|
||||||
|
|
||||||
|
def get_trace_id() -> str:
|
||||||
|
"""Get fallback trace ID when OTEL is unavailable (32 hex chars)."""
|
||||||
|
return _trace_id.get()
|
||||||
|
|
||||||
|
|
||||||
|
def init_request_context() -> None:
|
||||||
|
"""Initialize request context. Call at start of each request."""
|
||||||
|
req_id = uuid.uuid4().hex[:10]
|
||||||
|
trace_id = uuid.uuid5(uuid.NAMESPACE_DNS, req_id).hex
|
||||||
|
_request_id.set(req_id)
|
||||||
|
_trace_id.set(trace_id)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_request_context() -> None:
|
||||||
|
"""Clear request context. Call at end of request (optional)."""
|
||||||
|
_request_id.set("")
|
||||||
|
_trace_id.set("")
|
||||||
|
|
@ -0,0 +1,94 @@
|
||||||
|
"""Logging filters for structured logging."""
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import flask
|
||||||
|
|
||||||
|
from core.logging.context import get_request_id, get_trace_id
|
||||||
|
|
||||||
|
|
||||||
|
class TraceContextFilter(logging.Filter):
|
||||||
|
"""
|
||||||
|
Filter that adds trace_id and span_id to log records.
|
||||||
|
Integrates with OpenTelemetry when available, falls back to ContextVar-based trace_id.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def filter(self, record: logging.LogRecord) -> bool:
|
||||||
|
# Get trace context from OpenTelemetry
|
||||||
|
trace_id, span_id = self._get_otel_context()
|
||||||
|
|
||||||
|
# Set trace_id (fallback to ContextVar if no OTEL context)
|
||||||
|
if trace_id:
|
||||||
|
record.trace_id = trace_id
|
||||||
|
else:
|
||||||
|
record.trace_id = get_trace_id()
|
||||||
|
|
||||||
|
record.span_id = span_id or ""
|
||||||
|
|
||||||
|
# For backward compatibility, also set req_id
|
||||||
|
record.req_id = get_request_id()
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _get_otel_context(self) -> tuple[str, str]:
|
||||||
|
"""Extract trace_id and span_id from OpenTelemetry context."""
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
from opentelemetry.trace import get_current_span
|
||||||
|
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
|
||||||
|
|
||||||
|
span = get_current_span()
|
||||||
|
if span and span.get_span_context():
|
||||||
|
ctx = span.get_span_context()
|
||||||
|
if ctx.is_valid and ctx.trace_id != INVALID_TRACE_ID:
|
||||||
|
trace_id = f"{ctx.trace_id:032x}"
|
||||||
|
span_id = f"{ctx.span_id:016x}" if ctx.span_id != INVALID_SPAN_ID else ""
|
||||||
|
return trace_id, span_id
|
||||||
|
return "", ""
|
||||||
|
|
||||||
|
|
||||||
|
class IdentityContextFilter(logging.Filter):
|
||||||
|
"""
|
||||||
|
Filter that adds user identity context to log records.
|
||||||
|
Extracts tenant_id, user_id, and user_type from Flask-Login current_user.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def filter(self, record: logging.LogRecord) -> bool:
|
||||||
|
identity = self._extract_identity()
|
||||||
|
record.tenant_id = identity.get("tenant_id", "")
|
||||||
|
record.user_id = identity.get("user_id", "")
|
||||||
|
record.user_type = identity.get("user_type", "")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _extract_identity(self) -> dict[str, str]:
|
||||||
|
"""Extract identity from current_user if in request context."""
|
||||||
|
try:
|
||||||
|
if not flask.has_request_context():
|
||||||
|
return {}
|
||||||
|
from flask_login import current_user
|
||||||
|
|
||||||
|
# Check if user is authenticated using the proxy
|
||||||
|
if not current_user.is_authenticated:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Access the underlying user object
|
||||||
|
user = current_user
|
||||||
|
|
||||||
|
from models import Account
|
||||||
|
from models.model import EndUser
|
||||||
|
|
||||||
|
identity: dict[str, str] = {}
|
||||||
|
|
||||||
|
if isinstance(user, Account):
|
||||||
|
if user.current_tenant_id:
|
||||||
|
identity["tenant_id"] = user.current_tenant_id
|
||||||
|
identity["user_id"] = user.id
|
||||||
|
identity["user_type"] = "account"
|
||||||
|
elif isinstance(user, EndUser):
|
||||||
|
identity["tenant_id"] = user.tenant_id
|
||||||
|
identity["user_id"] = user.id
|
||||||
|
identity["user_type"] = user.type or "end_user"
|
||||||
|
|
||||||
|
return identity
|
||||||
|
except Exception:
|
||||||
|
return {}
|
||||||
|
|
@ -0,0 +1,107 @@
|
||||||
|
"""Structured JSON log formatter for Dify."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import traceback
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import orjson
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
|
||||||
|
|
||||||
|
class StructuredJSONFormatter(logging.Formatter):
|
||||||
|
"""
|
||||||
|
JSON log formatter following the specified schema:
|
||||||
|
{
|
||||||
|
"ts": "ISO 8601 UTC",
|
||||||
|
"severity": "INFO|ERROR|WARN|DEBUG",
|
||||||
|
"service": "service name",
|
||||||
|
"caller": "file:line",
|
||||||
|
"trace_id": "hex 32",
|
||||||
|
"span_id": "hex 16",
|
||||||
|
"identity": { "tenant_id", "user_id", "user_type" },
|
||||||
|
"message": "log message",
|
||||||
|
"attributes": { ... },
|
||||||
|
"stack_trace": "..."
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
SEVERITY_MAP: dict[int, str] = {
|
||||||
|
logging.DEBUG: "DEBUG",
|
||||||
|
logging.INFO: "INFO",
|
||||||
|
logging.WARNING: "WARN",
|
||||||
|
logging.ERROR: "ERROR",
|
||||||
|
logging.CRITICAL: "ERROR",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, service_name: str | None = None):
|
||||||
|
super().__init__()
|
||||||
|
self._service_name = service_name or dify_config.APPLICATION_NAME
|
||||||
|
|
||||||
|
def format(self, record: logging.LogRecord) -> str:
|
||||||
|
log_dict = self._build_log_dict(record)
|
||||||
|
try:
|
||||||
|
return orjson.dumps(log_dict).decode("utf-8")
|
||||||
|
except TypeError:
|
||||||
|
# Fallback: convert non-serializable objects to string
|
||||||
|
import json
|
||||||
|
|
||||||
|
return json.dumps(log_dict, default=str, ensure_ascii=False)
|
||||||
|
|
||||||
|
def _build_log_dict(self, record: logging.LogRecord) -> dict[str, Any]:
|
||||||
|
# Core fields
|
||||||
|
log_dict: dict[str, Any] = {
|
||||||
|
"ts": datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z"),
|
||||||
|
"severity": self.SEVERITY_MAP.get(record.levelno, "INFO"),
|
||||||
|
"service": self._service_name,
|
||||||
|
"caller": f"{record.filename}:{record.lineno}",
|
||||||
|
"message": record.getMessage(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Trace context (from TraceContextFilter)
|
||||||
|
trace_id = getattr(record, "trace_id", "")
|
||||||
|
span_id = getattr(record, "span_id", "")
|
||||||
|
|
||||||
|
if trace_id:
|
||||||
|
log_dict["trace_id"] = trace_id
|
||||||
|
if span_id:
|
||||||
|
log_dict["span_id"] = span_id
|
||||||
|
|
||||||
|
# Identity context (from IdentityContextFilter)
|
||||||
|
identity = self._extract_identity(record)
|
||||||
|
if identity:
|
||||||
|
log_dict["identity"] = identity
|
||||||
|
|
||||||
|
# Dynamic attributes
|
||||||
|
attributes = getattr(record, "attributes", None)
|
||||||
|
if attributes:
|
||||||
|
log_dict["attributes"] = attributes
|
||||||
|
|
||||||
|
# Stack trace for errors with exceptions
|
||||||
|
if record.exc_info and record.levelno >= logging.ERROR:
|
||||||
|
log_dict["stack_trace"] = self._format_exception(record.exc_info)
|
||||||
|
|
||||||
|
return log_dict
|
||||||
|
|
||||||
|
def _extract_identity(self, record: logging.LogRecord) -> dict[str, str] | None:
|
||||||
|
tenant_id = getattr(record, "tenant_id", None)
|
||||||
|
user_id = getattr(record, "user_id", None)
|
||||||
|
user_type = getattr(record, "user_type", None)
|
||||||
|
|
||||||
|
if not any([tenant_id, user_id, user_type]):
|
||||||
|
return None
|
||||||
|
|
||||||
|
identity: dict[str, str] = {}
|
||||||
|
if tenant_id:
|
||||||
|
identity["tenant_id"] = tenant_id
|
||||||
|
if user_id:
|
||||||
|
identity["user_id"] = user_id
|
||||||
|
if user_type:
|
||||||
|
identity["user_type"] = user_type
|
||||||
|
return identity
|
||||||
|
|
||||||
|
def _format_exception(self, exc_info: tuple[Any, ...]) -> str:
|
||||||
|
if exc_info and exc_info[0] is not None:
|
||||||
|
return "".join(traceback.format_exception(*exc_info))
|
||||||
|
return ""
|
||||||
|
|
@ -103,6 +103,9 @@ class BasePluginClient:
|
||||||
prepared_headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY
|
prepared_headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY
|
||||||
prepared_headers.setdefault("Accept-Encoding", "gzip, deflate, br")
|
prepared_headers.setdefault("Accept-Encoding", "gzip, deflate, br")
|
||||||
|
|
||||||
|
# Inject traceparent header for distributed tracing
|
||||||
|
self._inject_trace_headers(prepared_headers)
|
||||||
|
|
||||||
prepared_data: bytes | dict[str, Any] | str | None = (
|
prepared_data: bytes | dict[str, Any] | str | None = (
|
||||||
data if isinstance(data, (bytes, str, dict)) or data is None else None
|
data if isinstance(data, (bytes, str, dict)) or data is None else None
|
||||||
)
|
)
|
||||||
|
|
@ -114,6 +117,28 @@ class BasePluginClient:
|
||||||
|
|
||||||
return str(url), prepared_headers, prepared_data, params, files
|
return str(url), prepared_headers, prepared_data, params, files
|
||||||
|
|
||||||
|
def _inject_trace_headers(self, headers: dict[str, str]) -> None:
|
||||||
|
"""
|
||||||
|
Inject W3C traceparent header for distributed tracing.
|
||||||
|
|
||||||
|
This ensures trace context is propagated to plugin daemon even if
|
||||||
|
HTTPXClientInstrumentor doesn't cover module-level httpx functions.
|
||||||
|
"""
|
||||||
|
import contextlib
|
||||||
|
|
||||||
|
# Skip if already present (case-insensitive check)
|
||||||
|
for key in headers:
|
||||||
|
if key.lower() == "traceparent":
|
||||||
|
return
|
||||||
|
|
||||||
|
# Inject traceparent - works as fallback when OTEL instrumentation doesn't cover this call
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
from core.helper.trace_id_helper import generate_traceparent_header
|
||||||
|
|
||||||
|
traceparent = generate_traceparent_header()
|
||||||
|
if traceparent:
|
||||||
|
headers["traceparent"] = traceparent
|
||||||
|
|
||||||
def _stream_request(
|
def _stream_request(
|
||||||
self,
|
self,
|
||||||
method: str,
|
method: str,
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,11 @@ def _get_celery_ssl_options() -> dict[str, Any] | None:
|
||||||
def init_app(app: DifyApp) -> Celery:
|
def init_app(app: DifyApp) -> Celery:
|
||||||
class FlaskTask(Task):
|
class FlaskTask(Task):
|
||||||
def __call__(self, *args: object, **kwargs: object) -> object:
|
def __call__(self, *args: object, **kwargs: object) -> object:
|
||||||
|
from core.logging.context import init_request_context
|
||||||
|
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
|
# Initialize logging context for this task (similar to before_request in Flask)
|
||||||
|
init_request_context()
|
||||||
return self.run(*args, **kwargs)
|
return self.run(*args, **kwargs)
|
||||||
|
|
||||||
broker_transport_options = {}
|
broker_transport_options = {}
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,19 @@
|
||||||
|
"""Logging extension for Dify Flask application."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import uuid
|
|
||||||
from logging.handlers import RotatingFileHandler
|
from logging.handlers import RotatingFileHandler
|
||||||
|
|
||||||
import flask
|
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.helper.trace_id_helper import get_trace_id_from_otel_context
|
|
||||||
from dify_app import DifyApp
|
from dify_app import DifyApp
|
||||||
|
|
||||||
|
|
||||||
def init_app(app: DifyApp):
|
def init_app(app: DifyApp):
|
||||||
|
"""Initialize logging with support for text or JSON format."""
|
||||||
log_handlers: list[logging.Handler] = []
|
log_handlers: list[logging.Handler] = []
|
||||||
|
|
||||||
|
# File handler
|
||||||
log_file = dify_config.LOG_FILE
|
log_file = dify_config.LOG_FILE
|
||||||
if log_file:
|
if log_file:
|
||||||
log_dir = os.path.dirname(log_file)
|
log_dir = os.path.dirname(log_file)
|
||||||
|
|
@ -25,27 +26,53 @@ def init_app(app: DifyApp):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Always add StreamHandler to log to console
|
# Console handler
|
||||||
sh = logging.StreamHandler(sys.stdout)
|
sh = logging.StreamHandler(sys.stdout)
|
||||||
log_handlers.append(sh)
|
log_handlers.append(sh)
|
||||||
|
|
||||||
# Apply RequestIdFilter to all handlers
|
# Apply filters to all handlers
|
||||||
for handler in log_handlers:
|
from core.logging.filters import IdentityContextFilter, TraceContextFilter
|
||||||
handler.addFilter(RequestIdFilter())
|
|
||||||
|
|
||||||
|
for handler in log_handlers:
|
||||||
|
handler.addFilter(TraceContextFilter())
|
||||||
|
handler.addFilter(IdentityContextFilter())
|
||||||
|
|
||||||
|
# Configure formatter based on format type
|
||||||
|
formatter = _create_formatter()
|
||||||
|
for handler in log_handlers:
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
# Configure root logger
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=dify_config.LOG_LEVEL,
|
level=dify_config.LOG_LEVEL,
|
||||||
format=dify_config.LOG_FORMAT,
|
|
||||||
datefmt=dify_config.LOG_DATEFORMAT,
|
|
||||||
handlers=log_handlers,
|
handlers=log_handlers,
|
||||||
force=True,
|
force=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply RequestIdFormatter to all handlers
|
|
||||||
apply_request_id_formatter()
|
|
||||||
|
|
||||||
# Disable propagation for noisy loggers to avoid duplicate logs
|
# Disable propagation for noisy loggers to avoid duplicate logs
|
||||||
logging.getLogger("sqlalchemy.engine").propagate = False
|
logging.getLogger("sqlalchemy.engine").propagate = False
|
||||||
|
|
||||||
|
# Apply timezone if specified (only for text format)
|
||||||
|
if dify_config.LOG_OUTPUT_FORMAT == "text":
|
||||||
|
_apply_timezone(log_handlers)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_formatter() -> logging.Formatter:
|
||||||
|
"""Create appropriate formatter based on configuration."""
|
||||||
|
if dify_config.LOG_OUTPUT_FORMAT == "json":
|
||||||
|
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||||
|
|
||||||
|
return StructuredJSONFormatter()
|
||||||
|
else:
|
||||||
|
# Text format - use existing pattern with backward compatible formatter
|
||||||
|
return _TextFormatter(
|
||||||
|
fmt=dify_config.LOG_FORMAT,
|
||||||
|
datefmt=dify_config.LOG_DATEFORMAT,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_timezone(handlers: list[logging.Handler]):
|
||||||
|
"""Apply timezone conversion to text formatters."""
|
||||||
log_tz = dify_config.LOG_TZ
|
log_tz = dify_config.LOG_TZ
|
||||||
if log_tz:
|
if log_tz:
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
@ -57,34 +84,51 @@ def init_app(app: DifyApp):
|
||||||
def time_converter(seconds):
|
def time_converter(seconds):
|
||||||
return datetime.fromtimestamp(seconds, tz=timezone).timetuple()
|
return datetime.fromtimestamp(seconds, tz=timezone).timetuple()
|
||||||
|
|
||||||
for handler in logging.root.handlers:
|
for handler in handlers:
|
||||||
if handler.formatter:
|
if handler.formatter:
|
||||||
handler.formatter.converter = time_converter
|
handler.formatter.converter = time_converter # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
def get_request_id():
|
class _TextFormatter(logging.Formatter):
|
||||||
if getattr(flask.g, "request_id", None):
|
"""Text formatter that ensures trace_id and req_id are always present."""
|
||||||
return flask.g.request_id
|
|
||||||
|
|
||||||
new_uuid = uuid.uuid4().hex[:10]
|
def format(self, record: logging.LogRecord) -> str:
|
||||||
flask.g.request_id = new_uuid
|
if not hasattr(record, "req_id"):
|
||||||
|
record.req_id = ""
|
||||||
return new_uuid
|
if not hasattr(record, "trace_id"):
|
||||||
|
record.trace_id = ""
|
||||||
|
if not hasattr(record, "span_id"):
|
||||||
|
record.span_id = ""
|
||||||
|
return super().format(record)
|
||||||
|
|
||||||
|
|
||||||
|
def get_request_id() -> str:
|
||||||
|
"""Get request ID for current request context.
|
||||||
|
|
||||||
|
Deprecated: Use core.logging.context.get_request_id() directly.
|
||||||
|
"""
|
||||||
|
from core.logging.context import get_request_id as _get_request_id
|
||||||
|
|
||||||
|
return _get_request_id()
|
||||||
|
|
||||||
|
|
||||||
|
# Backward compatibility aliases
|
||||||
class RequestIdFilter(logging.Filter):
|
class RequestIdFilter(logging.Filter):
|
||||||
# This is a logging filter that makes the request ID available for use in
|
"""Deprecated: Use TraceContextFilter from core.logging.filters instead."""
|
||||||
# the logging format. Note that we're checking if we're in a request
|
|
||||||
# context, as we may want to log things before Flask is fully loaded.
|
def filter(self, record: logging.LogRecord) -> bool:
|
||||||
def filter(self, record):
|
from core.logging.context import get_request_id as _get_request_id
|
||||||
trace_id = get_trace_id_from_otel_context() or ""
|
from core.logging.context import get_trace_id as _get_trace_id
|
||||||
record.req_id = get_request_id() if flask.has_request_context() else ""
|
|
||||||
record.trace_id = trace_id
|
record.req_id = _get_request_id()
|
||||||
|
record.trace_id = _get_trace_id()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
class RequestIdFormatter(logging.Formatter):
|
class RequestIdFormatter(logging.Formatter):
|
||||||
def format(self, record):
|
"""Deprecated: Use _TextFormatter instead."""
|
||||||
|
|
||||||
|
def format(self, record: logging.LogRecord) -> str:
|
||||||
if not hasattr(record, "req_id"):
|
if not hasattr(record, "req_id"):
|
||||||
record.req_id = ""
|
record.req_id = ""
|
||||||
if not hasattr(record, "trace_id"):
|
if not hasattr(record, "trace_id"):
|
||||||
|
|
@ -93,6 +137,7 @@ class RequestIdFormatter(logging.Formatter):
|
||||||
|
|
||||||
|
|
||||||
def apply_request_id_formatter():
|
def apply_request_id_formatter():
|
||||||
|
"""Deprecated: Formatter is now applied in init_app."""
|
||||||
for handler in logging.root.handlers:
|
for handler in logging.root.handlers:
|
||||||
if handler.formatter:
|
if handler.formatter:
|
||||||
handler.formatter = RequestIdFormatter(dify_config.LOG_FORMAT, dify_config.LOG_DATEFORMAT)
|
handler.formatter = RequestIdFormatter(dify_config.LOG_FORMAT, dify_config.LOG_DATEFORMAT)
|
||||||
|
|
|
||||||
|
|
@ -19,11 +19,29 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ExceptionLoggingHandler(logging.Handler):
|
class ExceptionLoggingHandler(logging.Handler):
|
||||||
|
"""
|
||||||
|
Handler that records exceptions to the current OpenTelemetry span.
|
||||||
|
|
||||||
|
Unlike creating a new span, this records exceptions on the existing span
|
||||||
|
to maintain trace context consistency throughout the request lifecycle.
|
||||||
|
"""
|
||||||
|
|
||||||
def emit(self, record: logging.LogRecord):
|
def emit(self, record: logging.LogRecord):
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
if record.exc_info:
|
if not record.exc_info:
|
||||||
tracer = get_tracer_provider().get_tracer("dify.exception.logging")
|
return
|
||||||
with tracer.start_as_current_span(
|
|
||||||
|
from opentelemetry.trace import get_current_span
|
||||||
|
|
||||||
|
span = get_current_span()
|
||||||
|
if not span or not span.is_recording():
|
||||||
|
return
|
||||||
|
|
||||||
|
# Record exception on the current span instead of creating a new one
|
||||||
|
span.set_status(StatusCode.ERROR, record.getMessage())
|
||||||
|
|
||||||
|
# Add log context as span events/attributes
|
||||||
|
span.add_event(
|
||||||
"log.exception",
|
"log.exception",
|
||||||
attributes={
|
attributes={
|
||||||
"log.level": record.levelname,
|
"log.level": record.levelname,
|
||||||
|
|
@ -32,11 +50,10 @@ class ExceptionLoggingHandler(logging.Handler):
|
||||||
"log.file.path": record.pathname,
|
"log.file.path": record.pathname,
|
||||||
"log.file.line": record.lineno,
|
"log.file.line": record.lineno,
|
||||||
},
|
},
|
||||||
) as span:
|
)
|
||||||
span.set_status(StatusCode.ERROR)
|
|
||||||
if record.exc_info[1]:
|
if record.exc_info[1]:
|
||||||
span.record_exception(record.exc_info[1])
|
span.record_exception(record.exc_info[1])
|
||||||
span.set_attribute("exception.message", str(record.exc_info[1]))
|
|
||||||
if record.exc_info[0]:
|
if record.exc_info[0]:
|
||||||
span.set_attribute("exception.type", record.exc_info[0].__name__)
|
span.set_attribute("exception.type", record.exc_info[0].__name__)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
import re
|
import re
|
||||||
import sys
|
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
@ -109,11 +108,8 @@ def register_external_error_handlers(api: Api):
|
||||||
data.setdefault("code", "unknown")
|
data.setdefault("code", "unknown")
|
||||||
data.setdefault("status", status_code)
|
data.setdefault("status", status_code)
|
||||||
|
|
||||||
# Log stack
|
# Note: Exception logging is handled by Flask/Flask-RESTX framework automatically
|
||||||
exc_info: Any = sys.exc_info()
|
# Explicit log_exception call removed to avoid duplicate log entries
|
||||||
if exc_info[1] is None:
|
|
||||||
exc_info = (None, None, None)
|
|
||||||
current_app.log_exception(exc_info)
|
|
||||||
|
|
||||||
return data, status_code
|
return data, status_code
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -14,12 +14,12 @@ def test_successful_request(mock_get_client):
|
||||||
mock_client = MagicMock()
|
mock_client = MagicMock()
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.status_code = 200
|
mock_response.status_code = 200
|
||||||
mock_client.send.return_value = mock_response
|
|
||||||
mock_client.request.return_value = mock_response
|
mock_client.request.return_value = mock_response
|
||||||
mock_get_client.return_value = mock_client
|
mock_get_client.return_value = mock_client
|
||||||
|
|
||||||
response = make_request("GET", "http://example.com")
|
response = make_request("GET", "http://example.com")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
mock_client.request.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||||
|
|
@ -27,7 +27,6 @@ def test_retry_exceed_max_retries(mock_get_client):
|
||||||
mock_client = MagicMock()
|
mock_client = MagicMock()
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.status_code = 500
|
mock_response.status_code = 500
|
||||||
mock_client.send.return_value = mock_response
|
|
||||||
mock_client.request.return_value = mock_response
|
mock_client.request.return_value = mock_response
|
||||||
mock_get_client.return_value = mock_client
|
mock_get_client.return_value = mock_client
|
||||||
|
|
||||||
|
|
@ -72,34 +71,12 @@ class TestGetUserProvidedHostHeader:
|
||||||
assert result in ("first.com", "second.com")
|
assert result in ("first.com", "second.com")
|
||||||
|
|
||||||
|
|
||||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
|
||||||
def test_host_header_preservation_without_user_header(mock_get_client):
|
|
||||||
"""Test that when no Host header is provided, the default behavior is maintained."""
|
|
||||||
mock_client = MagicMock()
|
|
||||||
mock_request = MagicMock()
|
|
||||||
mock_request.headers = {}
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.status_code = 200
|
|
||||||
mock_client.send.return_value = mock_response
|
|
||||||
mock_client.request.return_value = mock_response
|
|
||||||
mock_get_client.return_value = mock_client
|
|
||||||
|
|
||||||
response = make_request("GET", "http://example.com")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
# Host should not be set if not provided by user
|
|
||||||
assert "Host" not in mock_request.headers or mock_request.headers.get("Host") is None
|
|
||||||
|
|
||||||
|
|
||||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||||
def test_host_header_preservation_with_user_header(mock_get_client):
|
def test_host_header_preservation_with_user_header(mock_get_client):
|
||||||
"""Test that user-provided Host header is preserved in the request."""
|
"""Test that user-provided Host header is preserved in the request."""
|
||||||
mock_client = MagicMock()
|
mock_client = MagicMock()
|
||||||
mock_request = MagicMock()
|
|
||||||
mock_request.headers = {}
|
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.status_code = 200
|
mock_response.status_code = 200
|
||||||
mock_client.send.return_value = mock_response
|
|
||||||
mock_client.request.return_value = mock_response
|
mock_client.request.return_value = mock_response
|
||||||
mock_get_client.return_value = mock_client
|
mock_get_client.return_value = mock_client
|
||||||
|
|
||||||
|
|
@ -107,3 +84,93 @@ def test_host_header_preservation_with_user_header(mock_get_client):
|
||||||
response = make_request("GET", "http://example.com", headers={"Host": custom_host})
|
response = make_request("GET", "http://example.com", headers={"Host": custom_host})
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
# Verify client.request was called with the host header preserved (lowercase)
|
||||||
|
call_kwargs = mock_client.request.call_args.kwargs
|
||||||
|
assert call_kwargs["headers"]["host"] == custom_host
|
||||||
|
|
||||||
|
|
||||||
|
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||||
|
@pytest.mark.parametrize("host_key", ["host", "HOST", "Host"])
|
||||||
|
def test_host_header_preservation_case_insensitive(mock_get_client, host_key):
|
||||||
|
"""Test that Host header is preserved regardless of case."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_client.request.return_value = mock_response
|
||||||
|
mock_get_client.return_value = mock_client
|
||||||
|
|
||||||
|
response = make_request("GET", "http://example.com", headers={host_key: "api.example.com"})
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
# Host header should be normalized to lowercase "host"
|
||||||
|
call_kwargs = mock_client.request.call_args.kwargs
|
||||||
|
assert call_kwargs["headers"]["host"] == "api.example.com"
|
||||||
|
|
||||||
|
|
||||||
|
class TestFollowRedirectsParameter:
|
||||||
|
"""Tests for follow_redirects parameter handling.
|
||||||
|
|
||||||
|
These tests verify that follow_redirects is correctly passed to client.request().
|
||||||
|
"""
|
||||||
|
|
||||||
|
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||||
|
def test_follow_redirects_passed_to_request(self, mock_get_client):
|
||||||
|
"""Verify follow_redirects IS passed to client.request()."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_client.request.return_value = mock_response
|
||||||
|
mock_get_client.return_value = mock_client
|
||||||
|
|
||||||
|
make_request("GET", "http://example.com", follow_redirects=True)
|
||||||
|
|
||||||
|
# Verify follow_redirects was passed to request
|
||||||
|
call_kwargs = mock_client.request.call_args.kwargs
|
||||||
|
assert call_kwargs.get("follow_redirects") is True
|
||||||
|
|
||||||
|
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||||
|
def test_allow_redirects_converted_to_follow_redirects(self, mock_get_client):
|
||||||
|
"""Verify allow_redirects (requests-style) is converted to follow_redirects (httpx-style)."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_client.request.return_value = mock_response
|
||||||
|
mock_get_client.return_value = mock_client
|
||||||
|
|
||||||
|
# Use allow_redirects (requests-style parameter)
|
||||||
|
make_request("GET", "http://example.com", allow_redirects=True)
|
||||||
|
|
||||||
|
# Verify it was converted to follow_redirects
|
||||||
|
call_kwargs = mock_client.request.call_args.kwargs
|
||||||
|
assert call_kwargs.get("follow_redirects") is True
|
||||||
|
assert "allow_redirects" not in call_kwargs
|
||||||
|
|
||||||
|
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||||
|
def test_follow_redirects_not_set_when_not_specified(self, mock_get_client):
|
||||||
|
"""Verify follow_redirects is not in kwargs when not specified (httpx default behavior)."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_client.request.return_value = mock_response
|
||||||
|
mock_get_client.return_value = mock_client
|
||||||
|
|
||||||
|
make_request("GET", "http://example.com")
|
||||||
|
|
||||||
|
# follow_redirects should not be in kwargs, letting httpx use its default
|
||||||
|
call_kwargs = mock_client.request.call_args.kwargs
|
||||||
|
assert "follow_redirects" not in call_kwargs
|
||||||
|
|
||||||
|
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||||
|
def test_follow_redirects_takes_precedence_over_allow_redirects(self, mock_get_client):
|
||||||
|
"""Verify follow_redirects takes precedence when both are specified."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_client.request.return_value = mock_response
|
||||||
|
mock_get_client.return_value = mock_client
|
||||||
|
|
||||||
|
# Both specified - follow_redirects should take precedence
|
||||||
|
make_request("GET", "http://example.com", allow_redirects=False, follow_redirects=True)
|
||||||
|
|
||||||
|
call_kwargs = mock_client.request.call_args.kwargs
|
||||||
|
assert call_kwargs.get("follow_redirects") is True
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,79 @@
|
||||||
|
"""Tests for logging context module."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from core.logging.context import (
|
||||||
|
clear_request_context,
|
||||||
|
get_request_id,
|
||||||
|
get_trace_id,
|
||||||
|
init_request_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoggingContext:
|
||||||
|
"""Tests for the logging context functions."""
|
||||||
|
|
||||||
|
def test_init_creates_request_id(self):
|
||||||
|
"""init_request_context should create a 10-char request ID."""
|
||||||
|
init_request_context()
|
||||||
|
request_id = get_request_id()
|
||||||
|
assert len(request_id) == 10
|
||||||
|
assert all(c in "0123456789abcdef" for c in request_id)
|
||||||
|
|
||||||
|
def test_init_creates_trace_id(self):
|
||||||
|
"""init_request_context should create a 32-char trace ID."""
|
||||||
|
init_request_context()
|
||||||
|
trace_id = get_trace_id()
|
||||||
|
assert len(trace_id) == 32
|
||||||
|
assert all(c in "0123456789abcdef" for c in trace_id)
|
||||||
|
|
||||||
|
def test_trace_id_derived_from_request_id(self):
|
||||||
|
"""trace_id should be deterministically derived from request_id."""
|
||||||
|
init_request_context()
|
||||||
|
request_id = get_request_id()
|
||||||
|
trace_id = get_trace_id()
|
||||||
|
|
||||||
|
# Verify trace_id is derived using uuid5
|
||||||
|
expected_trace = uuid.uuid5(uuid.NAMESPACE_DNS, request_id).hex
|
||||||
|
assert trace_id == expected_trace
|
||||||
|
|
||||||
|
def test_clear_resets_context(self):
|
||||||
|
"""clear_request_context should reset both IDs to empty strings."""
|
||||||
|
init_request_context()
|
||||||
|
assert get_request_id() != ""
|
||||||
|
assert get_trace_id() != ""
|
||||||
|
|
||||||
|
clear_request_context()
|
||||||
|
assert get_request_id() == ""
|
||||||
|
assert get_trace_id() == ""
|
||||||
|
|
||||||
|
def test_default_values_are_empty(self):
|
||||||
|
"""Default values should be empty strings before init."""
|
||||||
|
clear_request_context()
|
||||||
|
assert get_request_id() == ""
|
||||||
|
assert get_trace_id() == ""
|
||||||
|
|
||||||
|
def test_multiple_inits_create_different_ids(self):
|
||||||
|
"""Each init should create new unique IDs."""
|
||||||
|
init_request_context()
|
||||||
|
first_request_id = get_request_id()
|
||||||
|
first_trace_id = get_trace_id()
|
||||||
|
|
||||||
|
init_request_context()
|
||||||
|
second_request_id = get_request_id()
|
||||||
|
second_trace_id = get_trace_id()
|
||||||
|
|
||||||
|
assert first_request_id != second_request_id
|
||||||
|
assert first_trace_id != second_trace_id
|
||||||
|
|
||||||
|
def test_context_isolation(self):
|
||||||
|
"""Context should be isolated per-call (no thread leakage in same thread)."""
|
||||||
|
init_request_context()
|
||||||
|
id1 = get_request_id()
|
||||||
|
|
||||||
|
# Simulate another request
|
||||||
|
init_request_context()
|
||||||
|
id2 = get_request_id()
|
||||||
|
|
||||||
|
# IDs should be different
|
||||||
|
assert id1 != id2
|
||||||
|
|
@ -0,0 +1,114 @@
|
||||||
|
"""Tests for logging filters."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def log_record():
|
||||||
|
return logging.LogRecord(
|
||||||
|
name="test",
|
||||||
|
level=logging.INFO,
|
||||||
|
pathname="",
|
||||||
|
lineno=0,
|
||||||
|
msg="test",
|
||||||
|
args=(),
|
||||||
|
exc_info=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTraceContextFilter:
|
||||||
|
def test_sets_empty_trace_id_without_context(self, log_record):
|
||||||
|
from core.logging.context import clear_request_context
|
||||||
|
from core.logging.filters import TraceContextFilter
|
||||||
|
|
||||||
|
# Ensure no context is set
|
||||||
|
clear_request_context()
|
||||||
|
|
||||||
|
filter = TraceContextFilter()
|
||||||
|
result = filter.filter(log_record)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
assert hasattr(log_record, "trace_id")
|
||||||
|
assert hasattr(log_record, "span_id")
|
||||||
|
assert hasattr(log_record, "req_id")
|
||||||
|
# Without context, IDs should be empty
|
||||||
|
assert log_record.trace_id == ""
|
||||||
|
assert log_record.req_id == ""
|
||||||
|
|
||||||
|
def test_sets_trace_id_from_context(self, log_record):
|
||||||
|
"""Test that trace_id and req_id are set from ContextVar when initialized."""
|
||||||
|
from core.logging.context import init_request_context
|
||||||
|
from core.logging.filters import TraceContextFilter
|
||||||
|
|
||||||
|
# Initialize context (no Flask needed!)
|
||||||
|
init_request_context()
|
||||||
|
|
||||||
|
filter = TraceContextFilter()
|
||||||
|
filter.filter(log_record)
|
||||||
|
|
||||||
|
# With context initialized, IDs should be set
|
||||||
|
assert log_record.trace_id != ""
|
||||||
|
assert len(log_record.trace_id) == 32
|
||||||
|
assert log_record.req_id != ""
|
||||||
|
assert len(log_record.req_id) == 10
|
||||||
|
|
||||||
|
def test_filter_always_returns_true(self, log_record):
|
||||||
|
from core.logging.filters import TraceContextFilter
|
||||||
|
|
||||||
|
filter = TraceContextFilter()
|
||||||
|
result = filter.filter(log_record)
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
def test_sets_trace_id_from_otel_when_available(self, log_record):
|
||||||
|
from core.logging.filters import TraceContextFilter
|
||||||
|
|
||||||
|
mock_span = mock.MagicMock()
|
||||||
|
mock_context = mock.MagicMock()
|
||||||
|
mock_context.trace_id = 0x5B8AA5A2D2C872E8321CF37308D69DF2
|
||||||
|
mock_context.span_id = 0x051581BF3BB55C45
|
||||||
|
mock_span.get_span_context.return_value = mock_context
|
||||||
|
|
||||||
|
with (
|
||||||
|
mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span),
|
||||||
|
mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0),
|
||||||
|
mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0),
|
||||||
|
):
|
||||||
|
filter = TraceContextFilter()
|
||||||
|
filter.filter(log_record)
|
||||||
|
|
||||||
|
assert log_record.trace_id == "5b8aa5a2d2c872e8321cf37308d69df2"
|
||||||
|
assert log_record.span_id == "051581bf3bb55c45"
|
||||||
|
|
||||||
|
|
||||||
|
class TestIdentityContextFilter:
|
||||||
|
def test_sets_empty_identity_without_request_context(self, log_record):
|
||||||
|
from core.logging.filters import IdentityContextFilter
|
||||||
|
|
||||||
|
filter = IdentityContextFilter()
|
||||||
|
result = filter.filter(log_record)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
assert log_record.tenant_id == ""
|
||||||
|
assert log_record.user_id == ""
|
||||||
|
assert log_record.user_type == ""
|
||||||
|
|
||||||
|
def test_filter_always_returns_true(self, log_record):
|
||||||
|
from core.logging.filters import IdentityContextFilter
|
||||||
|
|
||||||
|
filter = IdentityContextFilter()
|
||||||
|
result = filter.filter(log_record)
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
def test_handles_exception_gracefully(self, log_record):
|
||||||
|
from core.logging.filters import IdentityContextFilter
|
||||||
|
|
||||||
|
filter = IdentityContextFilter()
|
||||||
|
|
||||||
|
# Should not raise even if something goes wrong
|
||||||
|
with mock.patch("core.logging.filters.flask.has_request_context", side_effect=Exception("Test error")):
|
||||||
|
result = filter.filter(log_record)
|
||||||
|
assert result is True
|
||||||
|
assert log_record.tenant_id == ""
|
||||||
|
|
@ -0,0 +1,267 @@
|
||||||
|
"""Tests for structured JSON formatter."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import orjson
|
||||||
|
|
||||||
|
|
||||||
|
class TestStructuredJSONFormatter:
|
||||||
|
def test_basic_log_format(self):
|
||||||
|
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||||
|
|
||||||
|
formatter = StructuredJSONFormatter(service_name="test-service")
|
||||||
|
record = logging.LogRecord(
|
||||||
|
name="test",
|
||||||
|
level=logging.INFO,
|
||||||
|
pathname="test.py",
|
||||||
|
lineno=42,
|
||||||
|
msg="Test message",
|
||||||
|
args=(),
|
||||||
|
exc_info=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
output = formatter.format(record)
|
||||||
|
log_dict = orjson.loads(output)
|
||||||
|
|
||||||
|
assert log_dict["severity"] == "INFO"
|
||||||
|
assert log_dict["service"] == "test-service"
|
||||||
|
assert log_dict["caller"] == "test.py:42"
|
||||||
|
assert log_dict["message"] == "Test message"
|
||||||
|
assert "ts" in log_dict
|
||||||
|
assert log_dict["ts"].endswith("Z")
|
||||||
|
|
||||||
|
def test_severity_mapping(self):
|
||||||
|
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||||
|
|
||||||
|
formatter = StructuredJSONFormatter()
|
||||||
|
|
||||||
|
test_cases = [
|
||||||
|
(logging.DEBUG, "DEBUG"),
|
||||||
|
(logging.INFO, "INFO"),
|
||||||
|
(logging.WARNING, "WARN"),
|
||||||
|
(logging.ERROR, "ERROR"),
|
||||||
|
(logging.CRITICAL, "ERROR"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for level, expected_severity in test_cases:
|
||||||
|
record = logging.LogRecord(
|
||||||
|
name="test",
|
||||||
|
level=level,
|
||||||
|
pathname="test.py",
|
||||||
|
lineno=1,
|
||||||
|
msg="Test",
|
||||||
|
args=(),
|
||||||
|
exc_info=None,
|
||||||
|
)
|
||||||
|
output = formatter.format(record)
|
||||||
|
log_dict = orjson.loads(output)
|
||||||
|
assert log_dict["severity"] == expected_severity, f"Level {level} should map to {expected_severity}"
|
||||||
|
|
||||||
|
def test_error_with_stack_trace(self):
|
||||||
|
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||||
|
|
||||||
|
formatter = StructuredJSONFormatter()
|
||||||
|
|
||||||
|
try:
|
||||||
|
raise ValueError("Test error")
|
||||||
|
except ValueError:
|
||||||
|
exc_info = sys.exc_info()
|
||||||
|
|
||||||
|
record = logging.LogRecord(
|
||||||
|
name="test",
|
||||||
|
level=logging.ERROR,
|
||||||
|
pathname="test.py",
|
||||||
|
lineno=10,
|
||||||
|
msg="Error occurred",
|
||||||
|
args=(),
|
||||||
|
exc_info=exc_info,
|
||||||
|
)
|
||||||
|
|
||||||
|
output = formatter.format(record)
|
||||||
|
log_dict = orjson.loads(output)
|
||||||
|
|
||||||
|
assert log_dict["severity"] == "ERROR"
|
||||||
|
assert "stack_trace" in log_dict
|
||||||
|
assert "ValueError: Test error" in log_dict["stack_trace"]
|
||||||
|
|
||||||
|
def test_no_stack_trace_for_info(self):
|
||||||
|
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||||
|
|
||||||
|
formatter = StructuredJSONFormatter()
|
||||||
|
|
||||||
|
try:
|
||||||
|
raise ValueError("Test error")
|
||||||
|
except ValueError:
|
||||||
|
exc_info = sys.exc_info()
|
||||||
|
|
||||||
|
record = logging.LogRecord(
|
||||||
|
name="test",
|
||||||
|
level=logging.INFO,
|
||||||
|
pathname="test.py",
|
||||||
|
lineno=10,
|
||||||
|
msg="Info message",
|
||||||
|
args=(),
|
||||||
|
exc_info=exc_info,
|
||||||
|
)
|
||||||
|
|
||||||
|
output = formatter.format(record)
|
||||||
|
log_dict = orjson.loads(output)
|
||||||
|
|
||||||
|
assert "stack_trace" not in log_dict
|
||||||
|
|
||||||
|
def test_trace_context_included(self):
|
||||||
|
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||||
|
|
||||||
|
formatter = StructuredJSONFormatter()
|
||||||
|
record = logging.LogRecord(
|
||||||
|
name="test",
|
||||||
|
level=logging.INFO,
|
||||||
|
pathname="test.py",
|
||||||
|
lineno=1,
|
||||||
|
msg="Test",
|
||||||
|
args=(),
|
||||||
|
exc_info=None,
|
||||||
|
)
|
||||||
|
record.trace_id = "5b8aa5a2d2c872e8321cf37308d69df2"
|
||||||
|
record.span_id = "051581bf3bb55c45"
|
||||||
|
|
||||||
|
output = formatter.format(record)
|
||||||
|
log_dict = orjson.loads(output)
|
||||||
|
|
||||||
|
assert log_dict["trace_id"] == "5b8aa5a2d2c872e8321cf37308d69df2"
|
||||||
|
assert log_dict["span_id"] == "051581bf3bb55c45"
|
||||||
|
|
||||||
|
def test_identity_context_included(self):
|
||||||
|
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||||
|
|
||||||
|
formatter = StructuredJSONFormatter()
|
||||||
|
record = logging.LogRecord(
|
||||||
|
name="test",
|
||||||
|
level=logging.INFO,
|
||||||
|
pathname="test.py",
|
||||||
|
lineno=1,
|
||||||
|
msg="Test",
|
||||||
|
args=(),
|
||||||
|
exc_info=None,
|
||||||
|
)
|
||||||
|
record.tenant_id = "t-global-corp"
|
||||||
|
record.user_id = "u-admin-007"
|
||||||
|
record.user_type = "admin"
|
||||||
|
|
||||||
|
output = formatter.format(record)
|
||||||
|
log_dict = orjson.loads(output)
|
||||||
|
|
||||||
|
assert "identity" in log_dict
|
||||||
|
assert log_dict["identity"]["tenant_id"] == "t-global-corp"
|
||||||
|
assert log_dict["identity"]["user_id"] == "u-admin-007"
|
||||||
|
assert log_dict["identity"]["user_type"] == "admin"
|
||||||
|
|
||||||
|
def test_no_identity_when_empty(self):
|
||||||
|
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||||
|
|
||||||
|
formatter = StructuredJSONFormatter()
|
||||||
|
record = logging.LogRecord(
|
||||||
|
name="test",
|
||||||
|
level=logging.INFO,
|
||||||
|
pathname="test.py",
|
||||||
|
lineno=1,
|
||||||
|
msg="Test",
|
||||||
|
args=(),
|
||||||
|
exc_info=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
output = formatter.format(record)
|
||||||
|
log_dict = orjson.loads(output)
|
||||||
|
|
||||||
|
assert "identity" not in log_dict
|
||||||
|
|
||||||
|
def test_attributes_included(self):
|
||||||
|
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||||
|
|
||||||
|
formatter = StructuredJSONFormatter()
|
||||||
|
record = logging.LogRecord(
|
||||||
|
name="test",
|
||||||
|
level=logging.INFO,
|
||||||
|
pathname="test.py",
|
||||||
|
lineno=1,
|
||||||
|
msg="Test",
|
||||||
|
args=(),
|
||||||
|
exc_info=None,
|
||||||
|
)
|
||||||
|
record.attributes = {"order_id": "ord-123", "amount": 99.99}
|
||||||
|
|
||||||
|
output = formatter.format(record)
|
||||||
|
log_dict = orjson.loads(output)
|
||||||
|
|
||||||
|
assert log_dict["attributes"]["order_id"] == "ord-123"
|
||||||
|
assert log_dict["attributes"]["amount"] == 99.99
|
||||||
|
|
||||||
|
def test_message_with_args(self):
|
||||||
|
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||||
|
|
||||||
|
formatter = StructuredJSONFormatter()
|
||||||
|
record = logging.LogRecord(
|
||||||
|
name="test",
|
||||||
|
level=logging.INFO,
|
||||||
|
pathname="test.py",
|
||||||
|
lineno=1,
|
||||||
|
msg="User %s logged in from %s",
|
||||||
|
args=("john", "192.168.1.1"),
|
||||||
|
exc_info=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
output = formatter.format(record)
|
||||||
|
log_dict = orjson.loads(output)
|
||||||
|
|
||||||
|
assert log_dict["message"] == "User john logged in from 192.168.1.1"
|
||||||
|
|
||||||
|
def test_timestamp_format(self):
|
||||||
|
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||||
|
|
||||||
|
formatter = StructuredJSONFormatter()
|
||||||
|
record = logging.LogRecord(
|
||||||
|
name="test",
|
||||||
|
level=logging.INFO,
|
||||||
|
pathname="test.py",
|
||||||
|
lineno=1,
|
||||||
|
msg="Test",
|
||||||
|
args=(),
|
||||||
|
exc_info=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
output = formatter.format(record)
|
||||||
|
log_dict = orjson.loads(output)
|
||||||
|
|
||||||
|
# Verify ISO 8601 format with Z suffix
|
||||||
|
ts = log_dict["ts"]
|
||||||
|
assert ts.endswith("Z")
|
||||||
|
assert "T" in ts
|
||||||
|
# Should have milliseconds
|
||||||
|
assert "." in ts
|
||||||
|
|
||||||
|
def test_fallback_for_non_serializable_attributes(self):
|
||||||
|
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||||
|
|
||||||
|
formatter = StructuredJSONFormatter()
|
||||||
|
record = logging.LogRecord(
|
||||||
|
name="test",
|
||||||
|
level=logging.INFO,
|
||||||
|
pathname="test.py",
|
||||||
|
lineno=1,
|
||||||
|
msg="Test with non-serializable",
|
||||||
|
args=(),
|
||||||
|
exc_info=None,
|
||||||
|
)
|
||||||
|
# Set is not serializable by orjson
|
||||||
|
record.attributes = {"items": {1, 2, 3}, "custom": object()}
|
||||||
|
|
||||||
|
# Should not raise, fallback to json.dumps with default=str
|
||||||
|
output = formatter.format(record)
|
||||||
|
|
||||||
|
# Verify it's valid JSON (parsed by stdlib json since orjson may fail)
|
||||||
|
import json
|
||||||
|
|
||||||
|
log_dict = json.loads(output)
|
||||||
|
assert log_dict["message"] == "Test with non-serializable"
|
||||||
|
assert "attributes" in log_dict
|
||||||
|
|
@ -0,0 +1,102 @@
|
||||||
|
"""Tests for trace helper functions."""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetSpanIdFromOtelContext:
|
||||||
|
def test_returns_none_without_span(self):
|
||||||
|
from core.helper.trace_id_helper import get_span_id_from_otel_context
|
||||||
|
|
||||||
|
with mock.patch("opentelemetry.trace.get_current_span", return_value=None):
|
||||||
|
result = get_span_id_from_otel_context()
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_returns_span_id_when_available(self):
|
||||||
|
from core.helper.trace_id_helper import get_span_id_from_otel_context
|
||||||
|
|
||||||
|
mock_span = mock.MagicMock()
|
||||||
|
mock_context = mock.MagicMock()
|
||||||
|
mock_context.span_id = 0x051581BF3BB55C45
|
||||||
|
mock_span.get_span_context.return_value = mock_context
|
||||||
|
|
||||||
|
with mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span):
|
||||||
|
with mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0):
|
||||||
|
result = get_span_id_from_otel_context()
|
||||||
|
assert result == "051581bf3bb55c45"
|
||||||
|
|
||||||
|
def test_returns_none_on_exception(self):
|
||||||
|
from core.helper.trace_id_helper import get_span_id_from_otel_context
|
||||||
|
|
||||||
|
with mock.patch("opentelemetry.trace.get_current_span", side_effect=Exception("Test error")):
|
||||||
|
result = get_span_id_from_otel_context()
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestGenerateTraceparentHeader:
|
||||||
|
def test_generates_valid_format(self):
|
||||||
|
from core.helper.trace_id_helper import generate_traceparent_header
|
||||||
|
|
||||||
|
with mock.patch("opentelemetry.trace.get_current_span", return_value=None):
|
||||||
|
result = generate_traceparent_header()
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
# Format: 00-{trace_id}-{span_id}-01
|
||||||
|
parts = result.split("-")
|
||||||
|
assert len(parts) == 4
|
||||||
|
assert parts[0] == "00" # version
|
||||||
|
assert len(parts[1]) == 32 # trace_id (32 hex chars)
|
||||||
|
assert len(parts[2]) == 16 # span_id (16 hex chars)
|
||||||
|
assert parts[3] == "01" # flags
|
||||||
|
|
||||||
|
def test_uses_otel_context_when_available(self):
|
||||||
|
from core.helper.trace_id_helper import generate_traceparent_header
|
||||||
|
|
||||||
|
mock_span = mock.MagicMock()
|
||||||
|
mock_context = mock.MagicMock()
|
||||||
|
mock_context.trace_id = 0x5B8AA5A2D2C872E8321CF37308D69DF2
|
||||||
|
mock_context.span_id = 0x051581BF3BB55C45
|
||||||
|
mock_span.get_span_context.return_value = mock_context
|
||||||
|
|
||||||
|
with mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span):
|
||||||
|
with (
|
||||||
|
mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0),
|
||||||
|
mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0),
|
||||||
|
):
|
||||||
|
result = generate_traceparent_header()
|
||||||
|
|
||||||
|
assert result == "00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01"
|
||||||
|
|
||||||
|
def test_generates_hex_only_values(self):
|
||||||
|
from core.helper.trace_id_helper import generate_traceparent_header
|
||||||
|
|
||||||
|
with mock.patch("opentelemetry.trace.get_current_span", return_value=None):
|
||||||
|
result = generate_traceparent_header()
|
||||||
|
|
||||||
|
parts = result.split("-")
|
||||||
|
# All parts should be valid hex
|
||||||
|
assert re.match(r"^[0-9a-f]+$", parts[1])
|
||||||
|
assert re.match(r"^[0-9a-f]+$", parts[2])
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseTraceparentHeader:
|
||||||
|
def test_parses_valid_traceparent(self):
|
||||||
|
from core.helper.trace_id_helper import parse_traceparent_header
|
||||||
|
|
||||||
|
traceparent = "00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01"
|
||||||
|
result = parse_traceparent_header(traceparent)
|
||||||
|
|
||||||
|
assert result == "5b8aa5a2d2c872e8321cf37308d69df2"
|
||||||
|
|
||||||
|
def test_returns_none_for_invalid_format(self):
|
||||||
|
from core.helper.trace_id_helper import parse_traceparent_header
|
||||||
|
|
||||||
|
# Wrong number of parts
|
||||||
|
assert parse_traceparent_header("00-abc-def") is None
|
||||||
|
# Wrong trace_id length
|
||||||
|
assert parse_traceparent_header("00-abc-def-01") is None
|
||||||
|
|
||||||
|
def test_returns_none_for_empty_string(self):
|
||||||
|
from core.helper.trace_id_helper import parse_traceparent_header
|
||||||
|
|
||||||
|
assert parse_traceparent_header("") is None
|
||||||
|
|
@ -99,14 +99,7 @@ def test_external_api_json_message_and_bad_request_rewrite():
|
||||||
assert res.get_json()["message"] == "Invalid JSON payload received or JSON payload is empty."
|
assert res.get_json()["message"] == "Invalid JSON payload received or JSON payload is empty."
|
||||||
|
|
||||||
|
|
||||||
def test_external_api_param_mapping_and_quota_and_exc_info_none():
|
def test_external_api_param_mapping_and_quota():
|
||||||
# Force exc_info() to return (None,None,None) only during request
|
|
||||||
import libs.external_api as ext
|
|
||||||
|
|
||||||
orig_exc_info = ext.sys.exc_info
|
|
||||||
try:
|
|
||||||
ext.sys.exc_info = lambda: (None, None, None)
|
|
||||||
|
|
||||||
app = _create_api_app()
|
app = _create_api_app()
|
||||||
client = app.test_client()
|
client = app.test_client()
|
||||||
|
|
||||||
|
|
@ -120,8 +113,6 @@ def test_external_api_param_mapping_and_quota_and_exc_info_none():
|
||||||
# Quota path — depending on Flask-RESTX internals it may be handled
|
# Quota path — depending on Flask-RESTX internals it may be handled
|
||||||
res = client.get("/api/quota")
|
res = client.get("/api/quota")
|
||||||
assert res.status_code in (400, 429)
|
assert res.status_code in (400, 429)
|
||||||
finally:
|
|
||||||
ext.sys.exc_info = orig_exc_info # type: ignore[assignment]
|
|
||||||
|
|
||||||
|
|
||||||
def test_unauthorized_and_force_logout_clears_cookies():
|
def test_unauthorized_and_force_logout_clears_cookies():
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue