mirror of https://github.com/langgenius/dify.git
refactor: remove all reqparser (#29289)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com>
This commit is contained in:
parent
b8cb5f5ea2
commit
7828508b30
|
|
@ -53,6 +53,7 @@ select = [
|
|||
"S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers.
|
||||
"S302", # suspicious-marshal-usage, disallow use of `marshal` module
|
||||
"S311", # suspicious-non-cryptographic-random-usage,
|
||||
"TID", # flake8-tidy-imports
|
||||
|
||||
]
|
||||
|
||||
|
|
@ -88,6 +89,7 @@ ignore = [
|
|||
"SIM113", # enumerate-for-loop
|
||||
"SIM117", # multiple-with-statements
|
||||
"SIM210", # if-expr-with-true-false
|
||||
"TID252", # allow relative imports from parent modules
|
||||
]
|
||||
|
||||
[lint.per-file-ignores]
|
||||
|
|
@ -109,10 +111,20 @@ ignore = [
|
|||
"S110", # allow ignoring exceptions in tests code (currently)
|
||||
|
||||
]
|
||||
"controllers/console/explore/trial.py" = ["TID251"]
|
||||
"controllers/console/human_input_form.py" = ["TID251"]
|
||||
"controllers/web/human_input_form.py" = ["TID251"]
|
||||
|
||||
[lint.pyflakes]
|
||||
allowed-unused-imports = [
|
||||
"_pytest.monkeypatch",
|
||||
"tests.integration_tests",
|
||||
"tests.unit_tests",
|
||||
]
|
||||
|
||||
[lint.flake8-tidy-imports]
|
||||
|
||||
[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse"]
|
||||
msg = "Use Pydantic payload/query models instead of reqparse."
|
||||
|
||||
[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse.RequestParser"]
|
||||
msg = "Use Pydantic payload/query models instead of reqparse."
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
import json
|
||||
import logging
|
||||
from typing import Any, Literal, cast
|
||||
from uuid import UUID
|
||||
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource, marshal_with, reqparse # type: ignore
|
||||
from flask_restx import Resource, marshal_with # type: ignore
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
|
@ -38,7 +37,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
|||
from extensions.ext_database import db
|
||||
from factories import variable_factory
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField
|
||||
from libs.helper import TimestampField, UUIDStrOrEmpty
|
||||
from libs.login import current_account_with_tenant, current_user, login_required
|
||||
from models import Account
|
||||
from models.dataset import Pipeline
|
||||
|
|
@ -110,7 +109,7 @@ class NodeIdQuery(BaseModel):
|
|||
|
||||
|
||||
class WorkflowRunQuery(BaseModel):
|
||||
last_id: UUID | None = None
|
||||
last_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
|
|
@ -121,6 +120,10 @@ class DatasourceVariablesPayload(BaseModel):
|
|||
start_node_title: str
|
||||
|
||||
|
||||
class RagPipelineRecommendedPluginQuery(BaseModel):
|
||||
type: str = "all"
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
DraftWorkflowSyncPayload,
|
||||
|
|
@ -135,6 +138,7 @@ register_schema_models(
|
|||
NodeIdQuery,
|
||||
WorkflowRunQuery,
|
||||
DatasourceVariablesPayload,
|
||||
RagPipelineRecommendedPluginQuery,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -975,11 +979,8 @@ class RagPipelineRecommendedPluginApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("type", type=str, location="args", required=False, default="all")
|
||||
args = parser.parse_args()
|
||||
type = args["type"]
|
||||
query = RagPipelineRecommendedPluginQuery.model_validate(request.args.to_dict())
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
recommended_plugins = rag_pipeline_service.get_recommended_plugins(type)
|
||||
recommended_plugins = rag_pipeline_service.get_recommended_plugins(query.type)
|
||||
return recommended_plugins
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -30,6 +30,7 @@ from core.errors.error import (
|
|||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_task_service import AppTaskService
|
||||
|
|
@ -52,7 +53,7 @@ class ChatRequestPayload(BaseModel):
|
|||
query: str
|
||||
files: list[dict[str, Any]] | None = None
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
conversation_id: str | None = Field(default=None, description="Conversation UUID")
|
||||
conversation_id: UUIDStrOrEmpty | None = Field(default=None, description="Conversation UUID")
|
||||
retriever_from: str = Field(default="dev")
|
||||
auto_generate_name: bool = Field(default=True, description="Auto generate conversation name")
|
||||
workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
|
|
@ -23,12 +22,13 @@ from fields.conversation_variable_fields import (
|
|||
build_conversation_variable_infinite_scroll_pagination_model,
|
||||
build_conversation_variable_model,
|
||||
)
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.conversation_service import ConversationService
|
||||
|
||||
|
||||
class ConversationListQuery(BaseModel):
|
||||
last_id: UUID | None = Field(default=None, description="Last conversation ID for pagination")
|
||||
last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last conversation ID for pagination")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of conversations to return")
|
||||
sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field(
|
||||
default="-updated_at", description="Sort order for conversations"
|
||||
|
|
@ -48,7 +48,7 @@ class ConversationRenamePayload(BaseModel):
|
|||
|
||||
|
||||
class ConversationVariablesQuery(BaseModel):
|
||||
last_id: UUID | None = Field(default=None, description="Last variable ID for pagination")
|
||||
last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last variable ID for pagination")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")
|
||||
variable_name: str | None = Field(
|
||||
default=None, description="Filter variables by name", min_length=1, max_length=255
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import logging
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
|
|
@ -15,6 +14,7 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
|
|||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.errors.message import (
|
||||
FirstMessageNotExistsError,
|
||||
|
|
@ -27,8 +27,8 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class MessageListQuery(BaseModel):
|
||||
conversation_id: UUID
|
||||
first_id: UUID | None = None
|
||||
conversation_id: UUIDStrOrEmpty
|
||||
first_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase, HitTestingPayload
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
|
||||
|
||||
register_schema_model(service_api_ns, HitTestingPayload)
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/hit-testing", "/datasets/<uuid:dataset_id>/retrieve")
|
||||
class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
|
||||
|
|
@ -15,6 +18,7 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
|
|||
404: "Dataset not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.expect(service_api_ns.models[HitTestingPayload.__name__])
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id):
|
||||
"""Perform hit testing on a dataset.
|
||||
|
|
|
|||
|
|
@ -7,11 +7,6 @@ from core.workflow.nodes.base.entities import OutputVariableEntity
|
|||
|
||||
|
||||
class WorkflowToolConfigurationUtils:
|
||||
@classmethod
|
||||
def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]):
|
||||
for configuration in configurations:
|
||||
WorkflowToolParameterConfiguration.model_validate(configuration)
|
||||
|
||||
@classmethod
|
||||
def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
|
@ -10,8 +8,8 @@ from sqlalchemy.orm import Session
|
|||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -38,12 +36,10 @@ class WorkflowToolManageService:
|
|||
label: str,
|
||||
icon: dict,
|
||||
description: str,
|
||||
parameters: list[Mapping[str, Any]],
|
||||
parameters: list[WorkflowToolParameterConfiguration],
|
||||
privacy_policy: str = "",
|
||||
labels: list[str] | None = None,
|
||||
):
|
||||
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
|
||||
|
||||
# check if the name is unique
|
||||
existing_workflow_tool_provider = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
|
|
@ -75,7 +71,7 @@ class WorkflowToolManageService:
|
|||
label=label,
|
||||
icon=json.dumps(icon),
|
||||
description=description,
|
||||
parameter_configuration=json.dumps(parameters),
|
||||
parameter_configuration=json.dumps([p.model_dump() for p in parameters]),
|
||||
privacy_policy=privacy_policy,
|
||||
version=workflow.version,
|
||||
)
|
||||
|
|
@ -104,7 +100,7 @@ class WorkflowToolManageService:
|
|||
label: str,
|
||||
icon: dict,
|
||||
description: str,
|
||||
parameters: list[Mapping[str, Any]],
|
||||
parameters: list[WorkflowToolParameterConfiguration],
|
||||
privacy_policy: str = "",
|
||||
labels: list[str] | None = None,
|
||||
):
|
||||
|
|
@ -122,8 +118,6 @@ class WorkflowToolManageService:
|
|||
:param labels: labels
|
||||
:return: the updated tool
|
||||
"""
|
||||
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
|
||||
|
||||
# check if the name is unique
|
||||
existing_workflow_tool_provider = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
|
|
@ -162,7 +156,7 @@ class WorkflowToolManageService:
|
|||
workflow_tool_provider.label = label
|
||||
workflow_tool_provider.icon = json.dumps(icon)
|
||||
workflow_tool_provider.description = description
|
||||
workflow_tool_provider.parameter_configuration = json.dumps(parameters)
|
||||
workflow_tool_provider.parameter_configuration = json.dumps([p.model_dump() for p in parameters])
|
||||
workflow_tool_provider.privacy_policy = privacy_policy
|
||||
workflow_tool_provider.version = workflow.version
|
||||
workflow_tool_provider.updated_at = datetime.now()
|
||||
|
|
|
|||
|
|
@ -3,7 +3,9 @@ from unittest.mock import patch
|
|||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
from models.tools import WorkflowToolProvider
|
||||
from models.workflow import Workflow as WorkflowModel
|
||||
from services.account_service import AccountService, TenantService
|
||||
|
|
@ -130,20 +132,24 @@ class TestWorkflowToolManageService:
|
|||
def _create_test_workflow_tool_parameters(self):
|
||||
"""Helper method to create valid workflow tool parameters."""
|
||||
return [
|
||||
{
|
||||
"name": "input_text",
|
||||
"description": "Input text for processing",
|
||||
"form": "form",
|
||||
"type": "string",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "output_format",
|
||||
"description": "Output format specification",
|
||||
"form": "form",
|
||||
"type": "select",
|
||||
"required": False,
|
||||
},
|
||||
WorkflowToolParameterConfiguration.model_validate(
|
||||
{
|
||||
"name": "input_text",
|
||||
"description": "Input text for processing",
|
||||
"form": "form",
|
||||
"type": "string",
|
||||
"required": True,
|
||||
}
|
||||
),
|
||||
WorkflowToolParameterConfiguration.model_validate(
|
||||
{
|
||||
"name": "output_format",
|
||||
"description": "Output format specification",
|
||||
"form": "form",
|
||||
"type": "select",
|
||||
"required": False,
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
def test_create_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
|
|
@ -208,7 +214,7 @@ class TestWorkflowToolManageService:
|
|||
assert created_tool_provider.label == tool_label
|
||||
assert created_tool_provider.icon == json.dumps(tool_icon)
|
||||
assert created_tool_provider.description == tool_description
|
||||
assert created_tool_provider.parameter_configuration == json.dumps(tool_parameters)
|
||||
assert created_tool_provider.parameter_configuration == json.dumps([p.model_dump() for p in tool_parameters])
|
||||
assert created_tool_provider.privacy_policy == tool_privacy_policy
|
||||
assert created_tool_provider.version == workflow.version
|
||||
assert created_tool_provider.user_id == account.id
|
||||
|
|
@ -353,18 +359,9 @@ class TestWorkflowToolManageService:
|
|||
app, account, workflow = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Setup invalid workflow tool parameters (missing required fields)
|
||||
invalid_parameters = [
|
||||
{
|
||||
"name": "input_text",
|
||||
# Missing description and form fields
|
||||
"type": "string",
|
||||
"required": True,
|
||||
}
|
||||
]
|
||||
# Attempt to create workflow tool with invalid parameters
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
# Setup invalid workflow tool parameters (missing required fields)
|
||||
WorkflowToolManageService.create_workflow_tool(
|
||||
user_id=account.id,
|
||||
tenant_id=account.current_tenant.id,
|
||||
|
|
@ -373,7 +370,16 @@ class TestWorkflowToolManageService:
|
|||
label=fake.word(),
|
||||
icon={"type": "emoji", "emoji": "🔧"},
|
||||
description=fake.text(max_nb_chars=200),
|
||||
parameters=invalid_parameters,
|
||||
parameters=[
|
||||
WorkflowToolParameterConfiguration.model_validate(
|
||||
{
|
||||
"name": "input_text",
|
||||
# Missing description and form fields
|
||||
"type": "string",
|
||||
"required": True,
|
||||
}
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# Verify error message contains validation error
|
||||
|
|
@ -579,11 +585,12 @@ class TestWorkflowToolManageService:
|
|||
|
||||
# Verify database state was updated
|
||||
db.session.refresh(created_tool)
|
||||
assert created_tool is not None
|
||||
assert created_tool.name == updated_tool_name
|
||||
assert created_tool.label == updated_tool_label
|
||||
assert created_tool.icon == json.dumps(updated_tool_icon)
|
||||
assert created_tool.description == updated_tool_description
|
||||
assert created_tool.parameter_configuration == json.dumps(updated_tool_parameters)
|
||||
assert created_tool.parameter_configuration == json.dumps([p.model_dump() for p in updated_tool_parameters])
|
||||
assert created_tool.privacy_policy == updated_tool_privacy_policy
|
||||
assert created_tool.version == workflow.version
|
||||
assert created_tool.updated_at is not None
|
||||
|
|
@ -750,13 +757,15 @@ class TestWorkflowToolManageService:
|
|||
|
||||
# Setup workflow tool parameters with FILE type
|
||||
file_parameters = [
|
||||
{
|
||||
"name": "document",
|
||||
"description": "Upload a document",
|
||||
"form": "form",
|
||||
"type": "file",
|
||||
"required": False,
|
||||
}
|
||||
WorkflowToolParameterConfiguration.model_validate(
|
||||
{
|
||||
"name": "document",
|
||||
"description": "Upload a document",
|
||||
"form": "form",
|
||||
"type": "file",
|
||||
"required": False,
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
# Execute the method under test
|
||||
|
|
@ -823,13 +832,15 @@ class TestWorkflowToolManageService:
|
|||
|
||||
# Setup workflow tool parameters with FILES type
|
||||
files_parameters = [
|
||||
{
|
||||
"name": "documents",
|
||||
"description": "Upload multiple documents",
|
||||
"form": "form",
|
||||
"type": "files",
|
||||
"required": False,
|
||||
}
|
||||
WorkflowToolParameterConfiguration.model_validate(
|
||||
{
|
||||
"name": "documents",
|
||||
"description": "Upload multiple documents",
|
||||
"form": "form",
|
||||
"type": "files",
|
||||
"required": False,
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
# Execute the method under test
|
||||
|
|
|
|||
Loading…
Reference in New Issue