diff --git a/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py index 2bdf7cd1862..4d4898b1173 100644 --- a/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py +++ b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py @@ -181,29 +181,33 @@ class TestTencentDataTrace: mock_trace_utils.convert_to_trace_id.return_value = 123 mock_trace_utils.create_link.return_value = "link" - with patch.object(tencent_data_trace, "_get_user_id", return_value="user-1"): - with patch.object(tencent_data_trace, "_process_workflow_nodes") as mock_proc: - with patch.object(tencent_data_trace, "_record_workflow_trace_duration") as mock_dur: - mock_span_builder.build_workflow_spans.return_value = [MagicMock(), MagicMock()] + with ( + patch.object(tencent_data_trace, "_get_user_id", return_value="user-1"), + patch.object(tencent_data_trace, "_process_workflow_nodes") as mock_proc, + patch.object(tencent_data_trace, "_record_workflow_trace_duration") as mock_dur, + ): + mock_span_builder.build_workflow_spans.return_value = [MagicMock(), MagicMock()] - tencent_data_trace.workflow_trace(trace_info) + tencent_data_trace.workflow_trace(trace_info) - mock_trace_utils.convert_to_trace_id.assert_called_once_with("run-id") - mock_trace_utils.create_link.assert_called_once_with("parent-trace-id") - mock_span_builder.build_workflow_spans.assert_called_once() - assert tencent_data_trace.trace_client.add_span.call_count == 2 - mock_proc.assert_called_once_with(trace_info, 123) - mock_dur.assert_called_once_with(trace_info) + mock_trace_utils.convert_to_trace_id.assert_called_once_with("run-id") + mock_trace_utils.create_link.assert_called_once_with("parent-trace-id") + mock_span_builder.build_workflow_spans.assert_called_once() + assert tencent_data_trace.trace_client.add_span.call_count == 2 + mock_proc.assert_called_once_with(trace_info, 123) + mock_dur.assert_called_once_with(trace_info) - def test_workflow_trace_exception(self, tencent_data_trace, caplog): + def test_workflow_trace_exception(self, tencent_data_trace, caplog: pytest.LogCaptureFixture): trace_info = MagicMock(spec=WorkflowTraceInfo) trace_info.workflow_run_id = "run-id" - with patch( - "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error") + with ( + patch( + "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error") + ), + caplog.at_level(logging.ERROR), ): - with caplog.at_level(logging.ERROR): - tencent_data_trace.workflow_trace(trace_info) + tencent_data_trace.workflow_trace(trace_info) assert "[Tencent APM] Failed to process workflow trace" in caplog.text def test_message_trace(self, tencent_data_trace, mock_trace_utils, mock_span_builder): @@ -214,28 +218,32 @@ class TestTencentDataTrace: mock_trace_utils.convert_to_trace_id.return_value = 123 mock_trace_utils.create_link.return_value = "link" - with patch.object(tencent_data_trace, "_get_user_id", return_value="user-1"): - with patch.object(tencent_data_trace, "_record_message_llm_metrics") as mock_metrics: - with patch.object(tencent_data_trace, "_record_message_trace_duration") as mock_dur: - mock_span_builder.build_message_span.return_value = MagicMock() + with ( + patch.object(tencent_data_trace, "_get_user_id", return_value="user-1"), + patch.object(tencent_data_trace, "_record_message_llm_metrics") as mock_metrics, + patch.object(tencent_data_trace, "_record_message_trace_duration") as mock_dur, + ): + mock_span_builder.build_message_span.return_value = MagicMock() - tencent_data_trace.message_trace(trace_info) + tencent_data_trace.message_trace(trace_info) - mock_trace_utils.convert_to_trace_id.assert_called_once_with("msg-id") - mock_trace_utils.create_link.assert_called_once_with("parent-trace-id") - mock_span_builder.build_message_span.assert_called_once() - tencent_data_trace.trace_client.add_span.assert_called_once() - mock_metrics.assert_called_once_with(trace_info) - mock_dur.assert_called_once_with(trace_info) + mock_trace_utils.convert_to_trace_id.assert_called_once_with("msg-id") + mock_trace_utils.create_link.assert_called_once_with("parent-trace-id") + mock_span_builder.build_message_span.assert_called_once() + tencent_data_trace.trace_client.add_span.assert_called_once() + mock_metrics.assert_called_once_with(trace_info) + mock_dur.assert_called_once_with(trace_info) - def test_message_trace_exception(self, tencent_data_trace, caplog): + def test_message_trace_exception(self, tencent_data_trace, caplog: pytest.LogCaptureFixture): trace_info = MagicMock(spec=MessageTraceInfo) - with patch( - "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error") + with ( + patch( + "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error") + ), + caplog.at_level(logging.ERROR), ): - with caplog.at_level(logging.ERROR): - tencent_data_trace.message_trace(trace_info) + tencent_data_trace.message_trace(trace_info) assert "[Tencent APM] Failed to process message trace" in caplog.text def test_tool_trace(self, tencent_data_trace, mock_trace_utils, mock_span_builder): @@ -259,15 +267,17 @@ class TestTencentDataTrace: tencent_data_trace.tool_trace(trace_info) tencent_data_trace.trace_client.add_span.assert_not_called() - def test_tool_trace_exception(self, tencent_data_trace, caplog): + def test_tool_trace_exception(self, tencent_data_trace, caplog: pytest.LogCaptureFixture): trace_info = MagicMock(spec=ToolTraceInfo) trace_info.message_id = "msg-id" - with patch( - "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error") + with ( + patch( + "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error") + ), + caplog.at_level(logging.ERROR), ): - with caplog.at_level(logging.ERROR): - tencent_data_trace.tool_trace(trace_info) + tencent_data_trace.tool_trace(trace_info) assert "[Tencent APM] Failed to process tool trace" in caplog.text def test_dataset_retrieval_trace(self, tencent_data_trace, mock_trace_utils, mock_span_builder): @@ -291,24 +301,28 @@ class TestTencentDataTrace: tencent_data_trace.dataset_retrieval_trace(trace_info) tencent_data_trace.trace_client.add_span.assert_not_called() - def test_dataset_retrieval_trace_exception(self, tencent_data_trace, caplog): + def test_dataset_retrieval_trace_exception(self, tencent_data_trace, caplog: pytest.LogCaptureFixture): trace_info = MagicMock(spec=DatasetRetrievalTraceInfo) trace_info.message_id = "msg-id" - with patch( - "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error") + with ( + patch( + "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error") + ), + caplog.at_level(logging.ERROR), ): - with caplog.at_level(logging.ERROR): - tencent_data_trace.dataset_retrieval_trace(trace_info) + tencent_data_trace.dataset_retrieval_trace(trace_info) assert "[Tencent APM] Failed to process dataset retrieval trace" in caplog.text - def test_suggested_question_trace(self, tencent_data_trace, caplog): + def test_suggested_question_trace(self, tencent_data_trace, caplog: pytest.LogCaptureFixture): trace_info = MagicMock(spec=SuggestedQuestionTraceInfo) with caplog.at_level(logging.INFO): tencent_data_trace.suggested_question_trace(trace_info) assert "[Tencent APM] Processing suggested question trace" in caplog.text - def test_suggested_question_trace_exception(self, tencent_data_trace, monkeypatch, caplog): + def test_suggested_question_trace_exception( + self, tencent_data_trace, monkeypatch, caplog: pytest.LogCaptureFixture + ): trace_info = MagicMock(spec=SuggestedQuestionTraceInfo) target_logger = logging.getLogger("dify_trace_tencent.tencent_trace") monkeypatch.setattr(target_logger, "info", MagicMock(side_effect=Exception("error"))) @@ -328,28 +342,36 @@ class TestTencentDataTrace: node2.id = "n2" node2.node_type = BuiltinNodeTypes.TOOL - with patch.object(tencent_data_trace, "_get_workflow_node_executions", return_value=[node1, node2]): - with patch.object(tencent_data_trace, "_build_workflow_node_span", side_effect=["span1", "span2"]): - with patch.object(tencent_data_trace, "_record_llm_metrics") as mock_metrics: - tencent_data_trace._process_workflow_nodes(trace_info, 123) + with ( + patch.object(tencent_data_trace, "_get_workflow_node_executions", return_value=[node1, node2]), + patch.object(tencent_data_trace, "_build_workflow_node_span", side_effect=["span1", "span2"]), + patch.object(tencent_data_trace, "_record_llm_metrics") as mock_metrics, + ): + tencent_data_trace._process_workflow_nodes(trace_info, 123) - assert tencent_data_trace.trace_client.add_span.call_count == 2 - mock_metrics.assert_called_once_with(node1) + assert tencent_data_trace.trace_client.add_span.call_count == 2 + mock_metrics.assert_called_once_with(node1) - def test_process_workflow_nodes_node_exception(self, tencent_data_trace, mock_trace_utils, caplog): + def test_process_workflow_nodes_node_exception( + self, tencent_data_trace, mock_trace_utils, caplog: pytest.LogCaptureFixture + ): trace_info = MagicMock(spec=WorkflowTraceInfo) mock_trace_utils.convert_to_span_id.return_value = 111 node = MagicMock(spec=WorkflowNodeExecution) node.id = "n1" - with patch.object(tencent_data_trace, "_get_workflow_node_executions", return_value=[node]): - with patch.object(tencent_data_trace, "_build_workflow_node_span", side_effect=Exception("node error")): - with caplog.at_level(logging.ERROR): - tencent_data_trace._process_workflow_nodes(trace_info, 123) + with ( + patch.object(tencent_data_trace, "_get_workflow_node_executions", return_value=[node]), + patch.object(tencent_data_trace, "_build_workflow_node_span", side_effect=Exception("node error")), + caplog.at_level(logging.ERROR), + ): + tencent_data_trace._process_workflow_nodes(trace_info, 123) assert "[Tencent APM] Failed to process workflow nodes" in caplog.text - def test_process_workflow_nodes_exception(self, tencent_data_trace, mock_trace_utils, caplog): + def test_process_workflow_nodes_exception( + self, tencent_data_trace, mock_trace_utils, caplog: pytest.LogCaptureFixture + ): trace_info = MagicMock(spec=WorkflowTraceInfo) mock_trace_utils.convert_to_span_id.side_effect = Exception("outer error") @@ -377,7 +399,9 @@ class TestTencentDataTrace: assert result == "span" builder_method.assert_called_once_with(123, 456, trace_info, node) - def test_build_workflow_node_span_exception(self, tencent_data_trace, mock_span_builder, caplog): + def test_build_workflow_node_span_exception( + self, tencent_data_trace, mock_span_builder, caplog: pytest.LogCaptureFixture + ): node = MagicMock(spec=WorkflowNodeExecution) node.node_type = BuiltinNodeTypes.LLM node.id = "n1" @@ -419,7 +443,7 @@ class TestTencentDataTrace: assert results == mock_executions account.set_tenant_id.assert_called_once_with("tenant-1") - def test_get_workflow_node_executions_no_app_id(self, tencent_data_trace, caplog): + def test_get_workflow_node_executions_no_app_id(self, tencent_data_trace, caplog: pytest.LogCaptureFixture): trace_info = MagicMock(spec=WorkflowTraceInfo) trace_info.metadata = {} @@ -428,7 +452,7 @@ class TestTencentDataTrace: assert results == [] assert len([r for r in caplog.records if r.levelno == logging.ERROR]) >= 1 - def test_get_workflow_node_executions_app_not_found(self, tencent_data_trace, caplog): + def test_get_workflow_node_executions_app_not_found(self, tencent_data_trace, caplog: pytest.LogCaptureFixture): trace_info = MagicMock(spec=WorkflowTraceInfo) trace_info.metadata = {"app_id": "app-1"} @@ -449,13 +473,15 @@ class TestTencentDataTrace: trace_info.tenant_id = "tenant-1" trace_info.metadata = {"user_id": "user-1"} - with patch("dify_trace_tencent.tencent_trace.sessionmaker", side_effect=Exception("Database error")): - with patch("dify_trace_tencent.tencent_trace.db") as mock_db: - mock_db.init_app = MagicMock() - mock_db.engine = MagicMock() + with ( + patch("dify_trace_tencent.tencent_trace.sessionmaker", side_effect=Exception("Database error")), + patch("dify_trace_tencent.tencent_trace.db") as mock_db, + ): + mock_db.init_app = MagicMock() + mock_db.engine = MagicMock() - user_id = tencent_data_trace._get_user_id(trace_info) - assert user_id == "unknown" + user_id = tencent_data_trace._get_user_id(trace_info) + assert user_id == "unknown" def test_get_user_id_only_user_id(self, tencent_data_trace): trace_info = MagicMock(spec=MessageTraceInfo) @@ -471,14 +497,16 @@ class TestTencentDataTrace: user_id = tencent_data_trace._get_user_id(trace_info) assert user_id == "anonymous" - def test_get_user_id_exception(self, tencent_data_trace, caplog): + def test_get_user_id_exception(self, tencent_data_trace, caplog: pytest.LogCaptureFixture): trace_info = MagicMock(spec=WorkflowTraceInfo) trace_info.tenant_id = "t" trace_info.metadata = {"user_id": "u"} - with patch("dify_trace_tencent.tencent_trace.sessionmaker", side_effect=Exception("error")): - with caplog.at_level(logging.ERROR): - user_id = tencent_data_trace._get_user_id(trace_info) + with ( + patch("dify_trace_tencent.tencent_trace.sessionmaker", side_effect=Exception("error")), + caplog.at_level(logging.ERROR), + ): + user_id = tencent_data_trace._get_user_id(trace_info) assert user_id == "unknown" assert "[Tencent APM] Failed to get user ID" in caplog.text @@ -514,7 +542,7 @@ class TestTencentDataTrace: tencent_data_trace.trace_client.record_llm_duration.assert_called_once() tencent_data_trace.trace_client.record_token_usage.assert_called_once() - def test_record_llm_metrics_exception(self, tencent_data_trace, caplog): + def test_record_llm_metrics_exception(self, tencent_data_trace, caplog: pytest.LogCaptureFixture): node = MagicMock(spec=WorkflowNodeExecution) node.process_data = None node.outputs = None @@ -553,7 +581,7 @@ class TestTencentDataTrace: tencent_data_trace._record_message_llm_metrics(trace_info) tencent_data_trace.trace_client.record_llm_duration.assert_called_once() - def test_record_message_llm_metrics_exception(self, tencent_data_trace, caplog): + def test_record_message_llm_metrics_exception(self, tencent_data_trace, caplog: pytest.LogCaptureFixture): trace_info = MagicMock(spec=MessageTraceInfo) trace_info.metadata = None @@ -605,7 +633,7 @@ class TestTencentDataTrace: attributes = kwargs["attributes"] if "attributes" in kwargs else args[1] if len(args) > 1 else {} assert attributes["has_conversation"] == "false" - def test_record_workflow_trace_duration_exception(self, tencent_data_trace, caplog): + def test_record_workflow_trace_duration_exception(self, tencent_data_trace, caplog: pytest.LogCaptureFixture): trace_info = MagicMock(spec=WorkflowTraceInfo) trace_info.start_time = MagicMock() # This might cause total_seconds() to fail if not mocked right @@ -627,7 +655,7 @@ class TestTencentDataTrace: 2.0, {"conversation_mode": "chat", "stream": "true"} ) - def test_record_message_trace_duration_exception(self, tencent_data_trace, caplog): + def test_record_message_trace_duration_exception(self, tencent_data_trace, caplog: pytest.LogCaptureFixture): trace_info = MagicMock(spec=MessageTraceInfo) trace_info.start_time = None @@ -647,7 +675,7 @@ class TestTencentDataTrace: client.shutdown.assert_called_once() - def test_close_exception(self, tencent_data_trace, caplog): + def test_close_exception(self, tencent_data_trace, caplog: pytest.LogCaptureFixture): tencent_data_trace.trace_client.shutdown.side_effect = Exception("error") with caplog.at_level(logging.ERROR): tencent_data_trace.close() diff --git a/api/providers/vdb/vdb-weaviate/tests/unit_tests/test_weaviate_vector.py b/api/providers/vdb/vdb-weaviate/tests/unit_tests/test_weaviate_vector.py index 8aea8634fbe..79c06ea6028 100644 --- a/api/providers/vdb/vdb-weaviate/tests/unit_tests/test_weaviate_vector.py +++ b/api/providers/vdb/vdb-weaviate/tests/unit_tests/test_weaviate_vector.py @@ -268,9 +268,11 @@ class TestWeaviateVector(unittest.TestCase): wv._client = MagicMock() wv._client.collections.exists.side_effect = RuntimeError("create failed") - with patch.object(weaviate_vector_module.logger, "exception") as mock_exception: - with pytest.raises(RuntimeError, match="create failed"): - wv._create_collection() + with ( + patch.object(weaviate_vector_module.logger, "exception") as mock_exception, + pytest.raises(RuntimeError, match="create failed"), + ): + wv._create_collection() mock_exception.assert_called_once() @@ -835,9 +837,11 @@ class TestWeaviateVector(unittest.TestCase): wv._client.collections.use.return_value = mock_col mock_col.data.delete_by_id.side_effect = FakeUnexpectedStatusCodeError(500) - with patch.object(weaviate_vector_module, "UnexpectedStatusCodeError", FakeUnexpectedStatusCodeError): - with pytest.raises(FakeUnexpectedStatusCodeError, match="status=500"): - wv.delete_by_ids(["bad-id"]) + with ( + patch.object(weaviate_vector_module, "UnexpectedStatusCodeError", FakeUnexpectedStatusCodeError), + pytest.raises(FakeUnexpectedStatusCodeError, match="status=500"), + ): + wv.delete_by_ids(["bad-id"]) def test_json_serializable_converts_datetime(self): wv = WeaviateVector.__new__(WeaviateVector) diff --git a/api/tests/unit_tests/controllers/common/test_fields.py b/api/tests/unit_tests/controllers/common/test_fields.py index d4dc13127d6..595b3d93356 100644 --- a/api/tests/unit_tests/controllers/common/test_fields.py +++ b/api/tests/unit_tests/controllers/common/test_fields.py @@ -1,5 +1,4 @@ import builtins -from types import SimpleNamespace from unittest.mock import patch from flask.views import MethodView as FlaskMethodView @@ -22,7 +21,7 @@ def test_parameters_model_round_trip(): def test_site_icon_url_uses_signed_url_for_image_icon(): - site = SimpleNamespace( + site = Site( title="Example", chat_color_theme=None, chat_color_theme_inverted=False, @@ -46,7 +45,7 @@ def test_site_icon_url_uses_signed_url_for_image_icon(): def test_site_icon_url_is_none_for_non_image_icon(): - site = SimpleNamespace( + site = Site( title="Example", chat_color_theme=None, chat_color_theme_inverted=False, diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_convert_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_convert_api.py index dd254a31f63..942698db3bb 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow_convert_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_convert_api.py @@ -2,21 +2,14 @@ from __future__ import annotations +from inspect import unwrap from types import SimpleNamespace import pytest from flask import Flask from controllers.console.app import workflow as workflow_module - - -def _unwrap(func): - bound_self = getattr(func, "__self__", None) - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - if bound_self is not None: - return func.__get__(bound_self, bound_self.__class__) - return func +from controllers.console.app.workflow import ConvertToWorkflowApi class TestConvertToWorkflowApi: @@ -25,9 +18,9 @@ class TestConvertToWorkflowApi: return workflow_module.ConvertToWorkflowApi() def test_convert_to_workflow_attaches_permission_keys_when_rbac_enabled( - self, api, app: Flask, monkeypatch: pytest.MonkeyPatch + self, api: ConvertToWorkflowApi, app: Flask, monkeypatch: pytest.MonkeyPatch ) -> None: - method = _unwrap(api.post) + method = unwrap(api.post) monkeypatch.setattr( workflow_module, @@ -46,6 +39,7 @@ class TestConvertToWorkflowApi: json={}, ): response = method( + api, current_tenant_id="tenant-1", current_user=SimpleNamespace(id="u1"), app_model=SimpleNamespace(id="app-1"), diff --git a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py index 2c845673cd1..42495d7f63c 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py +++ b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py @@ -192,7 +192,9 @@ class TestLoginApi: @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") - def test_login_fails_when_rate_limited(self, mock_get_invitation, mock_is_rate_limit, mock_db, app: Flask, caplog): + def test_login_fails_when_rate_limited( + self, mock_get_invitation, mock_is_rate_limit, mock_db, app: Flask, caplog: pytest.LogCaptureFixture + ): """ Test login rejection when rate limit is exceeded. @@ -222,7 +224,9 @@ class TestLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", True) @patch("controllers.console.auth.login.BillingService.is_email_in_freeze") - def test_login_fails_when_account_frozen(self, mock_is_frozen, mock_db, app: Flask, caplog): + def test_login_fails_when_account_frozen( + self, mock_is_frozen, mock_db, app: Flask, caplog: pytest.LogCaptureFixture + ): """ Test login rejection for frozen accounts. @@ -262,7 +266,7 @@ class TestLoginApi: mock_is_rate_limit, mock_db, app: Flask, - caplog, + caplog: pytest.LogCaptureFixture, ): """ Test login failure with invalid credentials. @@ -462,7 +466,7 @@ class TestLoginApi: mock_get_token_data: MagicMock, mock_db: MagicMock, app: Flask, - caplog, + caplog: pytest.LogCaptureFixture, ): mock_get_token_data.return_value = {"email": "User@Example.com", "code": "123456"} mock_get_account.side_effect = Unauthorized("Account is banned.") diff --git a/api/tests/unit_tests/controllers/console/snippets/test_snippet_workflow_draft_variable.py b/api/tests/unit_tests/controllers/console/snippets/test_snippet_workflow_draft_variable.py index e6bee6fe1d3..03a6fdb0d60 100644 --- a/api/tests/unit_tests/controllers/console/snippets/test_snippet_workflow_draft_variable.py +++ b/api/tests/unit_tests/controllers/console/snippets/test_snippet_workflow_draft_variable.py @@ -1,3 +1,4 @@ +from inspect import unwrap from types import SimpleNamespace from unittest.mock import Mock @@ -10,12 +11,6 @@ from models.account import Account, AccountStatus from services.workflow_draft_variable_service import WorkflowDraftVariableList -def _unwrap(func): - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - return func - - def _make_account() -> Account: account = Account( name="tester", @@ -66,7 +61,7 @@ def test_ensure_snippet_draft_variable_row_allowed_accepts_canvas_node_variable( def test_conversation_variables_returns_empty_list(app: Flask): api = module.SnippetConversationVariableCollectionApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) with app.test_request_context("/"): result = handler(api, _make_account(), snippet=SimpleNamespace(id="snippet-1")) @@ -76,7 +71,7 @@ def test_conversation_variables_returns_empty_list(app: Flask): def test_system_variables_returns_empty_list(app: Flask): api = module.SnippetSystemVariableCollectionApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) with app.test_request_context("/"): result = handler(api, _make_account(), snippet=SimpleNamespace(id="snippet-1")) @@ -91,7 +86,7 @@ def test_delete_variable_collection_deletes_current_user_variables(app: Flask, m db_session.return_value = SimpleNamespace() monkeypatch.setattr(module.db, "session", db_session) api = module.SnippetWorkflowVariableCollectionApi() - handler = _unwrap(api.delete) + handler = unwrap(api.delete) with app.test_request_context("/", method="DELETE"): response = handler(api, _make_account(), snippet=SimpleNamespace(id="snippet-1")) @@ -109,7 +104,7 @@ def test_variable_collection_get_raises_when_draft_workflow_missing(app: Flask, ) api = module.SnippetWorkflowVariableCollectionApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) with app.test_request_context("/?page=1&limit=20"): with pytest.raises(module.DraftWorkflowNotExist): @@ -140,7 +135,7 @@ def test_node_variable_collection_get_lists_node_variables(app: Flask, monkeypat ) api = module.SnippetNodeVariableCollectionApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) with app.test_request_context("/"): result = handler(api, _make_account(), snippet=SimpleNamespace(id="snippet-1"), node_id="llm-1") @@ -158,7 +153,7 @@ def test_node_variable_collection_delete_deletes_node_variables(app: Flask, monk monkeypatch.setattr(module.db, "session", db_session) api = module.SnippetNodeVariableCollectionApi() - handler = _unwrap(api.delete) + handler = unwrap(api.delete) with app.test_request_context("/", method="DELETE"): response = handler(api, _make_account(), snippet=SimpleNamespace(id="snippet-1"), node_id="llm-1") @@ -177,7 +172,7 @@ def test_variable_patch_returns_variable_when_no_changes(app: Flask, monkeypatch monkeypatch.setattr(module, "WorkflowDraftVariableService", Mock(return_value=draft_var_service)) api = module.SnippetVariableApi() - handler = _unwrap(api.patch) + handler = unwrap(api.patch) with app.test_request_context("/", method="PATCH", json={}): result = handler( @@ -202,7 +197,7 @@ def test_variable_delete_deletes_variable(app: Flask, monkeypatch: pytest.Monkey monkeypatch.setattr(module, "WorkflowDraftVariableService", Mock(return_value=draft_var_service)) api = module.SnippetVariableApi() - handler = _unwrap(api.delete) + handler = unwrap(api.delete) with app.test_request_context("/", method="DELETE"): response = handler(api, _make_account(), snippet=SimpleNamespace(id="snippet-1"), variable_id="var-1") @@ -230,7 +225,7 @@ def test_variable_reset_returns_no_content_when_reset_result_is_none(app: Flask, ) api = module.SnippetVariableResetApi() - handler = _unwrap(api.put) + handler = unwrap(api.put) with app.test_request_context("/", method="PUT"): response = handler(api, _make_account(), snippet=SimpleNamespace(id="snippet-1"), variable_id="var-1") @@ -260,7 +255,7 @@ def test_environment_variables_returns_workflow_environment_variables(app: Flask ) api = module.SnippetEnvironmentVariableCollectionApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) with app.test_request_context("/"): result = handler(api, _make_account(), snippet=SimpleNamespace(id="snippet-1")) diff --git a/api/tests/unit_tests/enterprise/telemetry/test_enterprise_trace.py b/api/tests/unit_tests/enterprise/telemetry/test_enterprise_trace.py index 1ce9581aa1c..24c905c75c2 100644 --- a/api/tests/unit_tests/enterprise/telemetry/test_enterprise_trace.py +++ b/api/tests/unit_tests/enterprise/telemetry/test_enterprise_trace.py @@ -21,6 +21,7 @@ from core.ops.entities.trace_entity import ( WorkflowNodeTraceInfo, WorkflowTraceInfo, ) +from enterprise.telemetry.enterprise_trace import EnterpriseOtelTrace from enterprise.telemetry.entities import ( EnterpriseTelemetryCounter, EnterpriseTelemetryEvent, @@ -297,43 +298,43 @@ def test_init_succeeds_with_valid_exporter(mock_exporter): class TestSafePayloadValue: - def test_string_passthrough(self, trace_handler): + def test_string_passthrough(self, trace_handler: EnterpriseOtelTrace): assert trace_handler._safe_payload_value("hello") == "hello" - def test_dict_passthrough(self, trace_handler): + def test_dict_passthrough(self, trace_handler: EnterpriseOtelTrace): d = {"key": "val"} assert trace_handler._safe_payload_value(d) == d - def test_list_passthrough(self, trace_handler): + def test_list_passthrough(self, trace_handler: EnterpriseOtelTrace): lst = [1, 2, 3] assert trace_handler._safe_payload_value(lst) == lst - def test_none_returns_none(self, trace_handler): + def test_none_returns_none(self, trace_handler: EnterpriseOtelTrace): assert trace_handler._safe_payload_value(None) is None - def test_int_returns_none(self, trace_handler): + def test_int_returns_none(self, trace_handler: EnterpriseOtelTrace): assert trace_handler._safe_payload_value(42) is None - def test_bool_returns_none(self, trace_handler): + def test_bool_returns_none(self, trace_handler: EnterpriseOtelTrace): assert trace_handler._safe_payload_value(True) is None class TestMaybeJson: - def test_none_returns_none(self, trace_handler): + def test_none_returns_none(self, trace_handler: EnterpriseOtelTrace): assert trace_handler._maybe_json(None) is None - def test_string_passthrough(self, trace_handler): + def test_string_passthrough(self, trace_handler: EnterpriseOtelTrace): assert trace_handler._maybe_json("hello") == "hello" - def test_dict_serialised(self, trace_handler): + def test_dict_serialised(self, trace_handler: EnterpriseOtelTrace): result = trace_handler._maybe_json({"a": 1}) assert result == json.dumps({"a": 1}) - def test_list_serialised(self, trace_handler): + def test_list_serialised(self, trace_handler: EnterpriseOtelTrace): result = trace_handler._maybe_json([1, 2]) assert result == "[1, 2]" - def test_non_serialisable_falls_back_to_str(self, trace_handler): + def test_non_serialisable_falls_back_to_str(self, trace_handler: EnterpriseOtelTrace): class Unserializable: def __repr__(self): return "Unserializable()" @@ -344,22 +345,22 @@ class TestMaybeJson: class TestContentOrRef: - def test_returns_content_when_include_content_true(self, trace_handler, mock_exporter): + def test_returns_content_when_include_content_true(self, trace_handler: EnterpriseOtelTrace, mock_exporter): mock_exporter.include_content = True result = trace_handler._content_or_ref("actual content", "ref:x=1") assert result == "actual content" - def test_returns_ref_when_include_content_false(self, trace_handler, mock_exporter): + def test_returns_ref_when_include_content_false(self, trace_handler: EnterpriseOtelTrace, mock_exporter): mock_exporter.include_content = False result = trace_handler._content_or_ref("actual content", "ref:x=1") assert result == "ref:x=1" - def test_dict_serialised_when_include_content_true(self, trace_handler, mock_exporter): + def test_dict_serialised_when_include_content_true(self, trace_handler: EnterpriseOtelTrace, mock_exporter): mock_exporter.include_content = True result = trace_handler._content_or_ref({"key": "val"}, "ref:x=1") assert result == json.dumps({"key": "val"}) - def test_none_returns_none_when_include_content_true(self, trace_handler, mock_exporter): + def test_none_returns_none_when_include_content_true(self, trace_handler: EnterpriseOtelTrace, mock_exporter): mock_exporter.include_content = True result = trace_handler._content_or_ref(None, "ref:x=1") assert result is None @@ -371,67 +372,67 @@ class TestContentOrRef: class TestTraceDispatcher: - def test_dispatches_workflow_trace(self, trace_handler): + def test_dispatches_workflow_trace(self, trace_handler: EnterpriseOtelTrace): with patch.object(trace_handler, "_workflow_trace") as mock_method: info = make_workflow_info() trace_handler.trace(info) mock_method.assert_called_once_with(info) - def test_dispatches_message_trace(self, trace_handler): + def test_dispatches_message_trace(self, trace_handler: EnterpriseOtelTrace): with patch.object(trace_handler, "_message_trace") as mock_method: info = make_message_info() trace_handler.trace(info) mock_method.assert_called_once_with(info) - def test_dispatches_tool_trace(self, trace_handler): + def test_dispatches_tool_trace(self, trace_handler: EnterpriseOtelTrace): with patch.object(trace_handler, "_tool_trace") as mock_method: info = make_tool_info() trace_handler.trace(info) mock_method.assert_called_once_with(info) - def test_dispatches_draft_node_execution_trace(self, trace_handler): + def test_dispatches_draft_node_execution_trace(self, trace_handler: EnterpriseOtelTrace): with patch.object(trace_handler, "_draft_node_execution_trace") as mock_method: info = make_draft_node_info() trace_handler.trace(info) mock_method.assert_called_once_with(info) - def test_dispatches_node_execution_trace(self, trace_handler): + def test_dispatches_node_execution_trace(self, trace_handler: EnterpriseOtelTrace): with patch.object(trace_handler, "_node_execution_trace") as mock_method: info = make_node_info() trace_handler.trace(info) mock_method.assert_called_once_with(info) - def test_dispatches_moderation_trace(self, trace_handler): + def test_dispatches_moderation_trace(self, trace_handler: EnterpriseOtelTrace): with patch.object(trace_handler, "_moderation_trace") as mock_method: info = make_moderation_info() trace_handler.trace(info) mock_method.assert_called_once_with(info) - def test_dispatches_suggested_question_trace(self, trace_handler): + def test_dispatches_suggested_question_trace(self, trace_handler: EnterpriseOtelTrace): with patch.object(trace_handler, "_suggested_question_trace") as mock_method: info = make_suggested_question_info() trace_handler.trace(info) mock_method.assert_called_once_with(info) - def test_dispatches_dataset_retrieval_trace(self, trace_handler): + def test_dispatches_dataset_retrieval_trace(self, trace_handler: EnterpriseOtelTrace): with patch.object(trace_handler, "_dataset_retrieval_trace") as mock_method: info = make_dataset_retrieval_info() trace_handler.trace(info) mock_method.assert_called_once_with(info) - def test_dispatches_generate_name_trace(self, trace_handler): + def test_dispatches_generate_name_trace(self, trace_handler: EnterpriseOtelTrace): with patch.object(trace_handler, "_generate_name_trace") as mock_method: info = make_generate_name_info() trace_handler.trace(info) mock_method.assert_called_once_with(info) - def test_dispatches_prompt_generation_trace(self, trace_handler): + def test_dispatches_prompt_generation_trace(self, trace_handler: EnterpriseOtelTrace): with patch.object(trace_handler, "_prompt_generation_trace") as mock_method: info = make_prompt_generation_info() trace_handler.trace(info) mock_method.assert_called_once_with(info) - def test_draft_node_dispatched_before_node(self, trace_handler): + def test_draft_node_dispatched_before_node(self, trace_handler: EnterpriseOtelTrace): """DraftNodeExecutionTrace is a subclass of WorkflowNodeTraceInfo; it must be dispatched to _draft_node_execution_trace, not _node_execution_trace.""" with ( @@ -450,7 +451,7 @@ class TestTraceDispatcher: class TestWorkflowTrace: - def test_emits_correct_span_attributes(self, trace_handler, mock_exporter): + def test_emits_correct_span_attributes(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log") as mock_log: info = make_workflow_info() trace_handler._workflow_trace(info) @@ -465,7 +466,7 @@ class TestWorkflowTrace: assert attrs["dify.workflow.status"] == "succeeded" assert attrs["gen_ai.usage.total_tokens"] == 100 - def test_span_timing_passed_correctly(self, trace_handler, mock_exporter): + def test_span_timing_passed_correctly(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): info = make_workflow_info() trace_handler._workflow_trace(info) @@ -474,7 +475,7 @@ class TestWorkflowTrace: assert span_call[1]["start_time"] == _T0 assert span_call[1]["end_time"] == _T1 - def test_emits_companion_log_with_event_name(self, trace_handler, mock_exporter): + def test_emits_companion_log_with_event_name(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log") as mock_log: trace_handler._workflow_trace(make_workflow_info()) @@ -482,7 +483,7 @@ class TestWorkflowTrace: assert mock_log.call_args[1]["event_name"] == EnterpriseTelemetryEvent.WORKFLOW_RUN assert mock_log.call_args[1]["tenant_id"] == "tenant-abc" - def test_companion_log_includes_content_when_enabled(self, trace_handler, mock_exporter): + def test_companion_log_includes_content_when_enabled(self, trace_handler: EnterpriseOtelTrace, mock_exporter): mock_exporter.include_content = True with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log") as mock_log: trace_handler._workflow_trace(make_workflow_info()) @@ -491,7 +492,7 @@ class TestWorkflowTrace: assert log_attrs["dify.workflow.inputs"] == json.dumps({"query": "hello"}) assert log_attrs["dify.workflow.outputs"] == json.dumps({"answer": "world"}) - def test_companion_log_uses_ref_when_content_disabled(self, trace_handler, mock_exporter): + def test_companion_log_uses_ref_when_content_disabled(self, trace_handler: EnterpriseOtelTrace, mock_exporter): mock_exporter.include_content = False with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log") as mock_log: trace_handler._workflow_trace(make_workflow_info()) @@ -500,7 +501,7 @@ class TestWorkflowTrace: assert log_attrs["dify.workflow.inputs"].startswith("ref:workflow_run_id=") assert log_attrs["dify.workflow.outputs"].startswith("ref:workflow_run_id=") - def test_increments_token_counter(self, trace_handler, mock_exporter): + def test_increments_token_counter(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): trace_handler._workflow_trace(make_workflow_info()) @@ -510,7 +511,7 @@ class TestWorkflowTrace: assert len(token_calls) == 1 assert token_calls[0][0][1] == 100 - def test_increments_input_and_output_token_counters(self, trace_handler, mock_exporter): + def test_increments_input_and_output_token_counters(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): trace_handler._workflow_trace(make_workflow_info()) @@ -519,7 +520,7 @@ class TestWorkflowTrace: assert EnterpriseTelemetryCounter.INPUT_TOKENS in counter_names assert EnterpriseTelemetryCounter.OUTPUT_TOKENS in counter_names - def test_no_input_token_counter_when_prompt_tokens_zero(self, trace_handler, mock_exporter): + def test_no_input_token_counter_when_prompt_tokens_zero(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): info = make_workflow_info(prompt_tokens=0) trace_handler._workflow_trace(info) @@ -528,7 +529,7 @@ class TestWorkflowTrace: counter_names = [c[0][0] for c in all_calls] assert EnterpriseTelemetryCounter.INPUT_TOKENS not in counter_names - def test_records_workflow_duration_histogram(self, trace_handler, mock_exporter): + def test_records_workflow_duration_histogram(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): trace_handler._workflow_trace(make_workflow_info()) @@ -537,7 +538,9 @@ class TestWorkflowTrace: assert hist_call[0][0] == EnterpriseTelemetryHistogram.WORKFLOW_DURATION assert hist_call[0][1] == pytest.approx(5.0) - def test_duration_falls_back_to_elapsed_time_when_timestamps_missing(self, trace_handler, mock_exporter): + def test_duration_falls_back_to_elapsed_time_when_timestamps_missing( + self, trace_handler: EnterpriseOtelTrace, mock_exporter + ): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): info = make_workflow_info(start_time=None, end_time=None, workflow_run_elapsed_time=7.3) trace_handler._workflow_trace(info) @@ -545,7 +548,7 @@ class TestWorkflowTrace: hist_call = mock_exporter.record_histogram.call_args assert hist_call[0][1] == pytest.approx(7.3) - def test_duration_defaults_to_zero_when_no_timing(self, trace_handler, mock_exporter): + def test_duration_defaults_to_zero_when_no_timing(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): info = make_workflow_info(start_time=None, end_time=None, workflow_run_elapsed_time=0) trace_handler._workflow_trace(info) @@ -553,7 +556,7 @@ class TestWorkflowTrace: hist_call = mock_exporter.record_histogram.call_args assert hist_call[0][1] == pytest.approx(0.0) - def test_error_path_increments_error_counter(self, trace_handler, mock_exporter): + def test_error_path_increments_error_counter(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): info = make_workflow_info(error="Something went wrong", workflow_run_status="failed") trace_handler._workflow_trace(info) @@ -563,7 +566,7 @@ class TestWorkflowTrace: ] assert len(error_calls) == 1 - def test_no_error_counter_on_success(self, trace_handler, mock_exporter): + def test_no_error_counter_on_success(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): trace_handler._workflow_trace(make_workflow_info()) @@ -572,7 +575,7 @@ class TestWorkflowTrace: ] assert len(error_calls) == 0 - def test_parent_trace_context_injected_into_span_attrs(self, trace_handler, mock_exporter): + def test_parent_trace_context_injected_into_span_attrs(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): info = make_workflow_info( metadata={ @@ -601,14 +604,14 @@ class TestWorkflowTrace: class TestNodeExecutionTrace: - def test_emits_span_with_node_execution_span_name(self, trace_handler, mock_exporter): + def test_emits_span_with_node_execution_span_name(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): trace_handler._node_execution_trace(make_node_info()) span_call = mock_exporter.export_span.call_args assert span_call[0][0] == EnterpriseTelemetrySpan.NODE_EXECUTION - def test_span_contains_core_node_attributes(self, trace_handler, mock_exporter): + def test_span_contains_core_node_attributes(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): trace_handler._node_execution_trace(make_node_info()) @@ -620,7 +623,7 @@ class TestNodeExecutionTrace: assert attrs["gen_ai.request.model"] == "gpt-4" assert attrs["gen_ai.provider.name"] == "openai" - def test_increments_token_counters_when_tokens_present(self, trace_handler, mock_exporter): + def test_increments_token_counters_when_tokens_present(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): trace_handler._node_execution_trace(make_node_info()) @@ -629,7 +632,7 @@ class TestNodeExecutionTrace: assert EnterpriseTelemetryCounter.INPUT_TOKENS in counter_names assert EnterpriseTelemetryCounter.OUTPUT_TOKENS in counter_names - def test_no_token_counters_when_total_tokens_zero(self, trace_handler, mock_exporter): + def test_no_token_counters_when_total_tokens_zero(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): trace_handler._node_execution_trace(make_node_info(total_tokens=0)) @@ -637,7 +640,7 @@ class TestNodeExecutionTrace: assert EnterpriseTelemetryCounter.TOKENS not in counter_names assert EnterpriseTelemetryCounter.INPUT_TOKENS not in counter_names - def test_records_node_duration_histogram(self, trace_handler, mock_exporter): + def test_records_node_duration_histogram(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): trace_handler._node_execution_trace(make_node_info()) @@ -645,7 +648,7 @@ class TestNodeExecutionTrace: assert hist_call[0][0] == EnterpriseTelemetryHistogram.NODE_DURATION assert hist_call[0][1] == pytest.approx(2.5) - def test_error_path_increments_error_counter(self, trace_handler, mock_exporter): + def test_error_path_increments_error_counter(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): trace_handler._node_execution_trace(make_node_info(error="Node failed", status="failed")) @@ -654,14 +657,16 @@ class TestNodeExecutionTrace: ] assert len(error_calls) == 1 - def test_emits_companion_log_with_span_name_as_event(self, trace_handler, mock_exporter): + def test_emits_companion_log_with_span_name_as_event(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log") as mock_log: trace_handler._node_execution_trace(make_node_info()) mock_log.assert_called_once() assert mock_log.call_args[1]["event_name"] == EnterpriseTelemetrySpan.NODE_EXECUTION.value - def test_plugin_name_added_to_duration_labels_for_tool_node(self, trace_handler, mock_exporter): + def test_plugin_name_added_to_duration_labels_for_tool_node( + self, trace_handler: EnterpriseOtelTrace, mock_exporter + ): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): info = make_node_info( node_type="tool", @@ -677,7 +682,7 @@ class TestNodeExecutionTrace: duration_labels = hist_call[0][2] assert duration_labels.get("plugin_name") == "my-plugin" - def test_plugin_name_not_added_for_non_tool_node(self, trace_handler, mock_exporter): + def test_plugin_name_not_added_for_non_tool_node(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): info = make_node_info( node_type="llm", @@ -693,7 +698,9 @@ class TestNodeExecutionTrace: duration_labels = hist_call[0][2] assert "plugin_name" not in duration_labels - def test_companion_log_inputs_use_ref_when_content_disabled(self, trace_handler, mock_exporter): + def test_companion_log_inputs_use_ref_when_content_disabled( + self, trace_handler: EnterpriseOtelTrace, mock_exporter + ): mock_exporter.include_content = False with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log") as mock_log: trace_handler._node_execution_trace( @@ -711,14 +718,14 @@ class TestNodeExecutionTrace: class TestDraftNodeExecutionTrace: - def test_uses_draft_span_name(self, trace_handler, mock_exporter): + def test_uses_draft_span_name(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): trace_handler._draft_node_execution_trace(make_draft_node_info()) span_call = mock_exporter.export_span.call_args assert span_call[0][0] == EnterpriseTelemetrySpan.DRAFT_NODE_EXECUTION - def test_correlation_id_is_node_execution_id(self, trace_handler, mock_exporter): + def test_correlation_id_is_node_execution_id(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): info = make_draft_node_info() trace_handler._draft_node_execution_trace(info) @@ -726,7 +733,7 @@ class TestDraftNodeExecutionTrace: span_call = mock_exporter.export_span.call_args assert span_call[1]["correlation_id"] == "ne-draft-001" - def test_trace_correlation_override_is_workflow_run_id(self, trace_handler, mock_exporter): + def test_trace_correlation_override_is_workflow_run_id(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): info = make_draft_node_info() trace_handler._draft_node_execution_trace(info) @@ -734,7 +741,7 @@ class TestDraftNodeExecutionTrace: span_call = mock_exporter.export_span.call_args assert span_call[1]["trace_correlation_override"] == "run-draft-001" - def test_companion_log_uses_draft_span_name(self, trace_handler, mock_exporter): + def test_companion_log_uses_draft_span_name(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log") as mock_log: trace_handler._draft_node_execution_trace(make_draft_node_info()) @@ -747,34 +754,36 @@ class TestDraftNodeExecutionTrace: class TestMessageTrace: - def test_emits_event_with_correct_name(self, trace_handler, mock_exporter): + def test_emits_event_with_correct_name(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._message_trace(make_message_info()) mock_emit.assert_called_once() assert mock_emit.call_args[1]["event_name"] == EnterpriseTelemetryEvent.MESSAGE_RUN - def test_emits_correct_tenant_and_user(self, trace_handler, mock_exporter): + def test_emits_correct_tenant_and_user(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._message_trace(make_message_info()) assert mock_emit.call_args[1]["tenant_id"] == "tenant-abc" - def test_duration_computed_from_timestamps(self, trace_handler, mock_exporter): + def test_duration_computed_from_timestamps(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._message_trace(make_message_info()) attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.message.duration"] == pytest.approx(5.0) - def test_no_duration_when_timestamps_missing(self, trace_handler, mock_exporter): + def test_no_duration_when_timestamps_missing(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._message_trace(make_message_info(start_time=None, end_time=None)) attrs = mock_emit.call_args[1]["attributes"] assert "dify.message.duration" not in attrs - def test_records_duration_histogram_when_timestamps_present(self, trace_handler, mock_exporter): + def test_records_duration_histogram_when_timestamps_present( + self, trace_handler: EnterpriseOtelTrace, mock_exporter + ): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): trace_handler._message_trace(make_message_info()) @@ -786,14 +795,14 @@ class TestMessageTrace: assert len(hist_calls) == 1 assert hist_calls[0][0][1] == pytest.approx(5.0) - def test_no_duration_histogram_when_timestamps_missing(self, trace_handler, mock_exporter): + def test_no_duration_histogram_when_timestamps_missing(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): trace_handler._message_trace(make_message_info(start_time=None, end_time=None)) hist_names = [c[0][0] for c in mock_exporter.record_histogram.call_args_list] assert EnterpriseTelemetryHistogram.MESSAGE_DURATION not in hist_names - def test_records_ttft_histogram_when_present(self, trace_handler, mock_exporter): + def test_records_ttft_histogram_when_present(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): trace_handler._message_trace(make_message_info(gen_ai_server_time_to_first_token=0.42)) @@ -805,14 +814,14 @@ class TestMessageTrace: assert len(ttft_calls) == 1 assert ttft_calls[0][0][1] == pytest.approx(0.42) - def test_no_ttft_histogram_when_not_present(self, trace_handler, mock_exporter): + def test_no_ttft_histogram_when_not_present(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): trace_handler._message_trace(make_message_info(gen_ai_server_time_to_first_token=None)) hist_names = [c[0][0] for c in mock_exporter.record_histogram.call_args_list] assert EnterpriseTelemetryHistogram.MESSAGE_TTFT not in hist_names - def test_increments_token_counters(self, trace_handler, mock_exporter): + def test_increments_token_counters(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): trace_handler._message_trace(make_message_info()) @@ -821,7 +830,7 @@ class TestMessageTrace: assert EnterpriseTelemetryCounter.INPUT_TOKENS in counter_names assert EnterpriseTelemetryCounter.OUTPUT_TOKENS in counter_names - def test_error_path_increments_error_counter(self, trace_handler, mock_exporter): + def test_error_path_increments_error_counter(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): trace_handler._message_trace(make_message_info(error="LLM failed")) @@ -830,7 +839,7 @@ class TestMessageTrace: ] assert len(error_calls) == 1 - def test_inputs_and_outputs_gated_by_include_content(self, trace_handler, mock_exporter): + def test_inputs_and_outputs_gated_by_include_content(self, trace_handler: EnterpriseOtelTrace, mock_exporter): mock_exporter.include_content = False with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._message_trace(make_message_info()) @@ -846,27 +855,27 @@ class TestMessageTrace: class TestToolTrace: - def test_emits_event_with_correct_name(self, trace_handler, mock_exporter): + def test_emits_event_with_correct_name(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._tool_trace(make_tool_info()) assert mock_emit.call_args[1]["event_name"] == EnterpriseTelemetryEvent.TOOL_EXECUTION - def test_status_is_succeeded_on_success(self, trace_handler, mock_exporter): + def test_status_is_succeeded_on_success(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._tool_trace(make_tool_info()) attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.tool.status"] == "succeeded" - def test_status_is_failed_on_error(self, trace_handler, mock_exporter): + def test_status_is_failed_on_error(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._tool_trace(make_tool_info(error="Tool error")) attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.tool.status"] == "failed" - def test_records_tool_duration_histogram(self, trace_handler, mock_exporter): + def test_records_tool_duration_histogram(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): trace_handler._tool_trace(make_tool_info()) @@ -874,7 +883,7 @@ class TestToolTrace: assert hist_call[0][0] == EnterpriseTelemetryHistogram.TOOL_DURATION assert hist_call[0][1] == pytest.approx(1.5) - def test_error_increments_error_counter(self, trace_handler, mock_exporter): + def test_error_increments_error_counter(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): trace_handler._tool_trace(make_tool_info(error="Tool crashed")) @@ -883,7 +892,7 @@ class TestToolTrace: ] assert len(error_calls) == 1 - def test_inputs_and_outputs_gated_by_include_content(self, trace_handler, mock_exporter): + def test_inputs_and_outputs_gated_by_include_content(self, trace_handler: EnterpriseOtelTrace, mock_exporter): mock_exporter.include_content = False with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._tool_trace(make_tool_info()) @@ -892,7 +901,7 @@ class TestToolTrace: assert attrs["dify.tool.inputs"].startswith("ref:message_id=") assert attrs["dify.tool.outputs"].startswith("ref:message_id=") - def test_inputs_present_when_include_content_true(self, trace_handler, mock_exporter): + def test_inputs_present_when_include_content_true(self, trace_handler: EnterpriseOtelTrace, mock_exporter): mock_exporter.include_content = True with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._tool_trace(make_tool_info()) @@ -901,7 +910,7 @@ class TestToolTrace: assert attrs["dify.tool.inputs"] == json.dumps({"query": "test"}) assert attrs["dify.tool.outputs"] == "search results" - def test_increments_requests_counter(self, trace_handler, mock_exporter): + def test_increments_requests_counter(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): trace_handler._tool_trace(make_tool_info()) @@ -918,27 +927,27 @@ class TestToolTrace: class TestModerationTrace: - def test_emits_event_with_correct_name(self, trace_handler, mock_exporter): + def test_emits_event_with_correct_name(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._moderation_trace(make_moderation_info()) assert mock_emit.call_args[1]["event_name"] == EnterpriseTelemetryEvent.MODERATION_CHECK - def test_flagged_true_sets_attribute(self, trace_handler, mock_exporter): + def test_flagged_true_sets_attribute(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._moderation_trace(make_moderation_info(flagged=True)) attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.moderation.flagged"] is True - def test_flagged_false_sets_attribute(self, trace_handler, mock_exporter): + def test_flagged_false_sets_attribute(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._moderation_trace(make_moderation_info(flagged=False)) attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.moderation.flagged"] is False - def test_query_gated_by_include_content(self, trace_handler, mock_exporter): + def test_query_gated_by_include_content(self, trace_handler: EnterpriseOtelTrace, mock_exporter): mock_exporter.include_content = False with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._moderation_trace(make_moderation_info()) @@ -946,7 +955,7 @@ class TestModerationTrace: attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.moderation.query"].startswith("ref:message_id=") - def test_query_present_when_include_content_true(self, trace_handler, mock_exporter): + def test_query_present_when_include_content_true(self, trace_handler: EnterpriseOtelTrace, mock_exporter): mock_exporter.include_content = True with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._moderation_trace(make_moderation_info()) @@ -954,7 +963,7 @@ class TestModerationTrace: attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.moderation.query"] == "is this ok?" - def test_increments_requests_counter(self, trace_handler, mock_exporter): + def test_increments_requests_counter(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): trace_handler._moderation_trace(make_moderation_info()) @@ -971,48 +980,48 @@ class TestModerationTrace: class TestSuggestedQuestionTrace: - def test_emits_event_with_correct_name(self, trace_handler, mock_exporter): + def test_emits_event_with_correct_name(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._suggested_question_trace(make_suggested_question_info()) assert mock_emit.call_args[1]["event_name"] == EnterpriseTelemetryEvent.SUGGESTED_QUESTION_GENERATION - def test_duration_computed_from_timestamps(self, trace_handler, mock_exporter): + def test_duration_computed_from_timestamps(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._suggested_question_trace(make_suggested_question_info()) attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.suggested_question.duration"] == pytest.approx(5.0) - def test_duration_is_none_when_timestamps_missing(self, trace_handler, mock_exporter): + def test_duration_is_none_when_timestamps_missing(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._suggested_question_trace(make_suggested_question_info(start_time=None, end_time=None)) attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.suggested_question.duration"] is None - def test_status_is_failed_when_error_present(self, trace_handler, mock_exporter): + def test_status_is_failed_when_error_present(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._suggested_question_trace(make_suggested_question_info(error="Generation failed")) attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.suggested_question.status"] == "failed" - def test_status_falls_back_to_succeeded_when_no_error(self, trace_handler, mock_exporter): + def test_status_falls_back_to_succeeded_when_no_error(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._suggested_question_trace(make_suggested_question_info(status=None, error=None)) attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.suggested_question.status"] == "succeeded" - def test_question_count_attribute(self, trace_handler, mock_exporter): + def test_question_count_attribute(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._suggested_question_trace(make_suggested_question_info()) attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.suggested_question.count"] == 2 - def test_questions_gated_by_include_content(self, trace_handler, mock_exporter): + def test_questions_gated_by_include_content(self, trace_handler: EnterpriseOtelTrace, mock_exporter): mock_exporter.include_content = False with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._suggested_question_trace(make_suggested_question_info()) @@ -1020,7 +1029,7 @@ class TestSuggestedQuestionTrace: attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.suggested_question.questions"].startswith("ref:message_id=") - def test_increments_requests_counter(self, trace_handler, mock_exporter): + def test_increments_requests_counter(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): trace_handler._suggested_question_trace(make_suggested_question_info()) @@ -1037,48 +1046,48 @@ class TestSuggestedQuestionTrace: class TestDatasetRetrievalTrace: - def test_emits_event_with_correct_name(self, trace_handler, mock_exporter): + def test_emits_event_with_correct_name(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info()) assert mock_emit.call_args[1]["event_name"] == EnterpriseTelemetryEvent.DATASET_RETRIEVAL - def test_document_count_attribute(self, trace_handler, mock_exporter): + def test_document_count_attribute(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info()) attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.retrieval.document_count"] == 1 - def test_dataset_ids_extracted(self, trace_handler, mock_exporter): + def test_dataset_ids_extracted(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info()) attrs = mock_emit.call_args[1]["attributes"] assert "ds-001" in attrs["dify.dataset.id"] - def test_empty_documents_has_zero_count(self, trace_handler, mock_exporter): + def test_empty_documents_has_zero_count(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info(documents=[])) attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.retrieval.document_count"] == 0 - def test_status_succeeded_when_no_error(self, trace_handler, mock_exporter): + def test_status_succeeded_when_no_error(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info()) attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.retrieval.status"] == "succeeded" - def test_status_failed_when_error_present(self, trace_handler, mock_exporter): + def test_status_failed_when_error_present(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info(error="DB error")) attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.retrieval.status"] == "failed" - def test_embedding_model_attributes_set_when_present(self, trace_handler, mock_exporter): + def test_embedding_model_attributes_set_when_present(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info()) @@ -1086,7 +1095,7 @@ class TestDatasetRetrievalTrace: assert "dify.dataset.embedding_providers" in attrs assert "dify.dataset.embedding_models" in attrs - def test_no_embedding_model_attributes_when_not_provided(self, trace_handler, mock_exporter): + def test_no_embedding_model_attributes_when_not_provided(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._dataset_retrieval_trace( make_dataset_retrieval_info(metadata={"app_id": "app-001", "tenant_id": "tenant-abc"}) @@ -1096,7 +1105,7 @@ class TestDatasetRetrievalTrace: assert "dify.dataset.embedding_providers" not in attrs assert "dify.dataset.embedding_models" not in attrs - def test_rerank_attributes_set_when_present(self, trace_handler, mock_exporter): + def test_rerank_attributes_set_when_present(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._dataset_retrieval_trace( make_dataset_retrieval_info( @@ -1113,7 +1122,7 @@ class TestDatasetRetrievalTrace: assert attrs["dify.retrieval.rerank_provider"] == "cohere" assert attrs["dify.retrieval.rerank_model"] == "rerank-english" - def test_no_rerank_attributes_when_not_present(self, trace_handler, mock_exporter): + def test_no_rerank_attributes_when_not_present(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._dataset_retrieval_trace( make_dataset_retrieval_info(metadata={"app_id": "app-001", "tenant_id": "tenant-abc"}) @@ -1123,7 +1132,7 @@ class TestDatasetRetrievalTrace: assert "dify.retrieval.rerank_provider" not in attrs assert "dify.retrieval.rerank_model" not in attrs - def test_dataset_retrieval_counter_incremented_per_dataset(self, trace_handler, mock_exporter): + def test_dataset_retrieval_counter_incremented_per_dataset(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info()) @@ -1135,7 +1144,7 @@ class TestDatasetRetrievalTrace: assert len(ds_calls) == 1 assert ds_calls[0][0][2]["dataset_id"] == "ds-001" - def test_no_dataset_retrieval_counter_when_no_documents(self, trace_handler, mock_exporter): + def test_no_dataset_retrieval_counter_when_no_documents(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info(documents=[])) @@ -1146,7 +1155,7 @@ class TestDatasetRetrievalTrace: ] assert len(ds_calls) == 0 - def test_query_gated_by_include_content(self, trace_handler, mock_exporter): + def test_query_gated_by_include_content(self, trace_handler: EnterpriseOtelTrace, mock_exporter): mock_exporter.include_content = False with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info()) @@ -1161,34 +1170,34 @@ class TestDatasetRetrievalTrace: class TestGenerateNameTrace: - def test_emits_event_with_correct_name(self, trace_handler, mock_exporter): + def test_emits_event_with_correct_name(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._generate_name_trace(make_generate_name_info()) assert mock_emit.call_args[1]["event_name"] == EnterpriseTelemetryEvent.GENERATE_NAME_EXECUTION - def test_duration_computed_from_timestamps(self, trace_handler, mock_exporter): + def test_duration_computed_from_timestamps(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._generate_name_trace(make_generate_name_info()) attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.generate_name.duration"] == pytest.approx(5.0) - def test_no_duration_when_timestamps_missing(self, trace_handler, mock_exporter): + def test_no_duration_when_timestamps_missing(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._generate_name_trace(make_generate_name_info(start_time=None, end_time=None)) attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.generate_name.duration"] is None - def test_status_succeeded_on_success(self, trace_handler, mock_exporter): + def test_status_succeeded_on_success(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._generate_name_trace(make_generate_name_info()) attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.generate_name.status"] == "succeeded" - def test_status_failed_when_metadata_has_error(self, trace_handler, mock_exporter): + def test_status_failed_when_metadata_has_error(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._generate_name_trace( make_generate_name_info( @@ -1203,7 +1212,7 @@ class TestGenerateNameTrace: attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.generate_name.status"] == "failed" - def test_inputs_and_outputs_gated_by_include_content(self, trace_handler, mock_exporter): + def test_inputs_and_outputs_gated_by_include_content(self, trace_handler: EnterpriseOtelTrace, mock_exporter): mock_exporter.include_content = False with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._generate_name_trace(make_generate_name_info()) @@ -1212,7 +1221,7 @@ class TestGenerateNameTrace: assert attrs["dify.generate_name.inputs"].startswith("ref:conversation_id=") assert attrs["dify.generate_name.outputs"].startswith("ref:conversation_id=") - def test_increments_requests_counter(self, trace_handler, mock_exporter): + def test_increments_requests_counter(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): trace_handler._generate_name_trace(make_generate_name_info()) @@ -1229,27 +1238,27 @@ class TestGenerateNameTrace: class TestPromptGenerationTrace: - def test_emits_event_with_correct_name(self, trace_handler, mock_exporter): + def test_emits_event_with_correct_name(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._prompt_generation_trace(make_prompt_generation_info()) assert mock_emit.call_args[1]["event_name"] == EnterpriseTelemetryEvent.PROMPT_GENERATION_EXECUTION - def test_status_succeeded_on_success(self, trace_handler, mock_exporter): + def test_status_succeeded_on_success(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._prompt_generation_trace(make_prompt_generation_info()) attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.prompt_generation.status"] == "succeeded" - def test_status_failed_when_error_present(self, trace_handler, mock_exporter): + def test_status_failed_when_error_present(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._prompt_generation_trace(make_prompt_generation_info(error="Generation error")) attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.prompt_generation.status"] == "failed" - def test_token_counters_incremented(self, trace_handler, mock_exporter): + def test_token_counters_incremented(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): trace_handler._prompt_generation_trace(make_prompt_generation_info()) @@ -1258,7 +1267,7 @@ class TestPromptGenerationTrace: assert EnterpriseTelemetryCounter.INPUT_TOKENS in counter_names assert EnterpriseTelemetryCounter.OUTPUT_TOKENS in counter_names - def test_records_duration_histogram(self, trace_handler, mock_exporter): + def test_records_duration_histogram(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): trace_handler._prompt_generation_trace(make_prompt_generation_info()) @@ -1270,7 +1279,7 @@ class TestPromptGenerationTrace: assert len(hist_calls) == 1 assert hist_calls[0][0][1] == pytest.approx(3.2) - def test_total_price_attribute_set_when_present(self, trace_handler, mock_exporter): + def test_total_price_attribute_set_when_present(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._prompt_generation_trace(make_prompt_generation_info(total_price=0.05, currency="USD")) @@ -1278,14 +1287,14 @@ class TestPromptGenerationTrace: assert attrs["dify.prompt_generation.total_price"] == pytest.approx(0.05) assert attrs["dify.prompt_generation.currency"] == "USD" - def test_no_total_price_attribute_when_none(self, trace_handler, mock_exporter): + def test_no_total_price_attribute_when_none(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._prompt_generation_trace(make_prompt_generation_info(total_price=None)) attrs = mock_emit.call_args[1]["attributes"] assert "dify.prompt_generation.total_price" not in attrs - def test_error_increments_error_counter(self, trace_handler, mock_exporter): + def test_error_increments_error_counter(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): trace_handler._prompt_generation_trace(make_prompt_generation_info(error="Prompt failed")) @@ -1294,7 +1303,7 @@ class TestPromptGenerationTrace: ] assert len(error_calls) == 1 - def test_no_error_counter_on_success(self, trace_handler, mock_exporter): + def test_no_error_counter_on_success(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): trace_handler._prompt_generation_trace(make_prompt_generation_info()) @@ -1303,7 +1312,7 @@ class TestPromptGenerationTrace: ] assert len(error_calls) == 0 - def test_instruction_gated_by_include_content(self, trace_handler, mock_exporter): + def test_instruction_gated_by_include_content(self, trace_handler: EnterpriseOtelTrace, mock_exporter): mock_exporter.include_content = False with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._prompt_generation_trace(make_prompt_generation_info()) @@ -1311,7 +1320,7 @@ class TestPromptGenerationTrace: attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.prompt_generation.instruction"].startswith("ref:trace_id=") - def test_operation_type_label_used_in_token_counters(self, trace_handler, mock_exporter): + def test_operation_type_label_used_in_token_counters(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): trace_handler._prompt_generation_trace(make_prompt_generation_info(operation_type="code_generate")) @@ -1321,7 +1330,7 @@ class TestPromptGenerationTrace: assert len(token_calls) == 1 assert token_calls[0][0][2]["operation_type"] == "code_generate" - def test_emits_correct_tenant_id(self, trace_handler, mock_exporter): + def test_emits_correct_tenant_id(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._prompt_generation_trace(make_prompt_generation_info())