diff --git a/api/models/account.py b/api/models/account.py index 8c1f990aa2..86cd9e41b5 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -1,15 +1,16 @@ import enum import json +from dataclasses import field from datetime import datetime from typing import Any, Optional import sqlalchemy as sa from flask_login import UserMixin # type: ignore[import-untyped] from sqlalchemy import DateTime, String, func, select -from sqlalchemy.orm import Mapped, Session, mapped_column, reconstructor +from sqlalchemy.orm import Mapped, Session, mapped_column from typing_extensions import deprecated -from models.base import Base +from models.base import TypeBase from .engine import db from .types import StringUUID @@ -83,31 +84,37 @@ class AccountStatus(enum.StrEnum): CLOSED = "closed" -class Account(UserMixin, Base): +class Account(UserMixin, TypeBase): __tablename__ = "accounts" __table_args__ = (sa.PrimaryKeyConstraint("id", name="account_pkey"), sa.Index("account_email_idx", "email")) - id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False) name: Mapped[str] = mapped_column(String(255)) email: Mapped[str] = mapped_column(String(255)) - password: Mapped[str | None] = mapped_column(String(255)) - password_salt: Mapped[str | None] = mapped_column(String(255)) - avatar: Mapped[str | None] = mapped_column(String(255), nullable=True) - interface_language: Mapped[str | None] = mapped_column(String(255)) - interface_theme: Mapped[str | None] = mapped_column(String(255), nullable=True) - timezone: Mapped[str | None] = mapped_column(String(255)) - last_login_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) - last_login_ip: Mapped[str | None] = mapped_column(String(255), nullable=True) - last_active_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) - status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'active'::character varying")) - initialized_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) - created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) - updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) + password: Mapped[str | None] = mapped_column(String(255), default=None) + password_salt: Mapped[str | None] = mapped_column(String(255), default=None) + avatar: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) + interface_language: Mapped[str | None] = mapped_column(String(255), default=None) + interface_theme: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) + timezone: Mapped[str | None] = mapped_column(String(255), default=None) + last_login_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None) + last_login_ip: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) + last_active_at: Mapped[datetime] = mapped_column( + DateTime, server_default=func.current_timestamp(), nullable=False, init=False + ) + status: Mapped[str] = mapped_column( + String(16), server_default=sa.text("'active'::character varying"), default="active" + ) + initialized_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None) + created_at: Mapped[datetime] = mapped_column( + DateTime, server_default=func.current_timestamp(), nullable=False, init=False + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime, server_default=func.current_timestamp(), nullable=False, init=False + ) - @reconstructor - def init_on_load(self): - self.role: TenantAccountRole | None = None - self._current_tenant: Tenant | None = None + role: TenantAccountRole | None = field(default=None, init=False) + _current_tenant: "Tenant | None" = field(default=None, init=False) @property def is_password_set(self): @@ -226,18 +233,24 @@ class TenantStatus(enum.StrEnum): ARCHIVE = "archive" -class Tenant(Base): +class Tenant(TypeBase): __tablename__ = "tenants" __table_args__ = (sa.PrimaryKeyConstraint("id", name="tenant_pkey"),) - id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False) name: Mapped[str] = mapped_column(String(255)) - encrypt_public_key: Mapped[str | None] = mapped_column(sa.Text) - plan: Mapped[str] = mapped_column(String(255), server_default=sa.text("'basic'::character varying")) - status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'::character varying")) - custom_config: Mapped[str | None] = mapped_column(sa.Text) - created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) - updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) + encrypt_public_key: Mapped[str | None] = mapped_column(sa.Text, default=None) + plan: Mapped[str] = mapped_column( + String(255), server_default=sa.text("'basic'::character varying"), default="basic" + ) + status: Mapped[str] = mapped_column( + String(255), server_default=sa.text("'normal'::character varying"), default="normal" + ) + custom_config: Mapped[str | None] = mapped_column(sa.Text, default=None) + created_at: Mapped[datetime] = mapped_column( + DateTime, server_default=func.current_timestamp(), nullable=False, init=False + ) + updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), init=False) def get_accounts(self) -> list[Account]: return list( @@ -257,7 +270,7 @@ class Tenant(Base): self.custom_config = json.dumps(value) -class TenantAccountJoin(Base): +class TenantAccountJoin(TypeBase): __tablename__ = "tenant_account_joins" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"), @@ -266,17 +279,21 @@ class TenantAccountJoin(Base): sa.UniqueConstraint("tenant_id", "account_id", name="unique_tenant_account_join"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False) tenant_id: Mapped[str] = mapped_column(StringUUID) account_id: Mapped[str] = mapped_column(StringUUID) - current: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false")) - role: Mapped[str] = mapped_column(String(16), server_default="normal") - invited_by: Mapped[str | None] = mapped_column(StringUUID) - created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) + current: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"), default=False) + role: Mapped[str] = mapped_column(String(16), server_default="normal", default="normal") + invited_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) + created_at: Mapped[datetime] = mapped_column( + DateTime, server_default=func.current_timestamp(), nullable=False, init=False + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime, server_default=func.current_timestamp(), nullable=False, init=False + ) -class AccountIntegrate(Base): +class AccountIntegrate(TypeBase): __tablename__ = "account_integrates" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="account_integrate_pkey"), @@ -284,16 +301,20 @@ class AccountIntegrate(Base): sa.UniqueConstraint("provider", "open_id", name="unique_provider_open_id"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False) account_id: Mapped[str] = mapped_column(StringUUID) provider: Mapped[str] = mapped_column(String(16)) open_id: Mapped[str] = mapped_column(String(255)) encrypted_token: Mapped[str] = mapped_column(String(255)) - created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column( + DateTime, server_default=func.current_timestamp(), nullable=False, init=False + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime, server_default=func.current_timestamp(), nullable=False, init=False + ) -class InvitationCode(Base): +class InvitationCode(TypeBase): __tablename__ = "invitation_codes" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="invitation_code_pkey"), @@ -301,18 +322,22 @@ class InvitationCode(Base): sa.Index("invitation_codes_code_idx", "code", "status"), ) - id: Mapped[int] = mapped_column(sa.Integer) + id: Mapped[int] = mapped_column(sa.Integer, init=False) batch: Mapped[str] = mapped_column(String(255)) code: Mapped[str] = mapped_column(String(32)) - status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'unused'::character varying")) - used_at: Mapped[datetime | None] = mapped_column(DateTime) - used_by_tenant_id: Mapped[str | None] = mapped_column(StringUUID) - used_by_account_id: Mapped[str | None] = mapped_column(StringUUID) - deprecated_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) - created_at: Mapped[datetime] = mapped_column(DateTime, server_default=sa.text("CURRENT_TIMESTAMP(0)")) + status: Mapped[str] = mapped_column( + String(16), server_default=sa.text("'unused'::character varying"), default="unused" + ) + used_at: Mapped[datetime | None] = mapped_column(DateTime, default=None) + used_by_tenant_id: Mapped[str | None] = mapped_column(StringUUID, default=None) + used_by_account_id: Mapped[str | None] = mapped_column(StringUUID, default=None) + deprecated_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None) + created_at: Mapped[datetime] = mapped_column( + DateTime, server_default=sa.text("CURRENT_TIMESTAMP(0)"), nullable=False, init=False + ) -class TenantPluginPermission(Base): +class TenantPluginPermission(TypeBase): class InstallPermission(enum.StrEnum): EVERYONE = "everyone" ADMINS = "admins" @@ -329,13 +354,17 @@ class TenantPluginPermission(Base): sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - install_permission: Mapped[InstallPermission] = mapped_column(String(16), nullable=False, server_default="everyone") - debug_permission: Mapped[DebugPermission] = mapped_column(String(16), nullable=False, server_default="noone") + install_permission: Mapped[InstallPermission] = mapped_column( + String(16), nullable=False, server_default="everyone", default=InstallPermission.EVERYONE + ) + debug_permission: Mapped[DebugPermission] = mapped_column( + String(16), nullable=False, server_default="noone", default=DebugPermission.NOBODY + ) -class TenantPluginAutoUpgradeStrategy(Base): +class TenantPluginAutoUpgradeStrategy(TypeBase): class StrategySetting(enum.StrEnum): DISABLED = "disabled" FIX_ONLY = "fix_only" @@ -352,12 +381,20 @@ class TenantPluginAutoUpgradeStrategy(Base): sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin_auto_upgrade_strategy"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - strategy_setting: Mapped[StrategySetting] = mapped_column(String(16), nullable=False, server_default="fix_only") - upgrade_time_of_day: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) # seconds of the day - upgrade_mode: Mapped[UpgradeMode] = mapped_column(String(16), nullable=False, server_default="exclude") - exclude_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False) # plugin_id (author/name) - include_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False) # plugin_id (author/name) - created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + strategy_setting: Mapped[StrategySetting] = mapped_column( + String(16), nullable=False, server_default="fix_only", default=StrategySetting.FIX_ONLY + ) + upgrade_mode: Mapped[UpgradeMode] = mapped_column( + String(16), nullable=False, server_default="exclude", default=UpgradeMode.EXCLUDE + ) + exclude_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False, default_factory=list) + include_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False, default_factory=list) + upgrade_time_of_day: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) diff --git a/api/services/account_service.py b/api/services/account_service.py index 0e699d16da..77b8744020 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -246,10 +246,8 @@ class AccountService: ) ) - account = Account() - account.email = email - account.name = name - + password_to_set = None + salt_to_set = None if password: valid_password(password) @@ -261,14 +259,18 @@ class AccountService: password_hashed = hash_password(password, salt) base64_password_hashed = base64.b64encode(password_hashed).decode() - account.password = base64_password_hashed - account.password_salt = base64_salt + password_to_set = base64_password_hashed + salt_to_set = base64_salt - account.interface_language = interface_language - account.interface_theme = interface_theme - - # Set timezone based on language - account.timezone = language_timezone_mapping.get(interface_language, "UTC") + account = Account( + name=name, + email=email, + password=password_to_set, + password_salt=salt_to_set, + interface_language=interface_language, + interface_theme=interface_theme, + timezone=language_timezone_mapping.get(interface_language, "UTC"), + ) db.session.add(account) db.session.commit() diff --git a/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py b/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py index da1524ff2e..4d1c1227bd 100644 --- a/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py +++ b/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py @@ -33,17 +33,19 @@ class TestChatMessageApiPermissions: @pytest.fixture def mock_account(self, monkeypatch: pytest.MonkeyPatch): """Create a mock Account for testing.""" - account = Account() - account.id = str(uuid.uuid4()) - account.name = "Test User" - account.email = "test@example.com" + + account = Account( + name="Test User", + email="test@example.com", + ) account.last_active_at = naive_utc_now() account.created_at = naive_utc_now() account.updated_at = naive_utc_now() + account.id = str(uuid.uuid4()) - tenant = Tenant() + # Create mock tenant + tenant = Tenant(name="Test Tenant") tenant.id = str(uuid.uuid4()) - tenant.name = "Test Tenant" mock_session_instance = mock.Mock() diff --git a/api/tests/integration_tests/controllers/console/app/test_model_config_permissions.py b/api/tests/integration_tests/controllers/console/app/test_model_config_permissions.py index c0fd56ef63..e158f26f3a 100644 --- a/api/tests/integration_tests/controllers/console/app/test_model_config_permissions.py +++ b/api/tests/integration_tests/controllers/console/app/test_model_config_permissions.py @@ -32,17 +32,16 @@ class TestModelConfigResourcePermissions: @pytest.fixture def mock_account(self, monkeypatch: pytest.MonkeyPatch): """Create a mock Account for testing.""" - account = Account() + + account = Account(name="Test User", email="test@example.com") account.id = str(uuid.uuid4()) - account.name = "Test User" - account.email = "test@example.com" account.last_active_at = naive_utc_now() account.created_at = naive_utc_now() account.updated_at = naive_utc_now() - tenant = Tenant() + # Create mock tenant + tenant = Tenant(name="Test Tenant") tenant.id = str(uuid.uuid4()) - tenant.name = "Test Tenant" mock_session_instance = mock.Mock() diff --git a/api/tests/test_containers_integration_tests/services/test_account_service.py b/api/tests/test_containers_integration_tests/services/test_account_service.py index c98406d845..0a2fb955ae 100644 --- a/api/tests/test_containers_integration_tests/services/test_account_service.py +++ b/api/tests/test_containers_integration_tests/services/test_account_service.py @@ -16,6 +16,7 @@ from services.errors.account import ( AccountPasswordError, AccountRegisterError, CurrentPasswordIncorrectError, + TenantNotFoundError, ) from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError @@ -1414,7 +1415,7 @@ class TestTenantService: ) # Try to get current tenant (should fail) - with pytest.raises(AttributeError): + with pytest.raises((AttributeError, TenantNotFoundError)): TenantService.get_current_tenant_by_account(account) def test_switch_tenant_success(self, db_session_with_containers, mock_external_service_dependencies): diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_service.py index 60150667ed..0dd3909ba7 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_service.py @@ -44,27 +44,26 @@ class TestWorkflowService: Account: Created test account instance """ fake = fake or Faker() - account = Account() - account.id = fake.uuid4() - account.email = fake.email() - account.name = fake.name() - account.avatar_url = fake.url() - account.tenant_id = fake.uuid4() - account.status = "active" - account.type = "normal" - account.role = "owner" - account.interface_language = "en-US" # Set interface language for Site creation + account = Account( + email=fake.email(), + name=fake.name(), + avatar=fake.url(), + status="active", + interface_language="en-US", # Set interface language for Site creation + ) account.created_at = fake.date_time_this_year() + account.id = fake.uuid4() account.updated_at = account.created_at # Create a tenant for the account from models.account import Tenant - tenant = Tenant() - tenant.id = account.tenant_id - tenant.name = f"Test Tenant {fake.company()}" - tenant.plan = "basic" - tenant.status = "active" + tenant = Tenant( + name=f"Test Tenant {fake.company()}", + plan="basic", + status="active", + ) + tenant.id = account.current_tenant_id tenant.created_at = fake.date_time_this_year() tenant.updated_at = tenant.created_at @@ -91,20 +90,21 @@ class TestWorkflowService: App: Created test app instance """ fake = fake or Faker() - app = App() - app.id = fake.uuid4() - app.tenant_id = fake.uuid4() - app.name = fake.company() - app.description = fake.text() - app.mode = AppMode.WORKFLOW - app.icon_type = "emoji" - app.icon = "🤖" - app.icon_background = "#FFEAD5" - app.enable_site = True - app.enable_api = True - app.created_by = fake.uuid4() + app = App( + id=fake.uuid4(), + tenant_id=fake.uuid4(), + name=fake.company(), + description=fake.text(), + mode=AppMode.WORKFLOW, + icon_type="emoji", + icon="🤖", + icon_background="#FFEAD5", + enable_site=True, + enable_api=True, + created_by=fake.uuid4(), + workflow_id=None, # Will be set when workflow is created + ) app.updated_by = app.created_by - app.workflow_id = None # Will be set when workflow is created from extensions.ext_database import db @@ -126,19 +126,20 @@ class TestWorkflowService: Workflow: Created test workflow instance """ fake = fake or Faker() - workflow = Workflow() - workflow.id = fake.uuid4() - workflow.tenant_id = app.tenant_id - workflow.app_id = app.id - workflow.type = WorkflowType.WORKFLOW.value - workflow.version = Workflow.VERSION_DRAFT - workflow.graph = json.dumps({"nodes": [], "edges": []}) - workflow.features = json.dumps({"features": []}) - # unique_hash is a computed property based on graph and features - workflow.created_by = account.id - workflow.updated_by = account.id - workflow.environment_variables = [] - workflow.conversation_variables = [] + workflow = Workflow( + id=fake.uuid4(), + tenant_id=app.tenant_id, + app_id=app.id, + type=WorkflowType.WORKFLOW.value, + version=Workflow.VERSION_DRAFT, + graph=json.dumps({"nodes": [], "edges": []}), + features=json.dumps({"features": []}), + # unique_hash is a computed property based on graph and features + created_by=account.id, + updated_by=account.id, + environment_variables=[], + conversation_variables=[], + ) from extensions.ext_database import db diff --git a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py index 7af4f238be..94e9b76965 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py @@ -48,11 +48,8 @@ class TestDeleteSegmentFromIndexTask: Tenant: Created test tenant instance """ fake = fake or Faker() - tenant = Tenant() + tenant = Tenant(name=f"Test Tenant {fake.company()}", plan="basic", status="active") tenant.id = fake.uuid4() - tenant.name = f"Test Tenant {fake.company()}" - tenant.plan = "basic" - tenant.status = "active" tenant.created_at = fake.date_time_this_year() tenant.updated_at = tenant.created_at @@ -73,16 +70,14 @@ class TestDeleteSegmentFromIndexTask: Account: Created test account instance """ fake = fake or Faker() - account = Account() + account = Account( + name=fake.name(), + email=fake.email(), + avatar=fake.url(), + status="active", + interface_language="en-US", + ) account.id = fake.uuid4() - account.email = fake.email() - account.name = fake.name() - account.avatar_url = fake.url() - account.tenant_id = tenant.id - account.status = "active" - account.type = "normal" - account.role = "owner" - account.interface_language = "en-US" account.created_at = fake.date_time_this_year() account.updated_at = account.created_at diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py index 5fdb8c617c..0b36e0914a 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py @@ -43,27 +43,30 @@ class TestDisableSegmentsFromIndexTask: Account: Created test account instance """ fake = fake or Faker() - account = Account() + account = Account( + email=fake.email(), + name=fake.name(), + avatar=fake.url(), + status="active", + interface_language="en-US", + ) account.id = fake.uuid4() - account.email = fake.email() - account.name = fake.name() - account.avatar_url = fake.url() + # monkey-patch attributes for test setup account.tenant_id = fake.uuid4() - account.status = "active" account.type = "normal" account.role = "owner" - account.interface_language = "en-US" account.created_at = fake.date_time_this_year() account.updated_at = account.created_at # Create a tenant for the account from models.account import Tenant - tenant = Tenant() + tenant = Tenant( + name=f"Test Tenant {fake.company()}", + plan="basic", + status="active", + ) tenant.id = account.tenant_id - tenant.name = f"Test Tenant {fake.company()}" - tenant.plan = "basic" - tenant.status = "active" tenant.created_at = fake.date_time_this_year() tenant.updated_at = tenant.created_at @@ -91,20 +94,21 @@ class TestDisableSegmentsFromIndexTask: Dataset: Created test dataset instance """ fake = fake or Faker() - dataset = Dataset() - dataset.id = fake.uuid4() - dataset.tenant_id = account.tenant_id - dataset.name = f"Test Dataset {fake.word()}" - dataset.description = fake.text(max_nb_chars=200) - dataset.provider = "vendor" - dataset.permission = "only_me" - dataset.data_source_type = "upload_file" - dataset.indexing_technique = "high_quality" - dataset.created_by = account.id - dataset.updated_by = account.id - dataset.embedding_model = "text-embedding-ada-002" - dataset.embedding_model_provider = "openai" - dataset.built_in_field_enabled = False + dataset = Dataset( + id=fake.uuid4(), + tenant_id=account.tenant_id, + name=f"Test Dataset {fake.word()}", + description=fake.text(max_nb_chars=200), + provider="vendor", + permission="only_me", + data_source_type="upload_file", + indexing_technique="high_quality", + created_by=account.id, + updated_by=account.id, + embedding_model="text-embedding-ada-002", + embedding_model_provider="openai", + built_in_field_enabled=False, + ) from extensions.ext_database import db @@ -128,6 +132,7 @@ class TestDisableSegmentsFromIndexTask: """ fake = fake or Faker() document = DatasetDocument() + document.id = fake.uuid4() document.tenant_id = dataset.tenant_id document.dataset_id = dataset.id @@ -153,7 +158,6 @@ class TestDisableSegmentsFromIndexTask: document.archived = False document.doc_form = "text_model" # Use text_model form for testing document.doc_language = "en" - from extensions.ext_database import db db.session.add(document) diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py index 8fef87b317..ead7757c13 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py @@ -96,9 +96,9 @@ class TestMailInviteMemberTask: password=fake.password(), interface_language="en-US", status=AccountStatus.ACTIVE.value, - created_at=datetime.now(UTC), - updated_at=datetime.now(UTC), ) + account.created_at = datetime.now(UTC) + account.updated_at = datetime.now(UTC) db_session_with_containers.add(account) db_session_with_containers.commit() db_session_with_containers.refresh(account) @@ -106,9 +106,9 @@ class TestMailInviteMemberTask: # Create tenant tenant = Tenant( name=fake.company(), - created_at=datetime.now(UTC), - updated_at=datetime.now(UTC), ) + tenant.created_at = datetime.now(UTC) + tenant.updated_at = datetime.now(UTC) db_session_with_containers.add(tenant) db_session_with_containers.commit() db_session_with_containers.refresh(tenant) @@ -118,8 +118,8 @@ class TestMailInviteMemberTask: tenant_id=tenant.id, account_id=account.id, role=TenantAccountRole.OWNER.value, - created_at=datetime.now(UTC), ) + tenant_join.created_at = datetime.now(UTC) db_session_with_containers.add(tenant_join) db_session_with_containers.commit() @@ -164,9 +164,10 @@ class TestMailInviteMemberTask: password="", interface_language="en-US", status=AccountStatus.PENDING.value, - created_at=datetime.now(UTC), - updated_at=datetime.now(UTC), ) + + account.created_at = datetime.now(UTC) + account.updated_at = datetime.now(UTC) db_session_with_containers.add(account) db_session_with_containers.commit() db_session_with_containers.refresh(account) @@ -176,8 +177,8 @@ class TestMailInviteMemberTask: tenant_id=tenant.id, account_id=account.id, role=TenantAccountRole.NORMAL.value, - created_at=datetime.now(UTC), ) + tenant_join.created_at = datetime.now(UTC) db_session_with_containers.add(tenant_join) db_session_with_containers.commit() diff --git a/api/tests/unit_tests/libs/test_helper.py b/api/tests/unit_tests/libs/test_helper.py index b7701055f5..85789bfa7e 100644 --- a/api/tests/unit_tests/libs/test_helper.py +++ b/api/tests/unit_tests/libs/test_helper.py @@ -11,7 +11,7 @@ class TestExtractTenantId: def test_extract_tenant_id_from_account_with_tenant(self): """Test extracting tenant_id from Account with current_tenant_id.""" # Create a mock Account object - account = Account() + account = Account(name="test", email="test@example.com") # Mock the current_tenant_id property account._current_tenant = type("MockTenant", (), {"id": "account-tenant-123"})() @@ -21,7 +21,7 @@ class TestExtractTenantId: def test_extract_tenant_id_from_account_without_tenant(self): """Test extracting tenant_id from Account without current_tenant_id.""" # Create a mock Account object - account = Account() + account = Account(name="test", email="test@example.com") account._current_tenant = None tenant_id = extract_tenant_id(account) diff --git a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py index fadd1ee88f..28b339fe85 100644 --- a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py +++ b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py @@ -59,12 +59,11 @@ def session(): @pytest.fixture def mock_user(): """Create a user instance for testing.""" - user = Account() + user = Account(name="test", email="test@example.com") user.id = "test-user-id" - tenant = Tenant() + tenant = Tenant(name="Test Workspace") tenant.id = "test-tenant" - tenant.name = "Test Workspace" user._current_tenant = MagicMock() user._current_tenant.id = "test-tenant" diff --git a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py index 7e324ca4db..66361f26e0 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py @@ -47,7 +47,8 @@ class TestDraftVariableSaver: def test__should_variable_be_visible(self): mock_session = MagicMock(spec=Session) - mock_user = Account(id=str(uuid.uuid4())) + mock_user = Account(name="test", email="test@example.com") + mock_user.id = str(uuid.uuid4()) test_app_id = self._get_test_app_id() saver = DraftVariableSaver( session=mock_session,