mirror of https://github.com/langgenius/dify.git
Refactor account models to use SQLAlchemy 2.0 dataclass mapping (#26415)
Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
2b6882bd97
commit
8a2b208299
|
|
@ -1,15 +1,16 @@
|
||||||
import enum
|
import enum
|
||||||
import json
|
import json
|
||||||
|
from dataclasses import field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from flask_login import UserMixin # type: ignore[import-untyped]
|
from flask_login import UserMixin # type: ignore[import-untyped]
|
||||||
from sqlalchemy import DateTime, String, func, select
|
from sqlalchemy import DateTime, String, func, select
|
||||||
from sqlalchemy.orm import Mapped, Session, mapped_column, reconstructor
|
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||||
from typing_extensions import deprecated
|
from typing_extensions import deprecated
|
||||||
|
|
||||||
from models.base import Base
|
from models.base import TypeBase
|
||||||
|
|
||||||
from .engine import db
|
from .engine import db
|
||||||
from .types import StringUUID
|
from .types import StringUUID
|
||||||
|
|
@ -83,31 +84,37 @@ class AccountStatus(enum.StrEnum):
|
||||||
CLOSED = "closed"
|
CLOSED = "closed"
|
||||||
|
|
||||||
|
|
||||||
class Account(UserMixin, Base):
|
class Account(UserMixin, TypeBase):
|
||||||
__tablename__ = "accounts"
|
__tablename__ = "accounts"
|
||||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="account_pkey"), sa.Index("account_email_idx", "email"))
|
__table_args__ = (sa.PrimaryKeyConstraint("id", name="account_pkey"), sa.Index("account_email_idx", "email"))
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||||
name: Mapped[str] = mapped_column(String(255))
|
name: Mapped[str] = mapped_column(String(255))
|
||||||
email: Mapped[str] = mapped_column(String(255))
|
email: Mapped[str] = mapped_column(String(255))
|
||||||
password: Mapped[str | None] = mapped_column(String(255))
|
password: Mapped[str | None] = mapped_column(String(255), default=None)
|
||||||
password_salt: Mapped[str | None] = mapped_column(String(255))
|
password_salt: Mapped[str | None] = mapped_column(String(255), default=None)
|
||||||
avatar: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
avatar: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
|
||||||
interface_language: Mapped[str | None] = mapped_column(String(255))
|
interface_language: Mapped[str | None] = mapped_column(String(255), default=None)
|
||||||
interface_theme: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
interface_theme: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
|
||||||
timezone: Mapped[str | None] = mapped_column(String(255))
|
timezone: Mapped[str | None] = mapped_column(String(255), default=None)
|
||||||
last_login_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
last_login_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
|
||||||
last_login_ip: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
last_login_ip: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
|
||||||
last_active_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
|
last_active_at: Mapped[datetime] = mapped_column(
|
||||||
status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'active'::character varying"))
|
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||||
initialized_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
)
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
|
status: Mapped[str] = mapped_column(
|
||||||
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
|
String(16), server_default=sa.text("'active'::character varying"), default="active"
|
||||||
|
)
|
||||||
|
initialized_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||||
|
)
|
||||||
|
|
||||||
@reconstructor
|
role: TenantAccountRole | None = field(default=None, init=False)
|
||||||
def init_on_load(self):
|
_current_tenant: "Tenant | None" = field(default=None, init=False)
|
||||||
self.role: TenantAccountRole | None = None
|
|
||||||
self._current_tenant: Tenant | None = None
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_password_set(self):
|
def is_password_set(self):
|
||||||
|
|
@ -226,18 +233,24 @@ class TenantStatus(enum.StrEnum):
|
||||||
ARCHIVE = "archive"
|
ARCHIVE = "archive"
|
||||||
|
|
||||||
|
|
||||||
class Tenant(Base):
|
class Tenant(TypeBase):
|
||||||
__tablename__ = "tenants"
|
__tablename__ = "tenants"
|
||||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="tenant_pkey"),)
|
__table_args__ = (sa.PrimaryKeyConstraint("id", name="tenant_pkey"),)
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||||
name: Mapped[str] = mapped_column(String(255))
|
name: Mapped[str] = mapped_column(String(255))
|
||||||
encrypt_public_key: Mapped[str | None] = mapped_column(sa.Text)
|
encrypt_public_key: Mapped[str | None] = mapped_column(sa.Text, default=None)
|
||||||
plan: Mapped[str] = mapped_column(String(255), server_default=sa.text("'basic'::character varying"))
|
plan: Mapped[str] = mapped_column(
|
||||||
status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'::character varying"))
|
String(255), server_default=sa.text("'basic'::character varying"), default="basic"
|
||||||
custom_config: Mapped[str | None] = mapped_column(sa.Text)
|
)
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
|
status: Mapped[str] = mapped_column(
|
||||||
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
|
String(255), server_default=sa.text("'normal'::character varying"), default="normal"
|
||||||
|
)
|
||||||
|
custom_config: Mapped[str | None] = mapped_column(sa.Text, default=None)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), init=False)
|
||||||
|
|
||||||
def get_accounts(self) -> list[Account]:
|
def get_accounts(self) -> list[Account]:
|
||||||
return list(
|
return list(
|
||||||
|
|
@ -257,7 +270,7 @@ class Tenant(Base):
|
||||||
self.custom_config = json.dumps(value)
|
self.custom_config = json.dumps(value)
|
||||||
|
|
||||||
|
|
||||||
class TenantAccountJoin(Base):
|
class TenantAccountJoin(TypeBase):
|
||||||
__tablename__ = "tenant_account_joins"
|
__tablename__ = "tenant_account_joins"
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
sa.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"),
|
sa.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"),
|
||||||
|
|
@ -266,17 +279,21 @@ class TenantAccountJoin(Base):
|
||||||
sa.UniqueConstraint("tenant_id", "account_id", name="unique_tenant_account_join"),
|
sa.UniqueConstraint("tenant_id", "account_id", name="unique_tenant_account_join"),
|
||||||
)
|
)
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||||
tenant_id: Mapped[str] = mapped_column(StringUUID)
|
tenant_id: Mapped[str] = mapped_column(StringUUID)
|
||||||
account_id: Mapped[str] = mapped_column(StringUUID)
|
account_id: Mapped[str] = mapped_column(StringUUID)
|
||||||
current: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"))
|
current: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"), default=False)
|
||||||
role: Mapped[str] = mapped_column(String(16), server_default="normal")
|
role: Mapped[str] = mapped_column(String(16), server_default="normal", default="normal")
|
||||||
invited_by: Mapped[str | None] = mapped_column(StringUUID)
|
invited_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
|
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AccountIntegrate(Base):
|
class AccountIntegrate(TypeBase):
|
||||||
__tablename__ = "account_integrates"
|
__tablename__ = "account_integrates"
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
sa.PrimaryKeyConstraint("id", name="account_integrate_pkey"),
|
sa.PrimaryKeyConstraint("id", name="account_integrate_pkey"),
|
||||||
|
|
@ -284,16 +301,20 @@ class AccountIntegrate(Base):
|
||||||
sa.UniqueConstraint("provider", "open_id", name="unique_provider_open_id"),
|
sa.UniqueConstraint("provider", "open_id", name="unique_provider_open_id"),
|
||||||
)
|
)
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||||
account_id: Mapped[str] = mapped_column(StringUUID)
|
account_id: Mapped[str] = mapped_column(StringUUID)
|
||||||
provider: Mapped[str] = mapped_column(String(16))
|
provider: Mapped[str] = mapped_column(String(16))
|
||||||
open_id: Mapped[str] = mapped_column(String(255))
|
open_id: Mapped[str] = mapped_column(String(255))
|
||||||
encrypted_token: Mapped[str] = mapped_column(String(255))
|
encrypted_token: Mapped[str] = mapped_column(String(255))
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
|
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class InvitationCode(Base):
|
class InvitationCode(TypeBase):
|
||||||
__tablename__ = "invitation_codes"
|
__tablename__ = "invitation_codes"
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
sa.PrimaryKeyConstraint("id", name="invitation_code_pkey"),
|
sa.PrimaryKeyConstraint("id", name="invitation_code_pkey"),
|
||||||
|
|
@ -301,18 +322,22 @@ class InvitationCode(Base):
|
||||||
sa.Index("invitation_codes_code_idx", "code", "status"),
|
sa.Index("invitation_codes_code_idx", "code", "status"),
|
||||||
)
|
)
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(sa.Integer)
|
id: Mapped[int] = mapped_column(sa.Integer, init=False)
|
||||||
batch: Mapped[str] = mapped_column(String(255))
|
batch: Mapped[str] = mapped_column(String(255))
|
||||||
code: Mapped[str] = mapped_column(String(32))
|
code: Mapped[str] = mapped_column(String(32))
|
||||||
status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'unused'::character varying"))
|
status: Mapped[str] = mapped_column(
|
||||||
used_at: Mapped[datetime | None] = mapped_column(DateTime)
|
String(16), server_default=sa.text("'unused'::character varying"), default="unused"
|
||||||
used_by_tenant_id: Mapped[str | None] = mapped_column(StringUUID)
|
)
|
||||||
used_by_account_id: Mapped[str | None] = mapped_column(StringUUID)
|
used_at: Mapped[datetime | None] = mapped_column(DateTime, default=None)
|
||||||
deprecated_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
used_by_tenant_id: Mapped[str | None] = mapped_column(StringUUID, default=None)
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=sa.text("CURRENT_TIMESTAMP(0)"))
|
used_by_account_id: Mapped[str | None] = mapped_column(StringUUID, default=None)
|
||||||
|
deprecated_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime, server_default=sa.text("CURRENT_TIMESTAMP(0)"), nullable=False, init=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TenantPluginPermission(Base):
|
class TenantPluginPermission(TypeBase):
|
||||||
class InstallPermission(enum.StrEnum):
|
class InstallPermission(enum.StrEnum):
|
||||||
EVERYONE = "everyone"
|
EVERYONE = "everyone"
|
||||||
ADMINS = "admins"
|
ADMINS = "admins"
|
||||||
|
|
@ -329,13 +354,17 @@ class TenantPluginPermission(Base):
|
||||||
sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin"),
|
sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin"),
|
||||||
)
|
)
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
install_permission: Mapped[InstallPermission] = mapped_column(String(16), nullable=False, server_default="everyone")
|
install_permission: Mapped[InstallPermission] = mapped_column(
|
||||||
debug_permission: Mapped[DebugPermission] = mapped_column(String(16), nullable=False, server_default="noone")
|
String(16), nullable=False, server_default="everyone", default=InstallPermission.EVERYONE
|
||||||
|
)
|
||||||
|
debug_permission: Mapped[DebugPermission] = mapped_column(
|
||||||
|
String(16), nullable=False, server_default="noone", default=DebugPermission.NOBODY
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TenantPluginAutoUpgradeStrategy(Base):
|
class TenantPluginAutoUpgradeStrategy(TypeBase):
|
||||||
class StrategySetting(enum.StrEnum):
|
class StrategySetting(enum.StrEnum):
|
||||||
DISABLED = "disabled"
|
DISABLED = "disabled"
|
||||||
FIX_ONLY = "fix_only"
|
FIX_ONLY = "fix_only"
|
||||||
|
|
@ -352,12 +381,20 @@ class TenantPluginAutoUpgradeStrategy(Base):
|
||||||
sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin_auto_upgrade_strategy"),
|
sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin_auto_upgrade_strategy"),
|
||||||
)
|
)
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
strategy_setting: Mapped[StrategySetting] = mapped_column(String(16), nullable=False, server_default="fix_only")
|
strategy_setting: Mapped[StrategySetting] = mapped_column(
|
||||||
upgrade_time_of_day: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) # seconds of the day
|
String(16), nullable=False, server_default="fix_only", default=StrategySetting.FIX_ONLY
|
||||||
upgrade_mode: Mapped[UpgradeMode] = mapped_column(String(16), nullable=False, server_default="exclude")
|
)
|
||||||
exclude_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False) # plugin_id (author/name)
|
upgrade_mode: Mapped[UpgradeMode] = mapped_column(
|
||||||
include_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False) # plugin_id (author/name)
|
String(16), nullable=False, server_default="exclude", default=UpgradeMode.EXCLUDE
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
)
|
||||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
exclude_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False, default_factory=list)
|
||||||
|
include_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False, default_factory=list)
|
||||||
|
upgrade_time_of_day: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -246,10 +246,8 @@ class AccountService:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
account = Account()
|
password_to_set = None
|
||||||
account.email = email
|
salt_to_set = None
|
||||||
account.name = name
|
|
||||||
|
|
||||||
if password:
|
if password:
|
||||||
valid_password(password)
|
valid_password(password)
|
||||||
|
|
||||||
|
|
@ -261,14 +259,18 @@ class AccountService:
|
||||||
password_hashed = hash_password(password, salt)
|
password_hashed = hash_password(password, salt)
|
||||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||||
|
|
||||||
account.password = base64_password_hashed
|
password_to_set = base64_password_hashed
|
||||||
account.password_salt = base64_salt
|
salt_to_set = base64_salt
|
||||||
|
|
||||||
account.interface_language = interface_language
|
account = Account(
|
||||||
account.interface_theme = interface_theme
|
name=name,
|
||||||
|
email=email,
|
||||||
# Set timezone based on language
|
password=password_to_set,
|
||||||
account.timezone = language_timezone_mapping.get(interface_language, "UTC")
|
password_salt=salt_to_set,
|
||||||
|
interface_language=interface_language,
|
||||||
|
interface_theme=interface_theme,
|
||||||
|
timezone=language_timezone_mapping.get(interface_language, "UTC"),
|
||||||
|
)
|
||||||
|
|
||||||
db.session.add(account)
|
db.session.add(account)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
|
||||||
|
|
@ -33,17 +33,19 @@ class TestChatMessageApiPermissions:
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_account(self, monkeypatch: pytest.MonkeyPatch):
|
def mock_account(self, monkeypatch: pytest.MonkeyPatch):
|
||||||
"""Create a mock Account for testing."""
|
"""Create a mock Account for testing."""
|
||||||
account = Account()
|
|
||||||
account.id = str(uuid.uuid4())
|
account = Account(
|
||||||
account.name = "Test User"
|
name="Test User",
|
||||||
account.email = "test@example.com"
|
email="test@example.com",
|
||||||
|
)
|
||||||
account.last_active_at = naive_utc_now()
|
account.last_active_at = naive_utc_now()
|
||||||
account.created_at = naive_utc_now()
|
account.created_at = naive_utc_now()
|
||||||
account.updated_at = naive_utc_now()
|
account.updated_at = naive_utc_now()
|
||||||
|
account.id = str(uuid.uuid4())
|
||||||
|
|
||||||
tenant = Tenant()
|
# Create mock tenant
|
||||||
|
tenant = Tenant(name="Test Tenant")
|
||||||
tenant.id = str(uuid.uuid4())
|
tenant.id = str(uuid.uuid4())
|
||||||
tenant.name = "Test Tenant"
|
|
||||||
|
|
||||||
mock_session_instance = mock.Mock()
|
mock_session_instance = mock.Mock()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -32,17 +32,16 @@ class TestModelConfigResourcePermissions:
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_account(self, monkeypatch: pytest.MonkeyPatch):
|
def mock_account(self, monkeypatch: pytest.MonkeyPatch):
|
||||||
"""Create a mock Account for testing."""
|
"""Create a mock Account for testing."""
|
||||||
account = Account()
|
|
||||||
|
account = Account(name="Test User", email="test@example.com")
|
||||||
account.id = str(uuid.uuid4())
|
account.id = str(uuid.uuid4())
|
||||||
account.name = "Test User"
|
|
||||||
account.email = "test@example.com"
|
|
||||||
account.last_active_at = naive_utc_now()
|
account.last_active_at = naive_utc_now()
|
||||||
account.created_at = naive_utc_now()
|
account.created_at = naive_utc_now()
|
||||||
account.updated_at = naive_utc_now()
|
account.updated_at = naive_utc_now()
|
||||||
|
|
||||||
tenant = Tenant()
|
# Create mock tenant
|
||||||
|
tenant = Tenant(name="Test Tenant")
|
||||||
tenant.id = str(uuid.uuid4())
|
tenant.id = str(uuid.uuid4())
|
||||||
tenant.name = "Test Tenant"
|
|
||||||
|
|
||||||
mock_session_instance = mock.Mock()
|
mock_session_instance = mock.Mock()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ from services.errors.account import (
|
||||||
AccountPasswordError,
|
AccountPasswordError,
|
||||||
AccountRegisterError,
|
AccountRegisterError,
|
||||||
CurrentPasswordIncorrectError,
|
CurrentPasswordIncorrectError,
|
||||||
|
TenantNotFoundError,
|
||||||
)
|
)
|
||||||
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
|
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
|
||||||
|
|
||||||
|
|
@ -1414,7 +1415,7 @@ class TestTenantService:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Try to get current tenant (should fail)
|
# Try to get current tenant (should fail)
|
||||||
with pytest.raises(AttributeError):
|
with pytest.raises((AttributeError, TenantNotFoundError)):
|
||||||
TenantService.get_current_tenant_by_account(account)
|
TenantService.get_current_tenant_by_account(account)
|
||||||
|
|
||||||
def test_switch_tenant_success(self, db_session_with_containers, mock_external_service_dependencies):
|
def test_switch_tenant_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||||
|
|
|
||||||
|
|
@ -44,27 +44,26 @@ class TestWorkflowService:
|
||||||
Account: Created test account instance
|
Account: Created test account instance
|
||||||
"""
|
"""
|
||||||
fake = fake or Faker()
|
fake = fake or Faker()
|
||||||
account = Account()
|
account = Account(
|
||||||
account.id = fake.uuid4()
|
email=fake.email(),
|
||||||
account.email = fake.email()
|
name=fake.name(),
|
||||||
account.name = fake.name()
|
avatar=fake.url(),
|
||||||
account.avatar_url = fake.url()
|
status="active",
|
||||||
account.tenant_id = fake.uuid4()
|
interface_language="en-US", # Set interface language for Site creation
|
||||||
account.status = "active"
|
)
|
||||||
account.type = "normal"
|
|
||||||
account.role = "owner"
|
|
||||||
account.interface_language = "en-US" # Set interface language for Site creation
|
|
||||||
account.created_at = fake.date_time_this_year()
|
account.created_at = fake.date_time_this_year()
|
||||||
|
account.id = fake.uuid4()
|
||||||
account.updated_at = account.created_at
|
account.updated_at = account.created_at
|
||||||
|
|
||||||
# Create a tenant for the account
|
# Create a tenant for the account
|
||||||
from models.account import Tenant
|
from models.account import Tenant
|
||||||
|
|
||||||
tenant = Tenant()
|
tenant = Tenant(
|
||||||
tenant.id = account.tenant_id
|
name=f"Test Tenant {fake.company()}",
|
||||||
tenant.name = f"Test Tenant {fake.company()}"
|
plan="basic",
|
||||||
tenant.plan = "basic"
|
status="active",
|
||||||
tenant.status = "active"
|
)
|
||||||
|
tenant.id = account.current_tenant_id
|
||||||
tenant.created_at = fake.date_time_this_year()
|
tenant.created_at = fake.date_time_this_year()
|
||||||
tenant.updated_at = tenant.created_at
|
tenant.updated_at = tenant.created_at
|
||||||
|
|
||||||
|
|
@ -91,20 +90,21 @@ class TestWorkflowService:
|
||||||
App: Created test app instance
|
App: Created test app instance
|
||||||
"""
|
"""
|
||||||
fake = fake or Faker()
|
fake = fake or Faker()
|
||||||
app = App()
|
app = App(
|
||||||
app.id = fake.uuid4()
|
id=fake.uuid4(),
|
||||||
app.tenant_id = fake.uuid4()
|
tenant_id=fake.uuid4(),
|
||||||
app.name = fake.company()
|
name=fake.company(),
|
||||||
app.description = fake.text()
|
description=fake.text(),
|
||||||
app.mode = AppMode.WORKFLOW
|
mode=AppMode.WORKFLOW,
|
||||||
app.icon_type = "emoji"
|
icon_type="emoji",
|
||||||
app.icon = "🤖"
|
icon="🤖",
|
||||||
app.icon_background = "#FFEAD5"
|
icon_background="#FFEAD5",
|
||||||
app.enable_site = True
|
enable_site=True,
|
||||||
app.enable_api = True
|
enable_api=True,
|
||||||
app.created_by = fake.uuid4()
|
created_by=fake.uuid4(),
|
||||||
|
workflow_id=None, # Will be set when workflow is created
|
||||||
|
)
|
||||||
app.updated_by = app.created_by
|
app.updated_by = app.created_by
|
||||||
app.workflow_id = None # Will be set when workflow is created
|
|
||||||
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
|
@ -126,19 +126,20 @@ class TestWorkflowService:
|
||||||
Workflow: Created test workflow instance
|
Workflow: Created test workflow instance
|
||||||
"""
|
"""
|
||||||
fake = fake or Faker()
|
fake = fake or Faker()
|
||||||
workflow = Workflow()
|
workflow = Workflow(
|
||||||
workflow.id = fake.uuid4()
|
id=fake.uuid4(),
|
||||||
workflow.tenant_id = app.tenant_id
|
tenant_id=app.tenant_id,
|
||||||
workflow.app_id = app.id
|
app_id=app.id,
|
||||||
workflow.type = WorkflowType.WORKFLOW.value
|
type=WorkflowType.WORKFLOW.value,
|
||||||
workflow.version = Workflow.VERSION_DRAFT
|
version=Workflow.VERSION_DRAFT,
|
||||||
workflow.graph = json.dumps({"nodes": [], "edges": []})
|
graph=json.dumps({"nodes": [], "edges": []}),
|
||||||
workflow.features = json.dumps({"features": []})
|
features=json.dumps({"features": []}),
|
||||||
# unique_hash is a computed property based on graph and features
|
# unique_hash is a computed property based on graph and features
|
||||||
workflow.created_by = account.id
|
created_by=account.id,
|
||||||
workflow.updated_by = account.id
|
updated_by=account.id,
|
||||||
workflow.environment_variables = []
|
environment_variables=[],
|
||||||
workflow.conversation_variables = []
|
conversation_variables=[],
|
||||||
|
)
|
||||||
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -48,11 +48,8 @@ class TestDeleteSegmentFromIndexTask:
|
||||||
Tenant: Created test tenant instance
|
Tenant: Created test tenant instance
|
||||||
"""
|
"""
|
||||||
fake = fake or Faker()
|
fake = fake or Faker()
|
||||||
tenant = Tenant()
|
tenant = Tenant(name=f"Test Tenant {fake.company()}", plan="basic", status="active")
|
||||||
tenant.id = fake.uuid4()
|
tenant.id = fake.uuid4()
|
||||||
tenant.name = f"Test Tenant {fake.company()}"
|
|
||||||
tenant.plan = "basic"
|
|
||||||
tenant.status = "active"
|
|
||||||
tenant.created_at = fake.date_time_this_year()
|
tenant.created_at = fake.date_time_this_year()
|
||||||
tenant.updated_at = tenant.created_at
|
tenant.updated_at = tenant.created_at
|
||||||
|
|
||||||
|
|
@ -73,16 +70,14 @@ class TestDeleteSegmentFromIndexTask:
|
||||||
Account: Created test account instance
|
Account: Created test account instance
|
||||||
"""
|
"""
|
||||||
fake = fake or Faker()
|
fake = fake or Faker()
|
||||||
account = Account()
|
account = Account(
|
||||||
|
name=fake.name(),
|
||||||
|
email=fake.email(),
|
||||||
|
avatar=fake.url(),
|
||||||
|
status="active",
|
||||||
|
interface_language="en-US",
|
||||||
|
)
|
||||||
account.id = fake.uuid4()
|
account.id = fake.uuid4()
|
||||||
account.email = fake.email()
|
|
||||||
account.name = fake.name()
|
|
||||||
account.avatar_url = fake.url()
|
|
||||||
account.tenant_id = tenant.id
|
|
||||||
account.status = "active"
|
|
||||||
account.type = "normal"
|
|
||||||
account.role = "owner"
|
|
||||||
account.interface_language = "en-US"
|
|
||||||
account.created_at = fake.date_time_this_year()
|
account.created_at = fake.date_time_this_year()
|
||||||
account.updated_at = account.created_at
|
account.updated_at = account.created_at
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -43,27 +43,30 @@ class TestDisableSegmentsFromIndexTask:
|
||||||
Account: Created test account instance
|
Account: Created test account instance
|
||||||
"""
|
"""
|
||||||
fake = fake or Faker()
|
fake = fake or Faker()
|
||||||
account = Account()
|
account = Account(
|
||||||
|
email=fake.email(),
|
||||||
|
name=fake.name(),
|
||||||
|
avatar=fake.url(),
|
||||||
|
status="active",
|
||||||
|
interface_language="en-US",
|
||||||
|
)
|
||||||
account.id = fake.uuid4()
|
account.id = fake.uuid4()
|
||||||
account.email = fake.email()
|
# monkey-patch attributes for test setup
|
||||||
account.name = fake.name()
|
|
||||||
account.avatar_url = fake.url()
|
|
||||||
account.tenant_id = fake.uuid4()
|
account.tenant_id = fake.uuid4()
|
||||||
account.status = "active"
|
|
||||||
account.type = "normal"
|
account.type = "normal"
|
||||||
account.role = "owner"
|
account.role = "owner"
|
||||||
account.interface_language = "en-US"
|
|
||||||
account.created_at = fake.date_time_this_year()
|
account.created_at = fake.date_time_this_year()
|
||||||
account.updated_at = account.created_at
|
account.updated_at = account.created_at
|
||||||
|
|
||||||
# Create a tenant for the account
|
# Create a tenant for the account
|
||||||
from models.account import Tenant
|
from models.account import Tenant
|
||||||
|
|
||||||
tenant = Tenant()
|
tenant = Tenant(
|
||||||
|
name=f"Test Tenant {fake.company()}",
|
||||||
|
plan="basic",
|
||||||
|
status="active",
|
||||||
|
)
|
||||||
tenant.id = account.tenant_id
|
tenant.id = account.tenant_id
|
||||||
tenant.name = f"Test Tenant {fake.company()}"
|
|
||||||
tenant.plan = "basic"
|
|
||||||
tenant.status = "active"
|
|
||||||
tenant.created_at = fake.date_time_this_year()
|
tenant.created_at = fake.date_time_this_year()
|
||||||
tenant.updated_at = tenant.created_at
|
tenant.updated_at = tenant.created_at
|
||||||
|
|
||||||
|
|
@ -91,20 +94,21 @@ class TestDisableSegmentsFromIndexTask:
|
||||||
Dataset: Created test dataset instance
|
Dataset: Created test dataset instance
|
||||||
"""
|
"""
|
||||||
fake = fake or Faker()
|
fake = fake or Faker()
|
||||||
dataset = Dataset()
|
dataset = Dataset(
|
||||||
dataset.id = fake.uuid4()
|
id=fake.uuid4(),
|
||||||
dataset.tenant_id = account.tenant_id
|
tenant_id=account.tenant_id,
|
||||||
dataset.name = f"Test Dataset {fake.word()}"
|
name=f"Test Dataset {fake.word()}",
|
||||||
dataset.description = fake.text(max_nb_chars=200)
|
description=fake.text(max_nb_chars=200),
|
||||||
dataset.provider = "vendor"
|
provider="vendor",
|
||||||
dataset.permission = "only_me"
|
permission="only_me",
|
||||||
dataset.data_source_type = "upload_file"
|
data_source_type="upload_file",
|
||||||
dataset.indexing_technique = "high_quality"
|
indexing_technique="high_quality",
|
||||||
dataset.created_by = account.id
|
created_by=account.id,
|
||||||
dataset.updated_by = account.id
|
updated_by=account.id,
|
||||||
dataset.embedding_model = "text-embedding-ada-002"
|
embedding_model="text-embedding-ada-002",
|
||||||
dataset.embedding_model_provider = "openai"
|
embedding_model_provider="openai",
|
||||||
dataset.built_in_field_enabled = False
|
built_in_field_enabled=False,
|
||||||
|
)
|
||||||
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
|
@ -128,6 +132,7 @@ class TestDisableSegmentsFromIndexTask:
|
||||||
"""
|
"""
|
||||||
fake = fake or Faker()
|
fake = fake or Faker()
|
||||||
document = DatasetDocument()
|
document = DatasetDocument()
|
||||||
|
|
||||||
document.id = fake.uuid4()
|
document.id = fake.uuid4()
|
||||||
document.tenant_id = dataset.tenant_id
|
document.tenant_id = dataset.tenant_id
|
||||||
document.dataset_id = dataset.id
|
document.dataset_id = dataset.id
|
||||||
|
|
@ -153,7 +158,6 @@ class TestDisableSegmentsFromIndexTask:
|
||||||
document.archived = False
|
document.archived = False
|
||||||
document.doc_form = "text_model" # Use text_model form for testing
|
document.doc_form = "text_model" # Use text_model form for testing
|
||||||
document.doc_language = "en"
|
document.doc_language = "en"
|
||||||
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
|
||||||
db.session.add(document)
|
db.session.add(document)
|
||||||
|
|
|
||||||
|
|
@ -96,9 +96,9 @@ class TestMailInviteMemberTask:
|
||||||
password=fake.password(),
|
password=fake.password(),
|
||||||
interface_language="en-US",
|
interface_language="en-US",
|
||||||
status=AccountStatus.ACTIVE.value,
|
status=AccountStatus.ACTIVE.value,
|
||||||
created_at=datetime.now(UTC),
|
|
||||||
updated_at=datetime.now(UTC),
|
|
||||||
)
|
)
|
||||||
|
account.created_at = datetime.now(UTC)
|
||||||
|
account.updated_at = datetime.now(UTC)
|
||||||
db_session_with_containers.add(account)
|
db_session_with_containers.add(account)
|
||||||
db_session_with_containers.commit()
|
db_session_with_containers.commit()
|
||||||
db_session_with_containers.refresh(account)
|
db_session_with_containers.refresh(account)
|
||||||
|
|
@ -106,9 +106,9 @@ class TestMailInviteMemberTask:
|
||||||
# Create tenant
|
# Create tenant
|
||||||
tenant = Tenant(
|
tenant = Tenant(
|
||||||
name=fake.company(),
|
name=fake.company(),
|
||||||
created_at=datetime.now(UTC),
|
|
||||||
updated_at=datetime.now(UTC),
|
|
||||||
)
|
)
|
||||||
|
tenant.created_at = datetime.now(UTC)
|
||||||
|
tenant.updated_at = datetime.now(UTC)
|
||||||
db_session_with_containers.add(tenant)
|
db_session_with_containers.add(tenant)
|
||||||
db_session_with_containers.commit()
|
db_session_with_containers.commit()
|
||||||
db_session_with_containers.refresh(tenant)
|
db_session_with_containers.refresh(tenant)
|
||||||
|
|
@ -118,8 +118,8 @@ class TestMailInviteMemberTask:
|
||||||
tenant_id=tenant.id,
|
tenant_id=tenant.id,
|
||||||
account_id=account.id,
|
account_id=account.id,
|
||||||
role=TenantAccountRole.OWNER.value,
|
role=TenantAccountRole.OWNER.value,
|
||||||
created_at=datetime.now(UTC),
|
|
||||||
)
|
)
|
||||||
|
tenant_join.created_at = datetime.now(UTC)
|
||||||
db_session_with_containers.add(tenant_join)
|
db_session_with_containers.add(tenant_join)
|
||||||
db_session_with_containers.commit()
|
db_session_with_containers.commit()
|
||||||
|
|
||||||
|
|
@ -164,9 +164,10 @@ class TestMailInviteMemberTask:
|
||||||
password="",
|
password="",
|
||||||
interface_language="en-US",
|
interface_language="en-US",
|
||||||
status=AccountStatus.PENDING.value,
|
status=AccountStatus.PENDING.value,
|
||||||
created_at=datetime.now(UTC),
|
|
||||||
updated_at=datetime.now(UTC),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
account.created_at = datetime.now(UTC)
|
||||||
|
account.updated_at = datetime.now(UTC)
|
||||||
db_session_with_containers.add(account)
|
db_session_with_containers.add(account)
|
||||||
db_session_with_containers.commit()
|
db_session_with_containers.commit()
|
||||||
db_session_with_containers.refresh(account)
|
db_session_with_containers.refresh(account)
|
||||||
|
|
@ -176,8 +177,8 @@ class TestMailInviteMemberTask:
|
||||||
tenant_id=tenant.id,
|
tenant_id=tenant.id,
|
||||||
account_id=account.id,
|
account_id=account.id,
|
||||||
role=TenantAccountRole.NORMAL.value,
|
role=TenantAccountRole.NORMAL.value,
|
||||||
created_at=datetime.now(UTC),
|
|
||||||
)
|
)
|
||||||
|
tenant_join.created_at = datetime.now(UTC)
|
||||||
db_session_with_containers.add(tenant_join)
|
db_session_with_containers.add(tenant_join)
|
||||||
db_session_with_containers.commit()
|
db_session_with_containers.commit()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ class TestExtractTenantId:
|
||||||
def test_extract_tenant_id_from_account_with_tenant(self):
|
def test_extract_tenant_id_from_account_with_tenant(self):
|
||||||
"""Test extracting tenant_id from Account with current_tenant_id."""
|
"""Test extracting tenant_id from Account with current_tenant_id."""
|
||||||
# Create a mock Account object
|
# Create a mock Account object
|
||||||
account = Account()
|
account = Account(name="test", email="test@example.com")
|
||||||
# Mock the current_tenant_id property
|
# Mock the current_tenant_id property
|
||||||
account._current_tenant = type("MockTenant", (), {"id": "account-tenant-123"})()
|
account._current_tenant = type("MockTenant", (), {"id": "account-tenant-123"})()
|
||||||
|
|
||||||
|
|
@ -21,7 +21,7 @@ class TestExtractTenantId:
|
||||||
def test_extract_tenant_id_from_account_without_tenant(self):
|
def test_extract_tenant_id_from_account_without_tenant(self):
|
||||||
"""Test extracting tenant_id from Account without current_tenant_id."""
|
"""Test extracting tenant_id from Account without current_tenant_id."""
|
||||||
# Create a mock Account object
|
# Create a mock Account object
|
||||||
account = Account()
|
account = Account(name="test", email="test@example.com")
|
||||||
account._current_tenant = None
|
account._current_tenant = None
|
||||||
|
|
||||||
tenant_id = extract_tenant_id(account)
|
tenant_id = extract_tenant_id(account)
|
||||||
|
|
|
||||||
|
|
@ -59,12 +59,11 @@ def session():
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_user():
|
def mock_user():
|
||||||
"""Create a user instance for testing."""
|
"""Create a user instance for testing."""
|
||||||
user = Account()
|
user = Account(name="test", email="test@example.com")
|
||||||
user.id = "test-user-id"
|
user.id = "test-user-id"
|
||||||
|
|
||||||
tenant = Tenant()
|
tenant = Tenant(name="Test Workspace")
|
||||||
tenant.id = "test-tenant"
|
tenant.id = "test-tenant"
|
||||||
tenant.name = "Test Workspace"
|
|
||||||
user._current_tenant = MagicMock()
|
user._current_tenant = MagicMock()
|
||||||
user._current_tenant.id = "test-tenant"
|
user._current_tenant.id = "test-tenant"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,8 @@ class TestDraftVariableSaver:
|
||||||
|
|
||||||
def test__should_variable_be_visible(self):
|
def test__should_variable_be_visible(self):
|
||||||
mock_session = MagicMock(spec=Session)
|
mock_session = MagicMock(spec=Session)
|
||||||
mock_user = Account(id=str(uuid.uuid4()))
|
mock_user = Account(name="test", email="test@example.com")
|
||||||
|
mock_user.id = str(uuid.uuid4())
|
||||||
test_app_id = self._get_test_app_id()
|
test_app_id = self._get_test_app_id()
|
||||||
saver = DraftVariableSaver(
|
saver = DraftVariableSaver(
|
||||||
session=mock_session,
|
session=mock_session,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue