mirror of
https://github.com/langgenius/dify.git
synced 2026-06-07 16:23:44 +08:00
chore: add missing @override decorators to api/libs (#37012)
This commit is contained in:
parent
35a55813d2
commit
bb3c9929f9
@ -8,7 +8,7 @@ import types
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Iterator
|
||||
from contextlib import AbstractContextManager
|
||||
from typing import Protocol, Self
|
||||
from typing import Protocol, Self, override
|
||||
|
||||
|
||||
class Subscription(AbstractContextManager["Subscription"], Protocol):
|
||||
@ -37,10 +37,12 @@ class Subscription(AbstractContextManager["Subscription"], Protocol):
|
||||
"""close closes the subscription, releases any resources associated with it."""
|
||||
...
|
||||
|
||||
@override
|
||||
def __enter__(self) -> Self:
|
||||
"""`__enter__` does the setup logic of the subscription (if any), and return itself."""
|
||||
return self
|
||||
|
||||
@override
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
|
||||
@ -3,7 +3,7 @@ import queue
|
||||
import threading
|
||||
import types
|
||||
from collections.abc import Generator, Iterator
|
||||
from typing import Any, Self
|
||||
from typing import Any, Self, override
|
||||
|
||||
from libs.broadcast_channel.channel import Subscription
|
||||
from libs.broadcast_channel.exc import SubscriptionClosedError
|
||||
@ -165,6 +165,7 @@ class RedisSubscriptionBase(Subscription):
|
||||
|
||||
yield item
|
||||
|
||||
@override
|
||||
def __iter__(self) -> Iterator[bytes]:
|
||||
"""Return an iterator over messages from the subscription."""
|
||||
if self._closed.is_set():
|
||||
@ -172,6 +173,7 @@ class RedisSubscriptionBase(Subscription):
|
||||
self._start_if_needed()
|
||||
return iter(self._message_iterator())
|
||||
|
||||
@override
|
||||
def receive(self, timeout: float | None = 0.1) -> bytes | None:
|
||||
"""Receive the next message from the subscription."""
|
||||
if self._closed.is_set():
|
||||
@ -185,11 +187,13 @@ class RedisSubscriptionBase(Subscription):
|
||||
|
||||
return item
|
||||
|
||||
@override
|
||||
def __enter__(self) -> Self:
|
||||
"""Context manager entry point."""
|
||||
self._start_if_needed()
|
||||
return self
|
||||
|
||||
@override
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
@ -200,6 +204,7 @@ class RedisSubscriptionBase(Subscription):
|
||||
self.close()
|
||||
return None
|
||||
|
||||
@override
|
||||
def close(self) -> None:
|
||||
"""Close the subscription and clean up resources."""
|
||||
if self._closed.is_set():
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from extensions.redis_names import serialize_redis_name
|
||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||
@ -68,20 +68,25 @@ class Topic:
|
||||
class _RedisSubscription(RedisSubscriptionBase):
|
||||
"""Regular Redis pub/sub subscription implementation."""
|
||||
|
||||
@override
|
||||
def _get_subscription_type(self) -> str:
|
||||
return "regular"
|
||||
|
||||
@override
|
||||
def _subscribe(self) -> None:
|
||||
assert self._pubsub is not None
|
||||
self._pubsub.subscribe(self._topic)
|
||||
|
||||
@override
|
||||
def _unsubscribe(self) -> None:
|
||||
assert self._pubsub is not None
|
||||
self._pubsub.unsubscribe(self._topic)
|
||||
|
||||
@override
|
||||
def _get_message(self) -> dict[str, Any] | None:
|
||||
assert self._pubsub is not None
|
||||
return self._pubsub.get_message(ignore_subscribe_messages=True, timeout=1)
|
||||
|
||||
@override
|
||||
def _get_message_type(self) -> str:
|
||||
return "message"
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from extensions.redis_names import serialize_redis_name
|
||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||
@ -64,17 +64,21 @@ class ShardedTopic:
|
||||
class _RedisShardedSubscription(RedisSubscriptionBase):
|
||||
"""Redis 7.0+ sharded pub/sub subscription implementation."""
|
||||
|
||||
@override
|
||||
def _get_subscription_type(self) -> str:
|
||||
return "sharded"
|
||||
|
||||
@override
|
||||
def _subscribe(self) -> None:
|
||||
assert self._pubsub is not None
|
||||
self._pubsub.ssubscribe(self._topic) # type: ignore[attr-defined]
|
||||
|
||||
@override
|
||||
def _unsubscribe(self) -> None:
|
||||
assert self._pubsub is not None
|
||||
self._pubsub.sunsubscribe(self._topic) # type: ignore[attr-defined]
|
||||
|
||||
@override
|
||||
def _get_message(self) -> dict[str, Any] | None:
|
||||
assert self._pubsub is not None
|
||||
# NOTE(QuantumGhost): this is an issue in
|
||||
@ -101,5 +105,6 @@ class _RedisShardedSubscription(RedisSubscriptionBase):
|
||||
else:
|
||||
raise AssertionError("client should be either Redis or RedisCluster.")
|
||||
|
||||
@override
|
||||
def _get_message_type(self) -> str:
|
||||
return "smessage"
|
||||
|
||||
@ -4,7 +4,7 @@ import logging
|
||||
import queue
|
||||
import threading
|
||||
from collections.abc import Iterator
|
||||
from typing import Self
|
||||
from typing import Self, override
|
||||
|
||||
from extensions.redis_names import serialize_redis_name
|
||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||
@ -165,6 +165,7 @@ class _StreamsSubscription(Subscription):
|
||||
)
|
||||
self._listener.start()
|
||||
|
||||
@override
|
||||
def __iter__(self) -> Iterator[bytes]:
|
||||
# Iterator delegates to receive with timeout; stops on closure.
|
||||
with self._lock:
|
||||
@ -181,6 +182,7 @@ class _StreamsSubscription(Subscription):
|
||||
if item is not None:
|
||||
yield item
|
||||
|
||||
@override
|
||||
def receive(self, timeout: float | None = 0.1) -> bytes | None:
|
||||
with self._lock:
|
||||
if self._closed:
|
||||
@ -200,6 +202,7 @@ class _StreamsSubscription(Subscription):
|
||||
assert isinstance(item, (bytes, bytearray)), "Unexpected item type in stream queue"
|
||||
return bytes(item)
|
||||
|
||||
@override
|
||||
def close(self) -> None:
|
||||
with self._lock:
|
||||
if self._closed:
|
||||
@ -221,11 +224,13 @@ class _StreamsSubscription(Subscription):
|
||||
)
|
||||
|
||||
# Context manager helpers
|
||||
@override
|
||||
def __enter__(self) -> Self:
|
||||
with self._lock:
|
||||
self._start_if_needed()
|
||||
return self
|
||||
|
||||
@override
|
||||
def __exit__(self, exc_type, exc_value, traceback) -> bool | None:
|
||||
self.close()
|
||||
return None
|
||||
|
||||
@ -4,7 +4,7 @@ Email template rendering helpers with configurable safety modes.
|
||||
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from flask import render_template_string
|
||||
from jinja2.runtime import Context
|
||||
@ -21,6 +21,7 @@ class SandboxedEnvironment(ImmutableSandboxedEnvironment):
|
||||
self._deadline = time.time() + timeout if timeout else None
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@override
|
||||
def call(self, context: Context, obj: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
if self._deadline is not None and time.time() > self._deadline:
|
||||
raise TimeoutError("Template rendering timeout")
|
||||
|
||||
@ -10,7 +10,7 @@ import uuid
|
||||
from collections.abc import Callable, Generator, Mapping
|
||||
from datetime import datetime
|
||||
from hashlib import sha256
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Protocol, cast, overload
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Protocol, cast, overload, override
|
||||
from uuid import UUID
|
||||
from zoneinfo import available_timezones
|
||||
|
||||
@ -128,6 +128,7 @@ def run(script):
|
||||
|
||||
|
||||
class AppIconUrlField(fields.Raw):
|
||||
@override
|
||||
def output(self, key, obj, **kwargs):
|
||||
if obj is None:
|
||||
return None
|
||||
@ -163,6 +164,7 @@ def build_avatar_url(avatar: str | None) -> str | None:
|
||||
|
||||
|
||||
class AvatarUrlField(fields.Raw):
|
||||
@override
|
||||
def output(self, key, obj, **kwargs):
|
||||
if obj is None:
|
||||
return None
|
||||
@ -175,11 +177,13 @@ class AvatarUrlField(fields.Raw):
|
||||
|
||||
|
||||
class TimestampField(fields.Raw):
|
||||
@override
|
||||
def format(self, value) -> int:
|
||||
return int(value.timestamp())
|
||||
|
||||
|
||||
class OptionalTimestampField(fields.Raw):
|
||||
@override
|
||||
def format(self, value) -> int | None:
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
@ -4,7 +4,7 @@ import json
|
||||
import logging
|
||||
import urllib.parse
|
||||
from dataclasses import dataclass
|
||||
from typing import NotRequired, TypedDict
|
||||
from typing import NotRequired, TypedDict, override
|
||||
|
||||
import httpx
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
@ -145,6 +145,7 @@ class GitHubOAuth(OAuth):
|
||||
_USER_INFO_URL = "https://api.github.com/user"
|
||||
_EMAIL_INFO_URL = "https://api.github.com/user/emails"
|
||||
|
||||
@override
|
||||
def get_authorization_url(
|
||||
self,
|
||||
invite_token: str | None = None,
|
||||
@ -161,6 +162,7 @@ class GitHubOAuth(OAuth):
|
||||
params["state"] = state
|
||||
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
@override
|
||||
def get_access_token(self, code: str) -> str:
|
||||
data = {
|
||||
"client_id": self.client_id,
|
||||
@ -179,6 +181,7 @@ class GitHubOAuth(OAuth):
|
||||
|
||||
return access_token
|
||||
|
||||
@override
|
||||
def get_raw_user_info(self, token: str) -> JsonObject:
|
||||
headers = {"Authorization": f"token {token}"}
|
||||
response = _http_client.get(self._USER_INFO_URL, headers=headers)
|
||||
@ -219,6 +222,7 @@ class GitHubOAuth(OAuth):
|
||||
|
||||
return ""
|
||||
|
||||
@override
|
||||
def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo:
|
||||
payload = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(raw_info)
|
||||
email = payload.get("email") or ""
|
||||
@ -238,6 +242,7 @@ class GoogleOAuth(OAuth):
|
||||
_TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
_USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo"
|
||||
|
||||
@override
|
||||
def get_authorization_url(
|
||||
self,
|
||||
invite_token: str | None = None,
|
||||
@ -255,6 +260,7 @@ class GoogleOAuth(OAuth):
|
||||
params["state"] = state
|
||||
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
@override
|
||||
def get_access_token(self, code: str) -> str:
|
||||
data = {
|
||||
"client_id": self.client_id,
|
||||
@ -274,12 +280,14 @@ class GoogleOAuth(OAuth):
|
||||
|
||||
return access_token
|
||||
|
||||
@override
|
||||
def get_raw_user_info(self, token: str) -> JsonObject:
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
response = _http_client.get(self._USER_INFO_URL, headers=headers)
|
||||
response.raise_for_status()
|
||||
return _json_object(response)
|
||||
|
||||
@override
|
||||
def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo:
|
||||
payload = GOOGLE_RAW_USER_INFO_ADAPTER.validate_python(raw_info)
|
||||
return OAuthUserInfo(id=str(payload["sub"]), name="", email=payload["email"])
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import urllib.parse
|
||||
from typing import Any, Literal, TypedDict
|
||||
from typing import Any, Literal, TypedDict, override
|
||||
|
||||
import httpx
|
||||
from flask_login import current_user
|
||||
@ -64,6 +64,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
_NOTION_BLOCK_SEARCH = "https://api.notion.com/v1/blocks"
|
||||
_NOTION_BOT_USER = "https://api.notion.com/v1/users/me"
|
||||
|
||||
@override
|
||||
def get_authorization_url(self) -> str:
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
@ -73,6 +74,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
}
|
||||
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
@override
|
||||
def get_access_token(self, code: str) -> None:
|
||||
data = {"code": code, "grant_type": "authorization_code", "redirect_uri": self.redirect_uri}
|
||||
headers = {"Accept": "application/json"}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user