mirror of
https://github.com/langgenius/dify.git
synced 2026-06-10 18:24:09 +08:00
chore: add missing @override decorators to api/repositories (#37138)
This commit is contained in:
parent
fad5656b2e
commit
196c040c99
@ -8,7 +8,7 @@ using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations.
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Protocol, cast
|
||||
from typing import Protocol, cast, override
|
||||
|
||||
from sqlalchemy import asc, delete, desc, func, select
|
||||
from sqlalchemy.engine import CursorResult
|
||||
@ -65,6 +65,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
||||
"""
|
||||
self._session_maker = session_maker
|
||||
|
||||
@override
|
||||
def get_node_last_execution(
|
||||
self,
|
||||
tenant_id: str,
|
||||
@ -106,6 +107,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
||||
with self._session_maker() as session:
|
||||
return session.scalar(stmt)
|
||||
|
||||
@override
|
||||
def get_executions_by_workflow_run(
|
||||
self,
|
||||
tenant_id: str,
|
||||
@ -136,6 +138,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
||||
with self._session_maker() as session:
|
||||
return session.execute(stmt).scalars().all()
|
||||
|
||||
@override
|
||||
def get_execution_snapshots_by_workflow_run(
|
||||
self,
|
||||
tenant_id: str,
|
||||
@ -210,6 +213,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
||||
loop_id=str(loop_id) if loop_id else None,
|
||||
)
|
||||
|
||||
@override
|
||||
def get_execution_by_id(
|
||||
self,
|
||||
execution_id: str,
|
||||
@ -242,6 +246,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
||||
with self._session_maker() as session:
|
||||
return session.scalar(stmt)
|
||||
|
||||
@override
|
||||
def delete_expired_executions(
|
||||
self,
|
||||
tenant_id: str,
|
||||
@ -289,6 +294,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
||||
|
||||
return total_deleted
|
||||
|
||||
@override
|
||||
def delete_executions_by_app(
|
||||
self,
|
||||
tenant_id: str,
|
||||
@ -336,6 +342,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
||||
|
||||
return total_deleted
|
||||
|
||||
@override
|
||||
def get_expired_executions_batch(
|
||||
self,
|
||||
tenant_id: str,
|
||||
@ -365,6 +372,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
||||
with self._session_maker() as session:
|
||||
return session.execute(stmt).scalars().all()
|
||||
|
||||
@override
|
||||
def delete_executions_by_ids(
|
||||
self,
|
||||
execution_ids: Sequence[str],
|
||||
@ -387,6 +395,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
||||
session.commit()
|
||||
return result.rowcount
|
||||
|
||||
@override
|
||||
def delete_by_runs(self, session: Session, run_ids: Sequence[str]) -> tuple[int, int]:
|
||||
"""
|
||||
Delete node executions (and offloads) for the given workflow runs using workflow_run_id.
|
||||
@ -420,6 +429,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
||||
|
||||
return node_executions_deleted, offloads_deleted
|
||||
|
||||
@override
|
||||
def count_by_runs(self, session: Session, run_ids: Sequence[str]) -> tuple[int, int]:
|
||||
"""
|
||||
Count node executions (and offloads) for the given workflow runs using workflow_run_id.
|
||||
@ -456,6 +466,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
||||
stmt = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.workflow_run_id == run_id)
|
||||
return list(session.scalars(stmt))
|
||||
|
||||
@override
|
||||
def get_offloads_by_execution_ids(
|
||||
self,
|
||||
session: Session,
|
||||
|
||||
@ -25,7 +25,7 @@ import uuid
|
||||
from collections.abc import Callable, Sequence
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from typing import Any, cast
|
||||
from typing import Any, cast, override
|
||||
|
||||
import sqlalchemy as sa
|
||||
from pydantic import ValidationError
|
||||
@ -123,6 +123,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
"""
|
||||
self._session_maker = session_maker
|
||||
|
||||
@override
|
||||
def get_paginated_workflow_runs(
|
||||
self,
|
||||
tenant_id: str,
|
||||
@ -180,6 +181,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
|
||||
return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
|
||||
|
||||
@override
|
||||
def get_workflow_run_by_id(
|
||||
self,
|
||||
tenant_id: str,
|
||||
@ -197,6 +199,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
)
|
||||
return session.scalar(stmt)
|
||||
|
||||
@override
|
||||
def get_workflow_run_by_id_without_tenant(
|
||||
self,
|
||||
run_id: str,
|
||||
@ -208,6 +211,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
stmt = select(WorkflowRun).where(WorkflowRun.id == run_id)
|
||||
return session.scalar(stmt)
|
||||
|
||||
@override
|
||||
def get_workflow_runs_count(
|
||||
self,
|
||||
tenant_id: str,
|
||||
@ -275,6 +279,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
|
||||
return {"total": total} | status_counts
|
||||
|
||||
@override
|
||||
def get_expired_runs_batch(
|
||||
self,
|
||||
tenant_id: str,
|
||||
@ -295,6 +300,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
)
|
||||
return session.scalars(stmt).all()
|
||||
|
||||
@override
|
||||
def delete_runs_by_ids(
|
||||
self,
|
||||
run_ids: Sequence[str],
|
||||
@ -314,6 +320,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
logger.info("Deleted %s workflow runs by IDs", deleted_count)
|
||||
return deleted_count
|
||||
|
||||
@override
|
||||
def delete_runs_by_app(
|
||||
self,
|
||||
tenant_id: str,
|
||||
@ -358,6 +365,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
logger.info("Total deleted %s workflow runs for app %s", total_deleted, app_id)
|
||||
return total_deleted
|
||||
|
||||
@override
|
||||
def get_runs_batch_by_time_range(
|
||||
self,
|
||||
start_from: datetime | None,
|
||||
@ -412,6 +420,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
|
||||
return session.scalars(stmt).all()
|
||||
|
||||
@override
|
||||
def get_archived_run_ids(
|
||||
self,
|
||||
session: Session,
|
||||
@ -423,6 +432,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
stmt = select(WorkflowArchiveLog.workflow_run_id).where(WorkflowArchiveLog.workflow_run_id.in_(run_ids))
|
||||
return set(session.scalars(stmt).all())
|
||||
|
||||
@override
|
||||
def get_archived_log_by_run_id(
|
||||
self,
|
||||
run_id: str,
|
||||
@ -431,6 +441,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
stmt = select(WorkflowArchiveLog).where(WorkflowArchiveLog.workflow_run_id == run_id).limit(1)
|
||||
return session.scalar(stmt)
|
||||
|
||||
@override
|
||||
def delete_archive_log_by_run_id(
|
||||
self,
|
||||
session: Session,
|
||||
@ -440,6 +451,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
result = session.execute(stmt)
|
||||
return cast(CursorResult, result).rowcount or 0
|
||||
|
||||
@override
|
||||
def get_pause_records_by_run_id(
|
||||
self,
|
||||
session: Session,
|
||||
@ -448,6 +460,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
stmt = select(WorkflowPause).where(WorkflowPause.workflow_run_id == run_id)
|
||||
return list(session.scalars(stmt))
|
||||
|
||||
@override
|
||||
def get_pause_reason_records_by_run_id(
|
||||
self,
|
||||
session: Session,
|
||||
@ -459,6 +472,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
stmt = select(WorkflowPauseReason).where(WorkflowPauseReason.pause_id.in_(pause_ids))
|
||||
return list(session.scalars(stmt))
|
||||
|
||||
@override
|
||||
def delete_runs_with_related(
|
||||
self,
|
||||
runs: Sequence[WorkflowRun],
|
||||
@ -516,6 +530,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
"pause_reasons": pause_reasons_deleted,
|
||||
}
|
||||
|
||||
@override
|
||||
def get_app_logs_by_run_id(
|
||||
self,
|
||||
session: Session,
|
||||
@ -524,6 +539,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
stmt = select(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id == run_id)
|
||||
return list(session.scalars(stmt))
|
||||
|
||||
@override
|
||||
def create_archive_logs(
|
||||
self,
|
||||
session: Session,
|
||||
@ -585,6 +601,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
session.add_all(archive_logs)
|
||||
return len(archive_logs)
|
||||
|
||||
@override
|
||||
def get_archived_runs_by_time_range(
|
||||
self,
|
||||
session: Session,
|
||||
@ -612,6 +629,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
stmt = stmt.where(WorkflowArchiveLog.tenant_id.in_(tenant_ids))
|
||||
return list(session.scalars(stmt))
|
||||
|
||||
@override
|
||||
def get_archived_logs_by_time_range(
|
||||
self,
|
||||
session: Session,
|
||||
@ -634,6 +652,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
stmt = stmt.where(WorkflowArchiveLog.tenant_id.in_(tenant_ids))
|
||||
return list(session.scalars(stmt))
|
||||
|
||||
@override
|
||||
def count_runs_with_related(
|
||||
self,
|
||||
runs: Sequence[WorkflowRun],
|
||||
@ -692,6 +711,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
"pause_reasons": int(pause_reasons_count),
|
||||
}
|
||||
|
||||
@override
|
||||
def create_workflow_pause(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
@ -827,6 +847,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
pause_reasons.append(reason.to_entity())
|
||||
return pause_reasons
|
||||
|
||||
@override
|
||||
def get_workflow_pause(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
@ -866,6 +887,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
pause_reasons=pause_reasons,
|
||||
)
|
||||
|
||||
@override
|
||||
def resume_workflow_pause(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
@ -934,6 +956,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
pause_reasons=hydrated_pause_reasons,
|
||||
)
|
||||
|
||||
@override
|
||||
def delete_workflow_pause(
|
||||
self,
|
||||
pause_entity: WorkflowPauseEntity,
|
||||
@ -972,6 +995,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
|
||||
logger.info("Deleted workflow pause %s for workflow run %s", pause_model.id, pause_model.workflow_run_id)
|
||||
|
||||
@override
|
||||
def prune_pauses(
|
||||
self,
|
||||
expiration: datetime,
|
||||
@ -1044,6 +1068,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
|
||||
return pruned_record_ids
|
||||
|
||||
@override
|
||||
def get_daily_runs_statistics(
|
||||
self,
|
||||
tenant_id: str,
|
||||
@ -1092,6 +1117,7 @@ WHERE
|
||||
|
||||
return cast(list[DailyRunsStats], response_data)
|
||||
|
||||
@override
|
||||
def get_daily_terminals_statistics(
|
||||
self,
|
||||
tenant_id: str,
|
||||
@ -1140,6 +1166,7 @@ WHERE
|
||||
|
||||
return cast(list[DailyTerminalsStats], response_data)
|
||||
|
||||
@override
|
||||
def get_daily_token_cost_statistics(
|
||||
self,
|
||||
tenant_id: str,
|
||||
@ -1193,6 +1220,7 @@ WHERE
|
||||
|
||||
return cast(list[DailyTokenCostStats], response_data)
|
||||
|
||||
@override
|
||||
def get_average_app_interaction_statistics(
|
||||
self,
|
||||
tenant_id: str,
|
||||
@ -1258,6 +1286,7 @@ GROUP BY
|
||||
|
||||
return cast(list[AverageInteractionStats], response_data)
|
||||
|
||||
@override
|
||||
def get_workflow_run_by_id_and_tenant_id(self, tenant_id: str, run_id: str) -> WorkflowRun | None:
|
||||
"""Get a specific workflow run by its id and the associated tenant id."""
|
||||
with self._session_maker() as session:
|
||||
@ -1291,13 +1320,16 @@ class _PrivateWorkflowPauseEntity(WorkflowPauseEntity):
|
||||
self._human_input_form = human_input_form
|
||||
|
||||
@property
|
||||
@override
|
||||
def id(self) -> str:
|
||||
return self._pause_model.id
|
||||
|
||||
@property
|
||||
@override
|
||||
def workflow_execution_id(self) -> str:
|
||||
return self._pause_model.workflow_run_id
|
||||
|
||||
@override
|
||||
def get_state(self) -> bytes:
|
||||
"""
|
||||
Retrieve the serialized workflow state from storage.
|
||||
@ -1319,14 +1351,17 @@ class _PrivateWorkflowPauseEntity(WorkflowPauseEntity):
|
||||
return state_data
|
||||
|
||||
@property
|
||||
@override
|
||||
def resumed_at(self) -> datetime | None:
|
||||
return self._pause_model.resumed_at
|
||||
|
||||
@override
|
||||
def get_pause_reasons(self) -> Sequence[PauseReason]:
|
||||
if self._pause_reasons is not None:
|
||||
return list(self._pause_reasons)
|
||||
return [reason.to_entity() for reason in self._reason_models]
|
||||
return list(self._pause_reasons) # type: ignore
|
||||
return [reason.to_entity() for reason in self._reason_models] # type: ignore
|
||||
|
||||
@property
|
||||
@override
|
||||
def paused_at(self) -> datetime:
|
||||
return self._pause_model.created_at
|
||||
|
||||
@ -5,7 +5,7 @@ import logging
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, selectinload, sessionmaker
|
||||
@ -45,6 +45,7 @@ class SQLAlchemyExecutionExtraContentRepository(ExecutionExtraContentRepository)
|
||||
def __init__(self, session_maker: sessionmaker[Session]):
|
||||
self._session_maker = session_maker
|
||||
|
||||
@override
|
||||
def get_by_message_ids(self, message_ids: Sequence[str]) -> list[list[ExecutionExtraContentDomainModel]]:
|
||||
if not message_ids:
|
||||
return []
|
||||
|
||||
@ -4,7 +4,7 @@ SQLAlchemy implementation of WorkflowTriggerLogRepository.
|
||||
|
||||
from collections.abc import Sequence
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import cast
|
||||
from typing import cast, override
|
||||
|
||||
from sqlalchemy import and_, delete, func, select
|
||||
from sqlalchemy.engine import CursorResult
|
||||
@ -25,18 +25,21 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository):
|
||||
def __init__(self, session: Session):
|
||||
self.session = session
|
||||
|
||||
@override
|
||||
def create(self, trigger_log: WorkflowTriggerLog) -> WorkflowTriggerLog:
|
||||
"""Create a new trigger log entry."""
|
||||
self.session.add(trigger_log)
|
||||
self.session.flush()
|
||||
return trigger_log
|
||||
|
||||
@override
|
||||
def update(self, trigger_log: WorkflowTriggerLog) -> WorkflowTriggerLog:
|
||||
"""Update an existing trigger log entry."""
|
||||
self.session.merge(trigger_log)
|
||||
self.session.flush()
|
||||
return trigger_log
|
||||
|
||||
@override
|
||||
def get_by_id(self, trigger_log_id: str, tenant_id: str | None = None) -> WorkflowTriggerLog | None:
|
||||
"""Get a trigger log by its ID."""
|
||||
query = select(WorkflowTriggerLog).where(WorkflowTriggerLog.id == trigger_log_id)
|
||||
@ -51,6 +54,7 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository):
|
||||
query = select(WorkflowTriggerLog).where(WorkflowTriggerLog.workflow_run_id == run_id)
|
||||
return list(self.session.scalars(query).all())
|
||||
|
||||
@override
|
||||
def get_failed_for_retry(
|
||||
self, tenant_id: str, max_retry_count: int = 3, limit: int = 100
|
||||
) -> Sequence[WorkflowTriggerLog]:
|
||||
@ -70,6 +74,7 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository):
|
||||
|
||||
return list(self.session.scalars(query).all())
|
||||
|
||||
@override
|
||||
def get_recent_logs(
|
||||
self, tenant_id: str, app_id: str, hours: int = 24, limit: int = 100, offset: int = 0
|
||||
) -> Sequence[WorkflowTriggerLog]:
|
||||
@ -92,6 +97,7 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository):
|
||||
|
||||
return list(self.session.scalars(query).all())
|
||||
|
||||
@override
|
||||
def get_by_workflow_run_id(self, workflow_run_id: str) -> WorkflowTriggerLog | None:
|
||||
"""Get the trigger log associated with a workflow run."""
|
||||
query = (
|
||||
@ -102,6 +108,7 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository):
|
||||
)
|
||||
return self.session.scalar(query)
|
||||
|
||||
@override
|
||||
def delete_by_run_ids(self, run_ids: Sequence[str]) -> int:
|
||||
"""
|
||||
Delete trigger logs associated with the given workflow run ids.
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import TypedDict
|
||||
from typing import TypedDict, override
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
@ -28,6 +28,7 @@ class WorkflowCollaborationRepository:
|
||||
def __init__(self) -> None:
|
||||
self._redis = redis_client
|
||||
|
||||
@override
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(redis_client={self._redis})"
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user