chore: add missing @override decorators to api/repositories (#37138)

This commit is contained in:
eryue0220 2026-06-07 20:08:22 +08:00 committed by GitHub
parent fad5656b2e
commit 196c040c99
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 62 additions and 7 deletions

View File

@ -8,7 +8,7 @@ using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations.
import json
from collections.abc import Sequence
from datetime import datetime
from typing import Protocol, cast
from typing import Protocol, cast, override
from sqlalchemy import asc, delete, desc, func, select
from sqlalchemy.engine import CursorResult
@ -65,6 +65,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
"""
self._session_maker = session_maker
@override
def get_node_last_execution(
self,
tenant_id: str,
@ -106,6 +107,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
with self._session_maker() as session:
return session.scalar(stmt)
@override
def get_executions_by_workflow_run(
self,
tenant_id: str,
@ -136,6 +138,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
with self._session_maker() as session:
return session.execute(stmt).scalars().all()
@override
def get_execution_snapshots_by_workflow_run(
self,
tenant_id: str,
@ -210,6 +213,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
loop_id=str(loop_id) if loop_id else None,
)
@override
def get_execution_by_id(
self,
execution_id: str,
@ -242,6 +246,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
with self._session_maker() as session:
return session.scalar(stmt)
@override
def delete_expired_executions(
self,
tenant_id: str,
@ -289,6 +294,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
return total_deleted
@override
def delete_executions_by_app(
self,
tenant_id: str,
@ -336,6 +342,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
return total_deleted
@override
def get_expired_executions_batch(
self,
tenant_id: str,
@ -365,6 +372,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
with self._session_maker() as session:
return session.execute(stmt).scalars().all()
@override
def delete_executions_by_ids(
self,
execution_ids: Sequence[str],
@ -387,6 +395,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
session.commit()
return result.rowcount
@override
def delete_by_runs(self, session: Session, run_ids: Sequence[str]) -> tuple[int, int]:
"""
Delete node executions (and offloads) for the given workflow runs using workflow_run_id.
@ -420,6 +429,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
return node_executions_deleted, offloads_deleted
@override
def count_by_runs(self, session: Session, run_ids: Sequence[str]) -> tuple[int, int]:
"""
Count node executions (and offloads) for the given workflow runs using workflow_run_id.
@ -456,6 +466,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
stmt = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.workflow_run_id == run_id)
return list(session.scalars(stmt))
@override
def get_offloads_by_execution_ids(
self,
session: Session,

View File

@ -25,7 +25,7 @@ import uuid
from collections.abc import Callable, Sequence
from datetime import datetime
from decimal import Decimal
from typing import Any, cast
from typing import Any, cast, override
import sqlalchemy as sa
from pydantic import ValidationError
@ -123,6 +123,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
"""
self._session_maker = session_maker
@override
def get_paginated_workflow_runs(
self,
tenant_id: str,
@ -180,6 +181,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
@override
def get_workflow_run_by_id(
self,
tenant_id: str,
@ -197,6 +199,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
)
return session.scalar(stmt)
@override
def get_workflow_run_by_id_without_tenant(
self,
run_id: str,
@ -208,6 +211,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
stmt = select(WorkflowRun).where(WorkflowRun.id == run_id)
return session.scalar(stmt)
@override
def get_workflow_runs_count(
self,
tenant_id: str,
@ -275,6 +279,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
return {"total": total} | status_counts
@override
def get_expired_runs_batch(
self,
tenant_id: str,
@ -295,6 +300,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
)
return session.scalars(stmt).all()
@override
def delete_runs_by_ids(
self,
run_ids: Sequence[str],
@ -314,6 +320,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
logger.info("Deleted %s workflow runs by IDs", deleted_count)
return deleted_count
@override
def delete_runs_by_app(
self,
tenant_id: str,
@ -358,6 +365,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
logger.info("Total deleted %s workflow runs for app %s", total_deleted, app_id)
return total_deleted
@override
def get_runs_batch_by_time_range(
self,
start_from: datetime | None,
@ -412,6 +420,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
return session.scalars(stmt).all()
@override
def get_archived_run_ids(
self,
session: Session,
@ -423,6 +432,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
stmt = select(WorkflowArchiveLog.workflow_run_id).where(WorkflowArchiveLog.workflow_run_id.in_(run_ids))
return set(session.scalars(stmt).all())
@override
def get_archived_log_by_run_id(
self,
run_id: str,
@ -431,6 +441,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
stmt = select(WorkflowArchiveLog).where(WorkflowArchiveLog.workflow_run_id == run_id).limit(1)
return session.scalar(stmt)
@override
def delete_archive_log_by_run_id(
self,
session: Session,
@ -440,6 +451,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
result = session.execute(stmt)
return cast(CursorResult, result).rowcount or 0
@override
def get_pause_records_by_run_id(
self,
session: Session,
@ -448,6 +460,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
stmt = select(WorkflowPause).where(WorkflowPause.workflow_run_id == run_id)
return list(session.scalars(stmt))
@override
def get_pause_reason_records_by_run_id(
self,
session: Session,
@ -459,6 +472,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
stmt = select(WorkflowPauseReason).where(WorkflowPauseReason.pause_id.in_(pause_ids))
return list(session.scalars(stmt))
@override
def delete_runs_with_related(
self,
runs: Sequence[WorkflowRun],
@ -516,6 +530,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
"pause_reasons": pause_reasons_deleted,
}
@override
def get_app_logs_by_run_id(
self,
session: Session,
@ -524,6 +539,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
stmt = select(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id == run_id)
return list(session.scalars(stmt))
@override
def create_archive_logs(
self,
session: Session,
@ -585,6 +601,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
session.add_all(archive_logs)
return len(archive_logs)
@override
def get_archived_runs_by_time_range(
self,
session: Session,
@ -612,6 +629,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
stmt = stmt.where(WorkflowArchiveLog.tenant_id.in_(tenant_ids))
return list(session.scalars(stmt))
@override
def get_archived_logs_by_time_range(
self,
session: Session,
@ -634,6 +652,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
stmt = stmt.where(WorkflowArchiveLog.tenant_id.in_(tenant_ids))
return list(session.scalars(stmt))
@override
def count_runs_with_related(
self,
runs: Sequence[WorkflowRun],
@ -692,6 +711,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
"pause_reasons": int(pause_reasons_count),
}
@override
def create_workflow_pause(
self,
workflow_run_id: str,
@ -827,6 +847,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
pause_reasons.append(reason.to_entity())
return pause_reasons
@override
def get_workflow_pause(
self,
workflow_run_id: str,
@ -866,6 +887,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
pause_reasons=pause_reasons,
)
@override
def resume_workflow_pause(
self,
workflow_run_id: str,
@ -934,6 +956,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
pause_reasons=hydrated_pause_reasons,
)
@override
def delete_workflow_pause(
self,
pause_entity: WorkflowPauseEntity,
@ -972,6 +995,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
logger.info("Deleted workflow pause %s for workflow run %s", pause_model.id, pause_model.workflow_run_id)
@override
def prune_pauses(
self,
expiration: datetime,
@ -1044,6 +1068,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
return pruned_record_ids
@override
def get_daily_runs_statistics(
self,
tenant_id: str,
@ -1092,6 +1117,7 @@ WHERE
return cast(list[DailyRunsStats], response_data)
@override
def get_daily_terminals_statistics(
self,
tenant_id: str,
@ -1140,6 +1166,7 @@ WHERE
return cast(list[DailyTerminalsStats], response_data)
@override
def get_daily_token_cost_statistics(
self,
tenant_id: str,
@ -1193,6 +1220,7 @@ WHERE
return cast(list[DailyTokenCostStats], response_data)
@override
def get_average_app_interaction_statistics(
self,
tenant_id: str,
@ -1258,6 +1286,7 @@ GROUP BY
return cast(list[AverageInteractionStats], response_data)
@override
def get_workflow_run_by_id_and_tenant_id(self, tenant_id: str, run_id: str) -> WorkflowRun | None:
"""Get a specific workflow run by its id and the associated tenant id."""
with self._session_maker() as session:
@ -1291,13 +1320,16 @@ class _PrivateWorkflowPauseEntity(WorkflowPauseEntity):
self._human_input_form = human_input_form
@property
@override
def id(self) -> str:
return self._pause_model.id
@property
@override
def workflow_execution_id(self) -> str:
return self._pause_model.workflow_run_id
@override
def get_state(self) -> bytes:
"""
Retrieve the serialized workflow state from storage.
@ -1319,14 +1351,17 @@ class _PrivateWorkflowPauseEntity(WorkflowPauseEntity):
return state_data
@property
@override
def resumed_at(self) -> datetime | None:
return self._pause_model.resumed_at
@override
def get_pause_reasons(self) -> Sequence[PauseReason]:
if self._pause_reasons is not None:
return list(self._pause_reasons)
return [reason.to_entity() for reason in self._reason_models]
return list(self._pause_reasons) # type: ignore
return [reason.to_entity() for reason in self._reason_models] # type: ignore
@property
@override
def paused_at(self) -> datetime:
return self._pause_model.created_at

View File

@ -5,7 +5,7 @@ import logging
import re
from collections import defaultdict
from collections.abc import Sequence
from typing import Any
from typing import Any, override
from sqlalchemy import select
from sqlalchemy.orm import Session, selectinload, sessionmaker
@ -45,6 +45,7 @@ class SQLAlchemyExecutionExtraContentRepository(ExecutionExtraContentRepository)
def __init__(self, session_maker: sessionmaker[Session]):
self._session_maker = session_maker
@override
def get_by_message_ids(self, message_ids: Sequence[str]) -> list[list[ExecutionExtraContentDomainModel]]:
if not message_ids:
return []

View File

@ -4,7 +4,7 @@ SQLAlchemy implementation of WorkflowTriggerLogRepository.
from collections.abc import Sequence
from datetime import UTC, datetime, timedelta
from typing import cast
from typing import cast, override
from sqlalchemy import and_, delete, func, select
from sqlalchemy.engine import CursorResult
@ -25,18 +25,21 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository):
def __init__(self, session: Session):
self.session = session
@override
def create(self, trigger_log: WorkflowTriggerLog) -> WorkflowTriggerLog:
"""Create a new trigger log entry."""
self.session.add(trigger_log)
self.session.flush()
return trigger_log
@override
def update(self, trigger_log: WorkflowTriggerLog) -> WorkflowTriggerLog:
"""Update an existing trigger log entry."""
self.session.merge(trigger_log)
self.session.flush()
return trigger_log
@override
def get_by_id(self, trigger_log_id: str, tenant_id: str | None = None) -> WorkflowTriggerLog | None:
"""Get a trigger log by its ID."""
query = select(WorkflowTriggerLog).where(WorkflowTriggerLog.id == trigger_log_id)
@ -51,6 +54,7 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository):
query = select(WorkflowTriggerLog).where(WorkflowTriggerLog.workflow_run_id == run_id)
return list(self.session.scalars(query).all())
@override
def get_failed_for_retry(
self, tenant_id: str, max_retry_count: int = 3, limit: int = 100
) -> Sequence[WorkflowTriggerLog]:
@ -70,6 +74,7 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository):
return list(self.session.scalars(query).all())
@override
def get_recent_logs(
self, tenant_id: str, app_id: str, hours: int = 24, limit: int = 100, offset: int = 0
) -> Sequence[WorkflowTriggerLog]:
@ -92,6 +97,7 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository):
return list(self.session.scalars(query).all())
@override
def get_by_workflow_run_id(self, workflow_run_id: str) -> WorkflowTriggerLog | None:
"""Get the trigger log associated with a workflow run."""
query = (
@ -102,6 +108,7 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository):
)
return self.session.scalar(query)
@override
def delete_by_run_ids(self, run_ids: Sequence[str]) -> int:
"""
Delete trigger logs associated with the given workflow run ids.

View File

@ -1,7 +1,7 @@
from __future__ import annotations
import json
from typing import TypedDict
from typing import TypedDict, override
from extensions.ext_redis import redis_client
@ -28,6 +28,7 @@ class WorkflowCollaborationRepository:
def __init__(self) -> None:
self._redis = redis_client
@override
def __repr__(self) -> str:
return f"{self.__class__.__name__}(redis_client={self._redis})"