diff --git a/api/core/repositories/factory.py b/api/core/repositories/factory.py index 09c775f3a6..854c122331 100644 --- a/api/core/repositories/factory.py +++ b/api/core/repositories/factory.py @@ -5,10 +5,7 @@ This module provides a Django-like settings system for repository implementation allowing users to configure different repository backends through string paths. """ -import importlib -import inspect -import logging -from typing import Protocol, Union +from typing import Union from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -16,12 +13,11 @@ from sqlalchemy.orm import sessionmaker from configs import dify_config from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from libs.module_loading import import_string from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom from models.workflow import WorkflowNodeExecutionTriggeredFrom -logger = logging.getLogger(__name__) - class RepositoryImportError(Exception): """Raised when a repository implementation cannot be imported or instantiated.""" @@ -37,96 +33,6 @@ class DifyCoreRepositoryFactory: are specified as module paths (e.g., 'module.submodule.ClassName'). """ - @staticmethod - def _import_class(class_path: str) -> type: - """ - Import a class from a module path string. - - Args: - class_path: Full module path to the class (e.g., 'module.submodule.ClassName') - - Returns: - The imported class - - Raises: - RepositoryImportError: If the class cannot be imported - """ - try: - module_path, class_name = class_path.rsplit(".", 1) - module = importlib.import_module(module_path) - repo_class = getattr(module, class_name) - assert isinstance(repo_class, type) - return repo_class - except (ValueError, ImportError, AttributeError) as e: - raise RepositoryImportError(f"Cannot import repository class '{class_path}': {e}") from e - - @staticmethod - def _validate_repository_interface(repository_class: type, expected_interface: type[Protocol]) -> None: # type: ignore - """ - Validate that a class implements the expected repository interface. - - Args: - repository_class: The class to validate - expected_interface: The expected interface/protocol - - Raises: - RepositoryImportError: If the class doesn't implement the interface - """ - # Check if the class has all required methods from the protocol - required_methods = [ - method - for method in dir(expected_interface) - if not method.startswith("_") and callable(getattr(expected_interface, method, None)) - ] - - missing_methods = [] - for method_name in required_methods: - if not hasattr(repository_class, method_name): - missing_methods.append(method_name) - - if missing_methods: - raise RepositoryImportError( - f"Repository class '{repository_class.__name__}' does not implement required methods " - f"{missing_methods} from interface '{expected_interface.__name__}'" - ) - - @staticmethod - def _validate_constructor_signature(repository_class: type, required_params: list[str]) -> None: - """ - Validate that a repository class constructor accepts required parameters. - Args: - repository_class: The class to validate - required_params: List of required parameter names - Raises: - RepositoryImportError: If the constructor doesn't accept required parameters - """ - - try: - # MyPy may flag the line below with the following error: - # - # > Accessing "__init__" on an instance is unsound, since - # > instance.__init__ could be from an incompatible subclass. - # - # Despite this, we need to ensure that the constructor of `repository_class` - # has a compatible signature. - signature = inspect.signature(repository_class.__init__) # type: ignore[misc] - param_names = list(signature.parameters.keys()) - - # Remove 'self' parameter - if "self" in param_names: - param_names.remove("self") - - missing_params = [param for param in required_params if param not in param_names] - if missing_params: - raise RepositoryImportError( - f"Repository class '{repository_class.__name__}' constructor does not accept required parameters: " - f"{missing_params}. Expected parameters: {required_params}" - ) - except Exception as e: - raise RepositoryImportError( - f"Failed to validate constructor signature for '{repository_class.__name__}': {e}" - ) from e - @classmethod def create_workflow_execution_repository( cls, @@ -151,24 +57,16 @@ class DifyCoreRepositoryFactory: RepositoryImportError: If the configured repository cannot be created """ class_path = dify_config.CORE_WORKFLOW_EXECUTION_REPOSITORY - logger.debug("Creating WorkflowExecutionRepository from: %s", class_path) try: - repository_class = cls._import_class(class_path) - cls._validate_repository_interface(repository_class, WorkflowExecutionRepository) - - # All repository types now use the same constructor parameters + repository_class = import_string(class_path) return repository_class( # type: ignore[no-any-return] session_factory=session_factory, user=user, app_id=app_id, triggered_from=triggered_from, ) - except RepositoryImportError: - # Re-raise our custom errors as-is - raise - except Exception as e: - logger.exception("Failed to create WorkflowExecutionRepository") + except (ImportError, Exception) as e: raise RepositoryImportError(f"Failed to create WorkflowExecutionRepository from '{class_path}': {e}") from e @classmethod @@ -195,24 +93,16 @@ class DifyCoreRepositoryFactory: RepositoryImportError: If the configured repository cannot be created """ class_path = dify_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY - logger.debug("Creating WorkflowNodeExecutionRepository from: %s", class_path) try: - repository_class = cls._import_class(class_path) - cls._validate_repository_interface(repository_class, WorkflowNodeExecutionRepository) - - # All repository types now use the same constructor parameters + repository_class = import_string(class_path) return repository_class( # type: ignore[no-any-return] session_factory=session_factory, user=user, app_id=app_id, triggered_from=triggered_from, ) - except RepositoryImportError: - # Re-raise our custom errors as-is - raise - except Exception as e: - logger.exception("Failed to create WorkflowNodeExecutionRepository") + except (ImportError, Exception) as e: raise RepositoryImportError( f"Failed to create WorkflowNodeExecutionRepository from '{class_path}': {e}" ) from e diff --git a/api/libs/module_loading.py b/api/libs/module_loading.py new file mode 100644 index 0000000000..616d072a1b --- /dev/null +++ b/api/libs/module_loading.py @@ -0,0 +1,55 @@ +""" +Module loading utilities similar to Django's module_loading. + +Reference implementation from Django: +https://github.com/django/django/blob/main/django/utils/module_loading.py +""" + +import sys +from importlib import import_module +from typing import Any + + +def cached_import(module_path: str, class_name: str) -> Any: + """ + Import a module and return the named attribute/class from it, with caching. + + Args: + module_path: The module path to import from + class_name: The attribute/class name to retrieve + + Returns: + The imported attribute/class + """ + if not ( + (module := sys.modules.get(module_path)) + and (spec := getattr(module, "__spec__", None)) + and getattr(spec, "_initializing", False) is False + ): + module = import_module(module_path) + return getattr(module, class_name) + + +def import_string(dotted_path: str) -> Any: + """ + Import a dotted module path and return the attribute/class designated by + the last name in the path. Raise ImportError if the import failed. + + Args: + dotted_path: Full module path to the class (e.g., 'module.submodule.ClassName') + + Returns: + The imported class or attribute + + Raises: + ImportError: If the module or attribute cannot be imported + """ + try: + module_path, class_name = dotted_path.rsplit(".", 1) + except ValueError as err: + raise ImportError(f"{dotted_path} doesn't look like a module path") from err + + try: + return cached_import(module_path, class_name) + except AttributeError as err: + raise ImportError(f'Module "{module_path}" does not define a "{class_name}" attribute/class') from err diff --git a/api/repositories/factory.py b/api/repositories/factory.py index 1f0320054c..0be9c8908c 100644 --- a/api/repositories/factory.py +++ b/api/repositories/factory.py @@ -5,17 +5,14 @@ This factory is specifically designed for DifyAPI repositories that handle service-layer operations with dependency injection patterns. """ -import logging - from sqlalchemy.orm import sessionmaker from configs import dify_config from core.repositories import DifyCoreRepositoryFactory, RepositoryImportError +from libs.module_loading import import_string from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository from repositories.api_workflow_run_repository import APIWorkflowRunRepository -logger = logging.getLogger(__name__) - class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory): """ @@ -50,17 +47,9 @@ class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory): class_path = dify_config.API_WORKFLOW_NODE_EXECUTION_REPOSITORY try: - repository_class = cls._import_class(class_path) - cls._validate_repository_interface(repository_class, DifyAPIWorkflowNodeExecutionRepository) - # Service repository requires session_maker parameter - cls._validate_constructor_signature(repository_class, ["session_maker"]) - + repository_class = import_string(class_path) return repository_class(session_maker=session_maker) # type: ignore[no-any-return] - except RepositoryImportError: - # Re-raise our custom errors as-is - raise - except Exception as e: - logger.exception("Failed to create DifyAPIWorkflowNodeExecutionRepository") + except (ImportError, Exception) as e: raise RepositoryImportError( f"Failed to create DifyAPIWorkflowNodeExecutionRepository from '{class_path}': {e}" ) from e @@ -87,15 +76,7 @@ class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory): class_path = dify_config.API_WORKFLOW_RUN_REPOSITORY try: - repository_class = cls._import_class(class_path) - cls._validate_repository_interface(repository_class, APIWorkflowRunRepository) - # Service repository requires session_maker parameter - cls._validate_constructor_signature(repository_class, ["session_maker"]) - + repository_class = import_string(class_path) return repository_class(session_maker=session_maker) # type: ignore[no-any-return] - except RepositoryImportError: - # Re-raise our custom errors as-is - raise - except Exception as e: - logger.exception("Failed to create APIWorkflowRunRepository") + except (ImportError, Exception) as e: raise RepositoryImportError(f"Failed to create APIWorkflowRunRepository from '{class_path}': {e}") from e diff --git a/api/tests/unit_tests/core/repositories/test_factory.py b/api/tests/unit_tests/core/repositories/test_factory.py index 5146e82e8f..30f51902ef 100644 --- a/api/tests/unit_tests/core/repositories/test_factory.py +++ b/api/tests/unit_tests/core/repositories/test_factory.py @@ -2,19 +2,19 @@ Unit tests for the RepositoryFactory. This module tests the factory pattern implementation for creating repository instances -based on configuration, including error handling and validation. +based on configuration, including error handling. """ from unittest.mock import MagicMock, patch import pytest -from pytest_mock import MockerFixture from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from libs.module_loading import import_string from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom from models.workflow import WorkflowNodeExecutionTriggeredFrom @@ -23,98 +23,30 @@ from models.workflow import WorkflowNodeExecutionTriggeredFrom class TestRepositoryFactory: """Test cases for RepositoryFactory.""" - def test_import_class_success(self): + def test_import_string_success(self): """Test successful class import.""" # Test importing a real class class_path = "unittest.mock.MagicMock" - result = DifyCoreRepositoryFactory._import_class(class_path) + result = import_string(class_path) assert result is MagicMock - def test_import_class_invalid_path(self): + def test_import_string_invalid_path(self): """Test import with invalid module path.""" - with pytest.raises(RepositoryImportError) as exc_info: - DifyCoreRepositoryFactory._import_class("invalid.module.path") - assert "Cannot import repository class" in str(exc_info.value) + with pytest.raises(ImportError) as exc_info: + import_string("invalid.module.path") + assert "No module named" in str(exc_info.value) - def test_import_class_invalid_class_name(self): + def test_import_string_invalid_class_name(self): """Test import with invalid class name.""" - with pytest.raises(RepositoryImportError) as exc_info: - DifyCoreRepositoryFactory._import_class("unittest.mock.NonExistentClass") - assert "Cannot import repository class" in str(exc_info.value) + with pytest.raises(ImportError) as exc_info: + import_string("unittest.mock.NonExistentClass") + assert "does not define" in str(exc_info.value) - def test_import_class_malformed_path(self): + def test_import_string_malformed_path(self): """Test import with malformed path (no dots).""" - with pytest.raises(RepositoryImportError) as exc_info: - DifyCoreRepositoryFactory._import_class("invalidpath") - assert "Cannot import repository class" in str(exc_info.value) - - def test_validate_repository_interface_success(self): - """Test successful interface validation.""" - - # Create a mock class that implements the required methods - class MockRepository: - def save(self): - pass - - def get_by_id(self): - pass - - # Create a mock interface class - class MockInterface: - def save(self): - pass - - def get_by_id(self): - pass - - # Should not raise an exception when all methods are present - DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface) - - def test_validate_repository_interface_missing_methods(self): - """Test interface validation with missing methods.""" - - # Create a mock class that's missing required methods - class IncompleteRepository: - def save(self): - pass - - # Missing get_by_id method - - # Create a mock interface that requires both methods - class MockInterface: - def save(self): - pass - - def get_by_id(self): - pass - - def missing_method(self): - pass - - with pytest.raises(RepositoryImportError) as exc_info: - DifyCoreRepositoryFactory._validate_repository_interface(IncompleteRepository, MockInterface) - assert "does not implement required methods" in str(exc_info.value) - - def test_validate_repository_interface_with_private_methods(self): - """Test that private methods are ignored during interface validation.""" - - class MockRepository: - def save(self): - pass - - def _private_method(self): - pass - - # Create a mock interface with private methods - class MockInterface: - def save(self): - pass - - def _private_method(self): - pass - - # Should not raise exception - private methods should be ignored - DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface) + with pytest.raises(ImportError) as exc_info: + import_string("invalidpath") + assert "doesn't look like a module path" in str(exc_info.value) @patch("core.repositories.factory.dify_config") def test_create_workflow_execution_repository_success(self, mock_config): @@ -133,11 +65,8 @@ class TestRepositoryFactory: mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository) mock_repository_class.return_value = mock_repository_instance - # Mock the validation methods - with ( - patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), - patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), - ): + # Mock import_string + with patch("core.repositories.factory.import_string", return_value=mock_repository_class): result = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=mock_session_factory, user=mock_user, @@ -170,34 +99,7 @@ class TestRepositoryFactory: app_id="test-app-id", triggered_from=WorkflowRunTriggeredFrom.APP_RUN, ) - assert "Cannot import repository class" in str(exc_info.value) - - @patch("core.repositories.factory.dify_config") - def test_create_workflow_execution_repository_validation_error(self, mock_config, mocker: MockerFixture): - """Test WorkflowExecutionRepository creation with validation error.""" - # Setup mock configuration - mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" - - mock_session_factory = MagicMock(spec=sessionmaker) - mock_user = MagicMock(spec=Account) - - # Mock the import to succeed but validation to fail - mock_repository_class = MagicMock() - mocker.patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class) - mocker.patch.object( - DifyCoreRepositoryFactory, - "_validate_repository_interface", - side_effect=RepositoryImportError("Interface validation failed"), - ) - - with pytest.raises(RepositoryImportError) as exc_info: - DifyCoreRepositoryFactory.create_workflow_execution_repository( - session_factory=mock_session_factory, - user=mock_user, - app_id="test-app-id", - triggered_from=WorkflowRunTriggeredFrom.APP_RUN, - ) - assert "Interface validation failed" in str(exc_info.value) + assert "Failed to create WorkflowExecutionRepository" in str(exc_info.value) @patch("core.repositories.factory.dify_config") def test_create_workflow_execution_repository_instantiation_error(self, mock_config): @@ -212,11 +114,8 @@ class TestRepositoryFactory: mock_repository_class = MagicMock() mock_repository_class.side_effect = Exception("Instantiation failed") - # Mock the validation methods to succeed - with ( - patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), - patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), - ): + # Mock import_string to return a failing class + with patch("core.repositories.factory.import_string", return_value=mock_repository_class): with pytest.raises(RepositoryImportError) as exc_info: DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=mock_session_factory, @@ -243,11 +142,8 @@ class TestRepositoryFactory: mock_repository_instance = MagicMock(spec=WorkflowNodeExecutionRepository) mock_repository_class.return_value = mock_repository_instance - # Mock the validation methods - with ( - patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), - patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), - ): + # Mock import_string + with patch("core.repositories.factory.import_string", return_value=mock_repository_class): result = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=mock_session_factory, user=mock_user, @@ -280,34 +176,7 @@ class TestRepositoryFactory: app_id="test-app-id", triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, ) - assert "Cannot import repository class" in str(exc_info.value) - - @patch("core.repositories.factory.dify_config") - def test_create_workflow_node_execution_repository_validation_error(self, mock_config, mocker: MockerFixture): - """Test WorkflowNodeExecutionRepository creation with validation error.""" - # Setup mock configuration - mock_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" - - mock_session_factory = MagicMock(spec=sessionmaker) - mock_user = MagicMock(spec=EndUser) - - # Mock the import to succeed but validation to fail - mock_repository_class = MagicMock() - mocker.patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class) - mocker.patch.object( - DifyCoreRepositoryFactory, - "_validate_repository_interface", - side_effect=RepositoryImportError("Interface validation failed"), - ) - - with pytest.raises(RepositoryImportError) as exc_info: - DifyCoreRepositoryFactory.create_workflow_node_execution_repository( - session_factory=mock_session_factory, - user=mock_user, - app_id="test-app-id", - triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, - ) - assert "Interface validation failed" in str(exc_info.value) + assert "Failed to create WorkflowNodeExecutionRepository" in str(exc_info.value) @patch("core.repositories.factory.dify_config") def test_create_workflow_node_execution_repository_instantiation_error(self, mock_config): @@ -322,11 +191,8 @@ class TestRepositoryFactory: mock_repository_class = MagicMock() mock_repository_class.side_effect = Exception("Instantiation failed") - # Mock the validation methods to succeed - with ( - patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), - patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), - ): + # Mock import_string to return a failing class + with patch("core.repositories.factory.import_string", return_value=mock_repository_class): with pytest.raises(RepositoryImportError) as exc_info: DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=mock_session_factory, @@ -359,11 +225,8 @@ class TestRepositoryFactory: mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository) mock_repository_class.return_value = mock_repository_instance - # Mock the validation methods - with ( - patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), - patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), - ): + # Mock import_string + with patch("core.repositories.factory.import_string", return_value=mock_repository_class): result = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=mock_engine, # Using Engine instead of sessionmaker user=mock_user,