add typing to all wraps (#25405)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Asuka Minato 2025-09-09 17:48:33 +09:00 committed by GitHub
parent eb52216a9c
commit 38057b1b0e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 61 additions and 46 deletions

View File

@ -1,6 +1,6 @@
from collections.abc import Callable
from functools import wraps
from typing import Optional, Union
from typing import Optional, ParamSpec, TypeVar, Union
from controllers.console.app.error import AppNotFoundError
from extensions.ext_database import db
@ -8,6 +8,9 @@ from libs.login import current_user
from models import App, AppMode
from models.account import Account
P = ParamSpec("P")
R = TypeVar("R")
def _load_app_model(app_id: str) -> Optional[App]:
assert isinstance(current_user, Account)
@ -19,10 +22,10 @@ def _load_app_model(app_id: str) -> Optional[App]:
return app_model
def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode], None] = None):
def decorator(view_func):
def get_app_model(view: Optional[Callable[P, R]] = None, *, mode: Union[AppMode, list[AppMode], None] = None):
def decorator(view_func: Callable[P, R]):
@wraps(view_func)
def decorated_view(*args, **kwargs):
def decorated_view(*args: P.args, **kwargs: P.kwargs):
if not kwargs.get("app_id"):
raise ValueError("missing app_id in path parameters")

View File

@ -1,6 +1,6 @@
from collections.abc import Callable
from functools import wraps
from typing import Optional
from typing import Optional, ParamSpec, TypeVar
from flask import current_app, request
from flask_login import user_logged_in
@ -14,6 +14,9 @@ from libs.login import _get_user
from models.account import Tenant
from models.model import EndUser
P = ParamSpec("P")
R = TypeVar("R")
def get_user(tenant_id: str, user_id: str | None) -> EndUser:
"""
@ -52,19 +55,19 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
return user_model
def get_user_tenant(view: Optional[Callable] = None):
def decorator(view_func):
def get_user_tenant(view: Optional[Callable[P, R]] = None):
def decorator(view_func: Callable[P, R]):
@wraps(view_func)
def decorated_view(*args, **kwargs):
def decorated_view(*args: P.args, **kwargs: P.kwargs):
# fetch json body
parser = reqparse.RequestParser()
parser.add_argument("tenant_id", type=str, required=True, location="json")
parser.add_argument("user_id", type=str, required=True, location="json")
kwargs = parser.parse_args()
p = parser.parse_args()
user_id = kwargs.get("user_id")
tenant_id = kwargs.get("tenant_id")
user_id: Optional[str] = p.get("user_id")
tenant_id: str = p.get("tenant_id")
if not tenant_id:
raise ValueError("tenant_id is required")
@ -107,9 +110,9 @@ def get_user_tenant(view: Optional[Callable] = None):
return decorator(view)
def plugin_data(view: Optional[Callable] = None, *, payload_type: type[BaseModel]):
def decorator(view_func):
def decorated_view(*args, **kwargs):
def plugin_data(view: Optional[Callable[P, R]] = None, *, payload_type: type[BaseModel]):
def decorator(view_func: Callable[P, R]):
def decorated_view(*args: P.args, **kwargs: P.kwargs):
try:
data = request.get_json()
except Exception:

View File

@ -46,9 +46,9 @@ def enterprise_inner_api_only(view: Callable[P, R]):
return decorated
def enterprise_inner_api_user_auth(view):
def enterprise_inner_api_user_auth(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.INNER_API:
return view(*args, **kwargs)

View File

@ -19,7 +19,7 @@ class ModelProviderAvailableModelApi(Resource):
}
)
@validate_dataset_token
def get(self, _, model_type):
def get(self, _, model_type: str):
"""Get available models by model type.
Returns a list of available models for the specified model type.

View File

