refactor: add @override decorators to storage backend subclasses (#36406) (#36755)

This commit is contained in:
Nishchay Mahor 2026-05-27 23:04:47 -07:00 committed by GitHub
parent b94ff65e9f
commit 57b02e341c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 86 additions and 3 deletions

View File

@ -1,5 +1,6 @@
import posixpath
from collections.abc import Generator
from typing import override
import oss2 as aliyun_s3
@ -29,9 +30,11 @@ class AliyunOssStorage(BaseStorage):
cloudbox_id=dify_config.ALIYUN_CLOUDBOX_ID,
)
@override
def save(self, filename, data):
self.client.put_object(self.__wrapper_folder_filename(filename), data)
@override
def load_once(self, filename: str) -> bytes:
obj = self.client.get_object(self.__wrapper_folder_filename(filename))
data = obj.read()
@ -39,17 +42,21 @@ class AliyunOssStorage(BaseStorage):
return b""
return data
@override
def load_stream(self, filename: str) -> Generator:
obj = self.client.get_object(self.__wrapper_folder_filename(filename))
while chunk := obj.read(4096):
yield chunk
@override
def download(self, filename: str, target_filepath):
self.client.get_object_to_file(self.__wrapper_folder_filename(filename), target_filepath)
@override
def exists(self, filename: str):
return self.client.object_exists(self.__wrapper_folder_filename(filename))
@override
def delete(self, filename: str):
self.client.delete_object(self.__wrapper_folder_filename(filename))

View File

@ -1,5 +1,6 @@
import logging
from collections.abc import Generator
from typing import override
import boto3
from botocore.client import Config
@ -48,9 +49,11 @@ class AwsS3Storage(BaseStorage):
# other error, raise exception
raise
@override
def save(self, filename, data):
self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data)
@override
def load_once(self, filename: str) -> bytes:
try:
data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
@ -61,6 +64,7 @@ class AwsS3Storage(BaseStorage):
raise
return data
@override
def load_stream(self, filename: str) -> Generator:
try:
response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
@ -73,9 +77,11 @@ class AwsS3Storage(BaseStorage):
else:
raise
@override
def download(self, filename, target_filepath):
self.client.download_file(self.bucket_name, filename, target_filepath)
@override
def exists(self, filename):
try:
self.client.head_object(Bucket=self.bucket_name, Key=filename)
@ -83,5 +89,6 @@ class AwsS3Storage(BaseStorage):
except:
return False
@override
def delete(self, filename: str):
self.client.delete_object(Bucket=self.bucket_name, Key=filename)

View File

@ -1,5 +1,6 @@
from collections.abc import Generator
from datetime import timedelta
from typing import override
from azure.identity import ChainedTokenCredential, DefaultAzureCredential
from azure.storage.blob import AccountSasPermissions, BlobServiceClient, ResourceTypes, generate_account_sas
@ -26,6 +27,7 @@ class AzureBlobStorage(BaseStorage):
else:
self.credential = None
@override
def save(self, filename, data):
if not self.bucket_name:
return
@ -34,6 +36,7 @@ class AzureBlobStorage(BaseStorage):
blob_container = client.get_container_client(container=self.bucket_name)
blob_container.upload_blob(filename, data)
@override
def load_once(self, filename: str) -> bytes:
if not self.bucket_name:
raise FileNotFoundError("Azure bucket name is not configured.")
@ -46,6 +49,7 @@ class AzureBlobStorage(BaseStorage):
raise TypeError(f"Expected bytes from blob.readall(), got {type(data).__name__}")
return data
@override
def load_stream(self, filename: str) -> Generator:
if not self.bucket_name:
raise FileNotFoundError("Azure bucket name is not configured.")
@ -55,6 +59,7 @@ class AzureBlobStorage(BaseStorage):
blob_data = blob.download_blob()
yield from blob_data.chunks()
@override
def download(self, filename, target_filepath):
if not self.bucket_name:
return
@ -66,6 +71,7 @@ class AzureBlobStorage(BaseStorage):
blob_data = blob.download_blob()
blob_data.readinto(my_blob)
@override
def exists(self, filename):
if not self.bucket_name:
return False
@ -75,6 +81,7 @@ class AzureBlobStorage(BaseStorage):
blob = client.get_blob_client(container=self.bucket_name, blob=filename)
return blob.exists()
@override
def delete(self, filename: str):
if not self.bucket_name:
return

