diff --git a/api/extensions/ext_app_metrics.py b/api/extensions/ext_app_metrics.py index 56a69a1862..4a6490b9f0 100644 --- a/api/extensions/ext_app_metrics.py +++ b/api/extensions/ext_app_metrics.py @@ -10,14 +10,14 @@ from dify_app import DifyApp def init_app(app: DifyApp): @app.after_request - def after_request(response): + def after_request(response): # pyright: ignore[reportUnusedFunction] """Add Version headers to the response.""" response.headers.add("X-Version", dify_config.project.version) response.headers.add("X-Env", dify_config.DEPLOY_ENV) return response @app.route("/health") - def health(): + def health(): # pyright: ignore[reportUnusedFunction] return Response( json.dumps({"pid": os.getpid(), "status": "ok", "version": dify_config.project.version}), status=200, @@ -25,7 +25,7 @@ def init_app(app: DifyApp): ) @app.route("/threads") - def threads(): + def threads(): # pyright: ignore[reportUnusedFunction] num_threads = threading.active_count() threads = threading.enumerate() @@ -50,7 +50,7 @@ def init_app(app: DifyApp): } @app.route("/db-pool-stat") - def pool_stat(): + def pool_stat(): # pyright: ignore[reportUnusedFunction] from extensions.ext_database import db engine = db.engine diff --git a/api/extensions/ext_database.py b/api/extensions/ext_database.py index 067ce39e4f..c90b1d0a9f 100644 --- a/api/extensions/ext_database.py +++ b/api/extensions/ext_database.py @@ -10,7 +10,7 @@ from models.engine import db logger = logging.getLogger(__name__) # Global flag to avoid duplicate registration of event listener -_GEVENT_COMPATIBILITY_SETUP: bool = False +_gevent_compatibility_setup: bool = False def _safe_rollback(connection): @@ -26,14 +26,14 @@ def _safe_rollback(connection): def _setup_gevent_compatibility(): - global _GEVENT_COMPATIBILITY_SETUP # pylint: disable=global-statement + global _gevent_compatibility_setup # pylint: disable=global-statement # Avoid duplicate registration - if _GEVENT_COMPATIBILITY_SETUP: + if _gevent_compatibility_setup: return @event.listens_for(Pool, "reset") - def _safe_reset(dbapi_connection, connection_record, reset_state): # pylint: disable=unused-argument + def _safe_reset(dbapi_connection, connection_record, reset_state): # pyright: ignore[reportUnusedFunction] if reset_state.terminate_only: return @@ -47,7 +47,7 @@ def _setup_gevent_compatibility(): except (AttributeError, ImportError): _safe_rollback(dbapi_connection) - _GEVENT_COMPATIBILITY_SETUP = True + _gevent_compatibility_setup = True def init_app(app: DifyApp): diff --git a/api/extensions/ext_import_modules.py b/api/extensions/ext_import_modules.py index 9566f430b6..4eb363ff93 100644 --- a/api/extensions/ext_import_modules.py +++ b/api/extensions/ext_import_modules.py @@ -2,4 +2,4 @@ from dify_app import DifyApp def init_app(app: DifyApp): - from events import event_handlers # noqa: F401 + from events import event_handlers # noqa: F401 # pyright: ignore[reportUnusedImport] diff --git a/api/extensions/storage/aliyun_oss_storage.py b/api/extensions/storage/aliyun_oss_storage.py index 00bf5d4f93..5da4737138 100644 --- a/api/extensions/storage/aliyun_oss_storage.py +++ b/api/extensions/storage/aliyun_oss_storage.py @@ -33,7 +33,9 @@ class AliyunOssStorage(BaseStorage): def load_once(self, filename: str) -> bytes: obj = self.client.get_object(self.__wrapper_folder_filename(filename)) - data: bytes = obj.read() + data = obj.read() + if not isinstance(data, bytes): + return b"" return data def load_stream(self, filename: str) -> Generator: diff --git a/api/extensions/storage/aws_s3_storage.py b/api/extensions/storage/aws_s3_storage.py index e755ab089a..6ab2a95e3c 100644 --- a/api/extensions/storage/aws_s3_storage.py +++ b/api/extensions/storage/aws_s3_storage.py @@ -39,10 +39,10 @@ class AwsS3Storage(BaseStorage): self.client.head_bucket(Bucket=self.bucket_name) except ClientError as e: # if bucket not exists, create it - if e.response["Error"]["Code"] == "404": + if e.response.get("Error", {}).get("Code") == "404": self.client.create_bucket(Bucket=self.bucket_name) # if bucket is not accessible, pass, maybe the bucket is existing but not accessible - elif e.response["Error"]["Code"] == "403": + elif e.response.get("Error", {}).get("Code") == "403": pass else: # other error, raise exception @@ -55,7 +55,7 @@ class AwsS3Storage(BaseStorage): try: data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() except ClientError as ex: - if ex.response["Error"]["Code"] == "NoSuchKey": + if ex.response.get("Error", {}).get("Code") == "NoSuchKey": raise FileNotFoundError("File not found") else: raise @@ -66,7 +66,7 @@ class AwsS3Storage(BaseStorage): response = self.client.get_object(Bucket=self.bucket_name, Key=filename) yield from response["Body"].iter_chunks() except ClientError as ex: - if ex.response["Error"]["Code"] == "NoSuchKey": + if ex.response.get("Error", {}).get("Code") == "NoSuchKey": raise FileNotFoundError("file not found") elif "reached max retries" in str(ex): raise ValueError("please do not request the same file too frequently") diff --git a/api/extensions/storage/azure_blob_storage.py b/api/extensions/storage/azure_blob_storage.py index 9053aece89..4bccaf13c8 100644 --- a/api/extensions/storage/azure_blob_storage.py +++ b/api/extensions/storage/azure_blob_storage.py @@ -27,24 +27,38 @@ class AzureBlobStorage(BaseStorage): self.credential = None def save(self, filename, data): + if not self.bucket_name: + return + client = self._sync_client() blob_container = client.get_container_client(container=self.bucket_name) blob_container.upload_blob(filename, data) def load_once(self, filename: str) -> bytes: + if not self.bucket_name: + raise FileNotFoundError("Azure bucket name is not configured.") + client = self._sync_client() blob = client.get_container_client(container=self.bucket_name) blob = blob.get_blob_client(blob=filename) - data: bytes = blob.download_blob().readall() + data = blob.download_blob().readall() + if not isinstance(data, bytes): + raise TypeError(f"Expected bytes from blob.readall(), got {type(data).__name__}") return data def load_stream(self, filename: str) -> Generator: + if not self.bucket_name: + raise FileNotFoundError("Azure bucket name is not configured.") + client = self._sync_client() blob = client.get_blob_client(container=self.bucket_name, blob=filename) blob_data = blob.download_blob() yield from blob_data.chunks() def download(self, filename, target_filepath): + if not self.bucket_name: + return + client = self._sync_client() blob = client.get_blob_client(container=self.bucket_name, blob=filename) @@ -53,12 +67,18 @@ class AzureBlobStorage(BaseStorage): blob_data.readinto(my_blob) def exists(self, filename): + if not self.bucket_name: + return False + client = self._sync_client() blob = client.get_blob_client(container=self.bucket_name, blob=filename) return blob.exists() def delete(self, filename): + if not self.bucket_name: + return + client = self._sync_client() blob_container = client.get_container_client(container=self.bucket_name) diff --git a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py index 2ffac9a92d..06c528ca41 100644 --- a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py +++ b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py @@ -430,7 +430,7 @@ class ClickZettaVolumeStorage(BaseStorage): rows = self._execute_sql(sql, fetch=True) - exists = len(rows) > 0 + exists = len(rows) > 0 if rows else False logger.debug("File %s exists check: %s", filename, exists) return exists except Exception as e: @@ -509,16 +509,17 @@ class ClickZettaVolumeStorage(BaseStorage): rows = self._execute_sql(sql, fetch=True) result = [] - for row in rows: - file_path = row[0] # relative_path column + if rows: + for row in rows: + file_path = row[0] # relative_path column - # For User Volume, remove dify prefix from results - dify_prefix_with_slash = f"{self._config.dify_prefix}/" - if volume_prefix == "USER VOLUME" and file_path.startswith(dify_prefix_with_slash): - file_path = file_path[len(dify_prefix_with_slash) :] # Remove prefix + # For User Volume, remove dify prefix from results + dify_prefix_with_slash = f"{self._config.dify_prefix}/" + if volume_prefix == "USER VOLUME" and file_path.startswith(dify_prefix_with_slash): + file_path = file_path[len(dify_prefix_with_slash) :] # Remove prefix - if files and not file_path.endswith("/") or directories and file_path.endswith("/"): - result.append(file_path) + if files and not file_path.endswith("/") or directories and file_path.endswith("/"): + result.append(file_path) logger.debug("Scanned %d items in path %s", len(result), path) return result diff --git a/api/extensions/storage/clickzetta_volume/volume_permissions.py b/api/extensions/storage/clickzetta_volume/volume_permissions.py index eb1116638f..6dcf800abb 100644 --- a/api/extensions/storage/clickzetta_volume/volume_permissions.py +++ b/api/extensions/storage/clickzetta_volume/volume_permissions.py @@ -439,6 +439,11 @@ class VolumePermissionManager: self._permission_cache.clear() logger.debug("Permission cache cleared") + @property + def volume_type(self) -> str | None: + """Get the volume type.""" + return self._volume_type + def get_permission_summary(self, dataset_id: str | None = None) -> dict[str, bool]: """Get permission summary @@ -632,13 +637,13 @@ def check_volume_permission(permission_manager: VolumePermissionManager, operati VolumePermissionError: If no permission """ if not permission_manager.validate_operation(operation, dataset_id): - error_message = f"Permission denied for operation '{operation}' on {permission_manager._volume_type} volume" + error_message = f"Permission denied for operation '{operation}' on {permission_manager.volume_type} volume" if dataset_id: error_message += f" (dataset: {dataset_id})" raise VolumePermissionError( error_message, operation=operation, - volume_type=permission_manager._volume_type or "unknown", + volume_type=permission_manager.volume_type or "unknown", dataset_id=dataset_id, ) diff --git a/api/extensions/storage/google_cloud_storage.py b/api/extensions/storage/google_cloud_storage.py index 705639f42e..7f59252f2f 100644 --- a/api/extensions/storage/google_cloud_storage.py +++ b/api/extensions/storage/google_cloud_storage.py @@ -35,12 +35,16 @@ class GoogleCloudStorage(BaseStorage): def load_once(self, filename: str) -> bytes: bucket = self.client.get_bucket(self.bucket_name) blob = bucket.get_blob(filename) + if blob is None: + raise FileNotFoundError("File not found") data: bytes = blob.download_as_bytes() return data def load_stream(self, filename: str) -> Generator: bucket = self.client.get_bucket(self.bucket_name) blob = bucket.get_blob(filename) + if blob is None: + raise FileNotFoundError("File not found") with blob.open(mode="rb") as blob_stream: while chunk := blob_stream.read(4096): yield chunk @@ -48,6 +52,8 @@ class GoogleCloudStorage(BaseStorage): def download(self, filename, target_filepath): bucket = self.client.get_bucket(self.bucket_name) blob = bucket.get_blob(filename) + if blob is None: + raise FileNotFoundError("File not found") blob.download_to_filename(target_filepath) def exists(self, filename): diff --git a/api/extensions/storage/huawei_obs_storage.py b/api/extensions/storage/huawei_obs_storage.py index 07f1d19970..3e75ecb7a9 100644 --- a/api/extensions/storage/huawei_obs_storage.py +++ b/api/extensions/storage/huawei_obs_storage.py @@ -45,7 +45,7 @@ class HuaweiObsStorage(BaseStorage): def _get_meta(self, filename): res = self.client.getObjectMetadata(bucketName=self.bucket_name, objectKey=filename) - if res.status < 300: + if res and res.status and res.status < 300: return res else: return None diff --git a/api/extensions/storage/oracle_oci_storage.py b/api/extensions/storage/oracle_oci_storage.py index 82829f7fd5..acc00cbd6b 100644 --- a/api/extensions/storage/oracle_oci_storage.py +++ b/api/extensions/storage/oracle_oci_storage.py @@ -29,7 +29,7 @@ class OracleOCIStorage(BaseStorage): try: data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() except ClientError as ex: - if ex.response["Error"]["Code"] == "NoSuchKey": + if ex.response.get("Error", {}).get("Code") == "NoSuchKey": raise FileNotFoundError("File not found") else: raise @@ -40,7 +40,7 @@ class OracleOCIStorage(BaseStorage): response = self.client.get_object(Bucket=self.bucket_name, Key=filename) yield from response["Body"].iter_chunks() except ClientError as ex: - if ex.response["Error"]["Code"] == "NoSuchKey": + if ex.response.get("Error", {}).get("Code") == "NoSuchKey": raise FileNotFoundError("File not found") else: raise diff --git a/api/extensions/storage/supabase_storage.py b/api/extensions/storage/supabase_storage.py index 711c3f7211..2ca84d4c15 100644 --- a/api/extensions/storage/supabase_storage.py +++ b/api/extensions/storage/supabase_storage.py @@ -46,13 +46,13 @@ class SupabaseStorage(BaseStorage): Path(target_filepath).write_bytes(result) def exists(self, filename): - result = self.client.storage.from_(self.bucket_name).list(filename) - if result.count() > 0: + result = self.client.storage.from_(self.bucket_name).list(path=filename) + if len(result) > 0: return True return False def delete(self, filename): - self.client.storage.from_(self.bucket_name).remove(filename) + self.client.storage.from_(self.bucket_name).remove([filename]) def bucket_exists(self): buckets = self.client.storage.list_buckets() diff --git a/api/extensions/storage/volcengine_tos_storage.py b/api/extensions/storage/volcengine_tos_storage.py index 32839d3497..8ed8e4c170 100644 --- a/api/extensions/storage/volcengine_tos_storage.py +++ b/api/extensions/storage/volcengine_tos_storage.py @@ -11,6 +11,14 @@ class VolcengineTosStorage(BaseStorage): def __init__(self): super().__init__() + if not dify_config.VOLCENGINE_TOS_ACCESS_KEY: + raise ValueError("VOLCENGINE_TOS_ACCESS_KEY is not set") + if not dify_config.VOLCENGINE_TOS_SECRET_KEY: + raise ValueError("VOLCENGINE_TOS_SECRET_KEY is not set") + if not dify_config.VOLCENGINE_TOS_ENDPOINT: + raise ValueError("VOLCENGINE_TOS_ENDPOINT is not set") + if not dify_config.VOLCENGINE_TOS_REGION: + raise ValueError("VOLCENGINE_TOS_REGION is not set") self.bucket_name = dify_config.VOLCENGINE_TOS_BUCKET_NAME self.client = tos.TosClientV2( ak=dify_config.VOLCENGINE_TOS_ACCESS_KEY, @@ -20,27 +28,39 @@ class VolcengineTosStorage(BaseStorage): ) def save(self, filename, data): + if not self.bucket_name: + raise ValueError("VOLCENGINE_TOS_BUCKET_NAME is not set") self.client.put_object(bucket=self.bucket_name, key=filename, content=data) def load_once(self, filename: str) -> bytes: + if not self.bucket_name: + raise FileNotFoundError("VOLCENGINE_TOS_BUCKET_NAME is not set") data = self.client.get_object(bucket=self.bucket_name, key=filename).read() if not isinstance(data, bytes): raise TypeError(f"Expected bytes, got {type(data).__name__}") return data def load_stream(self, filename: str) -> Generator: + if not self.bucket_name: + raise FileNotFoundError("VOLCENGINE_TOS_BUCKET_NAME is not set") response = self.client.get_object(bucket=self.bucket_name, key=filename) while chunk := response.read(4096): yield chunk def download(self, filename, target_filepath): + if not self.bucket_name: + raise ValueError("VOLCENGINE_TOS_BUCKET_NAME is not set") self.client.get_object_to_file(bucket=self.bucket_name, key=filename, file_path=target_filepath) def exists(self, filename): + if not self.bucket_name: + return False res = self.client.head_object(bucket=self.bucket_name, key=filename) if res.status_code != 200: return False return True def delete(self, filename): + if not self.bucket_name: + return self.client.delete_object(bucket=self.bucket_name, key=filename) diff --git a/api/pyrightconfig.json b/api/pyrightconfig.json index 00dda8b087..d98e30bb80 100644 --- a/api/pyrightconfig.json +++ b/api/pyrightconfig.json @@ -5,7 +5,6 @@ ".venv", "migrations/", "core/rag", - "extensions", "core/app/app_config/easy_ui_based_app/dataset" ], "typeCheckingMode": "strict", diff --git a/api/tests/unit_tests/extensions/storage/test_supabase_storage.py b/api/tests/unit_tests/extensions/storage/test_supabase_storage.py index 958072223e..476f87269c 100644 --- a/api/tests/unit_tests/extensions/storage/test_supabase_storage.py +++ b/api/tests/unit_tests/extensions/storage/test_supabase_storage.py @@ -172,73 +172,31 @@ class TestSupabaseStorage: assert "test-bucket" in [call[0][0] for call in mock_client.storage.from_.call_args_list if call[0]] mock_client.storage.from_().download.assert_called_with("test.txt") - def test_exists_with_list_containing_items(self, storage_with_mock_client): - """Test exists returns True when list() returns items (using len() > 0).""" + def test_exists_returns_true_when_file_found(self, storage_with_mock_client): + """Test exists returns True when list() returns items.""" storage, mock_client = storage_with_mock_client - # Mock list return with special object that has count() method - mock_list_result = Mock() - mock_list_result.count.return_value = 1 - mock_client.storage.from_().list.return_value = mock_list_result + mock_client.storage.from_().list.return_value = [{"name": "test.txt"}] result = storage.exists("test.txt") assert result is True - # from_ gets called during init too, so just check it was called with the right bucket assert "test-bucket" in [call[0][0] for call in mock_client.storage.from_.call_args_list if call[0]] - mock_client.storage.from_().list.assert_called_with("test.txt") + mock_client.storage.from_().list.assert_called_with(path="test.txt") - def test_exists_with_count_method_greater_than_zero(self, storage_with_mock_client): - """Test exists returns True when list result has count() > 0.""" + def test_exists_returns_false_when_file_not_found(self, storage_with_mock_client): + """Test exists returns False when list() returns an empty list.""" storage, mock_client = storage_with_mock_client - # Mock list return with count() method - mock_list_result = Mock() - mock_list_result.count.return_value = 1 - mock_client.storage.from_().list.return_value = mock_list_result - - result = storage.exists("test.txt") - - assert result is True - # Verify the correct calls were made - assert "test-bucket" in [call[0][0] for call in mock_client.storage.from_.call_args_list if call[0]] - mock_client.storage.from_().list.assert_called_with("test.txt") - mock_list_result.count.assert_called() - - def test_exists_with_count_method_zero(self, storage_with_mock_client): - """Test exists returns False when list result has count() == 0.""" - storage, mock_client = storage_with_mock_client - - # Mock list return with count() method returning 0 - mock_list_result = Mock() - mock_list_result.count.return_value = 0 - mock_client.storage.from_().list.return_value = mock_list_result + mock_client.storage.from_().list.return_value = [] result = storage.exists("test.txt") assert result is False - # Verify the correct calls were made assert "test-bucket" in [call[0][0] for call in mock_client.storage.from_.call_args_list if call[0]] - mock_client.storage.from_().list.assert_called_with("test.txt") - mock_list_result.count.assert_called() + mock_client.storage.from_().list.assert_called_with(path="test.txt") - def test_exists_with_empty_list(self, storage_with_mock_client): - """Test exists returns False when list() returns empty list.""" - storage, mock_client = storage_with_mock_client - - # Mock list return with special object that has count() method returning 0 - mock_list_result = Mock() - mock_list_result.count.return_value = 0 - mock_client.storage.from_().list.return_value = mock_list_result - - result = storage.exists("test.txt") - - assert result is False - # Verify the correct calls were made - assert "test-bucket" in [call[0][0] for call in mock_client.storage.from_.call_args_list if call[0]] - mock_client.storage.from_().list.assert_called_with("test.txt") - - def test_delete_calls_remove_with_filename(self, storage_with_mock_client): + def test_delete_calls_remove_with_filename_in_list(self, storage_with_mock_client): """Test delete calls remove([...]) (some client versions require a list).""" storage, mock_client = storage_with_mock_client @@ -247,7 +205,7 @@ class TestSupabaseStorage: storage.delete(filename) mock_client.storage.from_.assert_called_once_with("test-bucket") - mock_client.storage.from_().remove.assert_called_once_with(filename) + mock_client.storage.from_().remove.assert_called_once_with([filename]) def test_bucket_exists_returns_true_when_bucket_found(self): """Test bucket_exists returns True when bucket is found in list.""" diff --git a/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py b/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py index 04988e85d8..1659205ec0 100644 --- a/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py +++ b/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import pytest from tos import TosClientV2 # type: ignore @@ -13,7 +15,13 @@ class TestVolcengineTos(BaseStorageTest): @pytest.fixture(autouse=True) def setup_method(self, setup_volcengine_tos_mock): """Executed before each test method.""" - self.storage = VolcengineTosStorage() + with patch("extensions.storage.volcengine_tos_storage.dify_config") as mock_config: + mock_config.VOLCENGINE_TOS_ACCESS_KEY = "test_access_key" + mock_config.VOLCENGINE_TOS_SECRET_KEY = "test_secret_key" + mock_config.VOLCENGINE_TOS_ENDPOINT = "test_endpoint" + mock_config.VOLCENGINE_TOS_REGION = "test_region" + self.storage = VolcengineTosStorage() + self.storage.bucket_name = get_example_bucket() self.storage.client = TosClientV2( ak="dify",