chore: add type to test (#37876)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Asuka Minato 2026-06-25 11:01:09 +08:00 committed by GitHub
parent 31a50a3b20
commit d93989bfc0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 266 additions and 233 deletions

View File

@ -181,29 +181,33 @@ class TestTencentDataTrace:
mock_trace_utils.convert_to_trace_id.return_value = 123
mock_trace_utils.create_link.return_value = "link"
with patch.object(tencent_data_trace, "_get_user_id", return_value="user-1"):
with patch.object(tencent_data_trace, "_process_workflow_nodes") as mock_proc:
with patch.object(tencent_data_trace, "_record_workflow_trace_duration") as mock_dur:
mock_span_builder.build_workflow_spans.return_value = [MagicMock(), MagicMock()]
with (
patch.object(tencent_data_trace, "_get_user_id", return_value="user-1"),
patch.object(tencent_data_trace, "_process_workflow_nodes") as mock_proc,
patch.object(tencent_data_trace, "_record_workflow_trace_duration") as mock_dur,
):
mock_span_builder.build_workflow_spans.return_value = [MagicMock(), MagicMock()]
tencent_data_trace.workflow_trace(trace_info)
tencent_data_trace.workflow_trace(trace_info)
mock_trace_utils.convert_to_trace_id.assert_called_once_with("run-id")
mock_trace_utils.create_link.assert_called_once_with("parent-trace-id")
mock_span_builder.build_workflow_spans.assert_called_once()
assert tencent_data_trace.trace_client.add_span.call_count == 2
mock_proc.assert_called_once_with(trace_info, 123)
mock_dur.assert_called_once_with(trace_info)
mock_trace_utils.convert_to_trace_id.assert_called_once_with("run-id")
mock_trace_utils.create_link.assert_called_once_with("parent-trace-id")
mock_span_builder.build_workflow_spans.assert_called_once()
assert tencent_data_trace.trace_client.add_span.call_count == 2
mock_proc.assert_called_once_with(trace_info, 123)
mock_dur.assert_called_once_with(trace_info)
def test_workflow_trace_exception(self, tencent_data_trace, caplog):
def test_workflow_trace_exception(self, tencent_data_trace, caplog: pytest.LogCaptureFixture):
trace_info = MagicMock(spec=WorkflowTraceInfo)
trace_info.workflow_run_id = "run-id"
with patch(
"dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error")
with (
patch(
"dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error")
),
caplog.at_level(logging.ERROR),
):
with caplog.at_level(logging.ERROR):
tencent_data_trace.workflow_trace(trace_info)
tencent_data_trace.workflow_trace(trace_info)
assert "[Tencent APM] Failed to process workflow trace" in caplog.text
def test_message_trace(self, tencent_data_trace, mock_trace_utils, mock_span_builder):
@ -214,28 +218,32 @@ class TestTencentDataTrace:
mock_trace_utils.convert_to_trace_id.return_value = 123
mock_trace_utils.create_link.return_value = "link"
with patch.object(tencent_data_trace, "_get_user_id", return_value="user-1"):
with patch.object(tencent_data_trace, "_record_message_llm_metrics") as mock_metrics:
with patch.object(tencent_data_trace, "_record_message_trace_duration") as mock_dur:
mock_span_builder.build_message_span.return_value = MagicMock()
with (
patch.object(tencent_data_trace, "_get_user_id", return_value="user-1"),
patch.object(tencent_data_trace, "_record_message_llm_metrics") as mock_metrics,
patch.object(tencent_data_trace, "_record_message_trace_duration") as mock_dur,
):
mock_span_builder.build_message_span.return_value = MagicMock()
tencent_data_trace.message_trace(trace_info)
tencent_data_trace.message_trace(trace_info)
mock_trace_utils.convert_to_trace_id.assert_called_once_with("msg-id")
mock_trace_utils.create_link.assert_called_once_with("parent-trace-id")
mock_span_builder.build_message_span.assert_called_once()
tencent_data_trace.trace_client.add_span.assert_called_once()
mock_metrics.assert_called_once_with(trace_info)
mock_dur.assert_called_once_with(trace_info)
mock_trace_utils.convert_to_trace_id.assert_called_once_with("msg-id")
mock_trace_utils.create_link.assert_called_once_with("parent-trace-id")
mock_span_builder.build_message_span.assert_called_once()
tencent_data_trace.trace_client.add_span.assert_called_once()
mock_metrics.assert_called_once_with(trace_info)
mock_dur.assert_called_once_with(trace_info)
def test_message_trace_exception(self, tencent_data_trace, caplog):
def test_message_trace_exception(self, tencent_data_trace, caplog: pytest.LogCaptureFixture):
trace_info = MagicMock(spec=MessageTraceInfo)
with patch(
"dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error")
with (
patch(
"dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error")
),
caplog.at_level(logging.ERROR),
):
with caplog.at_level(logging.ERROR):
tencent_data_trace.message_trace(trace_info)
tencent_data_trace.message_trace(trace_info)
assert "[Tencent APM] Failed to process message trace" in caplog.text
def test_tool_trace(self, tencent_data_trace, mock_trace_utils, mock_span_builder):
@ -259,15 +267,17 @@ class TestTencentDataTrace:
tencent_data_trace.tool_trace(trace_info)
tencent_data_trace.trace_client.add_span.assert_not_called()
def test_tool_trace_exception(self, tencent_data_trace, caplog):
def test_tool_trace_exception(self, tencent_data_trace, caplog: pytest.LogCaptureFixture):
trace_info = MagicMock(spec=ToolTraceInfo)
trace_info.message_id = "msg-id"
with patch(
"dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error")
with (
patch(
"dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error")
),
caplog.at_level(logging.ERROR),
):
with caplog.at_level(logging.ERROR):
tencent_data_trace.tool_trace(trace_info)
tencent_data_trace.tool_trace(trace_info)
assert "[Tencent APM] Failed to process tool trace" in caplog.text
def test_dataset_retrieval_trace(self, tencent_data_trace, mock_trace_utils, mock_span_builder):
@ -291,24 +301,28 @@ class TestTencentDataTrace:
tencent_data_trace.dataset_retrieval_trace(trace_info)
tencent_data_trace.trace_client.add_span.assert_not_called()
def test_dataset_retrieval_trace_exception(self, tencent_data_trace, caplog):
def test_dataset_retrieval_trace_exception(self, tencent_data_trace, caplog: pytest.LogCaptureFixture):
trace_info = MagicMock(spec=DatasetRetrievalTraceInfo)
trace_info.message_id = "msg-id"
with patch(
"dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error")
with (
patch(
"dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error")
),
caplog.at_level(logging.ERROR),
):
with caplog.at_level(logging.ERROR):
tencent_data_trace.dataset_retrieval_trace(trace_info)
tencent_data_trace.dataset_retrieval_trace(trace_info)
assert "[Tencent APM] Failed to process dataset retrieval trace" in caplog.text
def test_suggested_question_trace(self, tencent_data_trace, caplog):
def test_suggested_question_trace(self, tencent_data_trace, caplog: pytest.LogCaptureFixture):
trace_info = MagicMock(spec=SuggestedQuestionTraceInfo)
with caplog.at_level(logging.INFO):
tencent_data_trace.suggested_question_trace(trace_info)
assert "[Tencent APM] Processing suggested question trace" in caplog.text
def test_suggested_question_trace_exception(self, tencent_data_trace, monkeypatch, caplog):
def test_suggested_question_trace_exception(
self, tencent_data_trace, monkeypatch, caplog: pytest.LogCaptureFixture
):
trace_info = MagicMock(spec=SuggestedQuestionTraceInfo)
target_logger = logging.getLogger("dify_trace_tencent.tencent_trace")
monkeypatch.setattr(target_logger, "info", MagicMock(side_effect=Exception("error")))
@ -328,28 +342,36 @@ class TestTencentDataTrace:
node2.id = "n2"
node2.node_type = BuiltinNodeTypes.TOOL
with patch.object(tencent_data_trace, "_get_workflow_node_executions", return_value=[node1, node2]):
with patch.object(tencent_data_trace, "_build_workflow_node_span", side_effect=["span1", "span2"]):
with patch.object(tencent_data_trace, "_record_llm_metrics") as mock_metrics:
tencent_data_trace._process_workflow_nodes(trace_info, 123)
with (
patch.object(tencent_data_trace, "_get_workflow_node_executions", return_value=[node1, node2]),
patch.object(tencent_data_trace, "_build_workflow_node_span", side_effect=["span1", "span2"]),
patch.object(tencent_data_trace, "_record_llm_metrics") as mock_metrics,
):
tencent_data_trace._process_workflow_nodes(trace_info, 123)
assert tencent_data_trace.trace_client.add_span.call_count == 2
mock_metrics.assert_called_once_with(node1)
assert tencent_data_trace.trace_client.add_span.call_count == 2
mock_metrics.assert_called_once_with(node1)
def test_process_workflow_nodes_node_exception(self, tencent_data_trace, mock_trace_utils, caplog):
def test_process_workflow_nodes_node_exception(
self, tencent_data_trace, mock_trace_utils, caplog: pytest.LogCaptureFixture
):
trace_info = MagicMock(spec=WorkflowTraceInfo)
mock_trace_utils.convert_to_span_id.return_value = 111
node = MagicMock(spec=WorkflowNodeExecution)
node.id = "n1"
with patch.object(tencent_data_trace, "_get_workflow_node_executions", return_value=[node]):
with patch.object(tencent_data_trace, "_build_workflow_node_span", side_effect=Exception("node error")):
with caplog.at_level(logging.ERROR):
tencent_data_trace._process_workflow_nodes(trace_info, 123)
with (
patch.object(tencent_data_trace, "_get_workflow_node_executions", return_value=[node]),
patch.object(tencent_data_trace, "_build_workflow_node_span", side_effect=Exception("node error")),
caplog.at_level(logging.ERROR),
):
tencent_data_trace._process_workflow_nodes(trace_info, 123)
assert "[Tencent APM] Failed to process workflow nodes" in caplog.text
def test_process_workflow_nodes_exception(self, tencent_data_trace, mock_trace_utils, caplog):
def test_process_workflow_nodes_exception(
self, tencent_data_trace, mock_trace_utils, caplog: pytest.LogCaptureFixture
):
trace_info = MagicMock(spec=WorkflowTraceInfo)
mock_trace_utils.convert_to_span_id.side_effect = Exception("outer error")
@ -377,7 +399,9 @@ class TestTencentDataTrace:
assert result == "span"
builder_method.assert_called_once_with(123, 456, trace_info, node)
def test_build_workflow_node_span_exception(self, tencent_data_trace, mock_span_builder, caplog):
def test_build_workflow_node_span_exception(
self, tencent_data_trace, mock_span_builder, caplog: pytest.LogCaptureFixture
):
node = MagicMock(spec=WorkflowNodeExecution)
node.node_type = BuiltinNodeTypes.LLM
node.id = "n1"
@ -419,7 +443,7 @@ class TestTencentDataTrace:
assert results == mock_executions
account.set_tenant_id.assert_called_once_with("tenant-1")
def test_get_workflow_node_executions_no_app_id(self, tencent_data_trace, caplog):
def test_get_workflow_node_executions_no_app_id(self, tencent_data_trace, caplog: pytest.LogCaptureFixture):
trace_info = MagicMock(spec=WorkflowTraceInfo)
trace_info.metadata = {}
@ -428,7 +452,7 @@ class TestTencentDataTrace:
assert results == []
assert len([r for r in caplog.records if r.levelno == logging.ERROR]) >= 1
def test_get_workflow_node_executions_app_not_found(self, tencent_data_trace, caplog):
def test_get_workflow_node_executions_app_not_found(self, tencent_data_trace, caplog: pytest.LogCaptureFixture):
trace_info = MagicMock(spec=WorkflowTraceInfo)
trace_info.metadata = {"app_id": "app-1"}
@ -449,13 +473,15 @@ class TestTencentDataTrace:
trace_info.tenant_id = "tenant-1"
trace_info.metadata = {"user_id": "user-1"}
with patch("dify_trace_tencent.tencent_trace.sessionmaker", side_effect=Exception("Database error")):
with patch("dify_trace_tencent.tencent_trace.db") as mock_db:
mock_db.init_app = MagicMock()
mock_db.engine = MagicMock()
with (
patch("dify_trace_tencent.tencent_trace.sessionmaker", side_effect=Exception("Database error")),
patch("dify_trace_tencent.tencent_trace.db") as mock_db,
):
mock_db.init_app = MagicMock()
mock_db.engine = MagicMock()
user_id = tencent_data_trace._get_user_id(trace_info)
assert user_id == "unknown"
user_id = tencent_data_trace._get_user_id(trace_info)
assert user_id == "unknown"
def test_get_user_id_only_user_id(self, tencent_data_trace):
trace_info = MagicMock(spec=MessageTraceInfo)
@ -471,14 +497,16 @@ class TestTencentDataTrace:
user_id = tencent_data_trace._get_user_id(trace_info)
assert user_id == "anonymous"
def test_get_user_id_exception(self, tencent_data_trace, caplog):
def test_get_user_id_exception(self, tencent_data_trace, caplog: pytest.LogCaptureFixture):
trace_info = MagicMock(spec=WorkflowTraceInfo)
trace_info.tenant_id = "t"
trace_info.metadata = {"user_id": "u"}
with patch("dify_trace_tencent.tencent_trace.sessionmaker", side_effect=Exception("error")):
with caplog.at_level(logging.ERROR):
user_id = tencent_data_trace._get_user_id(trace_info)
with (
patch("dify_trace_tencent.tencent_trace.sessionmaker", side_effect=Exception("error")),
caplog.at_level(logging.ERROR),
):
user_id = tencent_data_trace._get_user_id(trace_info)
assert user_id == "unknown"
assert "[Tencent APM] Failed to get user ID" in caplog.text
@ -514,7 +542,7 @@ class TestTencentDataTrace:
tencent_data_trace.trace_client.record_llm_duration.assert_called_once()
tencent_data_trace.trace_client.record_token_usage.assert_called_once()
def test_record_llm_metrics_exception(self, tencent_data_trace, caplog):
def test_record_llm_metrics_exception(self, tencent_data_trace, caplog: pytest.LogCaptureFixture):
node = MagicMock(spec=WorkflowNodeExecution)
node.process_data = None
node.outputs = None
@ -553,7 +581,7 @@ class TestTencentDataTrace:
tencent_data_trace._record_message_llm_metrics(trace_info)
tencent_data_trace.trace_client.record_llm_duration.assert_called_once()
def test_record_message_llm_metrics_exception(self, tencent_data_trace, caplog):
def test_record_message_llm_metrics_exception(self, tencent_data_trace, caplog: pytest.LogCaptureFixture):
trace_info = MagicMock(spec=MessageTraceInfo)
trace_info.metadata = None
@ -605,7 +633,7 @@ class TestTencentDataTrace:
attributes = kwargs["attributes"] if "attributes" in kwargs else args[1] if len(args) > 1 else {}
assert attributes["has_conversation"] == "false"
def test_record_workflow_trace_duration_exception(self, tencent_data_trace, caplog):
def test_record_workflow_trace_duration_exception(self, tencent_data_trace, caplog: pytest.LogCaptureFixture):
trace_info = MagicMock(spec=WorkflowTraceInfo)
trace_info.start_time = MagicMock() # This might cause total_seconds() to fail if not mocked right
@ -627,7 +655,7 @@ class TestTencentDataTrace:
2.0, {"conversation_mode": "chat", "stream": "true"}
)
def test_record_message_trace_duration_exception(self, tencent_data_trace, caplog):
def test_record_message_trace_duration_exception(self, tencent_data_trace, caplog: pytest.LogCaptureFixture):
trace_info = MagicMock(spec=MessageTraceInfo)
trace_info.start_time = None
@ -647,7 +675,7 @@ class TestTencentDataTrace:
client.shutdown.assert_called_once()
def test_close_exception(self, tencent_data_trace, caplog):
def test_close_exception(self, tencent_data_trace, caplog: pytest.LogCaptureFixture):
tencent_data_trace.trace_client.shutdown.side_effect = Exception("error")
with caplog.at_level(logging.ERROR):
tencent_data_trace.close()

