diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index 937b8f033c..7fb20c1941 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -1,5 +1,6 @@ import base64 import logging +import pickle from typing import Any, cast import numpy as np @@ -89,8 +90,8 @@ class CacheEmbedding(Embeddings): model_name=self._model_instance.model, hash=hash, provider_name=self._model_instance.provider, + embedding=pickle.dumps(n_embedding, protocol=pickle.HIGHEST_PROTOCOL), ) - embedding_cache.set_embedding(n_embedding) db.session.add(embedding_cache) cache_embeddings.append(hash) db.session.commit() diff --git a/api/models/dataset.py b/api/models/dataset.py index 3f2d16d3bd..2ea6d98b5f 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -307,7 +307,7 @@ class Dataset(Base): return f"{dify_config.VECTOR_INDEX_NAME_PREFIX}_{normalized_dataset_id}_Node" -class DatasetProcessRule(Base): +class DatasetProcessRule(Base): # bug __tablename__ = "dataset_process_rules" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"), @@ -1004,7 +1004,7 @@ class DatasetKeywordTable(TypeBase): return None -class Embedding(Base): +class Embedding(TypeBase): __tablename__ = "embeddings" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="embedding_pkey"), @@ -1012,12 +1012,16 @@ class Embedding(Base): sa.Index("created_at_idx", "created_at"), ) - id = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4())) - model_name = mapped_column(String(255), nullable=False, server_default=sa.text("'text-embedding-ada-002'")) - hash = mapped_column(String(64), nullable=False) - embedding = mapped_column(BinaryData, nullable=False) - created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) - provider_name = mapped_column(String(255), nullable=False, server_default=sa.text("''")) + id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4()), init=False) + model_name: Mapped[str] = mapped_column( + String(255), nullable=False, server_default=sa.text("'text-embedding-ada-002'") + ) + hash: Mapped[str] = mapped_column(String(64), nullable=False) + embedding: Mapped[bytes] = mapped_column(BinaryData, nullable=False) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False, server_default=sa.text("''")) def set_embedding(self, embedding_data: list[float]): self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL) @@ -1214,7 +1218,7 @@ class RateLimitLog(TypeBase): ) -class DatasetMetadata(Base): +class DatasetMetadata(TypeBase): __tablename__ = "dataset_metadatas" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="dataset_metadata_pkey"), @@ -1222,20 +1226,26 @@ class DatasetMetadata(Base): sa.Index("dataset_metadata_dataset_idx", "dataset_id"), ) - id = mapped_column(StringUUID, default=lambda: str(uuid4())) - tenant_id = mapped_column(StringUUID, nullable=False) - dataset_id = mapped_column(StringUUID, nullable=False) + id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False) type: Mapped[str] = mapped_column(String(255), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False) - created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column( - DateTime, nullable=False, server_default=sa.func.current_timestamp(), onupdate=func.current_timestamp() + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False ) - created_by = mapped_column(StringUUID, nullable=False) - updated_by = mapped_column(StringUUID, nullable=True) + updated_at: Mapped[datetime] = mapped_column( + DateTime, + nullable=False, + server_default=sa.func.current_timestamp(), + onupdate=func.current_timestamp(), + init=False, + ) + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + updated_by: Mapped[str] = mapped_column(StringUUID, nullable=True, default=None) -class DatasetMetadataBinding(Base): +class DatasetMetadataBinding(TypeBase): __tablename__ = "dataset_metadata_bindings" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="dataset_metadata_binding_pkey"), @@ -1245,13 +1255,15 @@ class DatasetMetadataBinding(Base): sa.Index("dataset_metadata_binding_document_idx", "document_id"), ) - id = mapped_column(StringUUID, default=lambda: str(uuid4())) - tenant_id = mapped_column(StringUUID, nullable=False) - dataset_id = mapped_column(StringUUID, nullable=False) - metadata_id = mapped_column(StringUUID, nullable=False) - document_id = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) - created_by = mapped_column(StringUUID, nullable=False) + id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + metadata_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + document_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) class PipelineBuiltInTemplate(TypeBase): @@ -1319,22 +1331,30 @@ class PipelineCustomizedTemplate(TypeBase): return "" -class Pipeline(Base): # type: ignore[name-defined] +class Pipeline(TypeBase): __tablename__ = "pipelines" __table_args__ = (sa.PrimaryKeyConstraint("id", name="pipeline_pkey"),) - id = mapped_column(StringUUID, default=lambda: str(uuidv7())) + id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - name = mapped_column(sa.String(255), nullable=False) - description = mapped_column(LongText, nullable=False, default=sa.text("''")) - workflow_id = mapped_column(StringUUID, nullable=True) - is_public = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) - is_published = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) - created_by = mapped_column(StringUUID, nullable=True) - created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_by = mapped_column(StringUUID, nullable=True) - updated_at = mapped_column( - sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + name: Mapped[str] = mapped_column(sa.String(255), nullable=False) + description: Mapped[str] = mapped_column(LongText, nullable=False, default=sa.text("''")) + workflow_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) + is_public: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"), default=False) + is_published: Mapped[bool] = mapped_column( + sa.Boolean, nullable=False, server_default=sa.text("false"), default=False + ) + created_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) + updated_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) + updated_at: Mapped[datetime] = mapped_column( + sa.DateTime, + nullable=False, + server_default=func.current_timestamp(), + onupdate=func.current_timestamp(), + init=False, ) def retrieve_dataset(self, session: Session): diff --git a/api/models/model.py b/api/models/model.py index e2b9da46f1..fb084d1dc6 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -533,7 +533,7 @@ class AppModelConfig(Base): return self -class RecommendedApp(Base): +class RecommendedApp(Base): # bug __tablename__ = "recommended_apps" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="recommended_app_pkey"), @@ -1294,7 +1294,7 @@ class Message(Base): ) -class MessageFeedback(Base): +class MessageFeedback(TypeBase): __tablename__ = "message_feedbacks" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="message_feedback_pkey"), @@ -1303,18 +1303,24 @@ class MessageFeedback(Base): sa.Index("message_feedback_conversation_idx", "conversation_id", "from_source", "rating"), ) - id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4())) + id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False) message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) rating: Mapped[str] = mapped_column(String(255), nullable=False) - content: Mapped[str | None] = mapped_column(LongText) from_source: Mapped[str] = mapped_column(String(255), nullable=False) - from_end_user_id: Mapped[str | None] = mapped_column(StringUUID) - from_account_id: Mapped[str | None] = mapped_column(StringUUID) - created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + content: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) + from_end_user_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) + from_account_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) updated_at: Mapped[datetime] = mapped_column( - sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + sa.DateTime, + nullable=False, + server_default=func.current_timestamp(), + onupdate=func.current_timestamp(), + init=False, ) @property @@ -1467,22 +1473,28 @@ class AppAnnotationSetting(TypeBase): return collection_binding_detail -class OperationLog(Base): +class OperationLog(TypeBase): __tablename__ = "operation_logs" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="operation_log_pkey"), sa.Index("operation_log_account_action_idx", "tenant_id", "account_id", "action"), ) - id = mapped_column(StringUUID, default=lambda: str(uuid4())) - tenant_id = mapped_column(StringUUID, nullable=False) - account_id = mapped_column(StringUUID, nullable=False) + id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + account_id: Mapped[str] = mapped_column(StringUUID, nullable=False) action: Mapped[str] = mapped_column(String(255), nullable=False) - content = mapped_column(sa.JSON) - created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + content: Mapped[Any] = mapped_column(sa.JSON) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) created_ip: Mapped[str] = mapped_column(String(255), nullable=False) - updated_at = mapped_column( - sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + updated_at: Mapped[datetime] = mapped_column( + sa.DateTime, + nullable=False, + server_default=func.current_timestamp(), + onupdate=func.current_timestamp(), + init=False, ) @@ -1627,7 +1639,7 @@ class Site(Base): return dify_config.APP_WEB_URL or request.url_root.rstrip("/") -class ApiToken(Base): +class ApiToken(Base): # bug: this uses setattr so idk the field. __tablename__ = "api_tokens" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="api_token_pkey"), @@ -1887,34 +1899,36 @@ class MessageAgentThought(Base): return {} -class DatasetRetrieverResource(Base): +class DatasetRetrieverResource(TypeBase): __tablename__ = "dataset_retriever_resources" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="dataset_retriever_resource_pkey"), sa.Index("dataset_retriever_resource_message_id_idx", "message_id"), ) - id = mapped_column(StringUUID, default=lambda: str(uuid4())) - message_id = mapped_column(StringUUID, nullable=False) + id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False) + message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) position: Mapped[int] = mapped_column(sa.Integer, nullable=False) - dataset_id = mapped_column(StringUUID, nullable=False) - dataset_name = mapped_column(LongText, nullable=False) - document_id = mapped_column(StringUUID, nullable=True) - document_name = mapped_column(LongText, nullable=False) - data_source_type = mapped_column(LongText, nullable=True) - segment_id = mapped_column(StringUUID, nullable=True) + dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + dataset_name: Mapped[str] = mapped_column(LongText, nullable=False) + document_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) + document_name: Mapped[str] = mapped_column(LongText, nullable=False) + data_source_type: Mapped[str | None] = mapped_column(LongText, nullable=True) + segment_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) score: Mapped[float | None] = mapped_column(sa.Float, nullable=True) - content = mapped_column(LongText, nullable=False) + content: Mapped[str] = mapped_column(LongText, nullable=False) hit_count: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) word_count: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) segment_position: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) - index_node_hash = mapped_column(LongText, nullable=True) - retriever_from = mapped_column(LongText, nullable=False) - created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.func.current_timestamp()) + index_node_hash: Mapped[str | None] = mapped_column(LongText, nullable=True) + retriever_from: Mapped[str] = mapped_column(LongText, nullable=False) + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False + ) -class Tag(Base): +class Tag(TypeBase): __tablename__ = "tags" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="tag_pkey"), @@ -1924,12 +1938,14 @@ class Tag(Base): TAG_TYPE_LIST = ["knowledge", "app"] - id = mapped_column(StringUUID, default=lambda: str(uuid4())) - tenant_id = mapped_column(StringUUID, nullable=True) - type = mapped_column(String(16), nullable=False) + id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False) + tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) + type: Mapped[str] = mapped_column(String(16), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False) - created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) class TagBinding(TypeBase): diff --git a/api/models/trigger.py b/api/models/trigger.py index e89309551a..088e797f82 100644 --- a/api/models/trigger.py +++ b/api/models/trigger.py @@ -17,7 +17,7 @@ from core.trigger.utils.endpoint import generate_plugin_trigger_endpoint_url, ge from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 -from .base import Base, TypeBase +from .base import TypeBase from .engine import db from .enums import AppTriggerStatus, AppTriggerType, CreatorUserRole, WorkflowTriggerStatus from .model import Account @@ -160,7 +160,7 @@ class TriggerOAuthTenantClient(TypeBase): return cast(Mapping[str, Any], json.loads(self.encrypted_oauth_params or "{}")) -class WorkflowTriggerLog(Base): +class WorkflowTriggerLog(TypeBase): """ Workflow Trigger Log @@ -202,7 +202,7 @@ class WorkflowTriggerLog(Base): sa.Index("workflow_trigger_log_workflow_id_idx", "workflow_id"), ) - id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7())) + id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False) @@ -214,24 +214,21 @@ class WorkflowTriggerLog(Base): inputs: Mapped[str] = mapped_column(LongText, nullable=False) # Just inputs for easy viewing outputs: Mapped[str | None] = mapped_column(LongText, nullable=True) - status: Mapped[str] = mapped_column( - EnumText(WorkflowTriggerStatus, length=50), nullable=False, default=WorkflowTriggerStatus.PENDING - ) + status: Mapped[str] = mapped_column(EnumText(WorkflowTriggerStatus, length=50), nullable=False) error: Mapped[str | None] = mapped_column(LongText, nullable=True) queue_name: Mapped[str] = mapped_column(String(100), nullable=False) celery_task_id: Mapped[str | None] = mapped_column(String(255), nullable=True) - retry_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) - - elapsed_time: Mapped[float | None] = mapped_column(sa.Float, nullable=True) - total_tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) - - created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) created_by: Mapped[str] = mapped_column(String(255), nullable=False) - - triggered_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) - finished_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + retry_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) + elapsed_time: Mapped[float | None] = mapped_column(sa.Float, nullable=True, default=None) + total_tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) + triggered_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None) + finished_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None) @property def created_by_account(self): diff --git a/api/models/workflow.py b/api/models/workflow.py index 0280353d45..f206a6a870 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -96,7 +96,7 @@ class _InvalidGraphDefinitionError(Exception): pass -class Workflow(Base): +class Workflow(Base): # bug """ Workflow, for `Workflow App` and `Chat App workflow mode`. diff --git a/api/services/async_workflow_service.py b/api/services/async_workflow_service.py index 8d62f121e2..e100582511 100644 --- a/api/services/async_workflow_service.py +++ b/api/services/async_workflow_service.py @@ -113,6 +113,8 @@ class AsyncWorkflowService: trigger_data.trigger_metadata.model_dump_json() if trigger_data.trigger_metadata else "{}" ), trigger_type=trigger_data.trigger_type, + workflow_run_id=None, + outputs=None, trigger_data=trigger_data.model_dump_json(), inputs=json.dumps(dict(trigger_data.inputs)), status=WorkflowTriggerStatus.PENDING, @@ -120,6 +122,10 @@ class AsyncWorkflowService: retry_count=0, created_by_role=created_by_role, created_by=created_by, + celery_task_id=None, + error=None, + elapsed_time=None, + total_tokens=None, ) trigger_log = trigger_log_repo.create(trigger_log) diff --git a/api/services/message_service.py b/api/services/message_service.py index 7ed56d80f2..e1a256e64d 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -164,6 +164,7 @@ class MessageService: elif not rating and not feedback: raise ValueError("rating cannot be None when feedback not exists") else: + assert rating is not None feedback = MessageFeedback( app_id=app_model.id, conversation_id=message.conversation_id, diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index c02fad4dc6..06f294863d 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -580,13 +580,14 @@ class RagPipelineDslService: raise ValueError("Current tenant is not set") # Create new app - pipeline = Pipeline() + pipeline = Pipeline( + tenant_id=account.current_tenant_id, + name=pipeline_data.get("name", ""), + description=pipeline_data.get("description", ""), + created_by=account.id, + updated_by=account.id, + ) pipeline.id = str(uuid4()) - pipeline.tenant_id = account.current_tenant_id - pipeline.name = pipeline_data.get("name", "") - pipeline.description = pipeline_data.get("description", "") - pipeline.created_by = account.id - pipeline.updated_by = account.id self._session.add(pipeline) self._session.commit() diff --git a/api/services/rag_pipeline/rag_pipeline_transform_service.py b/api/services/rag_pipeline/rag_pipeline_transform_service.py index 22025dd44a..84f97907c0 100644 --- a/api/services/rag_pipeline/rag_pipeline_transform_service.py +++ b/api/services/rag_pipeline/rag_pipeline_transform_service.py @@ -198,15 +198,16 @@ class RagPipelineTransformService: graph = workflow_data.get("graph", {}) # Create new app - pipeline = Pipeline() + pipeline = Pipeline( + tenant_id=current_user.current_tenant_id, + name=pipeline_data.get("name", ""), + description=pipeline_data.get("description", ""), + created_by=current_user.id, + updated_by=current_user.id, + is_published=True, + is_public=True, + ) pipeline.id = str(uuid4()) - pipeline.tenant_id = current_user.current_tenant_id - pipeline.name = pipeline_data.get("name", "") - pipeline.description = pipeline_data.get("description", "") - pipeline.created_by = current_user.id - pipeline.updated_by = current_user.id - pipeline.is_published = True - pipeline.is_public = True db.session.add(pipeline) db.session.flush() diff --git a/api/services/tag_service.py b/api/services/tag_service.py index db7ed3d5c3..937e6593fe 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -79,12 +79,12 @@ class TagService: if TagService.get_tag_by_tag_name(args["type"], current_user.current_tenant_id, args["name"]): raise ValueError("Tag name already exists") tag = Tag( - id=str(uuid.uuid4()), name=args["name"], type=args["type"], created_by=current_user.id, tenant_id=current_user.current_tenant_id, ) + tag.id = str(uuid.uuid4()) db.session.add(tag) db.session.commit() return tag diff --git a/api/tasks/trigger_processing_tasks.py b/api/tasks/trigger_processing_tasks.py index 2619d8dd28..ee1d31aa91 100644 --- a/api/tasks/trigger_processing_tasks.py +++ b/api/tasks/trigger_processing_tasks.py @@ -218,6 +218,8 @@ def _record_trigger_failure_log( finished_at=now, elapsed_time=0.0, total_tokens=0, + outputs=None, + celery_task_id=None, ) session.add(trigger_log) session.commit() diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py index 45eb9d4f78..9297e997e9 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py @@ -384,24 +384,24 @@ class TestCleanDatasetTask: # Create dataset metadata and bindings metadata = DatasetMetadata( - id=str(uuid.uuid4()), dataset_id=dataset.id, tenant_id=tenant.id, name="test_metadata", type="string", created_by=account.id, - created_at=datetime.now(), ) + metadata.id = str(uuid.uuid4()) + metadata.created_at = datetime.now() binding = DatasetMetadataBinding( - id=str(uuid.uuid4()), tenant_id=tenant.id, dataset_id=dataset.id, metadata_id=metadata.id, document_id=documents[0].id, # Use first document as example created_by=account.id, - created_at=datetime.now(), ) + binding.id = str(uuid.uuid4()) + binding.created_at = datetime.now() from extensions.ext_database import db @@ -697,26 +697,26 @@ class TestCleanDatasetTask: for i in range(10): # Create 10 metadata items metadata = DatasetMetadata( - id=str(uuid.uuid4()), dataset_id=dataset.id, tenant_id=tenant.id, name=f"test_metadata_{i}", type="string", created_by=account.id, - created_at=datetime.now(), ) + metadata.id = str(uuid.uuid4()) + metadata.created_at = datetime.now() metadata_items.append(metadata) # Create binding for each metadata item binding = DatasetMetadataBinding( - id=str(uuid.uuid4()), tenant_id=tenant.id, dataset_id=dataset.id, metadata_id=metadata.id, document_id=documents[i % len(documents)].id, created_by=account.id, - created_at=datetime.now(), ) + binding.id = str(uuid.uuid4()) + binding.created_at = datetime.now() bindings.append(binding) from extensions.ext_database import db @@ -966,14 +966,15 @@ class TestCleanDatasetTask: # Create metadata with special characters special_metadata = DatasetMetadata( - id=str(uuid.uuid4()), dataset_id=dataset.id, tenant_id=tenant.id, name=f"metadata_{special_content}", type="string", created_by=account.id, - created_at=datetime.now(), ) + special_metadata.id = str(uuid.uuid4()) + special_metadata.created_at = datetime.now() + db.session.add(special_metadata) db.session.commit() diff --git a/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py b/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py index c82162238c..e29b98037f 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py +++ b/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py @@ -112,13 +112,13 @@ class TestRagPipelineRunTasks: # Create pipeline pipeline = Pipeline( - id=str(uuid.uuid4()), tenant_id=tenant.id, workflow_id=workflow.id, name=fake.company(), description=fake.text(max_nb_chars=100), created_by=account.id, ) + pipeline.id = str(uuid.uuid4()) db.session.add(pipeline) db.session.commit()