mirror of
https://github.com/langgenius/dify.git
synced 2026-05-10 05:56:31 +08:00
refactor(openapi): drop legacy per-mode bearer routes
Removes /openapi/v1/apps/<id>/{chat-messages,completion-messages,
workflows/run}. Bearer surface for runs is now the unified /run route
(api-3). Service-API /v1/* per-mode routes (app-key auth) untouched.
Also deletes the corresponding unit test files
(test_chat_messages.py, test_completion_messages.py, test_workflow_run.py)
which targeted the removed handlers; coverage of the unified route lives
in tests/unit_tests/controllers/openapi/test_app_run_dispatch.py and
tests/integration_tests/controllers/openapi/test_app_run.py.
This commit is contained in:
parent
4bc1046f14
commit
fb7b8dc151
@ -21,12 +21,9 @@ from . import (
|
||||
app_run,
|
||||
apps,
|
||||
apps_permitted,
|
||||
chat_messages,
|
||||
completion_messages,
|
||||
index,
|
||||
oauth_device,
|
||||
oauth_device_sso,
|
||||
workflow_run,
|
||||
workspaces,
|
||||
)
|
||||
|
||||
@ -35,12 +32,9 @@ __all__ = [
|
||||
"app_run",
|
||||
"apps",
|
||||
"apps_permitted",
|
||||
"chat_messages",
|
||||
"completion_messages",
|
||||
"index",
|
||||
"oauth_device",
|
||||
"oauth_device_sso",
|
||||
"workflow_run",
|
||||
"workspaces",
|
||||
]
|
||||
|
||||
|
||||
@ -1,177 +0,0 @@
|
||||
"""POST /openapi/v1/apps/<app_id>/chat-messages — port of
|
||||
service_api/app/completion.py:ChatApi.
|
||||
|
||||
Differences from service_api:
|
||||
- App is in URL path, not header.
|
||||
- One decorator: @OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN).
|
||||
- Request body has no `user` field (Model 2: identity is the bearer).
|
||||
- Typed Request and Response models.
|
||||
- invoke_from = InvokeFrom.OPENAPI.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, ValidationError, field_validator
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._audit import emit_app_run
|
||||
from controllers.openapi._models import MessageMetadata
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
from controllers.service_api.app.error import (
|
||||
AppUnavailableError,
|
||||
CompletionRequestError,
|
||||
ConversationCompletedError,
|
||||
NotChatAppError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from libs.oauth_bearer import Scope
|
||||
from models.model import App, AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import (
|
||||
IsDraftWorkflowError,
|
||||
WorkflowIdFormatError,
|
||||
WorkflowNotFoundError,
|
||||
)
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatMessageRequest(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
query: str
|
||||
files: list[dict[str, Any]] | None = None
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
conversation_id: UUIDStrOrEmpty | None = Field(default=None)
|
||||
auto_generate_name: bool = Field(default=True)
|
||||
workflow_id: str | None = Field(default=None)
|
||||
|
||||
@field_validator("conversation_id", mode="before")
|
||||
@classmethod
|
||||
def normalize_conversation_id(cls, value: str | UUID | None) -> str | None:
|
||||
if isinstance(value, str):
|
||||
value = value.strip()
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return helper.uuid_value(value)
|
||||
except ValueError as exc:
|
||||
raise ValueError("conversation_id must be a valid UUID") from exc
|
||||
|
||||
|
||||
class ChatMessageResponse(BaseModel):
|
||||
event: str
|
||||
task_id: str
|
||||
id: str
|
||||
message_id: str
|
||||
conversation_id: str
|
||||
mode: str
|
||||
answer: str
|
||||
metadata: MessageMetadata = Field(default_factory=MessageMetadata)
|
||||
created_at: int
|
||||
|
||||
|
||||
def _unpack_app(app_model):
|
||||
return app_model
|
||||
|
||||
|
||||
def _unpack_caller(caller):
|
||||
return caller
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/chat-messages")
|
||||
class ChatMessagesApi(Resource):
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, app_model: App, caller, caller_kind: str):
|
||||
app = _unpack_app(app_model)
|
||||
if AppMode.value_of(app.mode) not in {
|
||||
AppMode.CHAT,
|
||||
AppMode.AGENT_CHAT,
|
||||
AppMode.ADVANCED_CHAT,
|
||||
}:
|
||||
raise NotChatAppError()
|
||||
|
||||
body = request.get_json(silent=True) or {}
|
||||
body.pop("user", None)
|
||||
try:
|
||||
payload = ChatMessageRequest.model_validate(body)
|
||||
except ValidationError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
streaming = payload.response_mode == "streaming"
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app,
|
||||
user=_unpack_caller(caller),
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.OPENAPI,
|
||||
streaming=streaming,
|
||||
)
|
||||
except WorkflowNotFoundError as ex:
|
||||
raise NotFound(str(ex))
|
||||
except (IsDraftWorkflowError, WorkflowIdFormatError) as ex:
|
||||
raise BadRequest(str(ex))
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
emit_app_run(
|
||||
app_id=app.id,
|
||||
tenant_id=app.tenant_id,
|
||||
caller_kind=caller_kind,
|
||||
mode=str(app.mode),
|
||||
)
|
||||
|
||||
if streaming:
|
||||
return helper.compact_generate_response(response)
|
||||
|
||||
# Some upstream paths (and tests) return (body, status); production
|
||||
# generate returns Mapping. Accept both, then validate.
|
||||
if isinstance(response, tuple):
|
||||
body_dict: Any = response[0] # pyright: ignore[reportArgumentType]
|
||||
else:
|
||||
body_dict = response
|
||||
if not isinstance(body_dict, Mapping):
|
||||
raise InternalServerError("blocking generate returned non-mapping response")
|
||||
return ChatMessageResponse.model_validate(dict(body_dict)).model_dump(mode="json"), 200
|
||||
@ -1,132 +0,0 @@
|
||||
"""POST /openapi/v1/apps/<app_id>/completion-messages — port of
|
||||
service_api/app/completion.py:CompletionApi."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._audit import emit_app_run
|
||||
from controllers.openapi._models import MessageMetadata
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
from controllers.service_api.app.error import (
|
||||
AppUnavailableError,
|
||||
CompletionRequestError,
|
||||
ConversationCompletedError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.oauth_bearer import Scope
|
||||
from models.model import App, AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompletionMessageRequest(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
query: str = Field(default="")
|
||||
files: list[dict[str, Any]] | None = None
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
|
||||
|
||||
class CompletionMessageResponse(BaseModel):
|
||||
event: str
|
||||
task_id: str
|
||||
id: str
|
||||
message_id: str
|
||||
mode: str
|
||||
answer: str
|
||||
metadata: MessageMetadata = Field(default_factory=MessageMetadata)
|
||||
created_at: int
|
||||
|
||||
|
||||
def _unpack_app(app_model):
|
||||
return app_model
|
||||
|
||||
|
||||
def _unpack_caller(caller):
|
||||
return caller
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/completion-messages")
|
||||
class CompletionMessagesApi(Resource):
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, app_model: App, caller, caller_kind: str):
|
||||
app = _unpack_app(app_model)
|
||||
if AppMode.value_of(app.mode) != AppMode.COMPLETION:
|
||||
raise AppUnavailableError()
|
||||
|
||||
body = request.get_json(silent=True) or {}
|
||||
body.pop("user", None)
|
||||
try:
|
||||
payload = CompletionMessageRequest.model_validate(body)
|
||||
except ValidationError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
args["auto_generate_name"] = False
|
||||
streaming = payload.response_mode == "streaming"
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app,
|
||||
user=_unpack_caller(caller),
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.OPENAPI,
|
||||
streaming=streaming,
|
||||
)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
emit_app_run(
|
||||
app_id=app.id,
|
||||
tenant_id=app.tenant_id,
|
||||
caller_kind=caller_kind,
|
||||
mode=str(app.mode),
|
||||
)
|
||||
|
||||
if streaming:
|
||||
return helper.compact_generate_response(response)
|
||||
|
||||
if isinstance(response, tuple):
|
||||
body_dict: Any = response[0] # pyright: ignore[reportArgumentType]
|
||||
else:
|
||||
body_dict = response
|
||||
if not isinstance(body_dict, Mapping):
|
||||
raise InternalServerError("blocking generate returned non-mapping response")
|
||||
return CompletionMessageResponse.model_validate(dict(body_dict)).model_dump(mode="json"), 200
|
||||
@ -1,140 +0,0 @@
|
||||
"""POST /openapi/v1/apps/<app_id>/workflows/run — port of
|
||||
service_api/app/workflow.py:WorkflowRunApi."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
from controllers.common.controller_schemas import WorkflowRunPayload as WorkflowRunPayloadBase
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._audit import emit_app_run
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
from controllers.service_api.app.error import (
|
||||
CompletionRequestError,
|
||||
NotWorkflowAppError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.oauth_bearer import Scope
|
||||
from models.model import App, AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import (
|
||||
IsDraftWorkflowError,
|
||||
WorkflowIdFormatError,
|
||||
WorkflowNotFoundError,
|
||||
)
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowRunRequest(WorkflowRunPayloadBase):
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
|
||||
|
||||
class WorkflowRunData(BaseModel):
|
||||
id: str
|
||||
workflow_id: str
|
||||
status: str
|
||||
outputs: dict[str, Any] = {}
|
||||
error: str | None = None
|
||||
elapsed_time: float | None = None
|
||||
total_tokens: int | None = None
|
||||
total_steps: int | None = None
|
||||
created_at: int | None = None
|
||||
finished_at: int | None = None
|
||||
|
||||
|
||||
class WorkflowRunResponse(BaseModel):
|
||||
workflow_run_id: str
|
||||
task_id: str
|
||||
data: WorkflowRunData
|
||||
|
||||
|
||||
def _unpack_app(app_model):
|
||||
return app_model
|
||||
|
||||
|
||||
def _unpack_caller(caller):
|
||||
return caller
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/workflows/run")
|
||||
class WorkflowRunApi(Resource):
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, app_model: App, caller, caller_kind: str):
|
||||
app = _unpack_app(app_model)
|
||||
if AppMode.value_of(app.mode) != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
body = request.get_json(silent=True) or {}
|
||||
body.pop("user", None)
|
||||
try:
|
||||
payload = WorkflowRunRequest.model_validate(body)
|
||||
except ValidationError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
streaming = payload.response_mode == "streaming"
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app,
|
||||
user=_unpack_caller(caller),
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.OPENAPI,
|
||||
streaming=streaming,
|
||||
)
|
||||
except WorkflowNotFoundError as ex:
|
||||
raise NotFound(str(ex))
|
||||
except (IsDraftWorkflowError, WorkflowIdFormatError) as ex:
|
||||
raise BadRequest(str(ex))
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
emit_app_run(
|
||||
app_id=app.id,
|
||||
tenant_id=app.tenant_id,
|
||||
caller_kind=caller_kind,
|
||||
mode=str(app.mode),
|
||||
)
|
||||
|
||||
if streaming:
|
||||
return helper.compact_generate_response(response)
|
||||
|
||||
if isinstance(response, tuple):
|
||||
body_dict: Any = response[0] # pyright: ignore[reportArgumentType]
|
||||
else:
|
||||
body_dict = response
|
||||
if not isinstance(body_dict, Mapping):
|
||||
raise InternalServerError("blocking generate returned non-mapping response")
|
||||
return WorkflowRunResponse.model_validate(dict(body_dict)).model_dump(mode="json"), 200
|
||||
@ -1,89 +0,0 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from flask import Flask
|
||||
from flask_restx import Api
|
||||
|
||||
|
||||
def _client():
|
||||
from controllers.openapi import (
|
||||
chat_messages, # noqa: F401
|
||||
openapi_ns,
|
||||
)
|
||||
|
||||
app = Flask(__name__)
|
||||
api = Api(app)
|
||||
api.add_namespace(openapi_ns, path="/openapi/v1")
|
||||
return app.test_client()
|
||||
|
||||
|
||||
@patch("controllers.openapi.chat_messages.AppGenerateService")
|
||||
def test_chat_dispatches_and_returns_response_model(svc, bypass_pipeline):
|
||||
svc.generate.return_value = (
|
||||
{
|
||||
"event": "message",
|
||||
"task_id": "tk1",
|
||||
"id": "m1",
|
||||
"message_id": "m1",
|
||||
"conversation_id": "c1",
|
||||
"mode": "chat",
|
||||
"answer": "hi",
|
||||
"metadata": {},
|
||||
"created_at": 1700000000,
|
||||
},
|
||||
200,
|
||||
)
|
||||
fake = SimpleNamespace(mode="chat", id="app1", tenant_id="t1")
|
||||
with (
|
||||
patch("controllers.openapi.chat_messages._unpack_app", return_value=fake),
|
||||
patch("controllers.openapi.chat_messages._unpack_caller", return_value=SimpleNamespace()),
|
||||
):
|
||||
r = _client().post("/openapi/v1/apps/app1/chat-messages", json={"query": "hi", "inputs": {}})
|
||||
assert r.status_code == 200
|
||||
body = r.get_json()
|
||||
assert body["conversation_id"] == "c1"
|
||||
assert body["answer"] == "hi"
|
||||
assert svc.generate.call_args.kwargs["invoke_from"].value == "openapi"
|
||||
|
||||
|
||||
@patch("controllers.openapi.chat_messages.AppGenerateService")
|
||||
def test_chat_strips_user_field_from_body(svc, bypass_pipeline):
|
||||
svc.generate.return_value = (
|
||||
{
|
||||
"event": "message",
|
||||
"task_id": "tk1",
|
||||
"id": "m1",
|
||||
"message_id": "m1",
|
||||
"conversation_id": "c1",
|
||||
"mode": "chat",
|
||||
"answer": "hi",
|
||||
"metadata": {},
|
||||
"created_at": 1700000000,
|
||||
},
|
||||
200,
|
||||
)
|
||||
fake = SimpleNamespace(mode="chat", id="app1", tenant_id="t1")
|
||||
with (
|
||||
patch("controllers.openapi.chat_messages._unpack_app", return_value=fake),
|
||||
patch("controllers.openapi.chat_messages._unpack_caller", return_value=SimpleNamespace()),
|
||||
):
|
||||
_client().post(
|
||||
"/openapi/v1/apps/app1/chat-messages",
|
||||
json={"query": "hi", "inputs": {}, "user": "spoof@x.com"},
|
||||
)
|
||||
args = svc.generate.call_args.kwargs["args"]
|
||||
assert "user" not in args
|
||||
|
||||
|
||||
def test_chat_rejects_non_chat_mode(bypass_pipeline):
|
||||
fake = SimpleNamespace(mode="completion")
|
||||
with patch("controllers.openapi.chat_messages._unpack_app", return_value=fake):
|
||||
r = _client().post("/openapi/v1/apps/app1/chat-messages", json={"query": "hi", "inputs": {}})
|
||||
assert r.status_code in (400, 403)
|
||||
|
||||
|
||||
def test_chat_rejects_invalid_body(bypass_pipeline):
|
||||
fake = SimpleNamespace(mode="chat", id="app1", tenant_id="t1")
|
||||
with patch("controllers.openapi.chat_messages._unpack_app", return_value=fake):
|
||||
r = _client().post("/openapi/v1/apps/app1/chat-messages", json={"query": "hi"})
|
||||
assert r.status_code in (400, 422)
|
||||
@ -1,57 +0,0 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from flask import Flask
|
||||
from flask_restx import Api
|
||||
|
||||
|
||||
def _client():
|
||||
from controllers.openapi import (
|
||||
completion_messages, # noqa: F401
|
||||
openapi_ns,
|
||||
)
|
||||
|
||||
app = Flask(__name__)
|
||||
api = Api(app)
|
||||
api.add_namespace(openapi_ns, path="/openapi/v1")
|
||||
return app.test_client()
|
||||
|
||||
|
||||
@patch("controllers.openapi.completion_messages.AppGenerateService")
|
||||
def test_completion_returns_response_model(svc, bypass_pipeline):
|
||||
svc.generate.return_value = (
|
||||
{
|
||||
"event": "message",
|
||||
"task_id": "tk",
|
||||
"id": "m1",
|
||||
"message_id": "m1",
|
||||
"mode": "completion",
|
||||
"answer": "ok",
|
||||
"metadata": {},
|
||||
"created_at": 1700000000,
|
||||
},
|
||||
200,
|
||||
)
|
||||
fake = SimpleNamespace(mode="completion", id="app1", tenant_id="t1")
|
||||
with (
|
||||
patch("controllers.openapi.completion_messages._unpack_app", return_value=fake),
|
||||
patch("controllers.openapi.completion_messages._unpack_caller", return_value=SimpleNamespace()),
|
||||
):
|
||||
r = _client().post(
|
||||
"/openapi/v1/apps/app1/completion-messages",
|
||||
json={"inputs": {"x": 1}, "query": "hi"},
|
||||
)
|
||||
assert r.status_code == 200
|
||||
body = r.get_json()
|
||||
assert body["answer"] == "ok"
|
||||
assert svc.generate.call_args.kwargs["invoke_from"].value == "openapi"
|
||||
|
||||
|
||||
def test_completion_rejects_chat_mode(bypass_pipeline):
|
||||
fake = SimpleNamespace(mode="chat")
|
||||
with patch("controllers.openapi.completion_messages._unpack_app", return_value=fake):
|
||||
r = _client().post(
|
||||
"/openapi/v1/apps/app1/completion-messages",
|
||||
json={"inputs": {}, "query": "hi"},
|
||||
)
|
||||
assert r.status_code in (400, 403)
|
||||
@ -1,53 +0,0 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from flask import Flask
|
||||
from flask_restx import Api
|
||||
|
||||
|
||||
def _client():
|
||||
from controllers.openapi import (
|
||||
openapi_ns,
|
||||
workflow_run, # noqa: F401
|
||||
)
|
||||
|
||||
app = Flask(__name__)
|
||||
api = Api(app)
|
||||
api.add_namespace(openapi_ns, path="/openapi/v1")
|
||||
return app.test_client()
|
||||
|
||||
|
||||
@patch("controllers.openapi.workflow_run.AppGenerateService")
|
||||
def test_workflow_run_returns_response_model(svc, bypass_pipeline):
|
||||
svc.generate.return_value = (
|
||||
{
|
||||
"workflow_run_id": "wr1",
|
||||
"task_id": "tk",
|
||||
"data": {
|
||||
"id": "wr1",
|
||||
"workflow_id": "wf1",
|
||||
"status": "succeeded",
|
||||
"outputs": {"result": "ok"},
|
||||
"elapsed_time": 1.0,
|
||||
},
|
||||
},
|
||||
200,
|
||||
)
|
||||
fake = SimpleNamespace(mode="workflow", id="app1", tenant_id="t1")
|
||||
with (
|
||||
patch("controllers.openapi.workflow_run._unpack_app", return_value=fake),
|
||||
patch("controllers.openapi.workflow_run._unpack_caller", return_value=SimpleNamespace()),
|
||||
):
|
||||
r = _client().post("/openapi/v1/apps/app1/workflows/run", json={"inputs": {"x": 1}})
|
||||
assert r.status_code == 200
|
||||
body = r.get_json()
|
||||
assert body["workflow_run_id"] == "wr1"
|
||||
assert body["data"]["status"] == "succeeded"
|
||||
assert svc.generate.call_args.kwargs["invoke_from"].value == "openapi"
|
||||
|
||||
|
||||
def test_workflow_run_rejects_non_workflow(bypass_pipeline):
|
||||
fake = SimpleNamespace(mode="chat")
|
||||
with patch("controllers.openapi.workflow_run._unpack_app", return_value=fake):
|
||||
r = _client().post("/openapi/v1/apps/app1/workflows/run", json={"inputs": {}})
|
||||
assert r.status_code in (400, 403)
|
||||
Loading…
Reference in New Issue
Block a user