Refactor account models to use SQLAlchemy 2.0 dataclass mapping (#26415)

Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Asuka Minato 2025-10-10 17:12:12 +09:00 committed by GitHub
parent 2b6882bd97
commit 8a2b208299
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 219 additions and 177 deletions

View File

@ -1,15 +1,16 @@
import enum import enum
import json import json
from dataclasses import field
from datetime import datetime from datetime import datetime
from typing import Any, Optional from typing import Any, Optional
import sqlalchemy as sa import sqlalchemy as sa
from flask_login import UserMixin # type: ignore[import-untyped] from flask_login import UserMixin # type: ignore[import-untyped]
from sqlalchemy import DateTime, String, func, select from sqlalchemy import DateTime, String, func, select
from sqlalchemy.orm import Mapped, Session, mapped_column, reconstructor from sqlalchemy.orm import Mapped, Session, mapped_column
from typing_extensions import deprecated from typing_extensions import deprecated
from models.base import Base from models.base import TypeBase
from .engine import db from .engine import db
from .types import StringUUID from .types import StringUUID
@ -83,31 +84,37 @@ class AccountStatus(enum.StrEnum):
CLOSED = "closed" CLOSED = "closed"
class Account(UserMixin, Base): class Account(UserMixin, TypeBase):
__tablename__ = "accounts" __tablename__ = "accounts"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="account_pkey"), sa.Index("account_email_idx", "email")) __table_args__ = (sa.PrimaryKeyConstraint("id", name="account_pkey"), sa.Index("account_email_idx", "email"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
name: Mapped[str] = mapped_column(String(255)) name: Mapped[str] = mapped_column(String(255))
email: Mapped[str] = mapped_column(String(255)) email: Mapped[str] = mapped_column(String(255))
password: Mapped[str | None] = mapped_column(String(255)) password: Mapped[str | None] = mapped_column(String(255), default=None)
password_salt: Mapped[str | None] = mapped_column(String(255)) password_salt: Mapped[str | None] = mapped_column(String(255), default=None)
avatar: Mapped[str | None] = mapped_column(String(255), nullable=True) avatar: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
interface_language: Mapped[str | None] = mapped_column(String(255)) interface_language: Mapped[str | None] = mapped_column(String(255), default=None)
interface_theme: Mapped[str | None] = mapped_column(String(255), nullable=True) interface_theme: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
timezone: Mapped[str | None] = mapped_column(String(255)) timezone: Mapped[str | None] = mapped_column(String(255), default=None)
last_login_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) last_login_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
last_login_ip: Mapped[str | None] = mapped_column(String(255), nullable=True) last_login_ip: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
last_active_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) last_active_at: Mapped[datetime] = mapped_column(
status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'active'::character varying")) DateTime, server_default=func.current_timestamp(), nullable=False, init=False
initialized_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) )
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) status: Mapped[str] = mapped_column(
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) String(16), server_default=sa.text("'active'::character varying"), default="active"
)
initialized_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
)
@reconstructor role: TenantAccountRole | None = field(default=None, init=False)
def init_on_load(self): _current_tenant: "Tenant | None" = field(default=None, init=False)
self.role: TenantAccountRole | None = None
self._current_tenant: Tenant | None = None
@property @property
def is_password_set(self): def is_password_set(self):
@ -226,18 +233,24 @@ class TenantStatus(enum.StrEnum):
ARCHIVE = "archive" ARCHIVE = "archive"
class Tenant(Base): class Tenant(TypeBase):
__tablename__ = "tenants" __tablename__ = "tenants"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="tenant_pkey"),) __table_args__ = (sa.PrimaryKeyConstraint("id", name="tenant_pkey"),)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
name: Mapped[str] = mapped_column(String(255)) name: Mapped[str] = mapped_column(String(255))
encrypt_public_key: Mapped[str | None] = mapped_column(sa.Text) encrypt_public_key: Mapped[str | None] = mapped_column(sa.Text, default=None)
plan: Mapped[str] = mapped_column(String(255), server_default=sa.text("'basic'::character varying")) plan: Mapped[str] = mapped_column(
status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'::character varying")) String(255), server_default=sa.text("'basic'::character varying"), default="basic"
custom_config: Mapped[str | None] = mapped_column(sa.Text) )
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) status: Mapped[str] = mapped_column(
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) String(255), server_default=sa.text("'normal'::character varying"), default="normal"
)
custom_config: Mapped[str | None] = mapped_column(sa.Text, default=None)
created_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
)
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), init=False)
def get_accounts(self) -> list[Account]: def get_accounts(self) -> list[Account]:
return list( return list(
@ -257,7 +270,7 @@ class Tenant(Base):
self.custom_config = json.dumps(value) self.custom_config = json.dumps(value)
class TenantAccountJoin(Base): class TenantAccountJoin(TypeBase):
__tablename__ = "tenant_account_joins" __tablename__ = "tenant_account_joins"
__table_args__ = ( __table_args__ = (
sa.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"), sa.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"),
@ -266,17 +279,21 @@ class TenantAccountJoin(Base):
sa.UniqueConstraint("tenant_id", "account_id", name="unique_tenant_account_join"), sa.UniqueConstraint("tenant_id", "account_id", name="unique_tenant_account_join"),
) )
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
tenant_id: Mapped[str] = mapped_column(StringUUID) tenant_id: Mapped[str] = mapped_column(StringUUID)
account_id: Mapped[str] = mapped_column(StringUUID) account_id: Mapped[str] = mapped_column(StringUUID)
current: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false")) current: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"), default=False)
role: Mapped[str] = mapped_column(String(16), server_default="normal") role: Mapped[str] = mapped_column(String(16), server_default="normal", default="normal")
invited_by: Mapped[str | None] = mapped_column(StringUUID) invited_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) created_at: Mapped[datetime] = mapped_column(
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) DateTime, server_default=func.current_timestamp(), nullable=False, init=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
)
class AccountIntegrate(Base): class AccountIntegrate(TypeBase):
__tablename__ = "account_integrates" __tablename__ = "account_integrates"
__table_args__ = ( __table_args__ = (
sa.PrimaryKeyConstraint("id", name="account_integrate_pkey"), sa.PrimaryKeyConstraint("id", name="account_integrate_pkey"),
@ -284,16 +301,20 @@ class AccountIntegrate(Base):
sa.UniqueConstraint("provider", "open_id", name="unique_provider_open_id"), sa.UniqueConstraint("provider", "open_id", name="unique_provider_open_id"),
) )
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
account_id: Mapped[str] = mapped_column(StringUUID) account_id: Mapped[str] = mapped_column(StringUUID)
provider: Mapped[str] = mapped_column(String(16)) provider: Mapped[str] = mapped_column(String(16))
open_id: Mapped[str] = mapped_column(String(255)) open_id: Mapped[str] = mapped_column(String(255))
encrypted_token: Mapped[str] = mapped_column(String(255)) encrypted_token: Mapped[str] = mapped_column(String(255))
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) created_at: Mapped[datetime] = mapped_column(
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) DateTime, server_default=func.current_timestamp(), nullable=False, init=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
)
class InvitationCode(Base): class InvitationCode(TypeBase):
__tablename__ = "invitation_codes" __tablename__ = "invitation_codes"
__table_args__ = ( __table_args__ = (
sa.PrimaryKeyConstraint("id", name="invitation_code_pkey"), sa.PrimaryKeyConstraint("id", name="invitation_code_pkey"),
@ -301,18 +322,22 @@ class InvitationCode(Base):
sa.Index("invitation_codes_code_idx", "code", "status"), sa.Index("invitation_codes_code_idx", "code", "status"),
) )
id: Mapped[int] = mapped_column(sa.Integer) id: Mapped[int] = mapped_column(sa.Integer, init=False)
batch: Mapped[str] = mapped_column(String(255)) batch: Mapped[str] = mapped_column(String(255))
code: Mapped[str] = mapped_column(String(32)) code: Mapped[str] = mapped_column(String(32))
status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'unused'::character varying")) status: Mapped[str] = mapped_column(
used_at: Mapped[datetime | None] = mapped_column(DateTime) String(16), server_default=sa.text("'unused'::character varying"), default="unused"
used_by_tenant_id: Mapped[str | None] = mapped_column(StringUUID) )
used_by_account_id: Mapped[str | None] = mapped_column(StringUUID) used_at: Mapped[datetime | None] = mapped_column(DateTime, default=None)
deprecated_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) used_by_tenant_id: Mapped[str | None] = mapped_column(StringUUID, default=None)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=sa.text("CURRENT_TIMESTAMP(0)")) used_by_account_id: Mapped[str | None] = mapped_column(StringUUID, default=None)
deprecated_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column(
DateTime, server_default=sa.text("CURRENT_TIMESTAMP(0)"), nullable=False, init=False
)
class TenantPluginPermission(Base): class TenantPluginPermission(TypeBase):
class InstallPermission(enum.StrEnum): class InstallPermission(enum.StrEnum):
EVERYONE = "everyone" EVERYONE = "everyone"
ADMINS = "admins" ADMINS = "admins"
@ -329,13 +354,17 @@ class TenantPluginPermission(Base):
sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin"), sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin"),
) )
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
install_permission: Mapped[InstallPermission] = mapped_column(String(16), nullable=False, server_default="everyone") install_permission: Mapped[InstallPermission] = mapped_column(
debug_permission: Mapped[DebugPermission] = mapped_column(String(16), nullable=False, server_default="noone") String(16), nullable=False, server_default="everyone", default=InstallPermission.EVERYONE
)
debug_permission: Mapped[DebugPermission] = mapped_column(
String(16), nullable=False, server_default="noone", default=DebugPermission.NOBODY
)
class TenantPluginAutoUpgradeStrategy(Base): class TenantPluginAutoUpgradeStrategy(TypeBase):
class StrategySetting(enum.StrEnum): class StrategySetting(enum.StrEnum):
DISABLED = "disabled" DISABLED = "disabled"
FIX_ONLY = "fix_only" FIX_ONLY = "fix_only"
@ -352,12 +381,20 @@ class TenantPluginAutoUpgradeStrategy(Base):
sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin_auto_upgrade_strategy"), sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin_auto_upgrade_strategy"),
) )
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
strategy_setting: Mapped[StrategySetting] = mapped_column(String(16), nullable=False, server_default="fix_only") strategy_setting: Mapped[StrategySetting] = mapped_column(
upgrade_time_of_day: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) # seconds of the day String(16), nullable=False, server_default="fix_only", default=StrategySetting.FIX_ONLY
upgrade_mode: Mapped[UpgradeMode] = mapped_column(String(16), nullable=False, server_default="exclude") )
exclude_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False) # plugin_id (author/name) upgrade_mode: Mapped[UpgradeMode] = mapped_column(
include_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False) # plugin_id (author/name) String(16), nullable=False, server_default="exclude", default=UpgradeMode.EXCLUDE
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) )
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) exclude_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False, default_factory=list)
include_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False, default_factory=list)
upgrade_time_of_day: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)

