diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py index f8a2c71e80..709bba3f30 100644 --- a/api/controllers/inner_api/plugin/wraps.py +++ b/api/controllers/inner_api/plugin/wraps.py @@ -5,6 +5,7 @@ from typing import Optional from flask import request from flask_restful import reqparse from pydantic import BaseModel +from sqlalchemy.orm import Session from extensions.ext_database import db from models.account import Account, Tenant @@ -12,19 +13,29 @@ from models.model import EndUser from services.account_service import AccountService -def get_user(user_id: str | None) -> Account | EndUser: +def get_user(tenant_id: str, user_id: str | None) -> Account | EndUser: try: - if not user_id: - user_id = "DEFAULT-USER" + with Session(db.engine) as session: + if not user_id: + user_id = "DEFAULT-USER" - if user_id == "DEFAULT-USER": - user_model = db.session.query(EndUser).filter(EndUser.session_id == "DEFAULT-USER").first() - else: - user_model = AccountService.load_user(user_id) - if not user_model: - user_model = db.session.query(EndUser).filter(EndUser.id == user_id).first() - if not user_model: - raise ValueError("user not found") + if user_id == "DEFAULT-USER": + user_model = session.query(EndUser).filter(EndUser.session_id == "DEFAULT-USER").first() + if not user_model: + user_model = EndUser( + tenant_id=tenant_id, + type="service_api", + is_anonymous=True if user_id == "DEFAULT-USER" else False, + session_id=user_id, + ) + session.add(user_model) + session.commit() + else: + user_model = AccountService.load_user(user_id) + if not user_model: + user_model = session.query(EndUser).filter(EndUser.id == user_id).first() + if not user_model: + raise ValueError("user not found") except Exception: raise ValueError("user not found") @@ -45,6 +56,12 @@ def get_user_tenant(view: Optional[Callable] = None): user_id = kwargs.get("user_id") tenant_id = kwargs.get("tenant_id") + if not tenant_id: + raise ValueError("tenant_id is required") + + if not user_id: + user_id = "DEFAULT-USER" + del kwargs["tenant_id"] del kwargs["user_id"] @@ -63,7 +80,7 @@ def get_user_tenant(view: Optional[Callable] = None): raise ValueError("tenant not found") kwargs["tenant_model"] = tenant_model - kwargs["user_model"] = get_user(user_id) + kwargs["user_model"] = get_user(tenant_id, user_id) return view_func(*args, **kwargs) diff --git a/api/models/model.py b/api/models/model.py index e4a9181a3c..1d3468566d 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1260,7 +1260,7 @@ class OperationLog(Base): updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) -class EndUser(UserMixin, Base): +class EndUser(Base, UserMixin): __tablename__ = "end_users" __table_args__ = ( db.PrimaryKeyConstraint("id", name="end_user_pkey"),