refactor: replace bare dict with dict[str, Any] in controller and core unit tests (#35181)

This commit is contained in:
wdeveloper16 2026-04-14 19:51:49 +02:00 committed by GitHub
parent 1bcc7f78c7
commit 50a6892c3a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 25 additions and 18 deletions

View File

@ -2,6 +2,7 @@
Unit tests for inner_api plugin decorators
"""
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
@ -232,11 +233,11 @@ class TestGetUserTenant:
class PluginTestPayload:
"""Simple test payload class"""
def __init__(self, data: dict):
def __init__(self, data: dict[str, Any]):
self.value = data.get("value")
@classmethod
def model_validate(cls, data: dict):
def model_validate(cls, data: dict[str, Any]):
return cls(data)
@ -277,7 +278,7 @@ class TestPluginData:
# Arrange
class InvalidPayload:
@classmethod
def model_validate(cls, data: dict):
def model_validate(cls, data: dict[str, Any]):
raise Exception("Validation failed")
@plugin_data(payload_type=InvalidPayload)

View File

@ -1,4 +1,5 @@
from json.decoder import JSONDecodeError
from typing import Any
from unittest.mock import Mock, patch
import pytest
@ -259,8 +260,8 @@ def test_parse_openapi_to_tool_bundle_server_env_and_refs(app):
},
}
extra_info: dict = {}
warning: dict = {}
extra_info: dict[str, Any] = {}
warning: dict[str, Any] = {}
with app.test_request_context(headers={"X-Request-Env": "prod"}):
bundles = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
@ -298,7 +299,7 @@ def test_parse_swagger_to_openapi_branches():
}
)
warning: dict = {"seed": True}
warning: dict[str, Any] = {"seed": True}
converted = ApiBasedToolSchemaParser.parse_swagger_to_openapi(
{
"servers": [{"url": "https://x"}],

View File

@ -8,6 +8,7 @@ and select_trigger_debug_events orchestrator.
from __future__ import annotations
from datetime import datetime
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
@ -30,7 +31,7 @@ from core.trigger.debug.events import PluginTriggerDebugEvent, WebhookDebugEvent
from tests.unit_tests.core.trigger.conftest import VALID_PROVIDER_ID
def _make_poller_args(node_config: dict | None = None) -> dict:
def _make_poller_args(node_config: dict[str, Any] | None = None) -> dict[str, Any]:
return {
"tenant_id": "t1",
"user_id": "u1",

View File

@ -6,6 +6,7 @@ to FileVariable objects, fixing the "Invalid variable type: ObjectVariable" erro
when passing files to downstream LLM nodes.
"""
from typing import Any
from unittest.mock import Mock, patch
from graphon.entities import GraphInitParams
@ -97,7 +98,7 @@ def create_test_file_dict(
}
def build_webhook_variable_pool(inputs: dict) -> VariablePool:
def build_webhook_variable_pool(inputs: dict[str, Any]) -> VariablePool:
return build_test_variable_pool(
variables=default_system_variables(),
node_id="webhook-node-1",
@ -105,7 +106,7 @@ def build_webhook_variable_pool(inputs: dict) -> VariablePool:
)
def expected_factory_mapping(file_dict: dict) -> dict:
def expected_factory_mapping(file_dict: dict[str, Any]) -> dict[str, Any]:
return {**file_dict, "upload_file_id": file_dict["related_id"]}

View File

@ -1,3 +1,4 @@
from typing import Any
from unittest.mock import patch
import pytest
@ -62,7 +63,7 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool)
return node
def build_webhook_variable_pool(inputs: dict) -> VariablePool:
def build_webhook_variable_pool(inputs: dict[str, Any]) -> VariablePool:
return build_test_variable_pool(
variables=default_system_variables(),
node_id="1",

View File

@ -1,7 +1,7 @@
import threading
import time
from dataclasses import dataclass
from typing import cast
from typing import Any, cast
from unittest.mock import patch
import pytest
@ -30,7 +30,7 @@ class FakeStreamsRedis:
self._dollar_snapshots: dict[str, int] = {}
# Publisher API
def xadd(self, key: str, fields: dict, *, maxlen: int | None = None) -> str:
def xadd(self, key: str, fields: dict[str, Any], *, maxlen: int | None = None) -> str:
"""Append entry to stream; accept optional maxlen for API compatibility.
The test double ignores maxlen trimming semantics; only records the entry.
@ -45,7 +45,7 @@ class FakeStreamsRedis:
self._expire_calls[key] = self._expire_calls.get(key, 0) + 1
# Consumer API
def xread(self, streams: dict, block: int | None = None, count: int | None = None):
def xread(self, streams: dict[str, Any], block: int | None = None, count: int | None = None):
# Expect a single key
assert len(streams) == 1
key, last_id = next(iter(streams.items()))
@ -80,7 +80,7 @@ class BlockingRedis:
def __init__(self) -> None:
self._release = threading.Event()
def xread(self, streams: dict, block: int | None = None, count: int | None = None):
def xread(self, streams: dict[str, Any], block: int | None = None, count: int | None = None):
self._release.wait(timeout=block / 1000.0 if block else None)
return []
@ -245,7 +245,7 @@ class TestStreamsSubscription:
self._fields = fields
self._calls = 0
def xread(self, streams: dict, block: int | None = None, count: int | None = None):
def xread(self, streams: dict[str, Any], block: int | None = None, count: int | None = None):
self._calls += 1
if self._calls == 1:
key = next(iter(streams))

View File

@ -1,6 +1,7 @@
import json
import uuid
from collections import defaultdict, deque
from typing import Any
import pytest
@ -60,7 +61,7 @@ class _FakeStreams:
self._data: dict[str, list[tuple[str, dict]]] = defaultdict(list)
self._seq: dict[str, int] = defaultdict(int)
def xadd(self, key: str, fields: dict, *, maxlen: int | None = None) -> str:
def xadd(self, key: str, fields: dict[str, Any], *, maxlen: int | None = None) -> str:
# maxlen is accepted for API compatibility with redis-py; ignored in this test double
self._seq[key] += 1
eid = f"{self._seq[key]}-0"
@ -71,7 +72,7 @@ class _FakeStreams:
# no-op for tests
return None
def xread(self, streams: dict, block: int | None = None, count: int | None = None):
def xread(self, streams: dict[str, Any], block: int | None = None, count: int | None = None):
assert len(streams) == 1
key, last_id = next(iter(streams.items()))
entries = self._data.get(key, [])

View File

@ -1,4 +1,5 @@
import datetime
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
@ -18,7 +19,7 @@ def make_simple_message(msg_id: str, app_id: str) -> SimpleMessage:
return SimpleMessage(id=msg_id, app_id=app_id, created_at=datetime.datetime(2024, 1, 1))
def make_plan_provider(tenant_plans: dict) -> MagicMock:
def make_plan_provider(tenant_plans: dict[str, Any]) -> MagicMock:
"""Helper to create a mock plan_provider that returns the given tenant_plans."""
provider = MagicMock()
provider.return_value = tenant_plans