From 57b02e341ccbbad0d3860b24b9ed7020876c389d Mon Sep 17 00:00:00 2001 From: Nishchay Mahor Date: Wed, 27 May 2026 23:04:47 -0700 Subject: [PATCH] refactor: add @override decorators to storage backend subclasses (#36406) (#36755) --- api/extensions/storage/aliyun_oss_storage.py | 7 +++++++ api/extensions/storage/aws_s3_storage.py | 7 +++++++ api/extensions/storage/azure_blob_storage.py | 7 +++++++ api/extensions/storage/baidu_obs_storage.py | 7 +++++++ .../clickzetta_volume/clickzetta_volume_storage.py | 9 ++++++++- api/extensions/storage/google_cloud_storage.py | 8 +++++++- api/extensions/storage/huawei_obs_storage.py | 7 +++++++ api/extensions/storage/opendal_storage.py | 9 ++++++++- api/extensions/storage/oracle_oci_storage.py | 7 +++++++ api/extensions/storage/supabase_storage.py | 7 +++++++ api/extensions/storage/tencent_cos_storage.py | 7 +++++++ api/extensions/storage/volcengine_tos_storage.py | 7 +++++++ 12 files changed, 86 insertions(+), 3 deletions(-) diff --git a/api/extensions/storage/aliyun_oss_storage.py b/api/extensions/storage/aliyun_oss_storage.py index 3d7ef99c9e..5e96e7dbab 100644 --- a/api/extensions/storage/aliyun_oss_storage.py +++ b/api/extensions/storage/aliyun_oss_storage.py @@ -1,5 +1,6 @@ import posixpath from collections.abc import Generator +from typing import override import oss2 as aliyun_s3 @@ -29,9 +30,11 @@ class AliyunOssStorage(BaseStorage): cloudbox_id=dify_config.ALIYUN_CLOUDBOX_ID, ) + @override def save(self, filename, data): self.client.put_object(self.__wrapper_folder_filename(filename), data) + @override def load_once(self, filename: str) -> bytes: obj = self.client.get_object(self.__wrapper_folder_filename(filename)) data = obj.read() @@ -39,17 +42,21 @@ class AliyunOssStorage(BaseStorage): return b"" return data + @override def load_stream(self, filename: str) -> Generator: obj = self.client.get_object(self.__wrapper_folder_filename(filename)) while chunk := obj.read(4096): yield chunk + @override def download(self, filename: str, target_filepath): self.client.get_object_to_file(self.__wrapper_folder_filename(filename), target_filepath) + @override def exists(self, filename: str): return self.client.object_exists(self.__wrapper_folder_filename(filename)) + @override def delete(self, filename: str): self.client.delete_object(self.__wrapper_folder_filename(filename)) diff --git a/api/extensions/storage/aws_s3_storage.py b/api/extensions/storage/aws_s3_storage.py index 978f60c9b0..018aa17ac4 100644 --- a/api/extensions/storage/aws_s3_storage.py +++ b/api/extensions/storage/aws_s3_storage.py @@ -1,5 +1,6 @@ import logging from collections.abc import Generator +from typing import override import boto3 from botocore.client import Config @@ -48,9 +49,11 @@ class AwsS3Storage(BaseStorage): # other error, raise exception raise + @override def save(self, filename, data): self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data) + @override def load_once(self, filename: str) -> bytes: try: data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() @@ -61,6 +64,7 @@ class AwsS3Storage(BaseStorage): raise return data + @override def load_stream(self, filename: str) -> Generator: try: response = self.client.get_object(Bucket=self.bucket_name, Key=filename) @@ -73,9 +77,11 @@ class AwsS3Storage(BaseStorage): else: raise + @override def download(self, filename, target_filepath): self.client.download_file(self.bucket_name, filename, target_filepath) + @override def exists(self, filename): try: self.client.head_object(Bucket=self.bucket_name, Key=filename) @@ -83,5 +89,6 @@ class AwsS3Storage(BaseStorage): except: return False + @override def delete(self, filename: str): self.client.delete_object(Bucket=self.bucket_name, Key=filename) diff --git a/api/extensions/storage/azure_blob_storage.py b/api/extensions/storage/azure_blob_storage.py index f270267ce9..a17f61eb9b 100644 --- a/api/extensions/storage/azure_blob_storage.py +++ b/api/extensions/storage/azure_blob_storage.py @@ -1,5 +1,6 @@ from collections.abc import Generator from datetime import timedelta +from typing import override from azure.identity import ChainedTokenCredential, DefaultAzureCredential from azure.storage.blob import AccountSasPermissions, BlobServiceClient, ResourceTypes, generate_account_sas @@ -26,6 +27,7 @@ class AzureBlobStorage(BaseStorage): else: self.credential = None + @override def save(self, filename, data): if not self.bucket_name: return @@ -34,6 +36,7 @@ class AzureBlobStorage(BaseStorage): blob_container = client.get_container_client(container=self.bucket_name) blob_container.upload_blob(filename, data) + @override def load_once(self, filename: str) -> bytes: if not self.bucket_name: raise FileNotFoundError("Azure bucket name is not configured.") @@ -46,6 +49,7 @@ class AzureBlobStorage(BaseStorage): raise TypeError(f"Expected bytes from blob.readall(), got {type(data).__name__}") return data + @override def load_stream(self, filename: str) -> Generator: if not self.bucket_name: raise FileNotFoundError("Azure bucket name is not configured.") @@ -55,6 +59,7 @@ class AzureBlobStorage(BaseStorage): blob_data = blob.download_blob() yield from blob_data.chunks() + @override def download(self, filename, target_filepath): if not self.bucket_name: return @@ -66,6 +71,7 @@ class AzureBlobStorage(BaseStorage): blob_data = blob.download_blob() blob_data.readinto(my_blob) + @override def exists(self, filename): if not self.bucket_name: return False @@ -75,6 +81,7 @@ class AzureBlobStorage(BaseStorage): blob = client.get_blob_client(container=self.bucket_name, blob=filename) return blob.exists() + @override def delete(self, filename: str): if not self.bucket_name: return diff --git a/api/extensions/storage/baidu_obs_storage.py b/api/extensions/storage/baidu_obs_storage.py index 65345b0e4b..9bb9bb3129 100644 --- a/api/extensions/storage/baidu_obs_storage.py +++ b/api/extensions/storage/baidu_obs_storage.py @@ -1,6 +1,7 @@ import base64 import hashlib from collections.abc import Generator +from typing import override from baidubce.auth.bce_credentials import BceCredentials from baidubce.bce_client_configuration import BceClientConfiguration @@ -26,6 +27,7 @@ class BaiduObsStorage(BaseStorage): self.client = BosClient(config=client_config) + @override def save(self, filename, data): md5 = hashlib.md5() md5.update(data) @@ -34,24 +36,29 @@ class BaiduObsStorage(BaseStorage): bucket_name=self.bucket_name, key=filename, data=data, content_length=len(data), content_md5=content_md5 ) + @override def load_once(self, filename: str) -> bytes: response = self.client.get_object(bucket_name=self.bucket_name, key=filename) data: bytes = response.data.read() return data + @override def load_stream(self, filename: str) -> Generator: response = self.client.get_object(bucket_name=self.bucket_name, key=filename).data while chunk := response.read(4096): yield chunk + @override def download(self, filename, target_filepath): self.client.get_object_to_file(bucket_name=self.bucket_name, key=filename, file_name=target_filepath) + @override def exists(self, filename): res = self.client.get_object_meta_data(bucket_name=self.bucket_name, key=filename) if res is None: return False return True + @override def delete(self, filename: str): self.client.delete_object(bucket_name=self.bucket_name, key=filename) diff --git a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py index 05492327c8..f09ccd2a28 100644 --- a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py +++ b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py @@ -10,7 +10,7 @@ import tempfile from collections.abc import Generator from io import BytesIO from pathlib import Path -from typing import Any +from typing import Any, override import clickzetta from pydantic import BaseModel, model_validator @@ -251,6 +251,7 @@ class ClickZettaVolumeStorage(BaseStorage): # Don't raise exception, let the operation continue # The table might exist but not be visible due to permissions + @override def save(self, filename: str, data: bytes): """Save data to ClickZetta Volume. @@ -304,6 +305,7 @@ class ClickZettaVolumeStorage(BaseStorage): # Clean up temporary file Path(temp_file_path).unlink(missing_ok=True) + @override def load_once(self, filename: str) -> bytes: """Load file content from ClickZetta Volume. @@ -364,6 +366,7 @@ class ClickZettaVolumeStorage(BaseStorage): logger.debug("File %s loaded from ClickZetta Volume", filename) return content + @override def load_stream(self, filename: str) -> Generator: """Load file as stream from ClickZetta Volume. @@ -382,6 +385,7 @@ class ClickZettaVolumeStorage(BaseStorage): logger.debug("File %s loaded as stream from ClickZetta Volume", filename) + @override def download(self, filename: str, target_filepath: str): """Download file from ClickZetta Volume to local path. @@ -395,6 +399,7 @@ class ClickZettaVolumeStorage(BaseStorage): logger.debug("File %s downloaded from ClickZetta Volume to %s", filename, target_filepath) + @override def exists(self, filename: str) -> bool: """Check if file exists in ClickZetta Volume. @@ -436,6 +441,7 @@ class ClickZettaVolumeStorage(BaseStorage): logger.warning("Error checking file existence for %s: %s", filename, e) return False + @override def delete(self, filename: str): """Delete file from ClickZetta Volume. @@ -472,6 +478,7 @@ class ClickZettaVolumeStorage(BaseStorage): logger.debug("File %s deleted from ClickZetta Volume", filename) + @override def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]: """Scan files and directories in ClickZetta Volume. diff --git a/api/extensions/storage/google_cloud_storage.py b/api/extensions/storage/google_cloud_storage.py index 00f7289aa4..83f199c45a 100644 --- a/api/extensions/storage/google_cloud_storage.py +++ b/api/extensions/storage/google_cloud_storage.py @@ -1,7 +1,7 @@ import base64 import io from collections.abc import Generator -from typing import Any +from typing import Any, override from google.cloud import storage as google_cloud_storage # type: ignore from pydantic import TypeAdapter @@ -29,12 +29,14 @@ class GoogleCloudStorage(BaseStorage): else: self.client = google_cloud_storage.Client() + @override def save(self, filename, data): bucket = self.client.get_bucket(self.bucket_name) blob = bucket.blob(filename) with io.BytesIO(data) as stream: blob.upload_from_file(stream) + @override def load_once(self, filename: str) -> bytes: bucket = self.client.get_bucket(self.bucket_name) blob = bucket.get_blob(filename) @@ -43,6 +45,7 @@ class GoogleCloudStorage(BaseStorage): data: bytes = blob.download_as_bytes() return data + @override def load_stream(self, filename: str) -> Generator: bucket = self.client.get_bucket(self.bucket_name) blob = bucket.get_blob(filename) @@ -52,6 +55,7 @@ class GoogleCloudStorage(BaseStorage): while chunk := blob_stream.read(4096): yield chunk + @override def download(self, filename, target_filepath): bucket = self.client.get_bucket(self.bucket_name) blob = bucket.get_blob(filename) @@ -59,11 +63,13 @@ class GoogleCloudStorage(BaseStorage): raise FileNotFoundError("File not found") blob.download_to_filename(target_filepath) + @override def exists(self, filename): bucket = self.client.get_bucket(self.bucket_name) blob = bucket.blob(filename) return blob.exists() + @override def delete(self, filename: str): bucket = self.client.get_bucket(self.bucket_name) bucket.delete_blob(filename) diff --git a/api/extensions/storage/huawei_obs_storage.py b/api/extensions/storage/huawei_obs_storage.py index 2e4961bcd5..72fdabe455 100644 --- a/api/extensions/storage/huawei_obs_storage.py +++ b/api/extensions/storage/huawei_obs_storage.py @@ -1,4 +1,5 @@ from collections.abc import Generator +from typing import override from obs import ObsClient @@ -20,27 +21,33 @@ class HuaweiObsStorage(BaseStorage): path_style=dify_config.HUAWEI_OBS_PATH_STYLE, ) + @override def save(self, filename, data): self.client.putObject(bucketName=self.bucket_name, objectKey=filename, content=data) + @override def load_once(self, filename: str) -> bytes: data: bytes = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response.read() return data + @override def load_stream(self, filename: str) -> Generator: response = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response while chunk := response.read(4096): yield chunk + @override def download(self, filename, target_filepath): self.client.getObject(bucketName=self.bucket_name, objectKey=filename, downloadPath=target_filepath) + @override def exists(self, filename): res = self._get_meta(filename) if res is None: return False return True + @override def delete(self, filename: str): self.client.deleteObject(bucketName=self.bucket_name, objectKey=filename) diff --git a/api/extensions/storage/opendal_storage.py b/api/extensions/storage/opendal_storage.py index cd7f7db295..214a746941 100644 --- a/api/extensions/storage/opendal_storage.py +++ b/api/extensions/storage/opendal_storage.py @@ -2,7 +2,7 @@ import logging import os from collections.abc import Generator from pathlib import Path -from typing import Any +from typing import Any, override import opendal from dotenv import dotenv_values @@ -41,10 +41,12 @@ class OpenDALStorage(BaseStorage): logger.debug("opendal operator created with scheme %s", scheme) logger.debug("added retry layer to opendal operator") + @override def save(self, filename: str, data: bytes): self.op.write(path=filename, bs=data) logger.debug("file %s saved", filename) + @override def load_once(self, filename: str) -> bytes: if not self.exists(filename): raise FileNotFoundError("File not found") @@ -53,6 +55,7 @@ class OpenDALStorage(BaseStorage): logger.debug("file %s loaded", filename) return content + @override def load_stream(self, filename: str) -> Generator: if not self.exists(filename): raise FileNotFoundError("File not found") @@ -67,6 +70,7 @@ class OpenDALStorage(BaseStorage): yield chunk logger.debug("file %s loaded as stream", filename) + @override def download(self, filename: str, target_filepath: str): if not self.exists(filename): raise FileNotFoundError("File not found") @@ -74,9 +78,11 @@ class OpenDALStorage(BaseStorage): Path(target_filepath).write_bytes(self.op.read(path=filename)) logger.debug("file %s downloaded to %s", filename, target_filepath) + @override def exists(self, filename: str) -> bool: return self.op.exists(path=filename) + @override def delete(self, filename: str): if self.exists(filename): self.op.delete(path=filename) @@ -84,6 +90,7 @@ class OpenDALStorage(BaseStorage): return logger.debug("file %s not found, skip delete", filename) + @override def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]: if not self.exists(path): raise FileNotFoundError("Path not found") diff --git a/api/extensions/storage/oracle_oci_storage.py b/api/extensions/storage/oracle_oci_storage.py index c7217874e6..cee6647949 100644 --- a/api/extensions/storage/oracle_oci_storage.py +++ b/api/extensions/storage/oracle_oci_storage.py @@ -1,4 +1,5 @@ from collections.abc import Generator +from typing import override import boto3 from botocore.exceptions import ClientError @@ -22,9 +23,11 @@ class OracleOCIStorage(BaseStorage): region_name=dify_config.OCI_REGION, ) + @override def save(self, filename, data): self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data) + @override def load_once(self, filename: str) -> bytes: try: data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() @@ -35,6 +38,7 @@ class OracleOCIStorage(BaseStorage): raise return data + @override def load_stream(self, filename: str) -> Generator: try: response = self.client.get_object(Bucket=self.bucket_name, Key=filename) @@ -45,9 +49,11 @@ class OracleOCIStorage(BaseStorage): else: raise + @override def download(self, filename, target_filepath): self.client.download_file(self.bucket_name, filename, target_filepath) + @override def exists(self, filename): try: self.client.head_object(Bucket=self.bucket_name, Key=filename) @@ -55,5 +61,6 @@ class OracleOCIStorage(BaseStorage): except: return False + @override def delete(self, filename: str): self.client.delete_object(Bucket=self.bucket_name, Key=filename) diff --git a/api/extensions/storage/supabase_storage.py b/api/extensions/storage/supabase_storage.py index 76066e12f5..f47153ce37 100644 --- a/api/extensions/storage/supabase_storage.py +++ b/api/extensions/storage/supabase_storage.py @@ -1,6 +1,7 @@ import io from collections.abc import Generator from pathlib import Path +from typing import override from supabase import Client @@ -28,29 +29,35 @@ class SupabaseStorage(BaseStorage): if not self.bucket_exists(): self.client.storage.create_bucket(id=id, name=bucket_name) + @override def save(self, filename, data): self.client.storage.from_(self.bucket_name).upload(filename, data) + @override def load_once(self, filename: str) -> bytes: content: bytes = self.client.storage.from_(self.bucket_name).download(filename) return content + @override def load_stream(self, filename: str) -> Generator: result = self.client.storage.from_(self.bucket_name).download(filename) byte_stream = io.BytesIO(result) while chunk := byte_stream.read(4096): # Read in chunks of 4KB yield chunk + @override def download(self, filename, target_filepath): result = self.client.storage.from_(self.bucket_name).download(filename) Path(target_filepath).write_bytes(result) + @override def exists(self, filename): result = self.client.storage.from_(self.bucket_name).list(path=filename) if len(result) > 0: return True return False + @override def delete(self, filename: str): self.client.storage.from_(self.bucket_name).remove([filename]) diff --git a/api/extensions/storage/tencent_cos_storage.py b/api/extensions/storage/tencent_cos_storage.py index c886c82038..da34a6e007 100644 --- a/api/extensions/storage/tencent_cos_storage.py +++ b/api/extensions/storage/tencent_cos_storage.py @@ -1,4 +1,5 @@ from collections.abc import Generator +from typing import override from qcloud_cos import CosConfig, CosS3Client @@ -29,23 +30,29 @@ class TencentCosStorage(BaseStorage): ) self.client = CosS3Client(config) + @override def save(self, filename, data): self.client.put_object(Bucket=self.bucket_name, Body=data, Key=filename) + @override def load_once(self, filename: str) -> bytes: data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].get_raw_stream().read() return data + @override def load_stream(self, filename: str) -> Generator: response = self.client.get_object(Bucket=self.bucket_name, Key=filename) yield from response["Body"].get_stream(chunk_size=4096) + @override def download(self, filename, target_filepath): response = self.client.get_object(Bucket=self.bucket_name, Key=filename) response["Body"].get_stream_to_file(target_filepath) + @override def exists(self, filename): return self.client.object_exists(Bucket=self.bucket_name, Key=filename) + @override def delete(self, filename: str): self.client.delete_object(Bucket=self.bucket_name, Key=filename) diff --git a/api/extensions/storage/volcengine_tos_storage.py b/api/extensions/storage/volcengine_tos_storage.py index d19d6b3032..709f198712 100644 --- a/api/extensions/storage/volcengine_tos_storage.py +++ b/api/extensions/storage/volcengine_tos_storage.py @@ -1,4 +1,5 @@ from collections.abc import Generator +from typing import override import tos @@ -27,11 +28,13 @@ class VolcengineTosStorage(BaseStorage): region=dify_config.VOLCENGINE_TOS_REGION, ) + @override def save(self, filename, data): if not self.bucket_name: raise ValueError("VOLCENGINE_TOS_BUCKET_NAME is not set") self.client.put_object(bucket=self.bucket_name, key=filename, content=data) + @override def load_once(self, filename: str) -> bytes: if not self.bucket_name: raise FileNotFoundError("VOLCENGINE_TOS_BUCKET_NAME is not set") @@ -40,6 +43,7 @@ class VolcengineTosStorage(BaseStorage): raise TypeError(f"Expected bytes, got {type(data).__name__}") return data + @override def load_stream(self, filename: str) -> Generator: if not self.bucket_name: raise FileNotFoundError("VOLCENGINE_TOS_BUCKET_NAME is not set") @@ -47,11 +51,13 @@ class VolcengineTosStorage(BaseStorage): while chunk := response.read(4096): yield chunk + @override def download(self, filename, target_filepath): if not self.bucket_name: raise ValueError("VOLCENGINE_TOS_BUCKET_NAME is not set") self.client.get_object_to_file(bucket=self.bucket_name, key=filename, file_path=target_filepath) + @override def exists(self, filename): if not self.bucket_name: return False @@ -60,6 +66,7 @@ class VolcengineTosStorage(BaseStorage): return False return True + @override def delete(self, filename: str): if not self.bucket_name: return