View File

@ -1,6 +1,7 @@
import base64
import hashlib
from collections.abc import Generator
from typing import override
from baidubce.auth.bce_credentials import BceCredentials
from baidubce.bce_client_configuration import BceClientConfiguration
@ -26,6 +27,7 @@ class BaiduObsStorage(BaseStorage):
self.client = BosClient(config=client_config)
@override
def save(self, filename, data):
md5 = hashlib.md5()
md5.update(data)
@ -34,24 +36,29 @@ class BaiduObsStorage(BaseStorage):
bucket_name=self.bucket_name, key=filename, data=data, content_length=len(data), content_md5=content_md5
)
@override
def load_once(self, filename: str) -> bytes:
response = self.client.get_object(bucket_name=self.bucket_name, key=filename)
data: bytes = response.data.read()
return data
@override
def load_stream(self, filename: str) -> Generator:
response = self.client.get_object(bucket_name=self.bucket_name, key=filename).data
while chunk := response.read(4096):
yield chunk
@override
def download(self, filename, target_filepath):
self.client.get_object_to_file(bucket_name=self.bucket_name, key=filename, file_name=target_filepath)
@override
def exists(self, filename):
res = self.client.get_object_meta_data(bucket_name=self.bucket_name, key=filename)
if res is None:
return False
return True
@override
def delete(self, filename: str):
self.client.delete_object(bucket_name=self.bucket_name, key=filename)

View File

