From 00cb1c26a1472d2512173453fef8a7f6669c35de Mon Sep 17 00:00:00 2001 From: Shaun Date: Tue, 29 Jul 2025 19:34:46 +0800 Subject: [PATCH] refactor: pass external_trace_id to message trace (#23089) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/service_api/app/completion.py | 3 ++ api/core/ops/aliyun_trace/aliyun_trace.py | 28 ++++++++++++--- .../aliyun_trace/data_exporter/traceclient.py | 12 +++++-- .../arize_phoenix_trace.py | 28 ++++++++------- api/core/ops/entities/trace_entity.py | 1 + api/core/ops/langfuse_trace/langfuse_trace.py | 19 +++++----- .../ops/langsmith_trace/langsmith_trace.py | 17 +++++---- api/core/ops/opik_trace/opik_trace.py | 19 +++++----- api/core/ops/ops_trace_manager.py | 15 +++++--- api/core/ops/weave_trace/weave_trace.py | 35 +++++++++++++------ 10 files changed, 115 insertions(+), 62 deletions(-) diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 7762672494..edc66cc5e9 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -47,6 +47,9 @@ class CompletionApi(Resource): parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") args = parser.parse_args() + external_trace_id = get_external_trace_id(request) + if external_trace_id: + args["external_trace_id"] = external_trace_id streaming = args["response_mode"] == "streaming" diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/core/ops/aliyun_trace/aliyun_trace.py index af0e38f7ef..06050619e9 100644 --- a/api/core/ops/aliyun_trace/aliyun_trace.py +++ b/api/core/ops/aliyun_trace/aliyun_trace.py @@ -10,6 +10,7 @@ from sqlalchemy.orm import Session, sessionmaker from core.ops.aliyun_trace.data_exporter.traceclient import ( TraceClient, convert_datetime_to_nanoseconds, + convert_string_to_id, convert_to_span_id, convert_to_trace_id, generate_span_id, @@ -101,8 +102,9 @@ class AliyunDataTrace(BaseTraceInstance): raise ValueError(f"Aliyun get run url failed: {str(e)}") def workflow_trace(self, trace_info: WorkflowTraceInfo): - external_trace_id = trace_info.metadata.get("external_trace_id") - trace_id = external_trace_id or convert_to_trace_id(trace_info.workflow_run_id) + trace_id = convert_to_trace_id(trace_info.workflow_run_id) + if trace_info.trace_id: + trace_id = convert_string_to_id(trace_info.trace_id) workflow_span_id = convert_to_span_id(trace_info.workflow_run_id, "workflow") self.add_workflow_span(trace_id, workflow_span_id, trace_info) @@ -130,6 +132,9 @@ class AliyunDataTrace(BaseTraceInstance): status = Status(StatusCode.ERROR, trace_info.error) trace_id = convert_to_trace_id(message_id) + if trace_info.trace_id: + trace_id = convert_string_to_id(trace_info.trace_id) + message_span_id = convert_to_span_id(message_id, "message") message_span = SpanData( trace_id=trace_id, @@ -186,9 +191,13 @@ class AliyunDataTrace(BaseTraceInstance): return message_id = trace_info.message_id + trace_id = convert_to_trace_id(message_id) + if trace_info.trace_id: + trace_id = convert_string_to_id(trace_info.trace_id) + documents_data = extract_retrieval_documents(trace_info.documents) dataset_retrieval_span = SpanData( - trace_id=convert_to_trace_id(message_id), + trace_id=trace_id, parent_span_id=convert_to_span_id(message_id, "message"), span_id=generate_span_id(), name="dataset_retrieval", @@ -214,8 +223,12 @@ class AliyunDataTrace(BaseTraceInstance): if trace_info.error: status = Status(StatusCode.ERROR, trace_info.error) + trace_id = convert_to_trace_id(message_id) + if trace_info.trace_id: + trace_id = convert_string_to_id(trace_info.trace_id) + tool_span = SpanData( - trace_id=convert_to_trace_id(message_id), + trace_id=trace_id, parent_span_id=convert_to_span_id(message_id, "message"), span_id=generate_span_id(), name=trace_info.tool_name, @@ -451,8 +464,13 @@ class AliyunDataTrace(BaseTraceInstance): status: Status = Status(StatusCode.OK) if trace_info.error: status = Status(StatusCode.ERROR, trace_info.error) + + trace_id = convert_to_trace_id(message_id) + if trace_info.trace_id: + trace_id = convert_string_to_id(trace_info.trace_id) + suggested_question_span = SpanData( - trace_id=convert_to_trace_id(message_id), + trace_id=trace_id, parent_span_id=convert_to_span_id(message_id, "message"), span_id=convert_to_span_id(message_id, "suggested_question"), name="suggested_question", diff --git a/api/core/ops/aliyun_trace/data_exporter/traceclient.py b/api/core/ops/aliyun_trace/data_exporter/traceclient.py index 934ce95a64..bd19c8a503 100644 --- a/api/core/ops/aliyun_trace/data_exporter/traceclient.py +++ b/api/core/ops/aliyun_trace/data_exporter/traceclient.py @@ -181,15 +181,21 @@ def convert_to_trace_id(uuid_v4: Optional[str]) -> int: raise ValueError(f"Invalid UUID input: {e}") +def convert_string_to_id(string: Optional[str]) -> int: + if not string: + return generate_span_id() + hash_bytes = hashlib.sha256(string.encode("utf-8")).digest() + id = int.from_bytes(hash_bytes[:8], byteorder="big", signed=False) + return id + + def convert_to_span_id(uuid_v4: Optional[str], span_type: str) -> int: try: uuid_obj = uuid.UUID(uuid_v4) except Exception as e: raise ValueError(f"Invalid UUID input: {e}") combined_key = f"{uuid_obj.hex}-{span_type}" - hash_bytes = hashlib.sha256(combined_key.encode("utf-8")).digest() - span_id = int.from_bytes(hash_bytes[:8], byteorder="big", signed=False) - return span_id + return convert_string_to_id(combined_key) def convert_datetime_to_nanoseconds(start_time_a: Optional[datetime]) -> Optional[int]: diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index f252a022d8..a20f2485c8 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -91,16 +91,21 @@ def datetime_to_nanos(dt: Optional[datetime]) -> int: return int(dt.timestamp() * 1_000_000_000) -def uuid_to_trace_id(string: Optional[str]) -> int: - """Convert UUID string to a valid trace ID (16-byte integer).""" +def string_to_trace_id128(string: Optional[str]) -> int: + """ + Convert any input string into a stable 128-bit integer trace ID. + + This uses SHA-256 hashing and takes the first 16 bytes (128 bits) of the digest. + It's suitable for generating consistent, unique identifiers from strings. + """ if string is None: string = "" hash_object = hashlib.sha256(string.encode()) - # Take the first 16 bytes (128 bits) of the hash + # Take the first 16 bytes (128 bits) of the hash digest digest = hash_object.digest()[:16] - # Convert to integer (128 bits) + # Convert to a 128-bit integer return int.from_bytes(digest, byteorder="big") @@ -153,8 +158,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): } workflow_metadata.update(trace_info.metadata) - external_trace_id = trace_info.metadata.get("external_trace_id") - trace_id = external_trace_id or uuid_to_trace_id(trace_info.workflow_run_id) + trace_id = string_to_trace_id128(trace_info.trace_id or trace_info.workflow_run_id) span_id = RandomIdGenerator().generate_span_id() context = SpanContext( trace_id=trace_id, @@ -310,7 +314,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id, } - trace_id = uuid_to_trace_id(trace_info.message_id) + trace_id = string_to_trace_id128(trace_info.trace_id or trace_info.message_id) message_span_id = RandomIdGenerator().generate_span_id() span_context = SpanContext( trace_id=trace_id, @@ -406,7 +410,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): } metadata.update(trace_info.metadata) - trace_id = uuid_to_trace_id(trace_info.message_id) + trace_id = string_to_trace_id128(trace_info.message_id) span_id = RandomIdGenerator().generate_span_id() context = SpanContext( trace_id=trace_id, @@ -468,7 +472,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): } metadata.update(trace_info.metadata) - trace_id = uuid_to_trace_id(trace_info.message_id) + trace_id = string_to_trace_id128(trace_info.message_id) span_id = RandomIdGenerator().generate_span_id() context = SpanContext( trace_id=trace_id, @@ -521,7 +525,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): } metadata.update(trace_info.metadata) - trace_id = uuid_to_trace_id(trace_info.message_id) + trace_id = string_to_trace_id128(trace_info.message_id) span_id = RandomIdGenerator().generate_span_id() context = SpanContext( trace_id=trace_id, @@ -568,7 +572,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): "tool_config": json.dumps(trace_info.tool_config, ensure_ascii=False), } - trace_id = uuid_to_trace_id(trace_info.message_id) + trace_id = string_to_trace_id128(trace_info.message_id) tool_span_id = RandomIdGenerator().generate_span_id() logger.info("[Arize/Phoenix] Creating tool trace with trace_id: %s, span_id: %s", trace_id, tool_span_id) @@ -629,7 +633,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): } metadata.update(trace_info.metadata) - trace_id = uuid_to_trace_id(trace_info.message_id) + trace_id = string_to_trace_id128(trace_info.message_id) span_id = RandomIdGenerator().generate_span_id() context = SpanContext( trace_id=trace_id, diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py index 151fa2aaf4..3bad5c92fb 100644 --- a/api/core/ops/entities/trace_entity.py +++ b/api/core/ops/entities/trace_entity.py @@ -14,6 +14,7 @@ class BaseTraceInfo(BaseModel): start_time: Optional[datetime] = None end_time: Optional[datetime] = None metadata: dict[str, Any] + trace_id: Optional[str] = None @field_validator("inputs", "outputs") @classmethod diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index d356e735ee..3a03d9f4fe 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -67,14 +67,13 @@ class LangFuseDataTrace(BaseTraceInstance): self.generate_name_trace(trace_info) def workflow_trace(self, trace_info: WorkflowTraceInfo): - external_trace_id = trace_info.metadata.get("external_trace_id") - trace_id = external_trace_id or trace_info.workflow_run_id + trace_id = trace_info.trace_id or trace_info.workflow_run_id user_id = trace_info.metadata.get("user_id") metadata = trace_info.metadata metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id if trace_info.message_id: - trace_id = external_trace_id or trace_info.message_id + trace_id = trace_info.trace_id or trace_info.message_id name = TraceTaskName.MESSAGE_TRACE.value trace_data = LangfuseTrace( id=trace_id, @@ -250,8 +249,10 @@ class LangFuseDataTrace(BaseTraceInstance): user_id = end_user_data.session_id metadata["user_id"] = user_id + trace_id = trace_info.trace_id or message_id + trace_data = LangfuseTrace( - id=message_id, + id=trace_id, user_id=user_id, name=TraceTaskName.MESSAGE_TRACE.value, input={ @@ -285,7 +286,7 @@ class LangFuseDataTrace(BaseTraceInstance): langfuse_generation_data = LangfuseGeneration( name="llm", - trace_id=message_id, + trace_id=trace_id, start_time=trace_info.start_time, end_time=trace_info.end_time, model=message_data.model_id, @@ -311,7 +312,7 @@ class LangFuseDataTrace(BaseTraceInstance): "preset_response": trace_info.preset_response, "inputs": trace_info.inputs, }, - trace_id=trace_info.message_id, + trace_id=trace_info.trace_id or trace_info.message_id, start_time=trace_info.start_time or trace_info.message_data.created_at, end_time=trace_info.end_time or trace_info.message_data.created_at, metadata=trace_info.metadata, @@ -334,7 +335,7 @@ class LangFuseDataTrace(BaseTraceInstance): name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value, input=trace_info.inputs, output=str(trace_info.suggested_question), - trace_id=trace_info.message_id, + trace_id=trace_info.trace_id or trace_info.message_id, start_time=trace_info.start_time, end_time=trace_info.end_time, metadata=trace_info.metadata, @@ -352,7 +353,7 @@ class LangFuseDataTrace(BaseTraceInstance): name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value, input=trace_info.inputs, output={"documents": trace_info.documents}, - trace_id=trace_info.message_id, + trace_id=trace_info.trace_id or trace_info.message_id, start_time=trace_info.start_time or trace_info.message_data.created_at, end_time=trace_info.end_time or trace_info.message_data.updated_at, metadata=trace_info.metadata, @@ -365,7 +366,7 @@ class LangFuseDataTrace(BaseTraceInstance): name=trace_info.tool_name, input=trace_info.tool_inputs, output=trace_info.tool_outputs, - trace_id=trace_info.message_id, + trace_id=trace_info.trace_id or trace_info.message_id, start_time=trace_info.start_time, end_time=trace_info.end_time, metadata=trace_info.metadata, diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index fb3f6ecf0d..f9e5128e89 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -65,8 +65,7 @@ class LangSmithDataTrace(BaseTraceInstance): self.generate_name_trace(trace_info) def workflow_trace(self, trace_info: WorkflowTraceInfo): - external_trace_id = trace_info.metadata.get("external_trace_id") - trace_id = external_trace_id or trace_info.message_id or trace_info.workflow_run_id + trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id if trace_info.start_time is None: trace_info.start_time = datetime.now() message_dotted_order = ( @@ -290,7 +289,7 @@ class LangSmithDataTrace(BaseTraceInstance): reference_example_id=None, input_attachments={}, output_attachments={}, - trace_id=None, + trace_id=trace_info.trace_id, dotted_order=None, parent_run_id=None, ) @@ -319,7 +318,7 @@ class LangSmithDataTrace(BaseTraceInstance): reference_example_id=None, input_attachments={}, output_attachments={}, - trace_id=None, + trace_id=trace_info.trace_id, dotted_order=None, id=str(uuid.uuid4()), ) @@ -351,7 +350,7 @@ class LangSmithDataTrace(BaseTraceInstance): reference_example_id=None, input_attachments={}, output_attachments={}, - trace_id=None, + trace_id=trace_info.trace_id, dotted_order=None, error="", file_list=[], @@ -381,7 +380,7 @@ class LangSmithDataTrace(BaseTraceInstance): reference_example_id=None, input_attachments={}, output_attachments={}, - trace_id=None, + trace_id=trace_info.trace_id, dotted_order=None, error="", file_list=[], @@ -410,7 +409,7 @@ class LangSmithDataTrace(BaseTraceInstance): reference_example_id=None, input_attachments={}, output_attachments={}, - trace_id=None, + trace_id=trace_info.trace_id, dotted_order=None, error="", file_list=[], @@ -440,7 +439,7 @@ class LangSmithDataTrace(BaseTraceInstance): reference_example_id=None, input_attachments={}, output_attachments={}, - trace_id=None, + trace_id=trace_info.trace_id, dotted_order=None, error=trace_info.error or "", ) @@ -465,7 +464,7 @@ class LangSmithDataTrace(BaseTraceInstance): reference_example_id=None, input_attachments={}, output_attachments={}, - trace_id=None, + trace_id=trace_info.trace_id, dotted_order=None, error="", file_list=[], diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index 1e52f28350..dd6a424ddb 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -96,8 +96,7 @@ class OpikDataTrace(BaseTraceInstance): self.generate_name_trace(trace_info) def workflow_trace(self, trace_info: WorkflowTraceInfo): - external_trace_id = trace_info.metadata.get("external_trace_id") - dify_trace_id = external_trace_id or trace_info.workflow_run_id + dify_trace_id = trace_info.trace_id or trace_info.workflow_run_id opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id) workflow_metadata = wrap_metadata( trace_info.metadata, message_id=trace_info.message_id, workflow_app_log_id=trace_info.workflow_app_log_id @@ -105,7 +104,7 @@ class OpikDataTrace(BaseTraceInstance): root_span_id = None if trace_info.message_id: - dify_trace_id = external_trace_id or trace_info.message_id + dify_trace_id = trace_info.trace_id or trace_info.message_id opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id) trace_data = { @@ -276,7 +275,7 @@ class OpikDataTrace(BaseTraceInstance): return metadata = trace_info.metadata - message_id = trace_info.message_id + dify_trace_id = trace_info.trace_id or trace_info.message_id user_id = message_data.from_account_id metadata["user_id"] = user_id @@ -291,7 +290,7 @@ class OpikDataTrace(BaseTraceInstance): metadata["end_user_id"] = end_user_id trace_data = { - "id": prepare_opik_uuid(trace_info.start_time, message_id), + "id": prepare_opik_uuid(trace_info.start_time, dify_trace_id), "name": TraceTaskName.MESSAGE_TRACE.value, "start_time": trace_info.start_time, "end_time": trace_info.end_time, @@ -330,7 +329,7 @@ class OpikDataTrace(BaseTraceInstance): start_time = trace_info.start_time or trace_info.message_data.created_at span_data = { - "trace_id": prepare_opik_uuid(start_time, trace_info.message_id), + "trace_id": prepare_opik_uuid(start_time, trace_info.trace_id or trace_info.message_id), "name": TraceTaskName.MODERATION_TRACE.value, "type": "tool", "start_time": start_time, @@ -356,7 +355,7 @@ class OpikDataTrace(BaseTraceInstance): start_time = trace_info.start_time or message_data.created_at span_data = { - "trace_id": prepare_opik_uuid(start_time, trace_info.message_id), + "trace_id": prepare_opik_uuid(start_time, trace_info.trace_id or trace_info.message_id), "name": TraceTaskName.SUGGESTED_QUESTION_TRACE.value, "type": "tool", "start_time": start_time, @@ -376,7 +375,7 @@ class OpikDataTrace(BaseTraceInstance): start_time = trace_info.start_time or trace_info.message_data.created_at span_data = { - "trace_id": prepare_opik_uuid(start_time, trace_info.message_id), + "trace_id": prepare_opik_uuid(start_time, trace_info.trace_id or trace_info.message_id), "name": TraceTaskName.DATASET_RETRIEVAL_TRACE.value, "type": "tool", "start_time": start_time, @@ -391,7 +390,7 @@ class OpikDataTrace(BaseTraceInstance): def tool_trace(self, trace_info: ToolTraceInfo): span_data = { - "trace_id": prepare_opik_uuid(trace_info.start_time, trace_info.message_id), + "trace_id": prepare_opik_uuid(trace_info.start_time, trace_info.trace_id or trace_info.message_id), "name": trace_info.tool_name, "type": "tool", "start_time": trace_info.start_time, @@ -406,7 +405,7 @@ class OpikDataTrace(BaseTraceInstance): def generate_name_trace(self, trace_info: GenerateNameTraceInfo): trace_data = { - "id": prepare_opik_uuid(trace_info.start_time, trace_info.message_id), + "id": prepare_opik_uuid(trace_info.start_time, trace_info.trace_id or trace_info.message_id), "name": TraceTaskName.GENERATE_NAME_TRACE.value, "start_time": trace_info.start_time, "end_time": trace_info.end_time, diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 91cdc937a6..a607c76beb 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -407,6 +407,7 @@ class TraceTask: def __init__( self, trace_type: Any, + trace_id: Optional[str] = None, message_id: Optional[str] = None, workflow_execution: Optional[WorkflowExecution] = None, conversation_id: Optional[str] = None, @@ -424,6 +425,9 @@ class TraceTask: self.app_id = None self.kwargs = kwargs + external_trace_id = kwargs.get("external_trace_id") + if external_trace_id: + self.trace_id = external_trace_id def execute(self): return self.preprocess() @@ -520,11 +524,8 @@ class TraceTask: "app_id": workflow_run.app_id, } - external_trace_id = self.kwargs.get("external_trace_id") - if external_trace_id: - metadata["external_trace_id"] = external_trace_id - workflow_trace_info = WorkflowTraceInfo( + trace_id=self.trace_id, workflow_data=workflow_run.to_dict(), conversation_id=conversation_id, workflow_id=workflow_id, @@ -584,6 +585,7 @@ class TraceTask: message_tokens = message_data.message_tokens message_trace_info = MessageTraceInfo( + trace_id=self.trace_id, message_id=message_id, message_data=message_data.to_dict(), conversation_model=conversation_mode, @@ -627,6 +629,7 @@ class TraceTask: workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None moderation_trace_info = ModerationTraceInfo( + trace_id=self.trace_id, message_id=workflow_app_log_id or message_id, inputs=inputs, message_data=message_data.to_dict(), @@ -667,6 +670,7 @@ class TraceTask: workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None suggested_question_trace_info = SuggestedQuestionTraceInfo( + trace_id=self.trace_id, message_id=workflow_app_log_id or message_id, message_data=message_data.to_dict(), inputs=message_data.message, @@ -708,6 +712,7 @@ class TraceTask: } dataset_retrieval_trace_info = DatasetRetrievalTraceInfo( + trace_id=self.trace_id, message_id=message_id, inputs=message_data.query or message_data.inputs, documents=[doc.model_dump() for doc in documents] if documents else [], @@ -772,6 +777,7 @@ class TraceTask: ) tool_trace_info = ToolTraceInfo( + trace_id=self.trace_id, message_id=message_id, message_data=message_data.to_dict(), tool_name=tool_name, @@ -807,6 +813,7 @@ class TraceTask: } generate_name_trace_info = GenerateNameTraceInfo( + trace_id=self.trace_id, conversation_id=conversation_id, inputs=inputs, outputs=generate_conversation_name, diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index 470601b17a..8089860481 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -87,8 +87,7 @@ class WeaveDataTrace(BaseTraceInstance): self.generate_name_trace(trace_info) def workflow_trace(self, trace_info: WorkflowTraceInfo): - external_trace_id = trace_info.metadata.get("external_trace_id") - trace_id = external_trace_id or trace_info.message_id or trace_info.workflow_run_id + trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id if trace_info.start_time is None: trace_info.start_time = datetime.now() @@ -245,8 +244,12 @@ class WeaveDataTrace(BaseTraceInstance): attributes["start_time"] = trace_info.start_time attributes["end_time"] = trace_info.end_time attributes["tags"] = ["message", str(trace_info.conversation_mode)] + + trace_id = trace_info.trace_id or message_id + attributes["trace_id"] = trace_id + message_run = WeaveTraceModel( - id=message_id, + id=trace_id, op=str(TraceTaskName.MESSAGE_TRACE.value), input_tokens=trace_info.message_tokens, output_tokens=trace_info.answer_tokens, @@ -274,7 +277,7 @@ class WeaveDataTrace(BaseTraceInstance): ) self.start_call( llm_run, - parent_run_id=message_id, + parent_run_id=trace_id, ) self.finish_call(llm_run) self.finish_call(message_run) @@ -289,6 +292,9 @@ class WeaveDataTrace(BaseTraceInstance): attributes["start_time"] = trace_info.start_time or trace_info.message_data.created_at attributes["end_time"] = trace_info.end_time or trace_info.message_data.updated_at + trace_id = trace_info.trace_id or trace_info.message_id + attributes["trace_id"] = trace_id + moderation_run = WeaveTraceModel( id=str(uuid.uuid4()), op=str(TraceTaskName.MODERATION_TRACE.value), @@ -303,7 +309,7 @@ class WeaveDataTrace(BaseTraceInstance): exception=getattr(trace_info, "error", None), file_list=[], ) - self.start_call(moderation_run, parent_run_id=trace_info.message_id) + self.start_call(moderation_run, parent_run_id=trace_id) self.finish_call(moderation_run) def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo): @@ -316,6 +322,9 @@ class WeaveDataTrace(BaseTraceInstance): attributes["start_time"] = (trace_info.start_time or message_data.created_at,) attributes["end_time"] = (trace_info.end_time or message_data.updated_at,) + trace_id = trace_info.trace_id or trace_info.message_id + attributes["trace_id"] = trace_id + suggested_question_run = WeaveTraceModel( id=str(uuid.uuid4()), op=str(TraceTaskName.SUGGESTED_QUESTION_TRACE.value), @@ -326,7 +335,7 @@ class WeaveDataTrace(BaseTraceInstance): file_list=[], ) - self.start_call(suggested_question_run, parent_run_id=trace_info.message_id) + self.start_call(suggested_question_run, parent_run_id=trace_id) self.finish_call(suggested_question_run) def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo): @@ -338,6 +347,9 @@ class WeaveDataTrace(BaseTraceInstance): attributes["start_time"] = (trace_info.start_time or trace_info.message_data.created_at,) attributes["end_time"] = (trace_info.end_time or trace_info.message_data.updated_at,) + trace_id = trace_info.trace_id or trace_info.message_id + attributes["trace_id"] = trace_id + dataset_retrieval_run = WeaveTraceModel( id=str(uuid.uuid4()), op=str(TraceTaskName.DATASET_RETRIEVAL_TRACE.value), @@ -348,7 +360,7 @@ class WeaveDataTrace(BaseTraceInstance): file_list=[], ) - self.start_call(dataset_retrieval_run, parent_run_id=trace_info.message_id) + self.start_call(dataset_retrieval_run, parent_run_id=trace_id) self.finish_call(dataset_retrieval_run) def tool_trace(self, trace_info: ToolTraceInfo): @@ -357,6 +369,11 @@ class WeaveDataTrace(BaseTraceInstance): attributes["start_time"] = trace_info.start_time attributes["end_time"] = trace_info.end_time + message_id = trace_info.message_id or getattr(trace_info, "conversation_id", None) + message_id = message_id or None + trace_id = trace_info.trace_id or message_id + attributes["trace_id"] = trace_id + tool_run = WeaveTraceModel( id=str(uuid.uuid4()), op=trace_info.tool_name, @@ -366,9 +383,7 @@ class WeaveDataTrace(BaseTraceInstance): attributes=attributes, exception=trace_info.error, ) - message_id = trace_info.message_id or getattr(trace_info, "conversation_id", None) - message_id = message_id or None - self.start_call(tool_run, parent_run_id=message_id) + self.start_call(tool_run, parent_run_id=trace_id) self.finish_call(tool_run) def generate_name_trace(self, trace_info: GenerateNameTraceInfo):