diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py index 44735eb769..0d8f2c09ab 100644 --- a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py @@ -8,7 +8,7 @@ using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations. import json from collections.abc import Sequence from datetime import datetime -from typing import Protocol, cast +from typing import Protocol, cast, override from sqlalchemy import asc, delete, desc, func, select from sqlalchemy.engine import CursorResult @@ -65,6 +65,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut """ self._session_maker = session_maker + @override def get_node_last_execution( self, tenant_id: str, @@ -106,6 +107,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut with self._session_maker() as session: return session.scalar(stmt) + @override def get_executions_by_workflow_run( self, tenant_id: str, @@ -136,6 +138,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut with self._session_maker() as session: return session.execute(stmt).scalars().all() + @override def get_execution_snapshots_by_workflow_run( self, tenant_id: str, @@ -210,6 +213,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut loop_id=str(loop_id) if loop_id else None, ) + @override def get_execution_by_id( self, execution_id: str, @@ -242,6 +246,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut with self._session_maker() as session: return session.scalar(stmt) + @override def delete_expired_executions( self, tenant_id: str, @@ -289,6 +294,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut return total_deleted + @override def delete_executions_by_app( self, tenant_id: str, @@ -336,6 +342,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut return total_deleted + @override def get_expired_executions_batch( self, tenant_id: str, @@ -365,6 +372,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut with self._session_maker() as session: return session.execute(stmt).scalars().all() + @override def delete_executions_by_ids( self, execution_ids: Sequence[str], @@ -387,6 +395,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut session.commit() return result.rowcount + @override def delete_by_runs(self, session: Session, run_ids: Sequence[str]) -> tuple[int, int]: """ Delete node executions (and offloads) for the given workflow runs using workflow_run_id. @@ -420,6 +429,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut return node_executions_deleted, offloads_deleted + @override def count_by_runs(self, session: Session, run_ids: Sequence[str]) -> tuple[int, int]: """ Count node executions (and offloads) for the given workflow runs using workflow_run_id. @@ -456,6 +466,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut stmt = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.workflow_run_id == run_id) return list(session.scalars(stmt)) + @override def get_offloads_by_execution_ids( self, session: Session, diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index 71a2554a60..cbc9d03e5e 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -25,7 +25,7 @@ import uuid from collections.abc import Callable, Sequence from datetime import datetime from decimal import Decimal -from typing import Any, cast +from typing import Any, cast, override import sqlalchemy as sa from pydantic import ValidationError @@ -123,6 +123,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): """ self._session_maker = session_maker + @override def get_paginated_workflow_runs( self, tenant_id: str, @@ -180,6 +181,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more) + @override def get_workflow_run_by_id( self, tenant_id: str, @@ -197,6 +199,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): ) return session.scalar(stmt) + @override def get_workflow_run_by_id_without_tenant( self, run_id: str, @@ -208,6 +211,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): stmt = select(WorkflowRun).where(WorkflowRun.id == run_id) return session.scalar(stmt) + @override def get_workflow_runs_count( self, tenant_id: str, @@ -275,6 +279,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): return {"total": total} | status_counts + @override def get_expired_runs_batch( self, tenant_id: str, @@ -295,6 +300,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): ) return session.scalars(stmt).all() + @override def delete_runs_by_ids( self, run_ids: Sequence[str], @@ -314,6 +320,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): logger.info("Deleted %s workflow runs by IDs", deleted_count) return deleted_count + @override def delete_runs_by_app( self, tenant_id: str, @@ -358,6 +365,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): logger.info("Total deleted %s workflow runs for app %s", total_deleted, app_id) return total_deleted + @override def get_runs_batch_by_time_range( self, start_from: datetime | None, @@ -412,6 +420,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): return session.scalars(stmt).all() + @override def get_archived_run_ids( self, session: Session, @@ -423,6 +432,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): stmt = select(WorkflowArchiveLog.workflow_run_id).where(WorkflowArchiveLog.workflow_run_id.in_(run_ids)) return set(session.scalars(stmt).all()) + @override def get_archived_log_by_run_id( self, run_id: str, @@ -431,6 +441,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): stmt = select(WorkflowArchiveLog).where(WorkflowArchiveLog.workflow_run_id == run_id).limit(1) return session.scalar(stmt) + @override def delete_archive_log_by_run_id( self, session: Session, @@ -440,6 +451,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): result = session.execute(stmt) return cast(CursorResult, result).rowcount or 0 + @override def get_pause_records_by_run_id( self, session: Session, @@ -448,6 +460,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): stmt = select(WorkflowPause).where(WorkflowPause.workflow_run_id == run_id) return list(session.scalars(stmt)) + @override def get_pause_reason_records_by_run_id( self, session: Session, @@ -459,6 +472,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): stmt = select(WorkflowPauseReason).where(WorkflowPauseReason.pause_id.in_(pause_ids)) return list(session.scalars(stmt)) + @override def delete_runs_with_related( self, runs: Sequence[WorkflowRun], @@ -516,6 +530,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): "pause_reasons": pause_reasons_deleted, } + @override def get_app_logs_by_run_id( self, session: Session, @@ -524,6 +539,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): stmt = select(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id == run_id) return list(session.scalars(stmt)) + @override def create_archive_logs( self, session: Session, @@ -585,6 +601,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): session.add_all(archive_logs) return len(archive_logs) + @override def get_archived_runs_by_time_range( self, session: Session, @@ -612,6 +629,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): stmt = stmt.where(WorkflowArchiveLog.tenant_id.in_(tenant_ids)) return list(session.scalars(stmt)) + @override def get_archived_logs_by_time_range( self, session: Session, @@ -634,6 +652,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): stmt = stmt.where(WorkflowArchiveLog.tenant_id.in_(tenant_ids)) return list(session.scalars(stmt)) + @override def count_runs_with_related( self, runs: Sequence[WorkflowRun], @@ -692,6 +711,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): "pause_reasons": int(pause_reasons_count), } + @override def create_workflow_pause( self, workflow_run_id: str, @@ -827,6 +847,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): pause_reasons.append(reason.to_entity()) return pause_reasons + @override def get_workflow_pause( self, workflow_run_id: str, @@ -866,6 +887,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): pause_reasons=pause_reasons, ) + @override def resume_workflow_pause( self, workflow_run_id: str, @@ -934,6 +956,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): pause_reasons=hydrated_pause_reasons, ) + @override def delete_workflow_pause( self, pause_entity: WorkflowPauseEntity, @@ -972,6 +995,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): logger.info("Deleted workflow pause %s for workflow run %s", pause_model.id, pause_model.workflow_run_id) + @override def prune_pauses( self, expiration: datetime, @@ -1044,6 +1068,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): return pruned_record_ids + @override def get_daily_runs_statistics( self, tenant_id: str, @@ -1092,6 +1117,7 @@ WHERE return cast(list[DailyRunsStats], response_data) + @override def get_daily_terminals_statistics( self, tenant_id: str, @@ -1140,6 +1166,7 @@ WHERE return cast(list[DailyTerminalsStats], response_data) + @override def get_daily_token_cost_statistics( self, tenant_id: str, @@ -1193,6 +1220,7 @@ WHERE return cast(list[DailyTokenCostStats], response_data) + @override def get_average_app_interaction_statistics( self, tenant_id: str, @@ -1258,6 +1286,7 @@ GROUP BY return cast(list[AverageInteractionStats], response_data) + @override def get_workflow_run_by_id_and_tenant_id(self, tenant_id: str, run_id: str) -> WorkflowRun | None: """Get a specific workflow run by its id and the associated tenant id.""" with self._session_maker() as session: @@ -1291,13 +1320,16 @@ class _PrivateWorkflowPauseEntity(WorkflowPauseEntity): self._human_input_form = human_input_form @property + @override def id(self) -> str: return self._pause_model.id @property + @override def workflow_execution_id(self) -> str: return self._pause_model.workflow_run_id + @override def get_state(self) -> bytes: """ Retrieve the serialized workflow state from storage. @@ -1319,14 +1351,17 @@ class _PrivateWorkflowPauseEntity(WorkflowPauseEntity): return state_data @property + @override def resumed_at(self) -> datetime | None: return self._pause_model.resumed_at + @override def get_pause_reasons(self) -> Sequence[PauseReason]: if self._pause_reasons is not None: - return list(self._pause_reasons) - return [reason.to_entity() for reason in self._reason_models] + return list(self._pause_reasons) # type: ignore + return [reason.to_entity() for reason in self._reason_models] # type: ignore @property + @override def paused_at(self) -> datetime: return self._pause_model.created_at diff --git a/api/repositories/sqlalchemy_execution_extra_content_repository.py b/api/repositories/sqlalchemy_execution_extra_content_repository.py index 65a1edbf2d..cfac363b19 100644 --- a/api/repositories/sqlalchemy_execution_extra_content_repository.py +++ b/api/repositories/sqlalchemy_execution_extra_content_repository.py @@ -5,7 +5,7 @@ import logging import re from collections import defaultdict from collections.abc import Sequence -from typing import Any +from typing import Any, override from sqlalchemy import select from sqlalchemy.orm import Session, selectinload, sessionmaker @@ -45,6 +45,7 @@ class SQLAlchemyExecutionExtraContentRepository(ExecutionExtraContentRepository) def __init__(self, session_maker: sessionmaker[Session]): self._session_maker = session_maker + @override def get_by_message_ids(self, message_ids: Sequence[str]) -> list[list[ExecutionExtraContentDomainModel]]: if not message_ids: return [] diff --git a/api/repositories/sqlalchemy_workflow_trigger_log_repository.py b/api/repositories/sqlalchemy_workflow_trigger_log_repository.py index 1f6740b066..f805fa0e4b 100644 --- a/api/repositories/sqlalchemy_workflow_trigger_log_repository.py +++ b/api/repositories/sqlalchemy_workflow_trigger_log_repository.py @@ -4,7 +4,7 @@ SQLAlchemy implementation of WorkflowTriggerLogRepository. from collections.abc import Sequence from datetime import UTC, datetime, timedelta -from typing import cast +from typing import cast, override from sqlalchemy import and_, delete, func, select from sqlalchemy.engine import CursorResult @@ -25,18 +25,21 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository): def __init__(self, session: Session): self.session = session + @override def create(self, trigger_log: WorkflowTriggerLog) -> WorkflowTriggerLog: """Create a new trigger log entry.""" self.session.add(trigger_log) self.session.flush() return trigger_log + @override def update(self, trigger_log: WorkflowTriggerLog) -> WorkflowTriggerLog: """Update an existing trigger log entry.""" self.session.merge(trigger_log) self.session.flush() return trigger_log + @override def get_by_id(self, trigger_log_id: str, tenant_id: str | None = None) -> WorkflowTriggerLog | None: """Get a trigger log by its ID.""" query = select(WorkflowTriggerLog).where(WorkflowTriggerLog.id == trigger_log_id) @@ -51,6 +54,7 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository): query = select(WorkflowTriggerLog).where(WorkflowTriggerLog.workflow_run_id == run_id) return list(self.session.scalars(query).all()) + @override def get_failed_for_retry( self, tenant_id: str, max_retry_count: int = 3, limit: int = 100 ) -> Sequence[WorkflowTriggerLog]: @@ -70,6 +74,7 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository): return list(self.session.scalars(query).all()) + @override def get_recent_logs( self, tenant_id: str, app_id: str, hours: int = 24, limit: int = 100, offset: int = 0 ) -> Sequence[WorkflowTriggerLog]: @@ -92,6 +97,7 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository): return list(self.session.scalars(query).all()) + @override def get_by_workflow_run_id(self, workflow_run_id: str) -> WorkflowTriggerLog | None: """Get the trigger log associated with a workflow run.""" query = ( @@ -102,6 +108,7 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository): ) return self.session.scalar(query) + @override def delete_by_run_ids(self, run_ids: Sequence[str]) -> int: """ Delete trigger logs associated with the given workflow run ids. diff --git a/api/repositories/workflow_collaboration_repository.py b/api/repositories/workflow_collaboration_repository.py index 000f80496d..df6cb63515 100644 --- a/api/repositories/workflow_collaboration_repository.py +++ b/api/repositories/workflow_collaboration_repository.py @@ -1,7 +1,7 @@ from __future__ import annotations import json -from typing import TypedDict +from typing import TypedDict, override from extensions.ext_redis import redis_client @@ -28,6 +28,7 @@ class WorkflowCollaborationRepository: def __init__(self) -> None: self._redis = redis_client + @override def __repr__(self) -> str: return f"{self.__class__.__name__}(redis_client={self._redis})"