View File

@ -246,10 +246,8 @@ class AccountService:
) )
) )
account = Account() password_to_set = None
account.email = email salt_to_set = None
account.name = name
if password: if password:
valid_password(password) valid_password(password)
@ -261,14 +259,18 @@ class AccountService:
password_hashed = hash_password(password, salt) password_hashed = hash_password(password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode() base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed password_to_set = base64_password_hashed
account.password_salt = base64_salt salt_to_set = base64_salt
account.interface_language = interface_language account = Account(
account.interface_theme = interface_theme name=name,
email=email,
# Set timezone based on language password=password_to_set,
account.timezone = language_timezone_mapping.get(interface_language, "UTC") password_salt=salt_to_set,
interface_language=interface_language,
interface_theme=interface_theme,
timezone=language_timezone_mapping.get(interface_language, "UTC"),
)
db.session.add(account) db.session.add(account)
db.session.commit() db.session.commit()

View File

@ -33,17 +33,19 @@ class TestChatMessageApiPermissions:
@pytest.fixture @pytest.fixture
def mock_account(self, monkeypatch: pytest.MonkeyPatch): def mock_account(self, monkeypatch: pytest.MonkeyPatch):
"""Create a mock Account for testing.""" """Create a mock Account for testing."""
account = Account()
account.id = str(uuid.uuid4()) account = Account(
account.name = "Test User" name="Test User",
account.email = "test@example.com" email="test@example.com",
)
account.last_active_at = naive_utc_now() account.last_active_at = naive_utc_now()
account.created_at = naive_utc_now() account.created_at = naive_utc_now()
account.updated_at = naive_utc_now() account.updated_at = naive_utc_now()
account.id = str(uuid.uuid4())
tenant = Tenant() # Create mock tenant
tenant = Tenant(name="Test Tenant")
tenant.id = str(uuid.uuid4()) tenant.id = str(uuid.uuid4())
tenant.name = "Test Tenant"
mock_session_instance = mock.Mock() mock_session_instance = mock.Mock()

View File

@ -32,17 +32,16 @@ class TestModelConfigResourcePermissions:
@pytest.fixture @pytest.fixture
def mock_account(self, monkeypatch: pytest.MonkeyPatch): def mock_account(self, monkeypatch: pytest.MonkeyPatch):
"""Create a mock Account for testing.""" """Create a mock Account for testing."""
account = Account()
account = Account(name="Test User", email="test@example.com")
account.id = str(uuid.uuid4()) account.id = str(uuid.uuid4())
account.name = "Test User"
account.email = "test@example.com"
account.last_active_at = naive_utc_now() account.last_active_at = naive_utc_now()
account.created_at = naive_utc_now() account.created_at = naive_utc_now()
account.updated_at = naive_utc_now() account.updated_at = naive_utc_now()
tenant = Tenant() # Create mock tenant
tenant = Tenant(name="Test Tenant")
tenant.id = str(uuid.uuid4()) tenant.id = str(uuid.uuid4())
tenant.name = "Test Tenant"
mock_session_instance = mock.Mock() mock_session_instance = mock.Mock()

View File

@ -16,6 +16,7 @@ from services.errors.account import (
AccountPasswordError, AccountPasswordError,
AccountRegisterError, AccountRegisterError,
CurrentPasswordIncorrectError, CurrentPasswordIncorrectError,
TenantNotFoundError,
) )
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
@ -1414,7 +1415,7 @@ class TestTenantService:
) )
# Try to get current tenant (should fail) # Try to get current tenant (should fail)
with pytest.raises(AttributeError): with pytest.raises((AttributeError, TenantNotFoundError)):
TenantService.get_current_tenant_by_account(account) TenantService.get_current_tenant_by_account(account)
def test_switch_tenant_success(self, db_session_with_containers, mock_external_service_dependencies): def test_switch_tenant_success(self, db_session_with_containers, mock_external_service_dependencies):