@ -10,7 +10,7 @@ import tempfile
from collections.abc import Generator
from io import BytesIO
from pathlib import Path
from typing import Any
from typing import Any, override
import clickzetta
from pydantic import BaseModel, model_validator
@ -251,6 +251,7 @@ class ClickZettaVolumeStorage(BaseStorage):
# Don't raise exception, let the operation continue
# The table might exist but not be visible due to permissions
@override
def save(self, filename: str, data: bytes):
"""Save data to ClickZetta Volume.
@ -304,6 +305,7 @@ class ClickZettaVolumeStorage(BaseStorage):
# Clean up temporary file
Path(temp_file_path).unlink(missing_ok=True)
@override
def load_once(self, filename: str) -> bytes:
"""Load file content from ClickZetta Volume.
@ -364,6 +366,7 @@ class ClickZettaVolumeStorage(BaseStorage):
logger.debug("File %s loaded from ClickZetta Volume", filename)
return content
@override
def load_stream(self, filename: str) -> Generator:
"""Load file as stream from ClickZetta Volume.
@ -382,6 +385,7 @@ class ClickZettaVolumeStorage(BaseStorage):
logger.debug("File %s loaded as stream from ClickZetta Volume", filename)
@override
def download(self, filename: str, target_filepath: str):
"""Download file from ClickZetta Volume to local path.
@ -395,6 +399,7 @@ class ClickZettaVolumeStorage(BaseStorage):
logger.debug("File %s downloaded from ClickZetta Volume to %s", filename, target_filepath)
@override
def exists(self, filename: str) -> bool:
"""Check if file exists in ClickZetta Volume.
@ -436,6 +441,7 @@ class ClickZettaVolumeStorage(BaseStorage):
logger.warning("Error checking file existence for %s: %s", filename, e)
return False
@override
def delete(self, filename: str):
"""Delete file from ClickZetta Volume.
@ -472,6 +478,7 @@ class ClickZettaVolumeStorage(BaseStorage):
logger.debug("File %s deleted from ClickZetta Volume", filename)
@override
def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]:
"""Scan files and directories in ClickZetta Volume.

View File

@ -1,7 +1,7 @@
import base64
import io
from collections.abc import Generator
from typing import Any
from typing import Any, override
from google.cloud import storage as google_cloud_storage # type: ignore
from pydantic import TypeAdapter
@ -29,12 +29,14 @@ class GoogleCloudStorage(BaseStorage):
else:
self.client = google_cloud_storage.Client()
@override
def save(self, filename, data):
bucket = self.client.get_bucket(self.bucket_name)
blob = bucket.blob(filename)
with io.BytesIO(data) as stream:
blob.upload_from_file(stream)
@override
def load_once(self, filename: str) -> bytes:
bucket = self.client.get_bucket(self.bucket_name)
blob = bucket.get_blob(filename)
@ -43,6 +45,7 @@ class GoogleCloudStorage(BaseStorage):
data: bytes = blob.download_as_bytes()
return data
@override
def load_stream(self, filename: str) -> Generator:
bucket = self.client.get_bucket(self.bucket_name)
blob = bucket.get_blob(filename)
@ -52,6 +55,7 @@ class GoogleCloudStorage(BaseStorage):
while chunk := blob_stream.read(4096):
yield chunk
@override
def download(self, filename, target_filepath):
bucket = self.client.get_bucket(self.bucket_name)
blob = bucket.get_blob(filename)
@ -59,11 +63,13 @@ class GoogleCloudStorage(BaseStorage):
raise FileNotFoundError("File not found")
blob.download_to_filename(target_filepath)
@override
def exists(self, filename):
bucket = self.client.get_bucket(self.bucket_name)
blob = bucket.blob(filename)
return blob.exists()
@override
def delete(self, filename: str):
bucket = self.client.get_bucket(self.bucket_name)
bucket.delete_blob(filename)

View File

@ -1,4 +1,5 @@
from collections.abc import Generator
from typing import override
from obs import ObsClient
@ -20,27 +21,33 @@ class HuaweiObsStorage(BaseStorage):
path_style=dify_config.HUAWEI_OBS_PATH_STYLE,
)
@override
def save(self, filename, data):
self.client.putObject(bucketName=self.bucket_name, objectKey=filename, content=data)
@override
def load_once(self, filename: str) -> bytes:
data: bytes = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response.read()
return data
@override
def load_stream(self, filename: str) -> Generator:
response = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response
while chunk := response.read(4096):
yield chunk
@override
def download(self, filename, target_filepath):
self.client.getObject(bucketName=self.bucket_name, objectKey=filename, downloadPath=target_filepath)
@override
def exists(self, filename):
res = self._get_meta(filename)
if res is None:
return False
return True
@override
def delete(self, filename: str):
self.client.deleteObject(bucketName=self.bucket_name, objectKey=filename)

View File

@ -2,7 +2,7 @@ import logging
import os
from collections.abc import Generator
from pathlib import Path
from typing import Any
from typing import Any, override
import opendal
from dotenv import dotenv_values
@ -41,10 +41,12 @@ class OpenDALStorage(BaseStorage):
logger.debug("opendal operator created with scheme %s", scheme)
logger.debug("added retry layer to opendal operator")
@override
def save(self, filename: str, data: bytes):
self.op.write(path=filename, bs=data)
logger.debug("file %s saved", filename)
@override
def load_once(self, filename: str) -> bytes:
if not self.exists(filename):
raise FileNotFoundError("File not found")
@ -53,6 +55,7 @@ class OpenDALStorage(BaseStorage):
logger.debug("file %s loaded", filename)
return content
@override
def load_stream(self, filename: str) -> Generator:
if not self.exists(filename):
raise FileNotFoundError("File not found")
@ -67,6 +70,7 @@ class OpenDALStorage(BaseStorage):
yield chunk
logger.debug("file %s loaded as stream", filename)
@override
def download(self, filename: str, target_filepath: str):
if not self.exists(filename):
raise FileNotFoundError("File not found")
@ -74,9 +78,11 @@ class OpenDALStorage(BaseStorage):
Path(target_filepath).write_bytes(self.op.read(path=filename))
logger.debug("file %s downloaded to %s", filename, target_filepath)
@override
def exists(self, filename: str) -> bool:
return self.op.exists(path=filename)
@override
def delete(self, filename: str):
if self.exists(filename):
self.op.delete(path=filename)
@ -84,6 +90,7 @@ class OpenDALStorage(BaseStorage):
return
logger.debug("file %s not found, skip delete", filename)
@override
def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]:
if not self.exists(path):
raise FileNotFoundError("Path not found")

View File

@ -1,4 +1,5 @@
from collections.abc import Generator
from typing import override
import boto3
from botocore.exceptions import ClientError
@ -22,9 +23,11 @@ class OracleOCIStorage(BaseStorage):
region_name=dify_config.OCI_REGION,
)
@override
def save(self, filename, data):
self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data)
@override
def load_once(self, filename: str) -> bytes:
try:
data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
@ -35,6 +38,7 @@ class OracleOCIStorage(BaseStorage):
raise
return data
@override
def load_stream(self, filename: str) -> Generator:
try:
response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
@ -45,9 +49,11 @@ class OracleOCIStorage(BaseStorage):
else:
raise
@override
def download(self, filename, target_filepath):
self.client.download_file(self.bucket_name, filename, target_filepath)
@override
def exists(self, filename):
try:
self.client.head_object(Bucket=self.bucket_name, Key=filename)
@ -55,5 +61,6 @@ class OracleOCIStorage(BaseStorage):
except:
return False
@override
def delete(self, filename: str):
self.client.delete_object(Bucket=self.bucket_name, Key=filename)

View File

@ -1,6 +1,7 @@
import io
from collections.abc import Generator
from pathlib import Path
from typing import override
from supabase import Client
@ -28,29 +29,35 @@ class SupabaseStorage(BaseStorage):
if not self.bucket_exists():
self.client.storage.create_bucket(id=id, name=bucket_name)
@override
def save(self, filename, data):
self.client.storage.from_(self.bucket_name).upload(filename, data)
@override
def load_once(self, filename: str) -> bytes:
content: bytes = self.client.storage.from_(self.bucket_name).download(filename)
return content
@override
def load_stream(self, filename: str) -> Generator:
result = self.client.storage.from_(self.bucket_name).download(filename)
byte_stream = io.BytesIO(result)
while chunk := byte_stream.read(4096): # Read in chunks of 4KB
yield chunk
@override
def download(self, filename, target_filepath):
result = self.client.storage.from_(self.bucket_name).download(filename)
Path(target_filepath).write_bytes(result)
@override
def exists(self, filename):
result = self.client.storage.from_(self.bucket_name).list(path=filename)
if len(result) > 0:
return True
return False
@override
def delete(self, filename: str):
self.client.storage.from_(self.bucket_name).remove([filename])

View File

@ -1,4 +1,5 @@
from collections.abc import Generator
from typing import override
from qcloud_cos import CosConfig, CosS3Client
@ -29,23 +30,29 @@ class TencentCosStorage(BaseStorage):
)
self.client = CosS3Client(config)
@override
def save(self, filename, data):
self.client.put_object(Bucket=self.bucket_name, Body=data, Key=filename)
@override
def load_once(self, filename: str) -> bytes:
data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].get_raw_stream().read()
return data
@override
def load_stream(self, filename: str) -> Generator:
response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
yield from response["Body"].get_stream(chunk_size=4096)
@override
def download(self, filename, target_filepath):
response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
response["Body"].get_stream_to_file(target_filepath)
@override
def exists(self, filename):
return self.client.object_exists(Bucket=self.bucket_name, Key=filename)
@override
def delete(self, filename: str):
self.client.delete_object(Bucket=self.bucket_name, Key=filename)

View File

@ -1,4 +1,5 @@
from collections.abc import Generator
from typing import override
import tos
@ -27,11 +28,13 @@ class VolcengineTosStorage(BaseStorage):
region=dify_config.VOLCENGINE_TOS_REGION,
)
@override
def save(self, filename, data):
if not self.bucket_name:
raise ValueError("VOLCENGINE_TOS_BUCKET_NAME is not set")
self.client.put_object(bucket=self.bucket_name, key=filename, content=data)
@override
def load_once(self, filename: str) -> bytes:
if not self.bucket_name:
raise FileNotFoundError("VOLCENGINE_TOS_BUCKET_NAME is not set")
@ -40,6 +43,7 @@ class VolcengineTosStorage(BaseStorage):
raise TypeError(f"Expected bytes, got {type(data).__name__}")
return data
@override
def load_stream(self, filename: str) -> Generator:
if not self.bucket_name:
raise FileNotFoundError("VOLCENGINE_TOS_BUCKET_NAME is not set")
@ -47,11 +51,13 @@ class VolcengineTosStorage(BaseStorage):
while chunk := response.read(4096):
yield chunk
@override
def download(self, filename, target_filepath):
if not self.bucket_name:
raise ValueError("VOLCENGINE_TOS_BUCKET_NAME is not set")
self.client.get_object_to_file(bucket=self.bucket_name, key=filename, file_path=target_filepath)
@override
def exists(self, filename):
if not self.bucket_name:
return False
@ -60,6 +66,7 @@ class VolcengineTosStorage(BaseStorage):
return False
return True
@override
def delete(self, filename: str):
if not self.bucket_name:
return