This commit is contained in:
Asuka Minato 2025-12-29 15:43:28 +08:00 committed by GitHub
commit 9e19a05368
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 612 additions and 577 deletions

View File

@ -49,6 +49,7 @@ select = [
"S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers.
"S302", # suspicious-marshal-usage, disallow use of `marshal` module
"S311", # suspicious-non-cryptographic-random-usage,
"TID", # flake8-tidy-imports
]
@ -84,6 +85,7 @@ ignore = [
"SIM113", # enumerate-for-loop
"SIM117", # multiple-with-statements
"SIM210", # if-expr-with-true-false
"TID252", # allow relative imports from parent modules
]
[lint.per-file-ignores]
@ -112,3 +114,11 @@ allowed-unused-imports = [
"tests.integration_tests",
"tests.unit_tests",
]
[lint.flake8-tidy-imports]
[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse"]
msg = "Use Pydantic payload/query models instead of reqparse."
[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse.RequestParser"]
msg = "Use Pydantic payload/query models instead of reqparse."

View File

@ -1,7 +1,7 @@
import logging
from typing import Any
from flask_restx import marshal, reqparse
from flask_restx import marshal
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
@ -56,15 +56,10 @@ class DatasetsHitTestingBase:
HitTestingService.hit_testing_args_check(args)
@staticmethod
def parse_args():
parser = (
reqparse.RequestParser()
.add_argument("query", type=str, required=False, location="json")
.add_argument("attachment_ids", type=list, required=False, location="json")
.add_argument("retrieval_model", type=dict, required=False, location="json")
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
)
return parser.parse_args()
def parse_args(payload: dict[str, Any] | None = None) -> dict[str, Any]:
"""Validate and return hit-testing arguments from an incoming payload."""
hit_testing_payload = HitTestingPayload.model_validate(payload or {})
return hit_testing_payload.model_dump(exclude_none=True)
@staticmethod
def perform_hit_testing(dataset, args):

View File

@ -1,10 +1,9 @@
import json
import logging
from typing import Any, Literal, cast
from uuid import UUID
from flask import abort, request
from flask_restx import Resource, marshal_with, reqparse # type: ignore
from flask_restx import Resource, marshal_with # type: ignore
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
@ -38,7 +37,7 @@ from fields.workflow_run_fields import (
workflow_run_pagination_fields,
)
from libs import helper
from libs.helper import TimestampField
from libs.helper import TimestampField, UUIDStrOrEmpty
from libs.login import current_account_with_tenant, current_user, login_required
from models import Account
from models.dataset import Pipeline
@ -110,7 +109,7 @@ class NodeIdQuery(BaseModel):
class WorkflowRunQuery(BaseModel):
last_id: UUID | None = None
last_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
@ -121,6 +120,10 @@ class DatasourceVariablesPayload(BaseModel):
start_node_title: str
class RagPipelineRecommendedPluginQuery(BaseModel):
type: str = "all"
register_schema_models(
console_ns,
DraftWorkflowSyncPayload,
@ -135,6 +138,7 @@ register_schema_models(
NodeIdQuery,
WorkflowRunQuery,
DatasourceVariablesPayload,
RagPipelineRecommendedPluginQuery,
)
@ -975,11 +979,8 @@ class RagPipelineRecommendedPluginApi(Resource):
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("type", type=str, location="args", required=False, default="all")
args = parser.parse_args()
type = args["type"]
query = RagPipelineRecommendedPluginQuery.model_validate(request.args.to_dict())
rag_pipeline_service = RagPipelineService()
recommended_plugins = rag_pipeline_service.get_recommended_plugins(type)
recommended_plugins = rag_pipeline_service.get_recommended_plugins(query.type)
return recommended_plugins

View File

@ -11,6 +11,20 @@ from libs.login import current_account_with_tenant, login_required
from models import TenantAccountRole
from services.model_load_balancing_service import ModelLoadBalancingService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class LoadBalancingCredentialPayload(BaseModel):
model: str
model_type: ModelType
credentials: dict
console_ns.schema_model(
LoadBalancingCredentialPayload.__name__,
LoadBalancingCredentialPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
class LoadBalancingCredentialPayload(BaseModel):
model: str

File diff suppressed because it is too large Load Diff

View File

@ -1,15 +1,14 @@
import logging
from collections.abc import Mapping
from typing import Any
from flask import make_response, redirect, request
from flask_restx import Resource, reqparse
from pydantic import BaseModel, Field
from flask_restx import Resource
from pydantic import BaseModel
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden
from configs import dify_config
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from controllers.common.schema import register_schema_models
from controllers.web.error import NotFoundError
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin_daemon import CredentialType
@ -36,29 +35,32 @@ from ..wraps import (
logger = logging.getLogger(__name__)
class TriggerSubscriptionUpdateRequest(BaseModel):
"""Request payload for updating a trigger subscription"""
name: str | None = Field(default=None, description="The name for the subscription")
credentials: Mapping[str, Any] | None = Field(default=None, description="The credentials for the subscription")
parameters: Mapping[str, Any] | None = Field(default=None, description="The parameters for the subscription")
properties: Mapping[str, Any] | None = Field(default=None, description="The properties for the subscription")
class TriggerSubscriptionBuilderCreatePayload(BaseModel):
credential_type: str | None = None
class TriggerSubscriptionVerifyRequest(BaseModel):
"""Request payload for verifying subscription credentials."""
credentials: Mapping[str, Any] = Field(description="The credentials to verify")
class TriggerSubscriptionBuilderVerifyPayload(BaseModel):
credentials: dict[str, Any] | None = None
console_ns.schema_model(
TriggerSubscriptionUpdateRequest.__name__,
TriggerSubscriptionUpdateRequest.model_json_schema(ref_template="#/definitions/{model}"),
)
class TriggerSubscriptionBuilderUpdatePayload(BaseModel):
name: str | None = None
parameters: dict[str, Any] | None = None
properties: dict[str, Any] | None = None
credentials: dict[str, Any] | None = None
console_ns.schema_model(
TriggerSubscriptionVerifyRequest.__name__,
TriggerSubscriptionVerifyRequest.model_json_schema(ref_template="#/definitions/{model}"),
class TriggerOAuthClientPayload(BaseModel):
client_params: dict[str, Any] | None = None
enabled: bool | None = None
register_schema_models(
console_ns,
TriggerSubscriptionBuilderCreatePayload,
TriggerSubscriptionBuilderVerifyPayload,
TriggerSubscriptionBuilderUpdatePayload,
TriggerOAuthClientPayload,
)
@ -127,16 +129,11 @@ class TriggerSubscriptionListApi(Resource):
raise
parser = reqparse.RequestParser().add_argument(
"credential_type", type=str, required=False, nullable=True, location="json"
)
@console_ns.route(
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/create",
)
class TriggerSubscriptionBuilderCreateApi(Resource):
@console_ns.expect(parser)
@console_ns.expect(console_ns.models[TriggerSubscriptionBuilderCreatePayload.__name__])
@setup_required
@login_required
@edit_permission_required
@ -146,10 +143,10 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
user = current_user
assert user.current_tenant_id is not None
args = parser.parse_args()
payload = TriggerSubscriptionBuilderCreatePayload.model_validate(console_ns.payload or {})
try:
credential_type = CredentialType.of(args.get("credential_type") or CredentialType.UNAUTHORIZED.value)
credential_type = CredentialType.of(payload.credential_type or CredentialType.UNAUTHORIZED.value)
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
tenant_id=user.current_tenant_id,
user_id=user.id,
@ -177,18 +174,11 @@ class TriggerSubscriptionBuilderGetApi(Resource):
)
parser_api = (
reqparse.RequestParser()
# The credentials of the subscription builder
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
)
@console_ns.route(
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify-and-update/<path:subscription_builder_id>",
)
class TriggerSubscriptionBuilderVerifyAndUpdateApi(Resource):
@console_ns.expect(parser_api)
class TriggerSubscriptionBuilderVerifyApi(Resource):
@console_ns.expect(console_ns.models[TriggerSubscriptionBuilderVerifyPayload.__name__])
@setup_required
@login_required
@edit_permission_required
@ -198,7 +188,7 @@ class TriggerSubscriptionBuilderVerifyAndUpdateApi(Resource):
user = current_user
assert user.current_tenant_id is not None
args = parser_api.parse_args()
payload = TriggerSubscriptionBuilderVerifyPayload.model_validate(console_ns.payload or {})
try:
# Use atomic update_and_verify to prevent race conditions
@ -208,7 +198,7 @@ class TriggerSubscriptionBuilderVerifyAndUpdateApi(Resource):
provider_id=TriggerProviderID(provider),
subscription_builder_id=subscription_builder_id,
subscription_builder_updater=SubscriptionBuilderUpdater(
credentials=args.get("credentials", None),
credentials=payload.credentials,
),
)
except Exception as e:
@ -216,24 +206,11 @@ class TriggerSubscriptionBuilderVerifyAndUpdateApi(Resource):
raise ValueError(str(e)) from e
parser_update_api = (
reqparse.RequestParser()
# The name of the subscription builder
.add_argument("name", type=str, required=False, nullable=True, location="json")
# The parameters of the subscription builder
.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
# The properties of the subscription builder
.add_argument("properties", type=dict, required=False, nullable=True, location="json")
# The credentials of the subscription builder
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
)
@console_ns.route(
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/update/<path:subscription_builder_id>",
)
class TriggerSubscriptionBuilderUpdateApi(Resource):
@console_ns.expect(parser_update_api)
@console_ns.expect(console_ns.models[TriggerSubscriptionBuilderUpdatePayload.__name__])
@setup_required
@login_required
@edit_permission_required
@ -244,7 +221,7 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
assert isinstance(user, Account)
assert user.current_tenant_id is not None
args = parser_update_api.parse_args()
payload = TriggerSubscriptionBuilderUpdatePayload.model_validate(console_ns.payload or {})
try:
return jsonable_encoder(
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
@ -252,10 +229,10 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
provider_id=TriggerProviderID(provider),
subscription_builder_id=subscription_builder_id,
subscription_builder_updater=SubscriptionBuilderUpdater(
name=args.get("name", None),
parameters=args.get("parameters", None),
properties=args.get("properties", None),
credentials=args.get("credentials", None),
name=payload.name,
parameters=payload.parameters,
properties=payload.properties,
credentials=payload.credentials,
),
)
)
@ -290,7 +267,7 @@ class TriggerSubscriptionBuilderLogsApi(Resource):
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/build/<path:subscription_builder_id>",
)
class TriggerSubscriptionBuilderBuildApi(Resource):
@console_ns.expect(parser_update_api)
@console_ns.expect(console_ns.models[TriggerSubscriptionBuilderUpdatePayload.__name__])
@setup_required
@login_required
@edit_permission_required
@ -299,7 +276,7 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
"""Build a subscription instance for a trigger provider"""
user = current_user
assert user.current_tenant_id is not None
args = parser_update_api.parse_args()
payload = TriggerSubscriptionBuilderUpdatePayload.model_validate(console_ns.payload or {})
try:
# Use atomic update_and_build to prevent race conditions
TriggerSubscriptionBuilderService.update_and_build_builder(
@ -308,9 +285,9 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
provider_id=TriggerProviderID(provider),
subscription_builder_id=subscription_builder_id,
subscription_builder_updater=SubscriptionBuilderUpdater(
name=args.get("name", None),
parameters=args.get("parameters", None),
properties=args.get("properties", None),
name=payload.name,
parameters=payload.parameters,
properties=payload.properties,
),
)
return 200
@ -581,13 +558,6 @@ class TriggerOAuthCallbackApi(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
parser_oauth_client = (
reqparse.RequestParser()
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
.add_argument("enabled", type=bool, required=False, nullable=True, location="json")
)
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/oauth/client")
class TriggerOAuthClientManageApi(Resource):
@setup_required
@ -635,7 +605,7 @@ class TriggerOAuthClientManageApi(Resource):
logger.exception("Error getting OAuth client", exc_info=e)
raise
@console_ns.expect(parser_oauth_client)
@console_ns.expect(console_ns.models[TriggerOAuthClientPayload.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@ -645,15 +615,15 @@ class TriggerOAuthClientManageApi(Resource):
user = current_user
assert user.current_tenant_id is not None
args = parser_oauth_client.parse_args()
payload = TriggerOAuthClientPayload.model_validate(console_ns.payload or {})
try:
provider_id = TriggerProviderID(provider)
return TriggerProviderService.save_custom_oauth_client_params(
tenant_id=user.current_tenant_id,
provider_id=provider_id,
client_params=args.get("client_params"),
enabled=args.get("enabled"),
client_params=payload.client_params,
enabled=payload.enabled,
)
except ValueError as e:

View File

@ -30,6 +30,7 @@ from core.errors.error import (
from core.helper.trace_id_helper import get_external_trace_id
from core.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.helper import UUIDStrOrEmpty
from models.model import App, AppMode, EndUser
from services.app_generate_service import AppGenerateService
from services.app_task_service import AppTaskService
@ -52,7 +53,7 @@ class ChatRequestPayload(BaseModel):
query: str
files: list[dict[str, Any]] | None = None
response_mode: Literal["blocking", "streaming"] | None = None
conversation_id: str | None = Field(default=None, description="Conversation UUID")
conversation_id: UUIDStrOrEmpty | None = Field(default=None, description="Conversation UUID")
retriever_from: str = Field(default="dev")
auto_generate_name: bool = Field(default=True, description="Auto generate conversation name")
workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat")

View File

@ -1,5 +1,4 @@
from typing import Any, Literal
from uuid import UUID
from flask import request
from flask_restx import Resource
@ -24,12 +23,13 @@ from fields.conversation_variable_fields import (
build_conversation_variable_infinite_scroll_pagination_model,
build_conversation_variable_model,
)
from libs.helper import UUIDStrOrEmpty
from models.model import App, AppMode, EndUser
from services.conversation_service import ConversationService
class ConversationListQuery(BaseModel):
last_id: UUID | None = Field(default=None, description="Last conversation ID for pagination")
last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last conversation ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of conversations to return")
sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field(
default="-updated_at", description="Sort order for conversations"
@ -49,7 +49,7 @@ class ConversationRenamePayload(BaseModel):
class ConversationVariablesQuery(BaseModel):
last_id: UUID | None = Field(default=None, description="Last variable ID for pagination")
last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last variable ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")
variable_name: str | None = Field(
default=None, description="Filter variables by name", min_length=1, max_length=255

View File

@ -1,7 +1,6 @@
import json
import logging
from typing import Literal
from uuid import UUID
from flask import request
from flask_restx import Namespace, Resource, fields
@ -17,7 +16,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from fields.conversation_fields import build_message_file_model
from fields.message_fields import build_agent_thought_model, build_feedback_model
from fields.raws import FilesContainedField
from libs.helper import TimestampField
from libs.helper import TimestampField, UUIDStrOrEmpty
from models.model import App, AppMode, EndUser
from services.errors.message import (
FirstMessageNotExistsError,
@ -30,8 +29,8 @@ logger = logging.getLogger(__name__)
class MessageListQuery(BaseModel):
conversation_id: UUID
first_id: UUID | None = None
conversation_id: UUIDStrOrEmpty
first_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return")

View File

@ -1,7 +1,10 @@
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
from controllers.common.schema import register_schema_model
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase, HitTestingPayload
from controllers.service_api import service_api_ns
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
register_schema_model(service_api_ns, HitTestingPayload)
@service_api_ns.route("/datasets/<uuid:dataset_id>/hit-testing", "/datasets/<uuid:dataset_id>/retrieve")
class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
@ -15,6 +18,7 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
404: "Dataset not found",
}
)
@service_api_ns.expect(service_api_ns.models[HitTestingPayload.__name__])
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id):
"""Perform hit testing on a dataset.
@ -24,7 +28,7 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
dataset_id_str = str(dataset_id)
dataset = self.get_and_validate_dataset(dataset_id_str)
args = self.parse_args()
args = self.parse_args(service_api_ns.payload)
self.hit_testing_args_check(args)
return self.perform_hit_testing(dataset, args)

View File

@ -18,7 +18,6 @@ from controllers.web.error import (
ProviderQuotaExceededError,
UnsupportedAudioTypeError,
)
from controllers.web.wraps import WebApiResource
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from libs.helper import uuid_value

View File

@ -1,8 +1,12 @@
from flask_restx import fields, marshal_with, reqparse
from flask_restx.inputs import int_range
from typing import Literal
from flask import request
from flask_restx import fields, marshal_with
from pydantic import BaseModel, Field, field_validator, model_validator
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.error import NotChatAppError
from controllers.web.wraps import WebApiResource
@ -16,6 +20,35 @@ from services.errors.conversation import ConversationNotExistsError, LastConvers
from services.web_conversation_service import WebConversationService
class ConversationListQuery(BaseModel):
last_id: str | None = None
limit: int = Field(default=20, ge=1, le=100)
pinned: bool | None = None
sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = "-updated_at"
@field_validator("last_id")
@classmethod
def validate_last_id(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class ConversationRenamePayload(BaseModel):
name: str | None = None
auto_generate: bool = False
@model_validator(mode="after")
def validate_name_requirement(self):
if not self.auto_generate:
if self.name is None or not self.name.strip():
raise ValueError("name is required when auto_generate is false")
return self
register_schema_models(web_ns, ConversationListQuery, ConversationRenamePayload)
@web_ns.route("/conversations")
class ConversationListApi(WebApiResource):
@web_ns.doc("Get Conversation List")
@ -60,25 +93,8 @@ class ConversationListApi(WebApiResource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = (
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
.add_argument("pinned", type=str, choices=["true", "false", None], location="args")
.add_argument(
"sort_by",
type=str,
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
required=False,
default="-updated_at",
location="args",
)
)
args = parser.parse_args()
pinned = None
if "pinned" in args and args["pinned"] is not None:
pinned = args["pinned"] == "true"
raw_args = request.args.to_dict()
query = ConversationListQuery.model_validate(raw_args)
try:
with Session(db.engine) as session:
@ -86,11 +102,11 @@ class ConversationListApi(WebApiResource):
session=session,
app_model=app_model,
user=end_user,
last_id=args["last_id"],
limit=args["limit"],
last_id=query.last_id,
limit=query.limit,
invoke_from=InvokeFrom.WEB_APP,
pinned=pinned,
sort_by=args["sort_by"],
pinned=query.pinned,
sort_by=query.sort_by,
)
except LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")
@ -163,15 +179,12 @@ class ConversationRenameApi(WebApiResource):
conversation_id = str(c_id)
parser = (
reqparse.RequestParser()
.add_argument("name", type=str, required=False, location="json")
.add_argument("auto_generate", type=bool, required=False, default=False, location="json")
)
args = parser.parse_args()
payload = ConversationRenamePayload.model_validate(web_ns.payload or {})
try:
return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"])
return ConversationService.rename(
app_model, conversation_id, end_user, payload.name or "", payload.auto_generate
)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")

View File

@ -1,9 +1,11 @@
from flask import make_response, request
from flask_restx import Resource, reqparse
from flask_restx import Resource
from jwt import InvalidTokenError
from pydantic import BaseModel, Field, field_validator
import services
from configs import dify_config
from controllers.common.schema import register_schema_models
from controllers.console.auth.error import (
AuthenticationFailedError,
EmailCodeError,
@ -13,7 +15,7 @@ from controllers.console.error import AccountBannedError
from controllers.console.wraps import only_edition_enterprise, setup_required
from controllers.web import web_ns
from controllers.web.wraps import decode_jwt_token
from libs.helper import email
from libs.helper import EmailStr
from libs.passport import PassportService
from libs.password import valid_password
from libs.token import (
@ -25,6 +27,30 @@ from services.app_service import AppService
from services.webapp_auth_service import WebAppAuthService
class LoginPayload(BaseModel):
email: EmailStr
password: str
@field_validator("password")
@classmethod
def validate_password(cls, value: str) -> str:
return valid_password(value)
class EmailCodeLoginSendPayload(BaseModel):
email: EmailStr
language: str | None = None
class EmailCodeLoginVerifyPayload(BaseModel):
email: EmailStr
code: str
token: str = Field(min_length=1)
register_schema_models(web_ns, LoginPayload, EmailCodeLoginSendPayload, EmailCodeLoginVerifyPayload)
@web_ns.route("/login")
class LoginApi(Resource):
"""Resource for web app email/password login."""
@ -44,15 +70,10 @@ class LoginApi(Resource):
)
def post(self):
"""Authenticate user and login."""
parser = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("password", type=valid_password, required=True, location="json")
)
args = parser.parse_args()
payload = LoginPayload.model_validate(web_ns.payload or {})
try:
account = WebAppAuthService.authenticate(args["email"], args["password"])
account = WebAppAuthService.authenticate(payload.email, payload.password)
except services.errors.account.AccountLoginError:
raise AccountBannedError()
except services.errors.account.AccountPasswordError:
@ -139,6 +160,7 @@ class EmailCodeLoginSendEmailApi(Resource):
@only_edition_enterprise
@web_ns.doc("send_email_code_login")
@web_ns.doc(description="Send email verification code for login")
@web_ns.expect(web_ns.models[EmailCodeLoginSendPayload.__name__])
@web_ns.doc(
responses={
200: "Email code sent successfully",
@ -147,19 +169,14 @@ class EmailCodeLoginSendEmailApi(Resource):
}
)
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
payload = EmailCodeLoginSendPayload.model_validate(web_ns.payload or {})
if args["language"] is not None and args["language"] == "zh-Hans":
if payload.language == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
account = WebAppAuthService.get_user_through_email(args["email"])
account = WebAppAuthService.get_user_through_email(payload.email)
if account is None:
raise AuthenticationFailedError()
else:
@ -173,6 +190,7 @@ class EmailCodeLoginApi(Resource):
@only_edition_enterprise
@web_ns.doc("verify_email_code_login")
@web_ns.doc(description="Verify email code and complete login")
@web_ns.expect(web_ns.models[EmailCodeLoginVerifyPayload.__name__])
@web_ns.doc(
responses={
200: "Email code verified and login successful",
@ -182,33 +200,27 @@ class EmailCodeLoginApi(Resource):
}
)
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("email", type=str, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, location="json")
)
args = parser.parse_args()
payload = EmailCodeLoginVerifyPayload.model_validate(web_ns.payload or {})
user_email = args["email"]
user_email = payload.email
token_data = WebAppAuthService.get_email_code_login_data(args["token"])
token_data = WebAppAuthService.get_email_code_login_data(payload.token)
if token_data is None:
raise InvalidTokenError()
if token_data["email"] != args["email"]:
if token_data["email"] != payload.email:
raise InvalidEmailError()
if token_data["code"] != args["code"]:
if token_data["code"] != payload.code:
raise EmailCodeError()
WebAppAuthService.revoke_email_code_login_token(args["token"])
WebAppAuthService.revoke_email_code_login_token(payload.token)
account = WebAppAuthService.get_user_through_email(user_email)
if not account:
raise AuthenticationFailedError()
token = WebAppAuthService.login(account=account)
AccountService.reset_login_error_rate_limit(args["email"])
AccountService.reset_login_error_rate_limit(payload.email)
response = make_response({"result": "success", "data": {"access_token": token}})
# set_access_token_to_cookie(request, response, token, samesite="None", httponly=False)
return response

View File

@ -1,12 +1,14 @@
from flask_restx import fields, marshal_with, reqparse
from flask_restx.inputs import int_range
from flask import request
from flask_restx import fields, marshal_with
from pydantic import BaseModel, Field
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.error import NotCompletionAppError
from controllers.web.wraps import WebApiResource
from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField, uuid_value
from libs.helper import TimestampField, UUIDStrOrEmpty
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService
@ -23,6 +25,18 @@ message_fields = {
}
class SavedMessageListQuery(BaseModel):
last_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
class SavedMessageCreatePayload(BaseModel):
message_id: UUIDStrOrEmpty
register_schema_models(web_ns, SavedMessageListQuery, SavedMessageCreatePayload)
@web_ns.route("/saved-messages")
class SavedMessageListApi(WebApiResource):
saved_message_infinite_scroll_pagination_fields = {
@ -63,14 +77,10 @@ class SavedMessageListApi(WebApiResource):
if app_model.mode != "completion":
raise NotCompletionAppError()
parser = (
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
args = parser.parse_args()
raw_args = request.args.to_dict()
query = SavedMessageListQuery.model_validate(raw_args)
return SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"])
return SavedMessageService.pagination_by_last_id(app_model, end_user, query.last_id, query.limit)
@web_ns.doc("Save Message")
@web_ns.doc(description="Save a specific message for later reference.")
@ -94,11 +104,10 @@ class SavedMessageListApi(WebApiResource):
if app_model.mode != "completion":
raise NotCompletionAppError()
parser = reqparse.RequestParser().add_argument("message_id", type=uuid_value, required=True, location="json")
args = parser.parse_args()
payload = SavedMessageCreatePayload.model_validate(web_ns.payload or {})
try:
SavedMessageService.save(app_model, end_user, args["message_id"])
SavedMessageService.save(app_model, end_user, payload.message_id)
except MessageNotExistsError:
raise NotFound("Message Not Exists.")

View File

@ -1,8 +1,10 @@
import logging
from typing import Any
from flask_restx import reqparse
from pydantic import BaseModel
from werkzeug.exceptions import InternalServerError
from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.error import (
CompletionRequestError,
@ -27,8 +29,16 @@ from models.model import App, AppMode, EndUser
from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError
class WorkflowRunPayload(BaseModel):
inputs: dict[str, Any]
files: list[dict[str, Any]] | None = None
logger = logging.getLogger(__name__)
register_schema_models(web_ns, WorkflowRunPayload)
@web_ns.route("/workflows/run")
class WorkflowRunApi(WebApiResource):
@ -58,12 +68,8 @@ class WorkflowRunApi(WebApiResource):
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("files", type=list, required=False, location="json")
)
args = parser.parse_args()
payload = WorkflowRunPayload.model_validate(web_ns.payload or {})
args = payload.model_dump(exclude_none=True)
try:
response = AppGenerateService.generate(

View File

@ -378,7 +378,7 @@ class ApiBasedToolSchemaParser:
@staticmethod
def auto_parse_to_tool_bundle(
content: str, extra_info: dict | None = None, warning: dict | None = None
) -> tuple[list[ApiToolBundle], str]:
) -> tuple[list[ApiToolBundle], ApiProviderSchemaType]:
"""
auto parse to tool bundle

View File

@ -7,11 +7,6 @@ from core.workflow.nodes.base.entities import OutputVariableEntity
class WorkflowToolConfigurationUtils:
@classmethod
def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]):
for configuration in configurations:
WorkflowToolParameterConfiguration.model_validate(configuration)
@classmethod
def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]:
"""

View File

@ -86,7 +86,9 @@ class ApiToolManageService:
raise ValueError(f"invalid schema: {str(e)}")
@staticmethod
def convert_schema_to_tool_bundles(schema: str, extra_info: dict | None = None) -> tuple[list[ApiToolBundle], str]:
def convert_schema_to_tool_bundles(
schema: str, extra_info: dict | None = None
) -> tuple[list[ApiToolBundle], ApiProviderSchemaType]:
"""
convert schema to tool bundles
@ -104,7 +106,7 @@ class ApiToolManageService:
provider_name: str,
icon: dict,
credentials: dict,
schema_type: str,
schema_type: ApiProviderSchemaType,
schema: str,
privacy_policy: str,
custom_disclaimer: str,
@ -113,9 +115,6 @@ class ApiToolManageService:
"""
create api tool provider
"""
if schema_type not in [member.value for member in ApiProviderSchemaType]:
raise ValueError(f"invalid schema type {schema}")
provider_name = provider_name.strip()
# check if the provider exists
@ -245,18 +244,15 @@ class ApiToolManageService:
original_provider: str,
icon: dict,
credentials: dict,
schema_type: str,
_schema_type: ApiProviderSchemaType,
schema: str,
privacy_policy: str,
privacy_policy: str | None,
custom_disclaimer: str,
labels: list[str],
):
"""
update api tool provider
"""
if schema_type not in [member.value for member in ApiProviderSchemaType]:
raise ValueError(f"invalid schema type {schema}")
provider_name = provider_name.strip()
# check if the provider exists
@ -281,7 +277,7 @@ class ApiToolManageService:
provider.icon = json.dumps(icon)
provider.schema = schema
provider.description = extra_info.get("description", "")
provider.schema_type_str = ApiProviderSchemaType.OPENAPI
provider.schema_type_str = schema_type
provider.tools_str = json.dumps(jsonable_encoder(tool_bundles))
provider.privacy_policy = privacy_policy
provider.custom_disclaimer = custom_disclaimer
@ -366,7 +362,7 @@ class ApiToolManageService:
tool_name: str,
credentials: dict,
parameters: dict,
schema_type: str,
schema_type: ApiProviderSchemaType,
schema: str,
):
"""

View File

@ -1,8 +1,6 @@
import json
import logging
from collections.abc import Mapping
from datetime import datetime
from typing import Any
from sqlalchemy import or_, select
@ -11,8 +9,8 @@ from core.helper.tool_provider_cache import ToolProviderListCache
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db
@ -39,12 +37,10 @@ class WorkflowToolManageService:
label: str,
icon: dict,
description: str,
parameters: list[Mapping[str, Any]],
parameters: list[WorkflowToolParameterConfiguration],
privacy_policy: str = "",
labels: list[str] | None = None,
):
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
# check if the name is unique
existing_workflow_tool_provider = (
db.session.query(WorkflowToolProvider)
@ -109,7 +105,7 @@ class WorkflowToolManageService:
label: str,
icon: dict,
description: str,
parameters: list[Mapping[str, Any]],
parameters: list[WorkflowToolParameterConfiguration],
privacy_policy: str = "",
labels: list[str] | None = None,
):
@ -127,8 +123,6 @@ class WorkflowToolManageService:
:param labels: labels
:return: the updated tool
"""
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
# check if the name is unique
existing_workflow_tool_provider = (
db.session.query(WorkflowToolProvider)
@ -167,7 +161,7 @@ class WorkflowToolManageService:
workflow_tool_provider.label = label
workflow_tool_provider.icon = json.dumps(icon)
workflow_tool_provider.description = description
workflow_tool_provider.parameter_configuration = json.dumps(parameters)
workflow_tool_provider.parameter_configuration = json.dumps([p.model_dump() for p in parameters])
workflow_tool_provider.privacy_policy = privacy_policy
workflow_tool_provider.version = workflow.version
workflow_tool_provider.updated_at = datetime.now()

View File

@ -3,7 +3,9 @@ from unittest.mock import patch
import pytest
from faker import Faker
from pydantic import ValidationError
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
from models.tools import WorkflowToolProvider
from models.workflow import Workflow as WorkflowModel
from services.account_service import AccountService, TenantService
@ -130,20 +132,24 @@ class TestWorkflowToolManageService:
def _create_test_workflow_tool_parameters(self):
"""Helper method to create valid workflow tool parameters."""
return [
{
"name": "input_text",
"description": "Input text for processing",
"form": "form",
"type": "string",
"required": True,
},
{
"name": "output_format",
"description": "Output format specification",
"form": "form",
"type": "select",
"required": False,
},
WorkflowToolParameterConfiguration.model_validate(
{
"name": "input_text",
"description": "Input text for processing",
"form": "form",
"type": "string",
"required": True,
}
),
WorkflowToolParameterConfiguration.model_validate(
{
"name": "output_format",
"description": "Output format specification",
"form": "form",
"type": "select",
"required": False,
}
),
]
def test_create_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies):
@ -208,7 +214,7 @@ class TestWorkflowToolManageService:
assert created_tool_provider.label == tool_label
assert created_tool_provider.icon == json.dumps(tool_icon)
assert created_tool_provider.description == tool_description
assert created_tool_provider.parameter_configuration == json.dumps(tool_parameters)
assert created_tool_provider.parameter_configuration == json.dumps([p.model_dump() for p in tool_parameters])
assert created_tool_provider.privacy_policy == tool_privacy_policy
assert created_tool_provider.version == workflow.version
assert created_tool_provider.user_id == account.id
@ -353,18 +359,9 @@ class TestWorkflowToolManageService:
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
# Setup invalid workflow tool parameters (missing required fields)
invalid_parameters = [
{
"name": "input_text",
# Missing description and form fields
"type": "string",
"required": True,
}
]
# Attempt to create workflow tool with invalid parameters
with pytest.raises(ValueError) as exc_info:
with pytest.raises(ValidationError) as exc_info:
# Setup invalid workflow tool parameters (missing required fields)
WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
@ -373,7 +370,16 @@ class TestWorkflowToolManageService:
label=fake.word(),
icon={"type": "emoji", "emoji": "🔧"},
description=fake.text(max_nb_chars=200),
parameters=invalid_parameters,
parameters=[
WorkflowToolParameterConfiguration.model_validate(
{
"name": "input_text",
# Missing description and form fields
"type": "string",
"required": True,
}
)
],
)
# Verify error message contains validation error
@ -579,11 +585,12 @@ class TestWorkflowToolManageService:
# Verify database state was updated
db.session.refresh(created_tool)
assert created_tool is not None
assert created_tool.name == updated_tool_name
assert created_tool.label == updated_tool_label
assert created_tool.icon == json.dumps(updated_tool_icon)
assert created_tool.description == updated_tool_description
assert created_tool.parameter_configuration == json.dumps(updated_tool_parameters)
assert created_tool.parameter_configuration == json.dumps([p.model_dump() for p in updated_tool_parameters])
assert created_tool.privacy_policy == updated_tool_privacy_policy
assert created_tool.version == workflow.version
assert created_tool.updated_at is not None