View File

@ -44,27 +44,26 @@ class TestWorkflowService:
Account: Created test account instance Account: Created test account instance
""" """
fake = fake or Faker() fake = fake or Faker()
account = Account() account = Account(
account.id = fake.uuid4() email=fake.email(),
account.email = fake.email() name=fake.name(),
account.name = fake.name() avatar=fake.url(),
account.avatar_url = fake.url() status="active",
account.tenant_id = fake.uuid4() interface_language="en-US", # Set interface language for Site creation
account.status = "active" )
account.type = "normal"
account.role = "owner"
account.interface_language = "en-US" # Set interface language for Site creation
account.created_at = fake.date_time_this_year() account.created_at = fake.date_time_this_year()
account.id = fake.uuid4()
account.updated_at = account.created_at account.updated_at = account.created_at
# Create a tenant for the account # Create a tenant for the account
from models.account import Tenant from models.account import Tenant
tenant = Tenant() tenant = Tenant(
tenant.id = account.tenant_id name=f"Test Tenant {fake.company()}",
tenant.name = f"Test Tenant {fake.company()}" plan="basic",
tenant.plan = "basic" status="active",
tenant.status = "active" )
tenant.id = account.current_tenant_id
tenant.created_at = fake.date_time_this_year() tenant.created_at = fake.date_time_this_year()
tenant.updated_at = tenant.created_at tenant.updated_at = tenant.created_at
@ -91,20 +90,21 @@ class TestWorkflowService:
App: Created test app instance App: Created test app instance
""" """
fake = fake or Faker() fake = fake or Faker()
app = App() app = App(
app.id = fake.uuid4() id=fake.uuid4(),
app.tenant_id = fake.uuid4() tenant_id=fake.uuid4(),
app.name = fake.company() name=fake.company(),
app.description = fake.text() description=fake.text(),
app.mode = AppMode.WORKFLOW mode=AppMode.WORKFLOW,
app.icon_type = "emoji" icon_type="emoji",
app.icon = "🤖" icon="🤖",
app.icon_background = "#FFEAD5" icon_background="#FFEAD5",
app.enable_site = True enable_site=True,
app.enable_api = True enable_api=True,
app.created_by = fake.uuid4() created_by=fake.uuid4(),
workflow_id=None, # Will be set when workflow is created
)
app.updated_by = app.created_by app.updated_by = app.created_by
app.workflow_id = None # Will be set when workflow is created
from extensions.ext_database import db from extensions.ext_database import db
@ -126,19 +126,20 @@ class TestWorkflowService:
Workflow: Created test workflow instance Workflow: Created test workflow instance
""" """
fake = fake or Faker() fake = fake or Faker()
workflow = Workflow() workflow = Workflow(
workflow.id = fake.uuid4() id=fake.uuid4(),
workflow.tenant_id = app.tenant_id tenant_id=app.tenant_id,
workflow.app_id = app.id app_id=app.id,
workflow.type = WorkflowType.WORKFLOW.value type=WorkflowType.WORKFLOW.value,
workflow.version = Workflow.VERSION_DRAFT version=Workflow.VERSION_DRAFT,
workflow.graph = json.dumps({"nodes": [], "edges": []}) graph=json.dumps({"nodes": [], "edges": []}),
workflow.features = json.dumps({"features": []}) features=json.dumps({"features": []}),
# unique_hash is a computed property based on graph and features # unique_hash is a computed property based on graph and features
workflow.created_by = account.id created_by=account.id,
workflow.updated_by = account.id updated_by=account.id,
workflow.environment_variables = [] environment_variables=[],
workflow.conversation_variables = [] conversation_variables=[],
)
from extensions.ext_database import db from extensions.ext_database import db