@ -3,7 +3,7 @@ from collections.abc import Callable
from datetime import timedelta
from enum import StrEnum, auto
from functools import wraps
from typing import Optional, ParamSpec, TypeVar
from typing import Concatenate, Optional, ParamSpec, TypeVar
from flask import current_app, request
from flask_login import user_logged_in
@ -25,6 +25,7 @@ from services.feature_service import FeatureService
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
class WhereisUserArg(StrEnum):
@ -42,10 +43,10 @@ class FetchUserArg(BaseModel):
required: bool = False
def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optional[FetchUserArg] = None):
def decorator(view_func):
def validate_app_token(view: Optional[Callable[P, R]] = None, *, fetch_user_arg: Optional[FetchUserArg] = None):
def decorator(view_func: Callable[P, R]):
@wraps(view_func)
def decorated_view(*args, **kwargs):
def decorated_view(*args: P.args, **kwargs: P.kwargs):
api_token = validate_and_get_api_token("app")
app_model = db.session.query(App).where(App.id == api_token.app_id).first()
@ -189,10 +190,10 @@ def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
return interceptor
def validate_dataset_token(view=None):
def decorator(view):
def validate_dataset_token(view: Optional[Callable[Concatenate[T, P], R]] = None):
def decorator(view: Callable[Concatenate[T, P], R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
api_token = validate_and_get_api_token("dataset")
tenant_account_join = (
db.session.query(Tenant, TenantAccountJoin)

View File

@ -1,6 +1,7 @@
from collections.abc import Callable
from datetime import UTC, datetime
from functools import wraps
from typing import ParamSpec, TypeVar
from typing import Concatenate, Optional, ParamSpec, TypeVar
from flask import request
from flask_restx import Resource
@ -20,12 +21,11 @@ P = ParamSpec("P")
R = TypeVar("R")
def validate_jwt_token(view=None):
def decorator(view):
def validate_jwt_token(view: Optional[Callable[Concatenate[App, EndUser, P], R]] = None):
def decorator(view: Callable[Concatenate[App, EndUser, P], R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
app_model, end_user = decode_jwt_token()
return view(app_model, end_user, *args, **kwargs)
return decorated

View File

@ -1,8 +1,9 @@
import json
import logging
import uuid
from collections.abc import Callable
from functools import wraps
from typing import Any, Optional
from typing import Any, Concatenate, Optional, ParamSpec, TypeVar
from mo_vector.client import MoVectorClient # type: ignore
from pydantic import BaseModel, model_validator
@ -17,7 +18,6 @@ from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
from typing import ParamSpec, TypeVar
P = ParamSpec("P")
R = TypeVar("R")
@ -47,16 +47,6 @@ class MatrixoneConfig(BaseModel):
return values
def ensure_client(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if self.client is None:
self.client = self._get_client(None, False)
return func(self, *args, **kwargs)
return wrapper
class MatrixoneVector(BaseVector):
"""
Matrixone vector storage implementation.
@ -216,6 +206,19 @@ class MatrixoneVector(BaseVector):
self.client.delete()
T = TypeVar("T", bound=MatrixoneVector)
def ensure_client(func: Callable[Concatenate[T, P], R]):
@wraps(func)
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs):
if self.client is None:
self.client = self._get_client(None, False)
return func(self, *args, **kwargs)
return wrapper
class MatrixoneVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MatrixoneVector:
if dataset.index_struct_dict:

View File

@ -6,10 +6,12 @@ from pydantic import BaseModel
from services.enterprise.base import EnterprisePluginManagerRequest
from services.errors.base import BaseServiceError
logger = logging.getLogger(__name__)
class PluginCredentialType(enum.Enum):
MODEL = 0
TOOL = 1
class PluginCredentialType(enum.IntEnum):
MODEL = enum.auto()
TOOL = enum.auto()
def to_number(self):
return self.value
@ -47,6 +49,9 @@ class PluginManagerService:
if not ret.get("result", False):
raise CredentialPolicyViolationError("Credentials not available: Please use ENTERPRISE global credentials")
logging.debug(
f"Credential policy compliance checked for {body.provider} with credential {body.dify_credential_id}, result: {ret.get('result', False)}"
logger.debug(
"Credential policy compliance checked for %s with credential %s, result: %s",
body.provider,
body.dify_credential_id,
ret.get("result", False),
)