mirror of
https://github.com/langgenius/dify.git
synced 2026-05-09 21:28:25 +08:00
fix(openapi): /run swallowed HTTP errors as 500
Explicit re-raise list at L230-238 only covered UnprocessableEntity + NotChatAppError/NotWorkflowAppError. Other HTTPException subclasses raised inside handlers (NotFound, BadRequest, ConversationCompletedError, ProviderNotInitializeError, ProviderQuotaExceededError, ...) hit `except Exception` and got squashed to 500. Replace with `except HTTPException: raise`. Refactor bundled: collapse 3x try/except ladder into _translate_service_errors() ctxmgr, inline single-call constraint enforcers, drop wasted dict() copy in _unpack_blocking, trim module docstring and stale spec doc reference. -60 net lines.
This commit is contained in:
parent
0c568623d7
commit
8e2ab1367b
@ -108,5 +108,5 @@ class WorkflowRunData(BaseModel):
|
||||
class WorkflowRunResponse(BaseModel):
|
||||
workflow_run_id: str
|
||||
task_id: str
|
||||
mode: Literal["workflow"] = "workflow" # echoed for CLI per-mode rendering — see endpoints.md L154
|
||||
mode: Literal["workflow"] = "workflow"
|
||||
data: WorkflowRunData
|
||||
|
||||
@ -1,22 +1,17 @@
|
||||
"""POST /openapi/v1/apps/<app_id>/run — mode-agnostic runner.
|
||||
|
||||
Server reads ``apps.mode`` after AppResolver and dispatches via
|
||||
_DISPATCH to the per-mode helper. Per-mode constraints (e.g. chat-family
|
||||
requires ``query``; workflow rejects ``query``) are enforced inside
|
||||
the helper, post-resolve, since ``mode`` is not in the request body.
|
||||
"""
|
||||
"""POST /openapi/v1/apps/<app_id>/run — mode-agnostic runner."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable, Mapping
|
||||
from collections.abc import Callable, Iterator, Mapping
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, ValidationError, field_validator
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound, UnprocessableEntity
|
||||
from werkzeug.exceptions import BadRequest, HTTPException, InternalServerError, NotFound, UnprocessableEntity
|
||||
|
||||
import services
|
||||
from controllers.openapi import openapi_ns
|
||||
@ -31,8 +26,6 @@ from controllers.service_api.app.error import (
|
||||
AppUnavailableError,
|
||||
CompletionRequestError,
|
||||
ConversationCompletedError,
|
||||
NotChatAppError,
|
||||
NotWorkflowAppError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
@ -82,24 +75,39 @@ class AppRunRequest(BaseModel):
|
||||
raise ValueError("conversation_id must be a valid UUID") from exc
|
||||
|
||||
|
||||
def _enforce_chat_constraint(payload: AppRunRequest) -> None:
|
||||
if not payload.query or not payload.query.strip():
|
||||
raise UnprocessableEntity("query_required_for_chat")
|
||||
|
||||
|
||||
def _enforce_workflow_constraint(payload: AppRunRequest) -> None:
|
||||
if payload.query is not None:
|
||||
raise UnprocessableEntity("query_not_supported_for_workflow")
|
||||
@contextmanager
|
||||
def _translate_service_errors() -> Iterator[None]:
|
||||
try:
|
||||
yield
|
||||
except WorkflowNotFoundError as ex:
|
||||
raise NotFound(str(ex))
|
||||
except (IsDraftWorkflowError, WorkflowIdFormatError) as ex:
|
||||
raise BadRequest(str(ex))
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
|
||||
|
||||
def _unpack_blocking(response: Any) -> Mapping[str, Any]:
|
||||
if isinstance(response, tuple):
|
||||
body_dict: Any = response[0]
|
||||
else:
|
||||
body_dict = response
|
||||
if not isinstance(body_dict, Mapping):
|
||||
response = response[0]
|
||||
if not isinstance(response, Mapping):
|
||||
raise InternalServerError("blocking generate returned non-mapping response")
|
||||
return dict(body_dict)
|
||||
return response
|
||||
|
||||
|
||||
def _generate(app: App, caller: Any, args: dict[str, Any], streaming: bool):
|
||||
@ -113,91 +121,36 @@ def _generate(app: App, caller: Any, args: dict[str, Any], streaming: bool):
|
||||
|
||||
|
||||
def _run_chat(app: App, caller: Any, payload: AppRunRequest, streaming: bool):
|
||||
_enforce_chat_constraint(payload)
|
||||
if not payload.query or not payload.query.strip():
|
||||
raise UnprocessableEntity("query_required_for_chat")
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
try:
|
||||
with _translate_service_errors():
|
||||
response = _generate(app, caller, args, streaming)
|
||||
except WorkflowNotFoundError as ex:
|
||||
raise NotFound(str(ex))
|
||||
except (IsDraftWorkflowError, WorkflowIdFormatError) as ex:
|
||||
raise BadRequest(str(ex))
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
|
||||
if streaming:
|
||||
return response, None
|
||||
body = _unpack_blocking(response)
|
||||
return None, ChatMessageResponse.model_validate(body).model_dump(mode="json")
|
||||
return None, ChatMessageResponse.model_validate(_unpack_blocking(response)).model_dump(mode="json")
|
||||
|
||||
|
||||
def _run_completion(app: App, caller: Any, payload: AppRunRequest, streaming: bool):
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
# Completion mode disables auto-naming + tolerates absent query (legacy parity).
|
||||
args["auto_generate_name"] = False
|
||||
args.setdefault("query", "")
|
||||
try:
|
||||
with _translate_service_errors():
|
||||
response = _generate(app, caller, args, streaming)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
|
||||
if streaming:
|
||||
return response, None
|
||||
body = _unpack_blocking(response)
|
||||
return None, CompletionMessageResponse.model_validate(body).model_dump(mode="json")
|
||||
return None, CompletionMessageResponse.model_validate(_unpack_blocking(response)).model_dump(mode="json")
|
||||
|
||||
|
||||
def _run_workflow(app: App, caller: Any, payload: AppRunRequest, streaming: bool):
|
||||
_enforce_workflow_constraint(payload)
|
||||
if payload.query is not None:
|
||||
raise UnprocessableEntity("query_not_supported_for_workflow")
|
||||
args = payload.model_dump(exclude={"query", "conversation_id", "auto_generate_name"}, exclude_none=True)
|
||||
try:
|
||||
with _translate_service_errors():
|
||||
response = _generate(app, caller, args, streaming)
|
||||
except WorkflowNotFoundError as ex:
|
||||
raise NotFound(str(ex))
|
||||
except (IsDraftWorkflowError, WorkflowIdFormatError) as ex:
|
||||
raise BadRequest(str(ex))
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
|
||||
if streaming:
|
||||
return response, None
|
||||
body = _unpack_blocking(response)
|
||||
return None, WorkflowRunResponse.model_validate(body).model_dump(mode="json")
|
||||
return None, WorkflowRunResponse.model_validate(_unpack_blocking(response)).model_dump(mode="json")
|
||||
|
||||
|
||||
_DISPATCH: dict[AppMode, Callable[[App, Any, AppRunRequest, bool], tuple[Any, dict[str, Any] | None]]] = {
|
||||
@ -220,25 +173,25 @@ class AppRunApi(Resource):
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
mode = app_model.mode
|
||||
handler = _DISPATCH.get(mode)
|
||||
handler = _DISPATCH.get(app_model.mode)
|
||||
if handler is None:
|
||||
raise UnprocessableEntity("mode_not_runnable")
|
||||
|
||||
streaming = payload.response_mode == "streaming"
|
||||
# Preserve specific HTTPException codes that the catch-all would otherwise mask.
|
||||
try:
|
||||
stream_obj, blocking_body = handler(app_model, caller, payload, streaming)
|
||||
except UnprocessableEntity:
|
||||
raise
|
||||
except (NotChatAppError, NotWorkflowAppError):
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
emit_app_run(app_id=app_model.id, tenant_id=app_model.tenant_id,
|
||||
caller_kind=caller_kind, mode=str(app_model.mode))
|
||||
emit_app_run(
|
||||
app_id=app_model.id,
|
||||
tenant_id=app_model.tenant_id,
|
||||
caller_kind=caller_kind,
|
||||
mode=str(app_model.mode),
|
||||
)
|
||||
|
||||
if streaming:
|
||||
return helper.compact_generate_response(stream_obj)
|
||||
|
||||
@ -195,7 +195,6 @@ def test_run_with_insufficient_scope_returns_403(
|
||||
|
||||
def _stub_authenticate(self, token: str):
|
||||
ctx = real_authenticate(self, token)
|
||||
# Return a copy with empty scopes — frozen dataclass requires replace.
|
||||
from dataclasses import replace
|
||||
|
||||
return replace(ctx, scopes=frozenset())
|
||||
|
||||
@ -1,11 +1,9 @@
|
||||
import pytest
|
||||
from werkzeug.exceptions import InternalServerError, UnprocessableEntity
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from controllers.openapi.app_run import (
|
||||
_DISPATCH,
|
||||
AppRunRequest,
|
||||
_enforce_chat_constraint,
|
||||
_enforce_workflow_constraint,
|
||||
_unpack_blocking,
|
||||
)
|
||||
from models.model import AppMode
|
||||
@ -16,16 +14,6 @@ def test_dispatch_covers_runnable_modes():
|
||||
assert set(_DISPATCH) == runnable
|
||||
|
||||
|
||||
def test_chat_constraint_requires_query():
|
||||
with pytest.raises(UnprocessableEntity, match="query_required_for_chat"):
|
||||
_enforce_chat_constraint(AppRunRequest(inputs={}))
|
||||
|
||||
|
||||
def test_workflow_constraint_rejects_query():
|
||||
with pytest.raises(UnprocessableEntity, match="query_not_supported_for_workflow"):
|
||||
_enforce_workflow_constraint(AppRunRequest(inputs={}, query="hi"))
|
||||
|
||||
|
||||
def test_unpack_blocking_passes_through_mapping():
|
||||
assert _unpack_blocking({"a": 1}) == {"a": 1}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user