mirror of
https://github.com/langgenius/dify.git
synced 2026-05-11 14:58:23 +08:00
evaluation runtime
This commit is contained in:
parent
4e593df662
commit
13c0d6eddb
@ -20,7 +20,7 @@ from controllers.console.wraps import (
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.evaluation.entities.evaluation_entity import EvaluationCategory, EvaluationRunRequest
|
||||
from core.evaluation.entities.evaluation_entity import EvaluationCategory, EvaluationConfigData, EvaluationRunRequest
|
||||
from core.workflow.file import helpers as file_helpers
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
@ -261,7 +261,12 @@ class EvaluationDetailApi(Resource):
|
||||
Save evaluation configuration for the target.
|
||||
"""
|
||||
current_account, current_tenant_id = current_account_with_tenant()
|
||||
data = request.get_json(force=True)
|
||||
body = request.get_json(force=True)
|
||||
|
||||
try:
|
||||
config_data = EvaluationConfigData.model_validate(body)
|
||||
except Exception as e:
|
||||
raise BadRequest(f"Invalid request body: {e}")
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
config = EvaluationService.save_evaluation_config(
|
||||
@ -270,7 +275,7 @@ class EvaluationDetailApi(Resource):
|
||||
target_type=target_type,
|
||||
target_id=str(target.id),
|
||||
account_id=str(current_account.id),
|
||||
data=data,
|
||||
data=config_data,
|
||||
)
|
||||
|
||||
return {
|
||||
|
||||
@ -65,9 +65,8 @@ class CustomizedMetrics(BaseModel):
|
||||
output_fields: list[CustomizedMetricOutputField]
|
||||
|
||||
|
||||
class EvaluationRunRequest(BaseModel):
|
||||
"""Request body for starting an evaluation run."""
|
||||
file_id: str
|
||||
class EvaluationConfigData(BaseModel):
|
||||
"""Structured data for saving evaluation configuration."""
|
||||
evaluation_model: str = ""
|
||||
evaluation_model_provider: str = ""
|
||||
default_metrics: list[DefaultMetric] = Field(default_factory=list)
|
||||
@ -75,6 +74,11 @@ class EvaluationRunRequest(BaseModel):
|
||||
judgment_config: JudgmentConfig | None = None
|
||||
|
||||
|
||||
class EvaluationRunRequest(EvaluationConfigData):
|
||||
"""Request body for starting an evaluation run."""
|
||||
file_id: str
|
||||
|
||||
|
||||
class EvaluationRunData(BaseModel):
|
||||
"""Serializable data for Celery task."""
|
||||
evaluation_run_id: str
|
||||
@ -84,6 +88,7 @@ class EvaluationRunData(BaseModel):
|
||||
evaluation_category: EvaluationCategory
|
||||
evaluation_model_provider: str
|
||||
evaluation_model: str
|
||||
metrics_config: dict[str, Any] = Field(default_factory=dict)
|
||||
default_metrics: list[dict[str, Any]] = Field(default_factory=list)
|
||||
customized_metrics: dict[str, Any] | None = None
|
||||
judgment_config: JudgmentConfig | None = None
|
||||
items: list[EvaluationItemInput]
|
||||
|
||||
@ -12,6 +12,7 @@ from configs import dify_config
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
DefaultMetric,
|
||||
EvaluationCategory,
|
||||
EvaluationConfigData,
|
||||
EvaluationItemInput,
|
||||
EvaluationRunData,
|
||||
EvaluationRunRequest,
|
||||
@ -224,7 +225,7 @@ class EvaluationService:
|
||||
target_type: str,
|
||||
target_id: str,
|
||||
account_id: str,
|
||||
data: dict[str, Any],
|
||||
data: EvaluationConfigData,
|
||||
) -> EvaluationConfiguration:
|
||||
config = cls.get_evaluation_config(session, tenant_id, target_type, target_id)
|
||||
if config is None:
|
||||
@ -237,10 +238,15 @@ class EvaluationService:
|
||||
)
|
||||
session.add(config)
|
||||
|
||||
config.evaluation_model_provider = data.get("evaluation_model_provider")
|
||||
config.evaluation_model = data.get("evaluation_model")
|
||||
config.metrics_config = json.dumps(data.get("metrics_config", {}))
|
||||
config.judgement_conditions = json.dumps(data.get("judgement_conditions", {}))
|
||||
config.evaluation_model_provider = data.evaluation_model_provider
|
||||
config.evaluation_model = data.evaluation_model
|
||||
config.metrics_config = json.dumps({
|
||||
"default_metrics": [m.model_dump() for m in data.default_metrics],
|
||||
"customized_metrics": data.customized_metrics.model_dump() if data.customized_metrics else None,
|
||||
})
|
||||
config.judgement_conditions = json.dumps(
|
||||
data.judgment_config.model_dump() if data.judgment_config else {}
|
||||
)
|
||||
config.updated_by = account_id
|
||||
session.commit()
|
||||
session.refresh(config)
|
||||
@ -272,13 +278,6 @@ class EvaluationService:
|
||||
# Derive evaluation_category from default_metrics node types
|
||||
evaluation_category = cls._resolve_evaluation_category(run_request.default_metrics)
|
||||
|
||||
# Build metrics_config from default_metrics and customized_metrics
|
||||
metrics_config: dict[str, Any] = {
|
||||
"default_metrics": [m.model_dump() for m in run_request.default_metrics],
|
||||
}
|
||||
if run_request.customized_metrics is not None:
|
||||
metrics_config["customized_metrics"] = run_request.customized_metrics.model_dump()
|
||||
|
||||
# Save as latest EvaluationConfiguration
|
||||
config = cls.save_evaluation_config(
|
||||
session=session,
|
||||
@ -286,14 +285,7 @@ class EvaluationService:
|
||||
target_type=target_type,
|
||||
target_id=target_id,
|
||||
account_id=account_id,
|
||||
data={
|
||||
"evaluation_model_provider": run_request.evaluation_model_provider,
|
||||
"evaluation_model": run_request.evaluation_model,
|
||||
"metrics_config": metrics_config,
|
||||
"judgement_conditions": (
|
||||
run_request.judgment_config.model_dump() if run_request.judgment_config else {}
|
||||
),
|
||||
},
|
||||
data=run_request,
|
||||
)
|
||||
|
||||
# Check concurrent run limit
|
||||
@ -338,7 +330,10 @@ class EvaluationService:
|
||||
evaluation_category=evaluation_category,
|
||||
evaluation_model_provider=run_request.evaluation_model_provider,
|
||||
evaluation_model=run_request.evaluation_model,
|
||||
metrics_config=metrics_config,
|
||||
default_metrics=[m.model_dump() for m in run_request.default_metrics],
|
||||
customized_metrics=(
|
||||
run_request.customized_metrics.model_dump() if run_request.customized_metrics else None
|
||||
),
|
||||
judgment_config=run_request.judgment_config,
|
||||
items=items,
|
||||
)
|
||||
|
||||
@ -82,7 +82,8 @@ def _execute_evaluation(session: Any, run_data: EvaluationRunData) -> None:
|
||||
target_id=run_data.target_id,
|
||||
target_type=run_data.target_type,
|
||||
items=run_data.items,
|
||||
metrics_config=run_data.metrics_config,
|
||||
default_metrics=run_data.default_metrics,
|
||||
customized_metrics=run_data.customized_metrics,
|
||||
model_provider=run_data.evaluation_model_provider,
|
||||
model_name=run_data.evaluation_model,
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user