mirror of https://github.com/langgenius/dify.git
more typed orm (#28519)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
2c9e435558
commit
6241b87f90
|
|
@ -1026,19 +1026,21 @@ class Embedding(Base):
|
|||
return cast(list[float], pickle.loads(self.embedding)) # noqa: S301
|
||||
|
||||
|
||||
class DatasetCollectionBinding(Base):
|
||||
class DatasetCollectionBinding(TypeBase):
|
||||
__tablename__ = "dataset_collection_bindings"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"),
|
||||
sa.Index("provider_model_name_idx", "provider_name", "model_name"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4()))
|
||||
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4()), init=False)
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
type = mapped_column(String(40), server_default=sa.text("'dataset'"), nullable=False)
|
||||
collection_name = mapped_column(String(64), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
type: Mapped[str] = mapped_column(String(40), server_default=sa.text("'dataset'"), nullable=False)
|
||||
collection_name: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
|
||||
|
||||
class TidbAuthBinding(Base):
|
||||
|
|
@ -1176,7 +1178,7 @@ class ExternalKnowledgeBindings(TypeBase):
|
|||
)
|
||||
|
||||
|
||||
class DatasetAutoDisableLog(Base):
|
||||
class DatasetAutoDisableLog(TypeBase):
|
||||
__tablename__ = "dataset_auto_disable_logs"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"),
|
||||
|
|
@ -1185,12 +1187,14 @@ class DatasetAutoDisableLog(Base):
|
|||
sa.Index("dataset_auto_disable_log_created_atx", "created_at"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id = mapped_column(StringUUID, nullable=False)
|
||||
document_id = mapped_column(StringUUID, nullable=False)
|
||||
notified: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp())
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
notified: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"), default=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False
|
||||
)
|
||||
|
||||
|
||||
class RateLimitLog(TypeBase):
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from sqlalchemy.orm import Mapped, Session, mapped_column
|
|||
|
||||
from configs import dify_config
|
||||
from constants import DEFAULT_FILE_NUMBER_LIMITS
|
||||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
|
||||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
||||
from core.file import helpers as file_helpers
|
||||
from core.tools.signature import sign_tool_file
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
|
|
@ -594,7 +594,7 @@ class InstalledApp(TypeBase):
|
|||
return tenant
|
||||
|
||||
|
||||
class OAuthProviderApp(Base):
|
||||
class OAuthProviderApp(TypeBase):
|
||||
"""
|
||||
Globally shared OAuth provider app information.
|
||||
Only for Dify Cloud.
|
||||
|
|
@ -606,18 +606,21 @@ class OAuthProviderApp(Base):
|
|||
sa.Index("oauth_provider_app_client_id_idx", "client_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
||||
app_icon = mapped_column(String(255), nullable=False)
|
||||
app_label = mapped_column(sa.JSON, nullable=False, default="{}")
|
||||
client_id = mapped_column(String(255), nullable=False)
|
||||
client_secret = mapped_column(String(255), nullable=False)
|
||||
redirect_uris = mapped_column(sa.JSON, nullable=False, default="[]")
|
||||
scope = mapped_column(
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
|
||||
app_icon: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
client_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
client_secret: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
app_label: Mapped[dict] = mapped_column(sa.JSON, nullable=False, default_factory=dict)
|
||||
redirect_uris: Mapped[list] = mapped_column(sa.JSON, nullable=False, default_factory=list)
|
||||
scope: Mapped[str] = mapped_column(
|
||||
String(255),
|
||||
nullable=False,
|
||||
server_default=sa.text("'read:name read:email read:avatar read:interface_language read:timezone'"),
|
||||
default="read:name read:email read:avatar read:interface_language read:timezone",
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class Conversation(Base):
|
||||
|
|
@ -1335,7 +1338,7 @@ class MessageFeedback(Base):
|
|||
}
|
||||
|
||||
|
||||
class MessageFile(Base):
|
||||
class MessageFile(TypeBase):
|
||||
__tablename__ = "message_files"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="message_file_pkey"),
|
||||
|
|
@ -1343,37 +1346,18 @@ class MessageFile(Base):
|
|||
sa.Index("message_file_created_by_idx", "created_by"),
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
message_id: str,
|
||||
type: FileType,
|
||||
transfer_method: FileTransferMethod,
|
||||
url: str | None = None,
|
||||
belongs_to: Literal["user", "assistant"] | None = None,
|
||||
upload_file_id: str | None = None,
|
||||
created_by_role: CreatorUserRole,
|
||||
created_by: str,
|
||||
):
|
||||
self.message_id = message_id
|
||||
self.type = type
|
||||
self.transfer_method = transfer_method
|
||||
self.url = url
|
||||
self.belongs_to = belongs_to
|
||||
self.upload_file_id = upload_file_id
|
||||
self.created_by_role = created_by_role.value
|
||||
self.created_by = created_by
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
transfer_method: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
url: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
belongs_to: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
upload_file_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
transfer_method: Mapped[FileTransferMethod] = mapped_column(String(255), nullable=False)
|
||||
created_by_role: Mapped[CreatorUserRole] = mapped_column(String(255), nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
belongs_to: Mapped[Literal["user", "assistant"] | None] = mapped_column(String(255), nullable=True, default=None)
|
||||
url: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
|
||||
upload_file_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
|
||||
|
||||
class MessageAnnotation(Base):
|
||||
|
|
@ -1447,22 +1431,28 @@ class AppAnnotationHitHistory(Base):
|
|||
return account
|
||||
|
||||
|
||||
class AppAnnotationSetting(Base):
|
||||
class AppAnnotationSetting(TypeBase):
|
||||
__tablename__ = "app_annotation_settings"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"),
|
||||
sa.Index("app_annotation_settings_app_idx", "app_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
score_threshold = mapped_column(Float, nullable=False, server_default=sa.text("0"))
|
||||
collection_binding_id = mapped_column(StringUUID, nullable=False)
|
||||
created_user_id = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_user_id = mapped_column(StringUUID, nullable=False)
|
||||
updated_at = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
score_threshold: Mapped[float] = mapped_column(Float, nullable=False, server_default=sa.text("0"))
|
||||
collection_binding_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
updated_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
|
||||
@property
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from sqlalchemy.orm import Mapped, mapped_column
|
|||
|
||||
from libs.uuid_utils import uuidv7
|
||||
|
||||
from .base import Base, TypeBase
|
||||
from .base import TypeBase
|
||||
from .engine import db
|
||||
from .types import LongText, StringUUID
|
||||
|
||||
|
|
@ -262,7 +262,7 @@ class ProviderModelSetting(TypeBase):
|
|||
)
|
||||
|
||||
|
||||
class LoadBalancingModelConfig(Base):
|
||||
class LoadBalancingModelConfig(TypeBase):
|
||||
"""
|
||||
Configurations for load balancing models.
|
||||
"""
|
||||
|
|
@ -273,23 +273,25 @@ class LoadBalancingModelConfig(Base):
|
|||
sa.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
encrypted_config: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
credential_source_type: Mapped[str | None] = mapped_column(String(40), nullable=True)
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
encrypted_config: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
|
||||
credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
credential_source_type: Mapped[str | None] = mapped_column(String(40), nullable=True, default=None)
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"), default=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), init=False
|
||||
)
|
||||
|
||||
|
||||
class ProviderCredential(Base):
|
||||
class ProviderCredential(TypeBase):
|
||||
"""
|
||||
Provider credential - stores multiple named credentials for each provider
|
||||
"""
|
||||
|
|
@ -300,18 +302,20 @@ class ProviderCredential(Base):
|
|||
sa.Index("provider_credential_tenant_provider_idx", "tenant_id", "provider_name"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
credential_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
encrypted_config: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), init=False
|
||||
)
|
||||
|
||||
|
||||
class ProviderModelCredential(Base):
|
||||
class ProviderModelCredential(TypeBase):
|
||||
"""
|
||||
Provider model credential - stores multiple named credentials for each provider model
|
||||
"""
|
||||
|
|
@ -328,14 +332,16 @@ class ProviderModelCredential(Base):
|
|||
),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
credential_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
encrypted_config: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), init=False
|
||||
)
|
||||
|
|
|
|||
|
|
@ -129,27 +129,30 @@ class TriggerOAuthSystemClient(TypeBase):
|
|||
|
||||
|
||||
# tenant level trigger oauth client params (client_id, client_secret, etc.)
|
||||
class TriggerOAuthTenantClient(Base):
|
||||
class TriggerOAuthTenantClient(TypeBase):
|
||||
__tablename__ = "trigger_oauth_tenant_clients"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="trigger_oauth_tenant_client_pkey"),
|
||||
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_trigger_oauth_tenant_client"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
# tenant id
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
plugin_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"), default=True)
|
||||
# oauth params of the trigger provider
|
||||
encrypted_oauth_params: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
encrypted_oauth_params: Mapped[str] = mapped_column(LongText, nullable=False, default="{}")
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
server_default=func.current_timestamp(),
|
||||
server_onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
|
||||
@property
|
||||
|
|
|
|||
|
|
@ -852,6 +852,7 @@ class TestAgentService:
|
|||
# Add files to message
|
||||
from models.model import MessageFile
|
||||
|
||||
assert message.from_account_id is not None
|
||||
message_file1 = MessageFile(
|
||||
message_id=message.id,
|
||||
type=FileType.IMAGE,
|
||||
|
|
|
|||
|
|
@ -860,22 +860,24 @@ class TestAnnotationService:
|
|||
from models.model import AppAnnotationSetting
|
||||
|
||||
# Create a collection binding first
|
||||
collection_binding = DatasetCollectionBinding()
|
||||
collection_binding.id = fake.uuid4()
|
||||
collection_binding.provider_name = "openai"
|
||||
collection_binding.model_name = "text-embedding-ada-002"
|
||||
collection_binding.type = "annotation"
|
||||
collection_binding.collection_name = f"annotation_collection_{fake.uuid4()}"
|
||||
collection_binding = DatasetCollectionBinding(
|
||||
provider_name="openai",
|
||||
model_name="text-embedding-ada-002",
|
||||
type="annotation",
|
||||
collection_name=f"annotation_collection_{fake.uuid4()}",
|
||||
)
|
||||
collection_binding.id = str(fake.uuid4())
|
||||
db.session.add(collection_binding)
|
||||
db.session.flush()
|
||||
|
||||
# Create annotation setting
|
||||
annotation_setting = AppAnnotationSetting()
|
||||
annotation_setting.app_id = app.id
|
||||
annotation_setting.score_threshold = 0.8
|
||||
annotation_setting.collection_binding_id = collection_binding.id
|
||||
annotation_setting.created_user_id = account.id
|
||||
annotation_setting.updated_user_id = account.id
|
||||
annotation_setting = AppAnnotationSetting(
|
||||
app_id=app.id,
|
||||
score_threshold=0.8,
|
||||
collection_binding_id=collection_binding.id,
|
||||
created_user_id=account.id,
|
||||
updated_user_id=account.id,
|
||||
)
|
||||
db.session.add(annotation_setting)
|
||||
db.session.commit()
|
||||
|
||||
|
|
@ -919,22 +921,24 @@ class TestAnnotationService:
|
|||
from models.model import AppAnnotationSetting
|
||||
|
||||
# Create a collection binding first
|
||||
collection_binding = DatasetCollectionBinding()
|
||||
collection_binding.id = fake.uuid4()
|
||||
collection_binding.provider_name = "openai"
|
||||
collection_binding.model_name = "text-embedding-ada-002"
|
||||
collection_binding.type = "annotation"
|
||||
collection_binding.collection_name = f"annotation_collection_{fake.uuid4()}"
|
||||
collection_binding = DatasetCollectionBinding(
|
||||
provider_name="openai",
|
||||
model_name="text-embedding-ada-002",
|
||||
type="annotation",
|
||||
collection_name=f"annotation_collection_{fake.uuid4()}",
|
||||
)
|
||||
collection_binding.id = str(fake.uuid4())
|
||||
db.session.add(collection_binding)
|
||||
db.session.flush()
|
||||
|
||||
# Create annotation setting
|
||||
annotation_setting = AppAnnotationSetting()
|
||||
annotation_setting.app_id = app.id
|
||||
annotation_setting.score_threshold = 0.8
|
||||
annotation_setting.collection_binding_id = collection_binding.id
|
||||
annotation_setting.created_user_id = account.id
|
||||
annotation_setting.updated_user_id = account.id
|
||||
annotation_setting = AppAnnotationSetting(
|
||||
app_id=app.id,
|
||||
score_threshold=0.8,
|
||||
collection_binding_id=collection_binding.id,
|
||||
created_user_id=account.id,
|
||||
updated_user_id=account.id,
|
||||
)
|
||||
db.session.add(annotation_setting)
|
||||
db.session.commit()
|
||||
|
||||
|
|
@ -1020,22 +1024,24 @@ class TestAnnotationService:
|
|||
from models.model import AppAnnotationSetting
|
||||
|
||||
# Create a collection binding first
|
||||
collection_binding = DatasetCollectionBinding()
|
||||
collection_binding.id = fake.uuid4()
|
||||
collection_binding.provider_name = "openai"
|
||||
collection_binding.model_name = "text-embedding-ada-002"
|
||||
collection_binding.type = "annotation"
|
||||
collection_binding.collection_name = f"annotation_collection_{fake.uuid4()}"
|
||||
collection_binding = DatasetCollectionBinding(
|
||||
provider_name="openai",
|
||||
model_name="text-embedding-ada-002",
|
||||
type="annotation",
|
||||
collection_name=f"annotation_collection_{fake.uuid4()}",
|
||||
)
|
||||
collection_binding.id = str(fake.uuid4())
|
||||
db.session.add(collection_binding)
|
||||
db.session.flush()
|
||||
|
||||
# Create annotation setting
|
||||
annotation_setting = AppAnnotationSetting()
|
||||
annotation_setting.app_id = app.id
|
||||
annotation_setting.score_threshold = 0.8
|
||||
annotation_setting.collection_binding_id = collection_binding.id
|
||||
annotation_setting.created_user_id = account.id
|
||||
annotation_setting.updated_user_id = account.id
|
||||
annotation_setting = AppAnnotationSetting(
|
||||
app_id=app.id,
|
||||
score_threshold=0.8,
|
||||
collection_binding_id=collection_binding.id,
|
||||
created_user_id=account.id,
|
||||
updated_user_id=account.id,
|
||||
)
|
||||
db.session.add(annotation_setting)
|
||||
db.session.commit()
|
||||
|
||||
|
|
@ -1080,22 +1086,24 @@ class TestAnnotationService:
|
|||
from models.model import AppAnnotationSetting
|
||||
|
||||
# Create a collection binding first
|
||||
collection_binding = DatasetCollectionBinding()
|
||||
collection_binding.id = fake.uuid4()
|
||||
collection_binding.provider_name = "openai"
|
||||
collection_binding.model_name = "text-embedding-ada-002"
|
||||
collection_binding.type = "annotation"
|
||||
collection_binding.collection_name = f"annotation_collection_{fake.uuid4()}"
|
||||
collection_binding = DatasetCollectionBinding(
|
||||
provider_name="openai",
|
||||
model_name="text-embedding-ada-002",
|
||||
type="annotation",
|
||||
collection_name=f"annotation_collection_{fake.uuid4()}",
|
||||
)
|
||||
collection_binding.id = str(fake.uuid4())
|
||||
db.session.add(collection_binding)
|
||||
db.session.flush()
|
||||
|
||||
# Create annotation setting
|
||||
annotation_setting = AppAnnotationSetting()
|
||||
annotation_setting.app_id = app.id
|
||||
annotation_setting.score_threshold = 0.8
|
||||
annotation_setting.collection_binding_id = collection_binding.id
|
||||
annotation_setting.created_user_id = account.id
|
||||
annotation_setting.updated_user_id = account.id
|
||||
annotation_setting = AppAnnotationSetting(
|
||||
app_id=app.id,
|
||||
score_threshold=0.8,
|
||||
collection_binding_id=collection_binding.id,
|
||||
created_user_id=account.id,
|
||||
updated_user_id=account.id,
|
||||
)
|
||||
db.session.add(annotation_setting)
|
||||
db.session.commit()
|
||||
|
||||
|
|
@ -1151,22 +1159,25 @@ class TestAnnotationService:
|
|||
from models.model import AppAnnotationSetting
|
||||
|
||||
# Create a collection binding first
|
||||
collection_binding = DatasetCollectionBinding()
|
||||
collection_binding.id = fake.uuid4()
|
||||
collection_binding.provider_name = "openai"
|
||||
collection_binding.model_name = "text-embedding-ada-002"
|
||||
collection_binding.type = "annotation"
|
||||
collection_binding.collection_name = f"annotation_collection_{fake.uuid4()}"
|
||||
collection_binding = DatasetCollectionBinding(
|
||||
provider_name="openai",
|
||||
model_name="text-embedding-ada-002",
|
||||
type="annotation",
|
||||
collection_name=f"annotation_collection_{fake.uuid4()}",
|
||||
)
|
||||
collection_binding.id = str(fake.uuid4())
|
||||
db.session.add(collection_binding)
|
||||
db.session.flush()
|
||||
|
||||
# Create annotation setting
|
||||
annotation_setting = AppAnnotationSetting()
|
||||
annotation_setting.app_id = app.id
|
||||
annotation_setting.score_threshold = 0.8
|
||||
annotation_setting.collection_binding_id = collection_binding.id
|
||||
annotation_setting.created_user_id = account.id
|
||||
annotation_setting.updated_user_id = account.id
|
||||
annotation_setting = AppAnnotationSetting(
|
||||
app_id=app.id,
|
||||
score_threshold=0.8,
|
||||
collection_binding_id=collection_binding.id,
|
||||
created_user_id=account.id,
|
||||
updated_user_id=account.id,
|
||||
)
|
||||
|
||||
db.session.add(annotation_setting)
|
||||
db.session.commit()
|
||||
|
||||
|
|
@ -1211,22 +1222,24 @@ class TestAnnotationService:
|
|||
from models.model import AppAnnotationSetting
|
||||
|
||||
# Create a collection binding first
|
||||
collection_binding = DatasetCollectionBinding()
|
||||
collection_binding.id = fake.uuid4()
|
||||
collection_binding.provider_name = "openai"
|
||||
collection_binding.model_name = "text-embedding-ada-002"
|
||||
collection_binding.type = "annotation"
|
||||
collection_binding.collection_name = f"annotation_collection_{fake.uuid4()}"
|
||||
collection_binding = DatasetCollectionBinding(
|
||||
provider_name="openai",
|
||||
model_name="text-embedding-ada-002",
|
||||
type="annotation",
|
||||
collection_name=f"annotation_collection_{fake.uuid4()}",
|
||||
)
|
||||
collection_binding.id = str(fake.uuid4())
|
||||
db.session.add(collection_binding)
|
||||
db.session.flush()
|
||||
|
||||
# Create annotation setting
|
||||
annotation_setting = AppAnnotationSetting()
|
||||
annotation_setting.app_id = app.id
|
||||
annotation_setting.score_threshold = 0.8
|
||||
annotation_setting.collection_binding_id = collection_binding.id
|
||||
annotation_setting.created_user_id = account.id
|
||||
annotation_setting.updated_user_id = account.id
|
||||
annotation_setting = AppAnnotationSetting(
|
||||
app_id=app.id,
|
||||
score_threshold=0.8,
|
||||
collection_binding_id=collection_binding.id,
|
||||
created_user_id=account.id,
|
||||
updated_user_id=account.id,
|
||||
)
|
||||
db.session.add(annotation_setting)
|
||||
db.session.commit()
|
||||
|
||||
|
|
|
|||
|
|
@ -502,11 +502,11 @@ class TestAddDocumentToIndexTask:
|
|||
auto_disable_logs = []
|
||||
for _ in range(2):
|
||||
log_entry = DatasetAutoDisableLog(
|
||||
id=fake.uuid4(),
|
||||
tenant_id=document.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
)
|
||||
log_entry.id = str(fake.uuid4())
|
||||
db.session.add(log_entry)
|
||||
auto_disable_logs.append(log_entry)
|
||||
|
||||
|
|
|
|||
|
|
@ -39,9 +39,9 @@ def test__to_model_settings(mocker: MockerFixture, mock_provider_entity):
|
|||
ps.id = "id"
|
||||
|
||||
provider_model_settings = [ps]
|
||||
|
||||
load_balancing_model_configs = [
|
||||
LoadBalancingModelConfig(
|
||||
id="id1",
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
|
|
@ -51,7 +51,6 @@ def test__to_model_settings(mocker: MockerFixture, mock_provider_entity):
|
|||
enabled=True,
|
||||
),
|
||||
LoadBalancingModelConfig(
|
||||
id="id2",
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
|
|
@ -61,6 +60,8 @@ def test__to_model_settings(mocker: MockerFixture, mock_provider_entity):
|
|||
enabled=True,
|
||||
),
|
||||
]
|
||||
load_balancing_model_configs[0].id = "id1"
|
||||
load_balancing_model_configs[1].id = "id2"
|
||||
|
||||
mocker.patch(
|
||||
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
|
||||
|
|
@ -101,7 +102,6 @@ def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_ent
|
|||
provider_model_settings = [ps]
|
||||
load_balancing_model_configs = [
|
||||
LoadBalancingModelConfig(
|
||||
id="id1",
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
|
|
@ -111,6 +111,7 @@ def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_ent
|
|||
enabled=True,
|
||||
)
|
||||
]
|
||||
load_balancing_model_configs[0].id = "id1"
|
||||
|
||||
mocker.patch(
|
||||
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
|
||||
|
|
@ -148,7 +149,6 @@ def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_ent
|
|||
provider_model_settings = [ps]
|
||||
load_balancing_model_configs = [
|
||||
LoadBalancingModelConfig(
|
||||
id="id1",
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
|
|
@ -158,7 +158,6 @@ def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_ent
|
|||
enabled=True,
|
||||
),
|
||||
LoadBalancingModelConfig(
|
||||
id="id2",
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
|
|
@ -168,6 +167,8 @@ def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_ent
|
|||
enabled=True,
|
||||
),
|
||||
]
|
||||
load_balancing_model_configs[0].id = "id1"
|
||||
load_balancing_model_configs[1].id = "id2"
|
||||
|
||||
mocker.patch(
|
||||
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
|
||||
|
|
|
|||
Loading…
Reference in New Issue