View File

@ -268,9 +268,11 @@ class TestWeaviateVector(unittest.TestCase):
wv._client = MagicMock()
wv._client.collections.exists.side_effect = RuntimeError("create failed")
with patch.object(weaviate_vector_module.logger, "exception") as mock_exception:
with pytest.raises(RuntimeError, match="create failed"):
wv._create_collection()
with (
patch.object(weaviate_vector_module.logger, "exception") as mock_exception,
pytest.raises(RuntimeError, match="create failed"),
):
wv._create_collection()
mock_exception.assert_called_once()
@ -835,9 +837,11 @@ class TestWeaviateVector(unittest.TestCase):
wv._client.collections.use.return_value = mock_col
mock_col.data.delete_by_id.side_effect = FakeUnexpectedStatusCodeError(500)
with patch.object(weaviate_vector_module, "UnexpectedStatusCodeError", FakeUnexpectedStatusCodeError):
with pytest.raises(FakeUnexpectedStatusCodeError, match="status=500"):
wv.delete_by_ids(["bad-id"])
with (
patch.object(weaviate_vector_module, "UnexpectedStatusCodeError", FakeUnexpectedStatusCodeError),
pytest.raises(FakeUnexpectedStatusCodeError, match="status=500"),
):
wv.delete_by_ids(["bad-id"])
def test_json_serializable_converts_datetime(self):
wv = WeaviateVector.__new__(WeaviateVector)

