diff --git a/api/models/dataset.py b/api/models/dataset.py index 4bc802bb1c..3f2d16d3bd 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -1026,19 +1026,21 @@ class Embedding(Base): return cast(list[float], pickle.loads(self.embedding)) # noqa: S301 -class DatasetCollectionBinding(Base): +class DatasetCollectionBinding(TypeBase): __tablename__ = "dataset_collection_bindings" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"), sa.Index("provider_model_name_idx", "provider_name", "model_name"), ) - id = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4())) + id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4()), init=False) provider_name: Mapped[str] = mapped_column(String(255), nullable=False) model_name: Mapped[str] = mapped_column(String(255), nullable=False) - type = mapped_column(String(40), server_default=sa.text("'dataset'"), nullable=False) - collection_name = mapped_column(String(64), nullable=False) - created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + type: Mapped[str] = mapped_column(String(40), server_default=sa.text("'dataset'"), nullable=False) + collection_name: Mapped[str] = mapped_column(String(64), nullable=False) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) class TidbAuthBinding(Base): @@ -1176,7 +1178,7 @@ class ExternalKnowledgeBindings(TypeBase): ) -class DatasetAutoDisableLog(Base): +class DatasetAutoDisableLog(TypeBase): __tablename__ = "dataset_auto_disable_logs" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"), @@ -1185,12 +1187,14 @@ class DatasetAutoDisableLog(Base): sa.Index("dataset_auto_disable_log_created_atx", "created_at"), ) - id = mapped_column(StringUUID, default=lambda: str(uuid4())) - tenant_id = mapped_column(StringUUID, nullable=False) - dataset_id = mapped_column(StringUUID, nullable=False) - document_id = mapped_column(StringUUID, nullable=False) - notified: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) - created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp()) + id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + document_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + notified: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"), default=False) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False + ) class RateLimitLog(TypeBase): diff --git a/api/models/model.py b/api/models/model.py index b0bf46e7d7..e2b9da46f1 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -16,7 +16,7 @@ from sqlalchemy.orm import Mapped, Session, mapped_column from configs import dify_config from constants import DEFAULT_FILE_NUMBER_LIMITS -from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType +from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod from core.file import helpers as file_helpers from core.tools.signature import sign_tool_file from core.workflow.enums import WorkflowExecutionStatus @@ -594,7 +594,7 @@ class InstalledApp(TypeBase): return tenant -class OAuthProviderApp(Base): +class OAuthProviderApp(TypeBase): """ Globally shared OAuth provider app information. Only for Dify Cloud. @@ -606,18 +606,21 @@ class OAuthProviderApp(Base): sa.Index("oauth_provider_app_client_id_idx", "client_id"), ) - id = mapped_column(StringUUID, default=lambda: str(uuidv7())) - app_icon = mapped_column(String(255), nullable=False) - app_label = mapped_column(sa.JSON, nullable=False, default="{}") - client_id = mapped_column(String(255), nullable=False) - client_secret = mapped_column(String(255), nullable=False) - redirect_uris = mapped_column(sa.JSON, nullable=False, default="[]") - scope = mapped_column( + id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False) + app_icon: Mapped[str] = mapped_column(String(255), nullable=False) + client_id: Mapped[str] = mapped_column(String(255), nullable=False) + client_secret: Mapped[str] = mapped_column(String(255), nullable=False) + app_label: Mapped[dict] = mapped_column(sa.JSON, nullable=False, default_factory=dict) + redirect_uris: Mapped[list] = mapped_column(sa.JSON, nullable=False, default_factory=list) + scope: Mapped[str] = mapped_column( String(255), nullable=False, server_default=sa.text("'read:name read:email read:avatar read:interface_language read:timezone'"), + default="read:name read:email read:avatar read:interface_language read:timezone", + ) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False ) - created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) class Conversation(Base): @@ -1335,7 +1338,7 @@ class MessageFeedback(Base): } -class MessageFile(Base): +class MessageFile(TypeBase): __tablename__ = "message_files" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="message_file_pkey"), @@ -1343,37 +1346,18 @@ class MessageFile(Base): sa.Index("message_file_created_by_idx", "created_by"), ) - def __init__( - self, - *, - message_id: str, - type: FileType, - transfer_method: FileTransferMethod, - url: str | None = None, - belongs_to: Literal["user", "assistant"] | None = None, - upload_file_id: str | None = None, - created_by_role: CreatorUserRole, - created_by: str, - ): - self.message_id = message_id - self.type = type - self.transfer_method = transfer_method - self.url = url - self.belongs_to = belongs_to - self.upload_file_id = upload_file_id - self.created_by_role = created_by_role.value - self.created_by = created_by - - id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4())) + id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False) message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) type: Mapped[str] = mapped_column(String(255), nullable=False) - transfer_method: Mapped[str] = mapped_column(String(255), nullable=False) - url: Mapped[str | None] = mapped_column(LongText, nullable=True) - belongs_to: Mapped[str | None] = mapped_column(String(255), nullable=True) - upload_file_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) - created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) + transfer_method: Mapped[FileTransferMethod] = mapped_column(String(255), nullable=False) + created_by_role: Mapped[CreatorUserRole] = mapped_column(String(255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + belongs_to: Mapped[Literal["user", "assistant"] | None] = mapped_column(String(255), nullable=True, default=None) + url: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) + upload_file_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) class MessageAnnotation(Base): @@ -1447,22 +1431,28 @@ class AppAnnotationHitHistory(Base): return account -class AppAnnotationSetting(Base): +class AppAnnotationSetting(TypeBase): __tablename__ = "app_annotation_settings" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"), sa.Index("app_annotation_settings_app_idx", "app_id"), ) - id = mapped_column(StringUUID, default=lambda: str(uuid4())) - app_id = mapped_column(StringUUID, nullable=False) - score_threshold = mapped_column(Float, nullable=False, server_default=sa.text("0")) - collection_binding_id = mapped_column(StringUUID, nullable=False) - created_user_id = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_user_id = mapped_column(StringUUID, nullable=False) - updated_at = mapped_column( - sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False) + app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + score_threshold: Mapped[float] = mapped_column(Float, nullable=False, server_default=sa.text("0")) + collection_binding_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) + updated_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + updated_at: Mapped[datetime] = mapped_column( + sa.DateTime, + nullable=False, + server_default=func.current_timestamp(), + onupdate=func.current_timestamp(), + init=False, ) @property diff --git a/api/models/provider.py b/api/models/provider.py index a840a483ab..577e098a2e 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -9,7 +9,7 @@ from sqlalchemy.orm import Mapped, mapped_column from libs.uuid_utils import uuidv7 -from .base import Base, TypeBase +from .base import TypeBase from .engine import db from .types import LongText, StringUUID @@ -262,7 +262,7 @@ class ProviderModelSetting(TypeBase): ) -class LoadBalancingModelConfig(Base): +class LoadBalancingModelConfig(TypeBase): """ Configurations for load balancing models. """ @@ -273,23 +273,25 @@ class LoadBalancingModelConfig(Base): sa.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), ) - id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4())) + id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) provider_name: Mapped[str] = mapped_column(String(255), nullable=False) model_name: Mapped[str] = mapped_column(String(255), nullable=False) model_type: Mapped[str] = mapped_column(String(40), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False) - encrypted_config: Mapped[str | None] = mapped_column(LongText, nullable=True) - credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) - credential_source_type: Mapped[str | None] = mapped_column(String(40), nullable=True) - enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true")) - created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + encrypted_config: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) + credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) + credential_source_type: Mapped[str | None] = mapped_column(String(40), nullable=True, default=None) + enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"), default=True) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) updated_at: Mapped[datetime] = mapped_column( - DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), init=False ) -class ProviderCredential(Base): +class ProviderCredential(TypeBase): """ Provider credential - stores multiple named credentials for each provider """ @@ -300,18 +302,20 @@ class ProviderCredential(Base): sa.Index("provider_credential_tenant_provider_idx", "tenant_id", "provider_name"), ) - id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7())) + id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) provider_name: Mapped[str] = mapped_column(String(255), nullable=False) credential_name: Mapped[str] = mapped_column(String(255), nullable=False) encrypted_config: Mapped[str] = mapped_column(LongText, nullable=False) - created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) updated_at: Mapped[datetime] = mapped_column( - DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), init=False ) -class ProviderModelCredential(Base): +class ProviderModelCredential(TypeBase): """ Provider model credential - stores multiple named credentials for each provider model """ @@ -328,14 +332,16 @@ class ProviderModelCredential(Base): ), ) - id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7())) + id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) provider_name: Mapped[str] = mapped_column(String(255), nullable=False) model_name: Mapped[str] = mapped_column(String(255), nullable=False) model_type: Mapped[str] = mapped_column(String(40), nullable=False) credential_name: Mapped[str] = mapped_column(String(255), nullable=False) encrypted_config: Mapped[str] = mapped_column(LongText, nullable=False) - created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column( - DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), init=False ) diff --git a/api/models/trigger.py b/api/models/trigger.py index 753fdb227b..e89309551a 100644 --- a/api/models/trigger.py +++ b/api/models/trigger.py @@ -129,27 +129,30 @@ class TriggerOAuthSystemClient(TypeBase): # tenant level trigger oauth client params (client_id, client_secret, etc.) -class TriggerOAuthTenantClient(Base): +class TriggerOAuthTenantClient(TypeBase): __tablename__ = "trigger_oauth_tenant_clients" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="trigger_oauth_tenant_client_pkey"), sa.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_trigger_oauth_tenant_client"), ) - id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4())) + id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False) # tenant id tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) plugin_id: Mapped[str] = mapped_column(String(255), nullable=False) provider: Mapped[str] = mapped_column(String(255), nullable=False) enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"), default=True) # oauth params of the trigger provider - encrypted_oauth_params: Mapped[str] = mapped_column(LongText, nullable=False) - created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + encrypted_oauth_params: Mapped[str] = mapped_column(LongText, nullable=False, default="{}") + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) updated_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=func.current_timestamp(), server_onupdate=func.current_timestamp(), + init=False, ) @property diff --git a/api/tests/test_containers_integration_tests/services/test_agent_service.py b/api/tests/test_containers_integration_tests/services/test_agent_service.py index ca513319b2..3be2798085 100644 --- a/api/tests/test_containers_integration_tests/services/test_agent_service.py +++ b/api/tests/test_containers_integration_tests/services/test_agent_service.py @@ -852,6 +852,7 @@ class TestAgentService: # Add files to message from models.model import MessageFile + assert message.from_account_id is not None message_file1 = MessageFile( message_id=message.id, type=FileType.IMAGE, diff --git a/api/tests/test_containers_integration_tests/services/test_annotation_service.py b/api/tests/test_containers_integration_tests/services/test_annotation_service.py index 2b03ec1c26..da73122cd7 100644 --- a/api/tests/test_containers_integration_tests/services/test_annotation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_annotation_service.py @@ -860,22 +860,24 @@ class TestAnnotationService: from models.model import AppAnnotationSetting # Create a collection binding first - collection_binding = DatasetCollectionBinding() - collection_binding.id = fake.uuid4() - collection_binding.provider_name = "openai" - collection_binding.model_name = "text-embedding-ada-002" - collection_binding.type = "annotation" - collection_binding.collection_name = f"annotation_collection_{fake.uuid4()}" + collection_binding = DatasetCollectionBinding( + provider_name="openai", + model_name="text-embedding-ada-002", + type="annotation", + collection_name=f"annotation_collection_{fake.uuid4()}", + ) + collection_binding.id = str(fake.uuid4()) db.session.add(collection_binding) db.session.flush() # Create annotation setting - annotation_setting = AppAnnotationSetting() - annotation_setting.app_id = app.id - annotation_setting.score_threshold = 0.8 - annotation_setting.collection_binding_id = collection_binding.id - annotation_setting.created_user_id = account.id - annotation_setting.updated_user_id = account.id + annotation_setting = AppAnnotationSetting( + app_id=app.id, + score_threshold=0.8, + collection_binding_id=collection_binding.id, + created_user_id=account.id, + updated_user_id=account.id, + ) db.session.add(annotation_setting) db.session.commit() @@ -919,22 +921,24 @@ class TestAnnotationService: from models.model import AppAnnotationSetting # Create a collection binding first - collection_binding = DatasetCollectionBinding() - collection_binding.id = fake.uuid4() - collection_binding.provider_name = "openai" - collection_binding.model_name = "text-embedding-ada-002" - collection_binding.type = "annotation" - collection_binding.collection_name = f"annotation_collection_{fake.uuid4()}" + collection_binding = DatasetCollectionBinding( + provider_name="openai", + model_name="text-embedding-ada-002", + type="annotation", + collection_name=f"annotation_collection_{fake.uuid4()}", + ) + collection_binding.id = str(fake.uuid4()) db.session.add(collection_binding) db.session.flush() # Create annotation setting - annotation_setting = AppAnnotationSetting() - annotation_setting.app_id = app.id - annotation_setting.score_threshold = 0.8 - annotation_setting.collection_binding_id = collection_binding.id - annotation_setting.created_user_id = account.id - annotation_setting.updated_user_id = account.id + annotation_setting = AppAnnotationSetting( + app_id=app.id, + score_threshold=0.8, + collection_binding_id=collection_binding.id, + created_user_id=account.id, + updated_user_id=account.id, + ) db.session.add(annotation_setting) db.session.commit() @@ -1020,22 +1024,24 @@ class TestAnnotationService: from models.model import AppAnnotationSetting # Create a collection binding first - collection_binding = DatasetCollectionBinding() - collection_binding.id = fake.uuid4() - collection_binding.provider_name = "openai" - collection_binding.model_name = "text-embedding-ada-002" - collection_binding.type = "annotation" - collection_binding.collection_name = f"annotation_collection_{fake.uuid4()}" + collection_binding = DatasetCollectionBinding( + provider_name="openai", + model_name="text-embedding-ada-002", + type="annotation", + collection_name=f"annotation_collection_{fake.uuid4()}", + ) + collection_binding.id = str(fake.uuid4()) db.session.add(collection_binding) db.session.flush() # Create annotation setting - annotation_setting = AppAnnotationSetting() - annotation_setting.app_id = app.id - annotation_setting.score_threshold = 0.8 - annotation_setting.collection_binding_id = collection_binding.id - annotation_setting.created_user_id = account.id - annotation_setting.updated_user_id = account.id + annotation_setting = AppAnnotationSetting( + app_id=app.id, + score_threshold=0.8, + collection_binding_id=collection_binding.id, + created_user_id=account.id, + updated_user_id=account.id, + ) db.session.add(annotation_setting) db.session.commit() @@ -1080,22 +1086,24 @@ class TestAnnotationService: from models.model import AppAnnotationSetting # Create a collection binding first - collection_binding = DatasetCollectionBinding() - collection_binding.id = fake.uuid4() - collection_binding.provider_name = "openai" - collection_binding.model_name = "text-embedding-ada-002" - collection_binding.type = "annotation" - collection_binding.collection_name = f"annotation_collection_{fake.uuid4()}" + collection_binding = DatasetCollectionBinding( + provider_name="openai", + model_name="text-embedding-ada-002", + type="annotation", + collection_name=f"annotation_collection_{fake.uuid4()}", + ) + collection_binding.id = str(fake.uuid4()) db.session.add(collection_binding) db.session.flush() # Create annotation setting - annotation_setting = AppAnnotationSetting() - annotation_setting.app_id = app.id - annotation_setting.score_threshold = 0.8 - annotation_setting.collection_binding_id = collection_binding.id - annotation_setting.created_user_id = account.id - annotation_setting.updated_user_id = account.id + annotation_setting = AppAnnotationSetting( + app_id=app.id, + score_threshold=0.8, + collection_binding_id=collection_binding.id, + created_user_id=account.id, + updated_user_id=account.id, + ) db.session.add(annotation_setting) db.session.commit() @@ -1151,22 +1159,25 @@ class TestAnnotationService: from models.model import AppAnnotationSetting # Create a collection binding first - collection_binding = DatasetCollectionBinding() - collection_binding.id = fake.uuid4() - collection_binding.provider_name = "openai" - collection_binding.model_name = "text-embedding-ada-002" - collection_binding.type = "annotation" - collection_binding.collection_name = f"annotation_collection_{fake.uuid4()}" + collection_binding = DatasetCollectionBinding( + provider_name="openai", + model_name="text-embedding-ada-002", + type="annotation", + collection_name=f"annotation_collection_{fake.uuid4()}", + ) + collection_binding.id = str(fake.uuid4()) db.session.add(collection_binding) db.session.flush() # Create annotation setting - annotation_setting = AppAnnotationSetting() - annotation_setting.app_id = app.id - annotation_setting.score_threshold = 0.8 - annotation_setting.collection_binding_id = collection_binding.id - annotation_setting.created_user_id = account.id - annotation_setting.updated_user_id = account.id + annotation_setting = AppAnnotationSetting( + app_id=app.id, + score_threshold=0.8, + collection_binding_id=collection_binding.id, + created_user_id=account.id, + updated_user_id=account.id, + ) + db.session.add(annotation_setting) db.session.commit() @@ -1211,22 +1222,24 @@ class TestAnnotationService: from models.model import AppAnnotationSetting # Create a collection binding first - collection_binding = DatasetCollectionBinding() - collection_binding.id = fake.uuid4() - collection_binding.provider_name = "openai" - collection_binding.model_name = "text-embedding-ada-002" - collection_binding.type = "annotation" - collection_binding.collection_name = f"annotation_collection_{fake.uuid4()}" + collection_binding = DatasetCollectionBinding( + provider_name="openai", + model_name="text-embedding-ada-002", + type="annotation", + collection_name=f"annotation_collection_{fake.uuid4()}", + ) + collection_binding.id = str(fake.uuid4()) db.session.add(collection_binding) db.session.flush() # Create annotation setting - annotation_setting = AppAnnotationSetting() - annotation_setting.app_id = app.id - annotation_setting.score_threshold = 0.8 - annotation_setting.collection_binding_id = collection_binding.id - annotation_setting.created_user_id = account.id - annotation_setting.updated_user_id = account.id + annotation_setting = AppAnnotationSetting( + app_id=app.id, + score_threshold=0.8, + collection_binding_id=collection_binding.id, + created_user_id=account.id, + updated_user_id=account.id, + ) db.session.add(annotation_setting) db.session.commit() diff --git a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py index f1530bcac6..9478bb9ddb 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py @@ -502,11 +502,11 @@ class TestAddDocumentToIndexTask: auto_disable_logs = [] for _ in range(2): log_entry = DatasetAutoDisableLog( - id=fake.uuid4(), tenant_id=document.tenant_id, dataset_id=dataset.id, document_id=document.id, ) + log_entry.id = str(fake.uuid4()) db.session.add(log_entry) auto_disable_logs.append(log_entry) diff --git a/api/tests/unit_tests/core/test_provider_manager.py b/api/tests/unit_tests/core/test_provider_manager.py index dbbda5f74c..3163d53b87 100644 --- a/api/tests/unit_tests/core/test_provider_manager.py +++ b/api/tests/unit_tests/core/test_provider_manager.py @@ -39,9 +39,9 @@ def test__to_model_settings(mocker: MockerFixture, mock_provider_entity): ps.id = "id" provider_model_settings = [ps] + load_balancing_model_configs = [ LoadBalancingModelConfig( - id="id1", tenant_id="tenant_id", provider_name="openai", model_name="gpt-4", @@ -51,7 +51,6 @@ def test__to_model_settings(mocker: MockerFixture, mock_provider_entity): enabled=True, ), LoadBalancingModelConfig( - id="id2", tenant_id="tenant_id", provider_name="openai", model_name="gpt-4", @@ -61,6 +60,8 @@ def test__to_model_settings(mocker: MockerFixture, mock_provider_entity): enabled=True, ), ] + load_balancing_model_configs[0].id = "id1" + load_balancing_model_configs[1].id = "id2" mocker.patch( "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} @@ -101,7 +102,6 @@ def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_ent provider_model_settings = [ps] load_balancing_model_configs = [ LoadBalancingModelConfig( - id="id1", tenant_id="tenant_id", provider_name="openai", model_name="gpt-4", @@ -111,6 +111,7 @@ def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_ent enabled=True, ) ] + load_balancing_model_configs[0].id = "id1" mocker.patch( "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} @@ -148,7 +149,6 @@ def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_ent provider_model_settings = [ps] load_balancing_model_configs = [ LoadBalancingModelConfig( - id="id1", tenant_id="tenant_id", provider_name="openai", model_name="gpt-4", @@ -158,7 +158,6 @@ def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_ent enabled=True, ), LoadBalancingModelConfig( - id="id2", tenant_id="tenant_id", provider_name="openai", model_name="gpt-4", @@ -168,6 +167,8 @@ def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_ent enabled=True, ), ] + load_balancing_model_configs[0].id = "id1" + load_balancing_model_configs[1].id = "id2" mocker.patch( "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}