diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 21eac29f218..85eb06045ac 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -1,10 +1,12 @@ import concurrent.futures +import functools import logging -from collections.abc import Sequence +from collections.abc import Callable, Sequence from concurrent.futures import ThreadPoolExecutor from typing import Any, NotRequired, TypedDict from flask import Flask, current_app +from opentelemetry import context as otel_context from sqlalchemy import select from sqlalchemy.orm import Session, load_only @@ -25,6 +27,7 @@ from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.signature import sign_upload_file_preview_url from extensions.ext_database import db +from extensions.otel import trace_span from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import ( ChildChunk, @@ -90,9 +93,24 @@ default_retrieval_model: DefaultRetrievalModelDict = { logger = logging.getLogger(__name__) +def _propagate_otel_context[**P, R](func: Callable[P, R]) -> Callable[P, R]: + captured_context = otel_context.get_current() + + @functools.wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + token = otel_context.attach(captured_context) + try: + return func(*args, **kwargs) + finally: + otel_context.detach(token) + + return wrapper + + class RetrievalService: # Cache precompiled regular expressions to avoid repeated compilation @classmethod + @trace_span() def retrieve( cls, retrieval_method: RetrievalMethod, @@ -122,7 +140,7 @@ class RetrievalService: if query: futures.append( executor.submit( - retrieval_service._retrieve, + _propagate_otel_context(retrieval_service._retrieve), flask_app=current_app._get_current_object(), # type: ignore retrieval_method=retrieval_method, dataset=dataset, @@ -142,7 +160,7 @@ class RetrievalService: for attachment_id in attachment_ids: futures.append( executor.submit( - retrieval_service._retrieve, + _propagate_otel_context(retrieval_service._retrieve), flask_app=current_app._get_current_object(), # type: ignore retrieval_method=retrieval_method, dataset=dataset, @@ -264,6 +282,7 @@ class RetrievalService: return session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) @classmethod + @trace_span() def keyword_search( cls, flask_app: Flask, @@ -291,6 +310,7 @@ class RetrievalService: exceptions.append(str(e)) @classmethod + @trace_span() def embedding_search( cls, flask_app: Flask, @@ -392,6 +412,7 @@ class RetrievalService: exceptions.append(str(e)) @classmethod + @trace_span() def full_text_index_search( cls, flask_app: Flask, @@ -754,6 +775,7 @@ class RetrievalService: db.session.rollback() raise e + @trace_span() def _retrieve( self, flask_app: Flask, @@ -780,7 +802,7 @@ class RetrievalService: if retrieval_method == RetrievalMethod.KEYWORD_SEARCH and query: futures.append( executor.submit( - self.keyword_search, + _propagate_otel_context(self.keyword_search), flask_app=current_app._get_current_object(), # type: ignore dataset_id=dataset.id, query=query, @@ -794,7 +816,7 @@ class RetrievalService: if query: futures.append( executor.submit( - self.embedding_search, + _propagate_otel_context(self.embedding_search), flask_app=current_app._get_current_object(), # type: ignore dataset_id=dataset.id, query=query, @@ -811,7 +833,7 @@ class RetrievalService: if attachment_id: futures.append( executor.submit( - self.embedding_search, + _propagate_otel_context(self.embedding_search), flask_app=current_app._get_current_object(), # type: ignore dataset_id=dataset.id, query=attachment_id, @@ -828,7 +850,7 @@ class RetrievalService: if RetrievalMethod.is_support_fulltext_search(retrieval_method) and query: futures.append( executor.submit( - self.full_text_index_search, + _propagate_otel_context(self.full_text_index_search), flask_app=current_app._get_current_object(), # type: ignore dataset_id=dataset.id, query=query, diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index cd73bb9b1ac..4d65951d9a9 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -18,6 +18,7 @@ from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage +from extensions.otel import trace_span from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import Dataset, Whitelist from models.model import UploadFile @@ -244,6 +245,10 @@ class Vector: def search_by_vector(self, query: str, **kwargs: Any) -> list[Document]: query_vector = self._embeddings.embed_query(query) + return self._search_by_vector_traced(query_vector, **kwargs) + + @trace_span() + def _search_by_vector_traced(self, query_vector: list[float], **kwargs) -> list[Document]: return self._vector_processor.search_by_vector(query_vector, **kwargs) def search_by_file(self, file_id: str, **kwargs: Any) -> list[Document]: @@ -260,7 +265,7 @@ class Vector: "file_id": file_id, } ) - return self._vector_processor.search_by_vector(multimodal_vector, **kwargs) + return self._search_by_vector_traced(multimodal_vector, **kwargs) def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: return self._vector_processor.search_by_full_text(query, **kwargs)