View File

@ -1,5 +1,4 @@
import builtins
from types import SimpleNamespace
from unittest.mock import patch
from flask.views import MethodView as FlaskMethodView
@ -22,7 +21,7 @@ def test_parameters_model_round_trip():
def test_site_icon_url_uses_signed_url_for_image_icon():
site = SimpleNamespace(
site = Site(
title="Example",
chat_color_theme=None,
chat_color_theme_inverted=False,
@ -46,7 +45,7 @@ def test_site_icon_url_uses_signed_url_for_image_icon():
def test_site_icon_url_is_none_for_non_image_icon():
site = SimpleNamespace(
site = Site(
title="Example",
chat_color_theme=None,
chat_color_theme_inverted=False,

View File

@ -2,21 +2,14 @@
from __future__ import annotations
from inspect import unwrap
from types import SimpleNamespace
import pytest
from flask import Flask
from controllers.console.app import workflow as workflow_module
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
from controllers.console.app.workflow import ConvertToWorkflowApi
class TestConvertToWorkflowApi:
@ -25,9 +18,9 @@ class TestConvertToWorkflowApi:
return workflow_module.ConvertToWorkflowApi()
def test_convert_to_workflow_attaches_permission_keys_when_rbac_enabled(
self, api, app: Flask, monkeypatch: pytest.MonkeyPatch
self, api: ConvertToWorkflowApi, app: Flask, monkeypatch: pytest.MonkeyPatch
) -> None:
method = _unwrap(api.post)
method = unwrap(api.post)
monkeypatch.setattr(
workflow_module,
@ -46,6 +39,7 @@ class TestConvertToWorkflowApi:
json={},
):
response = method(
api,
current_tenant_id="tenant-1",
current_user=SimpleNamespace(id="u1"),
app_model=SimpleNamespace(id="app-1"),

View File

@ -192,7 +192,9 @@ class TestLoginApi:
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
def test_login_fails_when_rate_limited(self, mock_get_invitation, mock_is_rate_limit, mock_db, app: Flask, caplog):
def test_login_fails_when_rate_limited(
self, mock_get_invitation, mock_is_rate_limit, mock_db, app: Flask, caplog: pytest.LogCaptureFixture
):
"""
Test login rejection when rate limit is exceeded.
@ -222,7 +224,9 @@ class TestLoginApi:
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", True)
@patch("controllers.console.auth.login.BillingService.is_email_in_freeze")
def test_login_fails_when_account_frozen(self, mock_is_frozen, mock_db, app: Flask, caplog):
def test_login_fails_when_account_frozen(
self, mock_is_frozen, mock_db, app: Flask, caplog: pytest.LogCaptureFixture
):
"""
Test login rejection for frozen accounts.
@ -262,7 +266,7 @@ class TestLoginApi:
mock_is_rate_limit,
mock_db,
app: Flask,
caplog,
caplog: pytest.LogCaptureFixture,
):
"""
Test login failure with invalid credentials.
@ -462,7 +466,7 @@ class TestLoginApi:
mock_get_token_data: MagicMock,
mock_db: MagicMock,
app: Flask,
caplog,
caplog: pytest.LogCaptureFixture,
):
mock_get_token_data.return_value = {"email": "User@Example.com", "code": "123456"}
mock_get_account.side_effect = Unauthorized("Account is banned.")

View File

@ -1,3 +1,4 @@
from inspect import unwrap
from types import SimpleNamespace
from unittest.mock import Mock
@ -10,12 +11,6 @@ from models.account import Account, AccountStatus
from services.workflow_draft_variable_service import WorkflowDraftVariableList
def _unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
def _make_account() -> Account:
account = Account(
name="tester",
@ -66,7 +61,7 @@ def test_ensure_snippet_draft_variable_row_allowed_accepts_canvas_node_variable(
def test_conversation_variables_returns_empty_list(app: Flask):
api = module.SnippetConversationVariableCollectionApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
with app.test_request_context("/"):
result = handler(api, _make_account(), snippet=SimpleNamespace(id="snippet-1"))
@ -76,7 +71,7 @@ def test_conversation_variables_returns_empty_list(app: Flask):
def test_system_variables_returns_empty_list(app: Flask):
api = module.SnippetSystemVariableCollectionApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
with app.test_request_context("/"):
result = handler(api, _make_account(), snippet=SimpleNamespace(id="snippet-1"))
@ -91,7 +86,7 @@ def test_delete_variable_collection_deletes_current_user_variables(app: Flask, m
db_session.return_value = SimpleNamespace()
monkeypatch.setattr(module.db, "session", db_session)
api = module.SnippetWorkflowVariableCollectionApi()
handler = _unwrap(api.delete)
handler = unwrap(api.delete)
with app.test_request_context("/", method="DELETE"):
response = handler(api, _make_account(), snippet=SimpleNamespace(id="snippet-1"))
@ -109,7 +104,7 @@ def test_variable_collection_get_raises_when_draft_workflow_missing(app: Flask,
)
api = module.SnippetWorkflowVariableCollectionApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
with app.test_request_context("/?page=1&limit=20"):
with pytest.raises(module.DraftWorkflowNotExist):
@ -140,7 +135,7 @@ def test_node_variable_collection_get_lists_node_variables(app: Flask, monkeypat
)
api = module.SnippetNodeVariableCollectionApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
with app.test_request_context("/"):
result = handler(api, _make_account(), snippet=SimpleNamespace(id="snippet-1"), node_id="llm-1")
@ -158,7 +153,7 @@ def test_node_variable_collection_delete_deletes_node_variables(app: Flask, monk
monkeypatch.setattr(module.db, "session", db_session)
api = module.SnippetNodeVariableCollectionApi()
handler = _unwrap(api.delete)
handler = unwrap(api.delete)
with app.test_request_context("/", method="DELETE"):
response = handler(api, _make_account(), snippet=SimpleNamespace(id="snippet-1"), node_id="llm-1")
@ -177,7 +172,7 @@ def test_variable_patch_returns_variable_when_no_changes(app: Flask, monkeypatch
monkeypatch.setattr(module, "WorkflowDraftVariableService", Mock(return_value=draft_var_service))
api = module.SnippetVariableApi()
handler = _unwrap(api.patch)
handler = unwrap(api.patch)
with app.test_request_context("/", method="PATCH", json={}):
result = handler(
@ -202,7 +197,7 @@ def test_variable_delete_deletes_variable(app: Flask, monkeypatch: pytest.Monkey
monkeypatch.setattr(module, "WorkflowDraftVariableService", Mock(return_value=draft_var_service))
api = module.SnippetVariableApi()
handler = _unwrap(api.delete)
handler = unwrap(api.delete)
with app.test_request_context("/", method="DELETE"):
response = handler(api, _make_account(), snippet=SimpleNamespace(id="snippet-1"), variable_id="var-1")
@ -230,7 +225,7 @@ def test_variable_reset_returns_no_content_when_reset_result_is_none(app: Flask,
)
api = module.SnippetVariableResetApi()
handler = _unwrap(api.put)
handler = unwrap(api.put)
with app.test_request_context("/", method="PUT"):
response = handler(api, _make_account(), snippet=SimpleNamespace(id="snippet-1"), variable_id="var-1")
@ -260,7 +255,7 @@ def test_environment_variables_returns_workflow_environment_variables(app: Flask
)
api = module.SnippetEnvironmentVariableCollectionApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
with app.test_request_context("/"):
result = handler(api, _make_account(), snippet=SimpleNamespace(id="snippet-1"))

View File

@ -21,6 +21,7 @@ from core.ops.entities.trace_entity import (
WorkflowNodeTraceInfo,
WorkflowTraceInfo,
)
from enterprise.telemetry.enterprise_trace import EnterpriseOtelTrace
from enterprise.telemetry.entities import (
EnterpriseTelemetryCounter,
EnterpriseTelemetryEvent,
@ -297,43 +298,43 @@ def test_init_succeeds_with_valid_exporter(mock_exporter):
class TestSafePayloadValue:
def test_string_passthrough(self, trace_handler):
def test_string_passthrough(self, trace_handler: EnterpriseOtelTrace):
assert trace_handler._safe_payload_value("hello") == "hello"
def test_dict_passthrough(self, trace_handler):
def test_dict_passthrough(self, trace_handler: EnterpriseOtelTrace):
d = {"key": "val"}
assert trace_handler._safe_payload_value(d) == d
def test_list_passthrough(self, trace_handler):
def test_list_passthrough(self, trace_handler: EnterpriseOtelTrace):
lst = [1, 2, 3]
assert trace_handler._safe_payload_value(lst) == lst
def test_none_returns_none(self, trace_handler):
def test_none_returns_none(self, trace_handler: EnterpriseOtelTrace):
assert trace_handler._safe_payload_value(None) is None
def test_int_returns_none(self, trace_handler):
def test_int_returns_none(self, trace_handler: EnterpriseOtelTrace):
assert trace_handler._safe_payload_value(42) is None
def test_bool_returns_none(self, trace_handler):
def test_bool_returns_none(self, trace_handler: EnterpriseOtelTrace):
assert trace_handler._safe_payload_value(True) is None
class TestMaybeJson:
def test_none_returns_none(self, trace_handler):
def test_none_returns_none(self, trace_handler: EnterpriseOtelTrace):
assert trace_handler._maybe_json(None) is None
def test_string_passthrough(self, trace_handler):
def test_string_passthrough(self, trace_handler: EnterpriseOtelTrace):
assert trace_handler._maybe_json("hello") == "hello"
def test_dict_serialised(self, trace_handler):
def test_dict_serialised(self, trace_handler: EnterpriseOtelTrace):
result = trace_handler._maybe_json({"a": 1})
assert result == json.dumps({"a": 1})
def test_list_serialised(self, trace_handler):
def test_list_serialised(self, trace_handler: EnterpriseOtelTrace):
result = trace_handler._maybe_json([1, 2])
assert result == "[1, 2]"
def test_non_serialisable_falls_back_to_str(self, trace_handler):
def test_non_serialisable_falls_back_to_str(self, trace_handler: EnterpriseOtelTrace):
class Unserializable:
def __repr__(self):
return "Unserializable()"
@ -344,22 +345,22 @@ class TestMaybeJson:
class TestContentOrRef:
def test_returns_content_when_include_content_true(self, trace_handler, mock_exporter):
def test_returns_content_when_include_content_true(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
mock_exporter.include_content = True
result = trace_handler._content_or_ref("actual content", "ref:x=1")
assert result == "actual content"
def test_returns_ref_when_include_content_false(self, trace_handler, mock_exporter):
def test_returns_ref_when_include_content_false(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
mock_exporter.include_content = False
result = trace_handler._content_or_ref("actual content", "ref:x=1")
assert result == "ref:x=1"
def test_dict_serialised_when_include_content_true(self, trace_handler, mock_exporter):
def test_dict_serialised_when_include_content_true(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
mock_exporter.include_content = True
result = trace_handler._content_or_ref({"key": "val"}, "ref:x=1")
assert result == json.dumps({"key": "val"})
def test_none_returns_none_when_include_content_true(self, trace_handler, mock_exporter):
def test_none_returns_none_when_include_content_true(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
mock_exporter.include_content = True
result = trace_handler._content_or_ref(None, "ref:x=1")
assert result is None
@ -371,67 +372,67 @@ class TestContentOrRef:
class TestTraceDispatcher:
def test_dispatches_workflow_trace(self, trace_handler):
def test_dispatches_workflow_trace(self, trace_handler: EnterpriseOtelTrace):
with patch.object(trace_handler, "_workflow_trace") as mock_method:
info = make_workflow_info()
trace_handler.trace(info)
mock_method.assert_called_once_with(info)
def test_dispatches_message_trace(self, trace_handler):
def test_dispatches_message_trace(self, trace_handler: EnterpriseOtelTrace):
with patch.object(trace_handler, "_message_trace") as mock_method:
info = make_message_info()
trace_handler.trace(info)
mock_method.assert_called_once_with(info)
def test_dispatches_tool_trace(self, trace_handler):
def test_dispatches_tool_trace(self, trace_handler: EnterpriseOtelTrace):
with patch.object(trace_handler, "_tool_trace") as mock_method:
info = make_tool_info()
trace_handler.trace(info)
mock_method.assert_called_once_with(info)
def test_dispatches_draft_node_execution_trace(self, trace_handler):
def test_dispatches_draft_node_execution_trace(self, trace_handler: EnterpriseOtelTrace):
with patch.object(trace_handler, "_draft_node_execution_trace") as mock_method:
info = make_draft_node_info()
trace_handler.trace(info)
mock_method.assert_called_once_with(info)
def test_dispatches_node_execution_trace(self, trace_handler):
def test_dispatches_node_execution_trace(self, trace_handler: EnterpriseOtelTrace):
with patch.object(trace_handler, "_node_execution_trace") as mock_method:
info = make_node_info()
trace_handler.trace(info)
mock_method.assert_called_once_with(info)
def test_dispatches_moderation_trace(self, trace_handler):
def test_dispatches_moderation_trace(self, trace_handler: EnterpriseOtelTrace):
with patch.object(trace_handler, "_moderation_trace") as mock_method:
info = make_moderation_info()
trace_handler.trace(info)
mock_method.assert_called_once_with(info)
def test_dispatches_suggested_question_trace(self, trace_handler):
def test_dispatches_suggested_question_trace(self, trace_handler: EnterpriseOtelTrace):
with patch.object(trace_handler, "_suggested_question_trace") as mock_method:
info = make_suggested_question_info()
trace_handler.trace(info)
mock_method.assert_called_once_with(info)
def test_dispatches_dataset_retrieval_trace(self, trace_handler):
def test_dispatches_dataset_retrieval_trace(self, trace_handler: EnterpriseOtelTrace):
with patch.object(trace_handler, "_dataset_retrieval_trace") as mock_method:
info = make_dataset_retrieval_info()
trace_handler.trace(info)
mock_method.assert_called_once_with(info)
def test_dispatches_generate_name_trace(self, trace_handler):
def test_dispatches_generate_name_trace(self, trace_handler: EnterpriseOtelTrace):
with patch.object(trace_handler, "_generate_name_trace") as mock_method:
info = make_generate_name_info()
trace_handler.trace(info)
mock_method.assert_called_once_with(info)
def test_dispatches_prompt_generation_trace(self, trace_handler):
def test_dispatches_prompt_generation_trace(self, trace_handler: EnterpriseOtelTrace):
with patch.object(trace_handler, "_prompt_generation_trace") as mock_method:
info = make_prompt_generation_info()
trace_handler.trace(info)
mock_method.assert_called_once_with(info)
def test_draft_node_dispatched_before_node(self, trace_handler):
def test_draft_node_dispatched_before_node(self, trace_handler: EnterpriseOtelTrace):
"""DraftNodeExecutionTrace is a subclass of WorkflowNodeTraceInfo;
it must be dispatched to _draft_node_execution_trace, not _node_execution_trace."""
with (
@ -450,7 +451,7 @@ class TestTraceDispatcher:
class TestWorkflowTrace:
def test_emits_correct_span_attributes(self, trace_handler, mock_exporter):
def test_emits_correct_span_attributes(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log") as mock_log:
info = make_workflow_info()
trace_handler._workflow_trace(info)
@ -465,7 +466,7 @@ class TestWorkflowTrace:
assert attrs["dify.workflow.status"] == "succeeded"
assert attrs["gen_ai.usage.total_tokens"] == 100
def test_span_timing_passed_correctly(self, trace_handler, mock_exporter):
def test_span_timing_passed_correctly(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"):
info = make_workflow_info()
trace_handler._workflow_trace(info)
@ -474,7 +475,7 @@ class TestWorkflowTrace:
assert span_call[1]["start_time"] == _T0
assert span_call[1]["end_time"] == _T1
def test_emits_companion_log_with_event_name(self, trace_handler, mock_exporter):
def test_emits_companion_log_with_event_name(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log") as mock_log:
trace_handler._workflow_trace(make_workflow_info())
@ -482,7 +483,7 @@ class TestWorkflowTrace:
assert mock_log.call_args[1]["event_name"] == EnterpriseTelemetryEvent.WORKFLOW_RUN
assert mock_log.call_args[1]["tenant_id"] == "tenant-abc"
def test_companion_log_includes_content_when_enabled(self, trace_handler, mock_exporter):
def test_companion_log_includes_content_when_enabled(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
mock_exporter.include_content = True
with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log") as mock_log:
trace_handler._workflow_trace(make_workflow_info())
@ -491,7 +492,7 @@ class TestWorkflowTrace:
assert log_attrs["dify.workflow.inputs"] == json.dumps({"query": "hello"})
assert log_attrs["dify.workflow.outputs"] == json.dumps({"answer": "world"})
def test_companion_log_uses_ref_when_content_disabled(self, trace_handler, mock_exporter):
def test_companion_log_uses_ref_when_content_disabled(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
mock_exporter.include_content = False
with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log") as mock_log:
trace_handler._workflow_trace(make_workflow_info())
@ -500,7 +501,7 @@ class TestWorkflowTrace:
assert log_attrs["dify.workflow.inputs"].startswith("ref:workflow_run_id=")
assert log_attrs["dify.workflow.outputs"].startswith("ref:workflow_run_id=")
def test_increments_token_counter(self, trace_handler, mock_exporter):
def test_increments_token_counter(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"):
trace_handler._workflow_trace(make_workflow_info())
@ -510,7 +511,7 @@ class TestWorkflowTrace:
assert len(token_calls) == 1
assert token_calls[0][0][1] == 100
def test_increments_input_and_output_token_counters(self, trace_handler, mock_exporter):
def test_increments_input_and_output_token_counters(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"):
trace_handler._workflow_trace(make_workflow_info())
@ -519,7 +520,7 @@ class TestWorkflowTrace:
assert EnterpriseTelemetryCounter.INPUT_TOKENS in counter_names
assert EnterpriseTelemetryCounter.OUTPUT_TOKENS in counter_names
def test_no_input_token_counter_when_prompt_tokens_zero(self, trace_handler, mock_exporter):
def test_no_input_token_counter_when_prompt_tokens_zero(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"):
info = make_workflow_info(prompt_tokens=0)
trace_handler._workflow_trace(info)
@ -528,7 +529,7 @@ class TestWorkflowTrace:
counter_names = [c[0][0] for c in all_calls]
assert EnterpriseTelemetryCounter.INPUT_TOKENS not in counter_names
def test_records_workflow_duration_histogram(self, trace_handler, mock_exporter):
def test_records_workflow_duration_histogram(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"):
trace_handler._workflow_trace(make_workflow_info())
@ -537,7 +538,9 @@ class TestWorkflowTrace:
assert hist_call[0][0] == EnterpriseTelemetryHistogram.WORKFLOW_DURATION
assert hist_call[0][1] == pytest.approx(5.0)
def test_duration_falls_back_to_elapsed_time_when_timestamps_missing(self, trace_handler, mock_exporter):
def test_duration_falls_back_to_elapsed_time_when_timestamps_missing(
self, trace_handler: EnterpriseOtelTrace, mock_exporter
):
with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"):
info = make_workflow_info(start_time=None, end_time=None, workflow_run_elapsed_time=7.3)
trace_handler._workflow_trace(info)
@ -545,7 +548,7 @@ class TestWorkflowTrace:
hist_call = mock_exporter.record_histogram.call_args
assert hist_call[0][1] == pytest.approx(7.3)
def test_duration_defaults_to_zero_when_no_timing(self, trace_handler, mock_exporter):
def test_duration_defaults_to_zero_when_no_timing(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"):
info = make_workflow_info(start_time=None, end_time=None, workflow_run_elapsed_time=0)
trace_handler._workflow_trace(info)
@ -553,7 +556,7 @@ class TestWorkflowTrace:
hist_call = mock_exporter.record_histogram.call_args
assert hist_call[0][1] == pytest.approx(0.0)
def test_error_path_increments_error_counter(self, trace_handler, mock_exporter):
def test_error_path_increments_error_counter(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"):
info = make_workflow_info(error="Something went wrong", workflow_run_status="failed")
trace_handler._workflow_trace(info)
@ -563,7 +566,7 @@ class TestWorkflowTrace:
]
assert len(error_calls) == 1
def test_no_error_counter_on_success(self, trace_handler, mock_exporter):
def test_no_error_counter_on_success(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"):
trace_handler._workflow_trace(make_workflow_info())
@ -572,7 +575,7 @@ class TestWorkflowTrace:
]
assert len(error_calls) == 0
def test_parent_trace_context_injected_into_span_attrs(self, trace_handler, mock_exporter):
def test_parent_trace_context_injected_into_span_attrs(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"):
info = make_workflow_info(
metadata={
@ -601,14 +604,14 @@ class TestWorkflowTrace:
class TestNodeExecutionTrace:
def test_emits_span_with_node_execution_span_name(self, trace_handler, mock_exporter):
def test_emits_span_with_node_execution_span_name(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"):
trace_handler._node_execution_trace(make_node_info())
span_call = mock_exporter.export_span.call_args
assert span_call[0][0] == EnterpriseTelemetrySpan.NODE_EXECUTION
def test_span_contains_core_node_attributes(self, trace_handler, mock_exporter):
def test_span_contains_core_node_attributes(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"):
trace_handler._node_execution_trace(make_node_info())
@ -620,7 +623,7 @@ class TestNodeExecutionTrace:
assert attrs["gen_ai.request.model"] == "gpt-4"
assert attrs["gen_ai.provider.name"] == "openai"
def test_increments_token_counters_when_tokens_present(self, trace_handler, mock_exporter):
def test_increments_token_counters_when_tokens_present(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"):
trace_handler._node_execution_trace(make_node_info())
@ -629,7 +632,7 @@ class TestNodeExecutionTrace:
assert EnterpriseTelemetryCounter.INPUT_TOKENS in counter_names
assert EnterpriseTelemetryCounter.OUTPUT_TOKENS in counter_names
def test_no_token_counters_when_total_tokens_zero(self, trace_handler, mock_exporter):
def test_no_token_counters_when_total_tokens_zero(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"):
trace_handler._node_execution_trace(make_node_info(total_tokens=0))
@ -637,7 +640,7 @@ class TestNodeExecutionTrace:
assert EnterpriseTelemetryCounter.TOKENS not in counter_names
assert EnterpriseTelemetryCounter.INPUT_TOKENS not in counter_names
def test_records_node_duration_histogram(self, trace_handler, mock_exporter):
def test_records_node_duration_histogram(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"):
trace_handler._node_execution_trace(make_node_info())
@ -645,7 +648,7 @@ class TestNodeExecutionTrace:
assert hist_call[0][0] == EnterpriseTelemetryHistogram.NODE_DURATION
assert hist_call[0][1] == pytest.approx(2.5)
def test_error_path_increments_error_counter(self, trace_handler, mock_exporter):
def test_error_path_increments_error_counter(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"):
trace_handler._node_execution_trace(make_node_info(error="Node failed", status="failed"))
@ -654,14 +657,16 @@ class TestNodeExecutionTrace:
]
assert len(error_calls) == 1
def test_emits_companion_log_with_span_name_as_event(self, trace_handler, mock_exporter):
def test_emits_companion_log_with_span_name_as_event(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log") as mock_log:
trace_handler._node_execution_trace(make_node_info())
mock_log.assert_called_once()
assert mock_log.call_args[1]["event_name"] == EnterpriseTelemetrySpan.NODE_EXECUTION.value
def test_plugin_name_added_to_duration_labels_for_tool_node(self, trace_handler, mock_exporter):
def test_plugin_name_added_to_duration_labels_for_tool_node(
self, trace_handler: EnterpriseOtelTrace, mock_exporter
):
with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"):
info = make_node_info(
node_type="tool",
@ -677,7 +682,7 @@ class TestNodeExecutionTrace:
duration_labels = hist_call[0][2]
assert duration_labels.get("plugin_name") == "my-plugin"
def test_plugin_name_not_added_for_non_tool_node(self, trace_handler, mock_exporter):
def test_plugin_name_not_added_for_non_tool_node(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"):
info = make_node_info(
node_type="llm",
@ -693,7 +698,9 @@ class TestNodeExecutionTrace:
duration_labels = hist_call[0][2]
assert "plugin_name" not in duration_labels
def test_companion_log_inputs_use_ref_when_content_disabled(self, trace_handler, mock_exporter):
def test_companion_log_inputs_use_ref_when_content_disabled(
self, trace_handler: EnterpriseOtelTrace, mock_exporter
):
mock_exporter.include_content = False
with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log") as mock_log:
trace_handler._node_execution_trace(
@ -711,14 +718,14 @@ class TestNodeExecutionTrace:
class TestDraftNodeExecutionTrace:
def test_uses_draft_span_name(self, trace_handler, mock_exporter):
def test_uses_draft_span_name(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"):
trace_handler._draft_node_execution_trace(make_draft_node_info())
span_call = mock_exporter.export_span.call_args
assert span_call[0][0] == EnterpriseTelemetrySpan.DRAFT_NODE_EXECUTION
def test_correlation_id_is_node_execution_id(self, trace_handler, mock_exporter):
def test_correlation_id_is_node_execution_id(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"):
info = make_draft_node_info()
trace_handler._draft_node_execution_trace(info)
@ -726,7 +733,7 @@ class TestDraftNodeExecutionTrace:
span_call = mock_exporter.export_span.call_args
assert span_call[1]["correlation_id"] == "ne-draft-001"
def test_trace_correlation_override_is_workflow_run_id(self, trace_handler, mock_exporter):
def test_trace_correlation_override_is_workflow_run_id(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"):
info = make_draft_node_info()
trace_handler._draft_node_execution_trace(info)
@ -734,7 +741,7 @@ class TestDraftNodeExecutionTrace:
span_call = mock_exporter.export_span.call_args
assert span_call[1]["trace_correlation_override"] == "run-draft-001"
def test_companion_log_uses_draft_span_name(self, trace_handler, mock_exporter):
def test_companion_log_uses_draft_span_name(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log") as mock_log:
trace_handler._draft_node_execution_trace(make_draft_node_info())
@ -747,34 +754,36 @@ class TestDraftNodeExecutionTrace:
class TestMessageTrace:
def test_emits_event_with_correct_name(self, trace_handler, mock_exporter):
def test_emits_event_with_correct_name(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._message_trace(make_message_info())
mock_emit.assert_called_once()
assert mock_emit.call_args[1]["event_name"] == EnterpriseTelemetryEvent.MESSAGE_RUN
def test_emits_correct_tenant_and_user(self, trace_handler, mock_exporter):
def test_emits_correct_tenant_and_user(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._message_trace(make_message_info())
assert mock_emit.call_args[1]["tenant_id"] == "tenant-abc"
def test_duration_computed_from_timestamps(self, trace_handler, mock_exporter):
def test_duration_computed_from_timestamps(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._message_trace(make_message_info())
attrs = mock_emit.call_args[1]["attributes"]
assert attrs["dify.message.duration"] == pytest.approx(5.0)
def test_no_duration_when_timestamps_missing(self, trace_handler, mock_exporter):
def test_no_duration_when_timestamps_missing(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._message_trace(make_message_info(start_time=None, end_time=None))
attrs = mock_emit.call_args[1]["attributes"]
assert "dify.message.duration" not in attrs
def test_records_duration_histogram_when_timestamps_present(self, trace_handler, mock_exporter):
def test_records_duration_histogram_when_timestamps_present(
self, trace_handler: EnterpriseOtelTrace, mock_exporter
):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"):
trace_handler._message_trace(make_message_info())
@ -786,14 +795,14 @@ class TestMessageTrace:
assert len(hist_calls) == 1
assert hist_calls[0][0][1] == pytest.approx(5.0)
def test_no_duration_histogram_when_timestamps_missing(self, trace_handler, mock_exporter):
def test_no_duration_histogram_when_timestamps_missing(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"):
trace_handler._message_trace(make_message_info(start_time=None, end_time=None))
hist_names = [c[0][0] for c in mock_exporter.record_histogram.call_args_list]
assert EnterpriseTelemetryHistogram.MESSAGE_DURATION not in hist_names
def test_records_ttft_histogram_when_present(self, trace_handler, mock_exporter):
def test_records_ttft_histogram_when_present(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"):
trace_handler._message_trace(make_message_info(gen_ai_server_time_to_first_token=0.42))
@ -805,14 +814,14 @@ class TestMessageTrace:
assert len(ttft_calls) == 1
assert ttft_calls[0][0][1] == pytest.approx(0.42)
def test_no_ttft_histogram_when_not_present(self, trace_handler, mock_exporter):
def test_no_ttft_histogram_when_not_present(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"):
trace_handler._message_trace(make_message_info(gen_ai_server_time_to_first_token=None))
hist_names = [c[0][0] for c in mock_exporter.record_histogram.call_args_list]
assert EnterpriseTelemetryHistogram.MESSAGE_TTFT not in hist_names
def test_increments_token_counters(self, trace_handler, mock_exporter):
def test_increments_token_counters(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"):
trace_handler._message_trace(make_message_info())
@ -821,7 +830,7 @@ class TestMessageTrace:
assert EnterpriseTelemetryCounter.INPUT_TOKENS in counter_names
assert EnterpriseTelemetryCounter.OUTPUT_TOKENS in counter_names
def test_error_path_increments_error_counter(self, trace_handler, mock_exporter):
def test_error_path_increments_error_counter(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"):
trace_handler._message_trace(make_message_info(error="LLM failed"))
@ -830,7 +839,7 @@ class TestMessageTrace:
]
assert len(error_calls) == 1
def test_inputs_and_outputs_gated_by_include_content(self, trace_handler, mock_exporter):
def test_inputs_and_outputs_gated_by_include_content(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
mock_exporter.include_content = False
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._message_trace(make_message_info())
@ -846,27 +855,27 @@ class TestMessageTrace:
class TestToolTrace:
def test_emits_event_with_correct_name(self, trace_handler, mock_exporter):
def test_emits_event_with_correct_name(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._tool_trace(make_tool_info())
assert mock_emit.call_args[1]["event_name"] == EnterpriseTelemetryEvent.TOOL_EXECUTION
def test_status_is_succeeded_on_success(self, trace_handler, mock_exporter):
def test_status_is_succeeded_on_success(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._tool_trace(make_tool_info())
attrs = mock_emit.call_args[1]["attributes"]
assert attrs["dify.tool.status"] == "succeeded"
def test_status_is_failed_on_error(self, trace_handler, mock_exporter):
def test_status_is_failed_on_error(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._tool_trace(make_tool_info(error="Tool error"))
attrs = mock_emit.call_args[1]["attributes"]
assert attrs["dify.tool.status"] == "failed"
def test_records_tool_duration_histogram(self, trace_handler, mock_exporter):
def test_records_tool_duration_histogram(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"):
trace_handler._tool_trace(make_tool_info())
@ -874,7 +883,7 @@ class TestToolTrace:
assert hist_call[0][0] == EnterpriseTelemetryHistogram.TOOL_DURATION
assert hist_call[0][1] == pytest.approx(1.5)
def test_error_increments_error_counter(self, trace_handler, mock_exporter):
def test_error_increments_error_counter(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"):
trace_handler._tool_trace(make_tool_info(error="Tool crashed"))
@ -883,7 +892,7 @@ class TestToolTrace:
]
assert len(error_calls) == 1
def test_inputs_and_outputs_gated_by_include_content(self, trace_handler, mock_exporter):
def test_inputs_and_outputs_gated_by_include_content(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
mock_exporter.include_content = False
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._tool_trace(make_tool_info())
@ -892,7 +901,7 @@ class TestToolTrace:
assert attrs["dify.tool.inputs"].startswith("ref:message_id=")
assert attrs["dify.tool.outputs"].startswith("ref:message_id=")
def test_inputs_present_when_include_content_true(self, trace_handler, mock_exporter):
def test_inputs_present_when_include_content_true(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
mock_exporter.include_content = True
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._tool_trace(make_tool_info())
@ -901,7 +910,7 @@ class TestToolTrace:
assert attrs["dify.tool.inputs"] == json.dumps({"query": "test"})
assert attrs["dify.tool.outputs"] == "search results"
def test_increments_requests_counter(self, trace_handler, mock_exporter):
def test_increments_requests_counter(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"):
trace_handler._tool_trace(make_tool_info())
@ -918,27 +927,27 @@ class TestToolTrace:
class TestModerationTrace:
def test_emits_event_with_correct_name(self, trace_handler, mock_exporter):
def test_emits_event_with_correct_name(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._moderation_trace(make_moderation_info())
assert mock_emit.call_args[1]["event_name"] == EnterpriseTelemetryEvent.MODERATION_CHECK
def test_flagged_true_sets_attribute(self, trace_handler, mock_exporter):
def test_flagged_true_sets_attribute(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._moderation_trace(make_moderation_info(flagged=True))
attrs = mock_emit.call_args[1]["attributes"]
assert attrs["dify.moderation.flagged"] is True
def test_flagged_false_sets_attribute(self, trace_handler, mock_exporter):
def test_flagged_false_sets_attribute(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._moderation_trace(make_moderation_info(flagged=False))
attrs = mock_emit.call_args[1]["attributes"]
assert attrs["dify.moderation.flagged"] is False
def test_query_gated_by_include_content(self, trace_handler, mock_exporter):
def test_query_gated_by_include_content(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
mock_exporter.include_content = False
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._moderation_trace(make_moderation_info())
@ -946,7 +955,7 @@ class TestModerationTrace:
attrs = mock_emit.call_args[1]["attributes"]
assert attrs["dify.moderation.query"].startswith("ref:message_id=")
def test_query_present_when_include_content_true(self, trace_handler, mock_exporter):
def test_query_present_when_include_content_true(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
mock_exporter.include_content = True
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._moderation_trace(make_moderation_info())
@ -954,7 +963,7 @@ class TestModerationTrace:
attrs = mock_emit.call_args[1]["attributes"]
assert attrs["dify.moderation.query"] == "is this ok?"
def test_increments_requests_counter(self, trace_handler, mock_exporter):
def test_increments_requests_counter(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"):
trace_handler._moderation_trace(make_moderation_info())
@ -971,48 +980,48 @@ class TestModerationTrace:
class TestSuggestedQuestionTrace:
def test_emits_event_with_correct_name(self, trace_handler, mock_exporter):
def test_emits_event_with_correct_name(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._suggested_question_trace(make_suggested_question_info())
assert mock_emit.call_args[1]["event_name"] == EnterpriseTelemetryEvent.SUGGESTED_QUESTION_GENERATION
def test_duration_computed_from_timestamps(self, trace_handler, mock_exporter):
def test_duration_computed_from_timestamps(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._suggested_question_trace(make_suggested_question_info())
attrs = mock_emit.call_args[1]["attributes"]
assert attrs["dify.suggested_question.duration"] == pytest.approx(5.0)
def test_duration_is_none_when_timestamps_missing(self, trace_handler, mock_exporter):
def test_duration_is_none_when_timestamps_missing(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._suggested_question_trace(make_suggested_question_info(start_time=None, end_time=None))
attrs = mock_emit.call_args[1]["attributes"]
assert attrs["dify.suggested_question.duration"] is None
def test_status_is_failed_when_error_present(self, trace_handler, mock_exporter):
def test_status_is_failed_when_error_present(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._suggested_question_trace(make_suggested_question_info(error="Generation failed"))
attrs = mock_emit.call_args[1]["attributes"]
assert attrs["dify.suggested_question.status"] == "failed"
def test_status_falls_back_to_succeeded_when_no_error(self, trace_handler, mock_exporter):
def test_status_falls_back_to_succeeded_when_no_error(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._suggested_question_trace(make_suggested_question_info(status=None, error=None))
attrs = mock_emit.call_args[1]["attributes"]
assert attrs["dify.suggested_question.status"] == "succeeded"
def test_question_count_attribute(self, trace_handler, mock_exporter):
def test_question_count_attribute(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._suggested_question_trace(make_suggested_question_info())
attrs = mock_emit.call_args[1]["attributes"]
assert attrs["dify.suggested_question.count"] == 2
def test_questions_gated_by_include_content(self, trace_handler, mock_exporter):
def test_questions_gated_by_include_content(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
mock_exporter.include_content = False
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._suggested_question_trace(make_suggested_question_info())
@ -1020,7 +1029,7 @@ class TestSuggestedQuestionTrace:
attrs = mock_emit.call_args[1]["attributes"]
assert attrs["dify.suggested_question.questions"].startswith("ref:message_id=")
def test_increments_requests_counter(self, trace_handler, mock_exporter):
def test_increments_requests_counter(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"):
trace_handler._suggested_question_trace(make_suggested_question_info())
@ -1037,48 +1046,48 @@ class TestSuggestedQuestionTrace:
class TestDatasetRetrievalTrace:
def test_emits_event_with_correct_name(self, trace_handler, mock_exporter):
def test_emits_event_with_correct_name(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info())
assert mock_emit.call_args[1]["event_name"] == EnterpriseTelemetryEvent.DATASET_RETRIEVAL
def test_document_count_attribute(self, trace_handler, mock_exporter):
def test_document_count_attribute(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info())
attrs = mock_emit.call_args[1]["attributes"]
assert attrs["dify.retrieval.document_count"] == 1
def test_dataset_ids_extracted(self, trace_handler, mock_exporter):
def test_dataset_ids_extracted(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info())
attrs = mock_emit.call_args[1]["attributes"]
assert "ds-001" in attrs["dify.dataset.id"]
def test_empty_documents_has_zero_count(self, trace_handler, mock_exporter):
def test_empty_documents_has_zero_count(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info(documents=[]))
attrs = mock_emit.call_args[1]["attributes"]
assert attrs["dify.retrieval.document_count"] == 0
def test_status_succeeded_when_no_error(self, trace_handler, mock_exporter):
def test_status_succeeded_when_no_error(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info())
attrs = mock_emit.call_args[1]["attributes"]
assert attrs["dify.retrieval.status"] == "succeeded"
def test_status_failed_when_error_present(self, trace_handler, mock_exporter):
def test_status_failed_when_error_present(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info(error="DB error"))
attrs = mock_emit.call_args[1]["attributes"]
assert attrs["dify.retrieval.status"] == "failed"
def test_embedding_model_attributes_set_when_present(self, trace_handler, mock_exporter):
def test_embedding_model_attributes_set_when_present(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info())
@ -1086,7 +1095,7 @@ class TestDatasetRetrievalTrace:
assert "dify.dataset.embedding_providers" in attrs
assert "dify.dataset.embedding_models" in attrs
def test_no_embedding_model_attributes_when_not_provided(self, trace_handler, mock_exporter):
def test_no_embedding_model_attributes_when_not_provided(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._dataset_retrieval_trace(
make_dataset_retrieval_info(metadata={"app_id": "app-001", "tenant_id": "tenant-abc"})
@ -1096,7 +1105,7 @@ class TestDatasetRetrievalTrace:
assert "dify.dataset.embedding_providers" not in attrs
assert "dify.dataset.embedding_models" not in attrs
def test_rerank_attributes_set_when_present(self, trace_handler, mock_exporter):
def test_rerank_attributes_set_when_present(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._dataset_retrieval_trace(
make_dataset_retrieval_info(
@ -1113,7 +1122,7 @@ class TestDatasetRetrievalTrace:
assert attrs["dify.retrieval.rerank_provider"] == "cohere"
assert attrs["dify.retrieval.rerank_model"] == "rerank-english"
def test_no_rerank_attributes_when_not_present(self, trace_handler, mock_exporter):
def test_no_rerank_attributes_when_not_present(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._dataset_retrieval_trace(
make_dataset_retrieval_info(metadata={"app_id": "app-001", "tenant_id": "tenant-abc"})
@ -1123,7 +1132,7 @@ class TestDatasetRetrievalTrace:
assert "dify.retrieval.rerank_provider" not in attrs
assert "dify.retrieval.rerank_model" not in attrs
def test_dataset_retrieval_counter_incremented_per_dataset(self, trace_handler, mock_exporter):
def test_dataset_retrieval_counter_incremented_per_dataset(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"):
trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info())
@ -1135,7 +1144,7 @@ class TestDatasetRetrievalTrace:
assert len(ds_calls) == 1
assert ds_calls[0][0][2]["dataset_id"] == "ds-001"
def test_no_dataset_retrieval_counter_when_no_documents(self, trace_handler, mock_exporter):
def test_no_dataset_retrieval_counter_when_no_documents(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"):
trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info(documents=[]))
@ -1146,7 +1155,7 @@ class TestDatasetRetrievalTrace:
]
assert len(ds_calls) == 0
def test_query_gated_by_include_content(self, trace_handler, mock_exporter):
def test_query_gated_by_include_content(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
mock_exporter.include_content = False
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info())
@ -1161,34 +1170,34 @@ class TestDatasetRetrievalTrace:
class TestGenerateNameTrace:
def test_emits_event_with_correct_name(self, trace_handler, mock_exporter):
def test_emits_event_with_correct_name(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._generate_name_trace(make_generate_name_info())
assert mock_emit.call_args[1]["event_name"] == EnterpriseTelemetryEvent.GENERATE_NAME_EXECUTION
def test_duration_computed_from_timestamps(self, trace_handler, mock_exporter):
def test_duration_computed_from_timestamps(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._generate_name_trace(make_generate_name_info())
attrs = mock_emit.call_args[1]["attributes"]
assert attrs["dify.generate_name.duration"] == pytest.approx(5.0)
def test_no_duration_when_timestamps_missing(self, trace_handler, mock_exporter):
def test_no_duration_when_timestamps_missing(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._generate_name_trace(make_generate_name_info(start_time=None, end_time=None))
attrs = mock_emit.call_args[1]["attributes"]
assert attrs["dify.generate_name.duration"] is None
def test_status_succeeded_on_success(self, trace_handler, mock_exporter):
def test_status_succeeded_on_success(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._generate_name_trace(make_generate_name_info())
attrs = mock_emit.call_args[1]["attributes"]
assert attrs["dify.generate_name.status"] == "succeeded"
def test_status_failed_when_metadata_has_error(self, trace_handler, mock_exporter):
def test_status_failed_when_metadata_has_error(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._generate_name_trace(
make_generate_name_info(
@ -1203,7 +1212,7 @@ class TestGenerateNameTrace:
attrs = mock_emit.call_args[1]["attributes"]
assert attrs["dify.generate_name.status"] == "failed"
def test_inputs_and_outputs_gated_by_include_content(self, trace_handler, mock_exporter):
def test_inputs_and_outputs_gated_by_include_content(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
mock_exporter.include_content = False
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._generate_name_trace(make_generate_name_info())
@ -1212,7 +1221,7 @@ class TestGenerateNameTrace:
assert attrs["dify.generate_name.inputs"].startswith("ref:conversation_id=")
assert attrs["dify.generate_name.outputs"].startswith("ref:conversation_id=")
def test_increments_requests_counter(self, trace_handler, mock_exporter):
def test_increments_requests_counter(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"):
trace_handler._generate_name_trace(make_generate_name_info())
@ -1229,27 +1238,27 @@ class TestGenerateNameTrace:
class TestPromptGenerationTrace:
def test_emits_event_with_correct_name(self, trace_handler, mock_exporter):
def test_emits_event_with_correct_name(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._prompt_generation_trace(make_prompt_generation_info())
assert mock_emit.call_args[1]["event_name"] == EnterpriseTelemetryEvent.PROMPT_GENERATION_EXECUTION
def test_status_succeeded_on_success(self, trace_handler, mock_exporter):
def test_status_succeeded_on_success(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._prompt_generation_trace(make_prompt_generation_info())
attrs = mock_emit.call_args[1]["attributes"]
assert attrs["dify.prompt_generation.status"] == "succeeded"
def test_status_failed_when_error_present(self, trace_handler, mock_exporter):
def test_status_failed_when_error_present(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._prompt_generation_trace(make_prompt_generation_info(error="Generation error"))
attrs = mock_emit.call_args[1]["attributes"]
assert attrs["dify.prompt_generation.status"] == "failed"
def test_token_counters_incremented(self, trace_handler, mock_exporter):
def test_token_counters_incremented(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"):
trace_handler._prompt_generation_trace(make_prompt_generation_info())
@ -1258,7 +1267,7 @@ class TestPromptGenerationTrace:
assert EnterpriseTelemetryCounter.INPUT_TOKENS in counter_names
assert EnterpriseTelemetryCounter.OUTPUT_TOKENS in counter_names
def test_records_duration_histogram(self, trace_handler, mock_exporter):
def test_records_duration_histogram(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"):
trace_handler._prompt_generation_trace(make_prompt_generation_info())
@ -1270,7 +1279,7 @@ class TestPromptGenerationTrace:
assert len(hist_calls) == 1
assert hist_calls[0][0][1] == pytest.approx(3.2)
def test_total_price_attribute_set_when_present(self, trace_handler, mock_exporter):
def test_total_price_attribute_set_when_present(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._prompt_generation_trace(make_prompt_generation_info(total_price=0.05, currency="USD"))
@ -1278,14 +1287,14 @@ class TestPromptGenerationTrace:
assert attrs["dify.prompt_generation.total_price"] == pytest.approx(0.05)
assert attrs["dify.prompt_generation.currency"] == "USD"
def test_no_total_price_attribute_when_none(self, trace_handler, mock_exporter):
def test_no_total_price_attribute_when_none(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._prompt_generation_trace(make_prompt_generation_info(total_price=None))
attrs = mock_emit.call_args[1]["attributes"]
assert "dify.prompt_generation.total_price" not in attrs
def test_error_increments_error_counter(self, trace_handler, mock_exporter):
def test_error_increments_error_counter(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"):
trace_handler._prompt_generation_trace(make_prompt_generation_info(error="Prompt failed"))
@ -1294,7 +1303,7 @@ class TestPromptGenerationTrace:
]
assert len(error_calls) == 1
def test_no_error_counter_on_success(self, trace_handler, mock_exporter):
def test_no_error_counter_on_success(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"):
trace_handler._prompt_generation_trace(make_prompt_generation_info())
@ -1303,7 +1312,7 @@ class TestPromptGenerationTrace:
]
assert len(error_calls) == 0
def test_instruction_gated_by_include_content(self, trace_handler, mock_exporter):
def test_instruction_gated_by_include_content(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
mock_exporter.include_content = False
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._prompt_generation_trace(make_prompt_generation_info())
@ -1311,7 +1320,7 @@ class TestPromptGenerationTrace:
attrs = mock_emit.call_args[1]["attributes"]
assert attrs["dify.prompt_generation.instruction"].startswith("ref:trace_id=")
def test_operation_type_label_used_in_token_counters(self, trace_handler, mock_exporter):
def test_operation_type_label_used_in_token_counters(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"):
trace_handler._prompt_generation_trace(make_prompt_generation_info(operation_type="code_generate"))
@ -1321,7 +1330,7 @@ class TestPromptGenerationTrace:
assert len(token_calls) == 1
assert token_calls[0][0][2]["operation_type"] == "code_generate"
def test_emits_correct_tenant_id(self, trace_handler, mock_exporter):
def test_emits_correct_tenant_id(self, trace_handler: EnterpriseOtelTrace, mock_exporter):
with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit:
trace_handler._prompt_generation_trace(make_prompt_generation_info())