diff --git a/api/core/tools/__base/tool_provider.py b/api/core/tools/__base/tool_provider.py index 49cbf70378..70e4fe1ff7 100644 --- a/api/core/tools/__base/tool_provider.py +++ b/api/core/tools/__base/tool_provider.py @@ -11,8 +11,10 @@ from core.tools.entities.tool_entities import ( from core.tools.errors import ToolProviderCredentialValidationError -class ToolProviderController(ABC): - def __init__(self, entity: ToolProviderEntity): +class ToolProviderController[ToolProviderEntityT: ToolProviderEntity, ToolProviderToolT: Tool | None](ABC): + entity: ToolProviderEntityT + + def __init__(self, entity: ToolProviderEntityT): self.entity = entity def get_credentials_schema(self) -> list[ProviderConfig]: @@ -24,7 +26,7 @@ class ToolProviderController(ABC): return deepcopy(self.entity.credentials_schema) @abstractmethod - def get_tool(self, tool_name: str) -> Tool: + def get_tool(self, tool_name: str) -> ToolProviderToolT: """ returns a tool that the provider can provide diff --git a/api/core/tools/builtin_tool/provider.py b/api/core/tools/builtin_tool/provider.py index 52d86f0648..a01dbdbeed 100644 --- a/api/core/tools/builtin_tool/provider.py +++ b/api/core/tools/builtin_tool/provider.py @@ -21,7 +21,7 @@ from core.tools.errors import ( from core.tools.utils.yaml_utils import load_yaml_file_cached -class BuiltinToolProviderController(ToolProviderController): +class BuiltinToolProviderController(ToolProviderController[ToolProviderEntity, BuiltinTool | None]): tools: list[BuiltinTool] def __init__(self, **data: Any): @@ -163,7 +163,8 @@ class BuiltinToolProviderController(ToolProviderController): """ return self._get_builtin_tools() - def get_tool(self, tool_name: str) -> BuiltinTool | None: # type: ignore + @override + def get_tool(self, tool_name: str) -> BuiltinTool | None: """ returns the tool that the provider can provide """ diff --git a/api/core/tools/custom_tool/provider.py b/api/core/tools/custom_tool/provider.py index 520a55dbd3..ade5b894f9 100644 --- a/api/core/tools/custom_tool/provider.py +++ b/api/core/tools/custom_tool/provider.py @@ -24,7 +24,7 @@ from extensions.ext_database import db from models.tools import ApiToolProvider -class ApiToolProviderController(ToolProviderController): +class ApiToolProviderController(ToolProviderController[ToolProviderEntity, ApiTool]): provider_id: str tenant_id: str tools: list[ApiTool] = Field(default_factory=list) diff --git a/api/core/tools/mcp_tool/provider.py b/api/core/tools/mcp_tool/provider.py index 52414153b8..8648612172 100644 --- a/api/core/tools/mcp_tool/provider.py +++ b/api/core/tools/mcp_tool/provider.py @@ -18,7 +18,7 @@ from models.tools import MCPToolProvider from services.tools.tools_transform_service import ToolTransformService -class MCPToolProviderController(ToolProviderController): +class MCPToolProviderController(ToolProviderController[ToolProviderEntityWithPlugin, MCPTool]): def __init__( self, entity: ToolProviderEntityWithPlugin, diff --git a/api/core/tools/plugin_tool/provider.py b/api/core/tools/plugin_tool/provider.py index fc6ec14284..526b3680b0 100644 --- a/api/core/tools/plugin_tool/provider.py +++ b/api/core/tools/plugin_tool/provider.py @@ -9,7 +9,9 @@ from core.tools.plugin_tool.tool import PluginTool class PluginToolProviderController(BuiltinToolProviderController): - entity: ToolProviderEntityWithPlugin + # TODO: Split the credential/schema helpers from BuiltinToolProviderController + # so plugin providers do not need to inherit builtin tool-loading behavior. + entity: ToolProviderEntityWithPlugin # pyrefly: ignore[bad-override-mutable-attribute] tenant_id: str plugin_id: str plugin_unique_identifier: str @@ -46,7 +48,8 @@ class PluginToolProviderController(BuiltinToolProviderController): ): raise ToolProviderCredentialValidationError("Invalid credentials") - def get_tool(self, tool_name: str) -> PluginTool: # type: ignore + @override + def get_tool(self, tool_name: str) -> PluginTool: # type: ignore[override] # pyrefly: ignore[bad-override] """ return tool with given name """ @@ -65,7 +68,8 @@ class PluginToolProviderController(BuiltinToolProviderController): plugin_unique_identifier=self.plugin_unique_identifier, ) - def get_tools(self) -> list[PluginTool]: # type: ignore + @override + def get_tools(self) -> list[PluginTool]: # type: ignore[override] # pyrefly: ignore[bad-override] """ get all tools """ diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py index 41212bcec8..1611cd1d63 100644 --- a/api/core/tools/workflow_as_tool/provider.py +++ b/api/core/tools/workflow_as_tool/provider.py @@ -3,7 +3,6 @@ from __future__ import annotations from collections.abc import Mapping from typing import override -from pydantic import Field from sqlalchemy import select from sqlalchemy.orm import Session @@ -43,13 +42,14 @@ VARIABLE_TO_PARAMETER_TYPE_MAPPING = { } -class WorkflowToolProviderController(ToolProviderController): +class WorkflowToolProviderController(ToolProviderController[ToolProviderEntity, WorkflowTool | None]): provider_id: str - tools: list[WorkflowTool] = Field(default_factory=list) + tools: list[WorkflowTool] | None def __init__(self, entity: ToolProviderEntity, provider_id: str): super().__init__(entity=entity) self.provider_id = provider_id + self.tools = None @classmethod def from_db(cls, db_provider: WorkflowToolProvider) -> WorkflowToolProviderController: @@ -241,7 +241,8 @@ class WorkflowToolProviderController(ToolProviderController): return self.tools - def get_tool(self, tool_name: str) -> WorkflowTool | None: # type: ignore + @override + def get_tool(self, tool_name: str) -> WorkflowTool | None: """ get tool by name diff --git a/api/extensions/storage/huawei_obs_storage.py b/api/extensions/storage/huawei_obs_storage.py index 72fdabe455..c840b17086 100644 --- a/api/extensions/storage/huawei_obs_storage.py +++ b/api/extensions/storage/huawei_obs_storage.py @@ -27,12 +27,14 @@ class HuaweiObsStorage(BaseStorage): @override def load_once(self, filename: str) -> bytes: - data: bytes = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response.read() + # TODO: Huawei SDK lacks proper typing + data: bytes = self.client.getObject(bucketName=self.bucket_name, objectKey=filename).body.response.read() # type: ignore return data @override def load_stream(self, filename: str) -> Generator: - response = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response + # TODO: Huawei SDK lacks proper typing + response = self.client.getObject(bucketName=self.bucket_name, objectKey=filename).body.response # type: ignore while chunk := response.read(4096): yield chunk diff --git a/api/libs/gmpy2_pkcs10aep_cipher.py b/api/libs/gmpy2_pkcs10aep_cipher.py index ef26699fb3..ef1cc56e0e 100644 --- a/api/libs/gmpy2_pkcs10aep_cipher.py +++ b/api/libs/gmpy2_pkcs10aep_cipher.py @@ -20,6 +20,7 @@ # =================================================================== from hashlib import sha1 +from typing import TYPE_CHECKING, cast import Crypto.Hash.SHA1 import Crypto.Util.number @@ -30,6 +31,9 @@ from Crypto.Util.number import bytes_to_long, ceil_div, long_to_bytes from Crypto.Util.py3compat import bord from Crypto.Util.strxor import strxor +if TYPE_CHECKING: + from Crypto.Signature.pss import HashModule + class PKCS1OAepCipher: """Cipher object for PKCS#1 v1.5 OAEP. @@ -70,7 +74,7 @@ class PKCS1OAepCipher: if mgfunc: self._mgf = mgfunc else: - self._mgf = lambda x, y: MGF1(x, y, self._hashObj) + self._mgf = lambda x, y: MGF1(x, y, cast("HashModule", self._hashObj)) self._label = bytes(label) self._randfunc = randfunc diff --git a/api/pyrefly-local-excludes.txt b/api/pyrefly-local-excludes.txt index cd451dec25..c9637136d4 100644 --- a/api/pyrefly-local-excludes.txt +++ b/api/pyrefly-local-excludes.txt @@ -53,11 +53,3 @@ providers/vdb/vdb-vikingdb/src/dify_vdb_vikingdb/vikingdb_vector.py providers/vdb/vdb-vikingdb/tests/unit_tests/test_vikingdb_vector.py providers/vdb/vdb-weaviate/src/dify_vdb_weaviate/weaviate_vector.py providers/vdb/vdb-weaviate/tests/unit_tests/test_weaviate_vector.py -core/tools/mcp_tool/provider.py -core/tools/plugin_tool/provider.py -core/tools/workflow_as_tool/provider.py -extensions/storage/huawei_obs_storage.py -libs/gmpy2_pkcs10aep_cipher.py -services/audio_service.py -services/document_indexing_proxy/document_indexing_task_proxy.py -services/document_indexing_proxy/duplicate_document_indexing_task_proxy.py diff --git a/api/services/audio_service.py b/api/services/audio_service.py index c80b2f43fd..a9024eb3bd 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -30,7 +30,7 @@ logger = logging.getLogger(__name__) class AudioService: @classmethod - def transcript_asr(cls, app_model: App, file: FileStorage, end_user: str | None = None): + def transcript_asr(cls, app_model: App, file: FileStorage | None, end_user: str | None = None): if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: workflow = app_model.workflow if workflow is None: @@ -141,14 +141,14 @@ class AudioService: else: response = invoke_tts(text_content=message.answer, app_model=app_model, voice=voice, is_draft=is_draft) if isinstance(response, Generator): - return Response(stream_with_context(response), content_type="audio/mpeg") + return Response(stream_with_context(response), content_type="audio/mpeg") # type: ignore return response else: if text is None: raise ValueError("Text is required") response = invoke_tts(text_content=text, app_model=app_model, voice=voice, is_draft=is_draft) if isinstance(response, Generator): - return Response(stream_with_context(response), content_type="audio/mpeg") + return Response(stream_with_context(response), content_type="audio/mpeg") # type: ignore return response @classmethod diff --git a/api/services/document_indexing_proxy/document_indexing_task_proxy.py b/api/services/document_indexing_proxy/document_indexing_task_proxy.py index d9295899cb..23aab28f81 100644 --- a/api/services/document_indexing_proxy/document_indexing_task_proxy.py +++ b/api/services/document_indexing_proxy/document_indexing_task_proxy.py @@ -1,4 +1,5 @@ -from typing import ClassVar +from collections.abc import Callable +from typing import Any, ClassVar from services.document_indexing_proxy.batch_indexing_base import BatchDocumentIndexingProxy from tasks.document_indexing_task import normal_document_indexing_task, priority_document_indexing_task @@ -8,5 +9,5 @@ class DocumentIndexingTaskProxy(BatchDocumentIndexingProxy): """Proxy for document indexing tasks.""" QUEUE_NAME: ClassVar[str] = "document_indexing" - NORMAL_TASK_FUNC = normal_document_indexing_task # pyrefly: ignore[missing-override-decorator] - PRIORITY_TASK_FUNC = priority_document_indexing_task # pyrefly: ignore[missing-override-decorator] + NORMAL_TASK_FUNC: ClassVar[Callable[..., Any]] = normal_document_indexing_task # pyrefly: ignore + PRIORITY_TASK_FUNC: ClassVar[Callable[..., Any]] = priority_document_indexing_task # pyrefly: ignore diff --git a/api/services/document_indexing_proxy/duplicate_document_indexing_task_proxy.py b/api/services/document_indexing_proxy/duplicate_document_indexing_task_proxy.py index 224cab1e14..88b25aff4f 100644 --- a/api/services/document_indexing_proxy/duplicate_document_indexing_task_proxy.py +++ b/api/services/document_indexing_proxy/duplicate_document_indexing_task_proxy.py @@ -1,4 +1,5 @@ -from typing import ClassVar +from collections.abc import Callable +from typing import Any, ClassVar from services.document_indexing_proxy.batch_indexing_base import BatchDocumentIndexingProxy from tasks.duplicate_document_indexing_task import ( @@ -6,10 +7,12 @@ from tasks.duplicate_document_indexing_task import ( priority_duplicate_document_indexing_task, ) +TaskFunc = Callable[..., Any] + class DuplicateDocumentIndexingTaskProxy(BatchDocumentIndexingProxy): """Proxy for duplicate document indexing tasks.""" QUEUE_NAME: ClassVar[str] = "duplicate_document_indexing" - NORMAL_TASK_FUNC = normal_duplicate_document_indexing_task # pyrefly: ignore[missing-override-decorator] - PRIORITY_TASK_FUNC = priority_duplicate_document_indexing_task # pyrefly: ignore[missing-override-decorator] + NORMAL_TASK_FUNC: ClassVar[TaskFunc] = normal_duplicate_document_indexing_task # pyrefly: ignore + PRIORITY_TASK_FUNC: ClassVar[TaskFunc] = priority_duplicate_document_indexing_task # pyrefly: ignore diff --git a/api/tests/unit_tests/core/tools/test_builtin_tool_provider.py b/api/tests/unit_tests/core/tools/test_builtin_tool_provider.py index b21a5c3e24..649545ae68 100644 --- a/api/tests/unit_tests/core/tools/test_builtin_tool_provider.py +++ b/api/tests/unit_tests/core/tools/test_builtin_tool_provider.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Generator -from typing import Any +from typing import Any, override from unittest.mock import patch import pytest @@ -16,6 +16,7 @@ from core.tools.errors import ToolProviderNotFoundError class _FakeBuiltinTool(BuiltinTool): + @override def _invoke( self, user_id: str, @@ -30,6 +31,7 @@ class _FakeBuiltinTool(BuiltinTool): class _ConcreteBuiltinProvider(BuiltinToolProviderController): last_validation: tuple[str, dict[str, Any]] | None = None + @override def _validate_credentials(self, user_id: str, credentials: dict[str, Any]): self.last_validation = (user_id, credentials) diff --git a/api/tests/unit_tests/core/tools/test_tool_label_manager.py b/api/tests/unit_tests/core/tools/test_tool_label_manager.py index e13f430f9b..ba9f8ba9e1 100644 --- a/api/tests/unit_tests/core/tools/test_tool_label_manager.py +++ b/api/tests/unit_tests/core/tools/test_tool_label_manager.py @@ -1,7 +1,7 @@ from __future__ import annotations from types import SimpleNamespace -from typing import Any +from typing import Any, override from unittest.mock import MagicMock, PropertyMock, patch import pytest @@ -14,6 +14,7 @@ from core.tools.workflow_as_tool.provider import WorkflowToolProviderController # Create a mock class for testing abstract/base classes class _ConcreteBuiltinToolProviderController(BuiltinToolProviderController): + @override def _validate_credentials(self, user_id: str, credentials: dict[str, Any]): return None diff --git a/api/tests/unit_tests/core/tools/test_tool_provider_controller.py b/api/tests/unit_tests/core/tools/test_tool_provider_controller.py index 30b8494c92..9648305289 100644 --- a/api/tests/unit_tests/core/tools/test_tool_provider_controller.py +++ b/api/tests/unit_tests/core/tools/test_tool_provider_controller.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Generator -from typing import Any +from typing import Any, override import pytest @@ -22,9 +22,11 @@ from core.tools.errors import ToolProviderCredentialValidationError class _DummyTool(Tool): + @override def tool_provider_type(self) -> ToolProviderType: return ToolProviderType.BUILT_IN + @override def _invoke( self, user_id: str, @@ -36,7 +38,8 @@ class _DummyTool(Tool): yield self.create_text_message("ok") -class _DummyController(ToolProviderController): +class _DummyController(ToolProviderController[ToolProviderEntity, Tool]): + @override def get_tool(self, tool_name: str) -> Tool: entity = ToolEntity( identity=ToolIdentity( diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py index 5a585c609a..b876fa64b9 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py @@ -1,19 +1,30 @@ from __future__ import annotations +import json from types import SimpleNamespace +from typing import Any, cast from unittest.mock import MagicMock, Mock, patch import pytest +from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ( + ToolDescription, + ToolEntity, + ToolIdentity, ToolParameter, ToolProviderEntity, ToolProviderIdentity, ToolProviderType, ) from core.tools.workflow_as_tool.provider import WorkflowToolProviderController +from core.tools.workflow_as_tool.tool import WorkflowTool from graphon.variables.input_entities import VariableEntity, VariableEntityType +from models.account import Account +from models.model import App +from models.tools import WorkflowToolProvider +from models.workflow import Workflow, WorkflowType def _controller() -> WorkflowToolProviderController: @@ -30,6 +41,64 @@ def _controller() -> WorkflowToolProviderController: return WorkflowToolProviderController(entity=entity, provider_id="provider-1") +def _app() -> App: + return App(id="app-1") + + +def _account() -> Account: + return Account(name="Alice", email="alice@example.com") + + +def _workflow() -> Workflow: + return Workflow.new( + tenant_id="tenant-1", + app_id="app-1", + type=WorkflowType.WORKFLOW.value, + version="1", + graph=json.dumps({"nodes": []}), + features="{}", + created_by="user-1", + environment_variables=[], + conversation_variables=[], + rag_pipeline_variables=[], + ) + + +def _db_provider(*, parameter_configuration: str = "[]") -> WorkflowToolProvider: + return WorkflowToolProvider( + name="workflow_tool", + label="WF Provider", + icon="icon.svg", + app_id="app-1", + version="1", + user_id="user-1", + tenant_id="tenant-1", + description="desc", + parameter_configuration=parameter_configuration, + ) + + +def _workflow_tool(name: str = "workflow_tool") -> WorkflowTool: + return WorkflowTool( + workflow_as_tool_id="provider-1", + entity=ToolEntity( + identity=ToolIdentity( + author="author", + name=name, + label=I18nObject(en_US=name), + provider="provider-1", + ), + description=ToolDescription(human=I18nObject(en_US="desc"), llm="desc"), + parameters=[], + ), + runtime=ToolRuntime(tenant_id="tenant-1"), + workflow_app_id="app-1", + workflow_entities={"app": _app(), "workflow": _workflow()}, + version="1", + workflow_call_depth=0, + ) + + def _mock_session_with_begin() -> Mock: session = Mock() begin_cm = Mock() @@ -42,25 +111,18 @@ def _mock_session_with_begin() -> Mock: def test_get_db_provider_tool_builds_entity(): controller = _controller() session = Mock() - workflow = SimpleNamespace(graph_dict={"nodes": []}, features_dict={}) + workflow = _workflow() session.scalar.return_value = workflow - app = SimpleNamespace(id="app-1") - db_provider = SimpleNamespace( - id="provider-1", - app_id="app-1", - version="1", - label="WF Provider", - description="desc", - icon="icon.svg", - name="workflow_tool", - tenant_id="tenant-1", - user_id="user-1", - parameter_configurations=[ - SimpleNamespace(name="country", description="Country", form=ToolParameter.ToolParameterForm.FORM), - SimpleNamespace(name="files", description="files", form=ToolParameter.ToolParameterForm.FORM), - ], + app = _app() + db_provider = _db_provider( + parameter_configuration=json.dumps( + [ + {"name": "country", "description": "Country", "form": ToolParameter.ToolParameterForm.FORM.value}, + {"name": "files", "description": "files", "form": ToolParameter.ToolParameterForm.FORM.value}, + ] + ) ) - user = SimpleNamespace(name="Alice") + user = _account() variables = [ VariableEntity( variable="country", @@ -94,8 +156,9 @@ def test_get_db_provider_tool_builds_entity(): assert tool.entity.identity.name == "workflow_tool" # "json" output is reserved for ToolInvokeMessage.VariableMessage and filtered out. - assert tool.entity.output_schema["properties"] == {"answer": {"type": "string", "description": ""}} - assert "json" not in tool.entity.output_schema["properties"] + properties = cast(dict[str, Any], tool.entity.output_schema["properties"]) + assert properties == {"answer": {"type": "string", "description": ""}} + assert "json" not in properties assert tool.entity.parameters[0].type == ToolParameter.ToolParameterType.SELECT assert tool.entity.parameters[1].type == ToolParameter.ToolParameterType.SYSTEM_FILES assert controller.provider_type == ToolProviderType.WORKFLOW @@ -103,7 +166,7 @@ def test_get_db_provider_tool_builds_entity(): def test_get_tool_returns_hit_or_none(): controller = _controller() - tool = SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(name="workflow_tool"))) + tool = _workflow_tool() controller.tools = [tool] assert controller.get_tool("workflow_tool") is tool @@ -112,29 +175,16 @@ def test_get_tool_returns_hit_or_none(): def test_get_tools_returns_cached(): controller = _controller() - cached_tools = [SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(name="wf-cached")))] - controller.tools = cached_tools # type: ignore[assignment] + cached_tools = [_workflow_tool("wf-cached")] + controller.tools = cached_tools assert controller.get_tools("tenant-1") == cached_tools def test_from_db_builds_controller(): - controller = _controller() - - app = SimpleNamespace(id="app-1") - user = SimpleNamespace(name="Alice") - db_provider = SimpleNamespace( - id="provider-1", - app_id="app-1", - version="1", - user_id="user-1", - label="WF Provider", - description="desc", - icon="icon.svg", - name="workflow_tool", - tenant_id="tenant-1", - parameter_configurations=[], - ) + app = _app() + user = _account() + db_provider = _db_provider() session = _mock_session_with_begin() session.scalar.return_value = db_provider session.get.side_effect = [app, user] @@ -148,7 +198,7 @@ def test_from_db_builds_controller(): with patch.object( WorkflowToolProviderController, "_get_db_provider_tool", - return_value=SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(name="wf"))), + return_value=_workflow_tool("wf"), ): built = WorkflowToolProviderController.from_db(db_provider) assert isinstance(built, WorkflowToolProviderController) @@ -157,7 +207,7 @@ def test_from_db_builds_controller(): def test_get_tools_returns_empty_when_provider_missing(): controller = _controller() - controller.tools = None # type: ignore[assignment] + controller.tools = None with patch("core.tools.workflow_as_tool.provider.db") as mock_db: mock_db.engine = object() @@ -171,19 +221,8 @@ def test_get_tools_returns_empty_when_provider_missing(): def test_get_tools_raises_when_app_missing(): controller = _controller() - controller.tools = None # type: ignore[assignment] - db_provider = SimpleNamespace( - id="provider-1", - app_id="app-1", - version="1", - user_id="user-1", - label="WF Provider", - description="desc", - icon="icon.svg", - name="workflow_tool", - tenant_id="tenant-1", - parameter_configurations=[], - ) + controller.tools = None + db_provider = _db_provider() with patch("core.tools.workflow_as_tool.provider.db") as mock_db: mock_db.engine = object() diff --git a/api/tests/unit_tests/services/test_document_indexing_task_proxy.py b/api/tests/unit_tests/services/test_document_indexing_task_proxy.py index 28de9efa57..082bb7aa86 100644 --- a/api/tests/unit_tests/services/test_document_indexing_task_proxy.py +++ b/api/tests/unit_tests/services/test_document_indexing_task_proxy.py @@ -10,7 +10,7 @@ class DocumentIndexingTaskProxyTestDataFactory: """Factory class for creating test data and mock objects for DocumentIndexingTaskProxy tests.""" @staticmethod - def create_mock_features(billing_enabled: bool = False, plan: CloudPlan = CloudPlan.SANDBOX) -> Mock: + def create_mock_features(billing_enabled: bool = False, plan: CloudPlan | str | None = CloudPlan.SANDBOX) -> Mock: """Create mock features with billing configuration.""" features = Mock() features.billing = Mock() diff --git a/api/tests/unit_tests/services/test_duplicate_document_indexing_task_proxy.py b/api/tests/unit_tests/services/test_duplicate_document_indexing_task_proxy.py index 20358d6a0c..e0370edec9 100644 --- a/api/tests/unit_tests/services/test_duplicate_document_indexing_task_proxy.py +++ b/api/tests/unit_tests/services/test_duplicate_document_indexing_task_proxy.py @@ -12,7 +12,7 @@ class DuplicateDocumentIndexingTaskProxyTestDataFactory: """Factory class for creating test data and mock objects for DuplicateDocumentIndexingTaskProxy tests.""" @staticmethod - def create_mock_features(billing_enabled: bool = False, plan: CloudPlan = CloudPlan.SANDBOX) -> Mock: + def create_mock_features(billing_enabled: bool = False, plan: CloudPlan | str | None = CloudPlan.SANDBOX) -> Mock: """Create mock features with billing configuration.""" features = Mock() features.billing = Mock()