From 29bfa33d599161f8a81fad724bbb80b03aa8c6e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=91=86=E8=90=8C=E9=97=B7=E6=B2=B9=E7=93=B6?= <253605712@qq.com> Date: Mon, 13 Apr 2026 14:21:58 +0800 Subject: [PATCH] feat: support ttft report to langfuse (#33344) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- api/core/ops/langfuse_trace/langfuse_trace.py | 34 ++++- .../core/ops/test_langfuse_trace.py | 137 ++++++++++++++++++ .../services/test_external_dataset_service.py | 2 +- .../services/test_message_service.py | 16 +- 4 files changed, 183 insertions(+), 6 deletions(-) create mode 100644 api/tests/unit_tests/core/ops/test_langfuse_trace.py diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index 9be2ce1bdf..d53aa84aed 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -59,6 +59,24 @@ class LangFuseDataTrace(BaseTraceInstance): ) self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") + @staticmethod + def _get_completion_start_time( + start_time: datetime | None, time_to_first_token: float | int | None + ) -> datetime | None: + """Convert a relative TTFT value in seconds into Langfuse's absolute completion start time.""" + if start_time is None or time_to_first_token is None: + return None + + try: + ttft_seconds = float(time_to_first_token) + except (TypeError, ValueError): + return None + + if ttft_seconds < 0: + return None + + return start_time + timedelta(seconds=ttft_seconds) + def trace(self, trace_info: BaseTraceInfo): if isinstance(trace_info, WorkflowTraceInfo): self.workflow_trace(trace_info) @@ -189,10 +207,18 @@ class LangFuseDataTrace(BaseTraceInstance): total_token = metadata.get("total_tokens", 0) prompt_tokens = 0 completion_tokens = 0 + completion_start_time = None try: - usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {}) + usage_data = process_data.get("usage") + if not isinstance(usage_data, dict): + usage_data = outputs.get("usage") + if not isinstance(usage_data, dict): + usage_data = {} prompt_tokens = usage_data.get("prompt_tokens", 0) completion_tokens = usage_data.get("completion_tokens", 0) + completion_start_time = self._get_completion_start_time( + created_at, usage_data.get("time_to_first_token") + ) except Exception: logger.error("Failed to extract usage", exc_info=True) @@ -210,6 +236,7 @@ class LangFuseDataTrace(BaseTraceInstance): trace_id=trace_id, model=process_data.get("model_name"), start_time=created_at, + completion_start_time=completion_start_time, end_time=finished_at, input=inputs, output=outputs, @@ -290,11 +317,16 @@ class LangFuseDataTrace(BaseTraceInstance): unit=UnitEnum.TOKENS, totalCost=message_data.total_price, ) + completion_start_time = self._get_completion_start_time( + trace_info.start_time, + trace_info.gen_ai_server_time_to_first_token, + ) langfuse_generation_data = LangfuseGeneration( name="llm", trace_id=trace_id, start_time=trace_info.start_time, + completion_start_time=completion_start_time, end_time=trace_info.end_time, model=message_data.model_id, input=trace_info.inputs, diff --git a/api/tests/unit_tests/core/ops/test_langfuse_trace.py b/api/tests/unit_tests/core/ops/test_langfuse_trace.py new file mode 100644 index 0000000000..f8951d2b4a --- /dev/null +++ b/api/tests/unit_tests/core/ops/test_langfuse_trace.py @@ -0,0 +1,137 @@ +"""Tests for Langfuse TTFT reporting support.""" + +from datetime import datetime, timedelta +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +from graphon.enums import BuiltinNodeTypes + +from core.ops.entities.config_entity import LangfuseConfig +from core.ops.entities.trace_entity import MessageTraceInfo, WorkflowTraceInfo +from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace + + +def _create_trace_instance() -> LangFuseDataTrace: + with patch("core.ops.langfuse_trace.langfuse_trace.Langfuse", autospec=True): + return LangFuseDataTrace( + LangfuseConfig( + public_key="public-key", + secret_key="secret-key", + host="https://cloud.langfuse.com", + ) + ) + + +class TestLangFuseDataTraceCompletionStartTime: + def test_message_trace_reports_completion_start_time(self): + trace = _create_trace_instance() + start_time = datetime(2026, 3, 11, 13, 0, 0) + trace_info = MessageTraceInfo( + trace_id="trace-123", + message_id="message-123", + message_data=SimpleNamespace( + id="message-123", + from_account_id="account-1", + from_end_user_id=None, + conversation_id="conversation-1", + model_id="gpt-4o-mini", + answer="hi there", + status="normal", + error="", + total_price=0.12, + provider_response_latency=3.5, + ), + conversation_model="chat", + message_tokens=10, + answer_tokens=20, + total_tokens=30, + error="", + inputs="hello", + outputs="hi there", + file_list=[], + start_time=start_time, + end_time=start_time + timedelta(seconds=3.5), + metadata={}, + message_file_data=None, + conversation_mode="chat", + gen_ai_server_time_to_first_token=1.2, + llm_streaming_time_to_generate=2.3, + is_streaming_request=True, + ) + + with patch.object(trace, "add_trace"), patch.object(trace, "add_generation") as add_generation: + trace.message_trace(trace_info) + + generation = add_generation.call_args.args[0] + assert generation.completion_start_time == start_time + timedelta(seconds=1.2) + + def test_workflow_trace_reports_completion_start_time_from_llm_usage(self): + trace = _create_trace_instance() + start_time = datetime(2026, 3, 11, 13, 0, 0) + node_execution = SimpleNamespace( + id="node-exec-1", + title="Chat LLM", + node_type=BuiltinNodeTypes.LLM, + status="succeeded", + process_data={ + "model_mode": "chat", + "model_name": "gpt-4o-mini", + "usage": { + "prompt_tokens": 10, + "completion_tokens": 20, + "time_to_first_token": 1.2, + }, + }, + inputs={"question": "hello"}, + outputs={"text": "hi there"}, + created_at=start_time, + elapsed_time=3.5, + metadata={}, + ) + trace_info = WorkflowTraceInfo( + trace_id="trace-123", + workflow_data={}, + conversation_id=None, + workflow_app_log_id=None, + workflow_id="workflow-1", + tenant_id="tenant-1", + workflow_run_id="workflow-run-1", + workflow_run_elapsed_time=3.5, + workflow_run_status="succeeded", + workflow_run_inputs={"question": "hello"}, + workflow_run_outputs={"answer": "hi there"}, + workflow_run_version="1", + error="", + total_tokens=30, + file_list=[], + query="hello", + metadata={"app_id": "app-1", "user_id": "user-1"}, + start_time=start_time, + end_time=start_time + timedelta(seconds=3.5), + ) + repository = MagicMock() + repository.get_by_workflow_execution.return_value = [node_execution] + + with ( + patch.object(trace, "add_trace"), + patch.object(trace, "add_span"), + patch.object(trace, "add_generation") as add_generation, + patch.object(trace, "get_service_account_with_tenant", return_value=MagicMock()), + patch("core.ops.langfuse_trace.langfuse_trace.db", MagicMock()), + patch( + "core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + return_value=repository, + ), + ): + trace.workflow_trace(trace_info) + + generation = add_generation.call_args.kwargs["langfuse_generation_data"] + assert generation.completion_start_time == start_time + timedelta(seconds=1.2) + + def test_ignores_invalid_ttft_values(self): + trace = _create_trace_instance() + start_time = datetime(2026, 3, 11, 13, 0, 0) + + assert trace._get_completion_start_time(start_time, None) is None + assert trace._get_completion_start_time(start_time, -1) is None + assert trace._get_completion_start_time(start_time, "invalid") is None diff --git a/api/tests/unit_tests/services/test_external_dataset_service.py b/api/tests/unit_tests/services/test_external_dataset_service.py index b802f6931f..9c1a92b4d9 100644 --- a/api/tests/unit_tests/services/test_external_dataset_service.py +++ b/api/tests/unit_tests/services/test_external_dataset_service.py @@ -1702,7 +1702,7 @@ class TestExternalDatasetServiceFetchRetrieval: mock_process.return_value = mock_response # Act & Assert - with pytest.raises(Exception, match=""): + with pytest.raises(ValueError): ExternalDatasetService.fetch_external_knowledge_retrieval( "tenant-123", "dataset-123", "query", {"top_k": 5} ) diff --git a/api/tests/unit_tests/services/test_message_service.py b/api/tests/unit_tests/services/test_message_service.py index b6e990ebe0..969132cfd8 100644 --- a/api/tests/unit_tests/services/test_message_service.py +++ b/api/tests/unit_tests/services/test_message_service.py @@ -131,9 +131,12 @@ class TestMessageServicePaginationByFirstId: assert result.has_more is False # Test 03: Basic pagination without first_id (desc order) + @patch("services.message_service._create_execution_extra_content_repository") @patch("services.message_service.db") @patch("services.message_service.ConversationService") - def test_pagination_by_first_id_without_first_id_desc(self, mock_conversation_service, mock_db, factory): + def test_pagination_by_first_id_without_first_id_desc( + self, mock_conversation_service, mock_db, mock_create_repo, factory + ): """Test basic pagination without first_id in descending order.""" # Arrange app = factory.create_app_mock() @@ -171,9 +174,12 @@ class TestMessageServicePaginationByFirstId: assert result.data[0].id == "msg-000" # Test 04: Basic pagination without first_id (asc order) + @patch("services.message_service._create_execution_extra_content_repository") @patch("services.message_service.db") @patch("services.message_service.ConversationService") - def test_pagination_by_first_id_without_first_id_asc(self, mock_conversation_service, mock_db, factory): + def test_pagination_by_first_id_without_first_id_asc( + self, mock_conversation_service, mock_db, mock_create_repo, factory + ): """Test basic pagination without first_id in ascending order.""" # Arrange app = factory.create_app_mock() @@ -211,9 +217,10 @@ class TestMessageServicePaginationByFirstId: assert result.data[4].id == "msg-000" # Test 05: Pagination with first_id + @patch("services.message_service._create_execution_extra_content_repository") @patch("services.message_service.db") @patch("services.message_service.ConversationService") - def test_pagination_by_first_id_with_first_id(self, mock_conversation_service, mock_db, factory): + def test_pagination_by_first_id_with_first_id(self, mock_conversation_service, mock_db, mock_create_repo, factory): """Test pagination with first_id to get messages before a specific message.""" # Arrange app = factory.create_app_mock() @@ -278,9 +285,10 @@ class TestMessageServicePaginationByFirstId: ) # Test 07: Has_more flag when results exceed limit + @patch("services.message_service._create_execution_extra_content_repository") @patch("services.message_service.db") @patch("services.message_service.ConversationService") - def test_pagination_by_first_id_has_more_true(self, mock_conversation_service, mock_db, factory): + def test_pagination_by_first_id_has_more_true(self, mock_conversation_service, mock_db, mock_create_repo, factory): """Test has_more flag is True when results exceed limit.""" # Arrange app = factory.create_app_mock()