mirror of
https://github.com/langgenius/dify.git
synced 2026-05-09 12:59:18 +08:00
fix(swagger): add util to convert BaseModel to schema for query params (#35959)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
5ebeb34feb
commit
d5ad6aedc0
@ -6,7 +6,9 @@ These helpers keep that translation centralized so models registered through
|
||||
`register_schema_models` emit resolvable Swagger 2.0 references.
|
||||
"""
|
||||
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import Any, NotRequired, TypedDict
|
||||
|
||||
from flask_restx import Namespace
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
@ -14,6 +16,26 @@ from pydantic import BaseModel, TypeAdapter
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
QueryParamDoc = TypedDict(
|
||||
"QueryParamDoc",
|
||||
{
|
||||
"in": NotRequired[str],
|
||||
"type": NotRequired[str],
|
||||
"items": NotRequired[dict[str, object]],
|
||||
"required": NotRequired[bool],
|
||||
"description": NotRequired[str],
|
||||
"enum": NotRequired[list[object]],
|
||||
"default": NotRequired[object],
|
||||
"minimum": NotRequired[int | float],
|
||||
"maximum": NotRequired[int | float],
|
||||
"minLength": NotRequired[int],
|
||||
"maxLength": NotRequired[int],
|
||||
"minItems": NotRequired[int],
|
||||
"maxItems": NotRequired[int],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _register_json_schema(namespace: Namespace, name: str, schema: dict) -> None:
|
||||
"""Register a JSON schema and promote any nested Pydantic `$defs`."""
|
||||
|
||||
@ -69,9 +91,104 @@ def register_enum_models(namespace: Namespace, *models: type[StrEnum]) -> None:
|
||||
)
|
||||
|
||||
|
||||
def query_params_from_model(model: type[BaseModel]) -> dict[str, QueryParamDoc]:
|
||||
"""Build Flask-RESTX query parameter docs from a flat Pydantic model.
|
||||
|
||||
`Namespace.expect()` treats Pydantic schema models as request bodies, so GET
|
||||
endpoints should keep runtime validation on the Pydantic model and feed this
|
||||
derived mapping to `Namespace.doc(params=...)` for Swagger documentation.
|
||||
"""
|
||||
|
||||
schema = model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
properties = schema.get("properties", {})
|
||||
if not isinstance(properties, Mapping):
|
||||
return {}
|
||||
|
||||
required = schema.get("required", [])
|
||||
required_names = set(required) if isinstance(required, list) else set()
|
||||
|
||||
params: dict[str, QueryParamDoc] = {}
|
||||
for name, property_schema in properties.items():
|
||||
if not isinstance(name, str) or not isinstance(property_schema, Mapping):
|
||||
continue
|
||||
|
||||
params[name] = _query_param_from_property(property_schema, required=name in required_names)
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def _query_param_from_property(property_schema: Mapping[str, Any], *, required: bool) -> QueryParamDoc:
|
||||
param_schema = _nullable_property_schema(property_schema)
|
||||
param_doc: QueryParamDoc = {"in": "query", "required": required}
|
||||
|
||||
description = param_schema.get("description")
|
||||
if isinstance(description, str):
|
||||
param_doc["description"] = description
|
||||
|
||||
schema_type = param_schema.get("type")
|
||||
if isinstance(schema_type, str) and schema_type in {"array", "boolean", "integer", "number", "string"}:
|
||||
param_doc["type"] = schema_type
|
||||
if schema_type == "array":
|
||||
items = param_schema.get("items")
|
||||
if isinstance(items, Mapping):
|
||||
item_type = items.get("type")
|
||||
if isinstance(item_type, str):
|
||||
param_doc["items"] = {"type": item_type}
|
||||
|
||||
enum = param_schema.get("enum")
|
||||
if isinstance(enum, list):
|
||||
param_doc["enum"] = enum
|
||||
|
||||
default = param_schema.get("default")
|
||||
if default is not None:
|
||||
param_doc["default"] = default
|
||||
|
||||
minimum = param_schema.get("minimum")
|
||||
if isinstance(minimum, int | float):
|
||||
param_doc["minimum"] = minimum
|
||||
|
||||
maximum = param_schema.get("maximum")
|
||||
if isinstance(maximum, int | float):
|
||||
param_doc["maximum"] = maximum
|
||||
|
||||
min_length = param_schema.get("minLength")
|
||||
if isinstance(min_length, int):
|
||||
param_doc["minLength"] = min_length
|
||||
|
||||
max_length = param_schema.get("maxLength")
|
||||
if isinstance(max_length, int):
|
||||
param_doc["maxLength"] = max_length
|
||||
|
||||
min_items = param_schema.get("minItems")
|
||||
if isinstance(min_items, int):
|
||||
param_doc["minItems"] = min_items
|
||||
|
||||
max_items = param_schema.get("maxItems")
|
||||
if isinstance(max_items, int):
|
||||
param_doc["maxItems"] = max_items
|
||||
|
||||
return param_doc
|
||||
|
||||
|
||||
def _nullable_property_schema(property_schema: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
any_of = property_schema.get("anyOf")
|
||||
if not isinstance(any_of, list):
|
||||
return property_schema
|
||||
|
||||
non_null_candidates = [
|
||||
candidate for candidate in any_of if isinstance(candidate, Mapping) and candidate.get("type") != "null"
|
||||
]
|
||||
|
||||
if len(non_null_candidates) == 1:
|
||||
return {**property_schema, **non_null_candidates[0]}
|
||||
|
||||
return property_schema
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_REF_TEMPLATE_SWAGGER_2_0",
|
||||
"get_or_create_model",
|
||||
"query_params_from_model",
|
||||
"register_enum_models",
|
||||
"register_schema_model",
|
||||
"register_schema_models",
|
||||
|
||||
@ -5,7 +5,7 @@ from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, computed_field, field_validator
|
||||
|
||||
from constants.languages import languages
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.schema import query_params_from_model, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from fields.base import ResponseModel
|
||||
@ -15,7 +15,7 @@ from services.recommended_app_service import RecommendedAppService
|
||||
|
||||
|
||||
class RecommendedAppsQuery(BaseModel):
|
||||
language: str | None = Field(default=None)
|
||||
language: str | None = Field(default=None, description="Language code for recommended app localization")
|
||||
|
||||
|
||||
class RecommendedAppInfoResponse(ResponseModel):
|
||||
@ -74,7 +74,7 @@ register_schema_models(
|
||||
|
||||
@console_ns.route("/explore/apps")
|
||||
class RecommendedAppListApi(Resource):
|
||||
@console_ns.expect(console_ns.models[RecommendedAppsQuery.__name__])
|
||||
@console_ns.doc(params=query_params_from_model(RecommendedAppsQuery))
|
||||
@console_ns.response(200, "Success", console_ns.models[RecommendedAppListResponse.__name__])
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
||||
@ -5507,7 +5507,7 @@ Delete an API key for a dataset
|
||||
|
||||
| Name | Located in | Description | Required | Schema |
|
||||
| ---- | ---------- | ----------- | -------- | ------ |
|
||||
| payload | body | | Yes | [RecommendedAppsQuery](#recommendedappsquery) |
|
||||
| language | query | Language code for recommended app localization | No | string |
|
||||
|
||||
##### Responses
|
||||
|
||||
@ -13289,7 +13289,7 @@ Default value types for form inputs.
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| language | | | No |
|
||||
| language | | Language code for recommended app localization | No |
|
||||
|
||||
#### RelatedAppList
|
||||
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
import sys
|
||||
from enum import StrEnum
|
||||
from typing import Literal
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask_restx import Namespace
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class UserModel(BaseModel):
|
||||
@ -25,6 +26,27 @@ class ParentModel(BaseModel):
|
||||
child: ChildModel
|
||||
|
||||
|
||||
class StatusEnum(StrEnum):
|
||||
ACTIVE = "active"
|
||||
INACTIVE = "inactive"
|
||||
|
||||
|
||||
class PriorityEnum(StrEnum):
|
||||
HIGH = "high"
|
||||
LOW = "low"
|
||||
|
||||
|
||||
class QueryModel(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
page: int = Field(default=1, ge=1, le=100, description="Page number")
|
||||
keyword: str | None = Field(default=None, min_length=1, max_length=50, description="Search keyword")
|
||||
status: Literal["active", "inactive"] | None = Field(default=None, description="Status filter")
|
||||
app_id: str = Field(..., alias="appId", description="Application ID")
|
||||
tag_ids: list[str] = Field(default_factory=list, min_length=1, max_length=3, description="Tag IDs")
|
||||
ambiguous: int | str | None = Field(default=None, description="Ambiguous query parameter")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_console_ns():
|
||||
"""Mock the console_ns to avoid circular imports during test collection."""
|
||||
@ -124,16 +146,6 @@ def test_register_schema_models_calls_register_schema_model(monkeypatch: pytest.
|
||||
]
|
||||
|
||||
|
||||
class StatusEnum(StrEnum):
|
||||
ACTIVE = "active"
|
||||
INACTIVE = "inactive"
|
||||
|
||||
|
||||
class PriorityEnum(StrEnum):
|
||||
HIGH = "high"
|
||||
LOW = "low"
|
||||
|
||||
|
||||
def test_get_or_create_model_returns_existing_model(mock_console_ns):
|
||||
from controllers.common.schema import get_or_create_model
|
||||
|
||||
@ -211,3 +223,54 @@ def test_register_enum_models_uses_correct_ref_template():
|
||||
|
||||
# Verify the schema contains enum values
|
||||
assert "enum" in schema or "anyOf" in schema
|
||||
|
||||
|
||||
def test_query_params_from_model_builds_flask_restx_doc_params():
|
||||
from controllers.common.schema import query_params_from_model
|
||||
|
||||
params = query_params_from_model(QueryModel)
|
||||
|
||||
assert params["page"] == {
|
||||
"in": "query",
|
||||
"required": False,
|
||||
"description": "Page number",
|
||||
"type": "integer",
|
||||
"default": 1,
|
||||
"minimum": 1,
|
||||
"maximum": 100,
|
||||
}
|
||||
assert params["keyword"] == {
|
||||
"in": "query",
|
||||
"required": False,
|
||||
"description": "Search keyword",
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"maxLength": 50,
|
||||
}
|
||||
assert params["status"] == {
|
||||
"in": "query",
|
||||
"required": False,
|
||||
"description": "Status filter",
|
||||
"type": "string",
|
||||
"enum": ["active", "inactive"],
|
||||
}
|
||||
assert params["appId"] == {
|
||||
"in": "query",
|
||||
"required": True,
|
||||
"description": "Application ID",
|
||||
"type": "string",
|
||||
}
|
||||
assert params["tag_ids"] == {
|
||||
"in": "query",
|
||||
"required": False,
|
||||
"description": "Tag IDs",
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"minItems": 1,
|
||||
"maxItems": 3,
|
||||
}
|
||||
assert params["ambiguous"] == {
|
||||
"in": "query",
|
||||
"required": False,
|
||||
"description": "Ambiguous query parameter",
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user