mirror of https://github.com/langgenius/dify.git
Refactor account models to use SQLAlchemy 2.0 dataclass mapping (#26415)
Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
2b6882bd97
commit
8a2b208299
|
|
@ -1,15 +1,16 @@
|
|||
import enum
|
||||
import json
|
||||
from dataclasses import field
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask_login import UserMixin # type: ignore[import-untyped]
|
||||
from sqlalchemy import DateTime, String, func, select
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column, reconstructor
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from models.base import Base
|
||||
from models.base import TypeBase
|
||||
|
||||
from .engine import db
|
||||
from .types import StringUUID
|
||||
|
|
@ -83,31 +84,37 @@ class AccountStatus(enum.StrEnum):
|
|||
CLOSED = "closed"
|
||||
|
||||
|
||||
class Account(UserMixin, Base):
|
||||
class Account(UserMixin, TypeBase):
|
||||
__tablename__ = "accounts"
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="account_pkey"), sa.Index("account_email_idx", "email"))
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
name: Mapped[str] = mapped_column(String(255))
|
||||
email: Mapped[str] = mapped_column(String(255))
|
||||
password: Mapped[str | None] = mapped_column(String(255))
|
||||
password_salt: Mapped[str | None] = mapped_column(String(255))
|
||||
avatar: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
interface_language: Mapped[str | None] = mapped_column(String(255))
|
||||
interface_theme: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
timezone: Mapped[str | None] = mapped_column(String(255))
|
||||
last_login_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
last_login_ip: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
last_active_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
|
||||
status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'active'::character varying"))
|
||||
initialized_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
|
||||
password: Mapped[str | None] = mapped_column(String(255), default=None)
|
||||
password_salt: Mapped[str | None] = mapped_column(String(255), default=None)
|
||||
avatar: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
|
||||
interface_language: Mapped[str | None] = mapped_column(String(255), default=None)
|
||||
interface_theme: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
|
||||
timezone: Mapped[str | None] = mapped_column(String(255), default=None)
|
||||
last_login_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
|
||||
last_login_ip: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
|
||||
last_active_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||
)
|
||||
status: Mapped[str] = mapped_column(
|
||||
String(16), server_default=sa.text("'active'::character varying"), default="active"
|
||||
)
|
||||
initialized_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||
)
|
||||
|
||||
@reconstructor
|
||||
def init_on_load(self):
|
||||
self.role: TenantAccountRole | None = None
|
||||
self._current_tenant: Tenant | None = None
|
||||
role: TenantAccountRole | None = field(default=None, init=False)
|
||||
_current_tenant: "Tenant | None" = field(default=None, init=False)
|
||||
|
||||
@property
|
||||
def is_password_set(self):
|
||||
|
|
@ -226,18 +233,24 @@ class TenantStatus(enum.StrEnum):
|
|||
ARCHIVE = "archive"
|
||||
|
||||
|
||||
class Tenant(Base):
|
||||
class Tenant(TypeBase):
|
||||
__tablename__ = "tenants"
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="tenant_pkey"),)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
name: Mapped[str] = mapped_column(String(255))
|
||||
encrypt_public_key: Mapped[str | None] = mapped_column(sa.Text)
|
||||
plan: Mapped[str] = mapped_column(String(255), server_default=sa.text("'basic'::character varying"))
|
||||
status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'::character varying"))
|
||||
custom_config: Mapped[str | None] = mapped_column(sa.Text)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
|
||||
encrypt_public_key: Mapped[str | None] = mapped_column(sa.Text, default=None)
|
||||
plan: Mapped[str] = mapped_column(
|
||||
String(255), server_default=sa.text("'basic'::character varying"), default="basic"
|
||||
)
|
||||
status: Mapped[str] = mapped_column(
|
||||
String(255), server_default=sa.text("'normal'::character varying"), default="normal"
|
||||
)
|
||||
custom_config: Mapped[str | None] = mapped_column(sa.Text, default=None)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), init=False)
|
||||
|
||||
def get_accounts(self) -> list[Account]:
|
||||
return list(
|
||||
|
|
@ -257,7 +270,7 @@ class Tenant(Base):
|
|||
self.custom_config = json.dumps(value)
|
||||
|
||||
|
||||
class TenantAccountJoin(Base):
|
||||
class TenantAccountJoin(TypeBase):
|
||||
__tablename__ = "tenant_account_joins"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"),
|
||||
|
|
@ -266,17 +279,21 @@ class TenantAccountJoin(Base):
|
|||
sa.UniqueConstraint("tenant_id", "account_id", name="unique_tenant_account_join"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID)
|
||||
account_id: Mapped[str] = mapped_column(StringUUID)
|
||||
current: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"))
|
||||
role: Mapped[str] = mapped_column(String(16), server_default="normal")
|
||||
invited_by: Mapped[str | None] = mapped_column(StringUUID)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
|
||||
current: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"), default=False)
|
||||
role: Mapped[str] = mapped_column(String(16), server_default="normal", default="normal")
|
||||
invited_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||
)
|
||||
|
||||
|
||||
class AccountIntegrate(Base):
|
||||
class AccountIntegrate(TypeBase):
|
||||
__tablename__ = "account_integrates"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="account_integrate_pkey"),
|
||||
|
|
@ -284,16 +301,20 @@ class AccountIntegrate(Base):
|
|||
sa.UniqueConstraint("provider", "open_id", name="unique_provider_open_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
account_id: Mapped[str] = mapped_column(StringUUID)
|
||||
provider: Mapped[str] = mapped_column(String(16))
|
||||
open_id: Mapped[str] = mapped_column(String(255))
|
||||
encrypted_token: Mapped[str] = mapped_column(String(255))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||
)
|
||||
|
||||
|
||||
class InvitationCode(Base):
|
||||
class InvitationCode(TypeBase):
|
||||
__tablename__ = "invitation_codes"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="invitation_code_pkey"),
|
||||
|
|
@ -301,18 +322,22 @@ class InvitationCode(Base):
|
|||
sa.Index("invitation_codes_code_idx", "code", "status"),
|
||||
)
|
||||
|
||||
id: Mapped[int] = mapped_column(sa.Integer)
|
||||
id: Mapped[int] = mapped_column(sa.Integer, init=False)
|
||||
batch: Mapped[str] = mapped_column(String(255))
|
||||
code: Mapped[str] = mapped_column(String(32))
|
||||
status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'unused'::character varying"))
|
||||
used_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||
used_by_tenant_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||
used_by_account_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||
deprecated_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=sa.text("CURRENT_TIMESTAMP(0)"))
|
||||
status: Mapped[str] = mapped_column(
|
||||
String(16), server_default=sa.text("'unused'::character varying"), default="unused"
|
||||
)
|
||||
used_at: Mapped[datetime | None] = mapped_column(DateTime, default=None)
|
||||
used_by_tenant_id: Mapped[str | None] = mapped_column(StringUUID, default=None)
|
||||
used_by_account_id: Mapped[str | None] = mapped_column(StringUUID, default=None)
|
||||
deprecated_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=sa.text("CURRENT_TIMESTAMP(0)"), nullable=False, init=False
|
||||
)
|
||||
|
||||
|
||||
class TenantPluginPermission(Base):
|
||||
class TenantPluginPermission(TypeBase):
|
||||
class InstallPermission(enum.StrEnum):
|
||||
EVERYONE = "everyone"
|
||||
ADMINS = "admins"
|
||||
|
|
@ -329,13 +354,17 @@ class TenantPluginPermission(Base):
|
|||
sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
install_permission: Mapped[InstallPermission] = mapped_column(String(16), nullable=False, server_default="everyone")
|
||||
debug_permission: Mapped[DebugPermission] = mapped_column(String(16), nullable=False, server_default="noone")
|
||||
install_permission: Mapped[InstallPermission] = mapped_column(
|
||||
String(16), nullable=False, server_default="everyone", default=InstallPermission.EVERYONE
|
||||
)
|
||||
debug_permission: Mapped[DebugPermission] = mapped_column(
|
||||
String(16), nullable=False, server_default="noone", default=DebugPermission.NOBODY
|
||||
)
|
||||
|
||||
|
||||
class TenantPluginAutoUpgradeStrategy(Base):
|
||||
class TenantPluginAutoUpgradeStrategy(TypeBase):
|
||||
class StrategySetting(enum.StrEnum):
|
||||
DISABLED = "disabled"
|
||||
FIX_ONLY = "fix_only"
|
||||
|
|
@ -352,12 +381,20 @@ class TenantPluginAutoUpgradeStrategy(Base):
|
|||
sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin_auto_upgrade_strategy"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
strategy_setting: Mapped[StrategySetting] = mapped_column(String(16), nullable=False, server_default="fix_only")
|
||||
upgrade_time_of_day: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) # seconds of the day
|
||||
upgrade_mode: Mapped[UpgradeMode] = mapped_column(String(16), nullable=False, server_default="exclude")
|
||||
exclude_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False) # plugin_id (author/name)
|
||||
include_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False) # plugin_id (author/name)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
strategy_setting: Mapped[StrategySetting] = mapped_column(
|
||||
String(16), nullable=False, server_default="fix_only", default=StrategySetting.FIX_ONLY
|
||||
)
|
||||
upgrade_mode: Mapped[UpgradeMode] = mapped_column(
|
||||
String(16), nullable=False, server_default="exclude", default=UpgradeMode.EXCLUDE
|
||||
)
|
||||
exclude_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False, default_factory=list)
|
||||
include_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False, default_factory=list)
|
||||
upgrade_time_of_day: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
|
|
|
|||
|
|
@ -246,10 +246,8 @@ class AccountService:
|
|||
)
|
||||
)
|
||||
|
||||
account = Account()
|
||||
account.email = email
|
||||
account.name = name
|
||||
|
||||
password_to_set = None
|
||||
salt_to_set = None
|
||||
if password:
|
||||
valid_password(password)
|
||||
|
||||
|
|
@ -261,14 +259,18 @@ class AccountService:
|
|||
password_hashed = hash_password(password, salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
password_to_set = base64_password_hashed
|
||||
salt_to_set = base64_salt
|
||||
|
||||
account.interface_language = interface_language
|
||||
account.interface_theme = interface_theme
|
||||
|
||||
# Set timezone based on language
|
||||
account.timezone = language_timezone_mapping.get(interface_language, "UTC")
|
||||
account = Account(
|
||||
name=name,
|
||||
email=email,
|
||||
password=password_to_set,
|
||||
password_salt=salt_to_set,
|
||||
interface_language=interface_language,
|
||||
interface_theme=interface_theme,
|
||||
timezone=language_timezone_mapping.get(interface_language, "UTC"),
|
||||
)
|
||||
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
|
|
|
|||
|
|
@ -33,17 +33,19 @@ class TestChatMessageApiPermissions:
|
|||
@pytest.fixture
|
||||
def mock_account(self, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Create a mock Account for testing."""
|
||||
account = Account()
|
||||
account.id = str(uuid.uuid4())
|
||||
account.name = "Test User"
|
||||
account.email = "test@example.com"
|
||||
|
||||
account = Account(
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
)
|
||||
account.last_active_at = naive_utc_now()
|
||||
account.created_at = naive_utc_now()
|
||||
account.updated_at = naive_utc_now()
|
||||
account.id = str(uuid.uuid4())
|
||||
|
||||
tenant = Tenant()
|
||||
# Create mock tenant
|
||||
tenant = Tenant(name="Test Tenant")
|
||||
tenant.id = str(uuid.uuid4())
|
||||
tenant.name = "Test Tenant"
|
||||
|
||||
mock_session_instance = mock.Mock()
|
||||
|
||||
|
|
|
|||
|
|
@ -32,17 +32,16 @@ class TestModelConfigResourcePermissions:
|
|||
@pytest.fixture
|
||||
def mock_account(self, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Create a mock Account for testing."""
|
||||
account = Account()
|
||||
|
||||
account = Account(name="Test User", email="test@example.com")
|
||||
account.id = str(uuid.uuid4())
|
||||
account.name = "Test User"
|
||||
account.email = "test@example.com"
|
||||
account.last_active_at = naive_utc_now()
|
||||
account.created_at = naive_utc_now()
|
||||
account.updated_at = naive_utc_now()
|
||||
|
||||
tenant = Tenant()
|
||||
# Create mock tenant
|
||||
tenant = Tenant(name="Test Tenant")
|
||||
tenant.id = str(uuid.uuid4())
|
||||
tenant.name = "Test Tenant"
|
||||
|
||||
mock_session_instance = mock.Mock()
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from services.errors.account import (
|
|||
AccountPasswordError,
|
||||
AccountRegisterError,
|
||||
CurrentPasswordIncorrectError,
|
||||
TenantNotFoundError,
|
||||
)
|
||||
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
|
||||
|
||||
|
|
@ -1414,7 +1415,7 @@ class TestTenantService:
|
|||
)
|
||||
|
||||
# Try to get current tenant (should fail)
|
||||
with pytest.raises(AttributeError):
|
||||
with pytest.raises((AttributeError, TenantNotFoundError)):
|
||||
TenantService.get_current_tenant_by_account(account)
|
||||
|
||||
def test_switch_tenant_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
|
|
|
|||
|
|
@ -44,27 +44,26 @@ class TestWorkflowService:
|
|||
Account: Created test account instance
|
||||
"""
|
||||
fake = fake or Faker()
|
||||
account = Account()
|
||||
account.id = fake.uuid4()
|
||||
account.email = fake.email()
|
||||
account.name = fake.name()
|
||||
account.avatar_url = fake.url()
|
||||
account.tenant_id = fake.uuid4()
|
||||
account.status = "active"
|
||||
account.type = "normal"
|
||||
account.role = "owner"
|
||||
account.interface_language = "en-US" # Set interface language for Site creation
|
||||
account = Account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
avatar=fake.url(),
|
||||
status="active",
|
||||
interface_language="en-US", # Set interface language for Site creation
|
||||
)
|
||||
account.created_at = fake.date_time_this_year()
|
||||
account.id = fake.uuid4()
|
||||
account.updated_at = account.created_at
|
||||
|
||||
# Create a tenant for the account
|
||||
from models.account import Tenant
|
||||
|
||||
tenant = Tenant()
|
||||
tenant.id = account.tenant_id
|
||||
tenant.name = f"Test Tenant {fake.company()}"
|
||||
tenant.plan = "basic"
|
||||
tenant.status = "active"
|
||||
tenant = Tenant(
|
||||
name=f"Test Tenant {fake.company()}",
|
||||
plan="basic",
|
||||
status="active",
|
||||
)
|
||||
tenant.id = account.current_tenant_id
|
||||
tenant.created_at = fake.date_time_this_year()
|
||||
tenant.updated_at = tenant.created_at
|
||||
|
||||
|
|
@ -91,20 +90,21 @@ class TestWorkflowService:
|
|||
App: Created test app instance
|
||||
"""
|
||||
fake = fake or Faker()
|
||||
app = App()
|
||||
app.id = fake.uuid4()
|
||||
app.tenant_id = fake.uuid4()
|
||||
app.name = fake.company()
|
||||
app.description = fake.text()
|
||||
app.mode = AppMode.WORKFLOW
|
||||
app.icon_type = "emoji"
|
||||
app.icon = "🤖"
|
||||
app.icon_background = "#FFEAD5"
|
||||
app.enable_site = True
|
||||
app.enable_api = True
|
||||
app.created_by = fake.uuid4()
|
||||
app = App(
|
||||
id=fake.uuid4(),
|
||||
tenant_id=fake.uuid4(),
|
||||
name=fake.company(),
|
||||
description=fake.text(),
|
||||
mode=AppMode.WORKFLOW,
|
||||
icon_type="emoji",
|
||||
icon="🤖",
|
||||
icon_background="#FFEAD5",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
created_by=fake.uuid4(),
|
||||
workflow_id=None, # Will be set when workflow is created
|
||||
)
|
||||
app.updated_by = app.created_by
|
||||
app.workflow_id = None # Will be set when workflow is created
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
|
|
@ -126,19 +126,20 @@ class TestWorkflowService:
|
|||
Workflow: Created test workflow instance
|
||||
"""
|
||||
fake = fake or Faker()
|
||||
workflow = Workflow()
|
||||
workflow.id = fake.uuid4()
|
||||
workflow.tenant_id = app.tenant_id
|
||||
workflow.app_id = app.id
|
||||
workflow.type = WorkflowType.WORKFLOW.value
|
||||
workflow.version = Workflow.VERSION_DRAFT
|
||||
workflow.graph = json.dumps({"nodes": [], "edges": []})
|
||||
workflow.features = json.dumps({"features": []})
|
||||
workflow = Workflow(
|
||||
id=fake.uuid4(),
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
version=Workflow.VERSION_DRAFT,
|
||||
graph=json.dumps({"nodes": [], "edges": []}),
|
||||
features=json.dumps({"features": []}),
|
||||
# unique_hash is a computed property based on graph and features
|
||||
workflow.created_by = account.id
|
||||
workflow.updated_by = account.id
|
||||
workflow.environment_variables = []
|
||||
workflow.conversation_variables = []
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
|
|
|
|||
|
|
@ -48,11 +48,8 @@ class TestDeleteSegmentFromIndexTask:
|
|||
Tenant: Created test tenant instance
|
||||
"""
|
||||
fake = fake or Faker()
|
||||
tenant = Tenant()
|
||||
tenant = Tenant(name=f"Test Tenant {fake.company()}", plan="basic", status="active")
|
||||
tenant.id = fake.uuid4()
|
||||
tenant.name = f"Test Tenant {fake.company()}"
|
||||
tenant.plan = "basic"
|
||||
tenant.status = "active"
|
||||
tenant.created_at = fake.date_time_this_year()
|
||||
tenant.updated_at = tenant.created_at
|
||||
|
||||
|
|
@ -73,16 +70,14 @@ class TestDeleteSegmentFromIndexTask:
|
|||
Account: Created test account instance
|
||||
"""
|
||||
fake = fake or Faker()
|
||||
account = Account()
|
||||
account = Account(
|
||||
name=fake.name(),
|
||||
email=fake.email(),
|
||||
avatar=fake.url(),
|
||||
status="active",
|
||||
interface_language="en-US",
|
||||
)
|
||||
account.id = fake.uuid4()
|
||||
account.email = fake.email()
|
||||
account.name = fake.name()
|
||||
account.avatar_url = fake.url()
|
||||
account.tenant_id = tenant.id
|
||||
account.status = "active"
|
||||
account.type = "normal"
|
||||
account.role = "owner"
|
||||
account.interface_language = "en-US"
|
||||
account.created_at = fake.date_time_this_year()
|
||||
account.updated_at = account.created_at
|
||||
|
||||
|
|
|
|||
|
|
@ -43,27 +43,30 @@ class TestDisableSegmentsFromIndexTask:
|
|||
Account: Created test account instance
|
||||
"""
|
||||
fake = fake or Faker()
|
||||
account = Account()
|
||||
account = Account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
avatar=fake.url(),
|
||||
status="active",
|
||||
interface_language="en-US",
|
||||
)
|
||||
account.id = fake.uuid4()
|
||||
account.email = fake.email()
|
||||
account.name = fake.name()
|
||||
account.avatar_url = fake.url()
|
||||
# monkey-patch attributes for test setup
|
||||
account.tenant_id = fake.uuid4()
|
||||
account.status = "active"
|
||||
account.type = "normal"
|
||||
account.role = "owner"
|
||||
account.interface_language = "en-US"
|
||||
account.created_at = fake.date_time_this_year()
|
||||
account.updated_at = account.created_at
|
||||
|
||||
# Create a tenant for the account
|
||||
from models.account import Tenant
|
||||
|
||||
tenant = Tenant()
|
||||
tenant = Tenant(
|
||||
name=f"Test Tenant {fake.company()}",
|
||||
plan="basic",
|
||||
status="active",
|
||||
)
|
||||
tenant.id = account.tenant_id
|
||||
tenant.name = f"Test Tenant {fake.company()}"
|
||||
tenant.plan = "basic"
|
||||
tenant.status = "active"
|
||||
tenant.created_at = fake.date_time_this_year()
|
||||
tenant.updated_at = tenant.created_at
|
||||
|
||||
|
|
@ -91,20 +94,21 @@ class TestDisableSegmentsFromIndexTask:
|
|||
Dataset: Created test dataset instance
|
||||
"""
|
||||
fake = fake or Faker()
|
||||
dataset = Dataset()
|
||||
dataset.id = fake.uuid4()
|
||||
dataset.tenant_id = account.tenant_id
|
||||
dataset.name = f"Test Dataset {fake.word()}"
|
||||
dataset.description = fake.text(max_nb_chars=200)
|
||||
dataset.provider = "vendor"
|
||||
dataset.permission = "only_me"
|
||||
dataset.data_source_type = "upload_file"
|
||||
dataset.indexing_technique = "high_quality"
|
||||
dataset.created_by = account.id
|
||||
dataset.updated_by = account.id
|
||||
dataset.embedding_model = "text-embedding-ada-002"
|
||||
dataset.embedding_model_provider = "openai"
|
||||
dataset.built_in_field_enabled = False
|
||||
dataset = Dataset(
|
||||
id=fake.uuid4(),
|
||||
tenant_id=account.tenant_id,
|
||||
name=f"Test Dataset {fake.word()}",
|
||||
description=fake.text(max_nb_chars=200),
|
||||
provider="vendor",
|
||||
permission="only_me",
|
||||
data_source_type="upload_file",
|
||||
indexing_technique="high_quality",
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
embedding_model="text-embedding-ada-002",
|
||||
embedding_model_provider="openai",
|
||||
built_in_field_enabled=False,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
|
|
@ -128,6 +132,7 @@ class TestDisableSegmentsFromIndexTask:
|
|||
"""
|
||||
fake = fake or Faker()
|
||||
document = DatasetDocument()
|
||||
|
||||
document.id = fake.uuid4()
|
||||
document.tenant_id = dataset.tenant_id
|
||||
document.dataset_id = dataset.id
|
||||
|
|
@ -153,7 +158,6 @@ class TestDisableSegmentsFromIndexTask:
|
|||
document.archived = False
|
||||
document.doc_form = "text_model" # Use text_model form for testing
|
||||
document.doc_language = "en"
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(document)
|
||||
|
|
|
|||
|
|
@ -96,9 +96,9 @@ class TestMailInviteMemberTask:
|
|||
password=fake.password(),
|
||||
interface_language="en-US",
|
||||
status=AccountStatus.ACTIVE.value,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
account.created_at = datetime.now(UTC)
|
||||
account.updated_at = datetime.now(UTC)
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
db_session_with_containers.refresh(account)
|
||||
|
|
@ -106,9 +106,9 @@ class TestMailInviteMemberTask:
|
|||
# Create tenant
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
tenant.created_at = datetime.now(UTC)
|
||||
tenant.updated_at = datetime.now(UTC)
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
db_session_with_containers.refresh(tenant)
|
||||
|
|
@ -118,8 +118,8 @@ class TestMailInviteMemberTask:
|
|||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER.value,
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
tenant_join.created_at = datetime.now(UTC)
|
||||
db_session_with_containers.add(tenant_join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
|
|
@ -164,9 +164,10 @@ class TestMailInviteMemberTask:
|
|||
password="",
|
||||
interface_language="en-US",
|
||||
status=AccountStatus.PENDING.value,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
account.created_at = datetime.now(UTC)
|
||||
account.updated_at = datetime.now(UTC)
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
db_session_with_containers.refresh(account)
|
||||
|
|
@ -176,8 +177,8 @@ class TestMailInviteMemberTask:
|
|||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.NORMAL.value,
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
tenant_join.created_at = datetime.now(UTC)
|
||||
db_session_with_containers.add(tenant_join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ class TestExtractTenantId:
|
|||
def test_extract_tenant_id_from_account_with_tenant(self):
|
||||
"""Test extracting tenant_id from Account with current_tenant_id."""
|
||||
# Create a mock Account object
|
||||
account = Account()
|
||||
account = Account(name="test", email="test@example.com")
|
||||
# Mock the current_tenant_id property
|
||||
account._current_tenant = type("MockTenant", (), {"id": "account-tenant-123"})()
|
||||
|
||||
|
|
@ -21,7 +21,7 @@ class TestExtractTenantId:
|
|||
def test_extract_tenant_id_from_account_without_tenant(self):
|
||||
"""Test extracting tenant_id from Account without current_tenant_id."""
|
||||
# Create a mock Account object
|
||||
account = Account()
|
||||
account = Account(name="test", email="test@example.com")
|
||||
account._current_tenant = None
|
||||
|
||||
tenant_id = extract_tenant_id(account)
|
||||
|
|
|
|||
|
|
@ -59,12 +59,11 @@ def session():
|
|||
@pytest.fixture
|
||||
def mock_user():
|
||||
"""Create a user instance for testing."""
|
||||
user = Account()
|
||||
user = Account(name="test", email="test@example.com")
|
||||
user.id = "test-user-id"
|
||||
|
||||
tenant = Tenant()
|
||||
tenant = Tenant(name="Test Workspace")
|
||||
tenant.id = "test-tenant"
|
||||
tenant.name = "Test Workspace"
|
||||
user._current_tenant = MagicMock()
|
||||
user._current_tenant.id = "test-tenant"
|
||||
|
||||
|
|
|
|||
|
|
@ -47,7 +47,8 @@ class TestDraftVariableSaver:
|
|||
|
||||
def test__should_variable_be_visible(self):
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_user = Account(id=str(uuid.uuid4()))
|
||||
mock_user = Account(name="test", email="test@example.com")
|
||||
mock_user.id = str(uuid.uuid4())
|
||||
test_app_id = self._get_test_app_id()
|
||||
saver = DraftVariableSaver(
|
||||
session=mock_session,
|
||||
|
|
|
|||
Loading…
Reference in New Issue