View File

@ -48,11 +48,8 @@ class TestDeleteSegmentFromIndexTask:
Tenant: Created test tenant instance Tenant: Created test tenant instance
""" """
fake = fake or Faker() fake = fake or Faker()
tenant = Tenant() tenant = Tenant(name=f"Test Tenant {fake.company()}", plan="basic", status="active")
tenant.id = fake.uuid4() tenant.id = fake.uuid4()
tenant.name = f"Test Tenant {fake.company()}"
tenant.plan = "basic"
tenant.status = "active"
tenant.created_at = fake.date_time_this_year() tenant.created_at = fake.date_time_this_year()
tenant.updated_at = tenant.created_at tenant.updated_at = tenant.created_at
@ -73,16 +70,14 @@ class TestDeleteSegmentFromIndexTask:
Account: Created test account instance Account: Created test account instance
""" """
fake = fake or Faker() fake = fake or Faker()
account = Account() account = Account(
name=fake.name(),
email=fake.email(),
avatar=fake.url(),
status="active",
interface_language="en-US",
)
account.id = fake.uuid4() account.id = fake.uuid4()
account.email = fake.email()
account.name = fake.name()
account.avatar_url = fake.url()
account.tenant_id = tenant.id
account.status = "active"
account.type = "normal"
account.role = "owner"
account.interface_language = "en-US"
account.created_at = fake.date_time_this_year() account.created_at = fake.date_time_this_year()
account.updated_at = account.created_at account.updated_at = account.created_at

View File

@ -43,27 +43,30 @@ class TestDisableSegmentsFromIndexTask:
Account: Created test account instance Account: Created test account instance
""" """
fake = fake or Faker() fake = fake or Faker()
account = Account() account = Account(
email=fake.email(),
name=fake.name(),
avatar=fake.url(),
status="active",
interface_language="en-US",
)
account.id = fake.uuid4() account.id = fake.uuid4()
account.email = fake.email() # monkey-patch attributes for test setup
account.name = fake.name()
account.avatar_url = fake.url()
account.tenant_id = fake.uuid4() account.tenant_id = fake.uuid4()
account.status = "active"
account.type = "normal" account.type = "normal"
account.role = "owner" account.role = "owner"
account.interface_language = "en-US"
account.created_at = fake.date_time_this_year() account.created_at = fake.date_time_this_year()
account.updated_at = account.created_at account.updated_at = account.created_at
# Create a tenant for the account # Create a tenant for the account
from models.account import Tenant from models.account import Tenant
tenant = Tenant() tenant = Tenant(
name=f"Test Tenant {fake.company()}",
plan="basic",
status="active",
)
tenant.id = account.tenant_id tenant.id = account.tenant_id
tenant.name = f"Test Tenant {fake.company()}"
tenant.plan = "basic"
tenant.status = "active"
tenant.created_at = fake.date_time_this_year() tenant.created_at = fake.date_time_this_year()
tenant.updated_at = tenant.created_at tenant.updated_at = tenant.created_at
@ -91,20 +94,21 @@ class TestDisableSegmentsFromIndexTask:
Dataset: Created test dataset instance Dataset: Created test dataset instance
""" """
fake = fake or Faker() fake = fake or Faker()
dataset = Dataset() dataset = Dataset(
dataset.id = fake.uuid4() id=fake.uuid4(),
dataset.tenant_id = account.tenant_id tenant_id=account.tenant_id,
dataset.name = f"Test Dataset {fake.word()}" name=f"Test Dataset {fake.word()}",
dataset.description = fake.text(max_nb_chars=200) description=fake.text(max_nb_chars=200),
dataset.provider = "vendor" provider="vendor",
dataset.permission = "only_me" permission="only_me",
dataset.data_source_type = "upload_file" data_source_type="upload_file",
dataset.indexing_technique = "high_quality" indexing_technique="high_quality",
dataset.created_by = account.id created_by=account.id,
dataset.updated_by = account.id updated_by=account.id,
dataset.embedding_model = "text-embedding-ada-002" embedding_model="text-embedding-ada-002",
dataset.embedding_model_provider = "openai" embedding_model_provider="openai",
dataset.built_in_field_enabled = False built_in_field_enabled=False,
)
from extensions.ext_database import db from extensions.ext_database import db
@ -128,6 +132,7 @@ class TestDisableSegmentsFromIndexTask:
""" """
fake = fake or Faker() fake = fake or Faker()
document = DatasetDocument() document = DatasetDocument()
document.id = fake.uuid4() document.id = fake.uuid4()
document.tenant_id = dataset.tenant_id document.tenant_id = dataset.tenant_id
document.dataset_id = dataset.id document.dataset_id = dataset.id
@ -153,7 +158,6 @@ class TestDisableSegmentsFromIndexTask:
document.archived = False document.archived = False
document.doc_form = "text_model" # Use text_model form for testing document.doc_form = "text_model" # Use text_model form for testing
document.doc_language = "en" document.doc_language = "en"
from extensions.ext_database import db from extensions.ext_database import db
db.session.add(document) db.session.add(document)

