chore(api): Fix several typing errors (#37248)

This commit is contained in:
chariri 2026-06-12 23:02:09 +09:00 committed by GitHub
parent ad96501e09
commit 7cf75c3cc5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 148 additions and 93 deletions

View File

@ -11,8 +11,10 @@ from core.tools.entities.tool_entities import (
from core.tools.errors import ToolProviderCredentialValidationError
class ToolProviderController(ABC):
def __init__(self, entity: ToolProviderEntity):
class ToolProviderController[ToolProviderEntityT: ToolProviderEntity, ToolProviderToolT: Tool | None](ABC):
entity: ToolProviderEntityT
def __init__(self, entity: ToolProviderEntityT):
self.entity = entity
def get_credentials_schema(self) -> list[ProviderConfig]:
@ -24,7 +26,7 @@ class ToolProviderController(ABC):
return deepcopy(self.entity.credentials_schema)
@abstractmethod
def get_tool(self, tool_name: str) -> Tool:
def get_tool(self, tool_name: str) -> ToolProviderToolT:
"""
returns a tool that the provider can provide

View File

@ -21,7 +21,7 @@ from core.tools.errors import (
from core.tools.utils.yaml_utils import load_yaml_file_cached
class BuiltinToolProviderController(ToolProviderController):
class BuiltinToolProviderController(ToolProviderController[ToolProviderEntity, BuiltinTool | None]):
tools: list[BuiltinTool]
def __init__(self, **data: Any):
@ -163,7 +163,8 @@ class BuiltinToolProviderController(ToolProviderController):
"""
return self._get_builtin_tools()
def get_tool(self, tool_name: str) -> BuiltinTool | None: # type: ignore
@override
def get_tool(self, tool_name: str) -> BuiltinTool | None:
"""
returns the tool that the provider can provide
"""

View File

@ -24,7 +24,7 @@ from extensions.ext_database import db
from models.tools import ApiToolProvider
class ApiToolProviderController(ToolProviderController):
class ApiToolProviderController(ToolProviderController[ToolProviderEntity, ApiTool]):
provider_id: str
tenant_id: str
tools: list[ApiTool] = Field(default_factory=list)

View File

@ -18,7 +18,7 @@ from models.tools import MCPToolProvider
from services.tools.tools_transform_service import ToolTransformService
class MCPToolProviderController(ToolProviderController):
class MCPToolProviderController(ToolProviderController[ToolProviderEntityWithPlugin, MCPTool]):
def __init__(
self,
entity: ToolProviderEntityWithPlugin,

View File

@ -9,7 +9,9 @@ from core.tools.plugin_tool.tool import PluginTool
class PluginToolProviderController(BuiltinToolProviderController):
entity: ToolProviderEntityWithPlugin
# TODO: Split the credential/schema helpers from BuiltinToolProviderController
# so plugin providers do not need to inherit builtin tool-loading behavior.
entity: ToolProviderEntityWithPlugin # pyrefly: ignore[bad-override-mutable-attribute]
tenant_id: str
plugin_id: str
plugin_unique_identifier: str
@ -46,7 +48,8 @@ class PluginToolProviderController(BuiltinToolProviderController):
):
raise ToolProviderCredentialValidationError("Invalid credentials")
def get_tool(self, tool_name: str) -> PluginTool: # type: ignore
@override
def get_tool(self, tool_name: str) -> PluginTool: # type: ignore[override] # pyrefly: ignore[bad-override]
"""
return tool with given name
"""
@ -65,7 +68,8 @@ class PluginToolProviderController(BuiltinToolProviderController):
plugin_unique_identifier=self.plugin_unique_identifier,
)
def get_tools(self) -> list[PluginTool]: # type: ignore
@override
def get_tools(self) -> list[PluginTool]: # type: ignore[override] # pyrefly: ignore[bad-override]
"""
get all tools
"""

View File

@ -3,7 +3,6 @@ from __future__ import annotations
from collections.abc import Mapping
from typing import override
from pydantic import Field
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -43,13 +42,14 @@ VARIABLE_TO_PARAMETER_TYPE_MAPPING = {
}
class WorkflowToolProviderController(ToolProviderController):
class WorkflowToolProviderController(ToolProviderController[ToolProviderEntity, WorkflowTool | None]):
provider_id: str
tools: list[WorkflowTool] = Field(default_factory=list)
tools: list[WorkflowTool] | None
def __init__(self, entity: ToolProviderEntity, provider_id: str):
super().__init__(entity=entity)
self.provider_id = provider_id
self.tools = None
@classmethod
def from_db(cls, db_provider: WorkflowToolProvider) -> WorkflowToolProviderController:
@ -241,7 +241,8 @@ class WorkflowToolProviderController(ToolProviderController):
return self.tools
def get_tool(self, tool_name: str) -> WorkflowTool | None: # type: ignore
@override
def get_tool(self, tool_name: str) -> WorkflowTool | None:
"""
get tool by name

View File

@ -27,12 +27,14 @@ class HuaweiObsStorage(BaseStorage):
@override
def load_once(self, filename: str) -> bytes:
data: bytes = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response.read()
# TODO: Huawei SDK lacks proper typing
data: bytes = self.client.getObject(bucketName=self.bucket_name, objectKey=filename).body.response.read() # type: ignore
return data
@override
def load_stream(self, filename: str) -> Generator:
response = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response
# TODO: Huawei SDK lacks proper typing
response = self.client.getObject(bucketName=self.bucket_name, objectKey=filename).body.response # type: ignore
while chunk := response.read(4096):
yield chunk

View File

@ -20,6 +20,7 @@
# ===================================================================
from hashlib import sha1
from typing import TYPE_CHECKING, cast
import Crypto.Hash.SHA1
import Crypto.Util.number
@ -30,6 +31,9 @@ from Crypto.Util.number import bytes_to_long, ceil_div, long_to_bytes
from Crypto.Util.py3compat import bord
from Crypto.Util.strxor import strxor
if TYPE_CHECKING:
from Crypto.Signature.pss import HashModule
class PKCS1OAepCipher:
"""Cipher object for PKCS#1 v1.5 OAEP.
@ -70,7 +74,7 @@ class PKCS1OAepCipher:
if mgfunc:
self._mgf = mgfunc
else:
self._mgf = lambda x, y: MGF1(x, y, self._hashObj)
self._mgf = lambda x, y: MGF1(x, y, cast("HashModule", self._hashObj))
self._label = bytes(label)
self._randfunc = randfunc

View File

@ -53,11 +53,3 @@ providers/vdb/vdb-vikingdb/src/dify_vdb_vikingdb/vikingdb_vector.py
providers/vdb/vdb-vikingdb/tests/unit_tests/test_vikingdb_vector.py
providers/vdb/vdb-weaviate/src/dify_vdb_weaviate/weaviate_vector.py
providers/vdb/vdb-weaviate/tests/unit_tests/test_weaviate_vector.py
core/tools/mcp_tool/provider.py
core/tools/plugin_tool/provider.py
core/tools/workflow_as_tool/provider.py
extensions/storage/huawei_obs_storage.py
libs/gmpy2_pkcs10aep_cipher.py
services/audio_service.py
services/document_indexing_proxy/document_indexing_task_proxy.py
services/document_indexing_proxy/duplicate_document_indexing_task_proxy.py

View File

@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
class AudioService:
@classmethod
def transcript_asr(cls, app_model: App, file: FileStorage, end_user: str | None = None):
def transcript_asr(cls, app_model: App, file: FileStorage | None, end_user: str | None = None):
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow = app_model.workflow
if workflow is None:
@ -141,14 +141,14 @@ class AudioService:
else:
response = invoke_tts(text_content=message.answer, app_model=app_model, voice=voice, is_draft=is_draft)
if isinstance(response, Generator):
return Response(stream_with_context(response), content_type="audio/mpeg")
return Response(stream_with_context(response), content_type="audio/mpeg") # type: ignore
return response
else:
if text is None:
raise ValueError("Text is required")
response = invoke_tts(text_content=text, app_model=app_model, voice=voice, is_draft=is_draft)
if isinstance(response, Generator):
return Response(stream_with_context(response), content_type="audio/mpeg")
return Response(stream_with_context(response), content_type="audio/mpeg") # type: ignore
return response
@classmethod

View File

@ -1,4 +1,5 @@
from typing import ClassVar
from collections.abc import Callable
from typing import Any, ClassVar
from services.document_indexing_proxy.batch_indexing_base import BatchDocumentIndexingProxy
from tasks.document_indexing_task import normal_document_indexing_task, priority_document_indexing_task
@ -8,5 +9,5 @@ class DocumentIndexingTaskProxy(BatchDocumentIndexingProxy):
"""Proxy for document indexing tasks."""
QUEUE_NAME: ClassVar[str] = "document_indexing"
NORMAL_TASK_FUNC = normal_document_indexing_task # pyrefly: ignore[missing-override-decorator]
PRIORITY_TASK_FUNC = priority_document_indexing_task # pyrefly: ignore[missing-override-decorator]
NORMAL_TASK_FUNC: ClassVar[Callable[..., Any]] = normal_document_indexing_task # pyrefly: ignore
PRIORITY_TASK_FUNC: ClassVar[Callable[..., Any]] = priority_document_indexing_task # pyrefly: ignore

View File

@ -1,4 +1,5 @@
from typing import ClassVar
from collections.abc import Callable
from typing import Any, ClassVar
from services.document_indexing_proxy.batch_indexing_base import BatchDocumentIndexingProxy
from tasks.duplicate_document_indexing_task import (
@ -6,10 +7,12 @@ from tasks.duplicate_document_indexing_task import (
priority_duplicate_document_indexing_task,
)
TaskFunc = Callable[..., Any]
class DuplicateDocumentIndexingTaskProxy(BatchDocumentIndexingProxy):
"""Proxy for duplicate document indexing tasks."""
QUEUE_NAME: ClassVar[str] = "duplicate_document_indexing"
NORMAL_TASK_FUNC = normal_duplicate_document_indexing_task # pyrefly: ignore[missing-override-decorator]
PRIORITY_TASK_FUNC = priority_duplicate_document_indexing_task # pyrefly: ignore[missing-override-decorator]
NORMAL_TASK_FUNC: ClassVar[TaskFunc] = normal_duplicate_document_indexing_task # pyrefly: ignore
PRIORITY_TASK_FUNC: ClassVar[TaskFunc] = priority_duplicate_document_indexing_task # pyrefly: ignore

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from collections.abc import Generator
from typing import Any
from typing import Any, override
from unittest.mock import patch
import pytest
@ -16,6 +16,7 @@ from core.tools.errors import ToolProviderNotFoundError
class _FakeBuiltinTool(BuiltinTool):
@override
def _invoke(
self,
user_id: str,
@ -30,6 +31,7 @@ class _FakeBuiltinTool(BuiltinTool):
class _ConcreteBuiltinProvider(BuiltinToolProviderController):
last_validation: tuple[str, dict[str, Any]] | None = None
@override
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
self.last_validation = (user_id, credentials)

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from types import SimpleNamespace
from typing import Any
from typing import Any, override
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
@ -14,6 +14,7 @@ from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
# Create a mock class for testing abstract/base classes
class _ConcreteBuiltinToolProviderController(BuiltinToolProviderController):
@override
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
return None

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from collections.abc import Generator
from typing import Any
from typing import Any, override
import pytest
@ -22,9 +22,11 @@ from core.tools.errors import ToolProviderCredentialValidationError
class _DummyTool(Tool):
@override
def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.BUILT_IN
@override
def _invoke(
self,
user_id: str,
@ -36,7 +38,8 @@ class _DummyTool(Tool):
yield self.create_text_message("ok")
class _DummyController(ToolProviderController):
class _DummyController(ToolProviderController[ToolProviderEntity, Tool]):
@override
def get_tool(self, tool_name: str) -> Tool:
entity = ToolEntity(
identity=ToolIdentity(

View File

@ -1,19 +1,30 @@
from __future__ import annotations
import json
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock, Mock, patch
import pytest
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
ToolDescription,
ToolEntity,
ToolIdentity,
ToolParameter,
ToolProviderEntity,
ToolProviderIdentity,
ToolProviderType,
)
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.tools.workflow_as_tool.tool import WorkflowTool
from graphon.variables.input_entities import VariableEntity, VariableEntityType
from models.account import Account
from models.model import App
from models.tools import WorkflowToolProvider
from models.workflow import Workflow, WorkflowType
def _controller() -> WorkflowToolProviderController:
@ -30,6 +41,64 @@ def _controller() -> WorkflowToolProviderController:
return WorkflowToolProviderController(entity=entity, provider_id="provider-1")
def _app() -> App:
return App(id="app-1")
def _account() -> Account:
return Account(name="Alice", email="alice@example.com")
def _workflow() -> Workflow:
return Workflow.new(
tenant_id="tenant-1",
app_id="app-1",
type=WorkflowType.WORKFLOW.value,
version="1",
graph=json.dumps({"nodes": []}),
features="{}",
created_by="user-1",
environment_variables=[],
conversation_variables=[],
rag_pipeline_variables=[],
)
def _db_provider(*, parameter_configuration: str = "[]") -> WorkflowToolProvider:
return WorkflowToolProvider(
name="workflow_tool",
label="WF Provider",
icon="icon.svg",
app_id="app-1",
version="1",
user_id="user-1",
tenant_id="tenant-1",
description="desc",
parameter_configuration=parameter_configuration,
)
def _workflow_tool(name: str = "workflow_tool") -> WorkflowTool:
return WorkflowTool(
workflow_as_tool_id="provider-1",
entity=ToolEntity(
identity=ToolIdentity(
author="author",
name=name,
label=I18nObject(en_US=name),
provider="provider-1",
),
description=ToolDescription(human=I18nObject(en_US="desc"), llm="desc"),
parameters=[],
),
runtime=ToolRuntime(tenant_id="tenant-1"),
workflow_app_id="app-1",
workflow_entities={"app": _app(), "workflow": _workflow()},
version="1",
workflow_call_depth=0,
)
def _mock_session_with_begin() -> Mock:
session = Mock()
begin_cm = Mock()
@ -42,25 +111,18 @@ def _mock_session_with_begin() -> Mock:
def test_get_db_provider_tool_builds_entity():
controller = _controller()
session = Mock()
workflow = SimpleNamespace(graph_dict={"nodes": []}, features_dict={})
workflow = _workflow()
session.scalar.return_value = workflow
app = SimpleNamespace(id="app-1")
db_provider = SimpleNamespace(
id="provider-1",
app_id="app-1",
version="1",
label="WF Provider",
description="desc",
icon="icon.svg",
name="workflow_tool",
tenant_id="tenant-1",
user_id="user-1",
parameter_configurations=[
SimpleNamespace(name="country", description="Country", form=ToolParameter.ToolParameterForm.FORM),
SimpleNamespace(name="files", description="files", form=ToolParameter.ToolParameterForm.FORM),
],
app = _app()
db_provider = _db_provider(
parameter_configuration=json.dumps(
[
{"name": "country", "description": "Country", "form": ToolParameter.ToolParameterForm.FORM.value},
{"name": "files", "description": "files", "form": ToolParameter.ToolParameterForm.FORM.value},
]
)
)
user = SimpleNamespace(name="Alice")
user = _account()
variables = [
VariableEntity(
variable="country",
@ -94,8 +156,9 @@ def test_get_db_provider_tool_builds_entity():
assert tool.entity.identity.name == "workflow_tool"
# "json" output is reserved for ToolInvokeMessage.VariableMessage and filtered out.
assert tool.entity.output_schema["properties"] == {"answer": {"type": "string", "description": ""}}
assert "json" not in tool.entity.output_schema["properties"]
properties = cast(dict[str, Any], tool.entity.output_schema["properties"])
assert properties == {"answer": {"type": "string", "description": ""}}
assert "json" not in properties
assert tool.entity.parameters[0].type == ToolParameter.ToolParameterType.SELECT
assert tool.entity.parameters[1].type == ToolParameter.ToolParameterType.SYSTEM_FILES
assert controller.provider_type == ToolProviderType.WORKFLOW
@ -103,7 +166,7 @@ def test_get_db_provider_tool_builds_entity():
def test_get_tool_returns_hit_or_none():
controller = _controller()
tool = SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(name="workflow_tool")))
tool = _workflow_tool()
controller.tools = [tool]
assert controller.get_tool("workflow_tool") is tool
@ -112,29 +175,16 @@ def test_get_tool_returns_hit_or_none():
def test_get_tools_returns_cached():
controller = _controller()
cached_tools = [SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(name="wf-cached")))]
controller.tools = cached_tools # type: ignore[assignment]
cached_tools = [_workflow_tool("wf-cached")]
controller.tools = cached_tools
assert controller.get_tools("tenant-1") == cached_tools
def test_from_db_builds_controller():
controller = _controller()
app = SimpleNamespace(id="app-1")
user = SimpleNamespace(name="Alice")
db_provider = SimpleNamespace(
id="provider-1",
app_id="app-1",
version="1",
user_id="user-1",
label="WF Provider",
description="desc",
icon="icon.svg",
name="workflow_tool",
tenant_id="tenant-1",
parameter_configurations=[],
)
app = _app()
user = _account()
db_provider = _db_provider()
session = _mock_session_with_begin()
session.scalar.return_value = db_provider
session.get.side_effect = [app, user]
@ -148,7 +198,7 @@ def test_from_db_builds_controller():
with patch.object(
WorkflowToolProviderController,
"_get_db_provider_tool",
return_value=SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(name="wf"))),
return_value=_workflow_tool("wf"),
):
built = WorkflowToolProviderController.from_db(db_provider)
assert isinstance(built, WorkflowToolProviderController)
@ -157,7 +207,7 @@ def test_from_db_builds_controller():
def test_get_tools_returns_empty_when_provider_missing():
controller = _controller()
controller.tools = None # type: ignore[assignment]
controller.tools = None
with patch("core.tools.workflow_as_tool.provider.db") as mock_db:
mock_db.engine = object()
@ -171,19 +221,8 @@ def test_get_tools_returns_empty_when_provider_missing():
def test_get_tools_raises_when_app_missing():
controller = _controller()
controller.tools = None # type: ignore[assignment]
db_provider = SimpleNamespace(
id="provider-1",
app_id="app-1",
version="1",
user_id="user-1",
label="WF Provider",
description="desc",
icon="icon.svg",
name="workflow_tool",
tenant_id="tenant-1",
parameter_configurations=[],
)
controller.tools = None
db_provider = _db_provider()
with patch("core.tools.workflow_as_tool.provider.db") as mock_db:
mock_db.engine = object()

View File

@ -10,7 +10,7 @@ class DocumentIndexingTaskProxyTestDataFactory:
"""Factory class for creating test data and mock objects for DocumentIndexingTaskProxy tests."""
@staticmethod
def create_mock_features(billing_enabled: bool = False, plan: CloudPlan = CloudPlan.SANDBOX) -> Mock:
def create_mock_features(billing_enabled: bool = False, plan: CloudPlan | str | None = CloudPlan.SANDBOX) -> Mock:
"""Create mock features with billing configuration."""
features = Mock()
features.billing = Mock()

View File

@ -12,7 +12,7 @@ class DuplicateDocumentIndexingTaskProxyTestDataFactory:
"""Factory class for creating test data and mock objects for DuplicateDocumentIndexingTaskProxy tests."""
@staticmethod
def create_mock_features(billing_enabled: bool = False, plan: CloudPlan = CloudPlan.SANDBOX) -> Mock:
def create_mock_features(billing_enabled: bool = False, plan: CloudPlan | str | None = CloudPlan.SANDBOX) -> Mock:
"""Create mock features with billing configuration."""
features = Mock()
features.billing = Mock()