mirror of
https://github.com/langgenius/dify.git
synced 2026-06-12 19:53:38 +08:00
feat: trace document retrieval (#37283)
This commit is contained in:
parent
2a46a7d91d
commit
84490179b0
@ -1,10 +1,12 @@
|
||||
import concurrent.futures
|
||||
import functools
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, NotRequired, TypedDict
|
||||
|
||||
from flask import Flask, current_app
|
||||
from opentelemetry import context as otel_context
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, load_only
|
||||
|
||||
@ -25,6 +27,7 @@ from core.rag.rerank.rerank_type import RerankMode
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.signature import sign_upload_file_preview_url
|
||||
from extensions.ext_database import db
|
||||
from extensions.otel import trace_span
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from models.dataset import (
|
||||
ChildChunk,
|
||||
@ -90,9 +93,24 @@ default_retrieval_model: DefaultRetrievalModelDict = {
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _propagate_otel_context[**P, R](func: Callable[P, R]) -> Callable[P, R]:
|
||||
captured_context = otel_context.get_current()
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
token = otel_context.attach(captured_context)
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
finally:
|
||||
otel_context.detach(token)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class RetrievalService:
|
||||
# Cache precompiled regular expressions to avoid repeated compilation
|
||||
@classmethod
|
||||
@trace_span()
|
||||
def retrieve(
|
||||
cls,
|
||||
retrieval_method: RetrievalMethod,
|
||||
@ -122,7 +140,7 @@ class RetrievalService:
|
||||
if query:
|
||||
futures.append(
|
||||
executor.submit(
|
||||
retrieval_service._retrieve,
|
||||
_propagate_otel_context(retrieval_service._retrieve),
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
retrieval_method=retrieval_method,
|
||||
dataset=dataset,
|
||||
@ -142,7 +160,7 @@ class RetrievalService:
|
||||
for attachment_id in attachment_ids:
|
||||
futures.append(
|
||||
executor.submit(
|
||||
retrieval_service._retrieve,
|
||||
_propagate_otel_context(retrieval_service._retrieve),
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
retrieval_method=retrieval_method,
|
||||
dataset=dataset,
|
||||
@ -264,6 +282,7 @@ class RetrievalService:
|
||||
return session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
|
||||
|
||||
@classmethod
|
||||
@trace_span()
|
||||
def keyword_search(
|
||||
cls,
|
||||
flask_app: Flask,
|
||||
@ -291,6 +310,7 @@ class RetrievalService:
|
||||
exceptions.append(str(e))
|
||||
|
||||
@classmethod
|
||||
@trace_span()
|
||||
def embedding_search(
|
||||
cls,
|
||||
flask_app: Flask,
|
||||
@ -392,6 +412,7 @@ class RetrievalService:
|
||||
exceptions.append(str(e))
|
||||
|
||||
@classmethod
|
||||
@trace_span()
|
||||
def full_text_index_search(
|
||||
cls,
|
||||
flask_app: Flask,
|
||||
@ -754,6 +775,7 @@ class RetrievalService:
|
||||
db.session.rollback()
|
||||
raise e
|
||||
|
||||
@trace_span()
|
||||
def _retrieve(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
@ -780,7 +802,7 @@ class RetrievalService:
|
||||
if retrieval_method == RetrievalMethod.KEYWORD_SEARCH and query:
|
||||
futures.append(
|
||||
executor.submit(
|
||||
self.keyword_search,
|
||||
_propagate_otel_context(self.keyword_search),
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
@ -794,7 +816,7 @@ class RetrievalService:
|
||||
if query:
|
||||
futures.append(
|
||||
executor.submit(
|
||||
self.embedding_search,
|
||||
_propagate_otel_context(self.embedding_search),
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
@ -811,7 +833,7 @@ class RetrievalService:
|
||||
if attachment_id:
|
||||
futures.append(
|
||||
executor.submit(
|
||||
self.embedding_search,
|
||||
_propagate_otel_context(self.embedding_search),
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
dataset_id=dataset.id,
|
||||
query=attachment_id,
|
||||
@ -828,7 +850,7 @@ class RetrievalService:
|
||||
if RetrievalMethod.is_support_fulltext_search(retrieval_method) and query:
|
||||
futures.append(
|
||||
executor.submit(
|
||||
self.full_text_index_search,
|
||||
_propagate_otel_context(self.full_text_index_search),
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
|
||||
@ -18,6 +18,7 @@ from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.otel import trace_span
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from models.dataset import Dataset, Whitelist
|
||||
from models.model import UploadFile
|
||||
@ -244,6 +245,10 @@ class Vector:
|
||||
|
||||
def search_by_vector(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
query_vector = self._embeddings.embed_query(query)
|
||||
return self._search_by_vector_traced(query_vector, **kwargs)
|
||||
|
||||
@trace_span()
|
||||
def _search_by_vector_traced(self, query_vector: list[float], **kwargs) -> list[Document]:
|
||||
return self._vector_processor.search_by_vector(query_vector, **kwargs)
|
||||
|
||||
def search_by_file(self, file_id: str, **kwargs: Any) -> list[Document]:
|
||||
@ -260,7 +265,7 @@ class Vector:
|
||||
"file_id": file_id,
|
||||
}
|
||||
)
|
||||
return self._vector_processor.search_by_vector(multimodal_vector, **kwargs)
|
||||
return self._search_by_vector_traced(multimodal_vector, **kwargs)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
return self._vector_processor.search_by_full_text(query, **kwargs)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user