mirror of
https://github.com/langgenius/dify.git
synced 2026-04-16 02:16:57 +08:00
refactor(api): migrate service api workflow responses from marshal_with to BaseModel (#35195)
Co-authored-by: ai-hpc <ai-hpc@users.noreply.github.com>
This commit is contained in:
parent
1c3cba281a
commit
b65a5fcd97
@ -1,13 +1,15 @@
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from flask import request
|
||||
from flask_restx import Namespace, Resource, fields
|
||||
from flask_restx import Resource, fields
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
@ -33,9 +35,10 @@ from core.errors.error import (
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
|
||||
from fields.base import ResponseModel
|
||||
from fields.end_user_fields import SimpleEndUser
|
||||
from fields.member_fields import SimpleAccount
|
||||
from libs import helper
|
||||
from libs.helper import OptionalTimestampField, TimestampField
|
||||
from models.model import App, AppMode, EndUser
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
@ -65,38 +68,142 @@ class WorkflowLogQuery(BaseModel):
|
||||
register_schema_models(service_api_ns, WorkflowRunPayload, WorkflowLogQuery)
|
||||
|
||||
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
def _enum_value(value):
|
||||
return getattr(value, "value", value)
|
||||
|
||||
|
||||
class WorkflowRunStatusField(fields.Raw):
|
||||
def output(self, key, obj: WorkflowRun, **kwargs):
|
||||
return obj.status.value
|
||||
return _enum_value(obj.status)
|
||||
|
||||
|
||||
class WorkflowRunOutputsField(fields.Raw):
|
||||
def output(self, key, obj: WorkflowRun, **kwargs):
|
||||
if obj.status == WorkflowExecutionStatus.PAUSED:
|
||||
status = _enum_value(obj.status)
|
||||
if status == WorkflowExecutionStatus.PAUSED.value:
|
||||
return {}
|
||||
|
||||
outputs = obj.outputs_dict
|
||||
return outputs or {}
|
||||
|
||||
|
||||
workflow_run_fields = {
|
||||
"id": fields.String,
|
||||
"workflow_id": fields.String,
|
||||
"status": WorkflowRunStatusField,
|
||||
"inputs": fields.Raw,
|
||||
"outputs": WorkflowRunOutputsField,
|
||||
"error": fields.String,
|
||||
"total_steps": fields.Integer,
|
||||
"total_tokens": fields.Integer,
|
||||
"created_at": TimestampField,
|
||||
"finished_at": OptionalTimestampField,
|
||||
"elapsed_time": fields.Float,
|
||||
}
|
||||
class WorkflowRunResponse(ResponseModel):
|
||||
id: str
|
||||
workflow_id: str
|
||||
status: str
|
||||
inputs: dict | list | str | int | float | bool | None = None
|
||||
outputs: dict = Field(default_factory=dict)
|
||||
error: str | None = None
|
||||
total_steps: int | None = None
|
||||
total_tokens: int | None = None
|
||||
created_at: int | None = None
|
||||
finished_at: int | None = None
|
||||
elapsed_time: float | int | None = None
|
||||
|
||||
@field_validator("created_at", "finished_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
def build_workflow_run_model(api_or_ns: Namespace):
|
||||
"""Build the workflow run model for the API or Namespace."""
|
||||
return api_or_ns.model("WorkflowRun", workflow_run_fields)
|
||||
class WorkflowRunForLogResponse(ResponseModel):
|
||||
id: str
|
||||
version: str | None = None
|
||||
status: str | None = None
|
||||
triggered_from: str | None = None
|
||||
error: str | None = None
|
||||
elapsed_time: float | int | None = None
|
||||
total_tokens: int | None = None
|
||||
total_steps: int | None = None
|
||||
created_at: int | None = None
|
||||
finished_at: int | None = None
|
||||
exceptions_count: int | None = None
|
||||
|
||||
@field_validator("status", "triggered_from", mode="before")
|
||||
@classmethod
|
||||
def _normalize_enum(cls, value):
|
||||
return _enum_value(value)
|
||||
|
||||
@field_validator("created_at", "finished_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class WorkflowAppLogPartialResponse(ResponseModel):
|
||||
id: str
|
||||
workflow_run: WorkflowRunForLogResponse | None = None
|
||||
details: dict | list | str | int | float | bool | None = None
|
||||
created_from: str | None = None
|
||||
created_by_role: str | None = None
|
||||
created_by_account: SimpleAccount | None = None
|
||||
created_by_end_user: SimpleEndUser | None = None
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("created_from", "created_by_role", mode="before")
|
||||
@classmethod
|
||||
def _normalize_enum(cls, value):
|
||||
return _enum_value(value)
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class WorkflowAppLogPaginationResponse(ResponseModel):
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[WorkflowAppLogPartialResponse]
|
||||
|
||||
|
||||
register_schema_models(
|
||||
service_api_ns,
|
||||
WorkflowRunResponse,
|
||||
WorkflowRunForLogResponse,
|
||||
WorkflowAppLogPartialResponse,
|
||||
WorkflowAppLogPaginationResponse,
|
||||
)
|
||||
|
||||
|
||||
def _serialize_workflow_run(workflow_run: WorkflowRun) -> dict:
|
||||
status = _enum_value(workflow_run.status)
|
||||
raw_outputs = workflow_run.outputs_dict
|
||||
if status == WorkflowExecutionStatus.PAUSED.value or raw_outputs is None:
|
||||
outputs: dict = {}
|
||||
elif isinstance(raw_outputs, dict):
|
||||
outputs = raw_outputs
|
||||
elif isinstance(raw_outputs, Mapping):
|
||||
outputs = dict(raw_outputs)
|
||||
else:
|
||||
outputs = {}
|
||||
return WorkflowRunResponse.model_validate(
|
||||
{
|
||||
"id": workflow_run.id,
|
||||
"workflow_id": workflow_run.workflow_id,
|
||||
"status": status,
|
||||
"inputs": workflow_run.inputs,
|
||||
"outputs": outputs,
|
||||
"error": workflow_run.error,
|
||||
"total_steps": workflow_run.total_steps,
|
||||
"total_tokens": workflow_run.total_tokens,
|
||||
"created_at": workflow_run.created_at,
|
||||
"finished_at": workflow_run.finished_at,
|
||||
"elapsed_time": workflow_run.elapsed_time,
|
||||
}
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
def _serialize_workflow_log_pagination(pagination) -> dict:
|
||||
return WorkflowAppLogPaginationResponse.model_validate(pagination, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
@service_api_ns.route("/workflows/run/<string:workflow_run_id>")
|
||||
@ -112,7 +219,11 @@ class WorkflowRunDetailApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
@service_api_ns.marshal_with(build_workflow_run_model(service_api_ns))
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Workflow run details retrieved successfully",
|
||||
service_api_ns.models[WorkflowRunResponse.__name__],
|
||||
)
|
||||
def get(self, app_model: App, workflow_run_id: str):
|
||||
"""Get a workflow task running detail.
|
||||
|
||||
@ -133,7 +244,7 @@ class WorkflowRunDetailApi(Resource):
|
||||
)
|
||||
if not workflow_run:
|
||||
raise NotFound("Workflow run not found.")
|
||||
return workflow_run
|
||||
return _serialize_workflow_run(workflow_run)
|
||||
|
||||
|
||||
@service_api_ns.route("/workflows/run")
|
||||
@ -299,7 +410,11 @@ class WorkflowAppLogApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
@service_api_ns.marshal_with(build_workflow_app_log_pagination_model(service_api_ns))
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Logs retrieved successfully",
|
||||
service_api_ns.models[WorkflowAppLogPaginationResponse.__name__],
|
||||
)
|
||||
def get(self, app_model: App):
|
||||
"""Get workflow app logs.
|
||||
|
||||
@ -327,4 +442,4 @@ class WorkflowAppLogApi(Resource):
|
||||
created_by_account=args.created_by_account,
|
||||
)
|
||||
|
||||
return workflow_app_log_pagination
|
||||
return _serialize_workflow_log_pagination(workflow_app_log_pagination)
|
||||
|
||||
@ -15,6 +15,7 @@ Focus on:
|
||||
|
||||
import sys
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
@ -43,6 +44,22 @@ from services.errors.llm import InvokeRateLimitError
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
|
||||
|
||||
def _make_mock_workflow_run(run_id: str = "run-1"):
|
||||
run = Mock()
|
||||
run.id = run_id
|
||||
run.workflow_id = "wf-1"
|
||||
run.status = WorkflowExecutionStatus.SUCCEEDED
|
||||
run.inputs = {"input": "value"}
|
||||
run.outputs_dict = {"output": "value"}
|
||||
run.error = None
|
||||
run.total_steps = 1
|
||||
run.total_tokens = 10
|
||||
run.created_at = datetime(2026, 1, 1, tzinfo=UTC)
|
||||
run.finished_at = datetime(2026, 1, 1, tzinfo=UTC)
|
||||
run.elapsed_time = 0.1
|
||||
return run
|
||||
|
||||
|
||||
class TestWorkflowRunPayload:
|
||||
"""Test suite for WorkflowRunPayload Pydantic model."""
|
||||
|
||||
@ -359,7 +376,7 @@ class TestWorkflowRunDetailApi:
|
||||
handler(api, app_model=app_model, workflow_run_id="run")
|
||||
|
||||
def test_success(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
run = SimpleNamespace(id="run")
|
||||
run = _make_mock_workflow_run(run_id="run")
|
||||
repo = SimpleNamespace(get_workflow_run_by_id=lambda **_kwargs: run)
|
||||
workflow_module = sys.modules["controllers.service_api.app.workflow"]
|
||||
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
|
||||
@ -373,7 +390,10 @@ class TestWorkflowRunDetailApi:
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value, tenant_id="t1", id="a1")
|
||||
|
||||
assert handler(api, app_model=app_model, workflow_run_id="run") == run
|
||||
result = handler(api, app_model=app_model, workflow_run_id="run")
|
||||
assert result["id"] == "run"
|
||||
assert result["workflow_id"] == "wf-1"
|
||||
assert result["status"] == "succeeded"
|
||||
|
||||
|
||||
class TestWorkflowRunApi:
|
||||
@ -490,7 +510,7 @@ class TestWorkflowAppLogApi:
|
||||
monkeypatch.setattr(
|
||||
WorkflowAppService,
|
||||
"get_paginate_workflow_app_logs",
|
||||
lambda *_args, **_kwargs: {"items": [], "total": 0},
|
||||
lambda *_args, **_kwargs: {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []},
|
||||
)
|
||||
|
||||
api = WorkflowAppLogApi()
|
||||
@ -500,7 +520,7 @@ class TestWorkflowAppLogApi:
|
||||
with app.test_request_context("/workflows/logs", method="GET"):
|
||||
response = handler(api, app_model=app_model)
|
||||
|
||||
assert response == {"items": [], "total": 0}
|
||||
assert response == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@ -527,9 +547,8 @@ def mock_workflow_app():
|
||||
class TestWorkflowRunDetailApiGet:
|
||||
"""Test suite for WorkflowRunDetailApi.get() endpoint.
|
||||
|
||||
``get`` is wrapped by ``@validate_app_token`` (preserves ``__wrapped__``)
|
||||
and ``@service_api_ns.marshal_with``. We call the unwrapped method
|
||||
directly; ``marshal_with`` is a no-op when calling directly.
|
||||
``get`` is wrapped by ``@validate_app_token`` (preserves ``__wrapped__``),
|
||||
and we call the unwrapped method directly in tests.
|
||||
"""
|
||||
|
||||
@patch("controllers.service_api.app.workflow.DifyAPIRepositoryFactory")
|
||||
@ -542,9 +561,7 @@ class TestWorkflowRunDetailApiGet:
|
||||
mock_workflow_app,
|
||||
):
|
||||
"""Test successful workflow run detail retrieval."""
|
||||
mock_run = Mock()
|
||||
mock_run.id = "run-1"
|
||||
mock_run.status = "succeeded"
|
||||
mock_run = _make_mock_workflow_run(run_id="run-1")
|
||||
mock_repo = Mock()
|
||||
mock_repo.get_workflow_run_by_id.return_value = mock_run
|
||||
mock_repo_factory.create_api_workflow_run_repository.return_value = mock_repo
|
||||
@ -558,7 +575,8 @@ class TestWorkflowRunDetailApiGet:
|
||||
api = WorkflowRunDetailApi()
|
||||
result = _unwrap(api.get)(api, app_model=mock_workflow_app, workflow_run_id=mock_run.id)
|
||||
|
||||
assert result == mock_run
|
||||
assert result["id"] == mock_run.id
|
||||
assert result["status"] == "succeeded"
|
||||
|
||||
@patch("controllers.service_api.app.workflow.db")
|
||||
def test_get_workflow_run_wrong_app_mode(self, mock_db, app):
|
||||
@ -622,8 +640,7 @@ class TestWorkflowTaskStopApiPost:
|
||||
class TestWorkflowAppLogApiGet:
|
||||
"""Test suite for WorkflowAppLogApi.get() endpoint.
|
||||
|
||||
``get`` is wrapped by ``@validate_app_token`` and
|
||||
``@service_api_ns.marshal_with``.
|
||||
``get`` is wrapped by ``@validate_app_token``.
|
||||
"""
|
||||
|
||||
@patch("controllers.service_api.app.workflow.WorkflowAppService")
|
||||
@ -637,6 +654,10 @@ class TestWorkflowAppLogApiGet:
|
||||
):
|
||||
"""Test successful workflow log retrieval."""
|
||||
mock_pagination = Mock()
|
||||
mock_pagination.page = 1
|
||||
mock_pagination.limit = 20
|
||||
mock_pagination.total = 0
|
||||
mock_pagination.has_more = False
|
||||
mock_pagination.data = []
|
||||
mock_svc_instance = Mock()
|
||||
mock_svc_instance.get_paginate_workflow_app_logs.return_value = mock_pagination
|
||||
@ -661,4 +682,4 @@ class TestWorkflowAppLogApiGet:
|
||||
api = WorkflowAppLogApi()
|
||||
result = _unwrap(api.get)(api, app_model=mock_workflow_app)
|
||||
|
||||
assert result == mock_pagination
|
||||
assert result == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user