diff --git a/api/libs/broadcast_channel/channel.py b/api/libs/broadcast_channel/channel.py index 8eeac37232..a4957c9771 100644 --- a/api/libs/broadcast_channel/channel.py +++ b/api/libs/broadcast_channel/channel.py @@ -8,7 +8,7 @@ import types from abc import abstractmethod from collections.abc import Iterator from contextlib import AbstractContextManager -from typing import Protocol, Self +from typing import Protocol, Self, override class Subscription(AbstractContextManager["Subscription"], Protocol): @@ -37,10 +37,12 @@ class Subscription(AbstractContextManager["Subscription"], Protocol): """close closes the subscription, releases any resources associated with it.""" ... + @override def __enter__(self) -> Self: """`__enter__` does the setup logic of the subscription (if any), and return itself.""" return self + @override def __exit__( self, exc_type: type[BaseException] | None, diff --git a/api/libs/broadcast_channel/redis/_subscription.py b/api/libs/broadcast_channel/redis/_subscription.py index 9fe50445e4..15355a7762 100644 --- a/api/libs/broadcast_channel/redis/_subscription.py +++ b/api/libs/broadcast_channel/redis/_subscription.py @@ -3,7 +3,7 @@ import queue import threading import types from collections.abc import Generator, Iterator -from typing import Any, Self +from typing import Any, Self, override from libs.broadcast_channel.channel import Subscription from libs.broadcast_channel.exc import SubscriptionClosedError @@ -165,6 +165,7 @@ class RedisSubscriptionBase(Subscription): yield item + @override def __iter__(self) -> Iterator[bytes]: """Return an iterator over messages from the subscription.""" if self._closed.is_set(): @@ -172,6 +173,7 @@ class RedisSubscriptionBase(Subscription): self._start_if_needed() return iter(self._message_iterator()) + @override def receive(self, timeout: float | None = 0.1) -> bytes | None: """Receive the next message from the subscription.""" if self._closed.is_set(): @@ -185,11 +187,13 @@ class RedisSubscriptionBase(Subscription): return item + @override def __enter__(self) -> Self: """Context manager entry point.""" self._start_if_needed() return self + @override def __exit__( self, exc_type: type[BaseException] | None, @@ -200,6 +204,7 @@ class RedisSubscriptionBase(Subscription): self.close() return None + @override def close(self) -> None: """Close the subscription and clean up resources.""" if self._closed.is_set(): diff --git a/api/libs/broadcast_channel/redis/channel.py b/api/libs/broadcast_channel/redis/channel.py index 7f13ebaabc..bf304cc4a0 100644 --- a/api/libs/broadcast_channel/redis/channel.py +++ b/api/libs/broadcast_channel/redis/channel.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any +from typing import Any, override from extensions.redis_names import serialize_redis_name from libs.broadcast_channel.channel import Producer, Subscriber, Subscription @@ -68,20 +68,25 @@ class Topic: class _RedisSubscription(RedisSubscriptionBase): """Regular Redis pub/sub subscription implementation.""" + @override def _get_subscription_type(self) -> str: return "regular" + @override def _subscribe(self) -> None: assert self._pubsub is not None self._pubsub.subscribe(self._topic) + @override def _unsubscribe(self) -> None: assert self._pubsub is not None self._pubsub.unsubscribe(self._topic) + @override def _get_message(self) -> dict[str, Any] | None: assert self._pubsub is not None return self._pubsub.get_message(ignore_subscribe_messages=True, timeout=1) + @override def _get_message_type(self) -> str: return "message" diff --git a/api/libs/broadcast_channel/redis/sharded_channel.py b/api/libs/broadcast_channel/redis/sharded_channel.py index 02dc987107..a7303c0782 100644 --- a/api/libs/broadcast_channel/redis/sharded_channel.py +++ b/api/libs/broadcast_channel/redis/sharded_channel.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any +from typing import Any, override from extensions.redis_names import serialize_redis_name from libs.broadcast_channel.channel import Producer, Subscriber, Subscription @@ -64,17 +64,21 @@ class ShardedTopic: class _RedisShardedSubscription(RedisSubscriptionBase): """Redis 7.0+ sharded pub/sub subscription implementation.""" + @override def _get_subscription_type(self) -> str: return "sharded" + @override def _subscribe(self) -> None: assert self._pubsub is not None self._pubsub.ssubscribe(self._topic) # type: ignore[attr-defined] + @override def _unsubscribe(self) -> None: assert self._pubsub is not None self._pubsub.sunsubscribe(self._topic) # type: ignore[attr-defined] + @override def _get_message(self) -> dict[str, Any] | None: assert self._pubsub is not None # NOTE(QuantumGhost): this is an issue in @@ -101,5 +105,6 @@ class _RedisShardedSubscription(RedisSubscriptionBase): else: raise AssertionError("client should be either Redis or RedisCluster.") + @override def _get_message_type(self) -> str: return "smessage" diff --git a/api/libs/broadcast_channel/redis/streams_channel.py b/api/libs/broadcast_channel/redis/streams_channel.py index 985b253c7c..30c1458579 100644 --- a/api/libs/broadcast_channel/redis/streams_channel.py +++ b/api/libs/broadcast_channel/redis/streams_channel.py @@ -4,7 +4,7 @@ import logging import queue import threading from collections.abc import Iterator -from typing import Self +from typing import Self, override from extensions.redis_names import serialize_redis_name from libs.broadcast_channel.channel import Producer, Subscriber, Subscription @@ -165,6 +165,7 @@ class _StreamsSubscription(Subscription): ) self._listener.start() + @override def __iter__(self) -> Iterator[bytes]: # Iterator delegates to receive with timeout; stops on closure. with self._lock: @@ -181,6 +182,7 @@ class _StreamsSubscription(Subscription): if item is not None: yield item + @override def receive(self, timeout: float | None = 0.1) -> bytes | None: with self._lock: if self._closed: @@ -200,6 +202,7 @@ class _StreamsSubscription(Subscription): assert isinstance(item, (bytes, bytearray)), "Unexpected item type in stream queue" return bytes(item) + @override def close(self) -> None: with self._lock: if self._closed: @@ -221,11 +224,13 @@ class _StreamsSubscription(Subscription): ) # Context manager helpers + @override def __enter__(self) -> Self: with self._lock: self._start_if_needed() return self + @override def __exit__(self, exc_type, exc_value, traceback) -> bool | None: self.close() return None diff --git a/api/libs/email_template_renderer.py b/api/libs/email_template_renderer.py index 98ea30ab46..f249187e21 100644 --- a/api/libs/email_template_renderer.py +++ b/api/libs/email_template_renderer.py @@ -4,7 +4,7 @@ Email template rendering helpers with configurable safety modes. import time from collections.abc import Mapping -from typing import Any +from typing import Any, override from flask import render_template_string from jinja2.runtime import Context @@ -21,6 +21,7 @@ class SandboxedEnvironment(ImmutableSandboxedEnvironment): self._deadline = time.time() + timeout if timeout else None super().__init__(*args, **kwargs) + @override def call(self, context: Context, obj: Any, *args: Any, **kwargs: Any) -> Any: if self._deadline is not None and time.time() > self._deadline: raise TimeoutError("Template rendering timeout") diff --git a/api/libs/helper.py b/api/libs/helper.py index b66079fd5f..a31b546624 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -10,7 +10,7 @@ import uuid from collections.abc import Callable, Generator, Mapping from datetime import datetime from hashlib import sha256 -from typing import TYPE_CHECKING, Annotated, Any, Protocol, cast, overload +from typing import TYPE_CHECKING, Annotated, Any, Protocol, cast, overload, override from uuid import UUID from zoneinfo import available_timezones @@ -128,6 +128,7 @@ def run(script): class AppIconUrlField(fields.Raw): + @override def output(self, key, obj, **kwargs): if obj is None: return None @@ -163,6 +164,7 @@ def build_avatar_url(avatar: str | None) -> str | None: class AvatarUrlField(fields.Raw): + @override def output(self, key, obj, **kwargs): if obj is None: return None @@ -175,11 +177,13 @@ class AvatarUrlField(fields.Raw): class TimestampField(fields.Raw): + @override def format(self, value) -> int: return int(value.timestamp()) class OptionalTimestampField(fields.Raw): + @override def format(self, value) -> int | None: if value is None: return None diff --git a/api/libs/oauth.py b/api/libs/oauth.py index 309f2aa812..687fd08657 100644 --- a/api/libs/oauth.py +++ b/api/libs/oauth.py @@ -4,7 +4,7 @@ import json import logging import urllib.parse from dataclasses import dataclass -from typing import NotRequired, TypedDict +from typing import NotRequired, TypedDict, override import httpx from pydantic import TypeAdapter, ValidationError @@ -145,6 +145,7 @@ class GitHubOAuth(OAuth): _USER_INFO_URL = "https://api.github.com/user" _EMAIL_INFO_URL = "https://api.github.com/user/emails" + @override def get_authorization_url( self, invite_token: str | None = None, @@ -161,6 +162,7 @@ class GitHubOAuth(OAuth): params["state"] = state return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" + @override def get_access_token(self, code: str) -> str: data = { "client_id": self.client_id, @@ -179,6 +181,7 @@ class GitHubOAuth(OAuth): return access_token + @override def get_raw_user_info(self, token: str) -> JsonObject: headers = {"Authorization": f"token {token}"} response = _http_client.get(self._USER_INFO_URL, headers=headers) @@ -219,6 +222,7 @@ class GitHubOAuth(OAuth): return "" + @override def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo: payload = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(raw_info) email = payload.get("email") or "" @@ -238,6 +242,7 @@ class GoogleOAuth(OAuth): _TOKEN_URL = "https://oauth2.googleapis.com/token" _USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo" + @override def get_authorization_url( self, invite_token: str | None = None, @@ -255,6 +260,7 @@ class GoogleOAuth(OAuth): params["state"] = state return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" + @override def get_access_token(self, code: str) -> str: data = { "client_id": self.client_id, @@ -274,12 +280,14 @@ class GoogleOAuth(OAuth): return access_token + @override def get_raw_user_info(self, token: str) -> JsonObject: headers = {"Authorization": f"Bearer {token}"} response = _http_client.get(self._USER_INFO_URL, headers=headers) response.raise_for_status() return _json_object(response) + @override def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo: payload = GOOGLE_RAW_USER_INFO_ADAPTER.validate_python(raw_info) return OAuthUserInfo(id=str(payload["sub"]), name="", email=payload["email"]) diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index 934aacb45b..d9971d3992 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -1,5 +1,5 @@ import urllib.parse -from typing import Any, Literal, TypedDict +from typing import Any, Literal, TypedDict, override import httpx from flask_login import current_user @@ -64,6 +64,7 @@ class NotionOAuth(OAuthDataSource): _NOTION_BLOCK_SEARCH = "https://api.notion.com/v1/blocks" _NOTION_BOT_USER = "https://api.notion.com/v1/users/me" + @override def get_authorization_url(self) -> str: params = { "client_id": self.client_id, @@ -73,6 +74,7 @@ class NotionOAuth(OAuthDataSource): } return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" + @override def get_access_token(self, code: str) -> None: data = {"code": code, "grant_type": "authorization_code", "redirect_uri": self.redirect_uri} headers = {"Accept": "application/json"}