mirror of
https://github.com/langgenius/dify.git
synced 2026-05-11 23:18:39 +08:00
Merge remote-tracking branch 'origin/feat/evaluation' into feat/evaluation
This commit is contained in:
commit
87dd0d80e7
@ -1,10 +1,13 @@
|
||||
import json
|
||||
from typing import Any, cast
|
||||
from urllib.parse import quote
|
||||
|
||||
from flask import request
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
@ -24,6 +27,7 @@ from controllers.console.wraps import (
|
||||
setup_required,
|
||||
)
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.evaluation.entities.evaluation_entity import EvaluationCategory, EvaluationConfigData, EvaluationRunRequest
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
@ -32,6 +36,7 @@ from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from fields.app_fields import app_detail_kernel_fields, related_app_list
|
||||
from fields.dataset_fields import (
|
||||
content_fields,
|
||||
@ -52,11 +57,18 @@ from fields.dataset_fields import (
|
||||
)
|
||||
from fields.document_fields import document_status_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, EvaluationRun, EvaluationTargetType, UploadFile
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.api_token_service import ApiTokenCache
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
from services.errors.evaluation import (
|
||||
EvaluationDatasetInvalidError,
|
||||
EvaluationFrameworkNotConfiguredError,
|
||||
EvaluationMaxConcurrentRunsError,
|
||||
EvaluationNotFoundError,
|
||||
)
|
||||
from services.evaluation_service import EvaluationService
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
dataset_base_model = get_or_create_model("DatasetBase", dataset_fields)
|
||||
@ -948,3 +960,429 @@ class DatasetAutoDisableLogApi(Resource):
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
return DatasetService.get_dataset_auto_disable_logs(dataset_id_str), 200
|
||||
|
||||
|
||||
# ---- Knowledge Base Retrieval Evaluation ----
|
||||
|
||||
|
||||
def _serialize_dataset_evaluation_run(run: EvaluationRun) -> dict[str, Any]:
|
||||
return {
|
||||
"id": run.id,
|
||||
"tenant_id": run.tenant_id,
|
||||
"target_type": run.target_type,
|
||||
"target_id": run.target_id,
|
||||
"evaluation_config_id": run.evaluation_config_id,
|
||||
"status": run.status,
|
||||
"dataset_file_id": run.dataset_file_id,
|
||||
"result_file_id": run.result_file_id,
|
||||
"total_items": run.total_items,
|
||||
"completed_items": run.completed_items,
|
||||
"failed_items": run.failed_items,
|
||||
"progress": run.progress,
|
||||
"metrics_summary": json.loads(run.metrics_summary) if run.metrics_summary else {},
|
||||
"error": run.error,
|
||||
"created_by": run.created_by,
|
||||
"started_at": int(run.started_at.timestamp()) if run.started_at else None,
|
||||
"completed_at": int(run.completed_at.timestamp()) if run.completed_at else None,
|
||||
"created_at": int(run.created_at.timestamp()) if run.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
def _serialize_dataset_evaluation_run_item(item: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"id": item.id,
|
||||
"item_index": item.item_index,
|
||||
"inputs": item.inputs_dict,
|
||||
"expected_output": item.expected_output,
|
||||
"actual_output": item.actual_output,
|
||||
"metrics": item.metrics_list,
|
||||
"judgment": item.judgment_dict,
|
||||
"metadata": item.metadata_dict,
|
||||
"error": item.error,
|
||||
"overall_score": item.overall_score,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/template/download")
|
||||
class DatasetEvaluationTemplateDownloadApi(Resource):
|
||||
@console_ns.doc("download_dataset_evaluation_template")
|
||||
@console_ns.response(200, "Template file streamed as XLSX attachment")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id):
|
||||
"""Download evaluation dataset template for knowledge base retrieval."""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
xlsx_content, filename = EvaluationService.generate_retrieval_dataset_template()
|
||||
encoded_filename = quote(filename)
|
||||
response = Response(
|
||||
xlsx_content,
|
||||
mimetype="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
)
|
||||
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
response.headers["Content-Length"] = str(len(xlsx_content))
|
||||
return response
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation")
|
||||
class DatasetEvaluationDetailApi(Resource):
|
||||
@console_ns.doc("get_dataset_evaluation_config")
|
||||
@console_ns.response(200, "Evaluation configuration retrieved")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
"""Get evaluation configuration for the knowledge base."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
config = EvaluationService.get_evaluation_config(
|
||||
session, current_tenant_id, "dataset", dataset_id_str
|
||||
)
|
||||
|
||||
if config is None:
|
||||
return {
|
||||
"evaluation_model": None,
|
||||
"evaluation_model_provider": None,
|
||||
"metrics_config": None,
|
||||
"judgement_conditions": None,
|
||||
}
|
||||
|
||||
return {
|
||||
"evaluation_model": config.evaluation_model,
|
||||
"evaluation_model_provider": config.evaluation_model_provider,
|
||||
"metrics_config": config.metrics_config_dict,
|
||||
"judgement_conditions": config.judgement_conditions_dict,
|
||||
}
|
||||
|
||||
@console_ns.doc("save_dataset_evaluation_config")
|
||||
@console_ns.response(200, "Evaluation configuration saved")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def put(self, dataset_id):
|
||||
"""Save evaluation configuration for the knowledge base."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
body = request.get_json(force=True)
|
||||
try:
|
||||
config_data = EvaluationConfigData.model_validate(body)
|
||||
except Exception as e:
|
||||
raise BadRequest(f"Invalid request body: {e}")
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
config = EvaluationService.save_evaluation_config(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
target_type="dataset",
|
||||
target_id=dataset_id_str,
|
||||
account_id=str(current_user.id),
|
||||
data=config_data,
|
||||
)
|
||||
|
||||
return {
|
||||
"evaluation_model": config.evaluation_model,
|
||||
"evaluation_model_provider": config.evaluation_model_provider,
|
||||
"metrics_config": config.metrics_config_dict,
|
||||
"judgement_conditions": config.judgement_conditions_dict,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/run")
|
||||
class DatasetEvaluationRunApi(Resource):
|
||||
@console_ns.doc("start_dataset_evaluation_run")
|
||||
@console_ns.response(200, "Evaluation run started")
|
||||
@console_ns.response(400, "Invalid request")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id):
|
||||
"""Start an evaluation run for the knowledge base retrieval."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
body = request.get_json(force=True)
|
||||
if not body:
|
||||
raise BadRequest("Request body is required.")
|
||||
|
||||
try:
|
||||
run_request = EvaluationRunRequest.model_validate(body)
|
||||
except Exception as e:
|
||||
raise BadRequest(f"Invalid request body: {e}")
|
||||
|
||||
upload_file = (
|
||||
db.session.query(UploadFile).filter_by(id=run_request.file_id, tenant_id=current_tenant_id).first()
|
||||
)
|
||||
if not upload_file:
|
||||
raise NotFound("Dataset file not found.")
|
||||
|
||||
try:
|
||||
dataset_content = storage.load_once(upload_file.key)
|
||||
except Exception:
|
||||
raise BadRequest("Failed to read dataset file.")
|
||||
|
||||
if not dataset_content:
|
||||
raise BadRequest("Dataset file is empty.")
|
||||
|
||||
try:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
evaluation_run = EvaluationService.start_evaluation_run(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
target_type=EvaluationTargetType.KNOWLEDGE_BASE,
|
||||
target_id=dataset_id_str,
|
||||
account_id=str(current_user.id),
|
||||
dataset_file_content=dataset_content,
|
||||
run_request=run_request,
|
||||
)
|
||||
return _serialize_dataset_evaluation_run(evaluation_run), 200
|
||||
except EvaluationFrameworkNotConfiguredError as e:
|
||||
return {"message": str(e.description)}, 400
|
||||
except EvaluationNotFoundError as e:
|
||||
return {"message": str(e.description)}, 404
|
||||
except EvaluationMaxConcurrentRunsError as e:
|
||||
return {"message": str(e.description)}, 429
|
||||
except EvaluationDatasetInvalidError as e:
|
||||
return {"message": str(e.description)}, 400
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/logs")
|
||||
class DatasetEvaluationLogsApi(Resource):
|
||||
@console_ns.doc("get_dataset_evaluation_logs")
|
||||
@console_ns.response(200, "Evaluation logs retrieved")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
"""Get evaluation run history for the knowledge base."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
page = request.args.get("page", 1, type=int)
|
||||
page_size = request.args.get("page_size", 20, type=int)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
runs, total = EvaluationService.get_evaluation_runs(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
target_type="dataset",
|
||||
target_id=dataset_id_str,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
return {
|
||||
"data": [_serialize_dataset_evaluation_run(run) for run in runs],
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/runs/<uuid:run_id>")
|
||||
class DatasetEvaluationRunDetailApi(Resource):
|
||||
@console_ns.doc("get_dataset_evaluation_run_detail")
|
||||
@console_ns.response(200, "Evaluation run detail retrieved")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset or run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, run_id):
|
||||
"""Get evaluation run detail including per-item results."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
run_id_str = str(run_id)
|
||||
page = request.args.get("page", 1, type=int)
|
||||
page_size = request.args.get("page_size", 50, type=int)
|
||||
|
||||
try:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
run = EvaluationService.get_evaluation_run_detail(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
run_id=run_id_str,
|
||||
)
|
||||
items, total_items = EvaluationService.get_evaluation_run_items(
|
||||
session=session,
|
||||
run_id=run_id_str,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return {
|
||||
"run": _serialize_dataset_evaluation_run(run),
|
||||
"items": {
|
||||
"data": [_serialize_dataset_evaluation_run_item(item) for item in items],
|
||||
"total": total_items,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
},
|
||||
}
|
||||
except EvaluationNotFoundError as e:
|
||||
return {"message": str(e.description)}, 404
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/runs/<uuid:run_id>/cancel")
|
||||
class DatasetEvaluationRunCancelApi(Resource):
|
||||
@console_ns.doc("cancel_dataset_evaluation_run")
|
||||
@console_ns.response(200, "Evaluation run cancelled")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset or run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id, run_id):
|
||||
"""Cancel a running knowledge base evaluation."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
run_id_str = str(run_id)
|
||||
try:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
run = EvaluationService.cancel_evaluation_run(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
run_id=run_id_str,
|
||||
)
|
||||
return _serialize_dataset_evaluation_run(run)
|
||||
except EvaluationNotFoundError as e:
|
||||
return {"message": str(e.description)}, 404
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/metrics")
|
||||
class DatasetEvaluationMetricsApi(Resource):
|
||||
@console_ns.doc("get_dataset_evaluation_metrics")
|
||||
@console_ns.response(200, "Available retrieval metrics retrieved")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
"""Get available evaluation metrics for knowledge base retrieval."""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
return {
|
||||
"metrics": EvaluationService.get_supported_metrics(EvaluationCategory.KNOWLEDGE_BASE)
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/files/<uuid:file_id>")
|
||||
class DatasetEvaluationFileDownloadApi(Resource):
|
||||
@console_ns.doc("download_dataset_evaluation_file")
|
||||
@console_ns.response(200, "File download URL generated")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset or file not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, file_id):
|
||||
"""Download evaluation test file or result file for the knowledge base."""
|
||||
from core.workflow.file import helpers as file_helpers
|
||||
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
file_id_str = str(file_id)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(UploadFile).where(
|
||||
UploadFile.id == file_id_str,
|
||||
UploadFile.tenant_id == current_tenant_id,
|
||||
)
|
||||
upload_file = session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found.")
|
||||
|
||||
download_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id, as_attachment=True)
|
||||
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"size": upload_file.size,
|
||||
"extension": upload_file.extension,
|
||||
"mime_type": upload_file.mime_type,
|
||||
"created_at": int(upload_file.created_at.timestamp()) if upload_file.created_at else None,
|
||||
"download_url": download_url,
|
||||
}
|
||||
|
||||
@ -26,7 +26,7 @@ from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from models import App, Dataset
|
||||
from models.model import UploadFile
|
||||
from models.snippet import CustomizedSnippet
|
||||
from services.errors.evaluation import (
|
||||
@ -150,7 +150,7 @@ def get_evaluation_target(view_func: Callable[P, R]):
|
||||
del kwargs["evaluate_target_type"]
|
||||
del kwargs["evaluate_target_id"]
|
||||
|
||||
target: Union[App, CustomizedSnippet] | None = None
|
||||
target: Union[App, CustomizedSnippet, Dataset] | None = None
|
||||
|
||||
if target_type == "app":
|
||||
target = db.session.query(App).where(App.id == target_id, App.tenant_id == current_tenant_id).first()
|
||||
@ -160,6 +160,8 @@ def get_evaluation_target(view_func: Callable[P, R]):
|
||||
.where(CustomizedSnippet.id == target_id, CustomizedSnippet.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
elif target_type == "knowledge":
|
||||
target = db.session.query(Dataset).where(Dataset.id == target_id, Dataset.tenant_id == current_tenant_id).first()
|
||||
|
||||
if not target:
|
||||
raise NotFound(f"{str(target_type)} not found")
|
||||
@ -330,7 +332,7 @@ class EvaluationRunApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
@edit_permission_required
|
||||
def post(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
def post(self, target: Union[App, CustomizedSnippet, Dataset], target_type: str):
|
||||
"""
|
||||
Start an evaluation run.
|
||||
|
||||
|
||||
@ -36,6 +36,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from factories import variable_factory
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField
|
||||
@ -535,6 +536,6 @@ class SnippetWorkflowTaskStopApi(Resource):
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
GraphEngineManager(redis_client).send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@ -22,7 +22,7 @@ class BaseEvaluationInstance(ABC):
|
||||
def evaluate_llm(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_name: str,
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
@ -34,7 +34,7 @@ class BaseEvaluationInstance(ABC):
|
||||
def evaluate_retrieval(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_name: str,
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
@ -46,7 +46,7 @@ class BaseEvaluationInstance(ABC):
|
||||
def evaluate_agent(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_name: str,
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
|
||||
@ -12,7 +12,7 @@ class EvaluationCategory(StrEnum):
|
||||
AGENT = "agent"
|
||||
WORKFLOW = "workflow"
|
||||
SNIPPET = "snippet"
|
||||
RETRIEVAL_TEST = "retrieval_test"
|
||||
KNOWLEDGE_BASE = "knowledge_base"
|
||||
|
||||
|
||||
class EvaluationMetricName(StrEnum):
|
||||
@ -21,6 +21,65 @@ class EvaluationMetricName(StrEnum):
|
||||
Each framework maps these names to its own internal implementation.
|
||||
A framework that does not support a given metric should log a warning
|
||||
and skip it rather than raising an error.
|
||||
|
||||
── LLM / general text-quality metrics ──────────────────────────────────
|
||||
FAITHFULNESS
|
||||
Measures whether every claim in the model's response is grounded in
|
||||
the provided retrieved context. A high score means the answer
|
||||
contains no hallucinated content — each statement can be traced back
|
||||
to a passage in the context.
|
||||
Required fields: user_input, response, retrieved_contexts.
|
||||
|
||||
ANSWER_RELEVANCY
|
||||
Measures how well the model's response addresses the user's question.
|
||||
A high score means the answer stays on-topic; a low score indicates
|
||||
irrelevant content or a failure to answer the actual question.
|
||||
Required fields: user_input, response.
|
||||
|
||||
ANSWER_CORRECTNESS
|
||||
Measures the factual accuracy and completeness of the model's answer
|
||||
relative to a ground-truth reference. It combines semantic similarity
|
||||
with key-fact coverage, so both meaning and content matter.
|
||||
Required fields: user_input, response, reference (expected_output).
|
||||
|
||||
SEMANTIC_SIMILARITY
|
||||
Measures the cosine similarity between the model's response and the
|
||||
reference answer in an embedding space. It evaluates whether the two
|
||||
texts convey the same meaning, independent of factual correctness.
|
||||
Required fields: response, reference (expected_output).
|
||||
|
||||
── Retrieval-quality metrics ────────────────────────────────────────────
|
||||
CONTEXT_PRECISION
|
||||
Measures the proportion of retrieved context chunks that are actually
|
||||
relevant to the question (precision). A high score means the retrieval
|
||||
pipeline returns little noise.
|
||||
Required fields: user_input, reference, retrieved_contexts.
|
||||
|
||||
CONTEXT_RECALL
|
||||
Measures the proportion of ground-truth information that is covered by
|
||||
the retrieved context chunks (recall). A high score means the retrieval
|
||||
pipeline does not miss important supporting evidence.
|
||||
Required fields: user_input, reference, retrieved_contexts.
|
||||
|
||||
CONTEXT_RELEVANCE
|
||||
Measures how relevant each individual retrieved chunk is to the query.
|
||||
Similar to CONTEXT_PRECISION but evaluated at the chunk level rather
|
||||
than against a reference answer.
|
||||
Required fields: user_input, retrieved_contexts.
|
||||
|
||||
── Agent-quality metrics ────────────────────────────────────────────────
|
||||
TOOL_CORRECTNESS
|
||||
Measures the correctness of the tool calls made by the agent during
|
||||
task execution — both the choice of tool and the arguments passed.
|
||||
A high score means the agent's tool-use strategy matches the expected
|
||||
behavior.
|
||||
Required fields: actual tool calls vs. expected tool calls.
|
||||
|
||||
TASK_COMPLETION
|
||||
Measures whether the agent ultimately achieves the user's stated goal.
|
||||
It evaluates the reasoning chain, intermediate steps, and final output
|
||||
holistically; a high score means the task was fully accomplished.
|
||||
Required fields: user_input, actual_output.
|
||||
"""
|
||||
|
||||
# LLM / general text-quality metrics
|
||||
@ -41,21 +100,21 @@ class EvaluationMetricName(StrEnum):
|
||||
|
||||
# Per-category canonical metric lists used by get_supported_metrics().
|
||||
LLM_METRIC_NAMES: list[EvaluationMetricName] = [
|
||||
EvaluationMetricName.FAITHFULNESS,
|
||||
EvaluationMetricName.ANSWER_RELEVANCY,
|
||||
EvaluationMetricName.ANSWER_CORRECTNESS,
|
||||
EvaluationMetricName.SEMANTIC_SIMILARITY,
|
||||
EvaluationMetricName.FAITHFULNESS, # Every claim is grounded in context; no hallucinations
|
||||
EvaluationMetricName.ANSWER_RELEVANCY, # Response stays on-topic and addresses the question
|
||||
EvaluationMetricName.ANSWER_CORRECTNESS, # Factual accuracy and completeness vs. reference
|
||||
EvaluationMetricName.SEMANTIC_SIMILARITY, # Semantic closeness to the reference answer
|
||||
]
|
||||
|
||||
RETRIEVAL_METRIC_NAMES: list[EvaluationMetricName] = [
|
||||
EvaluationMetricName.CONTEXT_PRECISION,
|
||||
EvaluationMetricName.CONTEXT_RECALL,
|
||||
EvaluationMetricName.CONTEXT_RELEVANCE,
|
||||
EvaluationMetricName.CONTEXT_PRECISION, # Fraction of retrieved chunks that are relevant (precision)
|
||||
EvaluationMetricName.CONTEXT_RECALL, # Fraction of ground-truth info covered by retrieval (recall)
|
||||
EvaluationMetricName.CONTEXT_RELEVANCE, # Per-chunk relevance to the query
|
||||
]
|
||||
|
||||
AGENT_METRIC_NAMES: list[EvaluationMetricName] = [
|
||||
EvaluationMetricName.TOOL_CORRECTNESS,
|
||||
EvaluationMetricName.TASK_COMPLETION,
|
||||
EvaluationMetricName.TOOL_CORRECTNESS, # Correct tool selection and arguments
|
||||
EvaluationMetricName.TASK_COMPLETION, # Whether the agent fully achieves the user's goal
|
||||
]
|
||||
|
||||
WORKFLOW_METRIC_NAMES: list[EvaluationMetricName] = [
|
||||
|
||||
@ -62,47 +62,47 @@ class DeepEvalEvaluator(BaseEvaluationInstance):
|
||||
def evaluate_llm(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_name: str,
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
return self._evaluate(items, metric_name, model_provider, model_name, tenant_id, EvaluationCategory.LLM)
|
||||
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.LLM)
|
||||
|
||||
def evaluate_retrieval(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_name: str,
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
return self._evaluate(items, metric_name, model_provider, model_name, tenant_id, EvaluationCategory.RETRIEVAL)
|
||||
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.RETRIEVAL)
|
||||
|
||||
def evaluate_agent(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_name: str,
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
return self._evaluate(items, metric_name, model_provider, model_name, tenant_id, EvaluationCategory.AGENT)
|
||||
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.AGENT)
|
||||
|
||||
def evaluate_workflow(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_name: str,
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
return self._evaluate(items, metric_name, model_provider, model_name, tenant_id, EvaluationCategory.WORKFLOW)
|
||||
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.WORKFLOW)
|
||||
|
||||
def _evaluate(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_name: str,
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
@ -110,7 +110,7 @@ class DeepEvalEvaluator(BaseEvaluationInstance):
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Core evaluation logic using DeepEval."""
|
||||
model_wrapper = DifyModelWrapper(model_provider, model_name, tenant_id)
|
||||
requested_metrics = [metric_name] if metric_name else self.get_supported_metrics(category)
|
||||
requested_metrics = metric_names or self.get_supported_metrics(category)
|
||||
|
||||
try:
|
||||
return self._evaluate_with_deepeval(items, requested_metrics, category)
|
||||
|
||||
@ -55,47 +55,47 @@ class RagasEvaluator(BaseEvaluationInstance):
|
||||
def evaluate_llm(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_name: str,
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
return self._evaluate(items, metric_name, model_provider, model_name, tenant_id, EvaluationCategory.LLM)
|
||||
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.LLM)
|
||||
|
||||
def evaluate_retrieval(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_name: str,
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
return self._evaluate(items, metric_name, model_provider, model_name, tenant_id, EvaluationCategory.RETRIEVAL)
|
||||
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.RETRIEVAL)
|
||||
|
||||
def evaluate_agent(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_name: str,
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
return self._evaluate(items, metric_name, model_provider, model_name, tenant_id, EvaluationCategory.AGENT)
|
||||
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.AGENT)
|
||||
|
||||
def evaluate_workflow(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_name: str,
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> list[EvaluationItemResult]:
|
||||
return self._evaluate(items, metric_name, model_provider, model_name, tenant_id, EvaluationCategory.WORKFLOW)
|
||||
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.WORKFLOW)
|
||||
|
||||
def _evaluate(
|
||||
self,
|
||||
items: list[EvaluationItemInput],
|
||||
metric_name: str,
|
||||
metric_names: list[str],
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
@ -103,7 +103,7 @@ class RagasEvaluator(BaseEvaluationInstance):
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Core evaluation logic using RAGAS."""
|
||||
model_wrapper = DifyModelWrapper(model_provider, model_name, tenant_id)
|
||||
requested_metrics = [metric_name] if metric_name else self.get_supported_metrics(category)
|
||||
requested_metrics = metric_names or self.get_supported_metrics(category)
|
||||
|
||||
try:
|
||||
return self._evaluate_with_ragas(items, requested_metrics, model_wrapper, category)
|
||||
|
||||
@ -84,7 +84,7 @@ class AgentEvaluationRunner(BaseEvaluationRunner):
|
||||
raise ValueError("Default metric is required for agent evaluation")
|
||||
merged_items = self._merge_results_into_items(node_run_result_list)
|
||||
return self.evaluation_instance.evaluate_agent(
|
||||
merged_items, default_metric.metric, model_provider, model_name, tenant_id
|
||||
merged_items, [default_metric.metric], model_provider, model_name, tenant_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -21,6 +21,7 @@ from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
CustomizedMetrics,
|
||||
DefaultMetric,
|
||||
EvaluationDatasetInput,
|
||||
EvaluationItemResult,
|
||||
)
|
||||
from core.evaluation.entities.judgment_entity import JudgmentConfig
|
||||
@ -66,6 +67,7 @@ class BaseEvaluationRunner(ABC):
|
||||
model_name: str = "",
|
||||
node_run_result_mapping_list: list[dict[str, NodeRunResult]] | None = None,
|
||||
judgment_config: JudgmentConfig | None = None,
|
||||
input_list: list[EvaluationDatasetInput] | None = None,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Orchestrate target execution + metric evaluation + judgment for all items."""
|
||||
evaluation_run = self.session.query(EvaluationRun).filter_by(id=evaluation_run_id).first()
|
||||
@ -127,14 +129,15 @@ class BaseEvaluationRunner(ABC):
|
||||
)
|
||||
|
||||
# Phase 4: Persist individual items
|
||||
dataset_items = input_list or []
|
||||
for result in results:
|
||||
item_input = next((item for item in evaluation_run.input_list if item.index == result.index), None)
|
||||
item_input = next((item for item in dataset_items if item.index == result.index), None)
|
||||
run_item = EvaluationRunItem(
|
||||
evaluation_run_id=evaluation_run_id,
|
||||
item_index=result.index,
|
||||
inputs=json.dumps(item_input.inputs) if item_input else None,
|
||||
expected_output=item_input.expected_output if item_input else None,
|
||||
context=json.dumps(item_input.context) if item_input and item_input.context else None,
|
||||
context=json.dumps(item_input.context) if item_input and getattr(item_input, "context", None) else None,
|
||||
actual_output=result.actual_output,
|
||||
metrics=json.dumps([m.model_dump() for m in result.metrics]) if result.metrics else None,
|
||||
judgment=json.dumps(result.judgment.model_dump()) if result.judgment else None,
|
||||
|
||||
@ -41,7 +41,7 @@ class LLMEvaluationRunner(BaseEvaluationRunner):
|
||||
raise ValueError("Default metric is required for LLM evaluation")
|
||||
merged_items = self._merge_results_into_items(node_run_result_list)
|
||||
return self.evaluation_instance.evaluate_llm(
|
||||
merged_items, default_metric.metric, model_provider, model_name, tenant_id
|
||||
merged_items, [default_metric.metric], model_provider, model_name, tenant_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -35,8 +35,6 @@ class RetrievalEvaluationRunner(BaseEvaluationRunner):
|
||||
"""Compute retrieval evaluation metrics."""
|
||||
if not node_run_result_list:
|
||||
return []
|
||||
if not default_metric:
|
||||
raise ValueError("Default metric is required for retrieval evaluation")
|
||||
|
||||
merged_items = []
|
||||
for i, node_result in enumerate(node_run_result_list):
|
||||
@ -58,7 +56,7 @@ class RetrievalEvaluationRunner(BaseEvaluationRunner):
|
||||
)
|
||||
|
||||
return self.evaluation_instance.evaluate_retrieval(
|
||||
merged_items, default_metric.metric, model_provider, model_name, tenant_id
|
||||
merged_items, [default_metric.metric]if default_metric else [], model_provider, model_name, tenant_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -115,7 +115,7 @@ class SnippetEvaluationRunner(BaseEvaluationRunner):
|
||||
raise ValueError("Default metric is required for snippet evaluation")
|
||||
merged_items = self._merge_results_into_items(node_run_result_list)
|
||||
return self.evaluation_instance.evaluate_workflow(
|
||||
merged_items, default_metric.metric, model_provider, model_name, tenant_id
|
||||
merged_items, [default_metric.metric], model_provider, model_name, tenant_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -40,7 +40,7 @@ class WorkflowEvaluationRunner(BaseEvaluationRunner):
|
||||
raise ValueError("Default metric is required for workflow evaluation")
|
||||
merged_items = self._merge_results_into_items(node_run_result_list)
|
||||
return self.evaluation_instance.evaluate_workflow(
|
||||
merged_items, default_metric.metric, model_provider, model_name, tenant_id
|
||||
merged_items, [default_metric.metric], model_provider, model_name, tenant_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -26,7 +26,7 @@ class EvaluationRunStatus(StrEnum):
|
||||
class EvaluationTargetType(StrEnum):
|
||||
APP = "app"
|
||||
SNIPPETS = "snippets"
|
||||
|
||||
KNOWLEDGE_BASE = "knowledge_base"
|
||||
|
||||
class EvaluationConfiguration(Base):
|
||||
"""Stores evaluation configuration for each target (App or Snippet)."""
|
||||
@ -105,7 +105,6 @@ class EvaluationRun(Base):
|
||||
error: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
celery_task_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
metrics_summary: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
started_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
@ -209,6 +209,57 @@ class EvaluationService:
|
||||
|
||||
return output.getvalue()
|
||||
|
||||
@classmethod
|
||||
def generate_retrieval_dataset_template(cls) -> tuple[bytes, str]:
|
||||
"""Generate evaluation dataset XLSX template for knowledge base retrieval.
|
||||
|
||||
The template contains three columns: ``index``, ``query``, and
|
||||
``expected_output``. Callers upload a filled copy and start an
|
||||
evaluation run with ``target_type="dataset"``.
|
||||
|
||||
:returns: (xlsx_content_bytes, filename)
|
||||
"""
|
||||
wb = Workbook()
|
||||
ws = wb.active
|
||||
if ws is None:
|
||||
ws = wb.create_sheet("Evaluation Dataset")
|
||||
ws.title = "Evaluation Dataset"
|
||||
|
||||
header_font = Font(bold=True, color="FFFFFF")
|
||||
header_fill = PatternFill(start_color="4472C4", end_color="4472C4", fill_type="solid")
|
||||
header_alignment = Alignment(horizontal="center", vertical="center")
|
||||
thin_border = Border(
|
||||
left=Side(style="thin"),
|
||||
right=Side(style="thin"),
|
||||
top=Side(style="thin"),
|
||||
bottom=Side(style="thin"),
|
||||
)
|
||||
|
||||
headers = ["index", "query", "expected_output"]
|
||||
for col_idx, header in enumerate(headers, start=1):
|
||||
cell = ws.cell(row=1, column=col_idx, value=header)
|
||||
cell.font = header_font
|
||||
cell.fill = header_fill
|
||||
cell.alignment = header_alignment
|
||||
cell.border = thin_border
|
||||
|
||||
ws.column_dimensions["A"].width = 10
|
||||
ws.column_dimensions["B"].width = 30
|
||||
ws.column_dimensions["C"].width = 30
|
||||
|
||||
# Add one sample row
|
||||
for col_idx in range(1, len(headers) + 1):
|
||||
cell = ws.cell(row=2, column=col_idx, value="")
|
||||
cell.border = thin_border
|
||||
if col_idx == 1:
|
||||
cell.value = 1
|
||||
cell.alignment = Alignment(horizontal="center")
|
||||
|
||||
output = io.BytesIO()
|
||||
wb.save(output)
|
||||
output.seek(0)
|
||||
return output.getvalue(), "retrieval-evaluation-dataset.xlsx"
|
||||
|
||||
# ---- Evaluation Configuration CRUD ----
|
||||
|
||||
@classmethod
|
||||
@ -757,3 +808,86 @@ class EvaluationService:
|
||||
|
||||
wb.close()
|
||||
return items
|
||||
|
||||
@classmethod
|
||||
def execute_retrieval_test_targets(
|
||||
cls,
|
||||
dataset_id: str,
|
||||
account_id: str,
|
||||
input_list: list[EvaluationDatasetInput],
|
||||
max_workers: int = 5,
|
||||
) -> list[NodeRunResult]:
|
||||
"""Run hit testing against a knowledge base for every input item in parallel.
|
||||
|
||||
Each item must supply a ``query`` key in its ``inputs`` dict. The
|
||||
retrieved segments are normalised into the same ``NodeRunResult`` format
|
||||
that :class:`RetrievalEvaluationRunner` expects:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
NodeRunResult(
|
||||
inputs={"query": "..."},
|
||||
outputs={"result": [{"content": "...", "score": ...}, ...]},
|
||||
)
|
||||
|
||||
:returns: Ordered list of ``NodeRunResult`` — one per input item.
|
||||
If retrieval fails for an item the result has an empty ``result``
|
||||
list so the runner can still persist a (metric-less) row.
|
||||
"""
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from flask import current_app
|
||||
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
|
||||
def _worker(item: EvaluationDatasetInput) -> NodeRunResult:
|
||||
with flask_app.app_context():
|
||||
from extensions.ext_database import db as flask_db
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
||||
dataset = flask_db.session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
if not dataset:
|
||||
raise ValueError(f"Dataset {dataset_id} not found")
|
||||
|
||||
account = flask_db.session.query(Account).filter_by(id=account_id).first()
|
||||
if not account:
|
||||
raise ValueError(f"Account {account_id} not found")
|
||||
|
||||
query = str(item.inputs.get("query", ""))
|
||||
response = HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query=query,
|
||||
account=account,
|
||||
retrieval_model=None, # Use dataset's configured retrieval model
|
||||
external_retrieval_model={},
|
||||
limit=10,
|
||||
)
|
||||
|
||||
records = response.get("records", [])
|
||||
result_list = [
|
||||
{
|
||||
"content": r.get("segment", {}).get("content", "") or r.get("content", ""),
|
||||
"score": r.get("score"),
|
||||
}
|
||||
for r in records
|
||||
if r.get("segment", {}).get("content") or r.get("content")
|
||||
]
|
||||
|
||||
return NodeRunResult(
|
||||
inputs={"query": query},
|
||||
outputs={"result": result_list},
|
||||
)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = [executor.submit(_worker, item) for item in input_list]
|
||||
results: list[NodeRunResult] = []
|
||||
for item, future in zip(input_list, futures):
|
||||
try:
|
||||
results.append(future.result())
|
||||
except Exception:
|
||||
logger.exception("Retrieval test failed for item %d (dataset=%s)", item.index, dataset_id)
|
||||
results.append(NodeRunResult(inputs={}, outputs={"result": []}))
|
||||
|
||||
return results
|
||||
|
||||
@ -80,20 +80,27 @@ def _execute_evaluation(session: Any, run_data: EvaluationRunData) -> None:
|
||||
if evaluation_instance is None:
|
||||
raise ValueError("Evaluation framework not configured")
|
||||
|
||||
evaluation_service = EvaluationService()
|
||||
node_run_result_mapping_list: list[dict[str, NodeRunResult]] = evaluation_service.execute_targets(
|
||||
tenant_id=run_data.tenant_id,
|
||||
target_type=run_data.target_type,
|
||||
target_id=run_data.target_id,
|
||||
input_list=run_data.input_list,
|
||||
)
|
||||
|
||||
results: list[EvaluationItemResult] = _execute_evaluation_runner(
|
||||
session=session,
|
||||
run_data=run_data,
|
||||
evaluation_instance=evaluation_instance,
|
||||
node_run_result_mapping_list=node_run_result_mapping_list,
|
||||
)
|
||||
if run_data.target_type == "dataset":
|
||||
results: list[EvaluationItemResult] = _execute_retrieval_test(
|
||||
session=session,
|
||||
evaluation_run=evaluation_run,
|
||||
run_data=run_data,
|
||||
evaluation_instance=evaluation_instance,
|
||||
)
|
||||
else:
|
||||
evaluation_service = EvaluationService()
|
||||
node_run_result_mapping_list: list[dict[str, NodeRunResult]] = evaluation_service.execute_targets(
|
||||
tenant_id=run_data.tenant_id,
|
||||
target_type=run_data.target_type,
|
||||
target_id=run_data.target_id,
|
||||
input_list=run_data.input_list,
|
||||
)
|
||||
results = _execute_evaluation_runner(
|
||||
session=session,
|
||||
run_data=run_data,
|
||||
evaluation_instance=evaluation_instance,
|
||||
node_run_result_mapping_list=node_run_result_mapping_list,
|
||||
)
|
||||
|
||||
# Compute summary metrics
|
||||
metrics_summary = _compute_metrics_summary(results, run_data.judgment_config)
|
||||
@ -148,6 +155,7 @@ def _execute_evaluation_runner(
|
||||
model_name=run_data.evaluation_model,
|
||||
node_run_result_list=node_run_result_list,
|
||||
judgment_config=run_data.judgment_config,
|
||||
input_list=run_data.input_list,
|
||||
)
|
||||
)
|
||||
if customized_metrics:
|
||||
@ -163,6 +171,7 @@ def _execute_evaluation_runner(
|
||||
node_run_result_list=None,
|
||||
node_run_result_mapping_list=node_run_result_mapping_list,
|
||||
judgment_config=run_data.judgment_config,
|
||||
input_list=run_data.input_list,
|
||||
)
|
||||
)
|
||||
return results
|
||||
@ -177,7 +186,7 @@ def _create_runner(
|
||||
match category:
|
||||
case EvaluationCategory.LLM:
|
||||
return LLMEvaluationRunner(evaluation_instance, session)
|
||||
case EvaluationCategory.RETRIEVAL:
|
||||
case EvaluationCategory.RETRIEVAL | EvaluationCategory.KNOWLEDGE_BASE:
|
||||
return RetrievalEvaluationRunner(evaluation_instance, session)
|
||||
case EvaluationCategory.AGENT:
|
||||
return AgentEvaluationRunner(evaluation_instance, session)
|
||||
@ -189,6 +198,43 @@ def _create_runner(
|
||||
raise ValueError(f"Unknown evaluation category: {category}")
|
||||
|
||||
|
||||
def _execute_retrieval_test(
|
||||
session: Any,
|
||||
evaluation_run: EvaluationRun,
|
||||
run_data: EvaluationRunData,
|
||||
evaluation_instance: BaseEvaluationInstance,
|
||||
) -> list[EvaluationItemResult]:
|
||||
"""Execute knowledge base retrieval for all items, then evaluate metrics.
|
||||
|
||||
Unlike the workflow-based path, there are no workflow nodes to traverse.
|
||||
Hit testing is run directly for each dataset item and the results are fed
|
||||
straight into :class:`RetrievalEvaluationRunner`.
|
||||
"""
|
||||
node_run_result_list = EvaluationService.execute_retrieval_test_targets(
|
||||
dataset_id=run_data.target_id,
|
||||
account_id=evaluation_run.created_by,
|
||||
input_list=run_data.input_list,
|
||||
)
|
||||
|
||||
results: list[EvaluationItemResult] = []
|
||||
runner = RetrievalEvaluationRunner(evaluation_instance, session)
|
||||
results.extend(
|
||||
runner.run(
|
||||
evaluation_run_id=run_data.evaluation_run_id,
|
||||
tenant_id=run_data.tenant_id,
|
||||
target_id=run_data.target_id,
|
||||
target_type=run_data.target_type,
|
||||
default_metric=None,
|
||||
model_provider=run_data.evaluation_model_provider,
|
||||
model_name=run_data.evaluation_model,
|
||||
node_run_result_list=node_run_result_list,
|
||||
judgment_config=run_data.judgment_config,
|
||||
input_list=run_data.input_list,
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def _mark_run_failed(session: Any, run_id: str, error: str) -> None:
|
||||
"""Mark an evaluation run as failed."""
|
||||
try:
|
||||
@ -216,7 +262,9 @@ def _compute_metrics_summary(
|
||||
summary: dict[str, Any] = {}
|
||||
|
||||
if judgment_config is not None and judgment_config.conditions:
|
||||
evaluated_results: list[EvaluationItemResult] = [result for result in results if result.error is None and result.metrics]
|
||||
evaluated_results: list[EvaluationItemResult] = [
|
||||
result for result in results if result.error is None and result.metrics
|
||||
]
|
||||
passed_items = sum(1 for result in evaluated_results if result.judgment.passed)
|
||||
evaluated_items = len(evaluated_results)
|
||||
summary["_judgment"] = {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user