refactor(api): migrate service api workflow responses from marshal_with to BaseModel (#35195)

Co-authored-by: ai-hpc <ai-hpc@users.noreply.github.com>
This commit is contained in:
NVIDIAN 2026-04-14 12:50:59 -07:00 committed by GitHub
parent 1c3cba281a
commit b65a5fcd97
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 176 additions and 40 deletions

View File

@ -1,13 +1,15 @@
import logging import logging
from collections.abc import Mapping
from datetime import datetime
from typing import Literal from typing import Literal
from dateutil.parser import isoparse from dateutil.parser import isoparse
from flask import request from flask import request
from flask_restx import Namespace, Resource, fields from flask_restx import Resource, fields
from graphon.enums import WorkflowExecutionStatus from graphon.enums import WorkflowExecutionStatus
from graphon.graph_engine.manager import GraphEngineManager from graphon.graph_engine.manager import GraphEngineManager
from graphon.model_runtime.errors.invoke import InvokeError from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
@ -33,9 +35,10 @@ from core.errors.error import (
from core.helper.trace_id_helper import get_external_trace_id from core.helper.trace_id_helper import get_external_trace_id
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model from fields.base import ResponseModel
from fields.end_user_fields import SimpleEndUser
from fields.member_fields import SimpleAccount
from libs import helper from libs import helper
from libs.helper import OptionalTimestampField, TimestampField
from models.model import App, AppMode, EndUser from models.model import App, AppMode, EndUser
from models.workflow import WorkflowRun from models.workflow import WorkflowRun
from repositories.factory import DifyAPIRepositoryFactory from repositories.factory import DifyAPIRepositoryFactory
@ -65,38 +68,142 @@ class WorkflowLogQuery(BaseModel):
register_schema_models(service_api_ns, WorkflowRunPayload, WorkflowLogQuery) register_schema_models(service_api_ns, WorkflowRunPayload, WorkflowLogQuery)
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
def _enum_value(value):
return getattr(value, "value", value)
class WorkflowRunStatusField(fields.Raw): class WorkflowRunStatusField(fields.Raw):
def output(self, key, obj: WorkflowRun, **kwargs): def output(self, key, obj: WorkflowRun, **kwargs):
return obj.status.value return _enum_value(obj.status)
class WorkflowRunOutputsField(fields.Raw): class WorkflowRunOutputsField(fields.Raw):
def output(self, key, obj: WorkflowRun, **kwargs): def output(self, key, obj: WorkflowRun, **kwargs):
if obj.status == WorkflowExecutionStatus.PAUSED: status = _enum_value(obj.status)
if status == WorkflowExecutionStatus.PAUSED.value:
return {} return {}
outputs = obj.outputs_dict outputs = obj.outputs_dict
return outputs or {} return outputs or {}
workflow_run_fields = { class WorkflowRunResponse(ResponseModel):
"id": fields.String, id: str
"workflow_id": fields.String, workflow_id: str
"status": WorkflowRunStatusField, status: str
"inputs": fields.Raw, inputs: dict | list | str | int | float | bool | None = None
"outputs": WorkflowRunOutputsField, outputs: dict = Field(default_factory=dict)
"error": fields.String, error: str | None = None
"total_steps": fields.Integer, total_steps: int | None = None
"total_tokens": fields.Integer, total_tokens: int | None = None
"created_at": TimestampField, created_at: int | None = None
"finished_at": OptionalTimestampField, finished_at: int | None = None
"elapsed_time": fields.Float, elapsed_time: float | int | None = None
}
@field_validator("created_at", "finished_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
def build_workflow_run_model(api_or_ns: Namespace): class WorkflowRunForLogResponse(ResponseModel):
"""Build the workflow run model for the API or Namespace.""" id: str
return api_or_ns.model("WorkflowRun", workflow_run_fields) version: str | None = None
status: str | None = None
triggered_from: str | None = None
error: str | None = None
elapsed_time: float | int | None = None
total_tokens: int | None = None
total_steps: int | None = None
created_at: int | None = None
finished_at: int | None = None
exceptions_count: int | None = None
@field_validator("status", "triggered_from", mode="before")
@classmethod
def _normalize_enum(cls, value):
return _enum_value(value)
@field_validator("created_at", "finished_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class WorkflowAppLogPartialResponse(ResponseModel):
id: str
workflow_run: WorkflowRunForLogResponse | None = None
details: dict | list | str | int | float | bool | None = None
created_from: str | None = None
created_by_role: str | None = None
created_by_account: SimpleAccount | None = None
created_by_end_user: SimpleEndUser | None = None
created_at: int | None = None
@field_validator("created_from", "created_by_role", mode="before")
@classmethod
def _normalize_enum(cls, value):
return _enum_value(value)
@field_validator("created_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class WorkflowAppLogPaginationResponse(ResponseModel):
page: int
limit: int
total: int
has_more: bool
data: list[WorkflowAppLogPartialResponse]
register_schema_models(
service_api_ns,
WorkflowRunResponse,
WorkflowRunForLogResponse,
WorkflowAppLogPartialResponse,
WorkflowAppLogPaginationResponse,
)
def _serialize_workflow_run(workflow_run: WorkflowRun) -> dict:
status = _enum_value(workflow_run.status)
raw_outputs = workflow_run.outputs_dict
if status == WorkflowExecutionStatus.PAUSED.value or raw_outputs is None:
outputs: dict = {}
elif isinstance(raw_outputs, dict):
outputs = raw_outputs
elif isinstance(raw_outputs, Mapping):
outputs = dict(raw_outputs)
else:
outputs = {}
return WorkflowRunResponse.model_validate(
{
"id": workflow_run.id,
"workflow_id": workflow_run.workflow_id,
"status": status,
"inputs": workflow_run.inputs,
"outputs": outputs,
"error": workflow_run.error,
"total_steps": workflow_run.total_steps,
"total_tokens": workflow_run.total_tokens,
"created_at": workflow_run.created_at,
"finished_at": workflow_run.finished_at,
"elapsed_time": workflow_run.elapsed_time,
}
).model_dump(mode="json")
def _serialize_workflow_log_pagination(pagination) -> dict:
return WorkflowAppLogPaginationResponse.model_validate(pagination, from_attributes=True).model_dump(mode="json")
@service_api_ns.route("/workflows/run/<string:workflow_run_id>") @service_api_ns.route("/workflows/run/<string:workflow_run_id>")
@ -112,7 +219,11 @@ class WorkflowRunDetailApi(Resource):
} }
) )
@validate_app_token @validate_app_token
@service_api_ns.marshal_with(build_workflow_run_model(service_api_ns)) @service_api_ns.response(
200,
"Workflow run details retrieved successfully",
service_api_ns.models[WorkflowRunResponse.__name__],
)
def get(self, app_model: App, workflow_run_id: str): def get(self, app_model: App, workflow_run_id: str):
"""Get a workflow task running detail. """Get a workflow task running detail.
@ -133,7 +244,7 @@ class WorkflowRunDetailApi(Resource):
) )
if not workflow_run: if not workflow_run:
raise NotFound("Workflow run not found.") raise NotFound("Workflow run not found.")
return workflow_run return _serialize_workflow_run(workflow_run)
@service_api_ns.route("/workflows/run") @service_api_ns.route("/workflows/run")
@ -299,7 +410,11 @@ class WorkflowAppLogApi(Resource):
} }
) )
@validate_app_token @validate_app_token
@service_api_ns.marshal_with(build_workflow_app_log_pagination_model(service_api_ns)) @service_api_ns.response(
200,
"Logs retrieved successfully",
service_api_ns.models[WorkflowAppLogPaginationResponse.__name__],
)
def get(self, app_model: App): def get(self, app_model: App):
"""Get workflow app logs. """Get workflow app logs.
@ -327,4 +442,4 @@ class WorkflowAppLogApi(Resource):
created_by_account=args.created_by_account, created_by_account=args.created_by_account,
) )
return workflow_app_log_pagination return _serialize_workflow_log_pagination(workflow_app_log_pagination)

