From 7ef139caddf1d7f16ea62352479e121a3780fc93 Mon Sep 17 00:00:00 2001 From: GareArc Date: Wed, 4 Mar 2026 16:59:37 -0800 Subject: [PATCH] Squash merge 1.12.1-otel-ee into release/e-1.12.1 --- api/.ruff.toml | 8 +- api/README.md | 3 +- api/app_factory.py | 2 + api/configs/app_config.py | 4 +- api/configs/enterprise/__init__.py | 46 + api/controllers/console/app/generator.py | 70 +- api/controllers/console/app/ops_trace.py | 15 +- api/core/agent/base_agent_runner.py | 2 +- .../advanced_chat/generate_task_pipeline.py | 36 +- api/core/app/apps/workflow/app_generator.py | 5 +- .../easy_ui_based_generate_task_pipeline.py | 22 +- api/core/app/workflow/layers/persistence.py | 139 ++- .../agent_tool_callback_handler.py | 35 +- api/core/llm_generator/entities.py | 2 + api/core/llm_generator/llm_generator.py | 565 ++++++++--- api/core/logging/filters.py | 32 +- api/core/moderation/input_moderation.py | 25 +- api/core/ops/entities/trace_entity.py | 151 ++- api/core/ops/langfuse_trace/langfuse_trace.py | 57 +- .../ops/langsmith_trace/langsmith_trace.py | 40 +- api/core/ops/ops_trace_manager.py | 516 +++++++++- api/core/rag/retrieval/dataset_retrieval.py | 24 +- api/core/telemetry/__init__.py | 43 + api/core/telemetry/events.py | 21 + api/core/telemetry/gateway.py | 229 +++++ api/core/tools/workflow_as_tool/tool.py | 7 +- api/core/workflow/enums.py | 2 + api/core/workflow/nodes/llm/node.py | 2 + api/core/workflow/nodes/tool/tool_node.py | 18 + api/enterprise/__init__.py | 0 api/enterprise/telemetry/DATA_DICTIONARY.md | 522 ++++++++++ api/enterprise/telemetry/README.md | 116 +++ api/enterprise/telemetry/__init__.py | 0 api/enterprise/telemetry/contracts.py | 73 ++ api/enterprise/telemetry/draft_trace.py | 77 ++ api/enterprise/telemetry/enterprise_trace.py | 938 ++++++++++++++++++ api/enterprise/telemetry/entities/__init__.py | 121 +++ api/enterprise/telemetry/event_handlers.py | 99 ++ api/enterprise/telemetry/exporter.py | 284 ++++++ api/enterprise/telemetry/id_generator.py | 76 ++ api/enterprise/telemetry/metric_handler.py | 381 +++++++ api/enterprise/telemetry/telemetry_log.py | 122 +++ api/events/app_event.py | 6 + api/events/feedback_event.py | 4 + api/extensions/ext_celery.py | 2 + api/extensions/ext_enterprise_telemetry.py | 50 + api/extensions/ext_otel.py | 20 +- api/extensions/otel/parser/__init__.py | 3 +- api/extensions/otel/parser/base.py | 26 +- api/extensions/otel/parser/llm.py | 35 +- api/extensions/otel/parser/retrieval.py | 38 +- api/extensions/otel/parser/tool.py | 16 +- api/extensions/otel/semconv/dify.py | 12 + api/models/model.py | 8 +- api/services/app_service.py | 4 +- .../enterprise/account_deletion_sync.py | 4 +- api/services/message_service.py | 21 +- api/services/ops_service.py | 21 +- api/services/workflow_service.py | 24 +- api/tasks/enterprise_telemetry_task.py | 52 + api/tasks/ops_trace_task.py | 28 +- .../core/ops/test_trace_queue_manager.py | 200 ++++ .../unit_tests/core/telemetry/test_facade.py | 181 ++++ .../telemetry/test_gateway_integration.py | 225 +++++ api/tests/unit_tests/enterprise/__init__.py | 0 .../enterprise/telemetry/__init__.py | 0 .../enterprise/telemetry/test_contracts.py | 230 +++++ .../telemetry/test_event_handlers.py | 121 +++ .../enterprise/telemetry/test_exporter.py | 263 +++++ .../enterprise/telemetry/test_gateway.py | 272 +++++ .../telemetry/test_metric_handler.py | 507 ++++++++++ .../tasks/test_enterprise_telemetry_task.py | 69 ++ 72 files changed, 7000 insertions(+), 372 deletions(-) create mode 100644 api/core/telemetry/__init__.py create mode 100644 api/core/telemetry/events.py create mode 100644 api/core/telemetry/gateway.py create mode 100644 api/enterprise/__init__.py create mode 100644 api/enterprise/telemetry/DATA_DICTIONARY.md create mode 100644 api/enterprise/telemetry/README.md create mode 100644 api/enterprise/telemetry/__init__.py create mode 100644 api/enterprise/telemetry/contracts.py create mode 100644 api/enterprise/telemetry/draft_trace.py create mode 100644 api/enterprise/telemetry/enterprise_trace.py create mode 100644 api/enterprise/telemetry/entities/__init__.py create mode 100644 api/enterprise/telemetry/event_handlers.py create mode 100644 api/enterprise/telemetry/exporter.py create mode 100644 api/enterprise/telemetry/id_generator.py create mode 100644 api/enterprise/telemetry/metric_handler.py create mode 100644 api/enterprise/telemetry/telemetry_log.py create mode 100644 api/events/feedback_event.py create mode 100644 api/extensions/ext_enterprise_telemetry.py create mode 100644 api/tasks/enterprise_telemetry_task.py create mode 100644 api/tests/unit_tests/core/ops/test_trace_queue_manager.py create mode 100644 api/tests/unit_tests/core/telemetry/test_facade.py create mode 100644 api/tests/unit_tests/core/telemetry/test_gateway_integration.py create mode 100644 api/tests/unit_tests/enterprise/__init__.py create mode 100644 api/tests/unit_tests/enterprise/telemetry/__init__.py create mode 100644 api/tests/unit_tests/enterprise/telemetry/test_contracts.py create mode 100644 api/tests/unit_tests/enterprise/telemetry/test_event_handlers.py create mode 100644 api/tests/unit_tests/enterprise/telemetry/test_exporter.py create mode 100644 api/tests/unit_tests/enterprise/telemetry/test_gateway.py create mode 100644 api/tests/unit_tests/enterprise/telemetry/test_metric_handler.py create mode 100644 api/tests/unit_tests/tasks/test_enterprise_telemetry_task.py diff --git a/api/.ruff.toml b/api/.ruff.toml index 3301452ad9..64a461443b 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -106,10 +106,10 @@ ignore = [ "N803", # invalid-argument-name ] "tests/*" = [ - "F811", # redefined-while-unused - "T201", # allow print in tests, - "S110", # allow ignoring exceptions in tests code (currently) - + "F811", # redefined-while-unused + "T201", # allow print in tests, + "S110", # allow ignoring exceptions in tests code (currently) + "PT019", # @patch-injected params look like unused fixtures ] "controllers/console/explore/trial.py" = ["TID251"] "controllers/console/human_input_form.py" = ["TID251"] diff --git a/api/README.md b/api/README.md index 9d89b490b0..9f247da8f0 100644 --- a/api/README.md +++ b/api/README.md @@ -122,7 +122,8 @@ These commands assume you start from the repository root. ```bash cd api - uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention + # Note: enterprise_telemetry queue is only used in Enterprise Edition + uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,enterprise_telemetry ``` 1. Optional: start Celery Beat (scheduled tasks, in a new terminal). diff --git a/api/app_factory.py b/api/app_factory.py index dcbc821687..11568f139f 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -81,6 +81,7 @@ def initialize_extensions(app: DifyApp): ext_commands, ext_compress, ext_database, + ext_enterprise_telemetry, ext_fastopenapi, ext_forward_refs, ext_hosting_provider, @@ -131,6 +132,7 @@ def initialize_extensions(app: DifyApp): ext_commands, ext_fastopenapi, ext_otel, + ext_enterprise_telemetry, ext_request_logging, ext_session_factory, ] diff --git a/api/configs/app_config.py b/api/configs/app_config.py index d3b1cf9d5b..831f0a49e0 100644 --- a/api/configs/app_config.py +++ b/api/configs/app_config.py @@ -8,7 +8,7 @@ from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, Settings from libs.file_utils import search_file_upwards from .deploy import DeploymentConfig -from .enterprise import EnterpriseFeatureConfig +from .enterprise import EnterpriseFeatureConfig, EnterpriseTelemetryConfig from .extra import ExtraServiceConfig from .feature import FeatureConfig from .middleware import MiddlewareConfig @@ -73,6 +73,8 @@ class DifyConfig( # Enterprise feature configs # **Before using, please contact business@dify.ai by email to inquire about licensing matters.** EnterpriseFeatureConfig, + # Enterprise telemetry configs + EnterpriseTelemetryConfig, ): model_config = SettingsConfigDict( # read from dotenv format config file diff --git a/api/configs/enterprise/__init__.py b/api/configs/enterprise/__init__.py index eda6345e14..11a71e1537 100644 --- a/api/configs/enterprise/__init__.py +++ b/api/configs/enterprise/__init__.py @@ -18,3 +18,49 @@ class EnterpriseFeatureConfig(BaseSettings): description="Allow customization of the enterprise logo.", default=False, ) + + +class EnterpriseTelemetryConfig(BaseSettings): + """ + Configuration for enterprise telemetry. + """ + + ENTERPRISE_TELEMETRY_ENABLED: bool = Field( + description="Enable enterprise telemetry collection (also requires ENTERPRISE_ENABLED=true).", + default=False, + ) + + ENTERPRISE_OTLP_ENDPOINT: str = Field( + description="Enterprise OTEL collector endpoint.", + default="", + ) + + ENTERPRISE_OTLP_HEADERS: str = Field( + description="Auth headers for OTLP export (key=value,key2=value2).", + default="", + ) + + ENTERPRISE_OTLP_PROTOCOL: str = Field( + description="OTLP protocol: 'http' or 'grpc' (default: http).", + default="http", + ) + + ENTERPRISE_OTLP_API_KEY: str = Field( + description="Bearer token for enterprise OTLP export authentication.", + default="", + ) + + ENTERPRISE_INCLUDE_CONTENT: bool = Field( + description="Include input/output content in traces (privacy toggle).", + default=True, + ) + + ENTERPRISE_SERVICE_NAME: str = Field( + description="Service name for OTEL resource.", + default="dify", + ) + + ENTERPRISE_OTEL_SAMPLING_RATE: float = Field( + description="Sampling rate for enterprise traces (0.0 to 1.0, default 1.0 = 100%).", + default=1.0, + ) diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 1ac55b5e8d..95329ab213 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -35,6 +35,7 @@ class InstructionGeneratePayload(BaseModel): instruction: str = Field(..., description="Instruction for generation") model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration") ideal_output: str = Field(default="", description="Expected ideal output") + app_id: str | None = Field(default=None, description="App ID for prompt generation tracing") class InstructionTemplatePayload(BaseModel): @@ -50,7 +51,6 @@ reg(RuleCodeGeneratePayload) reg(RuleStructuredOutputPayload) reg(InstructionGeneratePayload) reg(InstructionTemplatePayload) -reg(ModelConfig) @console_ns.route("/rule-generate") @@ -66,10 +66,17 @@ class RuleGenerateApi(Resource): @account_initialization_required def post(self): args = RuleGeneratePayload.model_validate(console_ns.payload) - _, current_tenant_id = current_account_with_tenant() + account, current_tenant_id = current_account_with_tenant() try: - rules = LLMGenerator.generate_rule_config(tenant_id=current_tenant_id, args=args) + rules = LLMGenerator.generate_rule_config( + tenant_id=current_tenant_id, + instruction=args.instruction, + model_config=args.model_config_data, + no_variable=args.no_variable, + user_id=account.id, + app_id=args.app_id, + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) except QuotaExceededError: @@ -95,12 +102,16 @@ class RuleCodeGenerateApi(Resource): @account_initialization_required def post(self): args = RuleCodeGeneratePayload.model_validate(console_ns.payload) - _, current_tenant_id = current_account_with_tenant() + account, current_tenant_id = current_account_with_tenant() try: code_result = LLMGenerator.generate_code( tenant_id=current_tenant_id, - args=args, + instruction=args.instruction, + model_config=args.model_config_data, + code_language=args.code_language, + user_id=account.id, + app_id=args.app_id, ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -127,12 +138,15 @@ class RuleStructuredOutputGenerateApi(Resource): @account_initialization_required def post(self): args = RuleStructuredOutputPayload.model_validate(console_ns.payload) - _, current_tenant_id = current_account_with_tenant() + account, current_tenant_id = current_account_with_tenant() try: structured_output = LLMGenerator.generate_structured_output( tenant_id=current_tenant_id, - args=args, + instruction=args.instruction, + model_config=args.model_config_data, + user_id=account.id, + app_id=args.app_id, ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -159,14 +173,14 @@ class InstructionGenerateApi(Resource): @account_initialization_required def post(self): args = InstructionGeneratePayload.model_validate(console_ns.payload) - _, current_tenant_id = current_account_with_tenant() + account, current_tenant_id = current_account_with_tenant() + app_id = args.app_id or args.flow_id providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider] code_provider: type[CodeNodeProvider] | None = next( (p for p in providers if p.is_accept_language(args.language)), None ) code_template = code_provider.get_default_code() if code_provider else "" try: - # Generate from nothing for a workflow node if (args.current in (code_template, "")) and args.node_id != "": app = db.session.query(App).where(App.id == args.flow_id).first() if not app: @@ -183,33 +197,33 @@ class InstructionGenerateApi(Resource): case "llm": return LLMGenerator.generate_rule_config( current_tenant_id, - args=RuleGeneratePayload( - instruction=args.instruction, - model_config=args.model_config_data, - no_variable=True, - ), + instruction=args.instruction, + model_config=args.model_config_data, + no_variable=True, + user_id=account.id, + app_id=app_id, ) case "agent": return LLMGenerator.generate_rule_config( current_tenant_id, - args=RuleGeneratePayload( - instruction=args.instruction, - model_config=args.model_config_data, - no_variable=True, - ), + instruction=args.instruction, + model_config=args.model_config_data, + no_variable=True, + user_id=account.id, + app_id=app_id, ) case "code": return LLMGenerator.generate_code( tenant_id=current_tenant_id, - args=RuleCodeGeneratePayload( - instruction=args.instruction, - model_config=args.model_config_data, - code_language=args.language, - ), + instruction=args.instruction, + model_config=args.model_config_data, + code_language=args.language, + user_id=account.id, + app_id=app_id, ) case _: return {"error": f"invalid node type: {node_type}"} - if args.node_id == "" and args.current != "": # For legacy app without a workflow + if args.node_id == "" and args.current != "": return LLMGenerator.instruction_modify_legacy( tenant_id=current_tenant_id, flow_id=args.flow_id, @@ -217,8 +231,10 @@ class InstructionGenerateApi(Resource): instruction=args.instruction, model_config=args.model_config_data, ideal_output=args.ideal_output, + user_id=account.id, + app_id=app_id, ) - if args.node_id != "" and args.current != "": # For workflow node + if args.node_id != "" and args.current != "": return LLMGenerator.instruction_modify_workflow( tenant_id=current_tenant_id, flow_id=args.flow_id, @@ -228,6 +244,8 @@ class InstructionGenerateApi(Resource): model_config=args.model_config_data, ideal_output=args.ideal_output, workflow_service=WorkflowService(), + user_id=account.id, + app_id=app_id, ) return {"error": "incompatible parameters"}, 400 except ProviderTokenNotInitError as ex: diff --git a/api/controllers/console/app/ops_trace.py b/api/controllers/console/app/ops_trace.py index cbcf513162..c5622c7006 100644 --- a/api/controllers/console/app/ops_trace.py +++ b/api/controllers/console/app/ops_trace.py @@ -1,6 +1,7 @@ from typing import Any from flask import request +from flask_login import current_user from flask_restx import Resource, fields from pydantic import BaseModel, Field from werkzeug.exceptions import BadRequest @@ -77,7 +78,10 @@ class TraceAppConfigApi(Resource): try: result = OpsService.create_tracing_app_config( - app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config + app_id=app_id, + tracing_provider=args.tracing_provider, + tracing_config=args.tracing_config, + account_id=current_user.id, ) if not result: raise TracingConfigIsExist() @@ -102,7 +106,10 @@ class TraceAppConfigApi(Resource): try: result = OpsService.update_tracing_app_config( - app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config + app_id=app_id, + tracing_provider=args.tracing_provider, + tracing_config=args.tracing_config, + account_id=current_user.id, ) if not result: raise TracingConfigNotExist() @@ -124,7 +131,9 @@ class TraceAppConfigApi(Resource): args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore try: - result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider) + result = OpsService.delete_tracing_app_config( + app_id=app_id, tracing_provider=args.tracing_provider, account_id=current_user.id + ) if not result: raise TracingConfigNotExist() return {"result": "success"}, 204 diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 3c6d36afe4..9765d7f41c 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -79,7 +79,7 @@ class BaseAgentRunner(AppRunner): self.model_instance = model_instance # init callback - self.agent_callback = DifyAgentCallbackHandler() + self.agent_callback = DifyAgentCallbackHandler(tenant_id=tenant_id) # init dataset tools hit_callback = DatasetIndexToolCallbackHandler( queue_manager=queue_manager, diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index da1e9f19b6..d8123593ec 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -63,6 +63,8 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.ops_trace_manager import TraceQueueManager +from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName +from core.telemetry import emit as telemetry_emit from core.workflow.enums import WorkflowExecutionStatus from core.workflow.nodes import NodeType from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory @@ -564,7 +566,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle stop events.""" - _ = trace_manager resolved_state = None if self._workflow_run_id: resolved_state = self._resolve_graph_runtime_state(graph_runtime_state) @@ -579,8 +580,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): ) with self._database_session() as session: - # Save message - self._save_message(session=session, graph_runtime_state=resolved_state) + self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager) yield workflow_finish_resp elif event.stopped_by in ( @@ -589,8 +589,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): ): # When hitting input-moderation or annotation-reply, the workflow will not start with self._database_session() as session: - # Save message - self._save_message(session=session) + self._save_message(session=session, trace_manager=trace_manager) yield self._message_end_to_stream_response() @@ -599,6 +598,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): event: QueueAdvancedChatMessageEndEvent, *, graph_runtime_state: GraphRuntimeState | None = None, + trace_manager: TraceQueueManager | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle advanced chat message end events.""" @@ -616,7 +616,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): # Save message with self._database_session() as session: - self._save_message(session=session, graph_runtime_state=resolved_state) + self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager) yield self._message_end_to_stream_response() @@ -770,7 +770,13 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): if self._conversation_name_generate_thread: logger.debug("Conversation name generation running as daemon thread") - def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None): + def _save_message( + self, + *, + session: Session, + graph_runtime_state: GraphRuntimeState | None = None, + trace_manager: TraceQueueManager | None = None, + ): message = self._get_message(session=session) # If there are assistant files, remove markdown image links from answer @@ -826,6 +832,22 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): ] session.add_all(message_files) + if trace_manager: + telemetry_emit( + TelemetryEvent( + name=TraceTaskName.MESSAGE_TRACE, + context=TelemetryContext( + tenant_id=self._application_generate_entity.app_config.tenant_id, + app_id=self._application_generate_entity.app_config.app_id, + ), + payload={ + "conversation_id": str(message.conversation_id), + "message_id": str(message.id), + }, + ), + trace_manager=trace_manager, + ) + def _seed_graph_runtime_state_from_queue_manager(self) -> None: """Bootstrap the cached runtime state from the queue manager when present.""" candidate = self._base_task_pipeline.queue_manager.graph_runtime_state diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index ee205ed153..5d04ae56e0 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -147,9 +147,12 @@ class WorkflowAppGenerator(BaseAppGenerator): inputs: Mapping[str, Any] = args["inputs"] - extras = { + extras: dict[str, Any] = { **extract_external_trace_id_from_args(args), } + parent_trace_context = args.get("_parent_trace_context") + if parent_trace_context: + extras["parent_trace_context"] = parent_trace_context workflow_run_id = str(uuid.uuid4()) # FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args # trigger shouldn't prepare user inputs diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 833f32fc7d..08d3dec770 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -54,10 +54,11 @@ from core.model_runtime.entities.message_entities import ( TextPromptMessageContent, ) from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.ops.entities.trace_entity import TraceTaskName -from core.ops.ops_trace_manager import TraceQueueManager, TraceTask +from core.ops.ops_trace_manager import TraceQueueManager from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName +from core.telemetry import emit as telemetry_emit from core.tools.signature import sign_tool_file from events.message_event import message_was_created from extensions.ext_database import db @@ -412,10 +413,19 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): message.message_metadata = self._task_state.metadata.model_dump_json() if trace_manager: - trace_manager.add_trace_task( - TraceTask( - TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation_id, message_id=self._message_id - ) + telemetry_emit( + TelemetryEvent( + name=TraceTaskName.MESSAGE_TRACE, + context=TelemetryContext( + tenant_id=self._application_generate_entity.app_config.tenant_id, + app_id=self._application_generate_entity.app_config.app_id, + ), + payload={ + "conversation_id": self._conversation_id, + "message_id": self._message_id, + }, + ), + trace_manager=trace_manager, ) message_was_created.send( diff --git a/api/core/app/workflow/layers/persistence.py b/api/core/app/workflow/layers/persistence.py index 41052b4f52..fd7c19a71d 100644 --- a/api/core/app/workflow/layers/persistence.py +++ b/api/core/app/workflow/layers/persistence.py @@ -15,8 +15,7 @@ from datetime import datetime from typing import Any, Union from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity -from core.ops.entities.trace_entity import TraceTaskName -from core.ops.ops_trace_manager import TraceQueueManager, TraceTask +from core.ops.ops_trace_manager import TraceQueueManager from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.entities import WorkflowExecution, WorkflowNodeExecution from core.workflow.enums import ( @@ -373,6 +372,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer): self._workflow_node_execution_repository.save(domain_execution) self._workflow_node_execution_repository.save_execution_data(domain_execution) + self._enqueue_node_trace_task(domain_execution) def _fail_running_node_executions(self, *, error_message: str) -> None: now = naive_utc_now() @@ -390,17 +390,138 @@ class WorkflowPersistenceLayer(GraphEngineLayer): conversation_id = self._system_variables().get(SystemVariableKey.CONVERSATION_ID.value) external_trace_id = None + parent_trace_context = None if isinstance(self._application_generate_entity, (WorkflowAppGenerateEntity, AdvancedChatAppGenerateEntity)): external_trace_id = self._application_generate_entity.extras.get("external_trace_id") + parent_trace_context = self._application_generate_entity.extras.get("parent_trace_context") - trace_task = TraceTask( - TraceTaskName.WORKFLOW_TRACE, - workflow_execution=execution, - conversation_id=conversation_id, - user_id=self._trace_manager.user_id, - external_trace_id=external_trace_id, + from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName + from core.telemetry import emit as telemetry_emit + + telemetry_emit( + TelemetryEvent( + name=TraceTaskName.WORKFLOW_TRACE, + context=TelemetryContext( + tenant_id=self._application_generate_entity.app_config.tenant_id, + user_id=self._trace_manager.user_id, + app_id=self._application_generate_entity.app_config.app_id, + ), + payload={ + "workflow_execution": execution, + "conversation_id": conversation_id, + "user_id": self._trace_manager.user_id, + "external_trace_id": external_trace_id, + "parent_trace_context": parent_trace_context, + }, + ), + trace_manager=self._trace_manager, + ) + + def _enqueue_node_trace_task(self, domain_execution: WorkflowNodeExecution) -> None: + if not self._trace_manager: + return + + execution = self._get_workflow_execution() + meta = domain_execution.metadata or {} + + parent_trace_context = None + if isinstance(self._application_generate_entity, (WorkflowAppGenerateEntity, AdvancedChatAppGenerateEntity)): + parent_trace_context = self._application_generate_entity.extras.get("parent_trace_context") + + node_data: dict[str, Any] = { + "workflow_id": domain_execution.workflow_id, + "workflow_execution_id": execution.id_, + "tenant_id": self._application_generate_entity.app_config.tenant_id, + "app_id": self._application_generate_entity.app_config.app_id, + "node_execution_id": domain_execution.id, + "node_id": domain_execution.node_id, + "node_type": str(domain_execution.node_type.value), + "title": domain_execution.title, + "status": str(domain_execution.status.value), + "error": domain_execution.error, + "elapsed_time": domain_execution.elapsed_time, + "index": domain_execution.index, + "predecessor_node_id": domain_execution.predecessor_node_id, + "created_at": domain_execution.created_at, + "finished_at": domain_execution.finished_at, + "total_tokens": meta.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS, 0), + "prompt_tokens": meta.get(WorkflowNodeExecutionMetadataKey.PROMPT_TOKENS), + "completion_tokens": meta.get(WorkflowNodeExecutionMetadataKey.COMPLETION_TOKENS), + "total_price": meta.get(WorkflowNodeExecutionMetadataKey.TOTAL_PRICE, 0.0), + "currency": meta.get(WorkflowNodeExecutionMetadataKey.CURRENCY), + "tool_name": (meta.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO) or {}).get("tool_name") + if isinstance(meta.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO), dict) + else None, + "iteration_id": meta.get(WorkflowNodeExecutionMetadataKey.ITERATION_ID), + "iteration_index": meta.get(WorkflowNodeExecutionMetadataKey.ITERATION_INDEX), + "loop_id": meta.get(WorkflowNodeExecutionMetadataKey.LOOP_ID), + "loop_index": meta.get(WorkflowNodeExecutionMetadataKey.LOOP_INDEX), + "parallel_id": meta.get(WorkflowNodeExecutionMetadataKey.PARALLEL_ID), + "node_inputs": dict(domain_execution.inputs) if domain_execution.inputs else None, + "node_outputs": dict(domain_execution.outputs) if domain_execution.outputs else None, + "process_data": dict(domain_execution.process_data) if domain_execution.process_data else None, + } + node_data["invoke_from"] = self._application_generate_entity.invoke_from.value + node_data["user_id"] = self._system_variables().get(SystemVariableKey.USER_ID.value) + + # Extract model info from process_data — LLM nodes store provider/model there, + if domain_execution.process_data: + if mp := domain_execution.process_data.get("model_provider"): + node_data["model_provider"] = mp + if mn := domain_execution.process_data.get("model_name"): + node_data["model_name"] = mn + + if domain_execution.node_type.value == "knowledge-retrieval" and domain_execution.outputs: + results = domain_execution.outputs.get("result") or [] + dataset_ids: list[str] = [] + dataset_names: list[str] = [] + for doc in results: + if not isinstance(doc, dict): + continue + doc_meta = doc.get("metadata") or {} + did = doc_meta.get("dataset_id") + dname = doc_meta.get("dataset_name") + if did and did not in dataset_ids: + dataset_ids.append(did) + if dname and dname not in dataset_names: + dataset_names.append(dname) + if dataset_ids: + node_data["dataset_ids"] = dataset_ids + if dataset_names: + node_data["dataset_names"] = dataset_names + + tool_info = meta.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO) + if isinstance(tool_info, dict): + plugin_id = tool_info.get("plugin_unique_identifier") + if plugin_id: + node_data["plugin_name"] = plugin_id + credential_id = tool_info.get("credential_id") + if credential_id: + node_data["credential_id"] = credential_id + node_data["credential_provider_type"] = tool_info.get("provider_type") + + conversation_id = self._system_variables().get(SystemVariableKey.CONVERSATION_ID.value) + if conversation_id: + node_data["conversation_id"] = conversation_id + + if parent_trace_context: + node_data["parent_trace_context"] = parent_trace_context + + from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName + from core.telemetry import emit as telemetry_emit + + telemetry_emit( + TelemetryEvent( + name=TraceTaskName.NODE_EXECUTION_TRACE, + context=TelemetryContext( + tenant_id=node_data.get("tenant_id"), + user_id=node_data.get("user_id"), + app_id=node_data.get("app_id"), + ), + payload={"node_execution_data": node_data}, + ), + trace_manager=self._trace_manager, ) - self._trace_manager.add_trace_task(trace_task) def _system_variables(self) -> Mapping[str, Any]: runtime_state = self.graph_runtime_state diff --git a/api/core/callback_handler/agent_tool_callback_handler.py b/api/core/callback_handler/agent_tool_callback_handler.py index 6591b08a7e..e1c5f4ac4b 100644 --- a/api/core/callback_handler/agent_tool_callback_handler.py +++ b/api/core/callback_handler/agent_tool_callback_handler.py @@ -4,8 +4,9 @@ from typing import Any, TextIO, Union from pydantic import BaseModel from configs import dify_config -from core.ops.entities.trace_entity import TraceTaskName -from core.ops.ops_trace_manager import TraceQueueManager, TraceTask +from core.ops.ops_trace_manager import TraceQueueManager +from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName +from core.telemetry import emit as telemetry_emit from core.tools.entities.tool_entities import ToolInvokeMessage _TEXT_COLOR_MAPPING = { @@ -36,13 +37,15 @@ class DifyAgentCallbackHandler(BaseModel): color: str | None = "" current_loop: int = 1 + tenant_id: str | None = None - def __init__(self, color: str | None = None): + def __init__(self, color: str | None = None, tenant_id: str | None = None): super().__init__() """Initialize callback handler.""" # use a specific color is not specified self.color = color or "green" self.current_loop = 1 + self.tenant_id = tenant_id def on_tool_start( self, @@ -71,15 +74,23 @@ class DifyAgentCallbackHandler(BaseModel): print_text("\n") if trace_manager: - trace_manager.add_trace_task( - TraceTask( - TraceTaskName.TOOL_TRACE, - message_id=message_id, - tool_name=tool_name, - tool_inputs=tool_inputs, - tool_outputs=tool_outputs, - timer=timer, - ) + telemetry_emit( + TelemetryEvent( + name=TraceTaskName.TOOL_TRACE, + context=TelemetryContext( + tenant_id=self.tenant_id, + app_id=trace_manager.app_id, + user_id=trace_manager.user_id, + ), + payload={ + "message_id": message_id, + "tool_name": tool_name, + "tool_inputs": tool_inputs, + "tool_outputs": tool_outputs, + "timer": timer, + }, + ), + trace_manager=trace_manager, ) def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any): diff --git a/api/core/llm_generator/entities.py b/api/core/llm_generator/entities.py index 3bb8d2c899..6573bcbe95 100644 --- a/api/core/llm_generator/entities.py +++ b/api/core/llm_generator/entities.py @@ -9,6 +9,7 @@ class RuleGeneratePayload(BaseModel): instruction: str = Field(..., description="Rule generation instruction") model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration") no_variable: bool = Field(default=False, description="Whether to exclude variables") + app_id: str | None = Field(default=None, description="App ID for prompt generation tracing") class RuleCodeGeneratePayload(RuleGeneratePayload): @@ -18,3 +19,4 @@ class RuleCodeGeneratePayload(RuleGeneratePayload): class RuleStructuredOutputPayload(BaseModel): instruction: str = Field(..., description="Structured output generation instruction") model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration") + app_id: str | None = Field(default=None, description="App ID for prompt generation tracing") diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 5b2c640265..0093b17cba 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -7,7 +7,6 @@ from typing import Protocol, cast import json_repair from core.app.app_config.entities import ModelConfig -from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser from core.llm_generator.prompts import ( @@ -27,10 +26,11 @@ from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError -from core.ops.entities.trace_entity import TraceTaskName -from core.ops.ops_trace_manager import TraceQueueManager, TraceTask +from core.ops.entities.trace_entity import OperationType from core.ops.utils import measure_time from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName +from core.telemetry import emit as telemetry_emit from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from extensions.ext_database import db from extensions.ext_storage import storage @@ -74,7 +74,7 @@ class LLMGenerator: prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False ) answer = response.message.get_text_content() - if answer == "": + if answer is None: return "" try: result_dict = json.loads(answer) @@ -96,15 +96,17 @@ class LLMGenerator: name = name[:75] + "..." # get tracing instance - trace_manager = TraceQueueManager(app_id=app_id) - trace_manager.add_trace_task( - TraceTask( - TraceTaskName.GENERATE_NAME_TRACE, - conversation_id=conversation_id, - generate_conversation_name=name, - inputs=prompt, - timer=timer, - tenant_id=tenant_id, + telemetry_emit( + TelemetryEvent( + name=TraceTaskName.GENERATE_NAME_TRACE, + context=TelemetryContext(tenant_id=tenant_id, app_id=app_id), + payload={ + "conversation_id": conversation_id, + "generate_conversation_name": name, + "inputs": prompt, + "timer": timer, + "tenant_id": tenant_id, + }, ) ) @@ -153,19 +155,27 @@ class LLMGenerator: return questions @classmethod - def generate_rule_config(cls, tenant_id: str, args: RuleGeneratePayload): + def generate_rule_config( + cls, + tenant_id: str, + instruction: str, + model_config: ModelConfig, + no_variable: bool, + user_id: str | None = None, + app_id: str | None = None, + ): output_parser = RuleConfigGeneratorOutputParser() error = "" error_step = "" rule_config = {"prompt": "", "variables": [], "opening_statement": "", "error": ""} - model_parameters = args.model_config_data.completion_params - if args.no_variable: + model_parameters = model_config.completion_params + if no_variable: prompt_template = PromptTemplateParser(WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE) prompt_generate = prompt_template.format( inputs={ - "TASK_DESCRIPTION": args.instruction, + "TASK_DESCRIPTION": instruction, }, remove_template_variables=False, ) @@ -177,26 +187,45 @@ class LLMGenerator: model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, - provider=args.model_config_data.provider, - model=args.model_config_data.name, + provider=model_config.provider, + model=model_config.name, ) - try: - response: LLMResult = model_instance.invoke_llm( - prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False - ) + llm_result = None + with measure_time() as timer: + try: + llm_result = model_instance.invoke_llm( + prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False + ) - rule_config["prompt"] = response.message.get_text_content() + rule_config["prompt"] = llm_result.message.get_text_content() or "" - except InvokeError as e: - error = str(e) - error_step = "generate rule config" - except Exception as e: - logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name) - rule_config["error"] = str(e) + except InvokeError as e: + error = str(e) + error_step = "generate rule config" + except Exception as e: + logger.exception("Failed to generate rule config, model: %s", model_config.name) + rule_config["error"] = str(e) + error = str(e) rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else "" + if user_id: + prompt_value = rule_config.get("prompt", "") + generated_output = str(prompt_value) if prompt_value else "" + cls._emit_prompt_generation_trace( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + operation_type=OperationType.RULE_GENERATE, + instruction=instruction, + generated_output=generated_output, + llm_result=llm_result, + model_config=model_config, + timer=timer, + error=error or None, + ) + return rule_config # get rule config prompt, parameter and statement @@ -211,7 +240,7 @@ class LLMGenerator: # format the prompt_generate_prompt prompt_generate_prompt = prompt_template.format( inputs={ - "TASK_DESCRIPTION": args.instruction, + "TASK_DESCRIPTION": instruction, }, remove_template_variables=False, ) @@ -222,84 +251,125 @@ class LLMGenerator: model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, - provider=args.model_config_data.provider, - model=args.model_config_data.name, + provider=model_config.provider, + model=model_config.name, ) - try: + llm_result = None + with measure_time() as timer: try: - # the first step to generate the task prompt - prompt_content: LLMResult = model_instance.invoke_llm( - prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False + try: + # the first step to generate the task prompt + prompt_content: LLMResult = model_instance.invoke_llm( + prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False + ) + llm_result = prompt_content + except InvokeError as e: + error = str(e) + error_step = "generate prefix prompt" + rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else "" + + if user_id: + cls._emit_prompt_generation_trace( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + operation_type=OperationType.RULE_GENERATE, + instruction=instruction, + generated_output="", + llm_result=llm_result, + model_config=model_config, + timer=timer, + error=error, + ) + + return rule_config + + rule_config["prompt"] = prompt_content.message.get_text_content() or "" + + if not isinstance(prompt_content.message.content, str): + raise NotImplementedError("prompt content is not a string") + parameter_generate_prompt = parameter_template.format( + inputs={ + "INPUT_TEXT": prompt_content.message.content, + }, + remove_template_variables=False, ) - except InvokeError as e: - error = str(e) - error_step = "generate prefix prompt" - rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else "" + parameter_messages = [UserPromptMessage(content=parameter_generate_prompt)] - return rule_config - - rule_config["prompt"] = prompt_content.message.get_text_content() - - parameter_generate_prompt = parameter_template.format( - inputs={ - "INPUT_TEXT": prompt_content.message.get_text_content(), - }, - remove_template_variables=False, - ) - parameter_messages = [UserPromptMessage(content=parameter_generate_prompt)] - - # the second step to generate the task_parameter and task_statement - statement_generate_prompt = statement_template.format( - inputs={ - "TASK_DESCRIPTION": args.instruction, - "INPUT_TEXT": prompt_content.message.get_text_content(), - }, - remove_template_variables=False, - ) - statement_messages = [UserPromptMessage(content=statement_generate_prompt)] - - try: - parameter_content: LLMResult = model_instance.invoke_llm( - prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False + # the second step to generate the task_parameter and task_statement + statement_generate_prompt = statement_template.format( + inputs={ + "TASK_DESCRIPTION": instruction, + "INPUT_TEXT": prompt_content.message.content, + }, + remove_template_variables=False, ) - rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', parameter_content.message.get_text_content()) - except InvokeError as e: - error = str(e) - error_step = "generate variables" + statement_messages = [UserPromptMessage(content=statement_generate_prompt)] - try: - statement_content: LLMResult = model_instance.invoke_llm( - prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False - ) - rule_config["opening_statement"] = statement_content.message.get_text_content() - except InvokeError as e: - error = str(e) - error_step = "generate conversation opener" + try: + parameter_content: LLMResult = model_instance.invoke_llm( + prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False + ) + rule_config["variables"] = re.findall( + r'"\s*([^"]+)\s*"', prompt_content.message.get_text_content() or "" + ) + except InvokeError as e: + error = str(e) + error_step = "generate variables" - except Exception as e: - logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name) - rule_config["error"] = str(e) + try: + statement_content: LLMResult = model_instance.invoke_llm( + prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False + ) + rule_config["opening_statement"] = statement_content.message.get_text_content() or "" + except InvokeError as e: + error = str(e) + error_step = "generate conversation opener" + + except Exception as e: + logger.exception("Failed to generate rule config, model: %s", model_config.name) + rule_config["error"] = str(e) + error = str(e) rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else "" + if user_id: + generated_output = rule_config.get("prompt", "") + cls._emit_prompt_generation_trace( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + operation_type=OperationType.RULE_GENERATE, + instruction=instruction, + generated_output=str(generated_output) if generated_output else "", + llm_result=llm_result, + model_config=model_config, + timer=timer, + error=error or None, + ) + return rule_config @classmethod def generate_code( cls, tenant_id: str, - args: RuleCodeGeneratePayload, + instruction: str, + model_config: ModelConfig, + code_language: str = "javascript", + user_id: str | None = None, + app_id: str | None = None, ): - if args.code_language == "python": + if code_language == "python": prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE) else: prompt_template = PromptTemplateParser(JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE) prompt = prompt_template.format( inputs={ - "INSTRUCTION": args.instruction, - "CODE_LANGUAGE": args.code_language, + "INSTRUCTION": instruction, + "CODE_LANGUAGE": code_language, }, remove_template_variables=False, ) @@ -308,28 +378,49 @@ class LLMGenerator: model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, - provider=args.model_config_data.provider, - model=args.model_config_data.name, + provider=model_config.provider, + model=model_config.name, ) prompt_messages = [UserPromptMessage(content=prompt)] - model_parameters = args.model_config_data.completion_params - try: - response: LLMResult = model_instance.invoke_llm( - prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False + model_parameters = model_config.completion_params + + llm_result = None + error = None + with measure_time() as timer: + try: + llm_result = model_instance.invoke_llm( + prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False + ) + + generated_code = llm_result.message.get_text_content() or "" + result = {"code": generated_code, "language": code_language, "error": ""} + + except InvokeError as e: + error = str(e) + result = {"code": "", "language": code_language, "error": f"Failed to generate code. Error: {error}"} + except Exception as e: + logger.exception( + "Failed to invoke LLM model, model: %s, language: %s", model_config.name, code_language + ) + error = str(e) + result = {"code": "", "language": code_language, "error": f"An unexpected error occurred: {str(e)}"} + + if user_id: + cls._emit_prompt_generation_trace( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + operation_type=OperationType.CODE_GENERATE, + instruction=instruction, + generated_output=result.get("code", ""), + llm_result=llm_result, + model_config=model_config, + timer=timer, + error=error, ) - generated_code = response.message.get_text_content() - return {"code": generated_code, "language": args.code_language, "error": ""} - - except InvokeError as e: - error = str(e) - return {"code": "", "language": args.code_language, "error": f"Failed to generate code. Error: {error}"} - except Exception as e: - logger.exception( - "Failed to invoke LLM model, model: %s, language: %s", args.model_config_data.name, args.code_language - ) - return {"code": "", "language": args.code_language, "error": f"An unexpected error occurred: {str(e)}"} + return result @classmethod def generate_qa_document(cls, tenant_id: str, query, document_language: str): @@ -355,49 +446,81 @@ class LLMGenerator: raise TypeError("Expected LLMResult when stream=False") response = result - answer = response.message.get_text_content() + answer = response.message.get_text_content() or "" return answer.strip() @classmethod - def generate_structured_output(cls, tenant_id: str, args: RuleStructuredOutputPayload): + def generate_structured_output( + cls, + tenant_id: str, + instruction: str, + model_config: ModelConfig, + user_id: str | None = None, + app_id: str | None = None, + ): model_manager = ModelManager() model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, - provider=args.model_config_data.provider, - model=args.model_config_data.name, + provider=model_config.provider, + model=model_config.name, ) prompt_messages = [ SystemPromptMessage(content=SYSTEM_STRUCTURED_OUTPUT_GENERATE), - UserPromptMessage(content=args.instruction), + UserPromptMessage(content=instruction), ] - model_parameters = args.model_config_data.completion_params + model_parameters = model_config.completion_params - try: - response: LLMResult = model_instance.invoke_llm( - prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False + llm_result = None + error = None + result = {"output": "", "error": ""} + + with measure_time() as timer: + try: + llm_result = model_instance.invoke_llm( + prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False + ) + + raw_content = llm_result.message.content + + if not isinstance(raw_content, str): + raise ValueError(f"LLM response content must be a string, got: {type(raw_content)}") + + try: + parsed_content = json.loads(raw_content) + except json.JSONDecodeError: + parsed_content = json_repair.loads(raw_content) + + if not isinstance(parsed_content, dict | list): + raise ValueError(f"Failed to parse structured output from llm: {raw_content}") + + generated_json_schema = json.dumps(parsed_content, indent=2, ensure_ascii=False) + result = {"output": generated_json_schema, "error": ""} + + except InvokeError as e: + error = str(e) + result = {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"} + except Exception as e: + logger.exception("Failed to invoke LLM model, model: %s", model_config.name) + error = str(e) + result = {"output": "", "error": f"An unexpected error occurred: {str(e)}"} + + if user_id: + cls._emit_prompt_generation_trace( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + operation_type=OperationType.STRUCTURED_OUTPUT, + instruction=instruction, + generated_output=result.get("output", ""), + llm_result=llm_result, + model_config=model_config, + timer=timer, + error=error, ) - raw_content = response.message.get_text_content() - - try: - parsed_content = json.loads(raw_content) - except json.JSONDecodeError: - parsed_content = json_repair.loads(raw_content) - - if not isinstance(parsed_content, dict | list): - raise ValueError(f"Failed to parse structured output from llm: {raw_content}") - - generated_json_schema = json.dumps(parsed_content, indent=2, ensure_ascii=False) - return {"output": generated_json_schema, "error": ""} - - except InvokeError as e: - error = str(e) - return {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"} - except Exception as e: - logger.exception("Failed to invoke LLM model, model: %s", args.model_config_data.name) - return {"output": "", "error": f"An unexpected error occurred: {str(e)}"} + return result @staticmethod def instruction_modify_legacy( @@ -407,12 +530,14 @@ class LLMGenerator: instruction: str, model_config: ModelConfig, ideal_output: str | None, + user_id: str | None = None, + app_id: str | None = None, ): last_run: Message | None = ( db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first() ) if not last_run: - return LLMGenerator.__instruction_modify_common( + result = LLMGenerator.__instruction_modify_common( tenant_id=tenant_id, model_config=model_config, last_run=None, @@ -421,22 +546,28 @@ class LLMGenerator: instruction=instruction, node_type="llm", ideal_output=ideal_output, + user_id=user_id, + app_id=app_id, ) - last_run_dict = { - "query": last_run.query, - "answer": last_run.answer, - "error": last_run.error, - } - return LLMGenerator.__instruction_modify_common( - tenant_id=tenant_id, - model_config=model_config, - last_run=last_run_dict, - current=current, - error_message=str(last_run.error), - instruction=instruction, - node_type="llm", - ideal_output=ideal_output, - ) + else: + last_run_dict = { + "query": last_run.query, + "answer": last_run.answer, + "error": last_run.error, + } + result = LLMGenerator.__instruction_modify_common( + tenant_id=tenant_id, + model_config=model_config, + last_run=last_run_dict, + current=current, + error_message=str(last_run.error), + instruction=instruction, + node_type="llm", + ideal_output=ideal_output, + user_id=user_id, + app_id=app_id, + ) + return result @staticmethod def instruction_modify_workflow( @@ -448,6 +579,8 @@ class LLMGenerator: model_config: ModelConfig, ideal_output: str | None, workflow_service: WorkflowServiceInterface, + user_id: str | None = None, + app_id: str | None = None, ): session = db.session() @@ -478,6 +611,8 @@ class LLMGenerator: instruction=instruction, node_type=node_type, ideal_output=ideal_output, + user_id=user_id, + app_id=app_id, ) def agent_log_of(node_execution: WorkflowNodeExecutionModel) -> Sequence: @@ -511,6 +646,8 @@ class LLMGenerator: instruction=instruction, node_type=last_run.node_type, ideal_output=ideal_output, + user_id=user_id, + app_id=app_id, ) @staticmethod @@ -523,6 +660,8 @@ class LLMGenerator: instruction: str, node_type: str, ideal_output: str | None, + user_id: str | None = None, + app_id: str | None = None, ): LAST_RUN = "{{#last_run#}}" CURRENT = "{{#current#}}" @@ -562,24 +701,120 @@ class LLMGenerator: ] model_parameters = {"temperature": 0.4} - try: - response: LLMResult = model_instance.invoke_llm( - prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False + llm_result = None + error = None + result = {} + + with measure_time() as timer: + try: + llm_result = model_instance.invoke_llm( + prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False + ) + + generated_raw = llm_result.message.get_text_content() + first_brace = generated_raw.find("{") + last_brace = generated_raw.rfind("}") + if first_brace == -1 or last_brace == -1 or last_brace < first_brace: + raise ValueError(f"Could not find a valid JSON object in response: {generated_raw}") + json_str = generated_raw[first_brace : last_brace + 1] + data = json_repair.loads(json_str) + if not isinstance(data, dict): + raise TypeError(f"Expected a JSON object, but got {type(data).__name__}") + result = data + except InvokeError as e: + error = str(e) + result = {"error": f"Failed to generate code. Error: {error}"} + except Exception as e: + logger.exception("Failed to invoke LLM model, model: %s", json.dumps(model_config.name), exc_info=True) + error = str(e) + result = {"error": f"An unexpected error occurred: {str(e)}"} + + if user_id: + generated_output = "" + if isinstance(result, dict): + for key in ["prompt", "code", "output", "modified"]: + if result.get(key): + generated_output = str(result[key]) + break + + LLMGenerator._emit_prompt_generation_trace( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + operation_type=OperationType.INSTRUCTION_MODIFY, + instruction=instruction, + generated_output=generated_output, + llm_result=llm_result, + model_config=model_config, + timer=timer, + error=error, ) - generated_raw = response.message.get_text_content() - first_brace = generated_raw.find("{") - last_brace = generated_raw.rfind("}") - if first_brace == -1 or last_brace == -1 or last_brace < first_brace: - raise ValueError(f"Could not find a valid JSON object in response: {generated_raw}") - json_str = generated_raw[first_brace : last_brace + 1] - data = json_repair.loads(json_str) - if not isinstance(data, dict): - raise TypeError(f"Expected a JSON object, but got {type(data).__name__}") - return data - except InvokeError as e: - error = str(e) - return {"error": f"Failed to generate code. Error: {error}"} - except Exception as e: - logger.exception("Failed to invoke LLM model, model: %s", json.dumps(model_config.name), exc_info=True) - return {"error": f"An unexpected error occurred: {str(e)}"} + return result + + @classmethod + def _emit_prompt_generation_trace( + cls, + tenant_id: str, + user_id: str, + app_id: str | None, + operation_type: OperationType, + instruction: str, + generated_output: str, + llm_result: LLMResult | None, + model_config: ModelConfig | None = None, + timer=None, + error: str | None = None, + ): + if llm_result: + prompt_tokens = llm_result.usage.prompt_tokens + completion_tokens = llm_result.usage.completion_tokens + total_tokens = llm_result.usage.total_tokens + model_name = llm_result.model + # Extract provider from model_config if available, otherwise fall back to parsing model name + if model_config and model_config.provider: + model_provider = model_config.provider + else: + model_provider = model_name.split("/")[0] if "/" in model_name else "" + latency = llm_result.usage.latency + total_price = float(llm_result.usage.total_price) if llm_result.usage.total_price else None + currency = llm_result.usage.currency + else: + prompt_tokens = 0 + completion_tokens = 0 + total_tokens = 0 + model_provider = model_config.provider if model_config else "" + model_name = model_config.name if model_config else "" + latency = 0.0 + if timer: + start_time = timer.get("start") + end_time = timer.get("end") + if start_time and end_time: + latency = (end_time - start_time).total_seconds() + total_price = None + currency = None + + telemetry_emit( + TelemetryEvent( + name=TraceTaskName.PROMPT_GENERATION_TRACE, + context=TelemetryContext(tenant_id=tenant_id, user_id=user_id, app_id=app_id), + payload={ + "tenant_id": tenant_id, + "user_id": user_id, + "app_id": app_id, + "operation_type": operation_type, + "instruction": instruction, + "generated_output": generated_output, + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + "model_provider": model_provider, + "model_name": model_name, + "latency": latency, + "total_price": total_price, + "currency": currency, + "timer": timer, + "error": error, + }, + ) + ) diff --git a/api/core/logging/filters.py b/api/core/logging/filters.py index 1e8aa8d566..bc816eb66b 100644 --- a/api/core/logging/filters.py +++ b/api/core/logging/filters.py @@ -15,16 +15,23 @@ class TraceContextFilter(logging.Filter): """ def filter(self, record: logging.LogRecord) -> bool: - # Get trace context from OpenTelemetry - trace_id, span_id = self._get_otel_context() + # Preserve explicit trace_id set by the caller (e.g. emit_metric_only_event) + existing_trace_id = getattr(record, "trace_id", "") + if not existing_trace_id: + # Get trace context from OpenTelemetry + trace_id, span_id = self._get_otel_context() - # Set trace_id (fallback to ContextVar if no OTEL context) - if trace_id: - record.trace_id = trace_id + # Set trace_id (fallback to ContextVar if no OTEL context) + if trace_id: + record.trace_id = trace_id + else: + record.trace_id = get_trace_id() + + record.span_id = span_id or "" else: - record.trace_id = get_trace_id() - - record.span_id = span_id or "" + # Keep existing trace_id; only fill span_id if missing + if not getattr(record, "span_id", ""): + record.span_id = "" # For backward compatibility, also set req_id record.req_id = get_request_id() @@ -55,9 +62,12 @@ class IdentityContextFilter(logging.Filter): def filter(self, record: logging.LogRecord) -> bool: identity = self._extract_identity() - record.tenant_id = identity.get("tenant_id", "") - record.user_id = identity.get("user_id", "") - record.user_type = identity.get("user_type", "") + if not getattr(record, "tenant_id", ""): + record.tenant_id = identity.get("tenant_id", "") + if not getattr(record, "user_id", ""): + record.user_id = identity.get("user_id", "") + if not getattr(record, "user_type", ""): + record.user_type = identity.get("user_type", "") return True def _extract_identity(self) -> dict[str, str]: diff --git a/api/core/moderation/input_moderation.py b/api/core/moderation/input_moderation.py index 21dc58f16f..4afe706a62 100644 --- a/api/core/moderation/input_moderation.py +++ b/api/core/moderation/input_moderation.py @@ -5,9 +5,10 @@ from typing import Any from core.app.app_config.entities import AppConfig from core.moderation.base import ModerationAction, ModerationError from core.moderation.factory import ModerationFactory -from core.ops.entities.trace_entity import TraceTaskName -from core.ops.ops_trace_manager import TraceQueueManager, TraceTask +from core.ops.ops_trace_manager import TraceQueueManager from core.ops.utils import measure_time +from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName +from core.telemetry import emit as telemetry_emit logger = logging.getLogger(__name__) @@ -49,14 +50,18 @@ class InputModeration: moderation_result = moderation_factory.moderation_for_inputs(inputs, query) if trace_manager: - trace_manager.add_trace_task( - TraceTask( - TraceTaskName.MODERATION_TRACE, - message_id=message_id, - moderation_result=moderation_result, - inputs=inputs, - timer=timer, - ) + telemetry_emit( + TelemetryEvent( + name=TraceTaskName.MODERATION_TRACE, + context=TelemetryContext(tenant_id=tenant_id, app_id=app_id), + payload={ + "message_id": message_id, + "moderation_result": moderation_result, + "inputs": inputs, + "timer": timer, + }, + ), + trace_manager=trace_manager, ) if not moderation_result.flagged: diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py index 50a2cdea63..a6ca1b098b 100644 --- a/api/core/ops/entities/trace_entity.py +++ b/api/core/ops/entities/trace_entity.py @@ -9,8 +9,8 @@ from pydantic import BaseModel, ConfigDict, field_serializer, field_validator class BaseTraceInfo(BaseModel): message_id: str | None = None message_data: Any | None = None - inputs: Union[str, dict[str, Any], list] | None = None - outputs: Union[str, dict[str, Any], list] | None = None + inputs: Union[str, dict[str, Any], list[Any]] | None = None + outputs: Union[str, dict[str, Any], list[Any]] | None = None start_time: datetime | None = None end_time: datetime | None = None metadata: dict[str, Any] @@ -18,7 +18,7 @@ class BaseTraceInfo(BaseModel): @field_validator("inputs", "outputs") @classmethod - def ensure_type(cls, v): + def ensure_type(cls, v: str | dict[str, Any] | list[Any] | None) -> str | dict[str, Any] | list[Any] | None: if v is None: return None if isinstance(v, str | dict | list): @@ -27,6 +27,48 @@ class BaseTraceInfo(BaseModel): model_config = ConfigDict(protected_namespaces=()) + @property + def resolved_trace_id(self) -> str | None: + """Get trace_id with intelligent fallback. + + Priority: + 1. External trace_id (from X-Trace-Id header) + 2. workflow_run_id (if this trace type has it) + 3. message_id (as final fallback) + """ + if self.trace_id: + return self.trace_id + + # Try workflow_run_id (only exists on workflow-related traces) + workflow_run_id = getattr(self, "workflow_run_id", None) + if workflow_run_id: + return workflow_run_id + + # Final fallback to message_id + return str(self.message_id) if self.message_id else None + + @property + def resolved_parent_context(self) -> tuple[str | None, str | None]: + """Resolve cross-workflow parent linking from metadata. + + Extracts typed parent IDs from the untyped ``parent_trace_context`` + metadata dict (set by tool_node when invoking nested workflows). + + Returns: + (trace_correlation_override, parent_span_id_source) where + trace_correlation_override is the outer workflow_run_id and + parent_span_id_source is the outer node_execution_id. + """ + parent_ctx = self.metadata.get("parent_trace_context") + if not isinstance(parent_ctx, dict): + return None, None + trace_override = parent_ctx.get("parent_workflow_run_id") + parent_span = parent_ctx.get("parent_node_execution_id") + return ( + trace_override if isinstance(trace_override, str) else None, + parent_span if isinstance(parent_span, str) else None, + ) + @field_serializer("start_time", "end_time") def serialize_datetime(self, dt: datetime | None) -> str | None: if dt is None: @@ -48,10 +90,14 @@ class WorkflowTraceInfo(BaseTraceInfo): workflow_run_version: str error: str | None = None total_tokens: int + prompt_tokens: int | None = None + completion_tokens: int | None = None file_list: list[str] query: str metadata: dict[str, Any] + invoked_by: str | None = None + class MessageTraceInfo(BaseTraceInfo): conversation_model: str @@ -59,7 +105,7 @@ class MessageTraceInfo(BaseTraceInfo): answer_tokens: int total_tokens: int error: str | None = None - file_list: Union[str, dict[str, Any], list] | None = None + file_list: Union[str, dict[str, Any], list[Any]] | None = None message_file_data: Any | None = None conversation_mode: str gen_ai_server_time_to_first_token: float | None = None @@ -106,7 +152,7 @@ class ToolTraceInfo(BaseTraceInfo): tool_config: dict[str, Any] time_cost: Union[int, float] tool_parameters: dict[str, Any] - file_url: Union[str, None, list] = None + file_url: Union[str, None, list[str]] = None class GenerateNameTraceInfo(BaseTraceInfo): @@ -114,6 +160,79 @@ class GenerateNameTraceInfo(BaseTraceInfo): tenant_id: str +class PromptGenerationTraceInfo(BaseTraceInfo): + """Trace information for prompt generation operations (rule-generate, code-generate, etc.).""" + + tenant_id: str + user_id: str + app_id: str | None = None + + operation_type: str + instruction: str + + prompt_tokens: int + completion_tokens: int + total_tokens: int + + model_provider: str + model_name: str + + latency: float + + total_price: float | None = None + currency: str | None = None + + error: str | None = None + + model_config = ConfigDict(protected_namespaces=()) + + +class WorkflowNodeTraceInfo(BaseTraceInfo): + workflow_id: str + workflow_run_id: str + tenant_id: str + node_execution_id: str + node_id: str + node_type: str + title: str + + status: str + error: str | None = None + elapsed_time: float + + index: int + predecessor_node_id: str | None = None + + total_tokens: int = 0 + total_price: float = 0.0 + currency: str | None = None + + model_provider: str | None = None + model_name: str | None = None + prompt_tokens: int | None = None + completion_tokens: int | None = None + + tool_name: str | None = None + + iteration_id: str | None = None + iteration_index: int | None = None + loop_id: str | None = None + loop_index: int | None = None + parallel_id: str | None = None + + node_inputs: Mapping[str, Any] | None = None + node_outputs: Mapping[str, Any] | None = None + process_data: Mapping[str, Any] | None = None + + invoked_by: str | None = None + + model_config = ConfigDict(protected_namespaces=()) + + +class DraftNodeExecutionTrace(WorkflowNodeTraceInfo): + pass + + class TaskData(BaseModel): app_id: str trace_info_type: str @@ -128,16 +247,38 @@ trace_info_info_map = { "DatasetRetrievalTraceInfo": DatasetRetrievalTraceInfo, "ToolTraceInfo": ToolTraceInfo, "GenerateNameTraceInfo": GenerateNameTraceInfo, + "PromptGenerationTraceInfo": PromptGenerationTraceInfo, + "WorkflowNodeTraceInfo": WorkflowNodeTraceInfo, + "DraftNodeExecutionTrace": DraftNodeExecutionTrace, } +class OperationType(StrEnum): + """Operation type for token metric labels. + + Used as a metric attribute on ``dify.tokens.input`` / ``dify.tokens.output`` + counters so consumers can break down token usage by operation. + """ + + WORKFLOW = "workflow" + NODE_EXECUTION = "node_execution" + MESSAGE = "message" + RULE_GENERATE = "rule_generate" + CODE_GENERATE = "code_generate" + STRUCTURED_OUTPUT = "structured_output" + INSTRUCTION_MODIFY = "instruction_modify" + + class TraceTaskName(StrEnum): CONVERSATION_TRACE = "conversation" WORKFLOW_TRACE = "workflow" + DRAFT_NODE_EXECUTION_TRACE = "draft_node_execution" MESSAGE_TRACE = "message" MODERATION_TRACE = "moderation" SUGGESTED_QUESTION_TRACE = "suggested_question" DATASET_RETRIEVAL_TRACE = "dataset_retrieval" TOOL_TRACE = "tool" GENERATE_NAME_TRACE = "generate_conversation_name" + PROMPT_GENERATION_TRACE = "prompt_generation" DATASOURCE_TRACE = "datasource" + NODE_EXECUTION_TRACE = "node_execution" diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index 4de4f403ce..422a121311 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -3,6 +3,7 @@ import os from datetime import datetime, timedelta from langfuse import Langfuse +from sqlalchemy import select from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance @@ -30,7 +31,7 @@ from core.ops.utils import filter_none_values from core.repositories import DifyCoreRepositoryFactory from core.workflow.enums import NodeType from extensions.ext_database import db -from models import EndUser, WorkflowNodeExecutionTriggeredFrom +from models import EndUser, Message, WorkflowNodeExecutionTriggeredFrom from models.enums import MessageStatus logger = logging.getLogger(__name__) @@ -71,7 +72,50 @@ class LangFuseDataTrace(BaseTraceInstance): metadata = trace_info.metadata metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id - if trace_info.message_id: + # Check for parent_trace_context to detect nested workflow + parent_trace_context = trace_info.metadata.get("parent_trace_context") + + if parent_trace_context: + # Nested workflow: create span under outer trace + outer_trace_id = parent_trace_context.get("trace_id") + parent_node_execution_id = parent_trace_context.get("parent_node_execution_id") + parent_conversation_id = parent_trace_context.get("parent_conversation_id") + parent_workflow_run_id = parent_trace_context.get("parent_workflow_run_id") + + # Resolve outer trace_id: try message_id lookup first, fallback to workflow_run_id + if parent_conversation_id: + session_factory = sessionmaker(bind=db.engine) + with session_factory() as session: + message_data_stmt = select(Message.id).where( + Message.conversation_id == parent_conversation_id, + Message.workflow_run_id == parent_workflow_run_id, + ) + resolved_message_id = session.scalar(message_data_stmt) + if resolved_message_id: + outer_trace_id = resolved_message_id + else: + outer_trace_id = parent_workflow_run_id + else: + outer_trace_id = parent_workflow_run_id + + # Create inner workflow span under outer trace + workflow_span_data = LangfuseSpan( + id=trace_info.workflow_run_id, + name=TraceTaskName.WORKFLOW_TRACE, + input=dict(trace_info.workflow_run_inputs), + output=dict(trace_info.workflow_run_outputs), + trace_id=outer_trace_id, + parent_observation_id=parent_node_execution_id, + start_time=trace_info.start_time, + end_time=trace_info.end_time, + metadata=metadata, + level=LevelEnum.DEFAULT if trace_info.error == "" else LevelEnum.ERROR, + status_message=trace_info.error or "", + ) + self.add_span(langfuse_span_data=workflow_span_data) + # Use outer_trace_id for all node spans/generations + trace_id = outer_trace_id + elif trace_info.message_id: trace_id = trace_info.trace_id or trace_info.message_id name = TraceTaskName.MESSAGE_TRACE trace_data = LangfuseTrace( @@ -174,6 +218,11 @@ class LangFuseDataTrace(BaseTraceInstance): } ) + # Determine parent_observation_id for nested workflows + node_parent_observation_id = None + if parent_trace_context or trace_info.message_id: + node_parent_observation_id = trace_info.workflow_run_id + # add generation span if process_data and process_data.get("model_mode") == "chat": total_token = metadata.get("total_tokens", 0) @@ -206,7 +255,7 @@ class LangFuseDataTrace(BaseTraceInstance): metadata=metadata, level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR), status_message=trace_info.error or "", - parent_observation_id=trace_info.workflow_run_id if trace_info.message_id else None, + parent_observation_id=node_parent_observation_id, usage=generation_usage, ) @@ -225,7 +274,7 @@ class LangFuseDataTrace(BaseTraceInstance): metadata=metadata, level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR), status_message=trace_info.error or "", - parent_observation_id=trace_info.workflow_run_id if trace_info.message_id else None, + parent_observation_id=node_parent_observation_id, ) self.add_span(langfuse_span_data=span_data) diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index 8b8117b24c..7ca51e10ef 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -6,6 +6,7 @@ from typing import cast from langsmith import Client from langsmith.schemas import RunBase +from sqlalchemy import select from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance @@ -30,7 +31,7 @@ from core.ops.utils import filter_none_values, generate_dotted_order from core.repositories import DifyCoreRepositoryFactory from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey from extensions.ext_database import db -from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom +from models import EndUser, Message, MessageFile, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) @@ -64,7 +65,35 @@ class LangSmithDataTrace(BaseTraceInstance): self.generate_name_trace(trace_info) def workflow_trace(self, trace_info: WorkflowTraceInfo): - trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id + # Check for parent_trace_context for cross-workflow linking + parent_trace_context = trace_info.metadata.get("parent_trace_context") + + if parent_trace_context: + # Inner workflow: resolve outer trace_id and link to parent node + outer_trace_id = parent_trace_context.get("parent_workflow_run_id") + + # Try to resolve message_id from conversation_id if available + if parent_trace_context.get("parent_conversation_id"): + try: + session_factory = sessionmaker(bind=db.engine) + with session_factory() as session: + message_data_stmt = select(Message.id).where( + Message.conversation_id == parent_trace_context["parent_conversation_id"], + Message.workflow_run_id == parent_trace_context["parent_workflow_run_id"], + ) + resolved_message_id = session.scalar(message_data_stmt) + if resolved_message_id: + outer_trace_id = resolved_message_id + except Exception as e: + logger.debug("Failed to resolve message_id from conversation_id: %s", str(e)) + + trace_id = outer_trace_id + parent_run_id = parent_trace_context.get("parent_node_execution_id") + else: + # Outer workflow: existing behavior + trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id + parent_run_id = trace_info.message_id or None + if trace_info.start_time is None: trace_info.start_time = datetime.now() message_dotted_order = ( @@ -78,7 +107,8 @@ class LangSmithDataTrace(BaseTraceInstance): metadata = trace_info.metadata metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id - if trace_info.message_id: + # Only create message_run for outer workflows (no parent_trace_context) + if trace_info.message_id and not parent_trace_context: message_run = LangSmithRunModel( id=trace_info.message_id, name=TraceTaskName.MESSAGE_TRACE, @@ -121,9 +151,9 @@ class LangSmithDataTrace(BaseTraceInstance): }, error=trace_info.error, tags=["workflow"], - parent_run_id=trace_info.message_id or None, + parent_run_id=parent_run_id, trace_id=trace_id, - dotted_order=workflow_dotted_order, + dotted_order=None if parent_trace_context else workflow_dotted_order, serialized=None, events=[], session_id=None, diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 84f5bf5512..758ee3d494 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -21,19 +21,26 @@ from core.ops.entities.config_entity import ( ) from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, + DraftNodeExecutionTrace, GenerateNameTraceInfo, MessageTraceInfo, ModerationTraceInfo, + PromptGenerationTraceInfo, SuggestedQuestionTraceInfo, TaskData, ToolTraceInfo, TraceTaskName, + WorkflowNodeTraceInfo, WorkflowTraceInfo, ) from core.ops.utils import get_message_data from extensions.ext_database import db from extensions.ext_storage import storage +from models.account import Tenant +from models.dataset import Dataset from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig +from models.provider import Provider, ProviderCredential, ProviderModel, ProviderModelCredential, ProviderType +from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider from models.workflow import WorkflowAppLog from tasks.ops_trace_task import process_trace_tasks @@ -43,6 +50,139 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +def _lookup_app_and_workspace_names(app_id: str | None, tenant_id: str | None) -> tuple[str, str]: + """Return (app_name, workspace_name) for the given IDs. Falls back to empty strings.""" + app_name = "" + workspace_name = "" + if not app_id and not tenant_id: + return app_name, workspace_name + with Session(db.engine) as session: + if app_id: + name = session.scalar(select(App.name).where(App.id == app_id)) + if name: + app_name = name + if tenant_id: + name = session.scalar(select(Tenant.name).where(Tenant.id == tenant_id)) + if name: + workspace_name = name + return app_name, workspace_name + + +_PROVIDER_TYPE_TO_MODEL: dict[str, type] = { + "builtin": BuiltinToolProvider, + "plugin": BuiltinToolProvider, + "api": ApiToolProvider, + "workflow": WorkflowToolProvider, + "mcp": MCPToolProvider, +} + + +def _lookup_credential_name(credential_id: str | None, provider_type: str | None) -> str: + if not credential_id: + return "" + model_cls = _PROVIDER_TYPE_TO_MODEL.get(provider_type or "") + if not model_cls: + return "" + with Session(db.engine) as session: + name = session.scalar(select(model_cls.name).where(model_cls.id == credential_id)) + return str(name) if name else "" + + +def _lookup_llm_credential_info( + tenant_id: str | None, provider: str | None, model: str | None, model_type: str | None = "llm" +) -> tuple[str | None, str]: + """ + Lookup LLM credential ID and name for the given provider and model. + Returns (credential_id, credential_name). + + Handles async timing issues gracefully - if credential is deleted between lookups, + returns the ID but empty name rather than failing. + """ + if not tenant_id or not provider: + return None, "" + + try: + with Session(db.engine) as session: + # Try to find provider-level or model-level configuration + provider_record = session.scalar( + select(Provider).where( + Provider.tenant_id == tenant_id, + Provider.provider_name == provider, + Provider.provider_type == ProviderType.CUSTOM, + ) + ) + + if not provider_record: + return None, "" + + # Check if there's a model-specific config + credential_id = None + credential_name = "" + is_model_level = False + + if model and provider_record.credential_id: + # Try model-level first + model_record = session.scalar( + select(ProviderModel).where( + ProviderModel.tenant_id == tenant_id, + ProviderModel.provider_name == provider, + ProviderModel.model_name == model, + ProviderModel.model_type == model_type, + ) + ) + + if model_record and model_record.credential_id: + credential_id = model_record.credential_id + is_model_level = True + + if not credential_id and provider_record.credential_id: + # Fall back to provider-level credential + credential_id = provider_record.credential_id + is_model_level = False + + # Lookup credential_name if we have credential_id + if credential_id: + try: + if is_model_level: + # Query ProviderModelCredential + cred_name = session.scalar( + select(ProviderModelCredential.credential_name).where( + ProviderModelCredential.id == credential_id + ) + ) + else: + # Query ProviderCredential + cred_name = session.scalar( + select(ProviderCredential.credential_name).where(ProviderCredential.id == credential_id) + ) + + if cred_name: + credential_name = str(cred_name) + except Exception as e: + # Credential might have been deleted between lookups (async timing) + # Return ID but empty name rather than failing + logger.warning( + "Failed to lookup credential name for credential_id=%s (provider=%s, model=%s): %s", + credential_id, + provider, + model, + str(e), + ) + + return credential_id, credential_name + except Exception as e: + # Database query failed or other unexpected error + # Return empty rather than propagating error to telemetry emission + logger.warning( + "Failed to lookup LLM credential info for tenant_id=%s, provider=%s, model=%s: %s", + tenant_id, + provider, + model, + str(e), + ) + return None, "" + + class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]): def __getitem__(self, provider: str) -> dict[str, Any]: match provider: @@ -317,6 +457,10 @@ class OpsTraceManager: if app_id is None: return None + # Handle storage_id format (tenant-{uuid}) - not a real app_id + if isinstance(app_id, str) and app_id.startswith("tenant-"): + return None + app: App | None = db.session.query(App).where(App.id == app_id).first() if app is None: @@ -479,6 +623,56 @@ class TraceTask: cls._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) return cls._workflow_run_repo + @classmethod + def _get_user_id_from_metadata(cls, metadata: dict[str, Any]) -> str: + """Extract user ID from metadata, prioritizing end_user over account. + + Returns the actual user ID (end_user or account) who invoked the workflow, + regardless of invoke_from context. + """ + # Priority 1: End user (external users via API/WebApp) + if user_id := metadata.get("from_end_user_id"): + return f"end_user:{user_id}" + + # Priority 2: Account user (internal users via console/debugger) + if user_id := metadata.get("from_account_id"): + return f"account:{user_id}" + + # Priority 3: User (internal users via console/debugger) + if user_id := metadata.get("user_id"): + return f"user:{user_id}" + + return "anonymous" + + @classmethod + def _calculate_workflow_token_split(cls, workflow_run_id: str, tenant_id: str) -> tuple[int, int]: + from core.workflow.enums import WorkflowNodeExecutionMetadataKey + from models.workflow import WorkflowNodeExecutionModel + + with Session(db.engine) as session: + node_executions = session.scalars( + select(WorkflowNodeExecutionModel).where( + WorkflowNodeExecutionModel.tenant_id == tenant_id, + WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id, + ) + ).all() + + total_prompt = 0 + total_completion = 0 + + for node_exec in node_executions: + metadata = node_exec.execution_metadata_dict + + prompt = metadata.get(WorkflowNodeExecutionMetadataKey.PROMPT_TOKENS) + if prompt is not None: + total_prompt += prompt + + completion = metadata.get(WorkflowNodeExecutionMetadataKey.COMPLETION_TOKENS) + if completion is not None: + total_completion += completion + + return (total_prompt, total_completion) + def __init__( self, trace_type: Any, @@ -499,6 +693,8 @@ class TraceTask: self.app_id = None self.trace_id = None self.kwargs = kwargs + if user_id is not None and "user_id" not in self.kwargs: + self.kwargs["user_id"] = user_id external_trace_id = kwargs.get("external_trace_id") if external_trace_id: self.trace_id = external_trace_id @@ -512,7 +708,7 @@ class TraceTask: TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace( workflow_run_id=self.workflow_run_id, conversation_id=self.conversation_id, user_id=self.user_id ), - TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id), + TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id, **self.kwargs), TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace( message_id=self.message_id, timer=self.timer, **self.kwargs ), @@ -528,6 +724,9 @@ class TraceTask: TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace( conversation_id=self.conversation_id, timer=self.timer, **self.kwargs ), + TraceTaskName.PROMPT_GENERATION_TRACE: lambda: self.prompt_generation_trace(**self.kwargs), + TraceTaskName.NODE_EXECUTION_TRACE: lambda: self.node_execution_trace(**self.kwargs), + TraceTaskName.DRAFT_NODE_EXECUTION_TRACE: lambda: self.draft_node_execution_trace(**self.kwargs), } return preprocess_map.get(self.trace_type, lambda: None)() @@ -563,6 +762,10 @@ class TraceTask: total_tokens = workflow_run.total_tokens + prompt_tokens, completion_tokens = self._calculate_workflow_token_split( + workflow_run_id=workflow_run_id, tenant_id=tenant_id + ) + file_list = workflow_run_inputs.get("sys.file") or [] query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or "" @@ -583,7 +786,14 @@ class TraceTask: ) message_id = session.scalar(message_data_stmt) - metadata = { + from core.telemetry.gateway import is_enterprise_telemetry_enabled + + if is_enterprise_telemetry_enabled(): + app_name, workspace_name = _lookup_app_and_workspace_names(workflow_run.app_id, tenant_id) + else: + app_name, workspace_name = "", "" + + metadata: dict[str, Any] = { "workflow_id": workflow_id, "conversation_id": conversation_id, "workflow_run_id": workflow_run_id, @@ -596,8 +806,14 @@ class TraceTask: "triggered_from": workflow_run.triggered_from, "user_id": user_id, "app_id": workflow_run.app_id, + "app_name": app_name, + "workspace_name": workspace_name, } + parent_trace_context = self.kwargs.get("parent_trace_context") + if parent_trace_context: + metadata["parent_trace_context"] = parent_trace_context + workflow_trace_info = WorkflowTraceInfo( trace_id=self.trace_id, workflow_data=workflow_run.to_dict(), @@ -612,6 +828,8 @@ class TraceTask: workflow_run_version=workflow_run_version, error=error, total_tokens=total_tokens, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, file_list=file_list, query=query, metadata=metadata, @@ -619,10 +837,11 @@ class TraceTask: message_id=message_id, start_time=workflow_run.created_at, end_time=workflow_run.finished_at, + invoked_by=self._get_user_id_from_metadata(metadata), ) return workflow_trace_info - def message_trace(self, message_id: str | None): + def message_trace(self, message_id: str | None, **kwargs): if not message_id: return {} message_data = get_message_data(message_id) @@ -645,6 +864,19 @@ class TraceTask: streaming_metrics = self._extract_streaming_metrics(message_data) + tenant_id = "" + with Session(db.engine) as session: + tid = session.scalar(select(App.tenant_id).where(App.id == message_data.app_id)) + if tid: + tenant_id = str(tid) + + from core.telemetry.gateway import is_enterprise_telemetry_enabled + + if is_enterprise_telemetry_enabled(): + app_name, workspace_name = _lookup_app_and_workspace_names(message_data.app_id, tenant_id) + else: + app_name, workspace_name = "", "" + metadata = { "conversation_id": message_data.conversation_id, "ls_provider": message_data.model_provider, @@ -656,7 +888,14 @@ class TraceTask: "workflow_run_id": message_data.workflow_run_id, "from_source": message_data.from_source, "message_id": message_id, + "tenant_id": tenant_id, + "app_id": message_data.app_id, + "user_id": message_data.from_end_user_id or message_data.from_account_id, + "app_name": app_name, + "workspace_name": workspace_name, } + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id message_tokens = message_data.message_tokens @@ -673,7 +912,9 @@ class TraceTask: outputs=message_data.answer, file_list=file_list, start_time=created_at, - end_time=created_at + timedelta(seconds=message_data.provider_response_latency), + end_time=message_data.updated_at + if message_data.updated_at and message_data.updated_at > created_at + else created_at + timedelta(seconds=message_data.provider_response_latency), metadata=metadata, message_file_data=message_file_data, conversation_mode=conversation_mode, @@ -698,6 +939,8 @@ class TraceTask: "preset_response": moderation_result.preset_response, "query": moderation_result.query, } + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id # get workflow_app_log_id workflow_app_log_id = None @@ -739,6 +982,8 @@ class TraceTask: "workflow_run_id": message_data.workflow_run_id, "from_source": message_data.from_source, } + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id # get workflow_app_log_id workflow_app_log_id = None @@ -778,6 +1023,52 @@ class TraceTask: if not message_data: return {} + tenant_id = "" + with Session(db.engine) as session: + tid = session.scalar(select(App.tenant_id).where(App.id == message_data.app_id)) + if tid: + tenant_id = str(tid) + + from core.telemetry.gateway import is_enterprise_telemetry_enabled + + if is_enterprise_telemetry_enabled(): + app_name, workspace_name = _lookup_app_and_workspace_names(message_data.app_id, tenant_id) + else: + app_name, workspace_name = "", "" + + doc_list = [doc.model_dump() for doc in documents] if documents else [] + dataset_ids: set[str] = set() + for doc in doc_list: + doc_meta = doc.get("metadata") or {} + did = doc_meta.get("dataset_id") + if did: + dataset_ids.add(did) + + embedding_models: dict[str, dict[str, str]] = {} + if dataset_ids: + with Session(db.engine) as session: + rows = session.execute( + select(Dataset.id, Dataset.embedding_model, Dataset.embedding_model_provider).where( + Dataset.id.in_(list(dataset_ids)) + ) + ).all() + for row in rows: + embedding_models[str(row[0])] = { + "embedding_model": row[1] or "", + "embedding_model_provider": row[2] or "", + } + + # Extract rerank model info from retrieval_model kwargs + rerank_model_provider = "" + rerank_model_name = "" + if "retrieval_model" in kwargs: + retrieval_model = kwargs["retrieval_model"] + if isinstance(retrieval_model, dict): + reranking_model = retrieval_model.get("reranking_model") + if isinstance(reranking_model, dict): + rerank_model_provider = reranking_model.get("reranking_provider_name", "") + rerank_model_name = reranking_model.get("reranking_model_name", "") + metadata = { "message_id": message_id, "ls_provider": message_data.model_provider, @@ -788,13 +1079,23 @@ class TraceTask: "agent_based": message_data.agent_based, "workflow_run_id": message_data.workflow_run_id, "from_source": message_data.from_source, + "tenant_id": tenant_id, + "app_id": message_data.app_id, + "user_id": message_data.from_end_user_id or message_data.from_account_id, + "app_name": app_name, + "workspace_name": workspace_name, + "embedding_models": embedding_models, + "rerank_model_provider": rerank_model_provider, + "rerank_model_name": rerank_model_name, } + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id dataset_retrieval_trace_info = DatasetRetrievalTraceInfo( trace_id=self.trace_id, message_id=message_id, inputs=message_data.query or message_data.inputs, - documents=[doc.model_dump() for doc in documents] if documents else [], + documents=doc_list, start_time=timer.get("start"), end_time=timer.get("end"), metadata=metadata, @@ -837,6 +1138,10 @@ class TraceTask: "error": error, "tool_parameters": tool_parameters, } + if message_data.workflow_run_id: + metadata["workflow_run_id"] = message_data.workflow_run_id + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id file_url = "" message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first() @@ -891,6 +1196,8 @@ class TraceTask: "conversation_id": conversation_id, "tenant_id": tenant_id, } + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id generate_name_trace_info = GenerateNameTraceInfo( trace_id=self.trace_id, @@ -905,6 +1212,182 @@ class TraceTask: return generate_name_trace_info + def prompt_generation_trace(self, **kwargs) -> PromptGenerationTraceInfo | dict: + tenant_id = kwargs.get("tenant_id", "") + user_id = kwargs.get("user_id", "") + app_id = kwargs.get("app_id") + operation_type = kwargs.get("operation_type", "") + instruction = kwargs.get("instruction", "") + generated_output = kwargs.get("generated_output", "") + + prompt_tokens = kwargs.get("prompt_tokens", 0) + completion_tokens = kwargs.get("completion_tokens", 0) + total_tokens = kwargs.get("total_tokens", 0) + + model_provider = kwargs.get("model_provider", "") + model_name = kwargs.get("model_name", "") + + latency = kwargs.get("latency", 0.0) + + timer = kwargs.get("timer") + start_time = timer.get("start") if timer else None + end_time = timer.get("end") if timer else None + + total_price = kwargs.get("total_price") + currency = kwargs.get("currency") + + error = kwargs.get("error") + + app_name = None + workspace_name = None + if app_id: + app_name, workspace_name = _lookup_app_and_workspace_names(app_id, tenant_id) + + metadata = { + "tenant_id": tenant_id, + "user_id": user_id, + "app_id": app_id or "", + "app_name": app_name, + "workspace_name": workspace_name, + "operation_type": operation_type, + "model_provider": model_provider, + "model_name": model_name, + } + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id + + return PromptGenerationTraceInfo( + trace_id=self.trace_id, + inputs=instruction, + outputs=generated_output, + start_time=start_time, + end_time=end_time, + metadata=metadata, + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + operation_type=operation_type, + instruction=instruction, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + model_provider=model_provider, + model_name=model_name, + latency=latency, + total_price=total_price, + currency=currency, + error=error, + ) + + def node_execution_trace(self, **kwargs) -> WorkflowNodeTraceInfo | dict: + node_data: dict = kwargs.get("node_execution_data", {}) + if not node_data: + return {} + + from core.telemetry.gateway import is_enterprise_telemetry_enabled + + if is_enterprise_telemetry_enabled(): + app_name, workspace_name = _lookup_app_and_workspace_names( + node_data.get("app_id"), node_data.get("tenant_id") + ) + else: + app_name, workspace_name = "", "" + + # Try tool credential lookup first + credential_id = node_data.get("credential_id") + if is_enterprise_telemetry_enabled(): + credential_name = _lookup_credential_name(credential_id, node_data.get("credential_provider_type")) + # If no credential_id found (e.g., LLM nodes), try LLM credential lookup + if not credential_id: + llm_cred_id, llm_cred_name = _lookup_llm_credential_info( + tenant_id=node_data.get("tenant_id"), + provider=node_data.get("model_provider"), + model=node_data.get("model_name"), + model_type="llm", + ) + if llm_cred_id: + credential_id = llm_cred_id + credential_name = llm_cred_name + else: + credential_name = "" + metadata: dict[str, Any] = { + "tenant_id": node_data.get("tenant_id"), + "app_id": node_data.get("app_id"), + "app_name": app_name, + "workspace_name": workspace_name, + "user_id": node_data.get("user_id"), + "invoke_from": node_data.get("invoke_from"), + "credential_id": node_data.get("credential_id"), + "credential_name": credential_name, + "dataset_ids": node_data.get("dataset_ids"), + "dataset_names": node_data.get("dataset_names"), + "plugin_name": node_data.get("plugin_name"), + } + + parent_trace_context = node_data.get("parent_trace_context") + if parent_trace_context: + metadata["parent_trace_context"] = parent_trace_context + + message_id: str | None = None + conversation_id = node_data.get("conversation_id") + workflow_execution_id = node_data.get("workflow_execution_id") + if conversation_id and workflow_execution_id and not parent_trace_context: + with Session(db.engine) as session: + msg_id = session.scalar( + select(Message.id).where( + Message.conversation_id == conversation_id, + Message.workflow_run_id == workflow_execution_id, + ) + ) + if msg_id: + message_id = str(msg_id) + metadata["message_id"] = message_id + if conversation_id: + metadata["conversation_id"] = conversation_id + + return WorkflowNodeTraceInfo( + trace_id=self.trace_id, + message_id=message_id, + start_time=node_data.get("created_at"), + end_time=node_data.get("finished_at"), + metadata=metadata, + workflow_id=node_data.get("workflow_id", ""), + workflow_run_id=node_data.get("workflow_execution_id", ""), + tenant_id=node_data.get("tenant_id", ""), + node_execution_id=node_data.get("node_execution_id", ""), + node_id=node_data.get("node_id", ""), + node_type=node_data.get("node_type", ""), + title=node_data.get("title", ""), + status=node_data.get("status", ""), + error=node_data.get("error"), + elapsed_time=node_data.get("elapsed_time", 0.0), + index=node_data.get("index", 0), + predecessor_node_id=node_data.get("predecessor_node_id"), + total_tokens=node_data.get("total_tokens", 0), + total_price=node_data.get("total_price", 0.0), + currency=node_data.get("currency"), + model_provider=node_data.get("model_provider"), + model_name=node_data.get("model_name"), + prompt_tokens=node_data.get("prompt_tokens"), + completion_tokens=node_data.get("completion_tokens"), + tool_name=node_data.get("tool_name"), + iteration_id=node_data.get("iteration_id"), + iteration_index=node_data.get("iteration_index"), + loop_id=node_data.get("loop_id"), + loop_index=node_data.get("loop_index"), + parallel_id=node_data.get("parallel_id"), + node_inputs=node_data.get("node_inputs"), + node_outputs=node_data.get("node_outputs"), + process_data=node_data.get("process_data"), + invoked_by=self._get_user_id_from_metadata(metadata), + ) + + def draft_node_execution_trace(self, **kwargs) -> DraftNodeExecutionTrace | dict: + node_trace = self.node_execution_trace(**kwargs) + if not node_trace or not isinstance(node_trace, WorkflowNodeTraceInfo): + return node_trace + return DraftNodeExecutionTrace(**node_trace.model_dump()) + def _extract_streaming_metrics(self, message_data) -> dict: if not message_data.message_metadata: return {} @@ -938,13 +1421,17 @@ class TraceQueueManager: self.user_id = user_id self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id) self.flask_app = current_app._get_current_object() # type: ignore + + from core.telemetry.gateway import is_enterprise_telemetry_enabled + + self._enterprise_telemetry_enabled = is_enterprise_telemetry_enabled() if trace_manager_timer is None: self.start_timer() def add_trace_task(self, trace_task: TraceTask): global trace_manager_timer, trace_manager_queue try: - if self.trace_instance: + if self._enterprise_telemetry_enabled or self.trace_instance: trace_task.app_id = self.app_id trace_manager_queue.put(trace_task) except Exception: @@ -980,20 +1467,27 @@ class TraceQueueManager: def send_to_celery(self, tasks: list[TraceTask]): with self.flask_app.app_context(): for task in tasks: - if task.app_id is None: - continue + storage_id = task.app_id + if storage_id is None: + tenant_id = task.kwargs.get("tenant_id") + if tenant_id: + storage_id = f"tenant-{tenant_id}" + else: + logger.warning("Skipping trace without app_id or tenant_id, trace_type: %s", task.trace_type) + continue + file_id = uuid4().hex trace_info = task.execute() task_data = TaskData( - app_id=task.app_id, + app_id=storage_id, trace_info_type=type(trace_info).__name__, trace_info=trace_info.model_dump() if trace_info else None, ) - file_path = f"{OPS_FILE_PATH}{task.app_id}/{file_id}.json" + file_path = f"{OPS_FILE_PATH}{storage_id}/{file_id}.json" storage.save(file_path, task_data.model_dump_json().encode("utf-8")) file_info = { "file_id": file_id, - "app_id": task.app_id, + "app_id": storage_id, } process_trace_tasks.delay(file_info) # type: ignore diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 541c241ae5..33884378ce 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -27,8 +27,7 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.ops.entities.trace_entity import TraceTaskName -from core.ops.ops_trace_manager import TraceQueueManager, TraceTask +from core.ops.ops_trace_manager import TraceQueueManager from core.ops.utils import measure_time from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate @@ -56,6 +55,8 @@ from core.rag.retrieval.template_prompts import ( METADATA_FILTER_USER_PROMPT_2, METADATA_FILTER_USER_PROMPT_3, ) +from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName +from core.telemetry import emit as telemetry_emit from core.tools.signature import sign_upload_file from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from extensions.ext_database import db @@ -728,10 +729,21 @@ class DatasetRetrieval: self.application_generate_entity.trace_manager if self.application_generate_entity else None ) if trace_manager: - trace_manager.add_trace_task( - TraceTask( - TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer - ) + app_config = self.application_generate_entity.app_config if self.application_generate_entity else None + telemetry_emit( + TelemetryEvent( + name=TraceTaskName.DATASET_RETRIEVAL_TRACE, + context=TelemetryContext( + tenant_id=app_config.tenant_id if app_config else None, + app_id=app_config.app_id if app_config else None, + ), + payload={ + "message_id": message_id, + "documents": documents, + "timer": timer, + }, + ), + trace_manager=trace_manager, ) def _on_query( diff --git a/api/core/telemetry/__init__.py b/api/core/telemetry/__init__.py new file mode 100644 index 0000000000..ae4f53f3b7 --- /dev/null +++ b/api/core/telemetry/__init__.py @@ -0,0 +1,43 @@ +"""Telemetry facade. + +Thin public API for emitting telemetry events. All routing logic +lives in ``core.telemetry.gateway`` which is shared by both CE and EE. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from core.ops.entities.trace_entity import TraceTaskName +from core.telemetry.events import TelemetryContext, TelemetryEvent +from core.telemetry.gateway import emit as gateway_emit +from core.telemetry.gateway import get_trace_task_to_case + +if TYPE_CHECKING: + from core.ops.ops_trace_manager import TraceQueueManager + + +def emit(event: TelemetryEvent, trace_manager: TraceQueueManager | None = None) -> None: + """Emit a telemetry event. + + Translates the ``TelemetryEvent`` (keyed by ``TraceTaskName``) into a + ``TelemetryCase`` and delegates to ``core.telemetry.gateway.emit()``. + """ + case = get_trace_task_to_case().get(event.name) + if case is None: + return + + context: dict[str, object] = { + "tenant_id": event.context.tenant_id, + "user_id": event.context.user_id, + "app_id": event.context.app_id, + } + gateway_emit(case, context, event.payload, trace_manager) + + +__all__ = [ + "TelemetryContext", + "TelemetryEvent", + "TraceTaskName", + "emit", +] diff --git a/api/core/telemetry/events.py b/api/core/telemetry/events.py new file mode 100644 index 0000000000..35ace47510 --- /dev/null +++ b/api/core/telemetry/events.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from core.ops.entities.trace_entity import TraceTaskName + + +@dataclass(frozen=True) +class TelemetryContext: + tenant_id: str | None = None + user_id: str | None = None + app_id: str | None = None + + +@dataclass(frozen=True) +class TelemetryEvent: + name: TraceTaskName + context: TelemetryContext + payload: dict[str, Any] diff --git a/api/core/telemetry/gateway.py b/api/core/telemetry/gateway.py new file mode 100644 index 0000000000..b2b8d3d470 --- /dev/null +++ b/api/core/telemetry/gateway.py @@ -0,0 +1,229 @@ +"""Telemetry gateway — single routing layer for all editions. + +Maps ``TelemetryCase`` → ``CaseRoute`` and dispatches events to either +the CE/EE trace pipeline (``TraceQueueManager``) or the enterprise-only +metric/log Celery queue. + +This module lives in ``core/`` so both CE and EE share one routing table +and one ``emit()`` entry point. No separate enterprise gateway module is +needed — enterprise-specific dispatch (Celery task, payload offloading) +is handled here behind lazy imports that no-op in CE. +""" + +from __future__ import annotations + +import json +import logging +import uuid +from typing import TYPE_CHECKING, Any + +from core.ops.entities.trace_entity import TraceTaskName +from extensions.ext_storage import storage + +if TYPE_CHECKING: + from core.ops.ops_trace_manager import TraceQueueManager + from enterprise.telemetry.contracts import TelemetryCase + +logger = logging.getLogger(__name__) + +PAYLOAD_SIZE_THRESHOLD_BYTES = 1 * 1024 * 1024 + +# --------------------------------------------------------------------------- +# Routing table — authoritative mapping for all editions +# --------------------------------------------------------------------------- + +_case_to_trace_task: dict | None = None +_case_routing: dict | None = None + + +def _get_case_to_trace_task() -> dict: + global _case_to_trace_task + if _case_to_trace_task is None: + from enterprise.telemetry.contracts import TelemetryCase + + _case_to_trace_task = { + TelemetryCase.WORKFLOW_RUN: TraceTaskName.WORKFLOW_TRACE, + TelemetryCase.MESSAGE_RUN: TraceTaskName.MESSAGE_TRACE, + TelemetryCase.NODE_EXECUTION: TraceTaskName.NODE_EXECUTION_TRACE, + TelemetryCase.DRAFT_NODE_EXECUTION: TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + TelemetryCase.PROMPT_GENERATION: TraceTaskName.PROMPT_GENERATION_TRACE, + TelemetryCase.TOOL_EXECUTION: TraceTaskName.TOOL_TRACE, + TelemetryCase.MODERATION_CHECK: TraceTaskName.MODERATION_TRACE, + TelemetryCase.SUGGESTED_QUESTION: TraceTaskName.SUGGESTED_QUESTION_TRACE, + TelemetryCase.DATASET_RETRIEVAL: TraceTaskName.DATASET_RETRIEVAL_TRACE, + TelemetryCase.GENERATE_NAME: TraceTaskName.GENERATE_NAME_TRACE, + } + return _case_to_trace_task + + +def get_trace_task_to_case() -> dict: + """Return TraceTaskName → TelemetryCase (inverse of _get_case_to_trace_task).""" + return {v: k for k, v in _get_case_to_trace_task().items()} + + +def _get_case_routing() -> dict: + global _case_routing + if _case_routing is None: + from enterprise.telemetry.contracts import CaseRoute, SignalType, TelemetryCase + + _case_routing = { + # TRACE — CE-eligible (flow in both CE and EE) + TelemetryCase.WORKFLOW_RUN: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + TelemetryCase.MESSAGE_RUN: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + TelemetryCase.TOOL_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + TelemetryCase.MODERATION_CHECK: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + TelemetryCase.SUGGESTED_QUESTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + TelemetryCase.DATASET_RETRIEVAL: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + TelemetryCase.GENERATE_NAME: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + # TRACE — enterprise-only + TelemetryCase.NODE_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False), + TelemetryCase.DRAFT_NODE_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False), + TelemetryCase.PROMPT_GENERATION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False), + # METRIC_LOG — enterprise-only (signal-driven, not trace) + TelemetryCase.APP_CREATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False), + TelemetryCase.APP_UPDATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False), + TelemetryCase.APP_DELETED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False), + TelemetryCase.FEEDBACK_CREATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False), + } + return _case_routing + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def is_enterprise_telemetry_enabled() -> bool: + try: + from enterprise.telemetry.exporter import is_enterprise_telemetry_enabled + + return is_enterprise_telemetry_enabled() + except Exception: + return False + + +def _handle_payload_sizing( + payload: dict[str, Any], + tenant_id: str, + event_id: str, +) -> tuple[dict[str, Any], str | None]: + """Inline or offload payload based on size. + + Returns ``(payload_for_envelope, storage_key | None)``. Payloads + exceeding ``PAYLOAD_SIZE_THRESHOLD_BYTES`` are written to object + storage and replaced with an empty dict in the envelope. + """ + try: + payload_json = json.dumps(payload) + payload_size = len(payload_json.encode("utf-8")) + except (TypeError, ValueError): + logger.warning("Failed to serialize payload for sizing: event_id=%s", event_id) + return payload, None + + if payload_size <= PAYLOAD_SIZE_THRESHOLD_BYTES: + return payload, None + + storage_key = f"telemetry/{tenant_id}/{event_id}.json" + try: + storage.save(storage_key, payload_json.encode("utf-8")) + logger.debug("Stored large payload to storage: key=%s, size=%d", storage_key, payload_size) + return {}, storage_key + except Exception: + logger.warning("Failed to store large payload, inlining instead: event_id=%s", event_id, exc_info=True) + return payload, None + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def emit( + case: TelemetryCase, + context: dict[str, Any], + payload: dict[str, Any], + trace_manager: TraceQueueManager | None = None, +) -> None: + """Route a telemetry event to the correct pipeline. + + TRACE events are enqueued into ``TraceQueueManager`` (works in both CE + and EE). Enterprise-only traces are silently dropped when EE is + disabled. + + METRIC_LOG events are dispatched to the enterprise Celery queue; + silently dropped when enterprise telemetry is unavailable. + """ + route = _get_case_routing().get(case) + if route is None: + logger.warning("Unknown telemetry case: %s, dropping event", case) + return + + if not route.ce_eligible and not is_enterprise_telemetry_enabled(): + logger.debug("Dropping EE-only event: case=%s (EE disabled)", case) + return + + if route.signal_type == "trace": + _emit_trace(case, context, payload, trace_manager) + else: + _emit_metric_log(case, context, payload) + + +def _emit_trace( + case: TelemetryCase, + context: dict[str, Any], + payload: dict[str, Any], + trace_manager: TraceQueueManager | None, +) -> None: + from core.ops.ops_trace_manager import TraceQueueManager as LocalTraceQueueManager + from core.ops.ops_trace_manager import TraceTask + + trace_task_name = _get_case_to_trace_task().get(case) + if trace_task_name is None: + logger.warning("No TraceTaskName mapping for case: %s", case) + return + + queue_manager = trace_manager or LocalTraceQueueManager( + app_id=context.get("app_id"), + user_id=context.get("user_id"), + ) + queue_manager.add_trace_task(TraceTask(trace_task_name, **payload)) + logger.debug("Enqueued trace task: case=%s, app_id=%s", case, context.get("app_id")) + + +def _emit_metric_log( + case: TelemetryCase, + context: dict[str, Any], + payload: dict[str, Any], +) -> None: + """Build envelope and dispatch to enterprise Celery queue. + + No-ops when the enterprise telemetry task is not importable (CE mode). + """ + try: + from tasks.enterprise_telemetry_task import process_enterprise_telemetry + except ImportError: + logger.debug("Enterprise metric/log dispatch unavailable, dropping: case=%s", case) + return + + tenant_id = context.get("tenant_id", "") + event_id = str(uuid.uuid4()) + + payload_for_envelope, payload_ref = _handle_payload_sizing(payload, tenant_id, event_id) + + from enterprise.telemetry.contracts import TelemetryEnvelope + + envelope = TelemetryEnvelope( + case=case, + tenant_id=tenant_id, + event_id=event_id, + payload=payload_for_envelope, + metadata={"payload_ref": payload_ref} if payload_ref else None, + ) + + process_enterprise_telemetry.delay(envelope.model_dump_json()) + logger.debug( + "Enqueued metric/log event: case=%s, tenant_id=%s, event_id=%s", + case, + tenant_id, + event_id, + ) diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 9c1ceff145..0106f60c0d 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -50,6 +50,7 @@ class WorkflowTool(Tool): self.workflow_call_depth = workflow_call_depth self.label = label self._latest_usage = LLMUsage.empty_usage() + self.parent_trace_context: dict[str, str] | None = None super().__init__(entity=entity, runtime=runtime) @@ -90,11 +91,15 @@ class WorkflowTool(Tool): self._latest_usage = LLMUsage.empty_usage() + args: dict[str, Any] = {"inputs": tool_parameters, "files": files} + if self.parent_trace_context: + args["_parent_trace_context"] = self.parent_trace_context + result = generator.generate( app_model=app, workflow=workflow, user=user, - args={"inputs": tool_parameters, "files": files}, + args=args, invoke_from=self.runtime.invoke_from, streaming=False, call_depth=self.workflow_call_depth + 1, diff --git a/api/core/workflow/enums.py b/api/core/workflow/enums.py index bb3b13e8c6..938a2f5e21 100644 --- a/api/core/workflow/enums.py +++ b/api/core/workflow/enums.py @@ -232,6 +232,8 @@ class WorkflowNodeExecutionMetadataKey(StrEnum): """ TOTAL_TOKENS = "total_tokens" + PROMPT_TOKENS = "prompt_tokens" + COMPLETION_TOKENS = "completion_tokens" TOTAL_PRICE = "total_price" CURRENCY = "currency" TOOL_INFO = "tool_info" diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index beccf79344..92e9439acc 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -322,6 +322,8 @@ class LLMNode(Node[LLMNodeData]): outputs=outputs, metadata={ WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, + WorkflowNodeExecutionMetadataKey.PROMPT_TOKENS: usage.prompt_tokens, + WorkflowNodeExecutionMetadataKey.COMPLETION_TOKENS: usage.completion_tokens, WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, }, diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 60d76db9b6..1d88249fc8 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -60,7 +60,9 @@ class ToolNode(Node[ToolNodeData]): tool_info = { "provider_type": self.node_data.provider_type.value, "provider_id": self.node_data.provider_id, + "tool_name": self.node_data.tool_name, "plugin_unique_identifier": self.node_data.plugin_unique_identifier, + "credential_id": self.node_data.credential_id, } # get tool runtime @@ -105,6 +107,20 @@ class ToolNode(Node[ToolNodeData]): # get conversation id conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) + from core.tools.workflow_as_tool.tool import WorkflowTool + + if isinstance(tool_runtime, WorkflowTool): + workflow_run_id_var = self.graph_runtime_state.variable_pool.get( + ["sys", SystemVariableKey.WORKFLOW_EXECUTION_ID] + ) + tool_runtime.parent_trace_context = { + "trace_id": str(workflow_run_id_var.text) if workflow_run_id_var else "", + "parent_node_execution_id": self.execution_id, + "parent_workflow_run_id": str(workflow_run_id_var.text) if workflow_run_id_var else "", + "parent_app_id": self.app_id, + "parent_conversation_id": conversation_id.text if conversation_id else None, + } + try: message_stream = ToolEngine.generic_invoke( tool=tool_runtime, @@ -431,6 +447,8 @@ class ToolNode(Node[ToolNodeData]): } if isinstance(usage.total_tokens, int) and usage.total_tokens > 0: metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens + metadata[WorkflowNodeExecutionMetadataKey.PROMPT_TOKENS] = usage.prompt_tokens + metadata[WorkflowNodeExecutionMetadataKey.COMPLETION_TOKENS] = usage.completion_tokens metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency diff --git a/api/enterprise/__init__.py b/api/enterprise/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/enterprise/telemetry/DATA_DICTIONARY.md b/api/enterprise/telemetry/DATA_DICTIONARY.md new file mode 100644 index 0000000000..3053bee5ac --- /dev/null +++ b/api/enterprise/telemetry/DATA_DICTIONARY.md @@ -0,0 +1,522 @@ +# Dify Enterprise Telemetry Data Dictionary + +Quick reference for all telemetry signals emitted by Dify Enterprise. For configuration and architecture details, see [README.md](./README.md). + +## Resource Attributes + +Attached to every signal (Span, Metric, Log). + +| Attribute | Type | Example | +|-----------|------|---------| +| `service.name` | string | `dify` | +| `host.name` | string | `dify-api-7f8b` | + +## Traces (Spans) + +### `dify.workflow.run` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.trace_id` | string | Business trace ID (Workflow Run ID) | +| `dify.tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.workflow.id` | string | Workflow definition ID | +| `dify.workflow.run_id` | string | Unique ID for this run | +| `dify.workflow.status` | string | `succeeded`, `failed`, `stopped`, etc. | +| `dify.workflow.error` | string | Error message if failed | +| `dify.workflow.elapsed_time` | float | Total execution time (seconds) | +| `dify.invoke_from` | string | `api`, `webapp`, `debug` | +| `dify.conversation.id` | string | Conversation ID (optional) | +| `dify.message.id` | string | Message ID (optional) | +| `dify.invoked_by` | string | User ID who triggered the run | +| `gen_ai.usage.total_tokens` | int | Total tokens across all nodes (optional) | +| `gen_ai.user.id` | string | End-user identifier (optional) | +| `dify.parent.trace_id` | string | Parent workflow trace ID (optional) | +| `dify.parent.workflow.run_id` | string | Parent workflow run ID (optional) | +| `dify.parent.node.execution_id` | string | Parent node execution ID (optional) | +| `dify.parent.app.id` | string | Parent app ID (optional) | + +### `dify.node.execution` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.trace_id` | string | Business trace ID | +| `dify.tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.workflow.id` | string | Workflow definition ID | +| `dify.workflow.run_id` | string | Workflow Run ID | +| `dify.message.id` | string | Message ID (optional) | +| `dify.conversation.id` | string | Conversation ID (optional) | +| `dify.node.execution_id` | string | Unique node execution ID | +| `dify.node.id` | string | Node ID in workflow graph | +| `dify.node.type` | string | Node type (see appendix) | +| `dify.node.title` | string | Display title | +| `dify.node.status` | string | `succeeded`, `failed` | +| `dify.node.error` | string | Error message if failed | +| `dify.node.elapsed_time` | float | Execution time (seconds) | +| `dify.node.index` | int | Execution order index | +| `dify.node.predecessor_node_id` | string | Triggering node ID | +| `dify.node.iteration_id` | string | Iteration ID (optional) | +| `dify.node.loop_id` | string | Loop ID (optional) | +| `dify.node.parallel_id` | string | Parallel branch ID (optional) | +| `dify.node.invoked_by` | string | User ID who triggered execution | +| `gen_ai.usage.input_tokens` | int | Prompt tokens (LLM nodes only) | +| `gen_ai.usage.output_tokens` | int | Completion tokens (LLM nodes only) | +| `gen_ai.usage.total_tokens` | int | Total tokens (LLM nodes only) | +| `gen_ai.request.model` | string | LLM model name (LLM nodes only) | +| `gen_ai.provider.name` | string | LLM provider name (LLM nodes only) | +| `gen_ai.user.id` | string | End-user identifier (optional) | + +### `dify.node.execution.draft` + +Same attributes as `dify.node.execution`. Emitted during Preview/Debug runs. + +## Counters + +All counters are cumulative and emitted at 100% accuracy. + +### Token Counters + +| Metric | Unit | Description | +|--------|------|-------------| +| `dify.tokens.total` | `{token}` | Total tokens consumed | +| `dify.tokens.input` | `{token}` | Input (prompt) tokens | +| `dify.tokens.output` | `{token}` | Output (completion) tokens | + +**Labels:** +- `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name`, `node_type` (if node_execution) + +⚠️ **Warning:** `dify.tokens.total` at workflow level includes all node tokens. Filter by `operation_type` to avoid double-counting. + +#### Token Hierarchy & Query Patterns + +Token metrics are emitted at multiple layers. Understanding the hierarchy prevents double-counting: + +``` +App-level total +├── workflow ← sum of all node_execution tokens (DO NOT add both) +│ └── node_execution ← per-node breakdown +├── message ← independent (non-workflow chat apps only) +├── rule_generate ← independent helper LLM call +├── code_generate ← independent helper LLM call +├── structured_output ← independent helper LLM call +└── instruction_modify← independent helper LLM call +``` + +**Key rule:** `workflow` tokens already include all `node_execution` tokens. Never sum both. + +**Available labels on token metrics:** `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name`, `node_type`. +App name is only available on span attributes (`dify.app.name`), not metric labels — use `app_id` for metric queries. + +**Common queries** (PromQL): + +```promql +# ── Totals ────────────────────────────────────────────────── +# App-level total (exclude node_execution to avoid double-counting) +sum by (app_id) (dify_tokens_total{operation_type!="node_execution"}) + +# Single app total +sum (dify_tokens_total{app_id="", operation_type!="node_execution"}) + +# Per-tenant totals +sum by (tenant_id) (dify_tokens_total{operation_type!="node_execution"}) + +# ── Drill-down ────────────────────────────────────────────── +# Workflow-level tokens for an app +sum (dify_tokens_total{app_id="", operation_type="workflow"}) + +# Node-level breakdown within an app +sum by (node_type) (dify_tokens_total{app_id="", operation_type="node_execution"}) + +# Model breakdown for an app +sum by (model_provider, model_name) (dify_tokens_total{app_id=""}) + +# Input vs output per model +sum by (model_name) (dify_tokens_input_total{app_id=""}) +sum by (model_name) (dify_tokens_output_total{app_id=""}) + +# ── Rates ─────────────────────────────────────────────────── +# Token consumption rate (per hour) +sum(rate(dify_tokens_total{operation_type!="node_execution"}[1h])) + +# Per-app consumption rate +sum by (app_id) (rate(dify_tokens_total{operation_type!="node_execution"}[1h])) +``` + +**Finding `app_id` from app name** (trace query — Tempo / Jaeger): + +``` +{ resource.dify.app.name = "My Chatbot" } | select(resource.dify.app.id) +``` + +### Request Counters + +| Metric | Unit | Description | +|--------|------|-------------| +| `dify.requests.total` | `{request}` | Total operations count | + +**Labels by type:** + +| `type` | Additional Labels | +|--------|-------------------| +| `workflow` | `tenant_id`, `app_id`, `status`, `invoke_from` | +| `node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name`, `status` | +| `draft_node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name`, `status` | +| `message` | `tenant_id`, `app_id`, `model_provider`, `model_name`, `status`, `invoke_from` | +| `tool` | `tenant_id`, `app_id`, `tool_name` | +| `moderation` | `tenant_id`, `app_id` | +| `suggested_question` | `tenant_id`, `app_id`, `model_provider`, `model_name` | +| `dataset_retrieval` | `tenant_id`, `app_id` | +| `generate_name` | `tenant_id`, `app_id` | +| `prompt_generation` | `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name`, `status` | + +### Error Counters + +| Metric | Unit | Description | +|--------|------|-------------| +| `dify.errors.total` | `{error}` | Total failed operations | + +**Labels by type:** + +| `type` | Additional Labels | +|--------|-------------------| +| `workflow` | `tenant_id`, `app_id` | +| `node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name` | +| `draft_node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name` | +| `message` | `tenant_id`, `app_id`, `model_provider`, `model_name` | +| `tool` | `tenant_id`, `app_id`, `tool_name` | +| `prompt_generation` | `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name` | + +### Other Counters + +| Metric | Unit | Labels | +|--------|------|--------| +| `dify.feedback.total` | `{feedback}` | `tenant_id`, `app_id`, `rating` | +| `dify.dataset.retrievals.total` | `{retrieval}` | `tenant_id`, `app_id`, `dataset_id`, `embedding_model_provider`, `embedding_model`, `rerank_model_provider`, `rerank_model` | +| `dify.app.created.total` | `{app}` | `tenant_id`, `app_id`, `mode` | +| `dify.app.updated.total` | `{app}` | `tenant_id`, `app_id` | +| `dify.app.deleted.total` | `{app}` | `tenant_id`, `app_id` | + +## Histograms + +| Metric | Unit | Labels | +|--------|------|--------| +| `dify.workflow.duration` | `s` | `tenant_id`, `app_id`, `status` | +| `dify.node.duration` | `s` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name`, `plugin_name` | +| `dify.message.duration` | `s` | `tenant_id`, `app_id`, `model_provider`, `model_name` | +| `dify.message.time_to_first_token` | `s` | `tenant_id`, `app_id`, `model_provider`, `model_name` | +| `dify.tool.duration` | `s` | `tenant_id`, `app_id`, `tool_name` | +| `dify.prompt_generation.duration` | `s` | `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name` | + +## Structured Logs + +### Span Companion Logs + +Logs that accompany spans. Signal type: `span_detail` + +#### `dify.workflow.run` Companion Log + +**Common attributes:** All span attributes (see Traces section) plus: + +| Additional Attribute | Type | Always Present | Description | +|---------------------|------|----------------|-------------| +| `dify.app.name` | string | No | Application display name | +| `dify.workspace.name` | string | No | Workspace display name | +| `dify.workflow.version` | string | Yes | Workflow definition version | +| `dify.workflow.inputs` | string/JSON | Yes | Input parameters (content-gated) | +| `dify.workflow.outputs` | string/JSON | Yes | Output results (content-gated) | +| `dify.workflow.query` | string | No | User query text (content-gated) | + +**Event attributes:** +- `dify.event.name`: `"dify.workflow.run"` +- `dify.event.signal`: `"span_detail"` +- `trace_id`, `span_id`, `tenant_id`, `user_id` + +#### `dify.node.execution` and `dify.node.execution.draft` Companion Logs + +**Common attributes:** All span attributes (see Traces section) plus: + +| Additional Attribute | Type | Always Present | Description | +|---------------------|------|----------------|-------------| +| `dify.app.name` | string | No | Application display name | +| `dify.workspace.name` | string | No | Workspace display name | +| `dify.invoke_from` | string | No | Invocation source | +| `gen_ai.tool.name` | string | No | Tool name (tool nodes only) | +| `dify.node.total_price` | float | No | Cost (LLM nodes only) | +| `dify.node.currency` | string | No | Currency code (LLM nodes only) | +| `dify.node.iteration_index` | int | No | Iteration index (iteration nodes) | +| `dify.node.loop_index` | int | No | Loop index (loop nodes) | +| `dify.plugin.name` | string | No | Plugin name (tool/knowledge nodes) | +| `dify.credential.name` | string | No | Credential name (plugin nodes) | +| `dify.credential.id` | string | No | Credential ID (plugin nodes) | +| `dify.dataset.ids` | JSON array | No | Dataset IDs (knowledge nodes) | +| `dify.dataset.names` | JSON array | No | Dataset names (knowledge nodes) | +| `dify.node.inputs` | string/JSON | Yes | Node inputs (content-gated) | +| `dify.node.outputs` | string/JSON | Yes | Node outputs (content-gated) | +| `dify.node.process_data` | string/JSON | No | Processing data (content-gated) | + +**Event attributes:** +- `dify.event.name`: `"dify.node.execution"` or `"dify.node.execution.draft"` +- `dify.event.signal`: `"span_detail"` +- `trace_id`, `span_id`, `tenant_id`, `user_id` + +### Standalone Logs + +Logs without structural spans. Signal type: `metric_only` + +#### `dify.message.run` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.message.run"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID (32-char hex) | +| `span_id` | string | OTEL span ID (16-char hex) | +| `tenant_id` | string | Tenant identifier | +| `user_id` | string | User identifier (optional) | +| `dify.app_id` | string | Application identifier | +| `dify.message.id` | string | Message identifier | +| `dify.conversation.id` | string | Conversation ID (optional) | +| `dify.workflow.run_id` | string | Workflow run ID (optional) | +| `dify.invoke_from` | string | `service-api`, `web-app`, `debugger`, `explore` | +| `gen_ai.provider.name` | string | LLM provider | +| `gen_ai.request.model` | string | LLM model | +| `gen_ai.usage.input_tokens` | int | Input tokens | +| `gen_ai.usage.output_tokens` | int | Output tokens | +| `gen_ai.usage.total_tokens` | int | Total tokens | +| `dify.message.status` | string | `succeeded`, `failed` | +| `dify.message.error` | string | Error message (if failed) | +| `dify.message.duration` | float | Duration (seconds) | +| `dify.message.time_to_first_token` | float | TTFT (seconds) | +| `dify.message.inputs` | string/JSON | Inputs (content-gated) | +| `dify.message.outputs` | string/JSON | Outputs (content-gated) | + +#### `dify.tool.execution` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.tool.execution"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID | +| `span_id` | string | OTEL span ID | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.message.id` | string | Message identifier | +| `dify.tool.name` | string | Tool name | +| `dify.tool.duration` | float | Duration (seconds) | +| `dify.tool.status` | string | `succeeded`, `failed` | +| `dify.tool.error` | string | Error message (if failed) | +| `dify.tool.inputs` | string/JSON | Inputs (content-gated) | +| `dify.tool.outputs` | string/JSON | Outputs (content-gated) | +| `dify.tool.parameters` | string/JSON | Parameters (content-gated) | +| `dify.tool.config` | string/JSON | Configuration (content-gated) | + +#### `dify.moderation.check` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.moderation.check"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID | +| `span_id` | string | OTEL span ID | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.message.id` | string | Message identifier | +| `dify.moderation.type` | string | `input`, `output` | +| `dify.moderation.action` | string | `pass`, `block`, `flag` | +| `dify.moderation.flagged` | boolean | Whether flagged | +| `dify.moderation.categories` | JSON array | Flagged categories | +| `dify.moderation.query` | string | Content (content-gated) | + +#### `dify.suggested_question.generation` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.suggested_question.generation"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID | +| `span_id` | string | OTEL span ID | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.message.id` | string | Message identifier | +| `dify.suggested_question.count` | int | Number of questions | +| `dify.suggested_question.duration` | float | Duration (seconds) | +| `dify.suggested_question.status` | string | `succeeded`, `failed` | +| `dify.suggested_question.error` | string | Error message (if failed) | +| `dify.suggested_question.questions` | JSON array | Questions (content-gated) | + +#### `dify.dataset.retrieval` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.dataset.retrieval"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID | +| `span_id` | string | OTEL span ID | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.message.id` | string | Message identifier | +| `dify.dataset.id` | string | Dataset identifier | +| `dify.dataset.name` | string | Dataset name | +| `dify.dataset.embedding_providers` | JSON array | Embedding model providers (one per dataset) | +| `dify.dataset.embedding_models` | JSON array | Embedding models (one per dataset) | +| `dify.retrieval.rerank_provider` | string | Rerank model provider | +| `dify.retrieval.rerank_model` | string | Rerank model name | +| `dify.retrieval.query` | string | Search query (content-gated) | +| `dify.retrieval.document_count` | int | Documents retrieved | +| `dify.retrieval.duration` | float | Duration (seconds) | +| `dify.retrieval.status` | string | `succeeded`, `failed` | +| `dify.retrieval.error` | string | Error message (if failed) | +| `dify.dataset.documents` | JSON array | Documents (content-gated) | + +#### `dify.generate_name.execution` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.generate_name.execution"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID | +| `span_id` | string | OTEL span ID | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.conversation.id` | string | Conversation identifier | +| `dify.generate_name.duration` | float | Duration (seconds) | +| `dify.generate_name.status` | string | `succeeded`, `failed` | +| `dify.generate_name.error` | string | Error message (if failed) | +| `dify.generate_name.inputs` | string/JSON | Inputs (content-gated) | +| `dify.generate_name.outputs` | string | Generated name (content-gated) | + +#### `dify.prompt_generation.execution` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.prompt_generation.execution"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID | +| `span_id` | string | OTEL span ID | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.prompt_generation.operation_type` | string | Operation type (see appendix) | +| `gen_ai.provider.name` | string | LLM provider | +| `gen_ai.request.model` | string | LLM model | +| `gen_ai.usage.input_tokens` | int | Input tokens | +| `gen_ai.usage.output_tokens` | int | Output tokens | +| `gen_ai.usage.total_tokens` | int | Total tokens | +| `dify.prompt_generation.duration` | float | Duration (seconds) | +| `dify.prompt_generation.status` | string | `succeeded`, `failed` | +| `dify.prompt_generation.error` | string | Error message (if failed) | +| `dify.prompt_generation.instruction` | string | Instruction (content-gated) | +| `dify.prompt_generation.output` | string/JSON | Output (content-gated) | + +#### `dify.app.created` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.app.created"` | +| `dify.event.signal` | string | `"metric_only"` | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.app.mode` | string | `chat`, `completion`, `agent-chat`, `workflow` | +| `dify.app.created_at` | string | Timestamp (ISO 8601) | + +#### `dify.app.updated` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.app.updated"` | +| `dify.event.signal` | string | `"metric_only"` | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.app.updated_at` | string | Timestamp (ISO 8601) | + +#### `dify.app.deleted` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.app.deleted"` | +| `dify.event.signal` | string | `"metric_only"` | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.app.deleted_at` | string | Timestamp (ISO 8601) | + +#### `dify.feedback.created` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.feedback.created"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID | +| `span_id` | string | OTEL span ID | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.message.id` | string | Message identifier | +| `dify.feedback.rating` | string | `like`, `dislike`, `null` | +| `dify.feedback.content` | string | Feedback text (content-gated) | +| `dify.feedback.created_at` | string | Timestamp (ISO 8601) | + +#### `dify.telemetry.rehydration_failed` + +Diagnostic event for telemetry system health monitoring. + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.telemetry.rehydration_failed"` | +| `dify.event.signal` | string | `"metric_only"` | +| `tenant_id` | string | Tenant identifier | +| `dify.telemetry.error` | string | Error message | +| `dify.telemetry.payload_type` | string | Payload type (see appendix) | +| `dify.telemetry.correlation_id` | string | Correlation ID | + +## Content-Gated Attributes + +When `ENTERPRISE_INCLUDE_CONTENT=false`, these attributes are replaced with reference strings (`ref:{id_type}={uuid}`). + +| Attribute | Signal | +|-----------|--------| +| `dify.workflow.inputs` | `dify.workflow.run` | +| `dify.workflow.outputs` | `dify.workflow.run` | +| `dify.workflow.query` | `dify.workflow.run` | +| `dify.node.inputs` | `dify.node.execution` | +| `dify.node.outputs` | `dify.node.execution` | +| `dify.node.process_data` | `dify.node.execution` | +| `dify.message.inputs` | `dify.message.run` | +| `dify.message.outputs` | `dify.message.run` | +| `dify.tool.inputs` | `dify.tool.execution` | +| `dify.tool.outputs` | `dify.tool.execution` | +| `dify.tool.parameters` | `dify.tool.execution` | +| `dify.tool.config` | `dify.tool.execution` | +| `dify.moderation.query` | `dify.moderation.check` | +| `dify.suggested_question.questions` | `dify.suggested_question.generation` | +| `dify.retrieval.query` | `dify.dataset.retrieval` | +| `dify.dataset.documents` | `dify.dataset.retrieval` | +| `dify.generate_name.inputs` | `dify.generate_name.execution` | +| `dify.generate_name.outputs` | `dify.generate_name.execution` | +| `dify.prompt_generation.instruction` | `dify.prompt_generation.execution` | +| `dify.prompt_generation.output` | `dify.prompt_generation.execution` | +| `dify.feedback.content` | `dify.feedback.created` | + +## Appendix + +### Operation Types + +- `workflow`, `node_execution`, `message`, `rule_generate`, `code_generate`, `structured_output`, `instruction_modify` + +### Node Types + +- `start`, `end`, `answer`, `llm`, `knowledge-retrieval`, `knowledge-index`, `if-else`, `code`, `template-transform`, `question-classifier`, `http-request`, `tool`, `datasource`, `variable-aggregator`, `loop`, `iteration`, `parameter-extractor`, `assigner`, `document-extractor`, `list-operator`, `agent`, `trigger-webhook`, `trigger-schedule`, `trigger-plugin`, `human-input` + +### Workflow Statuses + +- `running`, `succeeded`, `failed`, `stopped`, `partial-succeeded`, `paused` + +### Payload Types + +- `workflow`, `node`, `message`, `tool`, `moderation`, `suggested_question`, `dataset_retrieval`, `generate_name`, `prompt_generation`, `app`, `feedback` + +### Null Value Behavior + +**Spans:** Attributes with `null` values are omitted. + +**Logs:** Attributes with `null` values appear as `null` in JSON. + +**Content-Gated:** Replaced with reference strings, not set to `null`. diff --git a/api/enterprise/telemetry/README.md b/api/enterprise/telemetry/README.md new file mode 100644 index 0000000000..2c8ca988e1 --- /dev/null +++ b/api/enterprise/telemetry/README.md @@ -0,0 +1,116 @@ +# Dify Enterprise Telemetry + +This document provides an overview of the Dify Enterprise OpenTelemetry (OTEL) exporter and how to configure it for integration with observability stacks like Prometheus, Grafana, Jaeger, or Honeycomb. + +## Overview + +Dify Enterprise uses a "slim span + rich companion log" architecture to provide high-fidelity observability without overwhelming trace storage. + +- **Traces (Spans)**: Capture the structure, identity, and timing of high-level operations (Workflows and Nodes). +- **Structured Logs**: Provide deep context (inputs, outputs, metadata) for every event, correlated to spans via `trace_id` and `span_id`. +- **Metrics**: Provide 100% accurate counters and histograms for usage, performance, and error tracking. + +### Signal Architecture + +```mermaid +graph TD + A[Workflow Run] -->|Span| B(dify.workflow.run) + A -->|Log| C(dify.workflow.run detail) + B ---|trace_id| C + + D[Node Execution] -->|Span| E(dify.node.execution) + D -->|Log| F(dify.node.execution detail) + E ---|span_id| F + + G[Message/Tool/etc] -->|Log| H(dify.* event) + G -->|Metric| I(dify.* counter/histogram) +``` + +## Configuration + +The Enterprise OTEL exporter is configured via environment variables. + +| Variable | Description | Default | +|----------|-------------|---------| +| `ENTERPRISE_ENABLED` | Master switch for all enterprise features. | `false` | +| `ENTERPRISE_TELEMETRY_ENABLED` | Master switch for enterprise telemetry. | `false` | +| `ENTERPRISE_OTLP_ENDPOINT` | OTLP collector endpoint (e.g., `http://otel-collector:4318`). | - | +| `ENTERPRISE_OTLP_HEADERS` | Custom headers for OTLP requests (e.g., `x-scope-orgid=tenant1`). | - | +| `ENTERPRISE_OTLP_PROTOCOL` | OTLP transport protocol (`http` or `grpc`). | `http` | +| `ENTERPRISE_OTLP_API_KEY` | Bearer token for authentication. | - | +| `ENTERPRISE_INCLUDE_CONTENT` | Whether to include sensitive content (inputs/outputs) in logs. | `true` | +| `ENTERPRISE_SERVICE_NAME` | Service name reported to OTEL. | `dify` | +| `ENTERPRISE_OTEL_SAMPLING_RATE` | Sampling rate for traces (0.0 to 1.0). Metrics are always 100%. | `1.0` | + +## Correlation Model + +Dify uses deterministic ID generation to ensure signals are correlated across different services and asynchronous tasks. + +### ID Generation Rules +- `trace_id`: Derived from the correlation ID (workflow_run_id or node_execution_id for drafts) using `int(UUID(correlation_id))` +- `span_id`: Derived from the source ID using `SHA256(source_id)[:8]` + +### Scenario A: Simple Workflow +A single workflow run with multiple nodes. All spans and logs share the same `trace_id` (derived from `workflow_run_id`). + +``` +trace_id = UUID(workflow_run_id) +├── [root span] dify.workflow.run (span_id = hash(workflow_run_id)) +│ ├── [child] dify.node.execution - "Start" (span_id = hash(node_exec_id_1)) +│ ├── [child] dify.node.execution - "LLM" (span_id = hash(node_exec_id_2)) +│ └── [child] dify.node.execution - "End" (span_id = hash(node_exec_id_3)) +``` + +### Scenario B: Nested Sub-Workflow +A workflow calling another workflow via a Tool or Sub-workflow node. The child workflow's spans are linked to the parent via `parent_span_id`. Both workflows share the same trace_id. + +``` +trace_id = UUID(outer_workflow_run_id) ← shared across both workflows +├── [root] dify.workflow.run (outer) (span_id = hash(outer_workflow_run_id)) +│ ├── dify.node.execution - "Start Node" +│ ├── dify.node.execution - "Tool Node" (triggers sub-workflow) +│ │ └── [child] dify.workflow.run (inner) (span_id = hash(inner_workflow_run_id)) +│ │ ├── dify.node.execution - "Inner Start" +│ │ └── dify.node.execution - "Inner End" +│ └── dify.node.execution - "End Node" +``` + +**Key attributes for nested workflows:** +- Inner workflow's `dify.parent.trace_id` = outer `workflow_run_id` +- Inner workflow's `dify.parent.node.execution_id` = tool node's `execution_id` +- Inner workflow's `dify.parent.workflow.run_id` = outer `workflow_run_id` +- Inner workflow's `dify.parent.app.id` = outer `app_id` + +### Scenario C: Draft Node Execution +A single node run in isolation (debugger/preview mode). It creates its own trace where the node span is the root. + +``` +trace_id = UUID(node_execution_id) ← own trace, NOT part of any workflow +└── dify.node.execution.draft (span_id = hash(node_execution_id)) +``` + +**Key difference:** Draft executions use `node_execution_id` as the correlation_id, so they are NOT children of any workflow trace. + +## Content Gating + +When `ENTERPRISE_INCLUDE_CONTENT` is set to `false`, sensitive content attributes (inputs, outputs, queries) are replaced with reference strings (e.g., `ref:workflow_run_id=...`) to prevent data leakage to the OTEL collector. + +**Reference String Format:** + +``` +ref:{id_type}={uuid} +``` + +**Examples:** + +``` +ref:workflow_run_id=550e8400-e29b-41d4-a716-446655440000 +ref:node_execution_id=660e8400-e29b-41d4-a716-446655440001 +ref:message_id=770e8400-e29b-41d4-a716-446655440002 +``` + +To retrieve actual content when gating is enabled, query the Dify database using the provided UUID. + +## Reference + +For a complete list of telemetry signals, attributes, and data structures, see [DATA_DICTIONARY.md](./DATA_DICTIONARY.md). diff --git a/api/enterprise/telemetry/__init__.py b/api/enterprise/telemetry/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/enterprise/telemetry/contracts.py b/api/enterprise/telemetry/contracts.py new file mode 100644 index 0000000000..91398cb8cb --- /dev/null +++ b/api/enterprise/telemetry/contracts.py @@ -0,0 +1,73 @@ +"""Telemetry gateway contracts and data structures. + +This module defines the envelope format for telemetry events and the routing +configuration that determines how each event type is processed. +""" + +from __future__ import annotations + +from enum import StrEnum +from typing import Any + +from pydantic import BaseModel, ConfigDict + + +class TelemetryCase(StrEnum): + """Enumeration of all known telemetry event cases.""" + + WORKFLOW_RUN = "workflow_run" + NODE_EXECUTION = "node_execution" + DRAFT_NODE_EXECUTION = "draft_node_execution" + MESSAGE_RUN = "message_run" + TOOL_EXECUTION = "tool_execution" + MODERATION_CHECK = "moderation_check" + SUGGESTED_QUESTION = "suggested_question" + DATASET_RETRIEVAL = "dataset_retrieval" + GENERATE_NAME = "generate_name" + PROMPT_GENERATION = "prompt_generation" + APP_CREATED = "app_created" + APP_UPDATED = "app_updated" + APP_DELETED = "app_deleted" + FEEDBACK_CREATED = "feedback_created" + + +class SignalType(StrEnum): + """Signal routing type for telemetry cases.""" + + TRACE = "trace" + METRIC_LOG = "metric_log" + + +class CaseRoute(BaseModel): + """Routing configuration for a telemetry case. + + Attributes: + signal_type: The type of signal (trace or metric_log). + ce_eligible: Whether this case is eligible for community edition tracing. + """ + + signal_type: SignalType + ce_eligible: bool + + +class TelemetryEnvelope(BaseModel): + """Envelope for telemetry events. + + Attributes: + case: The telemetry case type. + tenant_id: The tenant identifier. + event_id: Unique event identifier for deduplication. + payload: The main event payload (inline for small payloads, + empty when offloaded to storage via ``payload_ref``). + metadata: Optional metadata dictionary. When the gateway + offloads a large payload to object storage, this contains + ``{"payload_ref": ""}``. + """ + + model_config = ConfigDict(extra="forbid", use_enum_values=False) + + case: TelemetryCase + tenant_id: str + event_id: str + payload: dict[str, Any] + metadata: dict[str, Any] | None = None diff --git a/api/enterprise/telemetry/draft_trace.py b/api/enterprise/telemetry/draft_trace.py new file mode 100644 index 0000000000..ea8088695e --- /dev/null +++ b/api/enterprise/telemetry/draft_trace.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName +from core.telemetry import emit as telemetry_emit +from core.workflow.enums import WorkflowNodeExecutionMetadataKey +from models.workflow import WorkflowNodeExecutionModel + + +def enqueue_draft_node_execution_trace( + *, + execution: WorkflowNodeExecutionModel, + outputs: Mapping[str, Any] | None, + workflow_execution_id: str | None, + user_id: str, +) -> None: + node_data = _build_node_execution_data( + execution=execution, + outputs=outputs, + workflow_execution_id=workflow_execution_id, + ) + telemetry_emit( + TelemetryEvent( + name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + context=TelemetryContext( + tenant_id=execution.tenant_id, + user_id=user_id, + app_id=execution.app_id, + ), + payload={"node_execution_data": node_data}, + ) + ) + + +def _build_node_execution_data( + *, + execution: WorkflowNodeExecutionModel, + outputs: Mapping[str, Any] | None, + workflow_execution_id: str | None, +) -> dict[str, Any]: + metadata = execution.execution_metadata_dict + node_outputs = outputs if outputs is not None else execution.outputs_dict + execution_id = workflow_execution_id or execution.workflow_run_id or execution.id + + return { + "workflow_id": execution.workflow_id, + "workflow_execution_id": execution_id, + "tenant_id": execution.tenant_id, + "app_id": execution.app_id, + "node_execution_id": execution.id, + "node_id": execution.node_id, + "node_type": execution.node_type, + "title": execution.title, + "status": execution.status, + "error": execution.error, + "elapsed_time": execution.elapsed_time, + "index": execution.index, + "predecessor_node_id": execution.predecessor_node_id, + "created_at": execution.created_at, + "finished_at": execution.finished_at, + "total_tokens": metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS, 0), + "total_price": metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_PRICE, 0.0), + "currency": metadata.get(WorkflowNodeExecutionMetadataKey.CURRENCY), + "tool_name": (metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO) or {}).get("tool_name") + if isinstance(metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO), dict) + else None, + "iteration_id": metadata.get(WorkflowNodeExecutionMetadataKey.ITERATION_ID), + "iteration_index": metadata.get(WorkflowNodeExecutionMetadataKey.ITERATION_INDEX), + "loop_id": metadata.get(WorkflowNodeExecutionMetadataKey.LOOP_ID), + "loop_index": metadata.get(WorkflowNodeExecutionMetadataKey.LOOP_INDEX), + "parallel_id": metadata.get(WorkflowNodeExecutionMetadataKey.PARALLEL_ID), + "node_inputs": execution.inputs_dict, + "node_outputs": node_outputs, + "process_data": execution.process_data_dict, + } diff --git a/api/enterprise/telemetry/enterprise_trace.py b/api/enterprise/telemetry/enterprise_trace.py new file mode 100644 index 0000000000..a0a7d67f82 --- /dev/null +++ b/api/enterprise/telemetry/enterprise_trace.py @@ -0,0 +1,938 @@ +"""Enterprise trace handler — duck-typed, NOT a BaseTraceInstance subclass. + +Invoked directly in the Celery task, not through OpsTraceManager dispatch. +Only requires a matching ``trace(trace_info)`` method signature. + +Signal strategy: +- **Traces (spans)**: workflow run, node execution, draft node execution only. +- **Metrics + structured logs**: all other event types. + +Token metric labels (unified structure): +All token metrics (dify.tokens.input, dify.tokens.output, dify.tokens.total) use the +same label set for consistent filtering and aggregation: +- tenant_id: Tenant identifier +- app_id: Application identifier +- operation_type: Source of token usage (workflow | node_execution | message | rule_generate | etc.) +- model_provider: LLM provider name (empty string if not applicable) +- model_name: LLM model name (empty string if not applicable) +- node_type: Workflow node type (empty string if not node_execution) + +This unified structure allows filtering by operation_type to separate: +- Workflow-level aggregates (operation_type=workflow) +- Individual node executions (operation_type=node_execution) +- Direct message calls (operation_type=message) +- Prompt generation operations (operation_type=rule_generate, code_generate, etc.) + +Without this, tokens are double-counted when querying totals (workflow totals include +node totals, since workflow.total_tokens is the sum of all node tokens). +""" + +from __future__ import annotations + +import json +import logging +from typing import Any, cast + +from opentelemetry.util.types import AttributeValue + +from core.ops.entities.trace_entity import ( + BaseTraceInfo, + DatasetRetrievalTraceInfo, + DraftNodeExecutionTrace, + GenerateNameTraceInfo, + MessageTraceInfo, + ModerationTraceInfo, + OperationType, + PromptGenerationTraceInfo, + SuggestedQuestionTraceInfo, + ToolTraceInfo, + WorkflowNodeTraceInfo, + WorkflowTraceInfo, +) +from enterprise.telemetry.entities import ( + EnterpriseTelemetryCounter, + EnterpriseTelemetryEvent, + EnterpriseTelemetryHistogram, + EnterpriseTelemetrySpan, + TokenMetricLabels, +) +from enterprise.telemetry.telemetry_log import emit_metric_only_event, emit_telemetry_log + +logger = logging.getLogger(__name__) + + +class EnterpriseOtelTrace: + """Duck-typed enterprise trace handler. + + ``*_trace`` methods emit spans (workflow/node only) or structured logs + (all other events), plus metrics at 100 % accuracy. + """ + + def __init__(self) -> None: + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + + exporter = get_enterprise_exporter() + if exporter is None: + raise RuntimeError("EnterpriseOtelTrace instantiated but exporter is not initialized") + self._exporter = exporter + + def trace(self, trace_info: BaseTraceInfo) -> None: + if isinstance(trace_info, WorkflowTraceInfo): + self._workflow_trace(trace_info) + elif isinstance(trace_info, MessageTraceInfo): + self._message_trace(trace_info) + elif isinstance(trace_info, ToolTraceInfo): + self._tool_trace(trace_info) + elif isinstance(trace_info, DraftNodeExecutionTrace): + self._draft_node_execution_trace(trace_info) + elif isinstance(trace_info, WorkflowNodeTraceInfo): + self._node_execution_trace(trace_info) + elif isinstance(trace_info, ModerationTraceInfo): + self._moderation_trace(trace_info) + elif isinstance(trace_info, SuggestedQuestionTraceInfo): + self._suggested_question_trace(trace_info) + elif isinstance(trace_info, DatasetRetrievalTraceInfo): + self._dataset_retrieval_trace(trace_info) + elif isinstance(trace_info, GenerateNameTraceInfo): + self._generate_name_trace(trace_info) + elif isinstance(trace_info, PromptGenerationTraceInfo): + self._prompt_generation_trace(trace_info) + + def _common_attrs(self, trace_info: BaseTraceInfo) -> dict[str, Any]: + metadata = self._metadata(trace_info) + tenant_id, app_id, user_id = self._context_ids(trace_info, metadata) + return { + "dify.trace_id": trace_info.resolved_trace_id, + "dify.tenant_id": tenant_id, + "dify.app_id": app_id, + "dify.app.name": metadata.get("app_name"), + "dify.workspace.name": metadata.get("workspace_name"), + "gen_ai.user.id": user_id, + "dify.message.id": trace_info.message_id, + } + + def _metadata(self, trace_info: BaseTraceInfo) -> dict[str, Any]: + return trace_info.metadata + + def _context_ids( + self, + trace_info: BaseTraceInfo, + metadata: dict[str, Any], + ) -> tuple[str | None, str | None, str | None]: + tenant_id = getattr(trace_info, "tenant_id", None) or metadata.get("tenant_id") + app_id = getattr(trace_info, "app_id", None) or metadata.get("app_id") + user_id = getattr(trace_info, "user_id", None) or metadata.get("user_id") + return tenant_id, app_id, user_id + + def _labels(self, **values: AttributeValue) -> dict[str, AttributeValue]: + return dict(values) + + def _safe_payload_value(self, value: Any) -> str | dict[str, Any] | list[object] | None: + if isinstance(value, str): + return value + if isinstance(value, dict): + return cast(dict[str, Any], value) + if isinstance(value, list): + items: list[object] = [] + for item in cast(list[object], value): + items.append(item) + return items + return None + + def _content_or_ref(self, value: Any, ref: str) -> Any: + if self._exporter.include_content: + return self._maybe_json(value) + return ref + + def _maybe_json(self, value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + try: + return json.dumps(value, default=str) + except (TypeError, ValueError): + return str(value) + + # ------------------------------------------------------------------ + # SPAN-emitting handlers (workflow, node execution, draft node) + # ------------------------------------------------------------------ + + def _workflow_trace(self, info: WorkflowTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + # -- Span attrs: identity + structure + status + timing + gen_ai scalars -- + span_attrs: dict[str, Any] = { + "dify.trace_id": info.resolved_trace_id, + "dify.tenant_id": tenant_id, + "dify.app_id": app_id, + "dify.workflow.id": info.workflow_id, + "dify.workflow.run_id": info.workflow_run_id, + "dify.workflow.status": info.workflow_run_status, + "dify.workflow.error": info.error, + "dify.workflow.elapsed_time": info.workflow_run_elapsed_time, + "dify.invoke_from": metadata.get("triggered_from"), + "dify.conversation.id": info.conversation_id, + "dify.message.id": info.message_id, + "dify.invoked_by": info.invoked_by, + "gen_ai.usage.total_tokens": info.total_tokens, + "gen_ai.user.id": user_id, + } + + trace_correlation_override, parent_span_id_source = info.resolved_parent_context + + parent_ctx = metadata.get("parent_trace_context") + if isinstance(parent_ctx, dict): + parent_ctx_dict = cast(dict[str, Any], parent_ctx) + span_attrs["dify.parent.trace_id"] = parent_ctx_dict.get("trace_id") + span_attrs["dify.parent.node.execution_id"] = parent_ctx_dict.get("parent_node_execution_id") + span_attrs["dify.parent.workflow.run_id"] = parent_ctx_dict.get("parent_workflow_run_id") + span_attrs["dify.parent.app.id"] = parent_ctx_dict.get("parent_app_id") + + self._exporter.export_span( + EnterpriseTelemetrySpan.WORKFLOW_RUN, + span_attrs, + correlation_id=info.workflow_run_id, + span_id_source=info.workflow_run_id, + start_time=info.start_time, + end_time=info.end_time, + trace_correlation_override=trace_correlation_override, + parent_span_id_source=parent_span_id_source, + ) + + # -- Companion log: ALL attrs (span + detail) for full picture -- + log_attrs: dict[str, Any] = {**span_attrs} + log_attrs.update( + { + "dify.app.name": metadata.get("app_name"), + "dify.workspace.name": metadata.get("workspace_name"), + "gen_ai.user.id": user_id, + "gen_ai.usage.total_tokens": info.total_tokens, + "dify.workflow.version": info.workflow_run_version, + } + ) + + ref = f"ref:workflow_run_id={info.workflow_run_id}" + log_attrs["dify.workflow.inputs"] = self._content_or_ref(info.workflow_run_inputs, ref) + log_attrs["dify.workflow.outputs"] = self._content_or_ref(info.workflow_run_outputs, ref) + log_attrs["dify.workflow.query"] = self._content_or_ref(info.query, ref) + + emit_telemetry_log( + event_name=EnterpriseTelemetryEvent.WORKFLOW_RUN, + attributes=log_attrs, + signal="span_detail", + trace_id_source=info.workflow_run_id, + span_id_source=info.workflow_run_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + # -- Metrics -- + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + ) + token_labels = TokenMetricLabels( + tenant_id=tenant_id or "", + app_id=app_id or "", + operation_type=OperationType.WORKFLOW, + model_provider="", + model_name="", + node_type="", + ).to_dict() + self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels) + if info.prompt_tokens is not None and info.prompt_tokens > 0: + self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, token_labels) + if info.completion_tokens is not None and info.completion_tokens > 0: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, token_labels + ) + invoke_from = metadata.get("triggered_from", "") + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="workflow", + status=info.workflow_run_status, + invoke_from=invoke_from, + ), + ) + # Prefer wall-clock timestamps over the elapsed_time field: elapsed_time defaults + # to 0 in the DB and can be stale if the Celery write races with the trace task. + # start_time = workflow_run.created_at, end_time = workflow_run.finished_at. + if info.start_time and info.end_time: + workflow_duration = (info.end_time - info.start_time).total_seconds() + elif info.workflow_run_elapsed_time: + workflow_duration = float(info.workflow_run_elapsed_time) + else: + workflow_duration = 0.0 + self._exporter.record_histogram( + EnterpriseTelemetryHistogram.WORKFLOW_DURATION, + workflow_duration, + self._labels( + **labels, + status=info.workflow_run_status, + ), + ) + + if info.error: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.ERRORS, + 1, + self._labels( + **labels, + type="workflow", + ), + ) + + def _node_execution_trace(self, info: WorkflowNodeTraceInfo) -> None: + self._emit_node_execution_trace(info, EnterpriseTelemetrySpan.NODE_EXECUTION, "node") + + def _draft_node_execution_trace(self, info: DraftNodeExecutionTrace) -> None: + self._emit_node_execution_trace( + info, + EnterpriseTelemetrySpan.DRAFT_NODE_EXECUTION, + "draft_node", + correlation_id_override=info.node_execution_id, + trace_correlation_override_param=info.workflow_run_id, + ) + + def _emit_node_execution_trace( + self, + info: WorkflowNodeTraceInfo, + span_name: EnterpriseTelemetrySpan, + request_type: str, + correlation_id_override: str | None = None, + trace_correlation_override_param: str | None = None, + ) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + # -- Span attrs: identity + structure + status + timing + gen_ai scalars -- + span_attrs: dict[str, Any] = { + "dify.trace_id": info.resolved_trace_id, + "dify.tenant_id": tenant_id, + "dify.app_id": app_id, + "dify.workflow.id": info.workflow_id, + "dify.workflow.run_id": info.workflow_run_id, + "dify.message.id": info.message_id, + "dify.conversation.id": metadata.get("conversation_id"), + "dify.node.execution_id": info.node_execution_id, + "dify.node.id": info.node_id, + "dify.node.type": info.node_type, + "dify.node.title": info.title, + "dify.node.status": info.status, + "dify.node.error": info.error, + "dify.node.elapsed_time": info.elapsed_time, + "dify.node.index": info.index, + "dify.node.predecessor_node_id": info.predecessor_node_id, + "dify.node.iteration_id": info.iteration_id, + "dify.node.loop_id": info.loop_id, + "dify.node.parallel_id": info.parallel_id, + "dify.node.invoked_by": info.invoked_by, + "gen_ai.usage.input_tokens": info.prompt_tokens, + "gen_ai.usage.output_tokens": info.completion_tokens, + "gen_ai.usage.total_tokens": info.total_tokens, + "gen_ai.request.model": info.model_name, + "gen_ai.provider.name": info.model_provider, + "gen_ai.user.id": user_id, + } + + resolved_override, _ = info.resolved_parent_context + trace_correlation_override = trace_correlation_override_param or resolved_override + + effective_correlation_id = correlation_id_override or info.workflow_run_id + self._exporter.export_span( + span_name, + span_attrs, + correlation_id=effective_correlation_id, + span_id_source=info.node_execution_id, + start_time=info.start_time, + end_time=info.end_time, + trace_correlation_override=trace_correlation_override, + ) + + # -- Companion log: ALL attrs (span + detail) -- + log_attrs: dict[str, Any] = {**span_attrs} + log_attrs.update( + { + "dify.app.name": metadata.get("app_name"), + "dify.workspace.name": metadata.get("workspace_name"), + "dify.invoke_from": metadata.get("invoke_from"), + "gen_ai.user.id": user_id, + "gen_ai.usage.total_tokens": info.total_tokens, + "dify.node.total_price": info.total_price, + "dify.node.currency": info.currency, + "gen_ai.provider.name": info.model_provider, + "gen_ai.request.model": info.model_name, + "gen_ai.tool.name": info.tool_name, + "dify.node.iteration_index": info.iteration_index, + "dify.node.loop_index": info.loop_index, + "dify.plugin.name": metadata.get("plugin_name"), + "dify.credential.name": metadata.get("credential_name"), + "dify.credential.id": metadata.get("credential_id"), + "dify.dataset.ids": self._maybe_json(metadata.get("dataset_ids")), + "dify.dataset.names": self._maybe_json(metadata.get("dataset_names")), + } + ) + + ref = f"ref:node_execution_id={info.node_execution_id}" + log_attrs["dify.node.inputs"] = self._content_or_ref(info.node_inputs, ref) + log_attrs["dify.node.outputs"] = self._content_or_ref(info.node_outputs, ref) + log_attrs["dify.node.process_data"] = self._content_or_ref(info.process_data, ref) + + emit_telemetry_log( + event_name=span_name.value, + attributes=log_attrs, + signal="span_detail", + trace_id_source=info.workflow_run_id, + span_id_source=info.node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + # -- Metrics -- + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + node_type=info.node_type, + model_provider=info.model_provider or "", + ) + if info.total_tokens: + token_labels = TokenMetricLabels( + tenant_id=tenant_id or "", + app_id=app_id or "", + operation_type=OperationType.NODE_EXECUTION, + model_provider=info.model_provider or "", + model_name=info.model_name or "", + node_type=info.node_type, + ).to_dict() + self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels) + if info.prompt_tokens is not None and info.prompt_tokens > 0: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, token_labels + ) + if info.completion_tokens is not None and info.completion_tokens > 0: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, token_labels + ) + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type=request_type, + status=info.status, + model_name=info.model_name or "", + ), + ) + duration_labels = dict(labels) + duration_labels["model_name"] = info.model_name or "" + plugin_name = metadata.get("plugin_name") + if plugin_name and info.node_type in {"tool", "knowledge-retrieval"}: + duration_labels["plugin_name"] = plugin_name + self._exporter.record_histogram(EnterpriseTelemetryHistogram.NODE_DURATION, info.elapsed_time, duration_labels) + + if info.error: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.ERRORS, + 1, + self._labels( + **labels, + type=request_type, + model_name=info.model_name or "", + ), + ) + + # ------------------------------------------------------------------ + # METRIC-ONLY handlers (structured log + counters/histograms) + # ------------------------------------------------------------------ + + def _message_trace(self, info: MessageTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = self._common_attrs(info) + attrs.update( + { + "dify.invoke_from": metadata.get("from_source"), + "dify.conversation.id": metadata.get("conversation_id"), + "dify.conversation.mode": info.conversation_mode, + "gen_ai.provider.name": metadata.get("ls_provider"), + "gen_ai.request.model": metadata.get("ls_model_name"), + "gen_ai.usage.input_tokens": info.message_tokens, + "gen_ai.usage.output_tokens": info.answer_tokens, + "gen_ai.usage.total_tokens": info.total_tokens, + "dify.message.status": metadata.get("status"), + "dify.message.error": info.error, + "dify.message.from_source": metadata.get("from_source"), + "dify.message.from_end_user_id": metadata.get("from_end_user_id"), + "dify.message.from_account_id": metadata.get("from_account_id"), + "dify.streaming": info.is_streaming_request, + "dify.message.time_to_first_token": info.gen_ai_server_time_to_first_token, + "dify.message.streaming_duration": info.llm_streaming_time_to_generate, + "dify.workflow.run_id": metadata.get("workflow_run_id"), + } + ) + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + ref = f"ref:message_id={info.message_id}" + inputs = self._safe_payload_value(info.inputs) + outputs = self._safe_payload_value(info.outputs) + attrs["dify.message.inputs"] = self._content_or_ref(inputs, ref) + attrs["dify.message.outputs"] = self._content_or_ref(outputs, ref) + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.MESSAGE_RUN, + attributes=attrs, + trace_id_source=metadata.get("workflow_run_id") or str(info.message_id) if info.message_id else None, + span_id_source=node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + model_provider=metadata.get("ls_provider") or "", + model_name=metadata.get("ls_model_name") or "", + ) + token_labels = TokenMetricLabels( + tenant_id=tenant_id or "", + app_id=app_id or "", + operation_type=OperationType.MESSAGE, + model_provider=metadata.get("ls_provider") or "", + model_name=metadata.get("ls_model_name") or "", + node_type="", + ).to_dict() + self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels) + if info.message_tokens > 0: + self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.message_tokens, token_labels) + if info.answer_tokens > 0: + self._exporter.increment_counter(EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.answer_tokens, token_labels) + invoke_from = metadata.get("from_source", "") + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="message", + status=metadata.get("status", ""), + invoke_from=invoke_from, + ), + ) + + if info.start_time and info.end_time: + duration = (info.end_time - info.start_time).total_seconds() + self._exporter.record_histogram(EnterpriseTelemetryHistogram.MESSAGE_DURATION, duration, labels) + + if info.gen_ai_server_time_to_first_token is not None: + self._exporter.record_histogram( + EnterpriseTelemetryHistogram.MESSAGE_TTFT, info.gen_ai_server_time_to_first_token, labels + ) + + if info.error: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.ERRORS, + 1, + self._labels( + **labels, + type="message", + ), + ) + + def _tool_trace(self, info: ToolTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = self._common_attrs(info) + attrs.update( + { + "gen_ai.tool.name": info.tool_name, + "dify.tool.time_cost": info.time_cost, + "dify.tool.error": info.error, + "dify.workflow.run_id": metadata.get("workflow_run_id"), + } + ) + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + ref = f"ref:message_id={info.message_id}" + attrs["dify.tool.inputs"] = self._content_or_ref(info.tool_inputs, ref) + attrs["dify.tool.outputs"] = self._content_or_ref(info.tool_outputs, ref) + attrs["dify.tool.parameters"] = self._content_or_ref(info.tool_parameters, ref) + attrs["dify.tool.config"] = self._content_or_ref(info.tool_config, ref) + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.TOOL_EXECUTION, + attributes=attrs, + trace_id_source=info.resolved_trace_id, + span_id_source=node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + tool_name=info.tool_name, + ) + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="tool", + ), + ) + self._exporter.record_histogram(EnterpriseTelemetryHistogram.TOOL_DURATION, float(info.time_cost), labels) + + if info.error: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.ERRORS, + 1, + self._labels( + **labels, + type="tool", + ), + ) + + def _moderation_trace(self, info: ModerationTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = self._common_attrs(info) + attrs.update( + { + "dify.moderation.flagged": info.flagged, + "dify.moderation.action": info.action, + "dify.moderation.preset_response": info.preset_response, + "dify.workflow.run_id": metadata.get("workflow_run_id"), + } + ) + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + attrs["dify.moderation.query"] = self._content_or_ref( + info.query, + f"ref:message_id={info.message_id}", + ) + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.MODERATION_CHECK, + attributes=attrs, + trace_id_source=info.resolved_trace_id, + span_id_source=node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + ) + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="moderation", + ), + ) + + def _suggested_question_trace(self, info: SuggestedQuestionTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = self._common_attrs(info) + attrs.update( + { + "gen_ai.usage.total_tokens": info.total_tokens, + "dify.suggested_question.status": info.status, + "dify.suggested_question.error": info.error, + "gen_ai.provider.name": info.model_provider, + "gen_ai.request.model": info.model_id, + "dify.suggested_question.count": len(info.suggested_question), + "dify.workflow.run_id": metadata.get("workflow_run_id"), + } + ) + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + attrs["dify.suggested_question.questions"] = self._content_or_ref( + info.suggested_question, + f"ref:message_id={info.message_id}", + ) + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.SUGGESTED_QUESTION_GENERATION, + attributes=attrs, + trace_id_source=info.resolved_trace_id, + span_id_source=node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + ) + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="suggested_question", + model_provider=info.model_provider or "", + model_name=info.model_id or "", + ), + ) + + def _dataset_retrieval_trace(self, info: DatasetRetrievalTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = self._common_attrs(info) + attrs["dify.dataset.error"] = info.error + attrs["dify.workflow.run_id"] = metadata.get("workflow_run_id") + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + docs: list[dict[str, Any]] = [] + documents_any: Any = info.documents + documents_list: list[Any] = cast(list[Any], documents_any) if isinstance(documents_any, list) else [] + for entry in documents_list: + if isinstance(entry, dict): + entry_dict: dict[str, Any] = cast(dict[str, Any], entry) + docs.append(entry_dict) + dataset_ids: list[str] = [] + dataset_names: list[str] = [] + structured_docs: list[dict[str, Any]] = [] + for doc in docs: + meta_raw = doc.get("metadata") + meta: dict[str, Any] = cast(dict[str, Any], meta_raw) if isinstance(meta_raw, dict) else {} + did = meta.get("dataset_id") + dname = meta.get("dataset_name") + if did and did not in dataset_ids: + dataset_ids.append(did) + if dname and dname not in dataset_names: + dataset_names.append(dname) + structured_docs.append( + { + "dataset_id": did, + "document_id": meta.get("document_id"), + "segment_id": meta.get("segment_id"), + "score": meta.get("score"), + } + ) + + attrs["dify.dataset.ids"] = self._maybe_json(dataset_ids) + attrs["dify.dataset.names"] = self._maybe_json(dataset_names) + attrs["dify.retrieval.document_count"] = len(docs) + + embedding_models_raw: Any = metadata.get("embedding_models") + embedding_models: dict[str, Any] = ( + cast(dict[str, Any], embedding_models_raw) if isinstance(embedding_models_raw, dict) else {} + ) + if embedding_models: + providers: list[str] = [] + models: list[str] = [] + for ds_info in embedding_models.values(): + if isinstance(ds_info, dict): + ds_info_dict: dict[str, Any] = cast(dict[str, Any], ds_info) + p = ds_info_dict.get("embedding_model_provider", "") + m = ds_info_dict.get("embedding_model", "") + if p and p not in providers: + providers.append(p) + if m and m not in models: + models.append(m) + attrs["dify.dataset.embedding_providers"] = self._maybe_json(providers) + attrs["dify.dataset.embedding_models"] = self._maybe_json(models) + + # Add rerank model to logs + rerank_provider = metadata.get("rerank_model_provider", "") + rerank_model = metadata.get("rerank_model_name", "") + if rerank_provider or rerank_model: + attrs["dify.retrieval.rerank_provider"] = rerank_provider + attrs["dify.retrieval.rerank_model"] = rerank_model + + ref = f"ref:message_id={info.message_id}" + retrieval_inputs = self._safe_payload_value(info.inputs) + attrs["dify.retrieval.query"] = self._content_or_ref(retrieval_inputs, ref) + attrs["dify.dataset.documents"] = self._content_or_ref(structured_docs, ref) + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.DATASET_RETRIEVAL, + attributes=attrs, + trace_id_source=metadata.get("workflow_run_id") or str(info.message_id) if info.message_id else None, + span_id_source=node_execution_id or (str(info.message_id) if info.message_id else None), + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + ) + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="dataset_retrieval", + ), + ) + + for did in dataset_ids: + # Get embedding model for this specific dataset + ds_embedding_info = embedding_models.get(did, {}) + embedding_provider = ds_embedding_info.get("embedding_model_provider", "") + embedding_model = ds_embedding_info.get("embedding_model", "") + + # Get rerank model (same for all datasets in this retrieval) + rerank_provider = metadata.get("rerank_model_provider", "") + rerank_model = metadata.get("rerank_model_name", "") + + self._exporter.increment_counter( + EnterpriseTelemetryCounter.DATASET_RETRIEVALS, + 1, + self._labels( + **labels, + dataset_id=did, + embedding_model_provider=embedding_provider, + embedding_model=embedding_model, + rerank_model_provider=rerank_provider, + rerank_model=rerank_model, + ), + ) + + def _generate_name_trace(self, info: GenerateNameTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = self._common_attrs(info) + attrs["dify.conversation.id"] = info.conversation_id + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + ref = f"ref:conversation_id={info.conversation_id}" + inputs = self._safe_payload_value(info.inputs) + outputs = self._safe_payload_value(info.outputs) + attrs["dify.generate_name.inputs"] = self._content_or_ref(inputs, ref) + attrs["dify.generate_name.outputs"] = self._content_or_ref(outputs, ref) + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.GENERATE_NAME_EXECUTION, + attributes=attrs, + trace_id_source=info.resolved_trace_id, + span_id_source=node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + ) + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="generate_name", + ), + ) + + def _prompt_generation_trace(self, info: PromptGenerationTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = { + "dify.trace_id": info.resolved_trace_id, + "dify.tenant_id": tenant_id, + "dify.user.id": user_id, + "dify.app.id": app_id or "", + "dify.app.name": metadata.get("app_name"), + "dify.workspace.name": metadata.get("workspace_name"), + "dify.operation.type": info.operation_type, + "gen_ai.provider.name": info.model_provider, + "gen_ai.request.model": info.model_name, + "gen_ai.usage.input_tokens": info.prompt_tokens, + "gen_ai.usage.output_tokens": info.completion_tokens, + "gen_ai.usage.total_tokens": info.total_tokens, + "dify.prompt_generation.latency": info.latency, + "dify.prompt_generation.error": info.error, + } + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + if info.total_price is not None: + attrs["dify.prompt_generation.total_price"] = info.total_price + attrs["dify.prompt_generation.currency"] = info.currency + + ref = f"ref:trace_id={info.trace_id}" + outputs = self._safe_payload_value(info.outputs) + attrs["dify.prompt_generation.instruction"] = self._content_or_ref(info.instruction, ref) + attrs["dify.prompt_generation.output"] = self._content_or_ref(outputs, ref) + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.PROMPT_GENERATION_EXECUTION, + attributes=attrs, + trace_id_source=info.resolved_trace_id, + span_id_source=node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + token_labels = TokenMetricLabels( + tenant_id=tenant_id or "", + app_id=app_id or "", + operation_type=info.operation_type, + model_provider=info.model_provider, + model_name=info.model_name, + node_type="", + ).to_dict() + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + operation_type=info.operation_type, + model_provider=info.model_provider, + model_name=info.model_name, + ) + + self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels) + if info.prompt_tokens > 0: + self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, token_labels) + if info.completion_tokens > 0: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, token_labels + ) + + status = "failed" if info.error else "success" + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="prompt_generation", + status=status, + ), + ) + + self._exporter.record_histogram( + EnterpriseTelemetryHistogram.PROMPT_GENERATION_DURATION, + info.latency, + labels, + ) + + if info.error: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.ERRORS, + 1, + self._labels( + **labels, + type="prompt_generation", + ), + ) diff --git a/api/enterprise/telemetry/entities/__init__.py b/api/enterprise/telemetry/entities/__init__.py new file mode 100644 index 0000000000..4a9bd3dbf8 --- /dev/null +++ b/api/enterprise/telemetry/entities/__init__.py @@ -0,0 +1,121 @@ +from enum import StrEnum +from typing import cast + +from opentelemetry.util.types import AttributeValue +from pydantic import BaseModel, ConfigDict + + +class EnterpriseTelemetrySpan(StrEnum): + WORKFLOW_RUN = "dify.workflow.run" + NODE_EXECUTION = "dify.node.execution" + DRAFT_NODE_EXECUTION = "dify.node.execution.draft" + + +class EnterpriseTelemetryEvent(StrEnum): + """Event names for enterprise telemetry logs.""" + + APP_CREATED = "dify.app.created" + APP_UPDATED = "dify.app.updated" + APP_DELETED = "dify.app.deleted" + FEEDBACK_CREATED = "dify.feedback.created" + WORKFLOW_RUN = "dify.workflow.run" + MESSAGE_RUN = "dify.message.run" + TOOL_EXECUTION = "dify.tool.execution" + MODERATION_CHECK = "dify.moderation.check" + SUGGESTED_QUESTION_GENERATION = "dify.suggested_question.generation" + DATASET_RETRIEVAL = "dify.dataset.retrieval" + GENERATE_NAME_EXECUTION = "dify.generate_name.execution" + PROMPT_GENERATION_EXECUTION = "dify.prompt_generation.execution" + REHYDRATION_FAILED = "dify.telemetry.rehydration_failed" + + +class EnterpriseTelemetryCounter(StrEnum): + TOKENS = "tokens" + INPUT_TOKENS = "input_tokens" + OUTPUT_TOKENS = "output_tokens" + REQUESTS = "requests" + ERRORS = "errors" + FEEDBACK = "feedback" + DATASET_RETRIEVALS = "dataset_retrievals" + APP_CREATED = "app_created" + APP_UPDATED = "app_updated" + APP_DELETED = "app_deleted" + + +class EnterpriseTelemetryHistogram(StrEnum): + WORKFLOW_DURATION = "workflow_duration" + NODE_DURATION = "node_duration" + MESSAGE_DURATION = "message_duration" + MESSAGE_TTFT = "message_ttft" + TOOL_DURATION = "tool_duration" + PROMPT_GENERATION_DURATION = "prompt_generation_duration" + + +class TokenMetricLabels(BaseModel): + """Unified label structure for all dify.token.* metrics. + + All token counters (dify.tokens.input, dify.tokens.output, dify.tokens.total) MUST + use this exact label set to ensure consistent filtering and aggregation across + different operation types. + + Attributes: + tenant_id: Tenant identifier. + app_id: Application identifier. + operation_type: Source of token usage (workflow | node_execution | message | + rule_generate | code_generate | structured_output | instruction_modify). + model_provider: LLM provider name. Empty string if not applicable (e.g., workflow-level). + model_name: LLM model name. Empty string if not applicable (e.g., workflow-level). + node_type: Workflow node type. Empty string unless operation_type=node_execution. + + Usage: + labels = TokenMetricLabels( + tenant_id="tenant-123", + app_id="app-456", + operation_type=OperationType.WORKFLOW, + model_provider="", + model_name="", + node_type="", + ) + exporter.increment_counter( + EnterpriseTelemetryCounter.INPUT_TOKENS, + 100, + labels.to_dict() + ) + + Design rationale: + Without this unified structure, tokens get double-counted when querying totals + because workflow.total_tokens is already the sum of all node tokens. The + operation_type label allows filtering to separate workflow-level aggregates from + node-level detail, while keeping the same label cardinality for consistent queries. + """ + + tenant_id: str + app_id: str + operation_type: str + model_provider: str + model_name: str + node_type: str + + model_config = ConfigDict(extra="forbid", frozen=True) + + def to_dict(self) -> dict[str, AttributeValue]: + return cast( + dict[str, AttributeValue], + { + "tenant_id": self.tenant_id, + "app_id": self.app_id, + "operation_type": self.operation_type, + "model_provider": self.model_provider, + "model_name": self.model_name, + "node_type": self.node_type, + }, + ) + + +__all__ = [ + "EnterpriseTelemetryCounter", + "EnterpriseTelemetryEvent", + "EnterpriseTelemetryHistogram", + "EnterpriseTelemetrySpan", + "TokenMetricLabels", +] diff --git a/api/enterprise/telemetry/event_handlers.py b/api/enterprise/telemetry/event_handlers.py new file mode 100644 index 0000000000..167cde2cd8 --- /dev/null +++ b/api/enterprise/telemetry/event_handlers.py @@ -0,0 +1,99 @@ +"""Blinker signal handlers for enterprise telemetry. + +Registered at import time via ``@signal.connect`` decorators. +Import must happen during ``ext_enterprise_telemetry.init_app()`` to +ensure handlers fire. Each handler delegates to ``core.telemetry.gateway`` +which handles routing, EE-gating, and dispatch. + +All handlers are best-effort: exceptions are caught and logged so that +telemetry failures never break user-facing operations. +""" + +from __future__ import annotations + +import logging + +from events.app_event import app_was_created, app_was_deleted, app_was_updated +from events.feedback_event import feedback_was_created + +logger = logging.getLogger(__name__) + +__all__ = [ + "_handle_app_created", + "_handle_app_deleted", + "_handle_app_updated", + "_handle_feedback_created", +] + + +@app_was_created.connect +def _handle_app_created(sender: object, **kwargs: object) -> None: + try: + from core.telemetry.gateway import emit as gateway_emit + from enterprise.telemetry.contracts import TelemetryCase + + gateway_emit( + case=TelemetryCase.APP_CREATED, + context={"tenant_id": str(getattr(sender, "tenant_id", "") or "")}, + payload={ + "app_id": getattr(sender, "id", None), + "mode": getattr(sender, "mode", None), + }, + ) + except Exception: + logger.warning("Failed to emit app_created telemetry", exc_info=True) + + +@app_was_deleted.connect +def _handle_app_deleted(sender: object, **kwargs: object) -> None: + try: + from core.telemetry.gateway import emit as gateway_emit + from enterprise.telemetry.contracts import TelemetryCase + + gateway_emit( + case=TelemetryCase.APP_DELETED, + context={"tenant_id": str(getattr(sender, "tenant_id", "") or "")}, + payload={"app_id": getattr(sender, "id", None)}, + ) + except Exception: + logger.warning("Failed to emit app_deleted telemetry", exc_info=True) + + +@app_was_updated.connect +def _handle_app_updated(sender: object, **kwargs: object) -> None: + try: + from core.telemetry.gateway import emit as gateway_emit + from enterprise.telemetry.contracts import TelemetryCase + + gateway_emit( + case=TelemetryCase.APP_UPDATED, + context={"tenant_id": str(getattr(sender, "tenant_id", "") or "")}, + payload={"app_id": getattr(sender, "id", None)}, + ) + except Exception: + logger.warning("Failed to emit app_updated telemetry", exc_info=True) + + +@feedback_was_created.connect +def _handle_feedback_created(sender: object, **kwargs: object) -> None: + try: + from core.telemetry.gateway import emit as gateway_emit + from enterprise.telemetry.contracts import TelemetryCase + + tenant_id = str(kwargs.get("tenant_id", "") or "") + gateway_emit( + case=TelemetryCase.FEEDBACK_CREATED, + context={"tenant_id": tenant_id}, + payload={ + "message_id": getattr(sender, "message_id", None), + "app_id": getattr(sender, "app_id", None), + "conversation_id": getattr(sender, "conversation_id", None), + "from_end_user_id": getattr(sender, "from_end_user_id", None), + "from_account_id": getattr(sender, "from_account_id", None), + "rating": getattr(sender, "rating", None), + "from_source": getattr(sender, "from_source", None), + "content": getattr(sender, "content", None), + }, + ) + except Exception: + logger.warning("Failed to emit feedback_created telemetry", exc_info=True) diff --git a/api/enterprise/telemetry/exporter.py b/api/enterprise/telemetry/exporter.py new file mode 100644 index 0000000000..6276853dc1 --- /dev/null +++ b/api/enterprise/telemetry/exporter.py @@ -0,0 +1,284 @@ +"""Enterprise OTEL exporter — shared by EnterpriseOtelTrace, event handlers, and direct instrumentation. + +Uses dedicated TracerProvider and MeterProvider instances (configurable sampling, +independent from ext_otel.py infrastructure). + +Initialized once during Flask extension init (single-threaded via ext_enterprise_telemetry.py). +Accessed via ``ext_enterprise_telemetry.get_enterprise_exporter()`` from any thread/process. +""" + +import logging +import socket +import uuid +from datetime import datetime +from typing import Any, cast + +from opentelemetry import trace +from opentelemetry.context import Context +from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter as GRPCMetricExporter +from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GRPCSpanExporter +from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as HTTPMetricExporter +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HTTPSpanExporter +from opentelemetry.sdk.metrics import MeterProvider +from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor +from opentelemetry.sdk.trace.sampling import ParentBasedTraceIdRatio +from opentelemetry.semconv.resource import ResourceAttributes +from opentelemetry.trace import SpanContext, TraceFlags +from opentelemetry.util.types import Attributes, AttributeValue + +from configs import dify_config +from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryHistogram +from enterprise.telemetry.id_generator import ( + CorrelationIdGenerator, + compute_deterministic_span_id, + set_correlation_id, + set_span_id_source, +) + +logger = logging.getLogger(__name__) + + +def is_enterprise_telemetry_enabled() -> bool: + return bool(dify_config.ENTERPRISE_ENABLED and dify_config.ENTERPRISE_TELEMETRY_ENABLED) + + +def _parse_otlp_headers(raw: str) -> dict[str, str]: + """Parse ``key=value,key2=value2`` into a dict.""" + if not raw: + return {} + headers: dict[str, str] = {} + for pair in raw.split(","): + if "=" not in pair: + continue + k, v = pair.split("=", 1) + headers[k.strip()] = v.strip() + return headers + + +def _datetime_to_ns(dt: datetime) -> int: + """Convert a datetime to nanoseconds since epoch (OTEL convention).""" + return int(dt.timestamp() * 1_000_000_000) + + +class _ExporterFactory: + def __init__(self, protocol: str, endpoint: str, headers: dict[str, str], insecure: bool): + self._protocol = protocol + self._endpoint = endpoint + self._headers = headers + self._grpc_headers = tuple(headers.items()) if headers else None + self._http_headers = headers or None + self._insecure = insecure + + def create_trace_exporter(self) -> HTTPSpanExporter | GRPCSpanExporter: + if self._protocol == "grpc": + return GRPCSpanExporter( + endpoint=self._endpoint or None, + headers=self._grpc_headers, + insecure=self._insecure, + ) + trace_endpoint = f"{self._endpoint}/v1/traces" if self._endpoint else "" + return HTTPSpanExporter(endpoint=trace_endpoint or None, headers=self._http_headers) + + def create_metric_exporter(self) -> HTTPMetricExporter | GRPCMetricExporter: + if self._protocol == "grpc": + return GRPCMetricExporter( + endpoint=self._endpoint or None, + headers=self._grpc_headers, + insecure=self._insecure, + ) + metric_endpoint = f"{self._endpoint}/v1/metrics" if self._endpoint else "" + return HTTPMetricExporter(endpoint=metric_endpoint or None, headers=self._http_headers) + + +class EnterpriseExporter: + """Shared OTEL exporter for all enterprise telemetry. + + ``export_span`` creates spans with optional real timestamps, deterministic + span/trace IDs, and cross-workflow parent linking. + ``increment_counter`` / ``record_histogram`` emit OTEL metrics at 100% accuracy. + """ + + def __init__(self, config: object) -> None: + endpoint: str = getattr(config, "ENTERPRISE_OTLP_ENDPOINT", "") + headers_raw: str = getattr(config, "ENTERPRISE_OTLP_HEADERS", "") + protocol: str = (getattr(config, "ENTERPRISE_OTLP_PROTOCOL", "http") or "http").lower() + service_name: str = getattr(config, "ENTERPRISE_SERVICE_NAME", "dify") + sampling_rate: float = getattr(config, "ENTERPRISE_OTEL_SAMPLING_RATE", 1.0) + self.include_content: bool = getattr(config, "ENTERPRISE_INCLUDE_CONTENT", True) + api_key: str = getattr(config, "ENTERPRISE_OTLP_API_KEY", "") + + # Auto-detect TLS: https:// uses secure, everything else is insecure + insecure = not endpoint.startswith("https://") + + resource = Resource( + attributes={ + ResourceAttributes.SERVICE_NAME: service_name, + ResourceAttributes.HOST_NAME: socket.gethostname(), + } + ) + sampler = ParentBasedTraceIdRatio(sampling_rate) + id_generator = CorrelationIdGenerator() + self._tracer_provider = TracerProvider(resource=resource, sampler=sampler, id_generator=id_generator) + + headers = _parse_otlp_headers(headers_raw) + if api_key: + if "authorization" in headers: + logger.warning( + "ENTERPRISE_OTLP_API_KEY is set but ENTERPRISE_OTLP_HEADERS also contains " + "'authorization'; the API key will take precedence." + ) + headers["authorization"] = f"Bearer {api_key}" + factory = _ExporterFactory(protocol, endpoint, headers, insecure=insecure) + + trace_exporter = factory.create_trace_exporter() + self._tracer_provider.add_span_processor(BatchSpanProcessor(trace_exporter)) + self._tracer = self._tracer_provider.get_tracer("dify.enterprise") + + metric_exporter = factory.create_metric_exporter() + self._meter_provider = MeterProvider( + resource=resource, + metric_readers=[PeriodicExportingMetricReader(metric_exporter)], + ) + meter = self._meter_provider.get_meter("dify.enterprise") + self._counters = { + EnterpriseTelemetryCounter.TOKENS: meter.create_counter("dify.tokens.total", unit="{token}"), + EnterpriseTelemetryCounter.INPUT_TOKENS: meter.create_counter("dify.tokens.input", unit="{token}"), + EnterpriseTelemetryCounter.OUTPUT_TOKENS: meter.create_counter("dify.tokens.output", unit="{token}"), + EnterpriseTelemetryCounter.REQUESTS: meter.create_counter("dify.requests.total", unit="{request}"), + EnterpriseTelemetryCounter.ERRORS: meter.create_counter("dify.errors.total", unit="{error}"), + EnterpriseTelemetryCounter.FEEDBACK: meter.create_counter("dify.feedback.total", unit="{feedback}"), + EnterpriseTelemetryCounter.DATASET_RETRIEVALS: meter.create_counter( + "dify.dataset.retrievals.total", unit="{retrieval}" + ), + EnterpriseTelemetryCounter.APP_CREATED: meter.create_counter("dify.app.created.total", unit="{app}"), + EnterpriseTelemetryCounter.APP_UPDATED: meter.create_counter("dify.app.updated.total", unit="{app}"), + EnterpriseTelemetryCounter.APP_DELETED: meter.create_counter("dify.app.deleted.total", unit="{app}"), + } + self._histograms = { + EnterpriseTelemetryHistogram.WORKFLOW_DURATION: meter.create_histogram("dify.workflow.duration", unit="s"), + EnterpriseTelemetryHistogram.NODE_DURATION: meter.create_histogram("dify.node.duration", unit="s"), + EnterpriseTelemetryHistogram.MESSAGE_DURATION: meter.create_histogram("dify.message.duration", unit="s"), + EnterpriseTelemetryHistogram.MESSAGE_TTFT: meter.create_histogram( + "dify.message.time_to_first_token", unit="s" + ), + EnterpriseTelemetryHistogram.TOOL_DURATION: meter.create_histogram("dify.tool.duration", unit="s"), + EnterpriseTelemetryHistogram.PROMPT_GENERATION_DURATION: meter.create_histogram( + "dify.prompt_generation.duration", unit="s" + ), + } + + def export_span( + self, + name: str, + attributes: dict[str, Any], + correlation_id: str | None = None, + span_id_source: str | None = None, + start_time: datetime | None = None, + end_time: datetime | None = None, + trace_correlation_override: str | None = None, + parent_span_id_source: str | None = None, + ) -> None: + """Export an OTEL span with optional deterministic IDs and real timestamps. + + Args: + name: Span operation name. + attributes: Span attributes dict. + correlation_id: Source for trace_id derivation (groups spans in one trace). + span_id_source: Source for deterministic span_id (e.g. workflow_run_id or node_execution_id). + start_time: Real span start time. When None, uses current time. + end_time: Real span end time. When None, span ends immediately. + trace_correlation_override: Override trace_id source (for cross-workflow linking). + When set, trace_id is derived from this instead of ``correlation_id``. + parent_span_id_source: Override parent span_id source (for cross-workflow linking). + When set, parent span_id is derived from this value. When None and + ``correlation_id`` is set, parent is the workflow root span. + """ + effective_trace_correlation = trace_correlation_override or correlation_id + set_correlation_id(effective_trace_correlation) + set_span_id_source(span_id_source) + + try: + parent_context: Context | None = None + # A span is the "root" of its correlation group when span_id_source == correlation_id + # (i.e. a workflow root span). All other spans are children. + if parent_span_id_source: + # Cross-workflow linking: parent is an explicit span (e.g. tool node in outer workflow) + parent_span_id = compute_deterministic_span_id(parent_span_id_source) + try: + parent_trace_id = int(uuid.UUID(effective_trace_correlation)) if effective_trace_correlation else 0 + except (ValueError, AttributeError): + logger.warning( + "Invalid trace correlation UUID for cross-workflow link: %s, span=%s", + effective_trace_correlation, + name, + ) + parent_trace_id = 0 + if parent_trace_id: + parent_span_context = SpanContext( + trace_id=parent_trace_id, + span_id=parent_span_id, + is_remote=True, + trace_flags=TraceFlags(TraceFlags.SAMPLED), + ) + parent_context = trace.set_span_in_context(trace.NonRecordingSpan(parent_span_context)) + elif correlation_id and correlation_id != span_id_source: + # Child span: parent is the correlation-group root (workflow root span) + parent_span_id = compute_deterministic_span_id(correlation_id) + try: + parent_trace_id = int(uuid.UUID(effective_trace_correlation or correlation_id)) + except (ValueError, AttributeError): + logger.warning( + "Invalid trace correlation UUID for child span link: %s, span=%s", + effective_trace_correlation or correlation_id, + name, + ) + parent_trace_id = 0 + if parent_trace_id: + parent_span_context = SpanContext( + trace_id=parent_trace_id, + span_id=parent_span_id, + is_remote=True, + trace_flags=TraceFlags(TraceFlags.SAMPLED), + ) + parent_context = trace.set_span_in_context(trace.NonRecordingSpan(parent_span_context)) + + span_start_time = _datetime_to_ns(start_time) if start_time is not None else None + span_end_on_exit = end_time is None + + with self._tracer.start_as_current_span( + name, + context=parent_context, + start_time=span_start_time, + end_on_exit=span_end_on_exit, + ) as span: + for key, value in attributes.items(): + if value is not None: + span.set_attribute(key, value) + if end_time is not None: + span.end(end_time=_datetime_to_ns(end_time)) + except Exception: + logger.exception("Failed to export span %s", name) + finally: + set_correlation_id(None) + set_span_id_source(None) + + def increment_counter( + self, name: EnterpriseTelemetryCounter, value: int, labels: dict[str, AttributeValue] + ) -> None: + counter = self._counters.get(name) + if counter: + counter.add(value, cast(Attributes, labels)) + + def record_histogram( + self, name: EnterpriseTelemetryHistogram, value: float, labels: dict[str, AttributeValue] + ) -> None: + histogram = self._histograms.get(name) + if histogram: + histogram.record(value, cast(Attributes, labels)) + + def shutdown(self) -> None: + self._tracer_provider.shutdown() + self._meter_provider.shutdown() diff --git a/api/enterprise/telemetry/id_generator.py b/api/enterprise/telemetry/id_generator.py new file mode 100644 index 0000000000..8f4760cac2 --- /dev/null +++ b/api/enterprise/telemetry/id_generator.py @@ -0,0 +1,76 @@ +"""Custom OTEL ID Generator for correlation-based trace/span ID derivation. + +Uses contextvars for thread-safe correlation_id -> trace_id mapping. +When a span_id_source is set, the span_id is derived deterministically +from that value, enabling any span to reference another as parent +without depending on span creation order. +""" + +import random +import uuid +from contextvars import ContextVar +from typing import cast + +from opentelemetry.sdk.trace.id_generator import IdGenerator + +_correlation_id_context: ContextVar[str | None] = ContextVar("correlation_id", default=None) +_span_id_source_context: ContextVar[str | None] = ContextVar("span_id_source", default=None) + + +def set_correlation_id(correlation_id: str | None) -> None: + _correlation_id_context.set(correlation_id) + + +def get_correlation_id() -> str | None: + return _correlation_id_context.get() + + +def set_span_id_source(source_id: str | None) -> None: + """Set the source for deterministic span_id generation. + + When set, ``generate_span_id()`` derives the span_id from this value + (lower 64 bits of the UUID). Pass the ``workflow_run_id`` for workflow + root spans or ``node_execution_id`` for node spans. + """ + _span_id_source_context.set(source_id) + + +def compute_deterministic_span_id(source_id: str) -> int: + """Derive a deterministic span_id from any UUID string. + + Uses the lower 64 bits of the UUID, guaranteeing non-zero output + (OTEL requires span_id != 0). + """ + span_id = cast(int, uuid.UUID(source_id).int) & ((1 << 64) - 1) + return span_id if span_id != 0 else 1 + + +class CorrelationIdGenerator(IdGenerator): + """ID generator that derives trace_id and optionally span_id from context. + + - trace_id: always derived from correlation_id (groups all spans in one trace) + - span_id: derived from span_id_source when set (enables deterministic + parent-child linking), otherwise random + """ + + def generate_trace_id(self) -> int: + correlation_id = _correlation_id_context.get() + if correlation_id: + try: + return cast(int, uuid.UUID(correlation_id).int) + except (ValueError, AttributeError): + pass + return random.getrandbits(128) + + def generate_span_id(self) -> int: + source = _span_id_source_context.get() + if source: + try: + return compute_deterministic_span_id(source) + except (ValueError, AttributeError): + pass + + span_id = random.getrandbits(64) + while span_id == 0: + span_id = random.getrandbits(64) + return span_id diff --git a/api/enterprise/telemetry/metric_handler.py b/api/enterprise/telemetry/metric_handler.py new file mode 100644 index 0000000000..25bec993b7 --- /dev/null +++ b/api/enterprise/telemetry/metric_handler.py @@ -0,0 +1,381 @@ +"""Enterprise metric/log event handler. + +This module processes metric and log telemetry events after they've been +dequeued from the enterprise_telemetry Celery queue. It handles case routing, +idempotency checking, and payload rehydration. +""" + +from __future__ import annotations + +import json +import logging +from typing import Any + +from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope +from extensions.ext_redis import redis_client +from extensions.ext_storage import storage + +logger = logging.getLogger(__name__) + + +class EnterpriseMetricHandler: + """Handler for enterprise metric and log telemetry events. + + Processes envelopes from the enterprise_telemetry queue, routing each + case to the appropriate handler method. Implements idempotency checking + and payload rehydration with fallback. + """ + + def _increment_diagnostic_counter(self, counter_name: str, labels: dict[str, str] | None = None) -> None: + """Increment a diagnostic counter for operational monitoring. + + Args: + counter_name: Name of the counter (e.g., 'processed_total', 'deduped_total'). + labels: Optional labels for the counter. + """ + try: + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + + exporter = get_enterprise_exporter() + if not exporter: + return + + full_counter_name = f"enterprise_telemetry.handler.{counter_name}" + logger.debug( + "Diagnostic counter: %s, labels=%s", + full_counter_name, + labels or {}, + ) + except Exception: + logger.debug("Failed to increment diagnostic counter: %s", counter_name, exc_info=True) + + def handle(self, envelope: TelemetryEnvelope) -> None: + """Main entry point for processing telemetry envelopes. + + Args: + envelope: The telemetry envelope to process. + """ + # Check for duplicate events + if self._is_duplicate(envelope): + logger.debug( + "Skipping duplicate event: tenant_id=%s, event_id=%s", + envelope.tenant_id, + envelope.event_id, + ) + self._increment_diagnostic_counter("deduped_total") + return + + # Route to appropriate handler based on case + case = envelope.case + if case == TelemetryCase.APP_CREATED: + self._on_app_created(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "app_created"}) + elif case == TelemetryCase.APP_UPDATED: + self._on_app_updated(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "app_updated"}) + elif case == TelemetryCase.APP_DELETED: + self._on_app_deleted(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "app_deleted"}) + elif case == TelemetryCase.FEEDBACK_CREATED: + self._on_feedback_created(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "feedback_created"}) + elif case == TelemetryCase.MESSAGE_RUN: + self._on_message_run(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "message_run"}) + elif case == TelemetryCase.TOOL_EXECUTION: + self._on_tool_execution(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "tool_execution"}) + elif case == TelemetryCase.MODERATION_CHECK: + self._on_moderation_check(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "moderation_check"}) + elif case == TelemetryCase.SUGGESTED_QUESTION: + self._on_suggested_question(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "suggested_question"}) + elif case == TelemetryCase.DATASET_RETRIEVAL: + self._on_dataset_retrieval(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "dataset_retrieval"}) + elif case == TelemetryCase.GENERATE_NAME: + self._on_generate_name(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "generate_name"}) + elif case == TelemetryCase.PROMPT_GENERATION: + self._on_prompt_generation(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "prompt_generation"}) + else: + logger.warning( + "Unknown telemetry case: %s (tenant_id=%s, event_id=%s)", + case, + envelope.tenant_id, + envelope.event_id, + ) + + def _is_duplicate(self, envelope: TelemetryEnvelope) -> bool: + """Check if this event has already been processed. + + Uses Redis with TTL for deduplication. Returns True if duplicate, + False if first time seeing this event. + + Args: + envelope: The telemetry envelope to check. + + Returns: + True if this event_id has been seen before, False otherwise. + """ + dedup_key = f"telemetry:dedup:{envelope.tenant_id}:{envelope.event_id}" + + try: + # Atomic set-if-not-exists with 1h TTL + # Returns True if key was set (first time), None if already exists (duplicate) + was_set = redis_client.set(dedup_key, b"1", nx=True, ex=3600) + return was_set is None + except Exception: + # Fail open: if Redis is unavailable, process the event + # (prefer occasional duplicate over lost data) + logger.warning( + "Redis unavailable for deduplication check, processing event anyway: %s", + envelope.event_id, + exc_info=True, + ) + return False + + def _rehydrate(self, envelope: TelemetryEnvelope) -> dict[str, Any]: + """Rehydrate payload from storage reference or inline data. + + If the envelope payload is empty and metadata contains a + ``payload_ref``, the full payload is loaded from object storage + (where the gateway wrote it as JSON). When both the inline + payload and storage resolution fail, a degraded-event marker + is emitted so the gap is observable. + + Args: + envelope: The telemetry envelope containing payload data. + + Returns: + The rehydrated payload dictionary, or ``{}`` on total failure. + """ + payload = envelope.payload + + # Resolve from object storage when the gateway offloaded a large payload. + if not payload and envelope.metadata: + payload_ref = envelope.metadata.get("payload_ref") + if payload_ref: + try: + payload_bytes = storage.load(payload_ref) + payload = json.loads(payload_bytes.decode("utf-8")) + logger.debug("Loaded payload from storage: key=%s", payload_ref) + except Exception: + logger.warning( + "Failed to load payload from storage: key=%s, event_id=%s", + payload_ref, + envelope.event_id, + exc_info=True, + ) + + if not payload: + # Storage resolution failed or no data available — emit degraded event. + logger.error( + "Payload rehydration failed for event_id=%s, tenant_id=%s, case=%s", + envelope.event_id, + envelope.tenant_id, + envelope.case, + ) + from enterprise.telemetry.entities import EnterpriseTelemetryEvent + from enterprise.telemetry.telemetry_log import emit_metric_only_event + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.REHYDRATION_FAILED, + attributes={ + "dify.tenant_id": envelope.tenant_id, + "dify.event_id": envelope.event_id, + "dify.case": envelope.case, + "rehydration_failed": True, + }, + tenant_id=envelope.tenant_id, + ) + self._increment_diagnostic_counter("rehydration_failed_total") + return {} + + return payload + + # Stub methods for each metric/log case + # These will be implemented in later tasks with actual emission logic + + def _on_app_created(self, envelope: TelemetryEnvelope) -> None: + """Handle app created event.""" + from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent + from enterprise.telemetry.telemetry_log import emit_metric_only_event + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + + exporter = get_enterprise_exporter() + if not exporter: + logger.debug("No exporter available for APP_CREATED: event_id=%s", envelope.event_id) + return + + payload = self._rehydrate(envelope) + if not payload: + return + + attrs = { + "dify.app.id": payload.get("app_id"), + "dify.tenant_id": envelope.tenant_id, + "dify.event.id": envelope.event_id, + "dify.app.mode": payload.get("mode"), + } + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.APP_CREATED, + attributes=attrs, + tenant_id=envelope.tenant_id, + ) + exporter.increment_counter( + EnterpriseTelemetryCounter.APP_CREATED, + 1, + { + "tenant_id": envelope.tenant_id, + "app_id": str(payload.get("app_id", "")), + "mode": str(payload.get("mode", "")), + }, + ) + + def _on_app_updated(self, envelope: TelemetryEnvelope) -> None: + """Handle app updated event.""" + from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent + from enterprise.telemetry.telemetry_log import emit_metric_only_event + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + + exporter = get_enterprise_exporter() + if not exporter: + logger.debug("No exporter available for APP_UPDATED: event_id=%s", envelope.event_id) + return + + payload = self._rehydrate(envelope) + if not payload: + return + + attrs = { + "dify.app.id": payload.get("app_id"), + "dify.tenant_id": envelope.tenant_id, + "dify.event.id": envelope.event_id, + } + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.APP_UPDATED, + attributes=attrs, + tenant_id=envelope.tenant_id, + ) + exporter.increment_counter( + EnterpriseTelemetryCounter.APP_UPDATED, + 1, + { + "tenant_id": envelope.tenant_id, + "app_id": str(payload.get("app_id", "")), + }, + ) + + def _on_app_deleted(self, envelope: TelemetryEnvelope) -> None: + """Handle app deleted event.""" + from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent + from enterprise.telemetry.telemetry_log import emit_metric_only_event + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + + exporter = get_enterprise_exporter() + if not exporter: + logger.debug("No exporter available for APP_DELETED: event_id=%s", envelope.event_id) + return + + payload = self._rehydrate(envelope) + if not payload: + return + + attrs = { + "dify.app.id": payload.get("app_id"), + "dify.tenant_id": envelope.tenant_id, + "dify.event.id": envelope.event_id, + } + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.APP_DELETED, + attributes=attrs, + tenant_id=envelope.tenant_id, + ) + exporter.increment_counter( + EnterpriseTelemetryCounter.APP_DELETED, + 1, + { + "tenant_id": envelope.tenant_id, + "app_id": str(payload.get("app_id", "")), + }, + ) + + def _on_feedback_created(self, envelope: TelemetryEnvelope) -> None: + """Handle feedback created event.""" + from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent + from enterprise.telemetry.telemetry_log import emit_metric_only_event + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + + exporter = get_enterprise_exporter() + if not exporter: + logger.debug("No exporter available for FEEDBACK_CREATED: event_id=%s", envelope.event_id) + return + + payload = self._rehydrate(envelope) + if not payload: + return + + include_content = exporter.include_content + attrs: dict = { + "dify.message.id": payload.get("message_id"), + "dify.tenant_id": envelope.tenant_id, + "dify.event.id": envelope.event_id, + "dify.app_id": payload.get("app_id"), + "dify.conversation.id": payload.get("conversation_id"), + "gen_ai.user.id": payload.get("from_end_user_id") or payload.get("from_account_id"), + "dify.feedback.rating": payload.get("rating"), + "dify.feedback.from_source": payload.get("from_source"), + } + if include_content: + attrs["dify.feedback.content"] = payload.get("content") + + user_id = payload.get("from_end_user_id") or payload.get("from_account_id") + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.FEEDBACK_CREATED, + attributes=attrs, + tenant_id=envelope.tenant_id, + user_id=str(user_id or ""), + ) + exporter.increment_counter( + EnterpriseTelemetryCounter.FEEDBACK, + 1, + { + "tenant_id": envelope.tenant_id, + "app_id": str(payload.get("app_id", "")), + "rating": str(payload.get("rating", "")), + }, + ) + + def _on_message_run(self, envelope: TelemetryEnvelope) -> None: + """Handle message run event (stub).""" + logger.debug("Processing MESSAGE_RUN: event_id=%s", envelope.event_id) + + def _on_tool_execution(self, envelope: TelemetryEnvelope) -> None: + """Handle tool execution event (stub).""" + logger.debug("Processing TOOL_EXECUTION: event_id=%s", envelope.event_id) + + def _on_moderation_check(self, envelope: TelemetryEnvelope) -> None: + """Handle moderation check event (stub).""" + logger.debug("Processing MODERATION_CHECK: event_id=%s", envelope.event_id) + + def _on_suggested_question(self, envelope: TelemetryEnvelope) -> None: + """Handle suggested question event (stub).""" + logger.debug("Processing SUGGESTED_QUESTION: event_id=%s", envelope.event_id) + + def _on_dataset_retrieval(self, envelope: TelemetryEnvelope) -> None: + """Handle dataset retrieval event (stub).""" + logger.debug("Processing DATASET_RETRIEVAL: event_id=%s", envelope.event_id) + + def _on_generate_name(self, envelope: TelemetryEnvelope) -> None: + """Handle generate name event (stub).""" + logger.debug("Processing GENERATE_NAME: event_id=%s", envelope.event_id) + + def _on_prompt_generation(self, envelope: TelemetryEnvelope) -> None: + """Handle prompt generation event (stub).""" + logger.debug("Processing PROMPT_GENERATION: event_id=%s", envelope.event_id) diff --git a/api/enterprise/telemetry/telemetry_log.py b/api/enterprise/telemetry/telemetry_log.py new file mode 100644 index 0000000000..8cce4a9fcd --- /dev/null +++ b/api/enterprise/telemetry/telemetry_log.py @@ -0,0 +1,122 @@ +"""Structured-log emitter for enterprise telemetry events. + +Emits structured JSON log lines correlated with OTEL traces via trace_id. +Picked up by ``StructuredJSONFormatter`` → stdout/Loki/Elastic. +""" + +from __future__ import annotations + +import logging +import uuid +from functools import lru_cache +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from enterprise.telemetry.entities import EnterpriseTelemetryEvent + +logger = logging.getLogger("dify.telemetry") + + +@lru_cache(maxsize=4096) +def compute_trace_id_hex(uuid_str: str | None) -> str: + """Convert a business UUID string to a 32-hex OTEL-compatible trace_id. + + Returns empty string when *uuid_str* is ``None`` or invalid. + """ + if not uuid_str: + return "" + normalized = uuid_str.strip().lower() + if len(normalized) == 32 and all(ch in "0123456789abcdef" for ch in normalized): + return normalized + try: + return f"{uuid.UUID(normalized).int:032x}" + except (ValueError, AttributeError): + return "" + + +@lru_cache(maxsize=4096) +def compute_span_id_hex(uuid_str: str | None) -> str: + if not uuid_str: + return "" + normalized = uuid_str.strip().lower() + if len(normalized) == 16 and all(ch in "0123456789abcdef" for ch in normalized): + return normalized + try: + from enterprise.telemetry.id_generator import compute_deterministic_span_id + + return f"{compute_deterministic_span_id(normalized):016x}" + except (ValueError, AttributeError): + return "" + + +def emit_telemetry_log( + *, + event_name: str | EnterpriseTelemetryEvent, + attributes: dict[str, Any], + signal: str = "metric_only", + trace_id_source: str | None = None, + span_id_source: str | None = None, + tenant_id: str | None = None, + user_id: str | None = None, +) -> None: + """Emit a structured log line for a telemetry event. + + Parameters + ---------- + event_name: + Canonical event name, e.g. ``"dify.workflow.run"``. + attributes: + All event-specific attributes (already built by the caller). + signal: + ``"metric_only"`` for events with no span, ``"span_detail"`` + for detail logs accompanying a slim span. + trace_id_source: + A UUID string (e.g. ``workflow_run_id``) used to derive a 32-hex + trace_id for cross-signal correlation. + tenant_id: + Tenant identifier (for the ``IdentityContextFilter``). + user_id: + User identifier (for the ``IdentityContextFilter``). + """ + if not logger.isEnabledFor(logging.INFO): + return + attrs = { + "dify.event.name": event_name, + "dify.event.signal": signal, + **attributes, + } + + extra: dict[str, Any] = {"attributes": attrs} + + trace_id_hex = compute_trace_id_hex(trace_id_source) + if trace_id_hex: + extra["trace_id"] = trace_id_hex + span_id_hex = compute_span_id_hex(span_id_source) + if span_id_hex: + extra["span_id"] = span_id_hex + if tenant_id: + extra["tenant_id"] = tenant_id + if user_id: + extra["user_id"] = user_id + + logger.info("telemetry.%s", signal, extra=extra) + + +def emit_metric_only_event( + *, + event_name: str | EnterpriseTelemetryEvent, + attributes: dict[str, Any], + trace_id_source: str | None = None, + span_id_source: str | None = None, + tenant_id: str | None = None, + user_id: str | None = None, +) -> None: + emit_telemetry_log( + event_name=event_name, + attributes=attributes, + signal="metric_only", + trace_id_source=trace_id_source, + span_id_source=span_id_source, + tenant_id=tenant_id, + user_id=user_id, + ) diff --git a/api/events/app_event.py b/api/events/app_event.py index f2ce71bbbb..3a0094b77c 100644 --- a/api/events/app_event.py +++ b/api/events/app_event.py @@ -3,6 +3,12 @@ from blinker import signal # sender: app app_was_created = signal("app-was-created") +# sender: app +app_was_deleted = signal("app-was-deleted") + +# sender: app +app_was_updated = signal("app-was-updated") + # sender: app, kwargs: app_model_config app_model_config_was_updated = signal("app-model-config-was-updated") diff --git a/api/events/feedback_event.py b/api/events/feedback_event.py new file mode 100644 index 0000000000..8d91d5c5e5 --- /dev/null +++ b/api/events/feedback_event.py @@ -0,0 +1,4 @@ +from blinker import signal + +# sender: MessageFeedback, kwargs: tenant_id +feedback_was_created = signal("feedback-was-created") diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index af983f6d87..9944e768b9 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -184,6 +184,8 @@ def init_app(app: DifyApp) -> Celery: "task": "schedule.trigger_provider_refresh_task.trigger_provider_refresh", "schedule": timedelta(minutes=dify_config.TRIGGER_PROVIDER_REFRESH_INTERVAL), } + if dify_config.ENTERPRISE_TELEMETRY_ENABLED: + imports.append("tasks.enterprise_telemetry_task") celery_app.conf.update(beat_schedule=beat_schedule, imports=imports) return celery_app diff --git a/api/extensions/ext_enterprise_telemetry.py b/api/extensions/ext_enterprise_telemetry.py new file mode 100644 index 0000000000..f785c00ae0 --- /dev/null +++ b/api/extensions/ext_enterprise_telemetry.py @@ -0,0 +1,50 @@ +"""Flask extension for enterprise telemetry lifecycle management. + +Initializes the EnterpriseExporter singleton during ``create_app()`` +(single-threaded), registers blinker event handlers, and hooks atexit +for graceful shutdown. + +Skipped entirely when ``ENTERPRISE_ENABLED`` and ``ENTERPRISE_TELEMETRY_ENABLED`` +are false (``is_enabled()`` gate). +""" + +from __future__ import annotations + +import atexit +import logging +from typing import TYPE_CHECKING + +from configs import dify_config + +if TYPE_CHECKING: + from dify_app import DifyApp + from enterprise.telemetry.exporter import EnterpriseExporter + +logger = logging.getLogger(__name__) + +_exporter: EnterpriseExporter | None = None + + +def is_enabled() -> bool: + return bool(dify_config.ENTERPRISE_ENABLED and dify_config.ENTERPRISE_TELEMETRY_ENABLED) + + +def init_app(app: DifyApp) -> None: + global _exporter + + if not is_enabled(): + return + + from enterprise.telemetry.exporter import EnterpriseExporter + + _exporter = EnterpriseExporter(dify_config) + atexit.register(_exporter.shutdown) + + # Import to trigger @signal.connect decorator registration + import enterprise.telemetry.event_handlers # noqa: F401 # type: ignore[reportUnusedImport] + + logger.info("Enterprise telemetry initialized") + + +def get_enterprise_exporter() -> EnterpriseExporter | None: + return _exporter diff --git a/api/extensions/ext_otel.py b/api/extensions/ext_otel.py index 40a915e68c..37f881f7ea 100644 --- a/api/extensions/ext_otel.py +++ b/api/extensions/ext_otel.py @@ -59,16 +59,24 @@ def init_app(app: DifyApp): protocol = (dify_config.OTEL_EXPORTER_OTLP_PROTOCOL or "").lower() if dify_config.OTEL_EXPORTER_TYPE == "otlp": if protocol == "grpc": + # Auto-detect TLS: https:// uses secure, everything else is insecure + endpoint = dify_config.OTLP_BASE_ENDPOINT + insecure = not endpoint.startswith("https://") + exporter = GRPCSpanExporter( - endpoint=dify_config.OTLP_BASE_ENDPOINT, + endpoint=endpoint, # Header field names must consist of lowercase letters, check RFC7540 - headers=(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),), - insecure=True, + headers=( + (("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),) if dify_config.OTLP_API_KEY else None + ), + insecure=insecure, ) metric_exporter = GRPCMetricExporter( - endpoint=dify_config.OTLP_BASE_ENDPOINT, - headers=(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),), - insecure=True, + endpoint=endpoint, + headers=( + (("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),) if dify_config.OTLP_API_KEY else None + ), + insecure=insecure, ) else: headers = {"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"} if dify_config.OTLP_API_KEY else None diff --git a/api/extensions/otel/parser/__init__.py b/api/extensions/otel/parser/__init__.py index 164db7c275..c671e8b409 100644 --- a/api/extensions/otel/parser/__init__.py +++ b/api/extensions/otel/parser/__init__.py @@ -5,7 +5,7 @@ This module provides parsers that extract node-specific metadata and set OpenTelemetry span attributes according to semantic conventions. """ -from extensions.otel.parser.base import DefaultNodeOTelParser, NodeOTelParser, safe_json_dumps +from extensions.otel.parser.base import DefaultNodeOTelParser, NodeOTelParser, safe_json_dumps, should_include_content from extensions.otel.parser.llm import LLMNodeOTelParser from extensions.otel.parser.retrieval import RetrievalNodeOTelParser from extensions.otel.parser.tool import ToolNodeOTelParser @@ -17,4 +17,5 @@ __all__ = [ "RetrievalNodeOTelParser", "ToolNodeOTelParser", "safe_json_dumps", + "should_include_content", ] diff --git a/api/extensions/otel/parser/base.py b/api/extensions/otel/parser/base.py index f4db26e840..dc443fe8f4 100644 --- a/api/extensions/otel/parser/base.py +++ b/api/extensions/otel/parser/base.py @@ -1,5 +1,10 @@ """ Base parser interface and utilities for OpenTelemetry node parsers. + +Content gating: ``should_include_content()`` controls whether content-bearing +span attributes (inputs, outputs, prompts, completions, documents) are written. +Gate is only active in EE (``ENTERPRISE_ENABLED=True``) when +``ENTERPRISE_INCLUDE_CONTENT=False``; CE behaviour is unchanged. """ import json @@ -9,6 +14,7 @@ from opentelemetry.trace import Span from opentelemetry.trace.status import Status, StatusCode from pydantic import BaseModel +from configs import dify_config from core.file.models import File from core.variables import Segment from core.workflow.enums import NodeType @@ -17,6 +23,17 @@ from core.workflow.nodes.base.node import Node from extensions.otel.semconv.gen_ai import ChainAttributes, GenAIAttributes +def should_include_content() -> bool: + """Return True if content should be written to spans. + + CE (ENTERPRISE_ENABLED=False): always True — no behaviour change. + EE: follows ENTERPRISE_INCLUDE_CONTENT (default True). + """ + if not dify_config.ENTERPRISE_ENABLED: + return True + return dify_config.ENTERPRISE_INCLUDE_CONTENT + + def safe_json_dumps(obj: Any, ensure_ascii: bool = False) -> str: """ Safely serialize objects to JSON, handling non-serializable types. @@ -105,10 +122,11 @@ class DefaultNodeOTelParser: # Extract inputs and outputs from result_event if result_event and result_event.node_run_result: node_run_result = result_event.node_run_result - if node_run_result.inputs: - span.set_attribute(ChainAttributes.INPUT_VALUE, safe_json_dumps(node_run_result.inputs)) - if node_run_result.outputs: - span.set_attribute(ChainAttributes.OUTPUT_VALUE, safe_json_dumps(node_run_result.outputs)) + if should_include_content(): + if node_run_result.inputs: + span.set_attribute(ChainAttributes.INPUT_VALUE, safe_json_dumps(node_run_result.inputs)) + if node_run_result.outputs: + span.set_attribute(ChainAttributes.OUTPUT_VALUE, safe_json_dumps(node_run_result.outputs)) if error: span.record_exception(error) diff --git a/api/extensions/otel/parser/llm.py b/api/extensions/otel/parser/llm.py index 8556974080..2e244b6be3 100644 --- a/api/extensions/otel/parser/llm.py +++ b/api/extensions/otel/parser/llm.py @@ -10,7 +10,7 @@ from opentelemetry.trace import Span from core.workflow.graph_events import GraphNodeEventBase from core.workflow.nodes.base.node import Node -from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps +from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps, should_include_content from extensions.otel.semconv.gen_ai import LLMAttributes logger = logging.getLogger(__name__) @@ -132,24 +132,19 @@ class LLMNodeOTelParser: span.set_attribute(LLMAttributes.USAGE_OUTPUT_TOKENS, completion_tokens) span.set_attribute(LLMAttributes.USAGE_TOTAL_TOKENS, total_tokens) - # Prompts and completion - prompts = process_data.get("prompts", []) - if prompts: - prompts_json = safe_json_dumps(prompts) - span.set_attribute(LLMAttributes.PROMPT, prompts_json) + # Prompts and completion — gated by content policy + if should_include_content(): + prompts = process_data.get("prompts", []) + if prompts: + prompts_json = safe_json_dumps(prompts) + span.set_attribute(LLMAttributes.PROMPT, prompts_json) - text_output = str(outputs.get("text", "")) - if text_output: - span.set_attribute(LLMAttributes.COMPLETION, text_output) + text_output = str(outputs.get("text", "")) + if text_output: + span.set_attribute(LLMAttributes.COMPLETION, text_output) - # Finish reason - finish_reason = outputs.get("finish_reason") or "" - if finish_reason: - span.set_attribute(LLMAttributes.RESPONSE_FINISH_REASON, finish_reason) - - # Structured input/output messages - gen_ai_input_message = _format_input_messages(process_data) - gen_ai_output_message = _format_output_messages(outputs) - - span.set_attribute(LLMAttributes.INPUT_MESSAGE, gen_ai_input_message) - span.set_attribute(LLMAttributes.OUTPUT_MESSAGE, gen_ai_output_message) + # Structured input/output messages + gen_ai_input_message = _format_input_messages(process_data) + gen_ai_output_message = _format_output_messages(outputs) + span.set_attribute(LLMAttributes.INPUT_MESSAGE, gen_ai_input_message) + span.set_attribute(LLMAttributes.OUTPUT_MESSAGE, gen_ai_output_message) diff --git a/api/extensions/otel/parser/retrieval.py b/api/extensions/otel/parser/retrieval.py index fc151af691..25738bf18b 100644 --- a/api/extensions/otel/parser/retrieval.py +++ b/api/extensions/otel/parser/retrieval.py @@ -11,7 +11,7 @@ from opentelemetry.trace import Span from core.variables import Segment from core.workflow.graph_events import GraphNodeEventBase from core.workflow.nodes.base.node import Node -from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps +from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps, should_include_content from extensions.otel.semconv.gen_ai import RetrieverAttributes logger = logging.getLogger(__name__) @@ -83,23 +83,21 @@ class RetrievalNodeOTelParser: inputs = node_run_result.inputs or {} outputs = node_run_result.outputs or {} - # Extract query from inputs - query = str(inputs.get("query", "")) if inputs else "" - if query: - span.set_attribute(RetrieverAttributes.QUERY, query) + # Query and documents — gated by content policy + if should_include_content(): + query = str(inputs.get("query", "")) if inputs else "" + if query: + span.set_attribute(RetrieverAttributes.QUERY, query) - # Extract and format retrieval documents from outputs - result_value = outputs.get("result") if outputs else None - retrieval_documents: list[Any] = [] - if result_value: - value_to_check = result_value - if isinstance(result_value, Segment): - value_to_check = result_value.value - - if isinstance(value_to_check, (list, Sequence)): - retrieval_documents = list(value_to_check) - - if retrieval_documents: - semantic_retrieval_documents = _format_retrieval_documents(retrieval_documents) - semantic_retrieval_documents_json = safe_json_dumps(semantic_retrieval_documents) - span.set_attribute(RetrieverAttributes.DOCUMENT, semantic_retrieval_documents_json) + result_value = outputs.get("result") if outputs else None + retrieval_documents: list[Any] = [] + if result_value: + value_to_check = result_value + if isinstance(result_value, Segment): + value_to_check = result_value.value + if isinstance(value_to_check, (list, Sequence)): + retrieval_documents = list(value_to_check) + if retrieval_documents: + semantic_retrieval_documents = _format_retrieval_documents(retrieval_documents) + semantic_retrieval_documents_json = safe_json_dumps(semantic_retrieval_documents) + span.set_attribute(RetrieverAttributes.DOCUMENT, semantic_retrieval_documents_json) diff --git a/api/extensions/otel/parser/tool.py b/api/extensions/otel/parser/tool.py index b99180722b..cb394d42b1 100644 --- a/api/extensions/otel/parser/tool.py +++ b/api/extensions/otel/parser/tool.py @@ -8,7 +8,7 @@ from core.workflow.enums import WorkflowNodeExecutionMetadataKey from core.workflow.graph_events import GraphNodeEventBase from core.workflow.nodes.base.node import Node from core.workflow.nodes.tool.entities import ToolNodeData -from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps +from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps, should_include_content from extensions.otel.semconv.gen_ai import ToolAttributes @@ -40,8 +40,14 @@ class ToolNodeOTelParser: if tool_info: span.set_attribute(ToolAttributes.TOOL_DESCRIPTION, safe_json_dumps(tool_info)) - if result_event and result_event.node_run_result and result_event.node_run_result.inputs: - span.set_attribute(ToolAttributes.TOOL_CALL_ARGUMENTS, safe_json_dumps(result_event.node_run_result.inputs)) + # Tool call arguments and result — gated by content policy + if should_include_content(): + if result_event and result_event.node_run_result and result_event.node_run_result.inputs: + span.set_attribute( + ToolAttributes.TOOL_CALL_ARGUMENTS, safe_json_dumps(result_event.node_run_result.inputs) + ) - if result_event and result_event.node_run_result and result_event.node_run_result.outputs: - span.set_attribute(ToolAttributes.TOOL_CALL_RESULT, safe_json_dumps(result_event.node_run_result.outputs)) + if result_event and result_event.node_run_result and result_event.node_run_result.outputs: + span.set_attribute( + ToolAttributes.TOOL_CALL_RESULT, safe_json_dumps(result_event.node_run_result.outputs) + ) diff --git a/api/extensions/otel/semconv/dify.py b/api/extensions/otel/semconv/dify.py index a20b9b358d..301ddd11aa 100644 --- a/api/extensions/otel/semconv/dify.py +++ b/api/extensions/otel/semconv/dify.py @@ -21,3 +21,15 @@ class DifySpanAttributes: INVOKE_FROM = "dify.invoke_from" """Invocation source, e.g. SERVICE_API, WEB_APP, DEBUGGER.""" + + INVOKED_BY = "dify.invoked_by" + """Invoked by, e.g. end_user, account, user.""" + + USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens" + """Number of input tokens (prompt tokens) used.""" + + USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens" + """Number of output tokens (completion tokens) generated.""" + + USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens" + """Total number of tokens used.""" diff --git a/api/models/model.py b/api/models/model.py index 429c46bd85..1a259e4621 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -664,15 +664,11 @@ class ExporleBanner(TypeBase): content: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False) link: Mapped[str] = mapped_column(String(255), nullable=False) sort: Mapped[int] = mapped_column(sa.Integer, nullable=False) - status: Mapped[str] = mapped_column( - sa.String(255), nullable=False, server_default='enabled', default="enabled" - ) + status: Mapped[str] = mapped_column(sa.String(255), nullable=False, server_default="enabled", default="enabled") created_at: Mapped[datetime] = mapped_column( sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False ) - language: Mapped[str] = mapped_column( - String(255), nullable=False, server_default='en-US', default="en-US" - ) + language: Mapped[str] = mapped_column(String(255), nullable=False, server_default="en-US", default="en-US") class OAuthProviderApp(TypeBase): diff --git a/api/services/app_service.py b/api/services/app_service.py index af458ff618..0422b4bab9 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -14,7 +14,7 @@ from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelTy from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager -from events.app_event import app_was_created +from events.app_event import app_was_created, app_was_deleted from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from libs.login import current_user @@ -340,6 +340,8 @@ class AppService: db.session.delete(app) db.session.commit() + app_was_deleted.send(app) + # clean up web app settings if FeatureService.get_system_features().webapp_auth.enabled: EnterpriseService.WebAppAuth.cleanup_webapp(app.id) diff --git a/api/services/enterprise/account_deletion_sync.py b/api/services/enterprise/account_deletion_sync.py index f8f8189891..c7ff42894d 100644 --- a/api/services/enterprise/account_deletion_sync.py +++ b/api/services/enterprise/account_deletion_sync.py @@ -81,7 +81,7 @@ def sync_workspace_member_removal(workspace_id: str, member_id: str, *, source: bool: True if task was queued (or skipped in community), False if queueing failed """ if not dify_config.ENTERPRISE_ENABLED: - return True + return True return _queue_task(workspace_id=workspace_id, member_id=member_id, source=source) @@ -101,7 +101,7 @@ def sync_account_deletion(account_id: str, *, source: str) -> bool: bool: True if all tasks were queued (or skipped in community), False if any queueing failed """ if not dify_config.ENTERPRISE_ENABLED: - return True + return True # Fetch all workspaces the account belongs to workspace_joins = db.session.query(TenantAccountJoin).filter_by(account_id=account_id).all() diff --git a/api/services/message_service.py b/api/services/message_service.py index a53ca8b22d..26b220edfa 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -7,9 +7,10 @@ from core.llm_generator.llm_generator import LLMGenerator from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType -from core.ops.entities.trace_entity import TraceTaskName -from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.utils import measure_time +from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName +from core.telemetry import emit as telemetry_emit +from events.feedback_event import feedback_was_created from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account @@ -179,6 +180,9 @@ class MessageService: db.session.commit() + if feedback and rating: + feedback_was_created.send(feedback, tenant_id=app_model.tenant_id) + return feedback @classmethod @@ -294,10 +298,15 @@ class MessageService: questions: list[str] = list(questions_sequence) # get tracing instance - trace_manager = TraceQueueManager(app_id=app_model.id) - trace_manager.add_trace_task( - TraceTask( - TraceTaskName.SUGGESTED_QUESTION_TRACE, message_id=message_id, suggested_question=questions, timer=timer + telemetry_emit( + TelemetryEvent( + name=TraceTaskName.SUGGESTED_QUESTION_TRACE, + context=TelemetryContext(tenant_id=app_model.tenant_id, app_id=app_model.id), + payload={ + "message_id": message_id, + "suggested_question": questions, + "timer": timer, + }, ) ) diff --git a/api/services/ops_service.py b/api/services/ops_service.py index 50ea832085..c1c92b2de8 100644 --- a/api/services/ops_service.py +++ b/api/services/ops_service.py @@ -1,3 +1,4 @@ +import logging from typing import Any from core.ops.entities.config_entity import BaseTracingConfig @@ -5,6 +6,8 @@ from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map from extensions.ext_database import db from models.model import App, TraceAppConfig +logger = logging.getLogger(__name__) + class OpsService: @classmethod @@ -135,12 +138,13 @@ class OpsService: return trace_config_data.to_dict() @classmethod - def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict): + def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict, account_id: str): """ Create tracing app config :param app_id: app id :param tracing_provider: tracing provider :param tracing_config: tracing config + :param account_id: account id of the user creating the config :return: """ try: @@ -207,15 +211,19 @@ class OpsService: db.session.add(trace_config_data) db.session.commit() + # Log the creation with modifier information + logger.info("Trace config created: app_id=%s, provider=%s, created_by=%s", app_id, tracing_provider, account_id) + return {"result": "success"} @classmethod - def update_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict): + def update_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict, account_id: str): """ Update tracing app config :param app_id: app id :param tracing_provider: tracing provider :param tracing_config: tracing config + :param account_id: account id of the user updating the config :return: """ try: @@ -251,14 +259,18 @@ class OpsService: current_trace_config.tracing_config = tracing_config db.session.commit() + # Log the update with modifier information + logger.info("Trace config updated: app_id=%s, provider=%s, updated_by=%s", app_id, tracing_provider, account_id) + return current_trace_config.to_dict() @classmethod - def delete_tracing_app_config(cls, app_id: str, tracing_provider: str): + def delete_tracing_app_config(cls, app_id: str, tracing_provider: str, account_id: str): """ Delete tracing app config :param app_id: app id :param tracing_provider: tracing provider + :param account_id: account id of the user deleting the config :return: """ trace_config = ( @@ -270,6 +282,9 @@ class OpsService: if not trace_config: return None + # Log the deletion with modifier information + logger.info("Trace config deleted: app_id=%s, provider=%s, deleted_by=%s", app_id, tracing_provider, account_id) + db.session.delete(trace_config) db.session.commit() diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 6404136994..e4d0773030 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -647,6 +647,7 @@ class WorkflowService: node_config = draft_workflow.get_node_config_by_id(node_id) node_type = Workflow.get_node_type_from_node_config(node_config) node_data = node_config.get("data", {}) + workflow_execution_id: str | None = None if node_type.is_start_node: with Session(bind=db.engine) as session, session.begin(): draft_var_srv = WorkflowDraftVariableService(session) @@ -672,10 +673,13 @@ class WorkflowService: node_type=node_type, conversation_id=conversation_id, ) + workflow_execution_id = variable_pool.system_variables.workflow_execution_id else: + workflow_execution_id = str(uuid.uuid4()) + system_variable = SystemVariable(workflow_execution_id=workflow_execution_id) variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=system_variable, user_inputs=user_inputs, environment_variables=draft_workflow.environment_variables, conversation_variables=[], @@ -729,6 +733,15 @@ class WorkflowService: with Session(db.engine) as session: outputs = workflow_node_execution.load_full_outputs(session, storage) + from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace + + enqueue_draft_node_execution_trace( + execution=workflow_node_execution, + outputs=outputs, + workflow_execution_id=workflow_execution_id, + user_id=account.id, + ) + with Session(bind=db.engine) as session, session.begin(): draft_var_saver = DraftVariableSaver( session=session, @@ -784,19 +797,20 @@ class WorkflowService: Returns: WorkflowNodeExecution: The execution result """ + created_at = naive_utc_now() node, node_run_result, run_succeeded, error = self._execute_node_safely(invoke_node_fn) + finished_at = naive_utc_now() - # Create base node execution node_execution = WorkflowNodeExecution( - id=str(uuid.uuid4()), + id=node.execution_id or str(uuid.uuid4()), workflow_id="", # Single-step execution has no workflow ID index=1, node_id=node_id, node_type=node.node_type, title=node.title, elapsed_time=time.perf_counter() - start_at, - created_at=naive_utc_now(), - finished_at=naive_utc_now(), + created_at=created_at, + finished_at=finished_at, ) # Populate execution result data diff --git a/api/tasks/enterprise_telemetry_task.py b/api/tasks/enterprise_telemetry_task.py new file mode 100644 index 0000000000..7d5ea7c0a5 --- /dev/null +++ b/api/tasks/enterprise_telemetry_task.py @@ -0,0 +1,52 @@ +"""Celery worker for enterprise metric/log telemetry events. + +This module defines the Celery task that processes telemetry envelopes +from the enterprise_telemetry queue. It deserializes envelopes and +dispatches them to the EnterpriseMetricHandler. +""" + +import json +import logging + +from celery import shared_task + +from enterprise.telemetry.contracts import TelemetryEnvelope +from enterprise.telemetry.metric_handler import EnterpriseMetricHandler + +logger = logging.getLogger(__name__) + + +@shared_task(queue="enterprise_telemetry") +def process_enterprise_telemetry(envelope_json: str) -> None: + """Process enterprise metric/log telemetry envelope. + + This task is enqueued by the TelemetryGateway for metric/log-only + events. It deserializes the envelope and dispatches to the handler. + + Best-effort processing: logs errors but never raises, to avoid + failing user requests due to telemetry issues. + + Args: + envelope_json: JSON-serialized TelemetryEnvelope. + """ + try: + # Deserialize envelope + envelope_dict = json.loads(envelope_json) + envelope = TelemetryEnvelope.model_validate(envelope_dict) + + # Process through handler + handler = EnterpriseMetricHandler() + handler.handle(envelope) + + logger.debug( + "Successfully processed telemetry envelope: tenant_id=%s, event_id=%s, case=%s", + envelope.tenant_id, + envelope.event_id, + envelope.case, + ) + except Exception: + # Best-effort: log and drop on error, never fail user request + logger.warning( + "Failed to process enterprise telemetry envelope, dropping event", + exc_info=True, + ) diff --git a/api/tasks/ops_trace_task.py b/api/tasks/ops_trace_task.py index 72e3b42ca7..3d3a9755a5 100644 --- a/api/tasks/ops_trace_task.py +++ b/api/tasks/ops_trace_task.py @@ -39,12 +39,24 @@ def process_trace_tasks(file_info): trace_info["documents"] = [Document.model_validate(doc) for doc in trace_info["documents"]] try: + trace_type = trace_info_info_map.get(trace_info_type) + if trace_type: + trace_info = trace_type(**trace_info) + + from extensions.ext_enterprise_telemetry import is_enabled as is_ee_telemetry_enabled + + if is_ee_telemetry_enabled(): + from enterprise.telemetry.enterprise_trace import EnterpriseOtelTrace + + try: + EnterpriseOtelTrace().trace(trace_info) + except Exception: + logger.exception("Enterprise trace failed for app_id: %s", app_id) + if trace_instance: with current_app.app_context(): - trace_type = trace_info_info_map.get(trace_info_type) - if trace_type: - trace_info = trace_type(**trace_info) trace_instance.trace(trace_info) + logger.info("Processing trace tasks success, app_id: %s", app_id) except Exception as e: logger.info("error:\n\n\n%s\n\n\n\n", e) @@ -52,4 +64,12 @@ def process_trace_tasks(file_info): redis_client.incr(failed_key) logger.info("Processing trace tasks failed, app_id: %s", app_id) finally: - storage.delete(file_path) + try: + storage.delete(file_path) + except Exception as e: + logger.warning( + "Failed to delete trace file %s for app_id %s: %s", + file_path, + app_id, + e, + ) diff --git a/api/tests/unit_tests/core/ops/test_trace_queue_manager.py b/api/tests/unit_tests/core/ops/test_trace_queue_manager.py new file mode 100644 index 0000000000..44a58ab902 --- /dev/null +++ b/api/tests/unit_tests/core/ops/test_trace_queue_manager.py @@ -0,0 +1,200 @@ +"""Unit tests for TraceQueueManager telemetry guard. + +This test suite verifies that TraceQueueManager correctly drops trace tasks +when telemetry is disabled, proving Bug 1 from code review is a false positive. + +The guard logic moved from persistence.py to TraceQueueManager.add_trace_task() +at line 1282 of ops_trace_manager.py: + if self._enterprise_telemetry_enabled or self.trace_instance: + trace_task.app_id = self.app_id + trace_manager_queue.put(trace_task) + +Tasks are only enqueued if EITHER: +- Enterprise telemetry is enabled (_enterprise_telemetry_enabled=True), OR +- A third-party trace instance (Langfuse, etc.) is configured + +When BOTH are false, tasks are silently dropped (correct behavior). +""" + +import queue +import sys +import types +from unittest.mock import MagicMock, patch + +import pytest + + +@pytest.fixture +def trace_queue_manager_and_task(monkeypatch): + """Fixture to provide TraceQueueManager and TraceTask with delayed imports.""" + module_name = "core.ops.ops_trace_manager" + if module_name not in sys.modules: + ops_stub = types.ModuleType(module_name) + + class StubTraceTask: + def __init__(self, trace_type): + self.trace_type = trace_type + self.app_id = None + + class StubTraceQueueManager: + def __init__(self, app_id=None): + self.app_id = app_id + from core.telemetry.gateway import is_enterprise_telemetry_enabled + + self._enterprise_telemetry_enabled = is_enterprise_telemetry_enabled() + self.trace_instance = StubOpsTraceManager.get_ops_trace_instance(app_id) + + def add_trace_task(self, trace_task): + if self._enterprise_telemetry_enabled or self.trace_instance: + trace_task.app_id = self.app_id + from core.ops.ops_trace_manager import trace_manager_queue + + trace_manager_queue.put(trace_task) + + class StubOpsTraceManager: + @staticmethod + def get_ops_trace_instance(app_id): + return None + + ops_stub.TraceQueueManager = StubTraceQueueManager + ops_stub.TraceTask = StubTraceTask + ops_stub.OpsTraceManager = StubOpsTraceManager + ops_stub.trace_manager_queue = MagicMock(spec=queue.Queue) + monkeypatch.setitem(sys.modules, module_name, ops_stub) + + from core.ops.entities.trace_entity import TraceTaskName + + ops_module = __import__(module_name, fromlist=["TraceQueueManager", "TraceTask"]) + TraceQueueManager = ops_module.TraceQueueManager + TraceTask = ops_module.TraceTask + + return TraceQueueManager, TraceTask, TraceTaskName + + +class TestTraceQueueManagerTelemetryGuard: + """Test TraceQueueManager's telemetry guard in add_trace_task().""" + + def test_task_not_enqueued_when_telemetry_disabled_and_no_trace_instance(self, trace_queue_manager_and_task): + """Verify task is NOT enqueued when telemetry disabled and no trace instance. + + This is the core guard: when _enterprise_telemetry_enabled=False AND + trace_instance=None, the task should be silently dropped. + """ + TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task + + mock_queue = MagicMock(spec=queue.Queue) + + trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE) + + with ( + patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False), + patch("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=None), + patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue), + ): + manager = TraceQueueManager(app_id="test-app-id") + manager.add_trace_task(trace_task) + + mock_queue.put.assert_not_called() + + def test_task_enqueued_when_telemetry_enabled(self, trace_queue_manager_and_task): + """Verify task IS enqueued when enterprise telemetry is enabled. + + When _enterprise_telemetry_enabled=True, the task should be enqueued + regardless of trace_instance state. + """ + TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task + + mock_queue = MagicMock(spec=queue.Queue) + + trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE) + + with ( + patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True), + patch("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=None), + patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue), + ): + manager = TraceQueueManager(app_id="test-app-id") + manager.add_trace_task(trace_task) + + mock_queue.put.assert_called_once() + called_task = mock_queue.put.call_args[0][0] + assert called_task.app_id == "test-app-id" + + def test_task_enqueued_when_trace_instance_configured(self, trace_queue_manager_and_task): + """Verify task IS enqueued when third-party trace instance is configured. + + When trace_instance is not None (e.g., Langfuse configured), the task + should be enqueued even if enterprise telemetry is disabled. + """ + TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task + + mock_queue = MagicMock(spec=queue.Queue) + + mock_trace_instance = MagicMock() + + trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE) + + with ( + patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False), + patch( + "core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=mock_trace_instance + ), + patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue), + ): + manager = TraceQueueManager(app_id="test-app-id") + manager.add_trace_task(trace_task) + + mock_queue.put.assert_called_once() + called_task = mock_queue.put.call_args[0][0] + assert called_task.app_id == "test-app-id" + + def test_task_enqueued_when_both_telemetry_and_trace_instance_enabled(self, trace_queue_manager_and_task): + """Verify task IS enqueued when both telemetry and trace instance are enabled. + + When both _enterprise_telemetry_enabled=True AND trace_instance is set, + the task should definitely be enqueued. + """ + TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task + + mock_queue = MagicMock(spec=queue.Queue) + + mock_trace_instance = MagicMock() + + trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE) + + with ( + patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True), + patch( + "core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=mock_trace_instance + ), + patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue), + ): + manager = TraceQueueManager(app_id="test-app-id") + manager.add_trace_task(trace_task) + + mock_queue.put.assert_called_once() + called_task = mock_queue.put.call_args[0][0] + assert called_task.app_id == "test-app-id" + + def test_app_id_set_before_enqueue(self, trace_queue_manager_and_task): + """Verify app_id is set on the task before enqueuing. + + The guard logic sets trace_task.app_id = self.app_id before calling + trace_manager_queue.put(trace_task). This test verifies that behavior. + """ + TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task + + mock_queue = MagicMock(spec=queue.Queue) + + trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE) + + with ( + patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True), + patch("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=None), + patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue), + ): + manager = TraceQueueManager(app_id="expected-app-id") + manager.add_trace_task(trace_task) + + called_task = mock_queue.put.call_args[0][0] + assert called_task.app_id == "expected-app-id" diff --git a/api/tests/unit_tests/core/telemetry/test_facade.py b/api/tests/unit_tests/core/telemetry/test_facade.py new file mode 100644 index 0000000000..64c2f6a971 --- /dev/null +++ b/api/tests/unit_tests/core/telemetry/test_facade.py @@ -0,0 +1,181 @@ +"""Unit tests for core.telemetry.emit() routing and enterprise-only filtering.""" + +from __future__ import annotations + +import queue +import sys +import types +from unittest.mock import MagicMock, patch + +import pytest + +from core.ops.entities.trace_entity import TraceTaskName +from core.telemetry.events import TelemetryContext, TelemetryEvent + + +@pytest.fixture +def telemetry_test_setup(monkeypatch): + module_name = "core.ops.ops_trace_manager" + ops_stub = types.ModuleType(module_name) + + class StubTraceTask: + def __init__(self, trace_type, **kwargs): + self.trace_type = trace_type + self.app_id = None + self.kwargs = kwargs + + class StubTraceQueueManager: + def __init__(self, app_id=None, user_id=None): + self.app_id = app_id + self.user_id = user_id + self.trace_instance = StubOpsTraceManager.get_ops_trace_instance(app_id) + + def add_trace_task(self, trace_task): + trace_task.app_id = self.app_id + from core.ops.ops_trace_manager import trace_manager_queue + + trace_manager_queue.put(trace_task) + + class StubOpsTraceManager: + @staticmethod + def get_ops_trace_instance(app_id): + return None + + ops_stub.TraceQueueManager = StubTraceQueueManager + ops_stub.TraceTask = StubTraceTask + ops_stub.OpsTraceManager = StubOpsTraceManager + ops_stub.trace_manager_queue = MagicMock(spec=queue.Queue) + monkeypatch.setitem(sys.modules, module_name, ops_stub) + + from core.telemetry import emit + + return emit, ops_stub.trace_manager_queue + + +class TestTelemetryEmit: + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_emit_enterprise_trace_creates_trace_task(self, _mock_ee, telemetry_test_setup): + emit_fn, mock_queue = telemetry_test_setup + + event = TelemetryEvent( + name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + context=TelemetryContext( + tenant_id="test-tenant", + user_id="test-user", + app_id="test-app", + ), + payload={"key": "value"}, + ) + + emit_fn(event) + + mock_queue.put.assert_called_once() + called_task = mock_queue.put.call_args[0][0] + assert called_task.trace_type == TraceTaskName.DRAFT_NODE_EXECUTION_TRACE + + def test_emit_community_trace_enqueued(self, telemetry_test_setup): + emit_fn, mock_queue = telemetry_test_setup + + event = TelemetryEvent( + name=TraceTaskName.WORKFLOW_TRACE, + context=TelemetryContext( + tenant_id="test-tenant", + user_id="test-user", + app_id="test-app", + ), + payload={}, + ) + + emit_fn(event) + + mock_queue.put.assert_called_once() + + def test_emit_enterprise_only_trace_dropped_when_ee_disabled(self, telemetry_test_setup): + emit_fn, mock_queue = telemetry_test_setup + + event = TelemetryEvent( + name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + context=TelemetryContext( + tenant_id="test-tenant", + user_id="test-user", + app_id="test-app", + ), + payload={}, + ) + + emit_fn(event) + + mock_queue.put.assert_not_called() + + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_emit_all_enterprise_only_traces_allowed_when_ee_enabled(self, _mock_ee, telemetry_test_setup): + emit_fn, mock_queue = telemetry_test_setup + + enterprise_only_traces = [ + TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + TraceTaskName.NODE_EXECUTION_TRACE, + TraceTaskName.PROMPT_GENERATION_TRACE, + ] + + for trace_name in enterprise_only_traces: + mock_queue.reset_mock() + + event = TelemetryEvent( + name=trace_name, + context=TelemetryContext( + tenant_id="test-tenant", + user_id="test-user", + app_id="test-app", + ), + payload={}, + ) + + emit_fn(event) + + mock_queue.put.assert_called_once() + called_task = mock_queue.put.call_args[0][0] + assert called_task.trace_type == trace_name + + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_emit_passes_name_directly_to_trace_task(self, _mock_ee, telemetry_test_setup): + emit_fn, mock_queue = telemetry_test_setup + + event = TelemetryEvent( + name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + context=TelemetryContext( + tenant_id="test-tenant", + user_id="test-user", + app_id="test-app", + ), + payload={"extra": "data"}, + ) + + emit_fn(event) + + mock_queue.put.assert_called_once() + called_task = mock_queue.put.call_args[0][0] + assert called_task.trace_type == TraceTaskName.DRAFT_NODE_EXECUTION_TRACE + assert isinstance(called_task.trace_type, TraceTaskName) + + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_emit_with_provided_trace_manager(self, _mock_ee, telemetry_test_setup): + emit_fn, mock_queue = telemetry_test_setup + + mock_trace_manager = MagicMock() + mock_trace_manager.add_trace_task = MagicMock() + + event = TelemetryEvent( + name=TraceTaskName.NODE_EXECUTION_TRACE, + context=TelemetryContext( + tenant_id="test-tenant", + user_id="test-user", + app_id="test-app", + ), + payload={}, + ) + + emit_fn(event, trace_manager=mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + called_task = mock_trace_manager.add_trace_task.call_args[0][0] + assert called_task.trace_type == TraceTaskName.NODE_EXECUTION_TRACE diff --git a/api/tests/unit_tests/core/telemetry/test_gateway_integration.py b/api/tests/unit_tests/core/telemetry/test_gateway_integration.py new file mode 100644 index 0000000000..536d4374d6 --- /dev/null +++ b/api/tests/unit_tests/core/telemetry/test_gateway_integration.py @@ -0,0 +1,225 @@ +from __future__ import annotations + +import sys +from unittest.mock import MagicMock, patch + +import pytest + +from core.telemetry.gateway import emit, is_enterprise_telemetry_enabled +from enterprise.telemetry.contracts import TelemetryCase + + +class TestTelemetryCoreExports: + def test_is_enterprise_telemetry_enabled_exported(self) -> None: + from core.telemetry.gateway import is_enterprise_telemetry_enabled as exported_func + + assert callable(exported_func) + + +@pytest.fixture +def mock_ops_trace_manager(): + mock_module = MagicMock() + mock_trace_task_class = MagicMock() + mock_trace_task_class.return_value = MagicMock() + mock_module.TraceTask = mock_trace_task_class + mock_module.TraceQueueManager = MagicMock() + + mock_trace_entity = MagicMock() + mock_trace_task_name = MagicMock() + mock_trace_task_name.return_value = "workflow" + mock_trace_entity.TraceTaskName = mock_trace_task_name + + with ( + patch.dict(sys.modules, {"core.ops.ops_trace_manager": mock_module}), + patch.dict(sys.modules, {"core.ops.entities.trace_entity": mock_trace_entity}), + ): + yield mock_module, mock_trace_entity + + +class TestGatewayIntegrationTraceRouting: + @pytest.fixture + def mock_trace_manager(self) -> MagicMock: + return MagicMock() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_ce_eligible_trace_routed_to_trace_manager( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True): + context = {"app_id": "app-123", "user_id": "user-456", "tenant_id": "tenant-789"} + payload = {"workflow_run_id": "run-abc"} + + emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_ce_eligible_trace_routed_when_ee_disabled( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"workflow_run_id": "run-abc"} + + emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_enterprise_only_trace_dropped_when_ee_disabled( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"node_id": "node-abc"} + + emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_not_called() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_enterprise_only_trace_routed_when_ee_enabled( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"node_id": "node-abc"} + + emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + +class TestGatewayIntegrationMetricRouting: + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_metric_case_routes_to_celery_task( + self, + _mock_ee_enabled: MagicMock, + ) -> None: + from enterprise.telemetry.contracts import TelemetryEnvelope + + with patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") as mock_delay: + context = {"tenant_id": "tenant-123"} + payload = {"app_id": "app-abc", "name": "My App"} + + emit(TelemetryCase.APP_CREATED, context, payload) + + mock_delay.assert_called_once() + envelope_json = mock_delay.call_args[0][0] + envelope = TelemetryEnvelope.model_validate_json(envelope_json) + assert envelope.case == TelemetryCase.APP_CREATED + assert envelope.tenant_id == "tenant-123" + assert envelope.payload["app_id"] == "app-abc" + + @pytest.mark.usefixtures("mock_ops_trace_manager") + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_tool_execution_trace_routed( + self, + _mock_ee_enabled: MagicMock, + ) -> None: + mock_trace_manager = MagicMock() + context = {"tenant_id": "tenant-123", "app_id": "app-123"} + payload = {"tool_name": "test_tool", "tool_inputs": {}, "tool_outputs": "result"} + + emit(TelemetryCase.TOOL_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_moderation_check_trace_routed( + self, + _mock_ee_enabled: MagicMock, + ) -> None: + mock_trace_manager = MagicMock() + context = {"tenant_id": "tenant-123", "app_id": "app-123"} + payload = {"message_id": "msg-123", "moderation_result": {"flagged": False}} + + emit(TelemetryCase.MODERATION_CHECK, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + +class TestGatewayIntegrationCEEligibility: + @pytest.fixture + def mock_trace_manager(self) -> MagicMock: + return MagicMock() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_workflow_run_is_ce_eligible( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"workflow_run_id": "run-abc"} + + emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_message_run_is_ce_eligible( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"message_id": "msg-abc", "conversation_id": "conv-123"} + + emit(TelemetryCase.MESSAGE_RUN, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_node_execution_not_ce_eligible( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"node_id": "node-abc"} + + emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_not_called() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_draft_node_execution_not_ce_eligible( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"node_execution_data": {}} + + emit(TelemetryCase.DRAFT_NODE_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_not_called() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_prompt_generation_not_ce_eligible( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456", "tenant_id": "tenant-789"} + payload = {"operation_type": "generate", "instruction": "test"} + + emit(TelemetryCase.PROMPT_GENERATION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_not_called() + + +class TestIsEnterpriseTelemetryEnabled: + def test_returns_false_when_exporter_import_fails(self) -> None: + with patch.dict(sys.modules, {"enterprise.telemetry.exporter": None}): + result = is_enterprise_telemetry_enabled() + assert result is False + + def test_function_is_callable(self) -> None: + assert callable(is_enterprise_telemetry_enabled) diff --git a/api/tests/unit_tests/enterprise/__init__.py b/api/tests/unit_tests/enterprise/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/enterprise/telemetry/__init__.py b/api/tests/unit_tests/enterprise/telemetry/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/enterprise/telemetry/test_contracts.py b/api/tests/unit_tests/enterprise/telemetry/test_contracts.py new file mode 100644 index 0000000000..7453525bfc --- /dev/null +++ b/api/tests/unit_tests/enterprise/telemetry/test_contracts.py @@ -0,0 +1,230 @@ +"""Unit tests for telemetry gateway contracts.""" + +from __future__ import annotations + +import pytest +from pydantic import ValidationError + +from core.telemetry.gateway import CASE_ROUTING +from enterprise.telemetry.contracts import CaseRoute, SignalType, TelemetryCase, TelemetryEnvelope + + +class TestTelemetryCase: + """Tests for TelemetryCase enum.""" + + def test_all_cases_defined(self) -> None: + """Verify all 14 telemetry cases are defined.""" + expected_cases = { + "WORKFLOW_RUN", + "NODE_EXECUTION", + "DRAFT_NODE_EXECUTION", + "MESSAGE_RUN", + "TOOL_EXECUTION", + "MODERATION_CHECK", + "SUGGESTED_QUESTION", + "DATASET_RETRIEVAL", + "GENERATE_NAME", + "PROMPT_GENERATION", + "APP_CREATED", + "APP_UPDATED", + "APP_DELETED", + "FEEDBACK_CREATED", + } + actual_cases = {case.name for case in TelemetryCase} + assert actual_cases == expected_cases + + def test_case_values(self) -> None: + """Verify case enum values are correct.""" + assert TelemetryCase.WORKFLOW_RUN.value == "workflow_run" + assert TelemetryCase.NODE_EXECUTION.value == "node_execution" + assert TelemetryCase.DRAFT_NODE_EXECUTION.value == "draft_node_execution" + assert TelemetryCase.MESSAGE_RUN.value == "message_run" + assert TelemetryCase.TOOL_EXECUTION.value == "tool_execution" + assert TelemetryCase.MODERATION_CHECK.value == "moderation_check" + assert TelemetryCase.SUGGESTED_QUESTION.value == "suggested_question" + assert TelemetryCase.DATASET_RETRIEVAL.value == "dataset_retrieval" + assert TelemetryCase.GENERATE_NAME.value == "generate_name" + assert TelemetryCase.PROMPT_GENERATION.value == "prompt_generation" + assert TelemetryCase.APP_CREATED.value == "app_created" + assert TelemetryCase.APP_UPDATED.value == "app_updated" + assert TelemetryCase.APP_DELETED.value == "app_deleted" + assert TelemetryCase.FEEDBACK_CREATED.value == "feedback_created" + + +class TestCaseRoute: + """Tests for CaseRoute model.""" + + def test_valid_trace_route(self) -> None: + """Verify valid trace route creation.""" + route = CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True) + assert route.signal_type == SignalType.TRACE + assert route.ce_eligible is True + + def test_valid_metric_log_route(self) -> None: + """Verify valid metric_log route creation.""" + route = CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False) + assert route.signal_type == SignalType.METRIC_LOG + assert route.ce_eligible is False + + def test_invalid_signal_type(self) -> None: + """Verify invalid signal_type is rejected.""" + with pytest.raises(ValidationError): + CaseRoute(signal_type="invalid", ce_eligible=True) + + +class TestTelemetryEnvelope: + """Tests for TelemetryEnvelope model.""" + + def test_valid_envelope_minimal(self) -> None: + """Verify valid minimal envelope creation.""" + envelope = TelemetryEnvelope( + case=TelemetryCase.WORKFLOW_RUN, + tenant_id="tenant-123", + event_id="event-456", + payload={"key": "value"}, + ) + assert envelope.case == TelemetryCase.WORKFLOW_RUN + assert envelope.tenant_id == "tenant-123" + assert envelope.event_id == "event-456" + assert envelope.payload == {"key": "value"} + assert envelope.metadata is None + + def test_valid_envelope_full(self) -> None: + """Verify valid envelope with all fields.""" + metadata = {"payload_ref": "telemetry/tenant-789/event-012.json"} + envelope = TelemetryEnvelope( + case=TelemetryCase.MESSAGE_RUN, + tenant_id="tenant-789", + event_id="event-012", + payload={"message": "hello"}, + metadata=metadata, + ) + assert envelope.case == TelemetryCase.MESSAGE_RUN + assert envelope.tenant_id == "tenant-789" + assert envelope.event_id == "event-012" + assert envelope.payload == {"message": "hello"} + assert envelope.metadata == metadata + + def test_missing_required_case(self) -> None: + """Verify missing case field is rejected.""" + with pytest.raises(ValidationError): + TelemetryEnvelope( + tenant_id="tenant-123", + event_id="event-456", + payload={"key": "value"}, + ) + + def test_missing_required_tenant_id(self) -> None: + """Verify missing tenant_id field is rejected.""" + with pytest.raises(ValidationError): + TelemetryEnvelope( + case=TelemetryCase.WORKFLOW_RUN, + event_id="event-456", + payload={"key": "value"}, + ) + + def test_missing_required_event_id(self) -> None: + """Verify missing event_id field is rejected.""" + with pytest.raises(ValidationError): + TelemetryEnvelope( + case=TelemetryCase.WORKFLOW_RUN, + tenant_id="tenant-123", + payload={"key": "value"}, + ) + + def test_missing_required_payload(self) -> None: + """Verify missing payload field is rejected.""" + with pytest.raises(ValidationError): + TelemetryEnvelope( + case=TelemetryCase.WORKFLOW_RUN, + tenant_id="tenant-123", + event_id="event-456", + ) + + def test_metadata_none(self) -> None: + """Verify metadata can be None.""" + envelope = TelemetryEnvelope( + case=TelemetryCase.WORKFLOW_RUN, + tenant_id="tenant-123", + event_id="event-456", + payload={"key": "value"}, + metadata=None, + ) + assert envelope.metadata is None + + +class TestCaseRouting: + """Tests for CASE_ROUTING table.""" + + def test_all_cases_routed(self) -> None: + """Verify all 14 cases have routing entries.""" + assert len(CASE_ROUTING) == 14 + for case in TelemetryCase: + assert case in CASE_ROUTING + + def test_trace_ce_eligible_cases(self) -> None: + """Verify trace cases with CE eligibility.""" + ce_eligible_trace_cases = { + TelemetryCase.WORKFLOW_RUN, + TelemetryCase.MESSAGE_RUN, + } + for case in ce_eligible_trace_cases: + route = CASE_ROUTING[case] + assert route.signal_type == SignalType.TRACE + assert route.ce_eligible is True + + def test_trace_enterprise_only_cases(self) -> None: + """Verify trace cases that are enterprise-only.""" + enterprise_only_trace_cases = { + TelemetryCase.NODE_EXECUTION, + TelemetryCase.DRAFT_NODE_EXECUTION, + TelemetryCase.PROMPT_GENERATION, + } + for case in enterprise_only_trace_cases: + route = CASE_ROUTING[case] + assert route.signal_type == SignalType.TRACE + assert route.ce_eligible is False + + def test_metric_log_cases(self) -> None: + """Verify metric/log-only cases.""" + metric_log_cases = { + TelemetryCase.APP_CREATED, + TelemetryCase.APP_UPDATED, + TelemetryCase.APP_DELETED, + TelemetryCase.FEEDBACK_CREATED, + } + for case in metric_log_cases: + route = CASE_ROUTING[case] + assert route.signal_type == SignalType.METRIC_LOG + assert route.ce_eligible is False + + def test_routing_table_completeness(self) -> None: + """Verify routing table covers all cases with correct types.""" + trace_cases = { + TelemetryCase.WORKFLOW_RUN, + TelemetryCase.MESSAGE_RUN, + TelemetryCase.NODE_EXECUTION, + TelemetryCase.DRAFT_NODE_EXECUTION, + TelemetryCase.PROMPT_GENERATION, + TelemetryCase.TOOL_EXECUTION, + TelemetryCase.MODERATION_CHECK, + TelemetryCase.SUGGESTED_QUESTION, + TelemetryCase.DATASET_RETRIEVAL, + TelemetryCase.GENERATE_NAME, + } + metric_log_cases = { + TelemetryCase.APP_CREATED, + TelemetryCase.APP_UPDATED, + TelemetryCase.APP_DELETED, + TelemetryCase.FEEDBACK_CREATED, + } + + all_cases = trace_cases | metric_log_cases + assert len(all_cases) == 14 + assert all_cases == set(TelemetryCase) + + for case in trace_cases: + assert CASE_ROUTING[case].signal_type == SignalType.TRACE + + for case in metric_log_cases: + assert CASE_ROUTING[case].signal_type == SignalType.METRIC_LOG diff --git a/api/tests/unit_tests/enterprise/telemetry/test_event_handlers.py b/api/tests/unit_tests/enterprise/telemetry/test_event_handlers.py new file mode 100644 index 0000000000..ad15c9f096 --- /dev/null +++ b/api/tests/unit_tests/enterprise/telemetry/test_event_handlers.py @@ -0,0 +1,121 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from enterprise.telemetry import event_handlers +from enterprise.telemetry.contracts import TelemetryCase + + +@pytest.fixture +def mock_gateway_emit(): + with patch("core.telemetry.gateway.emit") as mock: + yield mock + + +def test_handle_app_created_calls_task(mock_gateway_emit): + sender = MagicMock() + sender.id = "app-123" + sender.tenant_id = "tenant-456" + sender.mode = "chat" + + event_handlers._handle_app_created(sender) + + mock_gateway_emit.assert_called_once_with( + case=TelemetryCase.APP_CREATED, + context={"tenant_id": "tenant-456"}, + payload={"app_id": "app-123", "mode": "chat"}, + ) + + +def test_handle_app_created_no_exporter(mock_gateway_emit): + """Gateway handles exporter availability internally; handler always calls gateway.""" + sender = MagicMock() + sender.id = "app-123" + sender.tenant_id = "tenant-456" + + event_handlers._handle_app_created(sender) + + mock_gateway_emit.assert_called_once() + + +def test_handle_app_updated_calls_task(mock_gateway_emit): + sender = MagicMock() + sender.id = "app-123" + sender.tenant_id = "tenant-456" + + event_handlers._handle_app_updated(sender) + + mock_gateway_emit.assert_called_once_with( + case=TelemetryCase.APP_UPDATED, + context={"tenant_id": "tenant-456"}, + payload={"app_id": "app-123"}, + ) + + +def test_handle_app_deleted_calls_task(mock_gateway_emit): + sender = MagicMock() + sender.id = "app-123" + sender.tenant_id = "tenant-456" + + event_handlers._handle_app_deleted(sender) + + mock_gateway_emit.assert_called_once_with( + case=TelemetryCase.APP_DELETED, + context={"tenant_id": "tenant-456"}, + payload={"app_id": "app-123"}, + ) + + +def test_handle_feedback_created_calls_task(mock_gateway_emit): + sender = MagicMock() + sender.message_id = "msg-123" + sender.app_id = "app-456" + sender.conversation_id = "conv-789" + sender.from_end_user_id = "user-001" + sender.from_account_id = None + sender.rating = "like" + sender.from_source = "api" + sender.content = "Great response!" + + event_handlers._handle_feedback_created(sender, tenant_id="tenant-456") + + mock_gateway_emit.assert_called_once_with( + case=TelemetryCase.FEEDBACK_CREATED, + context={"tenant_id": "tenant-456"}, + payload={ + "message_id": "msg-123", + "app_id": "app-456", + "conversation_id": "conv-789", + "from_end_user_id": "user-001", + "from_account_id": None, + "rating": "like", + "from_source": "api", + "content": "Great response!", + }, + ) + + +def test_handle_feedback_created_no_exporter(mock_gateway_emit): + """Gateway handles exporter availability internally; handler always calls gateway.""" + sender = MagicMock() + sender.message_id = "msg-123" + + event_handlers._handle_feedback_created(sender, tenant_id="tenant-456") + + mock_gateway_emit.assert_called_once() + + +def test_handlers_create_valid_envelopes(mock_gateway_emit): + """Verify handlers pass correct TelemetryCase and payload structure.""" + sender = MagicMock() + sender.id = "app-123" + sender.tenant_id = "tenant-456" + sender.mode = "chat" + + event_handlers._handle_app_created(sender) + + call_kwargs = mock_gateway_emit.call_args[1] + assert call_kwargs["case"] == TelemetryCase.APP_CREATED + assert call_kwargs["context"]["tenant_id"] == "tenant-456" + assert call_kwargs["payload"]["app_id"] == "app-123" + assert call_kwargs["payload"]["mode"] == "chat" diff --git a/api/tests/unit_tests/enterprise/telemetry/test_exporter.py b/api/tests/unit_tests/enterprise/telemetry/test_exporter.py new file mode 100644 index 0000000000..2c367b4118 --- /dev/null +++ b/api/tests/unit_tests/enterprise/telemetry/test_exporter.py @@ -0,0 +1,263 @@ +"""Unit tests for EnterpriseExporter and _ExporterFactory.""" + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +from configs.enterprise import EnterpriseTelemetryConfig +from enterprise.telemetry.exporter import EnterpriseExporter + + +def test_config_api_key_default_empty(): + """Test that ENTERPRISE_OTLP_API_KEY defaults to empty string.""" + config = EnterpriseTelemetryConfig() + assert config.ENTERPRISE_OTLP_API_KEY == "" + + +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_api_key_only_injects_bearer_header(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None: + """Test that API key alone injects Bearer authorization header.""" + mock_config = SimpleNamespace( + ENTERPRISE_OTLP_ENDPOINT="https://collector.example.com", + ENTERPRISE_OTLP_HEADERS="", + ENTERPRISE_OTLP_PROTOCOL="grpc", + ENTERPRISE_SERVICE_NAME="dify", + ENTERPRISE_OTEL_SAMPLING_RATE=1.0, + ENTERPRISE_INCLUDE_CONTENT=True, + ENTERPRISE_OTLP_API_KEY="test-secret-key", + ) + + EnterpriseExporter(mock_config) + + # Verify span exporter was called with Bearer header + assert mock_span_exporter.call_args is not None + headers = mock_span_exporter.call_args.kwargs.get("headers") + assert headers is not None + assert ("authorization", "Bearer test-secret-key") in headers + + +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_empty_api_key_no_auth_header(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None: + """Test that empty API key does not inject authorization header.""" + mock_config = SimpleNamespace( + ENTERPRISE_OTLP_ENDPOINT="https://collector.example.com", + ENTERPRISE_OTLP_HEADERS="", + ENTERPRISE_OTLP_PROTOCOL="grpc", + ENTERPRISE_SERVICE_NAME="dify", + ENTERPRISE_OTEL_SAMPLING_RATE=1.0, + ENTERPRISE_INCLUDE_CONTENT=True, + ENTERPRISE_OTLP_API_KEY="", + ) + + EnterpriseExporter(mock_config) + + # Verify span exporter was called without authorization header + assert mock_span_exporter.call_args is not None + headers = mock_span_exporter.call_args.kwargs.get("headers") + # Headers should be None or not contain authorization + if headers is not None: + assert not any(key == "authorization" for key, _ in headers) + + +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_api_key_and_custom_headers_merge(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None: + """Test that API key and custom headers are merged correctly.""" + mock_config = SimpleNamespace( + ENTERPRISE_OTLP_ENDPOINT="https://collector.example.com", + ENTERPRISE_OTLP_HEADERS="x-custom=foo", + ENTERPRISE_OTLP_PROTOCOL="grpc", + ENTERPRISE_SERVICE_NAME="dify", + ENTERPRISE_OTEL_SAMPLING_RATE=1.0, + ENTERPRISE_INCLUDE_CONTENT=True, + ENTERPRISE_OTLP_API_KEY="test-key", + ) + + EnterpriseExporter(mock_config) + + # Verify both headers are present + assert mock_span_exporter.call_args is not None + headers = mock_span_exporter.call_args.kwargs.get("headers") + assert headers is not None + assert ("authorization", "Bearer test-key") in headers + assert ("x-custom", "foo") in headers + + +@patch("enterprise.telemetry.exporter.logger") +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_api_key_overrides_conflicting_header( + mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock, mock_logger: MagicMock +) -> None: + """Test that API key overrides conflicting authorization header and logs warning.""" + mock_config = SimpleNamespace( + ENTERPRISE_OTLP_ENDPOINT="https://collector.example.com", + ENTERPRISE_OTLP_HEADERS="authorization=Basic old", + ENTERPRISE_OTLP_PROTOCOL="grpc", + ENTERPRISE_SERVICE_NAME="dify", + ENTERPRISE_OTEL_SAMPLING_RATE=1.0, + ENTERPRISE_INCLUDE_CONTENT=True, + ENTERPRISE_OTLP_API_KEY="test-key", + ) + + EnterpriseExporter(mock_config) + + # Verify Bearer header takes precedence + assert mock_span_exporter.call_args is not None + headers = mock_span_exporter.call_args.kwargs.get("headers") + assert headers is not None + assert ("authorization", "Bearer test-key") in headers + # Verify old authorization header is not present + assert ("authorization", "Basic old") not in headers + + # Verify warning was logged + mock_logger.warning.assert_called_once() + assert mock_logger.warning.call_args is not None + warning_message = mock_logger.warning.call_args[0][0] + assert "ENTERPRISE_OTLP_API_KEY is set" in warning_message + assert "authorization" in warning_message + + +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_https_endpoint_uses_secure_grpc(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None: + """Test that https:// endpoint enables TLS (insecure=False) for gRPC.""" + mock_config = SimpleNamespace( + ENTERPRISE_OTLP_ENDPOINT="https://collector.example.com", + ENTERPRISE_OTLP_HEADERS="", + ENTERPRISE_OTLP_PROTOCOL="grpc", + ENTERPRISE_SERVICE_NAME="dify", + ENTERPRISE_OTEL_SAMPLING_RATE=1.0, + ENTERPRISE_INCLUDE_CONTENT=True, + ENTERPRISE_OTLP_API_KEY="test-key", + ) + + EnterpriseExporter(mock_config) + + # Verify insecure=False for both exporters (https:// scheme) + assert mock_span_exporter.call_args is not None + assert mock_span_exporter.call_args.kwargs["insecure"] is False + + assert mock_metric_exporter.call_args is not None + assert mock_metric_exporter.call_args.kwargs["insecure"] is False + + +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_http_endpoint_uses_insecure_grpc(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None: + """Test that http:// endpoint uses insecure gRPC (insecure=True).""" + mock_config = SimpleNamespace( + ENTERPRISE_OTLP_ENDPOINT="http://collector.example.com", + ENTERPRISE_OTLP_HEADERS="", + ENTERPRISE_OTLP_PROTOCOL="grpc", + ENTERPRISE_SERVICE_NAME="dify", + ENTERPRISE_OTEL_SAMPLING_RATE=1.0, + ENTERPRISE_INCLUDE_CONTENT=True, + ENTERPRISE_OTLP_API_KEY="", + ) + + EnterpriseExporter(mock_config) + + # Verify insecure=True for both exporters (http:// scheme) + assert mock_span_exporter.call_args is not None + assert mock_span_exporter.call_args.kwargs["insecure"] is True + + assert mock_metric_exporter.call_args is not None + assert mock_metric_exporter.call_args.kwargs["insecure"] is True + + +@patch("enterprise.telemetry.exporter.HTTPSpanExporter") +@patch("enterprise.telemetry.exporter.HTTPMetricExporter") +def test_insecure_not_passed_to_http_exporters(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None: + """Test that insecure parameter is not passed to HTTP exporters.""" + mock_config = SimpleNamespace( + ENTERPRISE_OTLP_ENDPOINT="http://collector.example.com", + ENTERPRISE_OTLP_HEADERS="", + ENTERPRISE_OTLP_PROTOCOL="http", + ENTERPRISE_SERVICE_NAME="dify", + ENTERPRISE_OTEL_SAMPLING_RATE=1.0, + ENTERPRISE_INCLUDE_CONTENT=True, + ENTERPRISE_OTLP_API_KEY="test-key", + ) + + EnterpriseExporter(mock_config) + + # Verify insecure kwarg is NOT in HTTP exporter calls + assert mock_span_exporter.call_args is not None + assert "insecure" not in mock_span_exporter.call_args.kwargs + + assert mock_metric_exporter.call_args is not None + assert "insecure" not in mock_metric_exporter.call_args.kwargs + + +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_api_key_with_special_chars_preserved(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None: + """Test that API key with special characters is preserved without mangling.""" + special_key = "abc+def/ghi=jkl==" + mock_config = SimpleNamespace( + ENTERPRISE_OTLP_ENDPOINT="https://collector.example.com", + ENTERPRISE_OTLP_HEADERS="", + ENTERPRISE_OTLP_PROTOCOL="grpc", + ENTERPRISE_SERVICE_NAME="dify", + ENTERPRISE_OTEL_SAMPLING_RATE=1.0, + ENTERPRISE_INCLUDE_CONTENT=True, + ENTERPRISE_OTLP_API_KEY=special_key, + ) + + EnterpriseExporter(mock_config) + + # Verify special characters are preserved in Bearer header + assert mock_span_exporter.call_args is not None + headers = mock_span_exporter.call_args.kwargs.get("headers") + assert headers is not None + assert ("authorization", f"Bearer {special_key}") in headers + + +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_no_scheme_localhost_uses_insecure(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None: + """Test that endpoint without scheme defaults to insecure for localhost.""" + mock_config = SimpleNamespace( + ENTERPRISE_OTLP_ENDPOINT="localhost:4317", + ENTERPRISE_OTLP_HEADERS="", + ENTERPRISE_OTLP_PROTOCOL="grpc", + ENTERPRISE_SERVICE_NAME="dify", + ENTERPRISE_OTEL_SAMPLING_RATE=1.0, + ENTERPRISE_INCLUDE_CONTENT=True, + ENTERPRISE_OTLP_API_KEY="", + ) + + EnterpriseExporter(mock_config) + + # Verify insecure=True for localhost without scheme + assert mock_span_exporter.call_args is not None + assert mock_span_exporter.call_args.kwargs["insecure"] is True + + assert mock_metric_exporter.call_args is not None + assert mock_metric_exporter.call_args.kwargs["insecure"] is True + + +@patch("enterprise.telemetry.exporter.GRPCSpanExporter") +@patch("enterprise.telemetry.exporter.GRPCMetricExporter") +def test_no_scheme_production_uses_insecure(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None: + """Test that endpoint without scheme defaults to insecure (not https://).""" + mock_config = SimpleNamespace( + ENTERPRISE_OTLP_ENDPOINT="collector.example.com:4317", + ENTERPRISE_OTLP_HEADERS="", + ENTERPRISE_OTLP_PROTOCOL="grpc", + ENTERPRISE_SERVICE_NAME="dify", + ENTERPRISE_OTEL_SAMPLING_RATE=1.0, + ENTERPRISE_INCLUDE_CONTENT=True, + ENTERPRISE_OTLP_API_KEY="", + ) + + EnterpriseExporter(mock_config) + + # Verify insecure=True for any endpoint without https:// scheme + assert mock_span_exporter.call_args is not None + assert mock_span_exporter.call_args.kwargs["insecure"] is True + + assert mock_metric_exporter.call_args is not None + assert mock_metric_exporter.call_args.kwargs["insecure"] is True diff --git a/api/tests/unit_tests/enterprise/telemetry/test_gateway.py b/api/tests/unit_tests/enterprise/telemetry/test_gateway.py new file mode 100644 index 0000000000..d979dc7336 --- /dev/null +++ b/api/tests/unit_tests/enterprise/telemetry/test_gateway.py @@ -0,0 +1,272 @@ +from __future__ import annotations + +import sys +from unittest.mock import MagicMock, patch + +import pytest + +from core.ops.entities.trace_entity import TraceTaskName +from core.telemetry.gateway import ( + CASE_ROUTING, + CASE_TO_TRACE_TASK, + PAYLOAD_SIZE_THRESHOLD_BYTES, + emit, +) +from enterprise.telemetry.contracts import SignalType, TelemetryCase, TelemetryEnvelope + + +class TestCaseRoutingTable: + def test_all_cases_have_routing(self) -> None: + for case in TelemetryCase: + assert case in CASE_ROUTING, f"Missing routing for {case}" + + def test_trace_cases(self) -> None: + trace_cases = [ + TelemetryCase.WORKFLOW_RUN, + TelemetryCase.MESSAGE_RUN, + TelemetryCase.NODE_EXECUTION, + TelemetryCase.DRAFT_NODE_EXECUTION, + TelemetryCase.PROMPT_GENERATION, + ] + for case in trace_cases: + assert CASE_ROUTING[case].signal_type is SignalType.TRACE, f"{case} should be trace" + + def test_metric_log_cases(self) -> None: + metric_log_cases = [ + TelemetryCase.APP_CREATED, + TelemetryCase.APP_UPDATED, + TelemetryCase.APP_DELETED, + TelemetryCase.FEEDBACK_CREATED, + ] + for case in metric_log_cases: + assert CASE_ROUTING[case].signal_type is SignalType.METRIC_LOG, f"{case} should be metric_log" + + def test_ce_eligible_cases(self) -> None: + ce_eligible_cases = [ + TelemetryCase.WORKFLOW_RUN, + TelemetryCase.MESSAGE_RUN, + TelemetryCase.TOOL_EXECUTION, + TelemetryCase.MODERATION_CHECK, + TelemetryCase.SUGGESTED_QUESTION, + TelemetryCase.DATASET_RETRIEVAL, + TelemetryCase.GENERATE_NAME, + ] + for case in ce_eligible_cases: + assert CASE_ROUTING[case].ce_eligible is True, f"{case} should be CE eligible" + + def test_enterprise_only_cases(self) -> None: + enterprise_only_cases = [ + TelemetryCase.NODE_EXECUTION, + TelemetryCase.DRAFT_NODE_EXECUTION, + TelemetryCase.PROMPT_GENERATION, + ] + for case in enterprise_only_cases: + assert CASE_ROUTING[case].ce_eligible is False, f"{case} should be enterprise-only" + + def test_trace_cases_have_task_name_mapping(self) -> None: + trace_cases = [c for c in TelemetryCase if CASE_ROUTING[c].signal_type is SignalType.TRACE] + for case in trace_cases: + assert case in CASE_TO_TRACE_TASK, f"Missing TraceTaskName mapping for {case}" + + +@pytest.fixture +def mock_ops_trace_manager(): + mock_module = MagicMock() + mock_trace_task_class = MagicMock() + mock_trace_task_class.return_value = MagicMock() + mock_module.TraceTask = mock_trace_task_class + mock_module.TraceQueueManager = MagicMock() + + mock_trace_entity = MagicMock() + mock_trace_task_name = MagicMock() + mock_trace_task_name.return_value = "workflow" + mock_trace_entity.TraceTaskName = mock_trace_task_name + + with ( + patch.dict(sys.modules, {"core.ops.ops_trace_manager": mock_module}), + patch.dict(sys.modules, {"core.ops.entities.trace_entity": mock_trace_entity}), + ): + yield mock_module, mock_trace_entity + + +class TestGatewayTraceRouting: + @pytest.fixture + def mock_trace_manager(self) -> MagicMock: + return MagicMock() + + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_trace_case_routes_to_trace_manager( + self, + _mock_ee_enabled: MagicMock, + mock_trace_manager: MagicMock, + mock_ops_trace_manager: tuple[MagicMock, MagicMock], + ) -> None: + context = {"app_id": "app-123", "user_id": "user-456", "tenant_id": "tenant-789"} + payload = {"workflow_run_id": "run-abc"} + + emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False) + def test_ce_eligible_trace_enqueued_when_ee_disabled( + self, + _mock_ee_enabled: MagicMock, + mock_trace_manager: MagicMock, + mock_ops_trace_manager: tuple[MagicMock, MagicMock], + ) -> None: + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"workflow_run_id": "run-abc"} + + emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False) + def test_enterprise_only_trace_dropped_when_ee_disabled( + self, + _mock_ee_enabled: MagicMock, + mock_trace_manager: MagicMock, + mock_ops_trace_manager: tuple[MagicMock, MagicMock], + ) -> None: + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"node_id": "node-abc"} + + emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_not_called() + + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_enterprise_only_trace_enqueued_when_ee_enabled( + self, + _mock_ee_enabled: MagicMock, + mock_trace_manager: MagicMock, + mock_ops_trace_manager: tuple[MagicMock, MagicMock], + ) -> None: + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"node_id": "node-abc"} + + emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + +class TestGatewayMetricLogRouting: + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + @patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") + def test_metric_case_routes_to_celery_task( + self, + mock_delay: MagicMock, + _mock_ee_enabled: MagicMock, + ) -> None: + context = {"tenant_id": "tenant-123"} + payload = {"app_id": "app-abc", "name": "My App"} + + emit(TelemetryCase.APP_CREATED, context, payload) + + mock_delay.assert_called_once() + envelope_json = mock_delay.call_args[0][0] + envelope = TelemetryEnvelope.model_validate_json(envelope_json) + assert envelope.case == TelemetryCase.APP_CREATED + assert envelope.tenant_id == "tenant-123" + assert envelope.payload["app_id"] == "app-abc" + + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + @patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") + def test_envelope_has_unique_event_id( + self, + mock_delay: MagicMock, + _mock_ee_enabled: MagicMock, + ) -> None: + context = {"tenant_id": "tenant-123"} + payload = {"app_id": "app-abc"} + + emit(TelemetryCase.APP_CREATED, context, payload) + emit(TelemetryCase.APP_CREATED, context, payload) + + assert mock_delay.call_count == 2 + envelope1 = TelemetryEnvelope.model_validate_json(mock_delay.call_args_list[0][0][0]) + envelope2 = TelemetryEnvelope.model_validate_json(mock_delay.call_args_list[1][0][0]) + assert envelope1.event_id != envelope2.event_id + + +class TestGatewayPayloadSizing: + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + @patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") + def test_small_payload_inlined( + self, + mock_delay: MagicMock, + _mock_ee_enabled: MagicMock, + ) -> None: + context = {"tenant_id": "tenant-123"} + payload = {"key": "small_value"} + + emit(TelemetryCase.APP_CREATED, context, payload) + + envelope_json = mock_delay.call_args[0][0] + envelope = TelemetryEnvelope.model_validate_json(envelope_json) + assert envelope.payload == payload + assert envelope.metadata is None + + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + @patch("core.telemetry.gateway.storage") + @patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") + def test_large_payload_stored( + self, + mock_delay: MagicMock, + mock_storage: MagicMock, + _mock_ee_enabled: MagicMock, + ) -> None: + context = {"tenant_id": "tenant-123"} + large_value = "x" * (PAYLOAD_SIZE_THRESHOLD_BYTES + 1000) + payload = {"key": large_value} + + emit(TelemetryCase.APP_CREATED, context, payload) + + mock_storage.save.assert_called_once() + storage_key = mock_storage.save.call_args[0][0] + assert storage_key.startswith("telemetry/tenant-123/") + + envelope_json = mock_delay.call_args[0][0] + envelope = TelemetryEnvelope.model_validate_json(envelope_json) + assert envelope.payload == {} + assert envelope.metadata is not None + assert envelope.metadata["payload_ref"] == storage_key + + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + @patch("core.telemetry.gateway.storage") + @patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") + def test_large_payload_fallback_on_storage_error( + self, + mock_delay: MagicMock, + mock_storage: MagicMock, + _mock_ee_enabled: MagicMock, + ) -> None: + mock_storage.save.side_effect = Exception("Storage failure") + context = {"tenant_id": "tenant-123"} + large_value = "x" * (PAYLOAD_SIZE_THRESHOLD_BYTES + 1000) + payload = {"key": large_value} + + emit(TelemetryCase.APP_CREATED, context, payload) + + envelope_json = mock_delay.call_args[0][0] + envelope = TelemetryEnvelope.model_validate_json(envelope_json) + assert envelope.payload == payload + assert envelope.metadata is None + + +class TestTraceTaskNameMapping: + def test_workflow_run_mapping(self) -> None: + assert CASE_TO_TRACE_TASK[TelemetryCase.WORKFLOW_RUN] is TraceTaskName.WORKFLOW_TRACE + + def test_message_run_mapping(self) -> None: + assert CASE_TO_TRACE_TASK[TelemetryCase.MESSAGE_RUN] is TraceTaskName.MESSAGE_TRACE + + def test_node_execution_mapping(self) -> None: + assert CASE_TO_TRACE_TASK[TelemetryCase.NODE_EXECUTION] is TraceTaskName.NODE_EXECUTION_TRACE + + def test_draft_node_execution_mapping(self) -> None: + assert CASE_TO_TRACE_TASK[TelemetryCase.DRAFT_NODE_EXECUTION] is TraceTaskName.DRAFT_NODE_EXECUTION_TRACE + + def test_prompt_generation_mapping(self) -> None: + assert CASE_TO_TRACE_TASK[TelemetryCase.PROMPT_GENERATION] is TraceTaskName.PROMPT_GENERATION_TRACE diff --git a/api/tests/unit_tests/enterprise/telemetry/test_metric_handler.py b/api/tests/unit_tests/enterprise/telemetry/test_metric_handler.py new file mode 100644 index 0000000000..19822fd69f --- /dev/null +++ b/api/tests/unit_tests/enterprise/telemetry/test_metric_handler.py @@ -0,0 +1,507 @@ +"""Unit tests for EnterpriseMetricHandler.""" + +import json +from unittest.mock import MagicMock, patch + +import pytest + +from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope +from enterprise.telemetry.metric_handler import EnterpriseMetricHandler + + +@pytest.fixture +def mock_redis(): + with patch("enterprise.telemetry.metric_handler.redis_client") as mock: + yield mock + + +@pytest.fixture +def sample_envelope(): + return TelemetryEnvelope( + case=TelemetryCase.APP_CREATED, + tenant_id="test-tenant", + event_id="test-event-123", + payload={"app_id": "app-123", "name": "Test App"}, + ) + + +def test_dispatch_app_created(sample_envelope, mock_redis): + mock_redis.set.return_value = True + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_app_created") as mock_handler: + handler.handle(sample_envelope) + mock_handler.assert_called_once_with(sample_envelope) + + +def test_dispatch_app_updated(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.APP_UPDATED, + tenant_id="test-tenant", + event_id="test-event-456", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_app_updated") as mock_handler: + handler.handle(envelope) + mock_handler.assert_called_once_with(envelope) + + +def test_dispatch_app_deleted(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.APP_DELETED, + tenant_id="test-tenant", + event_id="test-event-789", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_app_deleted") as mock_handler: + handler.handle(envelope) + mock_handler.assert_called_once_with(envelope) + + +def test_dispatch_feedback_created(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.FEEDBACK_CREATED, + tenant_id="test-tenant", + event_id="test-event-abc", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_feedback_created") as mock_handler: + handler.handle(envelope) + mock_handler.assert_called_once_with(envelope) + + +def test_dispatch_message_run(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.MESSAGE_RUN, + tenant_id="test-tenant", + event_id="test-event-msg", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_message_run") as mock_handler: + handler.handle(envelope) + mock_handler.assert_called_once_with(envelope) + + +def test_dispatch_tool_execution(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.TOOL_EXECUTION, + tenant_id="test-tenant", + event_id="test-event-tool", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_tool_execution") as mock_handler: + handler.handle(envelope) + mock_handler.assert_called_once_with(envelope) + + +def test_dispatch_moderation_check(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.MODERATION_CHECK, + tenant_id="test-tenant", + event_id="test-event-mod", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_moderation_check") as mock_handler: + handler.handle(envelope) + mock_handler.assert_called_once_with(envelope) + + +def test_dispatch_suggested_question(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.SUGGESTED_QUESTION, + tenant_id="test-tenant", + event_id="test-event-sq", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_suggested_question") as mock_handler: + handler.handle(envelope) + mock_handler.assert_called_once_with(envelope) + + +def test_dispatch_dataset_retrieval(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.DATASET_RETRIEVAL, + tenant_id="test-tenant", + event_id="test-event-ds", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_dataset_retrieval") as mock_handler: + handler.handle(envelope) + mock_handler.assert_called_once_with(envelope) + + +def test_dispatch_generate_name(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.GENERATE_NAME, + tenant_id="test-tenant", + event_id="test-event-gn", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_generate_name") as mock_handler: + handler.handle(envelope) + mock_handler.assert_called_once_with(envelope) + + +def test_dispatch_prompt_generation(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.PROMPT_GENERATION, + tenant_id="test-tenant", + event_id="test-event-pg", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_prompt_generation") as mock_handler: + handler.handle(envelope) + mock_handler.assert_called_once_with(envelope) + + +def test_all_known_cases_have_handlers(mock_redis): + mock_redis.set.return_value = True + handler = EnterpriseMetricHandler() + + for case in TelemetryCase: + envelope = TelemetryEnvelope( + case=case, + tenant_id="test-tenant", + event_id=f"test-{case.value}", + payload={}, + ) + handler.handle(envelope) + + +def test_idempotency_duplicate(sample_envelope, mock_redis): + mock_redis.set.return_value = None + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_app_created") as mock_handler: + handler.handle(sample_envelope) + mock_handler.assert_not_called() + + +def test_idempotency_first_seen(sample_envelope, mock_redis): + mock_redis.set.return_value = True + + handler = EnterpriseMetricHandler() + is_dup = handler._is_duplicate(sample_envelope) + + assert is_dup is False + mock_redis.set.assert_called_once_with( + "telemetry:dedup:test-tenant:test-event-123", + b"1", + nx=True, + ex=3600, + ) + + +def test_idempotency_redis_failure_fails_open(sample_envelope, mock_redis, caplog): + mock_redis.set.side_effect = Exception("Redis unavailable") + + handler = EnterpriseMetricHandler() + is_dup = handler._is_duplicate(sample_envelope) + + assert is_dup is False + assert "Redis unavailable for deduplication check" in caplog.text + + +def test_rehydration_uses_payload(sample_envelope): + handler = EnterpriseMetricHandler() + payload = handler._rehydrate(sample_envelope) + + assert payload == {"app_id": "app-123", "name": "Test App"} + + +def test_rehydration_from_storage(): + """Verify _rehydrate loads payload from object storage via payload_ref.""" + stored_data = {"app_id": "app-stored", "mode": "workflow"} + envelope = TelemetryEnvelope( + case=TelemetryCase.APP_CREATED, + tenant_id="test-tenant", + event_id="test-event-fb", + payload={}, + metadata={"payload_ref": "telemetry/test-tenant/test-event-fb.json"}, + ) + + handler = EnterpriseMetricHandler() + with patch("enterprise.telemetry.metric_handler.storage") as mock_storage: + mock_storage.load.return_value = json.dumps(stored_data).encode("utf-8") + payload = handler._rehydrate(envelope) + + assert payload == stored_data + mock_storage.load.assert_called_once_with("telemetry/test-tenant/test-event-fb.json") + + +def test_rehydration_storage_failure_emits_degraded_event(): + """Verify _rehydrate emits degraded event when storage load fails.""" + envelope = TelemetryEnvelope( + case=TelemetryCase.APP_CREATED, + tenant_id="test-tenant", + event_id="test-event-fail", + payload={}, + metadata={"payload_ref": "telemetry/test-tenant/test-event-fail.json"}, + ) + + handler = EnterpriseMetricHandler() + with ( + patch("enterprise.telemetry.metric_handler.storage") as mock_storage, + patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit, + ): + mock_storage.load.side_effect = Exception("Storage unavailable") + payload = handler._rehydrate(envelope) + + from enterprise.telemetry.entities import EnterpriseTelemetryEvent + + assert payload == {} + mock_emit.assert_called_once() + call_args = mock_emit.call_args + assert call_args[1]["event_name"] == EnterpriseTelemetryEvent.REHYDRATION_FAILED + assert call_args[1]["attributes"]["rehydration_failed"] is True + + +def test_rehydration_emits_degraded_event_on_empty_payload(): + """Verify _rehydrate emits degraded event when payload is empty and no ref exists.""" + envelope = TelemetryEnvelope( + case=TelemetryCase.APP_CREATED, + tenant_id="test-tenant", + event_id="test-event-empty", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit: + payload = handler._rehydrate(envelope) + + from enterprise.telemetry.entities import EnterpriseTelemetryEvent + + assert payload == {} + mock_emit.assert_called_once() + call_args = mock_emit.call_args + assert call_args[1]["event_name"] == EnterpriseTelemetryEvent.REHYDRATION_FAILED + assert call_args[1]["attributes"]["rehydration_failed"] is True + + +def test_on_app_created_emits_correct_event(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.APP_CREATED, + tenant_id="tenant-123", + event_id="event-456", + payload={"app_id": "app-789", "mode": "chat"}, + ) + + handler = EnterpriseMetricHandler() + with ( + patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter, + patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit, + ): + mock_exporter = MagicMock() + mock_get_exporter.return_value = mock_exporter + + handler._on_app_created(envelope) + + from enterprise.telemetry.entities import EnterpriseTelemetryEvent + + mock_emit.assert_called_once_with( + event_name=EnterpriseTelemetryEvent.APP_CREATED, + attributes={ + "dify.app.id": "app-789", + "dify.tenant_id": "tenant-123", + "dify.event.id": "event-456", + "dify.app.mode": "chat", + }, + tenant_id="tenant-123", + ) + from enterprise.telemetry.entities import EnterpriseTelemetryCounter + + mock_exporter.increment_counter.assert_called_once() + call_args = mock_exporter.increment_counter.call_args + assert call_args[0][0] == EnterpriseTelemetryCounter.APP_CREATED + assert call_args[0][1] == 1 + assert call_args[0][2]["tenant_id"] == "tenant-123" + assert call_args[0][2]["app_id"] == "app-789" + assert call_args[0][2]["mode"] == "chat" + + +def test_on_app_updated_emits_correct_event(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.APP_UPDATED, + tenant_id="tenant-123", + event_id="event-456", + payload={"app_id": "app-789"}, + ) + + handler = EnterpriseMetricHandler() + with ( + patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter, + patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit, + ): + mock_exporter = MagicMock() + mock_get_exporter.return_value = mock_exporter + + handler._on_app_updated(envelope) + + from enterprise.telemetry.entities import EnterpriseTelemetryEvent + + mock_emit.assert_called_once_with( + event_name=EnterpriseTelemetryEvent.APP_UPDATED, + attributes={ + "dify.app.id": "app-789", + "dify.tenant_id": "tenant-123", + "dify.event.id": "event-456", + }, + tenant_id="tenant-123", + ) + from enterprise.telemetry.entities import EnterpriseTelemetryCounter + + mock_exporter.increment_counter.assert_called_once() + call_args = mock_exporter.increment_counter.call_args + assert call_args[0][0] == EnterpriseTelemetryCounter.APP_UPDATED + assert call_args[0][1] == 1 + assert call_args[0][2]["tenant_id"] == "tenant-123" + assert call_args[0][2]["app_id"] == "app-789" + + +def test_on_app_deleted_emits_correct_event(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.APP_DELETED, + tenant_id="tenant-123", + event_id="event-456", + payload={"app_id": "app-789"}, + ) + + handler = EnterpriseMetricHandler() + with ( + patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter, + patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit, + ): + mock_exporter = MagicMock() + mock_get_exporter.return_value = mock_exporter + + handler._on_app_deleted(envelope) + + from enterprise.telemetry.entities import EnterpriseTelemetryEvent + + mock_emit.assert_called_once_with( + event_name=EnterpriseTelemetryEvent.APP_DELETED, + attributes={ + "dify.app.id": "app-789", + "dify.tenant_id": "tenant-123", + "dify.event.id": "event-456", + }, + tenant_id="tenant-123", + ) + from enterprise.telemetry.entities import EnterpriseTelemetryCounter + + mock_exporter.increment_counter.assert_called_once() + call_args = mock_exporter.increment_counter.call_args + assert call_args[0][0] == EnterpriseTelemetryCounter.APP_DELETED + assert call_args[0][1] == 1 + assert call_args[0][2]["tenant_id"] == "tenant-123" + assert call_args[0][2]["app_id"] == "app-789" + + +def test_on_feedback_created_emits_correct_event(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.FEEDBACK_CREATED, + tenant_id="tenant-123", + event_id="event-456", + payload={ + "message_id": "msg-001", + "app_id": "app-789", + "conversation_id": "conv-123", + "from_end_user_id": "user-456", + "from_account_id": None, + "rating": "like", + "from_source": "api", + "content": "Great!", + }, + ) + + handler = EnterpriseMetricHandler() + with ( + patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter, + patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit, + ): + mock_exporter = MagicMock() + mock_exporter.include_content = True + mock_get_exporter.return_value = mock_exporter + + handler._on_feedback_created(envelope) + + mock_emit.assert_called_once() + call_args = mock_emit.call_args + assert call_args[1]["event_name"] == "dify.feedback.created" + assert call_args[1]["attributes"]["dify.message.id"] == "msg-001" + assert call_args[1]["attributes"]["dify.feedback.content"] == "Great!" + assert call_args[1]["tenant_id"] == "tenant-123" + assert call_args[1]["user_id"] == "user-456" + + mock_exporter.increment_counter.assert_called_once() + counter_args = mock_exporter.increment_counter.call_args + assert counter_args[0][2]["app_id"] == "app-789" + assert counter_args[0][2]["rating"] == "like" + + +def test_on_feedback_created_without_content(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.FEEDBACK_CREATED, + tenant_id="tenant-123", + event_id="event-456", + payload={ + "message_id": "msg-001", + "app_id": "app-789", + "conversation_id": "conv-123", + "from_end_user_id": "user-456", + "from_account_id": None, + "rating": "like", + "from_source": "api", + "content": "Great!", + }, + ) + + handler = EnterpriseMetricHandler() + with ( + patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter, + patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit, + ): + mock_exporter = MagicMock() + mock_exporter.include_content = False + mock_get_exporter.return_value = mock_exporter + + handler._on_feedback_created(envelope) + + mock_emit.assert_called_once() + call_args = mock_emit.call_args + assert "dify.feedback.content" not in call_args[1]["attributes"] diff --git a/api/tests/unit_tests/tasks/test_enterprise_telemetry_task.py b/api/tests/unit_tests/tasks/test_enterprise_telemetry_task.py new file mode 100644 index 0000000000..b48c69a146 --- /dev/null +++ b/api/tests/unit_tests/tasks/test_enterprise_telemetry_task.py @@ -0,0 +1,69 @@ +"""Unit tests for enterprise telemetry Celery task.""" + +import json +from unittest.mock import MagicMock, patch + +import pytest + +from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope +from tasks.enterprise_telemetry_task import process_enterprise_telemetry + + +@pytest.fixture +def sample_envelope_json(): + envelope = TelemetryEnvelope( + case=TelemetryCase.APP_CREATED, + tenant_id="test-tenant", + event_id="test-event-123", + payload={"app_id": "app-123"}, + ) + return envelope.model_dump_json() + + +def test_process_enterprise_telemetry_success(sample_envelope_json): + with patch("tasks.enterprise_telemetry_task.EnterpriseMetricHandler") as mock_handler_class: + mock_handler = MagicMock() + mock_handler_class.return_value = mock_handler + + process_enterprise_telemetry(sample_envelope_json) + + mock_handler.handle.assert_called_once() + call_args = mock_handler.handle.call_args[0][0] + assert isinstance(call_args, TelemetryEnvelope) + assert call_args.case == TelemetryCase.APP_CREATED + assert call_args.tenant_id == "test-tenant" + assert call_args.event_id == "test-event-123" + + +def test_process_enterprise_telemetry_invalid_json(caplog): + invalid_json = "not valid json" + + process_enterprise_telemetry(invalid_json) + + assert "Failed to process enterprise telemetry envelope" in caplog.text + + +def test_process_enterprise_telemetry_handler_exception(sample_envelope_json, caplog): + with patch("tasks.enterprise_telemetry_task.EnterpriseMetricHandler") as mock_handler_class: + mock_handler = MagicMock() + mock_handler.handle.side_effect = Exception("Handler error") + mock_handler_class.return_value = mock_handler + + process_enterprise_telemetry(sample_envelope_json) + + assert "Failed to process enterprise telemetry envelope" in caplog.text + + +def test_process_enterprise_telemetry_validation_error(caplog): + invalid_envelope = json.dumps( + { + "case": "INVALID_CASE", + "tenant_id": "test-tenant", + "event_id": "test-event", + "payload": {}, + } + ) + + process_enterprise_telemetry(invalid_envelope) + + assert "Failed to process enterprise telemetry envelope" in caplog.text