View File

@ -96,9 +96,9 @@ class TestMailInviteMemberTask:
password=fake.password(), password=fake.password(),
interface_language="en-US", interface_language="en-US",
status=AccountStatus.ACTIVE.value, status=AccountStatus.ACTIVE.value,
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
) )
account.created_at = datetime.now(UTC)
account.updated_at = datetime.now(UTC)
db_session_with_containers.add(account) db_session_with_containers.add(account)
db_session_with_containers.commit() db_session_with_containers.commit()
db_session_with_containers.refresh(account) db_session_with_containers.refresh(account)
@ -106,9 +106,9 @@ class TestMailInviteMemberTask:
# Create tenant # Create tenant
tenant = Tenant( tenant = Tenant(
name=fake.company(), name=fake.company(),
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
) )
tenant.created_at = datetime.now(UTC)
tenant.updated_at = datetime.now(UTC)
db_session_with_containers.add(tenant) db_session_with_containers.add(tenant)
db_session_with_containers.commit() db_session_with_containers.commit()
db_session_with_containers.refresh(tenant) db_session_with_containers.refresh(tenant)
@ -118,8 +118,8 @@ class TestMailInviteMemberTask:
tenant_id=tenant.id, tenant_id=tenant.id,
account_id=account.id, account_id=account.id,
role=TenantAccountRole.OWNER.value, role=TenantAccountRole.OWNER.value,
created_at=datetime.now(UTC),
) )
tenant_join.created_at = datetime.now(UTC)
db_session_with_containers.add(tenant_join) db_session_with_containers.add(tenant_join)
db_session_with_containers.commit() db_session_with_containers.commit()
@ -164,9 +164,10 @@ class TestMailInviteMemberTask:
password="", password="",
interface_language="en-US", interface_language="en-US",
status=AccountStatus.PENDING.value, status=AccountStatus.PENDING.value,
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
) )
account.created_at = datetime.now(UTC)
account.updated_at = datetime.now(UTC)
db_session_with_containers.add(account) db_session_with_containers.add(account)
db_session_with_containers.commit() db_session_with_containers.commit()
db_session_with_containers.refresh(account) db_session_with_containers.refresh(account)
@ -176,8 +177,8 @@ class TestMailInviteMemberTask:
tenant_id=tenant.id, tenant_id=tenant.id,
account_id=account.id, account_id=account.id,
role=TenantAccountRole.NORMAL.value, role=TenantAccountRole.NORMAL.value,
created_at=datetime.now(UTC),
) )
tenant_join.created_at = datetime.now(UTC)
db_session_with_containers.add(tenant_join) db_session_with_containers.add(tenant_join)
db_session_with_containers.commit() db_session_with_containers.commit()

