mirror of
https://github.com/langgenius/dify.git
synced 2026-04-21 23:38:53 +08:00
refactor(api): migrate service api workflow responses from marshal_with to BaseModel (#35195)
Co-authored-by: ai-hpc <ai-hpc@users.noreply.github.com>
This commit is contained in:
parent
1c3cba281a
commit
b65a5fcd97
@ -1,13 +1,15 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from datetime import datetime
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from dateutil.parser import isoparse
|
from dateutil.parser import isoparse
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Namespace, Resource, fields
|
from flask_restx import Resource, fields
|
||||||
from graphon.enums import WorkflowExecutionStatus
|
from graphon.enums import WorkflowExecutionStatus
|
||||||
from graphon.graph_engine.manager import GraphEngineManager
|
from graphon.graph_engine.manager import GraphEngineManager
|
||||||
from graphon.model_runtime.errors.invoke import InvokeError
|
from graphon.model_runtime.errors.invoke import InvokeError
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||||
|
|
||||||
@ -33,9 +35,10 @@ from core.errors.error import (
|
|||||||
from core.helper.trace_id_helper import get_external_trace_id
|
from core.helper.trace_id_helper import get_external_trace_id
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
|
from fields.base import ResponseModel
|
||||||
|
from fields.end_user_fields import SimpleEndUser
|
||||||
|
from fields.member_fields import SimpleAccount
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from libs.helper import OptionalTimestampField, TimestampField
|
|
||||||
from models.model import App, AppMode, EndUser
|
from models.model import App, AppMode, EndUser
|
||||||
from models.workflow import WorkflowRun
|
from models.workflow import WorkflowRun
|
||||||
from repositories.factory import DifyAPIRepositoryFactory
|
from repositories.factory import DifyAPIRepositoryFactory
|
||||||
@ -65,38 +68,142 @@ class WorkflowLogQuery(BaseModel):
|
|||||||
register_schema_models(service_api_ns, WorkflowRunPayload, WorkflowLogQuery)
|
register_schema_models(service_api_ns, WorkflowRunPayload, WorkflowLogQuery)
|
||||||
|
|
||||||
|
|
||||||
|
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||||
|
if isinstance(value, datetime):
|
||||||
|
return int(value.timestamp())
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def _enum_value(value):
|
||||||
|
return getattr(value, "value", value)
|
||||||
|
|
||||||
|
|
||||||
class WorkflowRunStatusField(fields.Raw):
|
class WorkflowRunStatusField(fields.Raw):
|
||||||
def output(self, key, obj: WorkflowRun, **kwargs):
|
def output(self, key, obj: WorkflowRun, **kwargs):
|
||||||
return obj.status.value
|
return _enum_value(obj.status)
|
||||||
|
|
||||||
|
|
||||||
class WorkflowRunOutputsField(fields.Raw):
|
class WorkflowRunOutputsField(fields.Raw):
|
||||||
def output(self, key, obj: WorkflowRun, **kwargs):
|
def output(self, key, obj: WorkflowRun, **kwargs):
|
||||||
if obj.status == WorkflowExecutionStatus.PAUSED:
|
status = _enum_value(obj.status)
|
||||||
|
if status == WorkflowExecutionStatus.PAUSED.value:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
outputs = obj.outputs_dict
|
outputs = obj.outputs_dict
|
||||||
return outputs or {}
|
return outputs or {}
|
||||||
|
|
||||||
|
|
||||||
workflow_run_fields = {
|
class WorkflowRunResponse(ResponseModel):
|
||||||
"id": fields.String,
|
id: str
|
||||||
"workflow_id": fields.String,
|
workflow_id: str
|
||||||
"status": WorkflowRunStatusField,
|
status: str
|
||||||
"inputs": fields.Raw,
|
inputs: dict | list | str | int | float | bool | None = None
|
||||||
"outputs": WorkflowRunOutputsField,
|
outputs: dict = Field(default_factory=dict)
|
||||||
"error": fields.String,
|
error: str | None = None
|
||||||
"total_steps": fields.Integer,
|
total_steps: int | None = None
|
||||||
"total_tokens": fields.Integer,
|
total_tokens: int | None = None
|
||||||
"created_at": TimestampField,
|
created_at: int | None = None
|
||||||
"finished_at": OptionalTimestampField,
|
finished_at: int | None = None
|
||||||
"elapsed_time": fields.Float,
|
elapsed_time: float | int | None = None
|
||||||
}
|
|
||||||
|
@field_validator("created_at", "finished_at", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||||
|
return _to_timestamp(value)
|
||||||
|
|
||||||
|
|
||||||
def build_workflow_run_model(api_or_ns: Namespace):
|
class WorkflowRunForLogResponse(ResponseModel):
|
||||||
"""Build the workflow run model for the API or Namespace."""
|
id: str
|
||||||
return api_or_ns.model("WorkflowRun", workflow_run_fields)
|
version: str | None = None
|
||||||
|
status: str | None = None
|
||||||
|
triggered_from: str | None = None
|
||||||
|
error: str | None = None
|
||||||
|
elapsed_time: float | int | None = None
|
||||||
|
total_tokens: int | None = None
|
||||||
|
total_steps: int | None = None
|
||||||
|
created_at: int | None = None
|
||||||
|
finished_at: int | None = None
|
||||||
|
exceptions_count: int | None = None
|
||||||
|
|
||||||
|
@field_validator("status", "triggered_from", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _normalize_enum(cls, value):
|
||||||
|
return _enum_value(value)
|
||||||
|
|
||||||
|
@field_validator("created_at", "finished_at", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||||
|
return _to_timestamp(value)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowAppLogPartialResponse(ResponseModel):
|
||||||
|
id: str
|
||||||
|
workflow_run: WorkflowRunForLogResponse | None = None
|
||||||
|
details: dict | list | str | int | float | bool | None = None
|
||||||
|
created_from: str | None = None
|
||||||
|
created_by_role: str | None = None
|
||||||
|
created_by_account: SimpleAccount | None = None
|
||||||
|
created_by_end_user: SimpleEndUser | None = None
|
||||||
|
created_at: int | None = None
|
||||||
|
|
||||||
|
@field_validator("created_from", "created_by_role", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _normalize_enum(cls, value):
|
||||||
|
return _enum_value(value)
|
||||||
|
|
||||||
|
@field_validator("created_at", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||||
|
return _to_timestamp(value)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowAppLogPaginationResponse(ResponseModel):
|
||||||
|
page: int
|
||||||
|
limit: int
|
||||||
|
total: int
|
||||||
|
has_more: bool
|
||||||
|
data: list[WorkflowAppLogPartialResponse]
|
||||||
|
|
||||||
|
|
||||||
|
register_schema_models(
|
||||||
|
service_api_ns,
|
||||||
|
WorkflowRunResponse,
|
||||||
|
WorkflowRunForLogResponse,
|
||||||
|
WorkflowAppLogPartialResponse,
|
||||||
|
WorkflowAppLogPaginationResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_workflow_run(workflow_run: WorkflowRun) -> dict:
|
||||||
|
status = _enum_value(workflow_run.status)
|
||||||
|
raw_outputs = workflow_run.outputs_dict
|
||||||
|
if status == WorkflowExecutionStatus.PAUSED.value or raw_outputs is None:
|
||||||
|
outputs: dict = {}
|
||||||
|
elif isinstance(raw_outputs, dict):
|
||||||
|
outputs = raw_outputs
|
||||||
|
elif isinstance(raw_outputs, Mapping):
|
||||||
|
outputs = dict(raw_outputs)
|
||||||
|
else:
|
||||||
|
outputs = {}
|
||||||
|
return WorkflowRunResponse.model_validate(
|
||||||
|
{
|
||||||
|
"id": workflow_run.id,
|
||||||
|
"workflow_id": workflow_run.workflow_id,
|
||||||
|
"status": status,
|
||||||
|
"inputs": workflow_run.inputs,
|
||||||
|
"outputs": outputs,
|
||||||
|
"error": workflow_run.error,
|
||||||
|
"total_steps": workflow_run.total_steps,
|
||||||
|
"total_tokens": workflow_run.total_tokens,
|
||||||
|
"created_at": workflow_run.created_at,
|
||||||
|
"finished_at": workflow_run.finished_at,
|
||||||
|
"elapsed_time": workflow_run.elapsed_time,
|
||||||
|
}
|
||||||
|
).model_dump(mode="json")
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_workflow_log_pagination(pagination) -> dict:
|
||||||
|
return WorkflowAppLogPaginationResponse.model_validate(pagination, from_attributes=True).model_dump(mode="json")
|
||||||
|
|
||||||
|
|
||||||
@service_api_ns.route("/workflows/run/<string:workflow_run_id>")
|
@service_api_ns.route("/workflows/run/<string:workflow_run_id>")
|
||||||
@ -112,7 +219,11 @@ class WorkflowRunDetailApi(Resource):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
@validate_app_token
|
@validate_app_token
|
||||||
@service_api_ns.marshal_with(build_workflow_run_model(service_api_ns))
|
@service_api_ns.response(
|
||||||
|
200,
|
||||||
|
"Workflow run details retrieved successfully",
|
||||||
|
service_api_ns.models[WorkflowRunResponse.__name__],
|
||||||
|
)
|
||||||
def get(self, app_model: App, workflow_run_id: str):
|
def get(self, app_model: App, workflow_run_id: str):
|
||||||
"""Get a workflow task running detail.
|
"""Get a workflow task running detail.
|
||||||
|
|
||||||
@ -133,7 +244,7 @@ class WorkflowRunDetailApi(Resource):
|
|||||||
)
|
)
|
||||||
if not workflow_run:
|
if not workflow_run:
|
||||||
raise NotFound("Workflow run not found.")
|
raise NotFound("Workflow run not found.")
|
||||||
return workflow_run
|
return _serialize_workflow_run(workflow_run)
|
||||||
|
|
||||||
|
|
||||||
@service_api_ns.route("/workflows/run")
|
@service_api_ns.route("/workflows/run")
|
||||||
@ -299,7 +410,11 @@ class WorkflowAppLogApi(Resource):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
@validate_app_token
|
@validate_app_token
|
||||||
@service_api_ns.marshal_with(build_workflow_app_log_pagination_model(service_api_ns))
|
@service_api_ns.response(
|
||||||
|
200,
|
||||||
|
"Logs retrieved successfully",
|
||||||
|
service_api_ns.models[WorkflowAppLogPaginationResponse.__name__],
|
||||||
|
)
|
||||||
def get(self, app_model: App):
|
def get(self, app_model: App):
|
||||||
"""Get workflow app logs.
|
"""Get workflow app logs.
|
||||||
|
|
||||||
@ -327,4 +442,4 @@ class WorkflowAppLogApi(Resource):
|
|||||||
created_by_account=args.created_by_account,
|
created_by_account=args.created_by_account,
|
||||||
)
|
)
|
||||||
|
|
||||||
return workflow_app_log_pagination
|
return _serialize_workflow_log_pagination(workflow_app_log_pagination)
|
||||||
|
|||||||
@ -15,6 +15,7 @@ Focus on:
|
|||||||
|
|
||||||
import sys
|
import sys
|
||||||
import uuid
|
import uuid
|
||||||
|
from datetime import UTC, datetime
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
@ -43,6 +44,22 @@ from services.errors.llm import InvokeRateLimitError
|
|||||||
from services.workflow_app_service import WorkflowAppService
|
from services.workflow_app_service import WorkflowAppService
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_workflow_run(run_id: str = "run-1"):
|
||||||
|
run = Mock()
|
||||||
|
run.id = run_id
|
||||||
|
run.workflow_id = "wf-1"
|
||||||
|
run.status = WorkflowExecutionStatus.SUCCEEDED
|
||||||
|
run.inputs = {"input": "value"}
|
||||||
|
run.outputs_dict = {"output": "value"}
|
||||||
|
run.error = None
|
||||||
|
run.total_steps = 1
|
||||||
|
run.total_tokens = 10
|
||||||
|
run.created_at = datetime(2026, 1, 1, tzinfo=UTC)
|
||||||
|
run.finished_at = datetime(2026, 1, 1, tzinfo=UTC)
|
||||||
|
run.elapsed_time = 0.1
|
||||||
|
return run
|
||||||
|
|
||||||
|
|
||||||
class TestWorkflowRunPayload:
|
class TestWorkflowRunPayload:
|
||||||
"""Test suite for WorkflowRunPayload Pydantic model."""
|
"""Test suite for WorkflowRunPayload Pydantic model."""
|
||||||
|
|
||||||
@ -359,7 +376,7 @@ class TestWorkflowRunDetailApi:
|
|||||||
handler(api, app_model=app_model, workflow_run_id="run")
|
handler(api, app_model=app_model, workflow_run_id="run")
|
||||||
|
|
||||||
def test_success(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
def test_success(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
run = SimpleNamespace(id="run")
|
run = _make_mock_workflow_run(run_id="run")
|
||||||
repo = SimpleNamespace(get_workflow_run_by_id=lambda **_kwargs: run)
|
repo = SimpleNamespace(get_workflow_run_by_id=lambda **_kwargs: run)
|
||||||
workflow_module = sys.modules["controllers.service_api.app.workflow"]
|
workflow_module = sys.modules["controllers.service_api.app.workflow"]
|
||||||
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
|
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
|
||||||
@ -373,7 +390,10 @@ class TestWorkflowRunDetailApi:
|
|||||||
handler = _unwrap(api.get)
|
handler = _unwrap(api.get)
|
||||||
app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value, tenant_id="t1", id="a1")
|
app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value, tenant_id="t1", id="a1")
|
||||||
|
|
||||||
assert handler(api, app_model=app_model, workflow_run_id="run") == run
|
result = handler(api, app_model=app_model, workflow_run_id="run")
|
||||||
|
assert result["id"] == "run"
|
||||||
|
assert result["workflow_id"] == "wf-1"
|
||||||
|
assert result["status"] == "succeeded"
|
||||||
|
|
||||||
|
|
||||||
class TestWorkflowRunApi:
|
class TestWorkflowRunApi:
|
||||||
@ -490,7 +510,7 @@ class TestWorkflowAppLogApi:
|
|||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
WorkflowAppService,
|
WorkflowAppService,
|
||||||
"get_paginate_workflow_app_logs",
|
"get_paginate_workflow_app_logs",
|
||||||
lambda *_args, **_kwargs: {"items": [], "total": 0},
|
lambda *_args, **_kwargs: {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []},
|
||||||
)
|
)
|
||||||
|
|
||||||
api = WorkflowAppLogApi()
|
api = WorkflowAppLogApi()
|
||||||
@ -500,7 +520,7 @@ class TestWorkflowAppLogApi:
|
|||||||
with app.test_request_context("/workflows/logs", method="GET"):
|
with app.test_request_context("/workflows/logs", method="GET"):
|
||||||
response = handler(api, app_model=app_model)
|
response = handler(api, app_model=app_model)
|
||||||
|
|
||||||
assert response == {"items": [], "total": 0}
|
assert response == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []}
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@ -527,9 +547,8 @@ def mock_workflow_app():
|
|||||||
class TestWorkflowRunDetailApiGet:
|
class TestWorkflowRunDetailApiGet:
|
||||||
"""Test suite for WorkflowRunDetailApi.get() endpoint.
|
"""Test suite for WorkflowRunDetailApi.get() endpoint.
|
||||||
|
|
||||||
``get`` is wrapped by ``@validate_app_token`` (preserves ``__wrapped__``)
|
``get`` is wrapped by ``@validate_app_token`` (preserves ``__wrapped__``),
|
||||||
and ``@service_api_ns.marshal_with``. We call the unwrapped method
|
and we call the unwrapped method directly in tests.
|
||||||
directly; ``marshal_with`` is a no-op when calling directly.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@patch("controllers.service_api.app.workflow.DifyAPIRepositoryFactory")
|
@patch("controllers.service_api.app.workflow.DifyAPIRepositoryFactory")
|
||||||
@ -542,9 +561,7 @@ class TestWorkflowRunDetailApiGet:
|
|||||||
mock_workflow_app,
|
mock_workflow_app,
|
||||||
):
|
):
|
||||||
"""Test successful workflow run detail retrieval."""
|
"""Test successful workflow run detail retrieval."""
|
||||||
mock_run = Mock()
|
mock_run = _make_mock_workflow_run(run_id="run-1")
|
||||||
mock_run.id = "run-1"
|
|
||||||
mock_run.status = "succeeded"
|
|
||||||
mock_repo = Mock()
|
mock_repo = Mock()
|
||||||
mock_repo.get_workflow_run_by_id.return_value = mock_run
|
mock_repo.get_workflow_run_by_id.return_value = mock_run
|
||||||
mock_repo_factory.create_api_workflow_run_repository.return_value = mock_repo
|
mock_repo_factory.create_api_workflow_run_repository.return_value = mock_repo
|
||||||
@ -558,7 +575,8 @@ class TestWorkflowRunDetailApiGet:
|
|||||||
api = WorkflowRunDetailApi()
|
api = WorkflowRunDetailApi()
|
||||||
result = _unwrap(api.get)(api, app_model=mock_workflow_app, workflow_run_id=mock_run.id)
|
result = _unwrap(api.get)(api, app_model=mock_workflow_app, workflow_run_id=mock_run.id)
|
||||||
|
|
||||||
assert result == mock_run
|
assert result["id"] == mock_run.id
|
||||||
|
assert result["status"] == "succeeded"
|
||||||
|
|
||||||
@patch("controllers.service_api.app.workflow.db")
|
@patch("controllers.service_api.app.workflow.db")
|
||||||
def test_get_workflow_run_wrong_app_mode(self, mock_db, app):
|
def test_get_workflow_run_wrong_app_mode(self, mock_db, app):
|
||||||
@ -622,8 +640,7 @@ class TestWorkflowTaskStopApiPost:
|
|||||||
class TestWorkflowAppLogApiGet:
|
class TestWorkflowAppLogApiGet:
|
||||||
"""Test suite for WorkflowAppLogApi.get() endpoint.
|
"""Test suite for WorkflowAppLogApi.get() endpoint.
|
||||||
|
|
||||||
``get`` is wrapped by ``@validate_app_token`` and
|
``get`` is wrapped by ``@validate_app_token``.
|
||||||
``@service_api_ns.marshal_with``.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@patch("controllers.service_api.app.workflow.WorkflowAppService")
|
@patch("controllers.service_api.app.workflow.WorkflowAppService")
|
||||||
@ -637,6 +654,10 @@ class TestWorkflowAppLogApiGet:
|
|||||||
):
|
):
|
||||||
"""Test successful workflow log retrieval."""
|
"""Test successful workflow log retrieval."""
|
||||||
mock_pagination = Mock()
|
mock_pagination = Mock()
|
||||||
|
mock_pagination.page = 1
|
||||||
|
mock_pagination.limit = 20
|
||||||
|
mock_pagination.total = 0
|
||||||
|
mock_pagination.has_more = False
|
||||||
mock_pagination.data = []
|
mock_pagination.data = []
|
||||||
mock_svc_instance = Mock()
|
mock_svc_instance = Mock()
|
||||||
mock_svc_instance.get_paginate_workflow_app_logs.return_value = mock_pagination
|
mock_svc_instance.get_paginate_workflow_app_logs.return_value = mock_pagination
|
||||||
@ -661,4 +682,4 @@ class TestWorkflowAppLogApiGet:
|
|||||||
api = WorkflowAppLogApi()
|
api = WorkflowAppLogApi()
|
||||||
result = _unwrap(api.get)(api, app_model=mock_workflow_app)
|
result = _unwrap(api.get)(api, app_model=mock_workflow_app)
|
||||||
|
|
||||||
assert result == mock_pagination
|
assert result == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user