View File

@ -15,6 +15,7 @@ Focus on:
import sys import sys
import uuid import uuid
from datetime import UTC, datetime
from types import SimpleNamespace from types import SimpleNamespace
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
@ -43,6 +44,22 @@ from services.errors.llm import InvokeRateLimitError
from services.workflow_app_service import WorkflowAppService from services.workflow_app_service import WorkflowAppService
def _make_mock_workflow_run(run_id: str = "run-1"):
run = Mock()
run.id = run_id
run.workflow_id = "wf-1"
run.status = WorkflowExecutionStatus.SUCCEEDED
run.inputs = {"input": "value"}
run.outputs_dict = {"output": "value"}
run.error = None
run.total_steps = 1
run.total_tokens = 10
run.created_at = datetime(2026, 1, 1, tzinfo=UTC)
run.finished_at = datetime(2026, 1, 1, tzinfo=UTC)
run.elapsed_time = 0.1
return run
class TestWorkflowRunPayload: class TestWorkflowRunPayload:
"""Test suite for WorkflowRunPayload Pydantic model.""" """Test suite for WorkflowRunPayload Pydantic model."""
@ -359,7 +376,7 @@ class TestWorkflowRunDetailApi:
handler(api, app_model=app_model, workflow_run_id="run") handler(api, app_model=app_model, workflow_run_id="run")
def test_success(self, monkeypatch: pytest.MonkeyPatch) -> None: def test_success(self, monkeypatch: pytest.MonkeyPatch) -> None:
run = SimpleNamespace(id="run") run = _make_mock_workflow_run(run_id="run")
repo = SimpleNamespace(get_workflow_run_by_id=lambda **_kwargs: run) repo = SimpleNamespace(get_workflow_run_by_id=lambda **_kwargs: run)
workflow_module = sys.modules["controllers.service_api.app.workflow"] workflow_module = sys.modules["controllers.service_api.app.workflow"]
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object())) monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
@ -373,7 +390,10 @@ class TestWorkflowRunDetailApi:
handler = _unwrap(api.get) handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value, tenant_id="t1", id="a1") app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value, tenant_id="t1", id="a1")
assert handler(api, app_model=app_model, workflow_run_id="run") == run result = handler(api, app_model=app_model, workflow_run_id="run")
assert result["id"] == "run"
assert result["workflow_id"] == "wf-1"
assert result["status"] == "succeeded"
class TestWorkflowRunApi: class TestWorkflowRunApi:
@ -490,7 +510,7 @@ class TestWorkflowAppLogApi:
monkeypatch.setattr( monkeypatch.setattr(
WorkflowAppService, WorkflowAppService,
"get_paginate_workflow_app_logs", "get_paginate_workflow_app_logs",
lambda *_args, **_kwargs: {"items": [], "total": 0}, lambda *_args, **_kwargs: {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []},
) )
api = WorkflowAppLogApi() api = WorkflowAppLogApi()
@ -500,7 +520,7 @@ class TestWorkflowAppLogApi:
with app.test_request_context("/workflows/logs", method="GET"): with app.test_request_context("/workflows/logs", method="GET"):
response = handler(api, app_model=app_model) response = handler(api, app_model=app_model)
assert response == {"items": [], "total": 0} assert response == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []}
# ============================================================================= # =============================================================================
@ -527,9 +547,8 @@ def mock_workflow_app():
class TestWorkflowRunDetailApiGet: class TestWorkflowRunDetailApiGet:
"""Test suite for WorkflowRunDetailApi.get() endpoint. """Test suite for WorkflowRunDetailApi.get() endpoint.
``get`` is wrapped by ``@validate_app_token`` (preserves ``__wrapped__``) ``get`` is wrapped by ``@validate_app_token`` (preserves ``__wrapped__``),
and ``@service_api_ns.marshal_with``. We call the unwrapped method and we call the unwrapped method directly in tests.
directly; ``marshal_with`` is a no-op when calling directly.
""" """
@patch("controllers.service_api.app.workflow.DifyAPIRepositoryFactory") @patch("controllers.service_api.app.workflow.DifyAPIRepositoryFactory")
@ -542,9 +561,7 @@ class TestWorkflowRunDetailApiGet:
mock_workflow_app, mock_workflow_app,
): ):
"""Test successful workflow run detail retrieval.""" """Test successful workflow run detail retrieval."""
mock_run = Mock() mock_run = _make_mock_workflow_run(run_id="run-1")
mock_run.id = "run-1"
mock_run.status = "succeeded"
mock_repo = Mock() mock_repo = Mock()
mock_repo.get_workflow_run_by_id.return_value = mock_run mock_repo.get_workflow_run_by_id.return_value = mock_run
mock_repo_factory.create_api_workflow_run_repository.return_value = mock_repo mock_repo_factory.create_api_workflow_run_repository.return_value = mock_repo
@ -558,7 +575,8 @@ class TestWorkflowRunDetailApiGet:
api = WorkflowRunDetailApi() api = WorkflowRunDetailApi()
result = _unwrap(api.get)(api, app_model=mock_workflow_app, workflow_run_id=mock_run.id) result = _unwrap(api.get)(api, app_model=mock_workflow_app, workflow_run_id=mock_run.id)
assert result == mock_run assert result["id"] == mock_run.id
assert result["status"] == "succeeded"
@patch("controllers.service_api.app.workflow.db") @patch("controllers.service_api.app.workflow.db")
def test_get_workflow_run_wrong_app_mode(self, mock_db, app): def test_get_workflow_run_wrong_app_mode(self, mock_db, app):
@ -622,8 +640,7 @@ class TestWorkflowTaskStopApiPost:
class TestWorkflowAppLogApiGet: class TestWorkflowAppLogApiGet:
"""Test suite for WorkflowAppLogApi.get() endpoint. """Test suite for WorkflowAppLogApi.get() endpoint.
``get`` is wrapped by ``@validate_app_token`` and ``get`` is wrapped by ``@validate_app_token``.
``@service_api_ns.marshal_with``.
""" """
@patch("controllers.service_api.app.workflow.WorkflowAppService") @patch("controllers.service_api.app.workflow.WorkflowAppService")
@ -637,6 +654,10 @@ class TestWorkflowAppLogApiGet:
): ):
"""Test successful workflow log retrieval.""" """Test successful workflow log retrieval."""
mock_pagination = Mock() mock_pagination = Mock()
mock_pagination.page = 1
mock_pagination.limit = 20
mock_pagination.total = 0
mock_pagination.has_more = False
mock_pagination.data = [] mock_pagination.data = []
mock_svc_instance = Mock() mock_svc_instance = Mock()
mock_svc_instance.get_paginate_workflow_app_logs.return_value = mock_pagination mock_svc_instance.get_paginate_workflow_app_logs.return_value = mock_pagination
@ -661,4 +682,4 @@ class TestWorkflowAppLogApiGet:
api = WorkflowAppLogApi() api = WorkflowAppLogApi()
result = _unwrap(api.get)(api, app_model=mock_workflow_app) result = _unwrap(api.get)(api, app_model=mock_workflow_app)
assert result == mock_pagination assert result == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []}