From 51b0c5c89cb22cb2afbae65134b25ca2412e6f76 Mon Sep 17 00:00:00 2001 From: GareArc Date: Thu, 5 Feb 2026 19:17:08 -0800 Subject: [PATCH] feat(telemetry): implement gateway routing and enqueue logic Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-Claude) Co-authored-by: Sisyphus --- .../learnings.md | 81 ++++ api/enterprise/telemetry/gateway.py | 311 ++++++++++++- .../enterprise/telemetry/test_gateway.py | 435 ++++++++++++++++++ 3 files changed, 825 insertions(+), 2 deletions(-) create mode 100644 api/tests/unit_tests/enterprise/telemetry/test_gateway.py diff --git a/api/.sisyphus/notepads/enterprise-telemetry-gateway-refactor/learnings.md b/api/.sisyphus/notepads/enterprise-telemetry-gateway-refactor/learnings.md index 25b4f75059..9b4223902b 100644 --- a/api/.sisyphus/notepads/enterprise-telemetry-gateway-refactor/learnings.md +++ b/api/.sisyphus/notepads/enterprise-telemetry-gateway-refactor/learnings.md @@ -168,3 +168,84 @@ basedpyright # 0 erro Task 3 (gateway implementation) will wire `TelemetryGateway.emit()` to call `process_enterprise_telemetry.delay()`. Once Task 3 completes, handlers can optionally be updated to call gateway directly instead of task. +## [2026-02-06] Task 3: TelemetryGateway Implementation + +### Implementation Decisions + +**Gateway Architecture:** +- `TelemetryGateway.emit(case, context, payload, trace_manager)` as main entry point +- Routes based on `CASE_ROUTING[case].signal_type`: + - `trace` → TraceQueueManager.add_trace_task() + - `metric_log` → process_enterprise_telemetry.delay() +- Feature flag `ENTERPRISE_TELEMETRY_GATEWAY_ENABLED` gates new vs legacy behavior + +**CE Eligibility:** +- Trace cases with `ce_eligible=False` dropped when enterprise disabled +- CE-eligible traces (WORKFLOW_RUN, MESSAGE_RUN) always processed +- Enterprise-only traces (NODE_EXECUTION, DRAFT_NODE_EXECUTION, PROMPT_GENERATION) require EE + +**Payload Sizing:** +- Threshold: 1MB (PAYLOAD_SIZE_THRESHOLD_BYTES) +- Large payloads stored to `telemetry/{tenant_id}/{event_id}.json` +- Storage failures fall back to inline payload (best-effort) + +**Legacy Path:** +- When gateway disabled, mimics original TelemetryFacade behavior +- Only processes trace cases, metric/log cases dropped +- Imports deferred until after route check to avoid circular imports + +### Testing Patterns + +**Circular Import Workaround:** +- `ops_trace_manager` has deep import chain causing circular imports in tests +- Solution: `mock_ops_trace_manager` fixture patches `sys.modules` before import +- Trace routing tests require fixture; metric/log tests don't + +**Mock Parameter Naming:** +- Prefixed unused mock params with `_` (e.g., `_mock_ee_enabled`) +- Ruff PT019 warnings are style hints, not errors + +**Coverage:** +- 38 tests total +- Feature flag on/off paths +- Trace routing (CE-eligible and enterprise-only) +- Metric/log routing +- Large payload storage and fallback +- Legacy path behavior + +### Files Modified/Created + +- `enterprise/telemetry/gateway.py` (350 lines, expanded from 27) + - Added TelemetryGateway class + - Added CASE_TO_TRACE_TASK_NAME mapping + - Added is_gateway_enabled() and _is_enterprise_telemetry_enabled() + - Added module-level emit() convenience function +- `tests/unit_tests/enterprise/telemetry/test_gateway.py` (NEW, 422 lines, 38 tests) + +### Verification + +```bash +pytest tests/unit_tests/enterprise/telemetry/test_gateway.py -v # 38 passed +ruff check --fix # 22 PT019 warnings (style) +basedpyright # 0 errors, 0 warnings +``` + +### Key Insights + +**Import Deferral:** +- Legacy path must defer `ops_trace_manager` import until after route check +- Otherwise metric/log cases trigger circular import chain +- Pattern: check signal_type first, then import if needed + +**TelemetryCase to TraceTaskName Mapping:** +- WORKFLOW_RUN → "workflow" +- MESSAGE_RUN → "message" +- NODE_EXECUTION → "node_execution" +- DRAFT_NODE_EXECUTION → "draft_node_execution" +- PROMPT_GENERATION → "prompt_generation" + +**Feature Flag Design:** +- Default OFF for safe rollout +- Env var: ENTERPRISE_TELEMETRY_GATEWAY_ENABLED +- Truthy values: "true", "1", "yes" (case-insensitive) + diff --git a/api/enterprise/telemetry/gateway.py b/api/enterprise/telemetry/gateway.py index d04222fd9c..104d3fc94c 100644 --- a/api/enterprise/telemetry/gateway.py +++ b/api/enterprise/telemetry/gateway.py @@ -1,12 +1,36 @@ -"""Telemetry gateway routing configuration. +"""Telemetry gateway routing configuration and implementation. This module defines the routing table that maps telemetry cases to their processing routes (trace vs metric/log) and customer engagement eligibility. +It also provides the TelemetryGateway class that routes events to the +appropriate processing path. """ from __future__ import annotations -from enterprise.telemetry.contracts import CaseRoute, TelemetryCase +import json +import logging +import os +import uuid +from typing import TYPE_CHECKING, Any + +from enterprise.telemetry.contracts import CaseRoute, TelemetryCase, TelemetryEnvelope +from extensions.ext_storage import storage + +if TYPE_CHECKING: + from core.ops.ops_trace_manager import TraceQueueManager + +logger = logging.getLogger(__name__) + +PAYLOAD_SIZE_THRESHOLD_BYTES = 1 * 1024 * 1024 + +CASE_TO_TRACE_TASK_NAME: dict[TelemetryCase, str] = { + TelemetryCase.WORKFLOW_RUN: "workflow", + TelemetryCase.MESSAGE_RUN: "message", + TelemetryCase.NODE_EXECUTION: "node_execution", + TelemetryCase.DRAFT_NODE_EXECUTION: "draft_node_execution", + TelemetryCase.PROMPT_GENERATION: "prompt_generation", +} CASE_ROUTING: dict[TelemetryCase, CaseRoute] = { TelemetryCase.WORKFLOW_RUN: CaseRoute(signal_type="trace", ce_eligible=True), @@ -24,3 +48,286 @@ CASE_ROUTING: dict[TelemetryCase, CaseRoute] = { TelemetryCase.DATASET_RETRIEVAL: CaseRoute(signal_type="metric_log", ce_eligible=False), TelemetryCase.GENERATE_NAME: CaseRoute(signal_type="metric_log", ce_eligible=False), } + + +def is_gateway_enabled() -> bool: + """Check if the telemetry gateway is enabled via feature flag. + + Returns: + True if ENTERPRISE_TELEMETRY_GATEWAY_ENABLED is set to a truthy value. + """ + return os.environ.get("ENTERPRISE_TELEMETRY_GATEWAY_ENABLED", "").lower() in ("true", "1", "yes") + + +def _is_enterprise_telemetry_enabled() -> bool: + """Check if enterprise telemetry is enabled. + + Wraps the check from core.telemetry to handle import failures gracefully. + """ + try: + from enterprise.telemetry.exporter import is_enterprise_telemetry_enabled + + return is_enterprise_telemetry_enabled() + except Exception: + return False + + +class TelemetryGateway: + """Gateway for routing telemetry events to appropriate processing paths. + + Routes trace-shaped events to TraceQueueManager and metric/log events + to the enterprise telemetry Celery queue. Handles CE eligibility checks, + large payload storage, and feature flag gating. + """ + + def emit( + self, + case: TelemetryCase, + context: dict[str, Any], + payload: dict[str, Any], + trace_manager: TraceQueueManager | None = None, + ) -> None: + """Emit a telemetry event through the gateway. + + Routes the event based on its case type: + - trace: Routes to TraceQueueManager for existing trace pipeline + - metric_log: Routes to enterprise telemetry Celery task + + Args: + case: The telemetry case type. + context: Event context containing tenant_id, app_id, user_id. + payload: The event payload data. + trace_manager: Optional TraceQueueManager for trace routing. + """ + if not is_gateway_enabled(): + self._emit_legacy(case, context, payload, trace_manager) + return + + route = CASE_ROUTING.get(case) + if route is None: + logger.warning("Unknown telemetry case: %s, dropping event", case) + return + + if route.signal_type == "trace": + self._emit_trace(case, context, payload, route, trace_manager) + else: + self._emit_metric_log(case, context, payload) + + def _emit_legacy( + self, + case: TelemetryCase, + context: dict[str, Any], + payload: dict[str, Any], + trace_manager: TraceQueueManager | None, + ) -> None: + """Emit using legacy path (TelemetryFacade behavior). + + Used when gateway feature flag is disabled. + """ + route = CASE_ROUTING.get(case) + if route is None or route.signal_type != "trace": + return + + trace_task_name_str = CASE_TO_TRACE_TASK_NAME.get(case) + if trace_task_name_str is None: + return + + if not route.ce_eligible and not _is_enterprise_telemetry_enabled(): + return + + from core.ops.entities.trace_entity import TraceTaskName + from core.ops.ops_trace_manager import ( + TraceQueueManager as LocalTraceQueueManager, + ) + from core.ops.ops_trace_manager import ( + TraceTask, + ) + + try: + trace_task_name = TraceTaskName(trace_task_name_str) + except ValueError: + logger.warning("Invalid trace task name: %s", trace_task_name_str) + return + + queue_manager = trace_manager or LocalTraceQueueManager( + app_id=context.get("app_id"), + user_id=context.get("user_id"), + ) + + queue_manager.add_trace_task( + TraceTask( + trace_task_name, + **payload, + ) + ) + + def _emit_trace( + self, + case: TelemetryCase, + context: dict[str, Any], + payload: dict[str, Any], + route: CaseRoute, + trace_manager: TraceQueueManager | None, + ) -> None: + """Emit a trace-shaped event to TraceQueueManager. + + Args: + case: The telemetry case type. + context: Event context. + payload: The event payload. + route: Routing configuration for this case. + trace_manager: Optional TraceQueueManager. + """ + from core.ops.entities.trace_entity import TraceTaskName + from core.ops.ops_trace_manager import ( + TraceQueueManager as LocalTraceQueueManager, + ) + from core.ops.ops_trace_manager import ( + TraceTask, + ) + + if not route.ce_eligible and not _is_enterprise_telemetry_enabled(): + logger.debug( + "Dropping enterprise-only trace event: case=%s (EE disabled)", + case, + ) + return + + trace_task_name_str = CASE_TO_TRACE_TASK_NAME.get(case) + if trace_task_name_str is None: + logger.warning("No TraceTaskName mapping for case: %s", case) + return + + try: + trace_task_name = TraceTaskName(trace_task_name_str) + except ValueError: + logger.warning("Invalid trace task name: %s", trace_task_name_str) + return + + queue_manager = trace_manager or LocalTraceQueueManager( + app_id=context.get("app_id"), + user_id=context.get("user_id"), + ) + + queue_manager.add_trace_task( + TraceTask( + trace_task_name, + **payload, + ) + ) + logger.debug( + "Enqueued trace task: case=%s, app_id=%s", + case, + context.get("app_id"), + ) + + def _emit_metric_log( + self, + case: TelemetryCase, + context: dict[str, Any], + payload: dict[str, Any], + ) -> None: + """Emit a metric/log event to the enterprise telemetry Celery queue. + + Args: + case: The telemetry case type. + context: Event context containing tenant_id. + payload: The event payload. + """ + from tasks.enterprise_telemetry_task import process_enterprise_telemetry + + tenant_id = context.get("tenant_id", "") + event_id = str(uuid.uuid4()) + + payload_for_envelope, payload_ref = self._handle_payload_sizing(payload, tenant_id, event_id) + + envelope = TelemetryEnvelope( + case=case, + tenant_id=tenant_id, + event_id=event_id, + payload=payload_for_envelope, + metadata={"payload_ref": payload_ref} if payload_ref else None, + ) + + process_enterprise_telemetry.delay(envelope.model_dump_json()) + logger.debug( + "Enqueued metric/log event: case=%s, tenant_id=%s, event_id=%s", + case, + tenant_id, + event_id, + ) + + def _handle_payload_sizing( + self, + payload: dict[str, Any], + tenant_id: str, + event_id: str, + ) -> tuple[dict[str, Any], str | None]: + """Handle large payload storage. + + If payload exceeds threshold, stores to shared storage and returns + a reference. Otherwise returns payload as-is. + + Args: + payload: The event payload. + tenant_id: Tenant identifier for storage path. + event_id: Event identifier for storage path. + + Returns: + Tuple of (payload_for_envelope, payload_ref). + If stored, payload_for_envelope is empty and payload_ref is set. + Otherwise, payload_for_envelope is the original payload and + payload_ref is None. + """ + try: + payload_json = json.dumps(payload) + payload_size = len(payload_json.encode("utf-8")) + except (TypeError, ValueError): + logger.warning("Failed to serialize payload for sizing: event_id=%s", event_id) + return payload, None + + if payload_size <= PAYLOAD_SIZE_THRESHOLD_BYTES: + return payload, None + + storage_key = f"telemetry/{tenant_id}/{event_id}.json" + try: + storage.save(storage_key, payload_json.encode("utf-8")) + logger.debug("Stored large payload to storage: key=%s, size=%d", storage_key, payload_size) + return {}, storage_key + except Exception: + logger.warning("Failed to store large payload, inlining instead: event_id=%s", event_id, exc_info=True) + return payload, None + + +_gateway: TelemetryGateway | None = None + + +def get_gateway() -> TelemetryGateway: + """Get the module-level gateway instance. + + Returns: + The singleton TelemetryGateway instance. + """ + global _gateway + if _gateway is None: + _gateway = TelemetryGateway() + return _gateway + + +def emit( + case: TelemetryCase, + context: dict[str, Any], + payload: dict[str, Any], + trace_manager: TraceQueueManager | None = None, +) -> None: + """Emit a telemetry event through the gateway. + + Convenience function that uses the module-level gateway instance. + + Args: + case: The telemetry case type. + context: Event context containing tenant_id, app_id, user_id. + payload: The event payload data. + trace_manager: Optional TraceQueueManager for trace routing. + """ + get_gateway().emit(case, context, payload, trace_manager) diff --git a/api/tests/unit_tests/enterprise/telemetry/test_gateway.py b/api/tests/unit_tests/enterprise/telemetry/test_gateway.py new file mode 100644 index 0000000000..4041ee424b --- /dev/null +++ b/api/tests/unit_tests/enterprise/telemetry/test_gateway.py @@ -0,0 +1,435 @@ +from __future__ import annotations + +import sys +from unittest.mock import MagicMock, patch + +import pytest + +from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope +from enterprise.telemetry.gateway import ( + CASE_ROUTING, + CASE_TO_TRACE_TASK_NAME, + PAYLOAD_SIZE_THRESHOLD_BYTES, + TelemetryGateway, + emit, + get_gateway, + is_gateway_enabled, +) + + +class TestIsGatewayEnabled: + @pytest.mark.parametrize( + ("env_value", "expected"), + [ + ("true", True), + ("True", True), + ("TRUE", True), + ("1", True), + ("yes", True), + ("YES", True), + ("false", False), + ("False", False), + ("0", False), + ("no", False), + ("", False), + ], + ) + def test_feature_flag_values(self, env_value: str, expected: bool) -> None: + with patch.dict("os.environ", {"ENTERPRISE_TELEMETRY_GATEWAY_ENABLED": env_value}): + assert is_gateway_enabled() is expected + + def test_missing_env_var(self) -> None: + with patch.dict("os.environ", {}, clear=True): + assert is_gateway_enabled() is False + + +class TestCaseRoutingTable: + def test_all_cases_have_routing(self) -> None: + for case in TelemetryCase: + assert case in CASE_ROUTING, f"Missing routing for {case}" + + def test_trace_cases(self) -> None: + trace_cases = [ + TelemetryCase.WORKFLOW_RUN, + TelemetryCase.MESSAGE_RUN, + TelemetryCase.NODE_EXECUTION, + TelemetryCase.DRAFT_NODE_EXECUTION, + TelemetryCase.PROMPT_GENERATION, + ] + for case in trace_cases: + assert CASE_ROUTING[case].signal_type == "trace", f"{case} should be trace" + + def test_metric_log_cases(self) -> None: + metric_log_cases = [ + TelemetryCase.APP_CREATED, + TelemetryCase.APP_UPDATED, + TelemetryCase.APP_DELETED, + TelemetryCase.FEEDBACK_CREATED, + TelemetryCase.TOOL_EXECUTION, + TelemetryCase.MODERATION_CHECK, + TelemetryCase.SUGGESTED_QUESTION, + TelemetryCase.DATASET_RETRIEVAL, + TelemetryCase.GENERATE_NAME, + ] + for case in metric_log_cases: + assert CASE_ROUTING[case].signal_type == "metric_log", f"{case} should be metric_log" + + def test_ce_eligible_cases(self) -> None: + ce_eligible_cases = [TelemetryCase.WORKFLOW_RUN, TelemetryCase.MESSAGE_RUN] + for case in ce_eligible_cases: + assert CASE_ROUTING[case].ce_eligible is True, f"{case} should be CE eligible" + + def test_enterprise_only_cases(self) -> None: + enterprise_only_cases = [ + TelemetryCase.NODE_EXECUTION, + TelemetryCase.DRAFT_NODE_EXECUTION, + TelemetryCase.PROMPT_GENERATION, + ] + for case in enterprise_only_cases: + assert CASE_ROUTING[case].ce_eligible is False, f"{case} should be enterprise-only" + + def test_trace_cases_have_task_name_mapping(self) -> None: + trace_cases = [c for c in TelemetryCase if CASE_ROUTING[c].signal_type == "trace"] + for case in trace_cases: + assert case in CASE_TO_TRACE_TASK_NAME, f"Missing TraceTaskName mapping for {case}" + + +@pytest.fixture +def mock_ops_trace_manager(): + mock_module = MagicMock() + mock_trace_task_class = MagicMock() + mock_trace_task_class.return_value = MagicMock() + mock_module.TraceTask = mock_trace_task_class + mock_module.TraceQueueManager = MagicMock() + + mock_trace_entity = MagicMock() + mock_trace_task_name = MagicMock() + mock_trace_task_name.return_value = "workflow" + mock_trace_entity.TraceTaskName = mock_trace_task_name + + with ( + patch.dict(sys.modules, {"core.ops.ops_trace_manager": mock_module}), + patch.dict(sys.modules, {"core.ops.entities.trace_entity": mock_trace_entity}), + ): + yield mock_module, mock_trace_entity + + +class TestTelemetryGatewayTraceRouting: + @pytest.fixture + def gateway(self) -> TelemetryGateway: + return TelemetryGateway() + + @pytest.fixture + def mock_trace_manager(self) -> MagicMock: + return MagicMock() + + @patch("enterprise.telemetry.gateway.is_gateway_enabled", return_value=True) + @patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=True) + def test_trace_case_routes_to_trace_manager( + self, + _mock_ee_enabled: MagicMock, + _mock_gateway_enabled: MagicMock, + gateway: TelemetryGateway, + mock_trace_manager: MagicMock, + mock_ops_trace_manager: tuple[MagicMock, MagicMock], + ) -> None: + context = {"app_id": "app-123", "user_id": "user-456", "tenant_id": "tenant-789"} + payload = {"workflow_run_id": "run-abc"} + + gateway.emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @patch("enterprise.telemetry.gateway.is_gateway_enabled", return_value=True) + @patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=False) + def test_ce_eligible_trace_enqueued_when_ee_disabled( + self, + _mock_ee_enabled: MagicMock, + _mock_gateway_enabled: MagicMock, + gateway: TelemetryGateway, + mock_trace_manager: MagicMock, + mock_ops_trace_manager: tuple[MagicMock, MagicMock], + ) -> None: + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"workflow_run_id": "run-abc"} + + gateway.emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @patch("enterprise.telemetry.gateway.is_gateway_enabled", return_value=True) + @patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=False) + def test_enterprise_only_trace_dropped_when_ee_disabled( + self, + _mock_ee_enabled: MagicMock, + _mock_gateway_enabled: MagicMock, + gateway: TelemetryGateway, + mock_trace_manager: MagicMock, + mock_ops_trace_manager: tuple[MagicMock, MagicMock], + ) -> None: + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"node_id": "node-abc"} + + gateway.emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_not_called() + + @patch("enterprise.telemetry.gateway.is_gateway_enabled", return_value=True) + @patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=True) + def test_enterprise_only_trace_enqueued_when_ee_enabled( + self, + _mock_ee_enabled: MagicMock, + _mock_gateway_enabled: MagicMock, + gateway: TelemetryGateway, + mock_trace_manager: MagicMock, + mock_ops_trace_manager: tuple[MagicMock, MagicMock], + ) -> None: + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"node_id": "node-abc"} + + gateway.emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + +class TestTelemetryGatewayMetricLogRouting: + @pytest.fixture + def gateway(self) -> TelemetryGateway: + return TelemetryGateway() + + @patch("enterprise.telemetry.gateway.is_gateway_enabled", return_value=True) + @patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") + def test_metric_case_routes_to_celery_task( + self, + mock_delay: MagicMock, + _mock_gateway_enabled: MagicMock, + gateway: TelemetryGateway, + ) -> None: + context = {"tenant_id": "tenant-123"} + payload = {"app_id": "app-abc", "name": "My App"} + + gateway.emit(TelemetryCase.APP_CREATED, context, payload) + + mock_delay.assert_called_once() + envelope_json = mock_delay.call_args[0][0] + envelope = TelemetryEnvelope.model_validate_json(envelope_json) + assert envelope.case == TelemetryCase.APP_CREATED + assert envelope.tenant_id == "tenant-123" + assert envelope.payload["app_id"] == "app-abc" + + @patch("enterprise.telemetry.gateway.is_gateway_enabled", return_value=True) + @patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") + def test_envelope_has_unique_event_id( + self, + mock_delay: MagicMock, + _mock_gateway_enabled: MagicMock, + gateway: TelemetryGateway, + ) -> None: + context = {"tenant_id": "tenant-123"} + payload = {"app_id": "app-abc"} + + gateway.emit(TelemetryCase.APP_CREATED, context, payload) + gateway.emit(TelemetryCase.APP_CREATED, context, payload) + + assert mock_delay.call_count == 2 + envelope1 = TelemetryEnvelope.model_validate_json(mock_delay.call_args_list[0][0][0]) + envelope2 = TelemetryEnvelope.model_validate_json(mock_delay.call_args_list[1][0][0]) + assert envelope1.event_id != envelope2.event_id + + +class TestTelemetryGatewayPayloadSizing: + @pytest.fixture + def gateway(self) -> TelemetryGateway: + return TelemetryGateway() + + @patch("enterprise.telemetry.gateway.is_gateway_enabled", return_value=True) + @patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") + def test_small_payload_inlined( + self, + mock_delay: MagicMock, + _mock_gateway_enabled: MagicMock, + gateway: TelemetryGateway, + ) -> None: + context = {"tenant_id": "tenant-123"} + payload = {"key": "small_value"} + + gateway.emit(TelemetryCase.APP_CREATED, context, payload) + + envelope_json = mock_delay.call_args[0][0] + envelope = TelemetryEnvelope.model_validate_json(envelope_json) + assert envelope.payload == payload + assert envelope.metadata is None + + @patch("enterprise.telemetry.gateway.is_gateway_enabled", return_value=True) + @patch("enterprise.telemetry.gateway.storage") + @patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") + def test_large_payload_stored( + self, + mock_delay: MagicMock, + mock_storage: MagicMock, + _mock_gateway_enabled: MagicMock, + gateway: TelemetryGateway, + ) -> None: + context = {"tenant_id": "tenant-123"} + large_value = "x" * (PAYLOAD_SIZE_THRESHOLD_BYTES + 1000) + payload = {"key": large_value} + + gateway.emit(TelemetryCase.APP_CREATED, context, payload) + + mock_storage.save.assert_called_once() + storage_key = mock_storage.save.call_args[0][0] + assert storage_key.startswith("telemetry/tenant-123/") + + envelope_json = mock_delay.call_args[0][0] + envelope = TelemetryEnvelope.model_validate_json(envelope_json) + assert envelope.payload == {} + assert envelope.metadata is not None + assert envelope.metadata["payload_ref"] == storage_key + + @patch("enterprise.telemetry.gateway.is_gateway_enabled", return_value=True) + @patch("enterprise.telemetry.gateway.storage") + @patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") + def test_large_payload_fallback_on_storage_error( + self, + mock_delay: MagicMock, + mock_storage: MagicMock, + _mock_gateway_enabled: MagicMock, + gateway: TelemetryGateway, + ) -> None: + mock_storage.save.side_effect = Exception("Storage failure") + context = {"tenant_id": "tenant-123"} + large_value = "x" * (PAYLOAD_SIZE_THRESHOLD_BYTES + 1000) + payload = {"key": large_value} + + gateway.emit(TelemetryCase.APP_CREATED, context, payload) + + envelope_json = mock_delay.call_args[0][0] + envelope = TelemetryEnvelope.model_validate_json(envelope_json) + assert envelope.payload == payload + assert envelope.metadata is None + + +class TestTelemetryGatewayFeatureFlag: + @pytest.fixture + def gateway(self) -> TelemetryGateway: + return TelemetryGateway() + + @pytest.fixture + def mock_trace_manager(self) -> MagicMock: + return MagicMock() + + @patch("enterprise.telemetry.gateway.is_gateway_enabled", return_value=False) + @patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=True) + def test_legacy_path_used_when_flag_disabled( + self, + _mock_ee_enabled: MagicMock, + _mock_gateway_enabled: MagicMock, + gateway: TelemetryGateway, + mock_trace_manager: MagicMock, + mock_ops_trace_manager: tuple[MagicMock, MagicMock], + ) -> None: + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"workflow_run_id": "run-abc"} + + gateway.emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @patch("enterprise.telemetry.gateway.is_gateway_enabled", return_value=False) + @patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") + def test_metric_log_not_processed_via_legacy_path( + self, + mock_delay: MagicMock, + _mock_gateway_enabled: MagicMock, + gateway: TelemetryGateway, + ) -> None: + context = {"tenant_id": "tenant-123"} + payload = {"app_id": "app-abc"} + + gateway.emit(TelemetryCase.APP_CREATED, context, payload) + + mock_delay.assert_not_called() + + +class TestTelemetryGatewayLegacyPath: + @pytest.fixture + def gateway(self) -> TelemetryGateway: + return TelemetryGateway() + + @pytest.fixture + def mock_trace_manager(self) -> MagicMock: + return MagicMock() + + @patch("enterprise.telemetry.gateway.is_gateway_enabled", return_value=False) + @patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=False) + def test_legacy_ce_eligible_enqueued_when_ee_disabled( + self, + _mock_ee_enabled: MagicMock, + _mock_gateway_enabled: MagicMock, + gateway: TelemetryGateway, + mock_trace_manager: MagicMock, + mock_ops_trace_manager: tuple[MagicMock, MagicMock], + ) -> None: + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"workflow_run_id": "run-abc"} + + gateway.emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @patch("enterprise.telemetry.gateway.is_gateway_enabled", return_value=False) + @patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=False) + def test_legacy_enterprise_only_dropped_when_ee_disabled( + self, + _mock_ee_enabled: MagicMock, + _mock_gateway_enabled: MagicMock, + gateway: TelemetryGateway, + mock_trace_manager: MagicMock, + mock_ops_trace_manager: tuple[MagicMock, MagicMock], + ) -> None: + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"node_id": "node-abc"} + + gateway.emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_not_called() + + +class TestModuleLevelFunctions: + def test_get_gateway_returns_singleton(self) -> None: + gateway1 = get_gateway() + gateway2 = get_gateway() + assert gateway1 is gateway2 + + @patch("enterprise.telemetry.gateway.is_gateway_enabled", return_value=True) + @patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=True) + def test_emit_function_uses_gateway( + self, + _mock_ee_enabled: MagicMock, + _mock_gateway_enabled: MagicMock, + mock_ops_trace_manager: tuple[MagicMock, MagicMock], + ) -> None: + mock_trace_manager = MagicMock() + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"workflow_run_id": "run-abc"} + + emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + +class TestTraceTaskNameMapping: + def test_workflow_run_mapping(self) -> None: + assert CASE_TO_TRACE_TASK_NAME[TelemetryCase.WORKFLOW_RUN] == "workflow" + + def test_message_run_mapping(self) -> None: + assert CASE_TO_TRACE_TASK_NAME[TelemetryCase.MESSAGE_RUN] == "message" + + def test_node_execution_mapping(self) -> None: + assert CASE_TO_TRACE_TASK_NAME[TelemetryCase.NODE_EXECUTION] == "node_execution" + + def test_draft_node_execution_mapping(self) -> None: + assert CASE_TO_TRACE_TASK_NAME[TelemetryCase.DRAFT_NODE_EXECUTION] == "draft_node_execution" + + def test_prompt_generation_mapping(self) -> None: + assert CASE_TO_TRACE_TASK_NAME[TelemetryCase.PROMPT_GENERATION] == "prompt_generation"