mirror of https://github.com/langgenius/dify.git
feat: add structured log and keep the trace id injecting to all request
This commit is contained in:
parent
f439e081b5
commit
183b38ecb0
|
|
@ -28,25 +28,33 @@ def create_flask_app_with_configs() -> DifyApp:
|
|||
# add an unique identifier to each request
|
||||
RecyclableContextVar.increment_thread_recycles()
|
||||
|
||||
# add after request hook for injecting X-Trace-Id header from OpenTelemetry span context
|
||||
# add after request hook for injecting trace headers from OpenTelemetry span context
|
||||
# Only adds headers when OTEL is enabled and has valid context
|
||||
@dify_app.after_request
|
||||
def add_trace_id_header(response):
|
||||
def add_trace_headers(response):
|
||||
try:
|
||||
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
|
||||
|
||||
span = get_current_span()
|
||||
ctx = span.get_span_context() if span else None
|
||||
if ctx and ctx.is_valid:
|
||||
trace_id_hex = format(ctx.trace_id, "032x")
|
||||
# Avoid duplicates if some middleware added it
|
||||
if "X-Trace-Id" not in response.headers:
|
||||
response.headers["X-Trace-Id"] = trace_id_hex
|
||||
|
||||
if not ctx or not ctx.is_valid:
|
||||
return response
|
||||
|
||||
# Inject trace headers from OTEL context
|
||||
if ctx.trace_id != INVALID_TRACE_ID and "X-Trace-Id" not in response.headers:
|
||||
response.headers["X-Trace-Id"] = format(ctx.trace_id, "032x")
|
||||
if ctx.span_id != INVALID_SPAN_ID and "X-Span-Id" not in response.headers:
|
||||
response.headers["X-Span-Id"] = format(ctx.span_id, "016x")
|
||||
|
||||
except Exception:
|
||||
# Never break the response due to tracing header injection
|
||||
logger.warning("Failed to add trace ID to response header", exc_info=True)
|
||||
logger.warning("Failed to add trace headers to response", exc_info=True)
|
||||
return response
|
||||
|
||||
# Capture the decorator's return value to avoid pyright reportUnusedFunction
|
||||
_ = before_request
|
||||
_ = add_trace_id_header
|
||||
_ = add_trace_headers
|
||||
|
||||
return dify_app
|
||||
|
||||
|
|
|
|||
|
|
@ -587,6 +587,11 @@ class LoggingConfig(BaseSettings):
|
|||
default="INFO",
|
||||
)
|
||||
|
||||
LOG_OUTPUT_FORMAT: Literal["text", "json"] = Field(
|
||||
description="Log output format: 'text' for human-readable, 'json' for structured JSON logs.",
|
||||
default="text",
|
||||
)
|
||||
|
||||
LOG_FILE: str | None = Field(
|
||||
description="File path for log output.",
|
||||
default=None,
|
||||
|
|
|
|||
|
|
@ -88,12 +88,48 @@ def _get_user_provided_host_header(headers: dict | None) -> str | None:
|
|||
return None
|
||||
|
||||
|
||||
def _inject_trace_headers(headers: dict | None) -> dict:
|
||||
"""
|
||||
Inject W3C traceparent header for distributed tracing.
|
||||
|
||||
When OTEL is enabled, HTTPXClientInstrumentor handles trace propagation automatically.
|
||||
When OTEL is disabled, we manually inject the traceparent header.
|
||||
"""
|
||||
if headers is None:
|
||||
headers = {}
|
||||
|
||||
# Skip if already present (case-insensitive check)
|
||||
for key in headers:
|
||||
if key.lower() == "traceparent":
|
||||
return headers
|
||||
|
||||
# Skip if OTEL is enabled - HTTPXClientInstrumentor handles this automatically
|
||||
if dify_config.ENABLE_OTEL:
|
||||
return headers
|
||||
|
||||
# Generate and inject traceparent for non-OTEL scenarios
|
||||
try:
|
||||
from core.helper.trace_id_helper import generate_traceparent_header
|
||||
|
||||
traceparent = generate_traceparent_header()
|
||||
if traceparent:
|
||||
headers["traceparent"] = traceparent
|
||||
except Exception:
|
||||
# Silently ignore errors to avoid breaking requests
|
||||
logger.debug("Failed to generate traceparent header", exc_info=True)
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
if "allow_redirects" in kwargs:
|
||||
allow_redirects = kwargs.pop("allow_redirects")
|
||||
if "follow_redirects" not in kwargs:
|
||||
kwargs["follow_redirects"] = allow_redirects
|
||||
|
||||
# Extract follow_redirects - it's a send() parameter, not build_request() parameter
|
||||
follow_redirects = kwargs.pop("follow_redirects", False)
|
||||
|
||||
if "timeout" not in kwargs:
|
||||
kwargs["timeout"] = httpx.Timeout(
|
||||
timeout=dify_config.SSRF_DEFAULT_TIME_OUT,
|
||||
|
|
@ -106,10 +142,14 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
|||
verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY)
|
||||
client = _get_ssrf_client(verify_option)
|
||||
|
||||
# Inject traceparent header for distributed tracing (when OTEL is not enabled)
|
||||
headers = kwargs.get("headers") or {}
|
||||
headers = _inject_trace_headers(headers)
|
||||
kwargs["headers"] = headers
|
||||
|
||||
# Preserve user-provided Host header
|
||||
# When using a forward proxy, httpx may override the Host header based on the URL.
|
||||
# We extract and preserve any explicitly set Host header to support virtual hosting.
|
||||
headers = kwargs.get("headers", {})
|
||||
user_provided_host = _get_user_provided_host_header(headers)
|
||||
|
||||
retries = 0
|
||||
|
|
@ -124,7 +164,7 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
|||
if user_provided_host is not None:
|
||||
request.headers["Host"] = user_provided_host
|
||||
|
||||
response = client.send(request)
|
||||
response = client.send(request, follow_redirects=follow_redirects)
|
||||
|
||||
# Check for SSRF protection by Squid proxy
|
||||
if response.status_code in (401, 403):
|
||||
|
|
|
|||
|
|
@ -103,3 +103,67 @@ def parse_traceparent_header(traceparent: str) -> str | None:
|
|||
if len(parts) == 4 and len(parts[1]) == 32:
|
||||
return parts[1]
|
||||
return None
|
||||
|
||||
|
||||
def get_span_id_from_otel_context() -> str | None:
|
||||
"""
|
||||
Retrieve the current span ID from the active OpenTelemetry trace context.
|
||||
|
||||
Returns:
|
||||
A 16-character hex string representing the span ID, or None if not available.
|
||||
"""
|
||||
try:
|
||||
from opentelemetry.trace import get_current_span
|
||||
from opentelemetry.trace.span import INVALID_SPAN_ID
|
||||
|
||||
span = get_current_span()
|
||||
if not span:
|
||||
return None
|
||||
|
||||
span_context = span.get_span_context()
|
||||
if not span_context or span_context.span_id == INVALID_SPAN_ID:
|
||||
return None
|
||||
|
||||
return f"{span_context.span_id:016x}"
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def generate_traceparent_header() -> str | None:
|
||||
"""
|
||||
Generate a W3C traceparent header from the current context.
|
||||
|
||||
Uses OpenTelemetry context if available, otherwise generates new IDs
|
||||
based on the Flask request context.
|
||||
|
||||
Format: {version}-{trace_id}-{span_id}-{flags}
|
||||
Example: 00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01
|
||||
|
||||
Returns:
|
||||
A valid traceparent header string, or None if generation fails.
|
||||
"""
|
||||
import uuid
|
||||
|
||||
# Try OTEL context first
|
||||
trace_id = get_trace_id_from_otel_context()
|
||||
span_id = get_span_id_from_otel_context()
|
||||
|
||||
if trace_id and span_id:
|
||||
return f"00-{trace_id}-{span_id}-01"
|
||||
|
||||
# Fallback: generate new trace context
|
||||
try:
|
||||
import flask
|
||||
|
||||
if flask.has_request_context() and hasattr(flask.g, "request_id"):
|
||||
# Derive trace_id from request_id for consistency
|
||||
trace_id = uuid.uuid5(uuid.NAMESPACE_DNS, flask.g.request_id).hex
|
||||
else:
|
||||
trace_id = uuid.uuid4().hex
|
||||
except Exception:
|
||||
trace_id = uuid.uuid4().hex
|
||||
|
||||
# Generate a new span_id (16 hex chars)
|
||||
span_id = uuid.uuid4().hex[:16]
|
||||
|
||||
return f"00-{trace_id}-{span_id}-01"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,10 @@
|
|||
"""Structured logging components for Dify."""
|
||||
|
||||
from core.logging.filters import IdentityContextFilter, TraceContextFilter
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
__all__ = [
|
||||
"IdentityContextFilter",
|
||||
"StructuredJSONFormatter",
|
||||
"TraceContextFilter",
|
||||
]
|
||||
|
|
@ -0,0 +1,115 @@
|
|||
"""Logging filters for structured logging."""
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
import flask
|
||||
|
||||
|
||||
class TraceContextFilter(logging.Filter):
|
||||
"""
|
||||
Filter that adds trace_id and span_id to log records.
|
||||
Integrates with OpenTelemetry when available, falls back to request_id.
|
||||
"""
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
# Get trace context from OpenTelemetry
|
||||
trace_id, span_id = self._get_otel_context()
|
||||
|
||||
# Set trace_id (fallback to request_id if no OTEL context)
|
||||
if trace_id:
|
||||
record.trace_id = trace_id
|
||||
else:
|
||||
record.trace_id = self._get_or_create_request_trace_id()
|
||||
|
||||
record.span_id = span_id or ""
|
||||
|
||||
# For backward compatibility, also set req_id
|
||||
record.req_id = self._get_request_id()
|
||||
|
||||
return True
|
||||
|
||||
def _get_otel_context(self) -> tuple[str, str]:
|
||||
"""Extract trace_id and span_id from OpenTelemetry context."""
|
||||
with contextlib.suppress(Exception):
|
||||
from opentelemetry.trace import get_current_span
|
||||
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
|
||||
|
||||
span = get_current_span()
|
||||
if span and span.get_span_context():
|
||||
ctx = span.get_span_context()
|
||||
if ctx.is_valid and ctx.trace_id != INVALID_TRACE_ID:
|
||||
trace_id = f"{ctx.trace_id:032x}"
|
||||
span_id = f"{ctx.span_id:016x}" if ctx.span_id != INVALID_SPAN_ID else ""
|
||||
return trace_id, span_id
|
||||
return "", ""
|
||||
|
||||
def _get_request_id(self) -> str:
|
||||
"""Get request ID from Flask context."""
|
||||
if flask.has_request_context():
|
||||
if hasattr(flask.g, "request_id"):
|
||||
return flask.g.request_id
|
||||
flask.g.request_id = uuid.uuid4().hex[:10]
|
||||
return flask.g.request_id
|
||||
return ""
|
||||
|
||||
def _get_or_create_request_trace_id(self) -> str:
|
||||
"""Get or create a trace_id derived from request context."""
|
||||
if flask.has_request_context():
|
||||
if hasattr(flask.g, "_trace_id"):
|
||||
return flask.g._trace_id
|
||||
# Derive trace_id from request_id for consistency
|
||||
request_id = self._get_request_id()
|
||||
if request_id:
|
||||
# Generate a 32-char hex trace_id from request_id
|
||||
flask.g._trace_id = uuid.uuid5(uuid.NAMESPACE_DNS, request_id).hex
|
||||
return flask.g._trace_id
|
||||
return ""
|
||||
|
||||
|
||||
class IdentityContextFilter(logging.Filter):
|
||||
"""
|
||||
Filter that adds user identity context to log records.
|
||||
Extracts tenant_id, user_id, and user_type from Flask-Login current_user.
|
||||
"""
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
identity = self._extract_identity()
|
||||
record.tenant_id = identity.get("tenant_id", "")
|
||||
record.user_id = identity.get("user_id", "")
|
||||
record.user_type = identity.get("user_type", "")
|
||||
return True
|
||||
|
||||
def _extract_identity(self) -> dict[str, str]:
|
||||
"""Extract identity from current_user if in request context."""
|
||||
try:
|
||||
if not flask.has_request_context():
|
||||
return {}
|
||||
from flask_login import current_user
|
||||
|
||||
# Check if user is authenticated using the proxy
|
||||
if not current_user.is_authenticated:
|
||||
return {}
|
||||
|
||||
# Access the underlying user object
|
||||
user = current_user
|
||||
|
||||
from models import Account
|
||||
from models.model import EndUser
|
||||
|
||||
identity: dict[str, str] = {}
|
||||
|
||||
if isinstance(user, Account):
|
||||
if user.current_tenant_id:
|
||||
identity["tenant_id"] = user.current_tenant_id
|
||||
identity["user_id"] = user.id
|
||||
identity["user_type"] = "account"
|
||||
elif isinstance(user, EndUser):
|
||||
identity["tenant_id"] = user.tenant_id
|
||||
identity["user_id"] = user.id
|
||||
identity["user_type"] = user.type or "end_user"
|
||||
|
||||
return identity
|
||||
except Exception:
|
||||
return {}
|
||||
|
|
@ -0,0 +1,101 @@
|
|||
"""Structured JSON log formatter for Dify."""
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
class StructuredJSONFormatter(logging.Formatter):
|
||||
"""
|
||||
JSON log formatter following the specified schema:
|
||||
{
|
||||
"ts": "ISO 8601 UTC",
|
||||
"severity": "INFO|ERROR|WARN|DEBUG",
|
||||
"service": "service name",
|
||||
"caller": "file:line",
|
||||
"trace_id": "hex 32",
|
||||
"span_id": "hex 16",
|
||||
"identity": { "tenant_id", "user_id", "user_type" },
|
||||
"message": "log message",
|
||||
"attributes": { ... },
|
||||
"stack_trace": "..."
|
||||
}
|
||||
"""
|
||||
|
||||
SEVERITY_MAP: dict[int, str] = {
|
||||
logging.DEBUG: "DEBUG",
|
||||
logging.INFO: "INFO",
|
||||
logging.WARNING: "WARN",
|
||||
logging.ERROR: "ERROR",
|
||||
logging.CRITICAL: "ERROR",
|
||||
}
|
||||
|
||||
def __init__(self, service_name: str | None = None):
|
||||
super().__init__()
|
||||
self._service_name = service_name or dify_config.APPLICATION_NAME
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
log_dict = self._build_log_dict(record)
|
||||
return orjson.dumps(log_dict).decode("utf-8")
|
||||
|
||||
def _build_log_dict(self, record: logging.LogRecord) -> dict[str, Any]:
|
||||
# Core fields
|
||||
log_dict: dict[str, Any] = {
|
||||
"ts": datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z"),
|
||||
"severity": self.SEVERITY_MAP.get(record.levelno, "INFO"),
|
||||
"service": self._service_name,
|
||||
"caller": f"{record.filename}:{record.lineno}",
|
||||
"message": record.getMessage(),
|
||||
}
|
||||
|
||||
# Trace context (from TraceContextFilter)
|
||||
trace_id = getattr(record, "trace_id", "") or ""
|
||||
span_id = getattr(record, "span_id", "") or ""
|
||||
|
||||
if trace_id:
|
||||
log_dict["trace_id"] = trace_id
|
||||
if span_id:
|
||||
log_dict["span_id"] = span_id
|
||||
|
||||
# Identity context (from IdentityContextFilter)
|
||||
identity = self._extract_identity(record)
|
||||
if identity:
|
||||
log_dict["identity"] = identity
|
||||
|
||||
# Dynamic attributes
|
||||
attributes = getattr(record, "attributes", None)
|
||||
if attributes:
|
||||
log_dict["attributes"] = attributes
|
||||
|
||||
# Stack trace for errors with exceptions
|
||||
if record.exc_info and record.levelno >= logging.ERROR:
|
||||
log_dict["stack_trace"] = self._format_exception(record.exc_info)
|
||||
|
||||
return log_dict
|
||||
|
||||
def _extract_identity(self, record: logging.LogRecord) -> dict[str, str] | None:
|
||||
tenant_id = getattr(record, "tenant_id", None)
|
||||
user_id = getattr(record, "user_id", None)
|
||||
user_type = getattr(record, "user_type", None)
|
||||
|
||||
if not any([tenant_id, user_id, user_type]):
|
||||
return None
|
||||
|
||||
identity: dict[str, str] = {}
|
||||
if tenant_id:
|
||||
identity["tenant_id"] = tenant_id
|
||||
if user_id:
|
||||
identity["user_id"] = user_id
|
||||
if user_type:
|
||||
identity["user_type"] = user_type
|
||||
return identity
|
||||
|
||||
def _format_exception(self, exc_info: tuple[Any, ...]) -> str:
|
||||
if exc_info and exc_info[0] is not None:
|
||||
return "".join(traceback.format_exception(*exc_info))
|
||||
return ""
|
||||
|
|
@ -103,6 +103,9 @@ class BasePluginClient:
|
|||
prepared_headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY
|
||||
prepared_headers.setdefault("Accept-Encoding", "gzip, deflate, br")
|
||||
|
||||
# Inject traceparent header for distributed tracing
|
||||
self._inject_trace_headers(prepared_headers)
|
||||
|
||||
prepared_data: bytes | dict[str, Any] | str | None = (
|
||||
data if isinstance(data, (bytes, str, dict)) or data is None else None
|
||||
)
|
||||
|
|
@ -114,6 +117,28 @@ class BasePluginClient:
|
|||
|
||||
return str(url), prepared_headers, prepared_data, params, files
|
||||
|
||||
def _inject_trace_headers(self, headers: dict[str, str]) -> None:
|
||||
"""
|
||||
Inject W3C traceparent header for distributed tracing.
|
||||
|
||||
This ensures trace context is propagated to plugin daemon even if
|
||||
HTTPXClientInstrumentor doesn't cover module-level httpx functions.
|
||||
"""
|
||||
import contextlib
|
||||
|
||||
# Skip if already present (case-insensitive check)
|
||||
for key in headers:
|
||||
if key.lower() == "traceparent":
|
||||
return
|
||||
|
||||
# Inject traceparent - works as fallback when OTEL instrumentation doesn't cover this call
|
||||
with contextlib.suppress(Exception):
|
||||
from core.helper.trace_id_helper import generate_traceparent_header
|
||||
|
||||
traceparent = generate_traceparent_header()
|
||||
if traceparent:
|
||||
headers["traceparent"] = traceparent
|
||||
|
||||
def _stream_request(
|
||||
self,
|
||||
method: str,
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
"""Logging extension for Dify Flask application."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
|
@ -7,12 +9,14 @@ from logging.handlers import RotatingFileHandler
|
|||
import flask
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.trace_id_helper import get_trace_id_from_otel_context
|
||||
from dify_app import DifyApp
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
"""Initialize logging with support for text or JSON format."""
|
||||
log_handlers: list[logging.Handler] = []
|
||||
|
||||
# File handler
|
||||
log_file = dify_config.LOG_FILE
|
||||
if log_file:
|
||||
log_dir = os.path.dirname(log_file)
|
||||
|
|
@ -25,27 +29,53 @@ def init_app(app: DifyApp):
|
|||
)
|
||||
)
|
||||
|
||||
# Always add StreamHandler to log to console
|
||||
# Console handler
|
||||
sh = logging.StreamHandler(sys.stdout)
|
||||
log_handlers.append(sh)
|
||||
|
||||
# Apply RequestIdFilter to all handlers
|
||||
for handler in log_handlers:
|
||||
handler.addFilter(RequestIdFilter())
|
||||
# Apply filters to all handlers
|
||||
from core.logging.filters import IdentityContextFilter, TraceContextFilter
|
||||
|
||||
for handler in log_handlers:
|
||||
handler.addFilter(TraceContextFilter())
|
||||
handler.addFilter(IdentityContextFilter())
|
||||
|
||||
# Configure formatter based on format type
|
||||
formatter = _create_formatter()
|
||||
for handler in log_handlers:
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
# Configure root logger
|
||||
logging.basicConfig(
|
||||
level=dify_config.LOG_LEVEL,
|
||||
format=dify_config.LOG_FORMAT,
|
||||
datefmt=dify_config.LOG_DATEFORMAT,
|
||||
handlers=log_handlers,
|
||||
force=True,
|
||||
)
|
||||
|
||||
# Apply RequestIdFormatter to all handlers
|
||||
apply_request_id_formatter()
|
||||
|
||||
# Disable propagation for noisy loggers to avoid duplicate logs
|
||||
logging.getLogger("sqlalchemy.engine").propagate = False
|
||||
|
||||
# Apply timezone if specified (only for text format)
|
||||
if dify_config.LOG_OUTPUT_FORMAT == "text":
|
||||
_apply_timezone(log_handlers)
|
||||
|
||||
|
||||
def _create_formatter() -> logging.Formatter:
|
||||
"""Create appropriate formatter based on configuration."""
|
||||
if dify_config.LOG_OUTPUT_FORMAT == "json":
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
return StructuredJSONFormatter()
|
||||
else:
|
||||
# Text format - use existing pattern with backward compatible formatter
|
||||
return _TextFormatter(
|
||||
fmt=dify_config.LOG_FORMAT,
|
||||
datefmt=dify_config.LOG_DATEFORMAT,
|
||||
)
|
||||
|
||||
|
||||
def _apply_timezone(handlers: list[logging.Handler]):
|
||||
"""Apply timezone conversion to text formatters."""
|
||||
log_tz = dify_config.LOG_TZ
|
||||
if log_tz:
|
||||
from datetime import datetime
|
||||
|
|
@ -57,26 +87,41 @@ def init_app(app: DifyApp):
|
|||
def time_converter(seconds):
|
||||
return datetime.fromtimestamp(seconds, tz=timezone).timetuple()
|
||||
|
||||
for handler in logging.root.handlers:
|
||||
for handler in handlers:
|
||||
if handler.formatter:
|
||||
handler.formatter.converter = time_converter
|
||||
handler.formatter.converter = time_converter # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def get_request_id():
|
||||
if getattr(flask.g, "request_id", None):
|
||||
class _TextFormatter(logging.Formatter):
|
||||
"""Text formatter that ensures trace_id and req_id are always present."""
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
if not hasattr(record, "req_id"):
|
||||
record.req_id = ""
|
||||
if not hasattr(record, "trace_id"):
|
||||
record.trace_id = ""
|
||||
if not hasattr(record, "span_id"):
|
||||
record.span_id = ""
|
||||
return super().format(record)
|
||||
|
||||
|
||||
def get_request_id() -> str:
|
||||
"""Get or create request ID for current request context."""
|
||||
if flask.has_request_context():
|
||||
if getattr(flask.g, "request_id", None):
|
||||
return flask.g.request_id
|
||||
flask.g.request_id = uuid.uuid4().hex[:10]
|
||||
return flask.g.request_id
|
||||
|
||||
new_uuid = uuid.uuid4().hex[:10]
|
||||
flask.g.request_id = new_uuid
|
||||
|
||||
return new_uuid
|
||||
return ""
|
||||
|
||||
|
||||
# Backward compatibility aliases
|
||||
class RequestIdFilter(logging.Filter):
|
||||
# This is a logging filter that makes the request ID available for use in
|
||||
# the logging format. Note that we're checking if we're in a request
|
||||
# context, as we may want to log things before Flask is fully loaded.
|
||||
def filter(self, record):
|
||||
"""Deprecated: Use TraceContextFilter from core.logging.filters instead."""
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
from core.helper.trace_id_helper import get_trace_id_from_otel_context
|
||||
|
||||
trace_id = get_trace_id_from_otel_context() or ""
|
||||
record.req_id = get_request_id() if flask.has_request_context() else ""
|
||||
record.trace_id = trace_id
|
||||
|
|
@ -84,7 +129,9 @@ class RequestIdFilter(logging.Filter):
|
|||
|
||||
|
||||
class RequestIdFormatter(logging.Formatter):
|
||||
def format(self, record):
|
||||
"""Deprecated: Use _TextFormatter instead."""
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
if not hasattr(record, "req_id"):
|
||||
record.req_id = ""
|
||||
if not hasattr(record, "trace_id"):
|
||||
|
|
@ -93,6 +140,7 @@ class RequestIdFormatter(logging.Formatter):
|
|||
|
||||
|
||||
def apply_request_id_formatter():
|
||||
"""Deprecated: Formatter is now applied in init_app."""
|
||||
for handler in logging.root.handlers:
|
||||
if handler.formatter:
|
||||
handler.formatter = RequestIdFormatter(dify_config.LOG_FORMAT, dify_config.LOG_DATEFORMAT)
|
||||
|
|
|
|||
|
|
@ -19,26 +19,43 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class ExceptionLoggingHandler(logging.Handler):
|
||||
"""
|
||||
Handler that records exceptions to the current OpenTelemetry span.
|
||||
|
||||
Unlike creating a new span, this records exceptions on the existing span
|
||||
to maintain trace context consistency throughout the request lifecycle.
|
||||
"""
|
||||
|
||||
def emit(self, record: logging.LogRecord):
|
||||
with contextlib.suppress(Exception):
|
||||
if record.exc_info:
|
||||
tracer = get_tracer_provider().get_tracer("dify.exception.logging")
|
||||
with tracer.start_as_current_span(
|
||||
"log.exception",
|
||||
attributes={
|
||||
"log.level": record.levelname,
|
||||
"log.message": record.getMessage(),
|
||||
"log.logger": record.name,
|
||||
"log.file.path": record.pathname,
|
||||
"log.file.line": record.lineno,
|
||||
},
|
||||
) as span:
|
||||
span.set_status(StatusCode.ERROR)
|
||||
if record.exc_info[1]:
|
||||
span.record_exception(record.exc_info[1])
|
||||
span.set_attribute("exception.message", str(record.exc_info[1]))
|
||||
if record.exc_info[0]:
|
||||
span.set_attribute("exception.type", record.exc_info[0].__name__)
|
||||
if not record.exc_info:
|
||||
return
|
||||
|
||||
from opentelemetry.trace import get_current_span
|
||||
|
||||
span = get_current_span()
|
||||
if not span or not span.is_recording():
|
||||
return
|
||||
|
||||
# Record exception on the current span instead of creating a new one
|
||||
span.set_status(StatusCode.ERROR, record.getMessage())
|
||||
|
||||
# Add log context as span events/attributes
|
||||
span.add_event(
|
||||
"log.exception",
|
||||
attributes={
|
||||
"log.level": record.levelname,
|
||||
"log.message": record.getMessage(),
|
||||
"log.logger": record.name,
|
||||
"log.file.path": record.pathname,
|
||||
"log.file.line": record.lineno,
|
||||
},
|
||||
)
|
||||
|
||||
if record.exc_info[1]:
|
||||
span.record_exception(record.exc_info[1])
|
||||
if record.exc_info[0]:
|
||||
span.set_attribute("exception.type", record.exc_info[0].__name__)
|
||||
|
||||
|
||||
def instrument_exception_logging() -> None:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import re
|
||||
import sys
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -109,11 +108,8 @@ def register_external_error_handlers(api: Api):
|
|||
data.setdefault("code", "unknown")
|
||||
data.setdefault("status", status_code)
|
||||
|
||||
# Log stack
|
||||
exc_info: Any = sys.exc_info()
|
||||
if exc_info[1] is None:
|
||||
exc_info = (None, None, None)
|
||||
current_app.log_exception(exc_info)
|
||||
# Note: Exception logging is handled by Flask/Flask-RESTX framework automatically
|
||||
# Explicit log_exception call removed to avoid duplicate log entries
|
||||
|
||||
return data, status_code
|
||||
|
||||
|
|
|
|||
|
|
@ -143,7 +143,7 @@ def test_host_header_preservation_with_user_header(mock_get_client):
|
|||
mock_client.build_request.assert_called_once()
|
||||
# Verify the Host header was set on the request object
|
||||
assert mock_request.headers.get("Host") == custom_host
|
||||
mock_client.send.assert_called_once_with(mock_request)
|
||||
mock_client.send.assert_called_once_with(mock_request, follow_redirects=False)
|
||||
|
||||
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||
|
|
@ -160,3 +160,101 @@ def test_host_header_preservation_case_insensitive(mock_get_client, host_key):
|
|||
mock_get_client.return_value = mock_client
|
||||
response = make_request("GET", "http://example.com", headers={host_key: "api.example.com"})
|
||||
assert mock_request.headers.get("Host") == "api.example.com"
|
||||
|
||||
|
||||
class TestFollowRedirectsParameter:
|
||||
"""Tests for follow_redirects parameter handling.
|
||||
|
||||
These tests verify that follow_redirects is passed to send(), not build_request().
|
||||
This is critical because httpx.Client.build_request() does not accept follow_redirects.
|
||||
"""
|
||||
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||
def test_follow_redirects_not_passed_to_build_request(self, mock_get_client):
|
||||
"""Verify follow_redirects is NOT passed to build_request()."""
|
||||
mock_client = MagicMock()
|
||||
mock_request = MagicMock()
|
||||
mock_request.headers = {}
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_client.send.return_value = mock_response
|
||||
mock_client.build_request.return_value = mock_request
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
make_request("GET", "http://example.com", follow_redirects=True)
|
||||
|
||||
# Verify follow_redirects was NOT passed to build_request
|
||||
call_kwargs = mock_client.build_request.call_args.kwargs
|
||||
assert "follow_redirects" not in call_kwargs, "follow_redirects should not be passed to build_request()"
|
||||
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||
def test_follow_redirects_passed_to_send(self, mock_get_client):
|
||||
"""Verify follow_redirects IS passed to send()."""
|
||||
mock_client = MagicMock()
|
||||
mock_request = MagicMock()
|
||||
mock_request.headers = {}
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_client.send.return_value = mock_response
|
||||
mock_client.build_request.return_value = mock_request
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
make_request("GET", "http://example.com", follow_redirects=True)
|
||||
|
||||
# Verify follow_redirects WAS passed to send
|
||||
mock_client.send.assert_called_once_with(mock_request, follow_redirects=True)
|
||||
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||
def test_allow_redirects_converted_to_follow_redirects(self, mock_get_client):
|
||||
"""Verify allow_redirects (requests-style) is converted to follow_redirects (httpx-style)."""
|
||||
mock_client = MagicMock()
|
||||
mock_request = MagicMock()
|
||||
mock_request.headers = {}
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_client.send.return_value = mock_response
|
||||
mock_client.build_request.return_value = mock_request
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
# Use allow_redirects (requests-style parameter)
|
||||
make_request("GET", "http://example.com", allow_redirects=True)
|
||||
|
||||
# Verify it was converted to follow_redirects for send()
|
||||
mock_client.send.assert_called_once_with(mock_request, follow_redirects=True)
|
||||
# Verify allow_redirects was NOT passed to build_request
|
||||
call_kwargs = mock_client.build_request.call_args.kwargs
|
||||
assert "allow_redirects" not in call_kwargs
|
||||
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||
def test_follow_redirects_default_is_false(self, mock_get_client):
|
||||
"""Verify follow_redirects defaults to False when not specified."""
|
||||
mock_client = MagicMock()
|
||||
mock_request = MagicMock()
|
||||
mock_request.headers = {}
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_client.send.return_value = mock_response
|
||||
mock_client.build_request.return_value = mock_request
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
make_request("GET", "http://example.com")
|
||||
|
||||
# Verify default is False
|
||||
mock_client.send.assert_called_once_with(mock_request, follow_redirects=False)
|
||||
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||
def test_follow_redirects_takes_precedence_over_allow_redirects(self, mock_get_client):
|
||||
"""Verify follow_redirects takes precedence when both are specified."""
|
||||
mock_client = MagicMock()
|
||||
mock_request = MagicMock()
|
||||
mock_request.headers = {}
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_client.send.return_value = mock_response
|
||||
mock_client.build_request.return_value = mock_request
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
# Both specified - follow_redirects should take precedence
|
||||
make_request("GET", "http://example.com", allow_redirects=False, follow_redirects=True)
|
||||
|
||||
mock_client.send.assert_called_once_with(mock_request, follow_redirects=True)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,134 @@
|
|||
"""Tests for logging filters."""
|
||||
|
||||
import logging
|
||||
from unittest import mock
|
||||
|
||||
|
||||
class TestTraceContextFilter:
|
||||
def test_sets_empty_trace_id_without_context(self):
|
||||
from core.logging.filters import TraceContextFilter
|
||||
|
||||
filter = TraceContextFilter()
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="",
|
||||
lineno=0,
|
||||
msg="test",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
|
||||
result = filter.filter(record)
|
||||
|
||||
assert result is True
|
||||
assert hasattr(record, "trace_id")
|
||||
assert hasattr(record, "span_id")
|
||||
assert hasattr(record, "req_id")
|
||||
|
||||
def test_filter_always_returns_true(self):
|
||||
from core.logging.filters import TraceContextFilter
|
||||
|
||||
filter = TraceContextFilter()
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="",
|
||||
lineno=0,
|
||||
msg="test",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
|
||||
result = filter.filter(record)
|
||||
assert result is True
|
||||
|
||||
def test_sets_trace_id_from_otel_when_available(self):
|
||||
from core.logging.filters import TraceContextFilter
|
||||
|
||||
mock_span = mock.MagicMock()
|
||||
mock_context = mock.MagicMock()
|
||||
mock_context.trace_id = 0x5B8AA5A2D2C872E8321CF37308D69DF2
|
||||
mock_context.span_id = 0x051581BF3BB55C45
|
||||
mock_span.get_span_context.return_value = mock_context
|
||||
|
||||
with mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span):
|
||||
with (
|
||||
mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0),
|
||||
mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0),
|
||||
):
|
||||
filter = TraceContextFilter()
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="",
|
||||
lineno=0,
|
||||
msg="test",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
|
||||
filter.filter(record)
|
||||
|
||||
assert record.trace_id == "5b8aa5a2d2c872e8321cf37308d69df2"
|
||||
assert record.span_id == "051581bf3bb55c45"
|
||||
|
||||
|
||||
class TestIdentityContextFilter:
|
||||
def test_sets_empty_identity_without_request_context(self):
|
||||
from core.logging.filters import IdentityContextFilter
|
||||
|
||||
filter = IdentityContextFilter()
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="",
|
||||
lineno=0,
|
||||
msg="test",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
|
||||
result = filter.filter(record)
|
||||
|
||||
assert result is True
|
||||
assert record.tenant_id == ""
|
||||
assert record.user_id == ""
|
||||
assert record.user_type == ""
|
||||
|
||||
def test_filter_always_returns_true(self):
|
||||
from core.logging.filters import IdentityContextFilter
|
||||
|
||||
filter = IdentityContextFilter()
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="",
|
||||
lineno=0,
|
||||
msg="test",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
|
||||
result = filter.filter(record)
|
||||
assert result is True
|
||||
|
||||
def test_handles_exception_gracefully(self):
|
||||
from core.logging.filters import IdentityContextFilter
|
||||
|
||||
filter = IdentityContextFilter()
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="",
|
||||
lineno=0,
|
||||
msg="test",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
|
||||
# Should not raise even if something goes wrong
|
||||
with mock.patch("core.logging.filters.flask.has_request_context", side_effect=Exception("Test error")):
|
||||
result = filter.filter(record)
|
||||
assert result is True
|
||||
assert record.tenant_id == ""
|
||||
|
|
@ -0,0 +1,240 @@
|
|||
"""Tests for structured JSON formatter."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
|
||||
|
||||
class TestStructuredJSONFormatter:
|
||||
def test_basic_log_format(self):
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
formatter = StructuredJSONFormatter(service_name="test-service")
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="test.py",
|
||||
lineno=42,
|
||||
msg="Test message",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
|
||||
output = formatter.format(record)
|
||||
log_dict = json.loads(output)
|
||||
|
||||
assert log_dict["severity"] == "INFO"
|
||||
assert log_dict["service"] == "test-service"
|
||||
assert log_dict["caller"] == "test.py:42"
|
||||
assert log_dict["message"] == "Test message"
|
||||
assert "ts" in log_dict
|
||||
assert log_dict["ts"].endswith("Z")
|
||||
|
||||
def test_severity_mapping(self):
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
formatter = StructuredJSONFormatter()
|
||||
|
||||
test_cases = [
|
||||
(logging.DEBUG, "DEBUG"),
|
||||
(logging.INFO, "INFO"),
|
||||
(logging.WARNING, "WARN"),
|
||||
(logging.ERROR, "ERROR"),
|
||||
(logging.CRITICAL, "ERROR"),
|
||||
]
|
||||
|
||||
for level, expected_severity in test_cases:
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=level,
|
||||
pathname="test.py",
|
||||
lineno=1,
|
||||
msg="Test",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
output = formatter.format(record)
|
||||
log_dict = json.loads(output)
|
||||
assert log_dict["severity"] == expected_severity, f"Level {level} should map to {expected_severity}"
|
||||
|
||||
def test_error_with_stack_trace(self):
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
formatter = StructuredJSONFormatter()
|
||||
|
||||
try:
|
||||
raise ValueError("Test error")
|
||||
except ValueError:
|
||||
exc_info = sys.exc_info()
|
||||
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.ERROR,
|
||||
pathname="test.py",
|
||||
lineno=10,
|
||||
msg="Error occurred",
|
||||
args=(),
|
||||
exc_info=exc_info,
|
||||
)
|
||||
|
||||
output = formatter.format(record)
|
||||
log_dict = json.loads(output)
|
||||
|
||||
assert log_dict["severity"] == "ERROR"
|
||||
assert "stack_trace" in log_dict
|
||||
assert "ValueError: Test error" in log_dict["stack_trace"]
|
||||
|
||||
def test_no_stack_trace_for_info(self):
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
formatter = StructuredJSONFormatter()
|
||||
|
||||
try:
|
||||
raise ValueError("Test error")
|
||||
except ValueError:
|
||||
exc_info = sys.exc_info()
|
||||
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="test.py",
|
||||
lineno=10,
|
||||
msg="Info message",
|
||||
args=(),
|
||||
exc_info=exc_info,
|
||||
)
|
||||
|
||||
output = formatter.format(record)
|
||||
log_dict = json.loads(output)
|
||||
|
||||
assert "stack_trace" not in log_dict
|
||||
|
||||
def test_trace_context_included(self):
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
formatter = StructuredJSONFormatter()
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="test.py",
|
||||
lineno=1,
|
||||
msg="Test",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
record.trace_id = "5b8aa5a2d2c872e8321cf37308d69df2"
|
||||
record.span_id = "051581bf3bb55c45"
|
||||
|
||||
output = formatter.format(record)
|
||||
log_dict = json.loads(output)
|
||||
|
||||
assert log_dict["trace_id"] == "5b8aa5a2d2c872e8321cf37308d69df2"
|
||||
assert log_dict["span_id"] == "051581bf3bb55c45"
|
||||
|
||||
def test_identity_context_included(self):
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
formatter = StructuredJSONFormatter()
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="test.py",
|
||||
lineno=1,
|
||||
msg="Test",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
record.tenant_id = "t-global-corp"
|
||||
record.user_id = "u-admin-007"
|
||||
record.user_type = "admin"
|
||||
|
||||
output = formatter.format(record)
|
||||
log_dict = json.loads(output)
|
||||
|
||||
assert "identity" in log_dict
|
||||
assert log_dict["identity"]["tenant_id"] == "t-global-corp"
|
||||
assert log_dict["identity"]["user_id"] == "u-admin-007"
|
||||
assert log_dict["identity"]["user_type"] == "admin"
|
||||
|
||||
def test_no_identity_when_empty(self):
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
formatter = StructuredJSONFormatter()
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="test.py",
|
||||
lineno=1,
|
||||
msg="Test",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
|
||||
output = formatter.format(record)
|
||||
log_dict = json.loads(output)
|
||||
|
||||
assert "identity" not in log_dict
|
||||
|
||||
def test_attributes_included(self):
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
formatter = StructuredJSONFormatter()
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="test.py",
|
||||
lineno=1,
|
||||
msg="Test",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
record.attributes = {"order_id": "ord-123", "amount": 99.99}
|
||||
|
||||
output = formatter.format(record)
|
||||
log_dict = json.loads(output)
|
||||
|
||||
assert log_dict["attributes"]["order_id"] == "ord-123"
|
||||
assert log_dict["attributes"]["amount"] == 99.99
|
||||
|
||||
def test_message_with_args(self):
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
formatter = StructuredJSONFormatter()
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="test.py",
|
||||
lineno=1,
|
||||
msg="User %s logged in from %s",
|
||||
args=("john", "192.168.1.1"),
|
||||
exc_info=None,
|
||||
)
|
||||
|
||||
output = formatter.format(record)
|
||||
log_dict = json.loads(output)
|
||||
|
||||
assert log_dict["message"] == "User john logged in from 192.168.1.1"
|
||||
|
||||
def test_timestamp_format(self):
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
formatter = StructuredJSONFormatter()
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="test.py",
|
||||
lineno=1,
|
||||
msg="Test",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
|
||||
output = formatter.format(record)
|
||||
log_dict = json.loads(output)
|
||||
|
||||
# Verify ISO 8601 format with Z suffix
|
||||
ts = log_dict["ts"]
|
||||
assert ts.endswith("Z")
|
||||
assert "T" in ts
|
||||
# Should have milliseconds
|
||||
assert "." in ts
|
||||
|
|
@ -0,0 +1,102 @@
|
|||
"""Tests for trace helper functions."""
|
||||
|
||||
import re
|
||||
from unittest import mock
|
||||
|
||||
|
||||
class TestGetSpanIdFromOtelContext:
|
||||
def test_returns_none_without_span(self):
|
||||
from core.helper.trace_id_helper import get_span_id_from_otel_context
|
||||
|
||||
with mock.patch("opentelemetry.trace.get_current_span", return_value=None):
|
||||
result = get_span_id_from_otel_context()
|
||||
assert result is None
|
||||
|
||||
def test_returns_span_id_when_available(self):
|
||||
from core.helper.trace_id_helper import get_span_id_from_otel_context
|
||||
|
||||
mock_span = mock.MagicMock()
|
||||
mock_context = mock.MagicMock()
|
||||
mock_context.span_id = 0x051581BF3BB55C45
|
||||
mock_span.get_span_context.return_value = mock_context
|
||||
|
||||
with mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span):
|
||||
with mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0):
|
||||
result = get_span_id_from_otel_context()
|
||||
assert result == "051581bf3bb55c45"
|
||||
|
||||
def test_returns_none_on_exception(self):
|
||||
from core.helper.trace_id_helper import get_span_id_from_otel_context
|
||||
|
||||
with mock.patch("opentelemetry.trace.get_current_span", side_effect=Exception("Test error")):
|
||||
result = get_span_id_from_otel_context()
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGenerateTraceparentHeader:
|
||||
def test_generates_valid_format(self):
|
||||
from core.helper.trace_id_helper import generate_traceparent_header
|
||||
|
||||
with mock.patch("opentelemetry.trace.get_current_span", return_value=None):
|
||||
result = generate_traceparent_header()
|
||||
|
||||
assert result is not None
|
||||
# Format: 00-{trace_id}-{span_id}-01
|
||||
parts = result.split("-")
|
||||
assert len(parts) == 4
|
||||
assert parts[0] == "00" # version
|
||||
assert len(parts[1]) == 32 # trace_id (32 hex chars)
|
||||
assert len(parts[2]) == 16 # span_id (16 hex chars)
|
||||
assert parts[3] == "01" # flags
|
||||
|
||||
def test_uses_otel_context_when_available(self):
|
||||
from core.helper.trace_id_helper import generate_traceparent_header
|
||||
|
||||
mock_span = mock.MagicMock()
|
||||
mock_context = mock.MagicMock()
|
||||
mock_context.trace_id = 0x5B8AA5A2D2C872E8321CF37308D69DF2
|
||||
mock_context.span_id = 0x051581BF3BB55C45
|
||||
mock_span.get_span_context.return_value = mock_context
|
||||
|
||||
with mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span):
|
||||
with (
|
||||
mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0),
|
||||
mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0),
|
||||
):
|
||||
result = generate_traceparent_header()
|
||||
|
||||
assert result == "00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01"
|
||||
|
||||
def test_generates_hex_only_values(self):
|
||||
from core.helper.trace_id_helper import generate_traceparent_header
|
||||
|
||||
with mock.patch("opentelemetry.trace.get_current_span", return_value=None):
|
||||
result = generate_traceparent_header()
|
||||
|
||||
parts = result.split("-")
|
||||
# All parts should be valid hex
|
||||
assert re.match(r"^[0-9a-f]+$", parts[1])
|
||||
assert re.match(r"^[0-9a-f]+$", parts[2])
|
||||
|
||||
|
||||
class TestParseTraceparentHeader:
|
||||
def test_parses_valid_traceparent(self):
|
||||
from core.helper.trace_id_helper import parse_traceparent_header
|
||||
|
||||
traceparent = "00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01"
|
||||
result = parse_traceparent_header(traceparent)
|
||||
|
||||
assert result == "5b8aa5a2d2c872e8321cf37308d69df2"
|
||||
|
||||
def test_returns_none_for_invalid_format(self):
|
||||
from core.helper.trace_id_helper import parse_traceparent_header
|
||||
|
||||
# Wrong number of parts
|
||||
assert parse_traceparent_header("00-abc-def") is None
|
||||
# Wrong trace_id length
|
||||
assert parse_traceparent_header("00-abc-def-01") is None
|
||||
|
||||
def test_returns_none_for_empty_string(self):
|
||||
from core.helper.trace_id_helper import parse_traceparent_header
|
||||
|
||||
assert parse_traceparent_header("") is None
|
||||
|
|
@ -99,29 +99,20 @@ def test_external_api_json_message_and_bad_request_rewrite():
|
|||
assert res.get_json()["message"] == "Invalid JSON payload received or JSON payload is empty."
|
||||
|
||||
|
||||
def test_external_api_param_mapping_and_quota_and_exc_info_none():
|
||||
# Force exc_info() to return (None,None,None) only during request
|
||||
import libs.external_api as ext
|
||||
def test_external_api_param_mapping_and_quota():
|
||||
app = _create_api_app()
|
||||
client = app.test_client()
|
||||
|
||||
orig_exc_info = ext.sys.exc_info
|
||||
try:
|
||||
ext.sys.exc_info = lambda: (None, None, None)
|
||||
# Param errors mapping payload path
|
||||
res = client.get("/api/param-errors")
|
||||
assert res.status_code == 400
|
||||
data = res.get_json()
|
||||
assert data["code"] == "invalid_param"
|
||||
assert data["params"] == "field"
|
||||
|
||||
app = _create_api_app()
|
||||
client = app.test_client()
|
||||
|
||||
# Param errors mapping payload path
|
||||
res = client.get("/api/param-errors")
|
||||
assert res.status_code == 400
|
||||
data = res.get_json()
|
||||
assert data["code"] == "invalid_param"
|
||||
assert data["params"] == "field"
|
||||
|
||||
# Quota path — depending on Flask-RESTX internals it may be handled
|
||||
res = client.get("/api/quota")
|
||||
assert res.status_code in (400, 429)
|
||||
finally:
|
||||
ext.sys.exc_info = orig_exc_info # type: ignore[assignment]
|
||||
# Quota path — depending on Flask-RESTX internals it may be handled
|
||||
res = client.get("/api/quota")
|
||||
assert res.status_code in (400, 429)
|
||||
|
||||
|
||||
def test_unauthorized_and_force_logout_clears_cookies():
|
||||
|
|
|
|||
Loading…
Reference in New Issue