View File

@ -11,7 +11,7 @@ class TestExtractTenantId:
def test_extract_tenant_id_from_account_with_tenant(self): def test_extract_tenant_id_from_account_with_tenant(self):
"""Test extracting tenant_id from Account with current_tenant_id.""" """Test extracting tenant_id from Account with current_tenant_id."""
# Create a mock Account object # Create a mock Account object
account = Account() account = Account(name="test", email="test@example.com")
# Mock the current_tenant_id property # Mock the current_tenant_id property
account._current_tenant = type("MockTenant", (), {"id": "account-tenant-123"})() account._current_tenant = type("MockTenant", (), {"id": "account-tenant-123"})()
@ -21,7 +21,7 @@ class TestExtractTenantId:
def test_extract_tenant_id_from_account_without_tenant(self): def test_extract_tenant_id_from_account_without_tenant(self):
"""Test extracting tenant_id from Account without current_tenant_id.""" """Test extracting tenant_id from Account without current_tenant_id."""
# Create a mock Account object # Create a mock Account object
account = Account() account = Account(name="test", email="test@example.com")
account._current_tenant = None account._current_tenant = None
tenant_id = extract_tenant_id(account) tenant_id = extract_tenant_id(account)

View File

@ -59,12 +59,11 @@ def session():
@pytest.fixture @pytest.fixture
def mock_user(): def mock_user():
"""Create a user instance for testing.""" """Create a user instance for testing."""
user = Account() user = Account(name="test", email="test@example.com")
user.id = "test-user-id" user.id = "test-user-id"
tenant = Tenant() tenant = Tenant(name="Test Workspace")
tenant.id = "test-tenant" tenant.id = "test-tenant"
tenant.name = "Test Workspace"
user._current_tenant = MagicMock() user._current_tenant = MagicMock()
user._current_tenant.id = "test-tenant" user._current_tenant.id = "test-tenant"

View File

@ -47,7 +47,8 @@ class TestDraftVariableSaver:
def test__should_variable_be_visible(self): def test__should_variable_be_visible(self):
mock_session = MagicMock(spec=Session) mock_session = MagicMock(spec=Session)
mock_user = Account(id=str(uuid.uuid4())) mock_user = Account(name="test", email="test@example.com")
mock_user.id = str(uuid.uuid4())
test_app_id = self._get_test_app_id() test_app_id = self._get_test_app_id()
saver = DraftVariableSaver( saver = DraftVariableSaver(
session=mock_session, session=mock_session,