mirror of https://github.com/langgenius/dify.git
feat(telemetry): implement gateway routing and enqueue logic
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-Claude) Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
This commit is contained in:
parent
752b01ae91
commit
51b0c5c89c
|
|
@ -168,3 +168,84 @@ basedpyright <files> # 0 erro
|
|||
Task 3 (gateway implementation) will wire `TelemetryGateway.emit()` to call `process_enterprise_telemetry.delay()`.
|
||||
Once Task 3 completes, handlers can optionally be updated to call gateway directly instead of task.
|
||||
|
||||
## [2026-02-06] Task 3: TelemetryGateway Implementation
|
||||
|
||||
### Implementation Decisions
|
||||
|
||||
**Gateway Architecture:**
|
||||
- `TelemetryGateway.emit(case, context, payload, trace_manager)` as main entry point
|
||||
- Routes based on `CASE_ROUTING[case].signal_type`:
|
||||
- `trace` → TraceQueueManager.add_trace_task()
|
||||
- `metric_log` → process_enterprise_telemetry.delay()
|
||||
- Feature flag `ENTERPRISE_TELEMETRY_GATEWAY_ENABLED` gates new vs legacy behavior
|
||||
|
||||
**CE Eligibility:**
|
||||
- Trace cases with `ce_eligible=False` dropped when enterprise disabled
|
||||
- CE-eligible traces (WORKFLOW_RUN, MESSAGE_RUN) always processed
|
||||
- Enterprise-only traces (NODE_EXECUTION, DRAFT_NODE_EXECUTION, PROMPT_GENERATION) require EE
|
||||
|
||||
**Payload Sizing:**
|
||||
- Threshold: 1MB (PAYLOAD_SIZE_THRESHOLD_BYTES)
|
||||
- Large payloads stored to `telemetry/{tenant_id}/{event_id}.json`
|
||||
- Storage failures fall back to inline payload (best-effort)
|
||||
|
||||
**Legacy Path:**
|
||||
- When gateway disabled, mimics original TelemetryFacade behavior
|
||||
- Only processes trace cases, metric/log cases dropped
|
||||
- Imports deferred until after route check to avoid circular imports
|
||||
|
||||
### Testing Patterns
|
||||
|
||||
**Circular Import Workaround:**
|
||||
- `ops_trace_manager` has deep import chain causing circular imports in tests
|
||||
- Solution: `mock_ops_trace_manager` fixture patches `sys.modules` before import
|
||||
- Trace routing tests require fixture; metric/log tests don't
|
||||
|
||||
**Mock Parameter Naming:**
|
||||
- Prefixed unused mock params with `_` (e.g., `_mock_ee_enabled`)
|
||||
- Ruff PT019 warnings are style hints, not errors
|
||||
|
||||
**Coverage:**
|
||||
- 38 tests total
|
||||
- Feature flag on/off paths
|
||||
- Trace routing (CE-eligible and enterprise-only)
|
||||
- Metric/log routing
|
||||
- Large payload storage and fallback
|
||||
- Legacy path behavior
|
||||
|
||||
### Files Modified/Created
|
||||
|
||||
- `enterprise/telemetry/gateway.py` (350 lines, expanded from 27)
|
||||
- Added TelemetryGateway class
|
||||
- Added CASE_TO_TRACE_TASK_NAME mapping
|
||||
- Added is_gateway_enabled() and _is_enterprise_telemetry_enabled()
|
||||
- Added module-level emit() convenience function
|
||||
- `tests/unit_tests/enterprise/telemetry/test_gateway.py` (NEW, 422 lines, 38 tests)
|
||||
|
||||
### Verification
|
||||
|
||||
```bash
|
||||
pytest tests/unit_tests/enterprise/telemetry/test_gateway.py -v # 38 passed
|
||||
ruff check --fix <files> # 22 PT019 warnings (style)
|
||||
basedpyright <files> # 0 errors, 0 warnings
|
||||
```
|
||||
|
||||
### Key Insights
|
||||
|
||||
**Import Deferral:**
|
||||
- Legacy path must defer `ops_trace_manager` import until after route check
|
||||
- Otherwise metric/log cases trigger circular import chain
|
||||
- Pattern: check signal_type first, then import if needed
|
||||
|
||||
**TelemetryCase to TraceTaskName Mapping:**
|
||||
- WORKFLOW_RUN → "workflow"
|
||||
- MESSAGE_RUN → "message"
|
||||
- NODE_EXECUTION → "node_execution"
|
||||
- DRAFT_NODE_EXECUTION → "draft_node_execution"
|
||||
- PROMPT_GENERATION → "prompt_generation"
|
||||
|
||||
**Feature Flag Design:**
|
||||
- Default OFF for safe rollout
|
||||
- Env var: ENTERPRISE_TELEMETRY_GATEWAY_ENABLED
|
||||
- Truthy values: "true", "1", "yes" (case-insensitive)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,36 @@
|
|||
"""Telemetry gateway routing configuration.
|
||||
"""Telemetry gateway routing configuration and implementation.
|
||||
|
||||
This module defines the routing table that maps telemetry cases to their
|
||||
processing routes (trace vs metric/log) and customer engagement eligibility.
|
||||
It also provides the TelemetryGateway class that routes events to the
|
||||
appropriate processing path.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enterprise.telemetry.contracts import CaseRoute, TelemetryCase
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from enterprise.telemetry.contracts import CaseRoute, TelemetryCase, TelemetryEnvelope
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PAYLOAD_SIZE_THRESHOLD_BYTES = 1 * 1024 * 1024
|
||||
|
||||
CASE_TO_TRACE_TASK_NAME: dict[TelemetryCase, str] = {
|
||||
TelemetryCase.WORKFLOW_RUN: "workflow",
|
||||
TelemetryCase.MESSAGE_RUN: "message",
|
||||
TelemetryCase.NODE_EXECUTION: "node_execution",
|
||||
TelemetryCase.DRAFT_NODE_EXECUTION: "draft_node_execution",
|
||||
TelemetryCase.PROMPT_GENERATION: "prompt_generation",
|
||||
}
|
||||
|
||||
CASE_ROUTING: dict[TelemetryCase, CaseRoute] = {
|
||||
TelemetryCase.WORKFLOW_RUN: CaseRoute(signal_type="trace", ce_eligible=True),
|
||||
|
|
@ -24,3 +48,286 @@ CASE_ROUTING: dict[TelemetryCase, CaseRoute] = {
|
|||
TelemetryCase.DATASET_RETRIEVAL: CaseRoute(signal_type="metric_log", ce_eligible=False),
|
||||
TelemetryCase.GENERATE_NAME: CaseRoute(signal_type="metric_log", ce_eligible=False),
|
||||
}
|
||||
|
||||
|
||||
def is_gateway_enabled() -> bool:
|
||||
"""Check if the telemetry gateway is enabled via feature flag.
|
||||
|
||||
Returns:
|
||||
True if ENTERPRISE_TELEMETRY_GATEWAY_ENABLED is set to a truthy value.
|
||||
"""
|
||||
return os.environ.get("ENTERPRISE_TELEMETRY_GATEWAY_ENABLED", "").lower() in ("true", "1", "yes")
|
||||
|
||||
|
||||
def _is_enterprise_telemetry_enabled() -> bool:
|
||||
"""Check if enterprise telemetry is enabled.
|
||||
|
||||
Wraps the check from core.telemetry to handle import failures gracefully.
|
||||
"""
|
||||
try:
|
||||
from enterprise.telemetry.exporter import is_enterprise_telemetry_enabled
|
||||
|
||||
return is_enterprise_telemetry_enabled()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
class TelemetryGateway:
|
||||
"""Gateway for routing telemetry events to appropriate processing paths.
|
||||
|
||||
Routes trace-shaped events to TraceQueueManager and metric/log events
|
||||
to the enterprise telemetry Celery queue. Handles CE eligibility checks,
|
||||
large payload storage, and feature flag gating.
|
||||
"""
|
||||
|
||||
def emit(
|
||||
self,
|
||||
case: TelemetryCase,
|
||||
context: dict[str, Any],
|
||||
payload: dict[str, Any],
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
) -> None:
|
||||
"""Emit a telemetry event through the gateway.
|
||||
|
||||
Routes the event based on its case type:
|
||||
- trace: Routes to TraceQueueManager for existing trace pipeline
|
||||
- metric_log: Routes to enterprise telemetry Celery task
|
||||
|
||||
Args:
|
||||
case: The telemetry case type.
|
||||
context: Event context containing tenant_id, app_id, user_id.
|
||||
payload: The event payload data.
|
||||
trace_manager: Optional TraceQueueManager for trace routing.
|
||||
"""
|
||||
if not is_gateway_enabled():
|
||||
self._emit_legacy(case, context, payload, trace_manager)
|
||||
return
|
||||
|
||||
route = CASE_ROUTING.get(case)
|
||||
if route is None:
|
||||
logger.warning("Unknown telemetry case: %s, dropping event", case)
|
||||
return
|
||||
|
||||
if route.signal_type == "trace":
|
||||
self._emit_trace(case, context, payload, route, trace_manager)
|
||||
else:
|
||||
self._emit_metric_log(case, context, payload)
|
||||
|
||||
def _emit_legacy(
|
||||
self,
|
||||
case: TelemetryCase,
|
||||
context: dict[str, Any],
|
||||
payload: dict[str, Any],
|
||||
trace_manager: TraceQueueManager | None,
|
||||
) -> None:
|
||||
"""Emit using legacy path (TelemetryFacade behavior).
|
||||
|
||||
Used when gateway feature flag is disabled.
|
||||
"""
|
||||
route = CASE_ROUTING.get(case)
|
||||
if route is None or route.signal_type != "trace":
|
||||
return
|
||||
|
||||
trace_task_name_str = CASE_TO_TRACE_TASK_NAME.get(case)
|
||||
if trace_task_name_str is None:
|
||||
return
|
||||
|
||||
if not route.ce_eligible and not _is_enterprise_telemetry_enabled():
|
||||
return
|
||||
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import (
|
||||
TraceQueueManager as LocalTraceQueueManager,
|
||||
)
|
||||
from core.ops.ops_trace_manager import (
|
||||
TraceTask,
|
||||
)
|
||||
|
||||
try:
|
||||
trace_task_name = TraceTaskName(trace_task_name_str)
|
||||
except ValueError:
|
||||
logger.warning("Invalid trace task name: %s", trace_task_name_str)
|
||||
return
|
||||
|
||||
queue_manager = trace_manager or LocalTraceQueueManager(
|
||||
app_id=context.get("app_id"),
|
||||
user_id=context.get("user_id"),
|
||||
)
|
||||
|
||||
queue_manager.add_trace_task(
|
||||
TraceTask(
|
||||
trace_task_name,
|
||||
**payload,
|
||||
)
|
||||
)
|
||||
|
||||
def _emit_trace(
|
||||
self,
|
||||
case: TelemetryCase,
|
||||
context: dict[str, Any],
|
||||
payload: dict[str, Any],
|
||||
route: CaseRoute,
|
||||
trace_manager: TraceQueueManager | None,
|
||||
) -> None:
|
||||
"""Emit a trace-shaped event to TraceQueueManager.
|
||||
|
||||
Args:
|
||||
case: The telemetry case type.
|
||||
context: Event context.
|
||||
payload: The event payload.
|
||||
route: Routing configuration for this case.
|
||||
trace_manager: Optional TraceQueueManager.
|
||||
"""
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import (
|
||||
TraceQueueManager as LocalTraceQueueManager,
|
||||
)
|
||||
from core.ops.ops_trace_manager import (
|
||||
TraceTask,
|
||||
)
|
||||
|
||||
if not route.ce_eligible and not _is_enterprise_telemetry_enabled():
|
||||
logger.debug(
|
||||
"Dropping enterprise-only trace event: case=%s (EE disabled)",
|
||||
case,
|
||||
)
|
||||
return
|
||||
|
||||
trace_task_name_str = CASE_TO_TRACE_TASK_NAME.get(case)
|
||||
if trace_task_name_str is None:
|
||||
logger.warning("No TraceTaskName mapping for case: %s", case)
|
||||
return
|
||||
|
||||
try:
|
||||
trace_task_name = TraceTaskName(trace_task_name_str)
|
||||
except ValueError:
|
||||
logger.warning("Invalid trace task name: %s", trace_task_name_str)
|
||||
return
|
||||
|
||||
queue_manager = trace_manager or LocalTraceQueueManager(
|
||||
app_id=context.get("app_id"),
|
||||
user_id=context.get("user_id"),
|
||||
)
|
||||
|
||||
queue_manager.add_trace_task(
|
||||
TraceTask(
|
||||
trace_task_name,
|
||||
**payload,
|
||||
)
|
||||
)
|
||||
logger.debug(
|
||||
"Enqueued trace task: case=%s, app_id=%s",
|
||||
case,
|
||||
context.get("app_id"),
|
||||
)
|
||||
|
||||
def _emit_metric_log(
|
||||
self,
|
||||
case: TelemetryCase,
|
||||
context: dict[str, Any],
|
||||
payload: dict[str, Any],
|
||||
) -> None:
|
||||
"""Emit a metric/log event to the enterprise telemetry Celery queue.
|
||||
|
||||
Args:
|
||||
case: The telemetry case type.
|
||||
context: Event context containing tenant_id.
|
||||
payload: The event payload.
|
||||
"""
|
||||
from tasks.enterprise_telemetry_task import process_enterprise_telemetry
|
||||
|
||||
tenant_id = context.get("tenant_id", "")
|
||||
event_id = str(uuid.uuid4())
|
||||
|
||||
payload_for_envelope, payload_ref = self._handle_payload_sizing(payload, tenant_id, event_id)
|
||||
|
||||
envelope = TelemetryEnvelope(
|
||||
case=case,
|
||||
tenant_id=tenant_id,
|
||||
event_id=event_id,
|
||||
payload=payload_for_envelope,
|
||||
metadata={"payload_ref": payload_ref} if payload_ref else None,
|
||||
)
|
||||
|
||||
process_enterprise_telemetry.delay(envelope.model_dump_json())
|
||||
logger.debug(
|
||||
"Enqueued metric/log event: case=%s, tenant_id=%s, event_id=%s",
|
||||
case,
|
||||
tenant_id,
|
||||
event_id,
|
||||
)
|
||||
|
||||
def _handle_payload_sizing(
|
||||
self,
|
||||
payload: dict[str, Any],
|
||||
tenant_id: str,
|
||||
event_id: str,
|
||||
) -> tuple[dict[str, Any], str | None]:
|
||||
"""Handle large payload storage.
|
||||
|
||||
If payload exceeds threshold, stores to shared storage and returns
|
||||
a reference. Otherwise returns payload as-is.
|
||||
|
||||
Args:
|
||||
payload: The event payload.
|
||||
tenant_id: Tenant identifier for storage path.
|
||||
event_id: Event identifier for storage path.
|
||||
|
||||
Returns:
|
||||
Tuple of (payload_for_envelope, payload_ref).
|
||||
If stored, payload_for_envelope is empty and payload_ref is set.
|
||||
Otherwise, payload_for_envelope is the original payload and
|
||||
payload_ref is None.
|
||||
"""
|
||||
try:
|
||||
payload_json = json.dumps(payload)
|
||||
payload_size = len(payload_json.encode("utf-8"))
|
||||
except (TypeError, ValueError):
|
||||
logger.warning("Failed to serialize payload for sizing: event_id=%s", event_id)
|
||||
return payload, None
|
||||
|
||||
if payload_size <= PAYLOAD_SIZE_THRESHOLD_BYTES:
|
||||
return payload, None
|
||||
|
||||
storage_key = f"telemetry/{tenant_id}/{event_id}.json"
|
||||
try:
|
||||
storage.save(storage_key, payload_json.encode("utf-8"))
|
||||
logger.debug("Stored large payload to storage: key=%s, size=%d", storage_key, payload_size)
|
||||
return {}, storage_key
|
||||
except Exception:
|
||||
logger.warning("Failed to store large payload, inlining instead: event_id=%s", event_id, exc_info=True)
|
||||
return payload, None
|
||||
|
||||
|
||||
_gateway: TelemetryGateway | None = None
|
||||
|
||||
|
||||
def get_gateway() -> TelemetryGateway:
|
||||
"""Get the module-level gateway instance.
|
||||
|
||||
Returns:
|
||||
The singleton TelemetryGateway instance.
|
||||
"""
|
||||
global _gateway
|
||||
if _gateway is None:
|
||||
_gateway = TelemetryGateway()
|
||||
return _gateway
|
||||
|
||||
|
||||
def emit(
|
||||
case: TelemetryCase,
|
||||
context: dict[str, Any],
|
||||
payload: dict[str, Any],
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
) -> None:
|
||||
"""Emit a telemetry event through the gateway.
|
||||
|
||||
Convenience function that uses the module-level gateway instance.
|
||||
|
||||
Args:
|
||||
case: The telemetry case type.
|
||||
context: Event context containing tenant_id, app_id, user_id.
|
||||
payload: The event payload data.
|
||||
trace_manager: Optional TraceQueueManager for trace routing.
|
||||
"""
|
||||
get_gateway().emit(case, context, payload, trace_manager)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,435 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope
|
||||
from enterprise.telemetry.gateway import (
|
||||
CASE_ROUTING,
|
||||
CASE_TO_TRACE_TASK_NAME,
|
||||
PAYLOAD_SIZE_THRESHOLD_BYTES,
|
||||
TelemetryGateway,
|
||||
emit,
|
||||
get_gateway,
|
||||
is_gateway_enabled,
|
||||
)
|
||||
|
||||
|
||||
class TestIsGatewayEnabled:
|
||||
@pytest.mark.parametrize(
|
||||
("env_value", "expected"),
|
||||
[
|
||||
("true", True),
|
||||
("True", True),
|
||||
("TRUE", True),
|
||||
("1", True),
|
||||
("yes", True),
|
||||
("YES", True),
|
||||
("false", False),
|
||||
("False", False),
|
||||
("0", False),
|
||||
("no", False),
|
||||
("", False),
|
||||
],
|
||||
)
|
||||
def test_feature_flag_values(self, env_value: str, expected: bool) -> None:
|
||||
with patch.dict("os.environ", {"ENTERPRISE_TELEMETRY_GATEWAY_ENABLED": env_value}):
|
||||
assert is_gateway_enabled() is expected
|
||||
|
||||
def test_missing_env_var(self) -> None:
|
||||
with patch.dict("os.environ", {}, clear=True):
|
||||
assert is_gateway_enabled() is False
|
||||
|
||||
|
||||
class TestCaseRoutingTable:
|
||||
def test_all_cases_have_routing(self) -> None:
|
||||
for case in TelemetryCase:
|
||||
assert case in CASE_ROUTING, f"Missing routing for {case}"
|
||||
|
||||
def test_trace_cases(self) -> None:
|
||||
trace_cases = [
|
||||
TelemetryCase.WORKFLOW_RUN,
|
||||
TelemetryCase.MESSAGE_RUN,
|
||||
TelemetryCase.NODE_EXECUTION,
|
||||
TelemetryCase.DRAFT_NODE_EXECUTION,
|
||||
TelemetryCase.PROMPT_GENERATION,
|
||||
]
|
||||
for case in trace_cases:
|
||||
assert CASE_ROUTING[case].signal_type == "trace", f"{case} should be trace"
|
||||
|
||||
def test_metric_log_cases(self) -> None:
|
||||
metric_log_cases = [
|
||||
TelemetryCase.APP_CREATED,
|
||||
TelemetryCase.APP_UPDATED,
|
||||
TelemetryCase.APP_DELETED,
|
||||
TelemetryCase.FEEDBACK_CREATED,
|
||||
TelemetryCase.TOOL_EXECUTION,
|
||||
TelemetryCase.MODERATION_CHECK,
|
||||
TelemetryCase.SUGGESTED_QUESTION,
|
||||
TelemetryCase.DATASET_RETRIEVAL,
|
||||
TelemetryCase.GENERATE_NAME,
|
||||
]
|
||||
for case in metric_log_cases:
|
||||
assert CASE_ROUTING[case].signal_type == "metric_log", f"{case} should be metric_log"
|
||||
|
||||
def test_ce_eligible_cases(self) -> None:
|
||||
ce_eligible_cases = [TelemetryCase.WORKFLOW_RUN, TelemetryCase.MESSAGE_RUN]
|
||||
for case in ce_eligible_cases:
|
||||
assert CASE_ROUTING[case].ce_eligible is True, f"{case} should be CE eligible"
|
||||
|
||||
def test_enterprise_only_cases(self) -> None:
|
||||
enterprise_only_cases = [
|
||||
TelemetryCase.NODE_EXECUTION,
|
||||
TelemetryCase.DRAFT_NODE_EXECUTION,
|
||||
TelemetryCase.PROMPT_GENERATION,
|
||||
]
|
||||
for case in enterprise_only_cases:
|
||||
assert CASE_ROUTING[case].ce_eligible is False, f"{case} should be enterprise-only"
|
||||
|
||||
def test_trace_cases_have_task_name_mapping(self) -> None:
|
||||
trace_cases = [c for c in TelemetryCase if CASE_ROUTING[c].signal_type == "trace"]
|
||||
for case in trace_cases:
|
||||
assert case in CASE_TO_TRACE_TASK_NAME, f"Missing TraceTaskName mapping for {case}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ops_trace_manager():
|
||||
mock_module = MagicMock()
|
||||
mock_trace_task_class = MagicMock()
|
||||
mock_trace_task_class.return_value = MagicMock()
|
||||
mock_module.TraceTask = mock_trace_task_class
|
||||
mock_module.TraceQueueManager = MagicMock()
|
||||
|
||||
mock_trace_entity = MagicMock()
|
||||
mock_trace_task_name = MagicMock()
|
||||
mock_trace_task_name.return_value = "workflow"
|
||||
mock_trace_entity.TraceTaskName = mock_trace_task_name
|
||||
|
||||
with (
|
||||
patch.dict(sys.modules, {"core.ops.ops_trace_manager": mock_module}),
|
||||
patch.dict(sys.modules, {"core.ops.entities.trace_entity": mock_trace_entity}),
|
||||
):
|
||||
yield mock_module, mock_trace_entity
|
||||
|
||||
|
||||
class TestTelemetryGatewayTraceRouting:
|
||||
@pytest.fixture
|
||||
def gateway(self) -> TelemetryGateway:
|
||||
return TelemetryGateway()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_trace_manager(self) -> MagicMock:
|
||||
return MagicMock()
|
||||
|
||||
@patch("enterprise.telemetry.gateway.is_gateway_enabled", return_value=True)
|
||||
@patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=True)
|
||||
def test_trace_case_routes_to_trace_manager(
|
||||
self,
|
||||
_mock_ee_enabled: MagicMock,
|
||||
_mock_gateway_enabled: MagicMock,
|
||||
gateway: TelemetryGateway,
|
||||
mock_trace_manager: MagicMock,
|
||||
mock_ops_trace_manager: tuple[MagicMock, MagicMock],
|
||||
) -> None:
|
||||
context = {"app_id": "app-123", "user_id": "user-456", "tenant_id": "tenant-789"}
|
||||
payload = {"workflow_run_id": "run-abc"}
|
||||
|
||||
gateway.emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager)
|
||||
|
||||
mock_trace_manager.add_trace_task.assert_called_once()
|
||||
|
||||
@patch("enterprise.telemetry.gateway.is_gateway_enabled", return_value=True)
|
||||
@patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=False)
|
||||
def test_ce_eligible_trace_enqueued_when_ee_disabled(
|
||||
self,
|
||||
_mock_ee_enabled: MagicMock,
|
||||
_mock_gateway_enabled: MagicMock,
|
||||
gateway: TelemetryGateway,
|
||||
mock_trace_manager: MagicMock,
|
||||
mock_ops_trace_manager: tuple[MagicMock, MagicMock],
|
||||
) -> None:
|
||||
context = {"app_id": "app-123", "user_id": "user-456"}
|
||||
payload = {"workflow_run_id": "run-abc"}
|
||||
|
||||
gateway.emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager)
|
||||
|
||||
mock_trace_manager.add_trace_task.assert_called_once()
|
||||
|
||||
@patch("enterprise.telemetry.gateway.is_gateway_enabled", return_value=True)
|
||||
@patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=False)
|
||||
def test_enterprise_only_trace_dropped_when_ee_disabled(
|
||||
self,
|
||||
_mock_ee_enabled: MagicMock,
|
||||
_mock_gateway_enabled: MagicMock,
|
||||
gateway: TelemetryGateway,
|
||||
mock_trace_manager: MagicMock,
|
||||
mock_ops_trace_manager: tuple[MagicMock, MagicMock],
|
||||
) -> None:
|
||||
context = {"app_id": "app-123", "user_id": "user-456"}
|
||||
payload = {"node_id": "node-abc"}
|
||||
|
||||
gateway.emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager)
|
||||
|
||||
mock_trace_manager.add_trace_task.assert_not_called()
|
||||
|
||||
@patch("enterprise.telemetry.gateway.is_gateway_enabled", return_value=True)
|
||||
@patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=True)
|
||||
def test_enterprise_only_trace_enqueued_when_ee_enabled(
|
||||
self,
|
||||
_mock_ee_enabled: MagicMock,
|
||||
_mock_gateway_enabled: MagicMock,
|
||||
gateway: TelemetryGateway,
|
||||
mock_trace_manager: MagicMock,
|
||||
mock_ops_trace_manager: tuple[MagicMock, MagicMock],
|
||||
) -> None:
|
||||
context = {"app_id": "app-123", "user_id": "user-456"}
|
||||
payload = {"node_id": "node-abc"}
|
||||
|
||||
gateway.emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager)
|
||||
|
||||
mock_trace_manager.add_trace_task.assert_called_once()
|
||||
|
||||
|
||||
class TestTelemetryGatewayMetricLogRouting:
|
||||
@pytest.fixture
|
||||
def gateway(self) -> TelemetryGateway:
|
||||
return TelemetryGateway()
|
||||
|
||||
@patch("enterprise.telemetry.gateway.is_gateway_enabled", return_value=True)
|
||||
@patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay")
|
||||
def test_metric_case_routes_to_celery_task(
|
||||
self,
|
||||
mock_delay: MagicMock,
|
||||
_mock_gateway_enabled: MagicMock,
|
||||
gateway: TelemetryGateway,
|
||||
) -> None:
|
||||
context = {"tenant_id": "tenant-123"}
|
||||
payload = {"app_id": "app-abc", "name": "My App"}
|
||||
|
||||
gateway.emit(TelemetryCase.APP_CREATED, context, payload)
|
||||
|
||||
mock_delay.assert_called_once()
|
||||
envelope_json = mock_delay.call_args[0][0]
|
||||
envelope = TelemetryEnvelope.model_validate_json(envelope_json)
|
||||
assert envelope.case == TelemetryCase.APP_CREATED
|
||||
assert envelope.tenant_id == "tenant-123"
|
||||
assert envelope.payload["app_id"] == "app-abc"
|
||||
|
||||
@patch("enterprise.telemetry.gateway.is_gateway_enabled", return_value=True)
|
||||
@patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay")
|
||||
def test_envelope_has_unique_event_id(
|
||||
self,
|
||||
mock_delay: MagicMock,
|
||||
_mock_gateway_enabled: MagicMock,
|
||||
gateway: TelemetryGateway,
|
||||
) -> None:
|
||||
context = {"tenant_id": "tenant-123"}
|
||||
payload = {"app_id": "app-abc"}
|
||||
|
||||
gateway.emit(TelemetryCase.APP_CREATED, context, payload)
|
||||
gateway.emit(TelemetryCase.APP_CREATED, context, payload)
|
||||
|
||||
assert mock_delay.call_count == 2
|
||||
envelope1 = TelemetryEnvelope.model_validate_json(mock_delay.call_args_list[0][0][0])
|
||||
envelope2 = TelemetryEnvelope.model_validate_json(mock_delay.call_args_list[1][0][0])
|
||||
assert envelope1.event_id != envelope2.event_id
|
||||
|
||||
|
||||
class TestTelemetryGatewayPayloadSizing:
|
||||
@pytest.fixture
|
||||
def gateway(self) -> TelemetryGateway:
|
||||
return TelemetryGateway()
|
||||
|
||||
@patch("enterprise.telemetry.gateway.is_gateway_enabled", return_value=True)
|
||||
@patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay")
|
||||
def test_small_payload_inlined(
|
||||
self,
|
||||
mock_delay: MagicMock,
|
||||
_mock_gateway_enabled: MagicMock,
|
||||
gateway: TelemetryGateway,
|
||||
) -> None:
|
||||
context = {"tenant_id": "tenant-123"}
|
||||
payload = {"key": "small_value"}
|
||||
|
||||
gateway.emit(TelemetryCase.APP_CREATED, context, payload)
|
||||
|
||||
envelope_json = mock_delay.call_args[0][0]
|
||||
envelope = TelemetryEnvelope.model_validate_json(envelope_json)
|
||||
assert envelope.payload == payload
|
||||
assert envelope.metadata is None
|
||||
|
||||
@patch("enterprise.telemetry.gateway.is_gateway_enabled", return_value=True)
|
||||
@patch("enterprise.telemetry.gateway.storage")
|
||||
@patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay")
|
||||
def test_large_payload_stored(
|
||||
self,
|
||||
mock_delay: MagicMock,
|
||||
mock_storage: MagicMock,
|
||||
_mock_gateway_enabled: MagicMock,
|
||||
gateway: TelemetryGateway,
|
||||
) -> None:
|
||||
context = {"tenant_id": "tenant-123"}
|
||||
large_value = "x" * (PAYLOAD_SIZE_THRESHOLD_BYTES + 1000)
|
||||
payload = {"key": large_value}
|
||||
|
||||
gateway.emit(TelemetryCase.APP_CREATED, context, payload)
|
||||
|
||||
mock_storage.save.assert_called_once()
|
||||
storage_key = mock_storage.save.call_args[0][0]
|
||||
assert storage_key.startswith("telemetry/tenant-123/")
|
||||
|
||||
envelope_json = mock_delay.call_args[0][0]
|
||||
envelope = TelemetryEnvelope.model_validate_json(envelope_json)
|
||||
assert envelope.payload == {}
|
||||
assert envelope.metadata is not None
|
||||
assert envelope.metadata["payload_ref"] == storage_key
|
||||
|
||||
@patch("enterprise.telemetry.gateway.is_gateway_enabled", return_value=True)
|
||||
@patch("enterprise.telemetry.gateway.storage")
|
||||
@patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay")
|
||||
def test_large_payload_fallback_on_storage_error(
|
||||
self,
|
||||
mock_delay: MagicMock,
|
||||
mock_storage: MagicMock,
|
||||
_mock_gateway_enabled: MagicMock,
|
||||
gateway: TelemetryGateway,
|
||||
) -> None:
|
||||
mock_storage.save.side_effect = Exception("Storage failure")
|
||||
context = {"tenant_id": "tenant-123"}
|
||||
large_value = "x" * (PAYLOAD_SIZE_THRESHOLD_BYTES + 1000)
|
||||
payload = {"key": large_value}
|
||||
|
||||
gateway.emit(TelemetryCase.APP_CREATED, context, payload)
|
||||
|
||||
envelope_json = mock_delay.call_args[0][0]
|
||||
envelope = TelemetryEnvelope.model_validate_json(envelope_json)
|
||||
assert envelope.payload == payload
|
||||
assert envelope.metadata is None
|
||||
|
||||
|
||||
class TestTelemetryGatewayFeatureFlag:
|
||||
@pytest.fixture
|
||||
def gateway(self) -> TelemetryGateway:
|
||||
return TelemetryGateway()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_trace_manager(self) -> MagicMock:
|
||||
return MagicMock()
|
||||
|
||||
@patch("enterprise.telemetry.gateway.is_gateway_enabled", return_value=False)
|
||||
@patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=True)
|
||||
def test_legacy_path_used_when_flag_disabled(
|
||||
self,
|
||||
_mock_ee_enabled: MagicMock,
|
||||
_mock_gateway_enabled: MagicMock,
|
||||
gateway: TelemetryGateway,
|
||||
mock_trace_manager: MagicMock,
|
||||
mock_ops_trace_manager: tuple[MagicMock, MagicMock],
|
||||
) -> None:
|
||||
context = {"app_id": "app-123", "user_id": "user-456"}
|
||||
payload = {"workflow_run_id": "run-abc"}
|
||||
|
||||
gateway.emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager)
|
||||
|
||||
mock_trace_manager.add_trace_task.assert_called_once()
|
||||
|
||||
@patch("enterprise.telemetry.gateway.is_gateway_enabled", return_value=False)
|
||||
@patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay")
|
||||
def test_metric_log_not_processed_via_legacy_path(
|
||||
self,
|
||||
mock_delay: MagicMock,
|
||||
_mock_gateway_enabled: MagicMock,
|
||||
gateway: TelemetryGateway,
|
||||
) -> None:
|
||||
context = {"tenant_id": "tenant-123"}
|
||||
payload = {"app_id": "app-abc"}
|
||||
|
||||
gateway.emit(TelemetryCase.APP_CREATED, context, payload)
|
||||
|
||||
mock_delay.assert_not_called()
|
||||
|
||||
|
||||
class TestTelemetryGatewayLegacyPath:
|
||||
@pytest.fixture
|
||||
def gateway(self) -> TelemetryGateway:
|
||||
return TelemetryGateway()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_trace_manager(self) -> MagicMock:
|
||||
return MagicMock()
|
||||
|
||||
@patch("enterprise.telemetry.gateway.is_gateway_enabled", return_value=False)
|
||||
@patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=False)
|
||||
def test_legacy_ce_eligible_enqueued_when_ee_disabled(
|
||||
self,
|
||||
_mock_ee_enabled: MagicMock,
|
||||
_mock_gateway_enabled: MagicMock,
|
||||
gateway: TelemetryGateway,
|
||||
mock_trace_manager: MagicMock,
|
||||
mock_ops_trace_manager: tuple[MagicMock, MagicMock],
|
||||
) -> None:
|
||||
context = {"app_id": "app-123", "user_id": "user-456"}
|
||||
payload = {"workflow_run_id": "run-abc"}
|
||||
|
||||
gateway.emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager)
|
||||
|
||||
mock_trace_manager.add_trace_task.assert_called_once()
|
||||
|
||||
@patch("enterprise.telemetry.gateway.is_gateway_enabled", return_value=False)
|
||||
@patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=False)
|
||||
def test_legacy_enterprise_only_dropped_when_ee_disabled(
|
||||
self,
|
||||
_mock_ee_enabled: MagicMock,
|
||||
_mock_gateway_enabled: MagicMock,
|
||||
gateway: TelemetryGateway,
|
||||
mock_trace_manager: MagicMock,
|
||||
mock_ops_trace_manager: tuple[MagicMock, MagicMock],
|
||||
) -> None:
|
||||
context = {"app_id": "app-123", "user_id": "user-456"}
|
||||
payload = {"node_id": "node-abc"}
|
||||
|
||||
gateway.emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager)
|
||||
|
||||
mock_trace_manager.add_trace_task.assert_not_called()
|
||||
|
||||
|
||||
class TestModuleLevelFunctions:
|
||||
def test_get_gateway_returns_singleton(self) -> None:
|
||||
gateway1 = get_gateway()
|
||||
gateway2 = get_gateway()
|
||||
assert gateway1 is gateway2
|
||||
|
||||
@patch("enterprise.telemetry.gateway.is_gateway_enabled", return_value=True)
|
||||
@patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=True)
|
||||
def test_emit_function_uses_gateway(
|
||||
self,
|
||||
_mock_ee_enabled: MagicMock,
|
||||
_mock_gateway_enabled: MagicMock,
|
||||
mock_ops_trace_manager: tuple[MagicMock, MagicMock],
|
||||
) -> None:
|
||||
mock_trace_manager = MagicMock()
|
||||
context = {"app_id": "app-123", "user_id": "user-456"}
|
||||
payload = {"workflow_run_id": "run-abc"}
|
||||
|
||||
emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager)
|
||||
|
||||
mock_trace_manager.add_trace_task.assert_called_once()
|
||||
|
||||
|
||||
class TestTraceTaskNameMapping:
|
||||
def test_workflow_run_mapping(self) -> None:
|
||||
assert CASE_TO_TRACE_TASK_NAME[TelemetryCase.WORKFLOW_RUN] == "workflow"
|
||||
|
||||
def test_message_run_mapping(self) -> None:
|
||||
assert CASE_TO_TRACE_TASK_NAME[TelemetryCase.MESSAGE_RUN] == "message"
|
||||
|
||||
def test_node_execution_mapping(self) -> None:
|
||||
assert CASE_TO_TRACE_TASK_NAME[TelemetryCase.NODE_EXECUTION] == "node_execution"
|
||||
|
||||
def test_draft_node_execution_mapping(self) -> None:
|
||||
assert CASE_TO_TRACE_TASK_NAME[TelemetryCase.DRAFT_NODE_EXECUTION] == "draft_node_execution"
|
||||
|
||||
def test_prompt_generation_mapping(self) -> None:
|
||||
assert CASE_TO_TRACE_TASK_NAME[TelemetryCase.PROMPT_GENERATION] == "prompt_generation"
|
||||
Loading…
Reference in New Issue