feat: trace document retrieval (#37283)

This commit is contained in:
Yunlu Wen 2026-06-11 10:39:59 +08:00 committed by GitHub
parent 2a46a7d91d
commit 84490179b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 35 additions and 8 deletions

View File

@ -1,10 +1,12 @@
import concurrent.futures
import functools
import logging
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from concurrent.futures import ThreadPoolExecutor
from typing import Any, NotRequired, TypedDict
from flask import Flask, current_app
from opentelemetry import context as otel_context
from sqlalchemy import select
from sqlalchemy.orm import Session, load_only
@ -25,6 +27,7 @@ from core.rag.rerank.rerank_type import RerankMode
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.signature import sign_upload_file_preview_url
from extensions.ext_database import db
from extensions.otel import trace_span
from graphon.model_runtime.entities.model_entities import ModelType
from models.dataset import (
ChildChunk,
@ -90,9 +93,24 @@ default_retrieval_model: DefaultRetrievalModelDict = {
logger = logging.getLogger(__name__)
def _propagate_otel_context[**P, R](func: Callable[P, R]) -> Callable[P, R]:
captured_context = otel_context.get_current()
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
token = otel_context.attach(captured_context)
try:
return func(*args, **kwargs)
finally:
otel_context.detach(token)
return wrapper
class RetrievalService:
# Cache precompiled regular expressions to avoid repeated compilation
@classmethod
@trace_span()
def retrieve(
cls,
retrieval_method: RetrievalMethod,
@ -122,7 +140,7 @@ class RetrievalService:
if query:
futures.append(
executor.submit(
retrieval_service._retrieve,
_propagate_otel_context(retrieval_service._retrieve),
flask_app=current_app._get_current_object(), # type: ignore
retrieval_method=retrieval_method,
dataset=dataset,
@ -142,7 +160,7 @@ class RetrievalService:
for attachment_id in attachment_ids:
futures.append(
executor.submit(
retrieval_service._retrieve,
_propagate_otel_context(retrieval_service._retrieve),
flask_app=current_app._get_current_object(), # type: ignore
retrieval_method=retrieval_method,
dataset=dataset,
@ -264,6 +282,7 @@ class RetrievalService:
return session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
@classmethod
@trace_span()
def keyword_search(
cls,
flask_app: Flask,
@ -291,6 +310,7 @@ class RetrievalService:
exceptions.append(str(e))
@classmethod
@trace_span()
def embedding_search(
cls,
flask_app: Flask,
@ -392,6 +412,7 @@ class RetrievalService:
exceptions.append(str(e))
@classmethod
@trace_span()
def full_text_index_search(
cls,
flask_app: Flask,
@ -754,6 +775,7 @@ class RetrievalService:
db.session.rollback()
raise e
@trace_span()
def _retrieve(
self,
flask_app: Flask,
@ -780,7 +802,7 @@ class RetrievalService:
if retrieval_method == RetrievalMethod.KEYWORD_SEARCH and query:
futures.append(
executor.submit(
self.keyword_search,
_propagate_otel_context(self.keyword_search),
flask_app=current_app._get_current_object(), # type: ignore
dataset_id=dataset.id,
query=query,
@ -794,7 +816,7 @@ class RetrievalService:
if query:
futures.append(
executor.submit(
self.embedding_search,
_propagate_otel_context(self.embedding_search),
flask_app=current_app._get_current_object(), # type: ignore
dataset_id=dataset.id,
query=query,
@ -811,7 +833,7 @@ class RetrievalService:
if attachment_id:
futures.append(
executor.submit(
self.embedding_search,
_propagate_otel_context(self.embedding_search),
flask_app=current_app._get_current_object(), # type: ignore
dataset_id=dataset.id,
query=attachment_id,
@ -828,7 +850,7 @@ class RetrievalService:
if RetrievalMethod.is_support_fulltext_search(retrieval_method) and query:
futures.append(
executor.submit(
self.full_text_index_search,
_propagate_otel_context(self.full_text_index_search),
flask_app=current_app._get_current_object(), # type: ignore
dataset_id=dataset.id,
query=query,

View File

@ -18,6 +18,7 @@ from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from extensions.otel import trace_span
from graphon.model_runtime.entities.model_entities import ModelType
from models.dataset import Dataset, Whitelist
from models.model import UploadFile
@ -244,6 +245,10 @@ class Vector:
def search_by_vector(self, query: str, **kwargs: Any) -> list[Document]:
query_vector = self._embeddings.embed_query(query)
return self._search_by_vector_traced(query_vector, **kwargs)
@trace_span()
def _search_by_vector_traced(self, query_vector: list[float], **kwargs) -> list[Document]:
return self._vector_processor.search_by_vector(query_vector, **kwargs)
def search_by_file(self, file_id: str, **kwargs: Any) -> list[Document]:
@ -260,7 +265,7 @@ class Vector:
"file_id": file_id,
}
)
return self._vector_processor.search_by_vector(multimodal_vector, **kwargs)
return self._search_by_vector_traced(multimodal_vector, **kwargs)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return self._vector_processor.search_by_full_text(query, **kwargs)