mirror of
https://github.com/langgenius/dify.git
synced 2026-05-13 08:57:28 +08:00
refactor: replace bare dict with dict[str, Any] in controller and core unit tests (#35181)
This commit is contained in:
parent
1bcc7f78c7
commit
50a6892c3a
@ -2,6 +2,7 @@
|
||||
Unit tests for inner_api plugin decorators
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@ -232,11 +233,11 @@ class TestGetUserTenant:
|
||||
class PluginTestPayload:
|
||||
"""Simple test payload class"""
|
||||
|
||||
def __init__(self, data: dict):
|
||||
def __init__(self, data: dict[str, Any]):
|
||||
self.value = data.get("value")
|
||||
|
||||
@classmethod
|
||||
def model_validate(cls, data: dict):
|
||||
def model_validate(cls, data: dict[str, Any]):
|
||||
return cls(data)
|
||||
|
||||
|
||||
@ -277,7 +278,7 @@ class TestPluginData:
|
||||
# Arrange
|
||||
class InvalidPayload:
|
||||
@classmethod
|
||||
def model_validate(cls, data: dict):
|
||||
def model_validate(cls, data: dict[str, Any]):
|
||||
raise Exception("Validation failed")
|
||||
|
||||
@plugin_data(payload_type=InvalidPayload)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from json.decoder import JSONDecodeError
|
||||
from typing import Any
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
@ -259,8 +260,8 @@ def test_parse_openapi_to_tool_bundle_server_env_and_refs(app):
|
||||
},
|
||||
}
|
||||
|
||||
extra_info: dict = {}
|
||||
warning: dict = {}
|
||||
extra_info: dict[str, Any] = {}
|
||||
warning: dict[str, Any] = {}
|
||||
with app.test_request_context(headers={"X-Request-Env": "prod"}):
|
||||
bundles = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
|
||||
|
||||
@ -298,7 +299,7 @@ def test_parse_swagger_to_openapi_branches():
|
||||
}
|
||||
)
|
||||
|
||||
warning: dict = {"seed": True}
|
||||
warning: dict[str, Any] = {"seed": True}
|
||||
converted = ApiBasedToolSchemaParser.parse_swagger_to_openapi(
|
||||
{
|
||||
"servers": [{"url": "https://x"}],
|
||||
|
||||
@ -8,6 +8,7 @@ and select_trigger_debug_events orchestrator.
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@ -30,7 +31,7 @@ from core.trigger.debug.events import PluginTriggerDebugEvent, WebhookDebugEvent
|
||||
from tests.unit_tests.core.trigger.conftest import VALID_PROVIDER_ID
|
||||
|
||||
|
||||
def _make_poller_args(node_config: dict | None = None) -> dict:
|
||||
def _make_poller_args(node_config: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
return {
|
||||
"tenant_id": "t1",
|
||||
"user_id": "u1",
|
||||
|
||||
@ -6,6 +6,7 @@ to FileVariable objects, fixing the "Invalid variable type: ObjectVariable" erro
|
||||
when passing files to downstream LLM nodes.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from graphon.entities import GraphInitParams
|
||||
@ -97,7 +98,7 @@ def create_test_file_dict(
|
||||
}
|
||||
|
||||
|
||||
def build_webhook_variable_pool(inputs: dict) -> VariablePool:
|
||||
def build_webhook_variable_pool(inputs: dict[str, Any]) -> VariablePool:
|
||||
return build_test_variable_pool(
|
||||
variables=default_system_variables(),
|
||||
node_id="webhook-node-1",
|
||||
@ -105,7 +106,7 @@ def build_webhook_variable_pool(inputs: dict) -> VariablePool:
|
||||
)
|
||||
|
||||
|
||||
def expected_factory_mapping(file_dict: dict) -> dict:
|
||||
def expected_factory_mapping(file_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
return {**file_dict, "upload_file_id": file_dict["related_id"]}
|
||||
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@ -62,7 +63,7 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool)
|
||||
return node
|
||||
|
||||
|
||||
def build_webhook_variable_pool(inputs: dict) -> VariablePool:
|
||||
def build_webhook_variable_pool(inputs: dict[str, Any]) -> VariablePool:
|
||||
return build_test_variable_pool(
|
||||
variables=default_system_variables(),
|
||||
node_id="1",
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@ -30,7 +30,7 @@ class FakeStreamsRedis:
|
||||
self._dollar_snapshots: dict[str, int] = {}
|
||||
|
||||
# Publisher API
|
||||
def xadd(self, key: str, fields: dict, *, maxlen: int | None = None) -> str:
|
||||
def xadd(self, key: str, fields: dict[str, Any], *, maxlen: int | None = None) -> str:
|
||||
"""Append entry to stream; accept optional maxlen for API compatibility.
|
||||
|
||||
The test double ignores maxlen trimming semantics; only records the entry.
|
||||
@ -45,7 +45,7 @@ class FakeStreamsRedis:
|
||||
self._expire_calls[key] = self._expire_calls.get(key, 0) + 1
|
||||
|
||||
# Consumer API
|
||||
def xread(self, streams: dict, block: int | None = None, count: int | None = None):
|
||||
def xread(self, streams: dict[str, Any], block: int | None = None, count: int | None = None):
|
||||
# Expect a single key
|
||||
assert len(streams) == 1
|
||||
key, last_id = next(iter(streams.items()))
|
||||
@ -80,7 +80,7 @@ class BlockingRedis:
|
||||
def __init__(self) -> None:
|
||||
self._release = threading.Event()
|
||||
|
||||
def xread(self, streams: dict, block: int | None = None, count: int | None = None):
|
||||
def xread(self, streams: dict[str, Any], block: int | None = None, count: int | None = None):
|
||||
self._release.wait(timeout=block / 1000.0 if block else None)
|
||||
return []
|
||||
|
||||
@ -245,7 +245,7 @@ class TestStreamsSubscription:
|
||||
self._fields = fields
|
||||
self._calls = 0
|
||||
|
||||
def xread(self, streams: dict, block: int | None = None, count: int | None = None):
|
||||
def xread(self, streams: dict[str, Any], block: int | None = None, count: int | None = None):
|
||||
self._calls += 1
|
||||
if self._calls == 1:
|
||||
key = next(iter(streams))
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import json
|
||||
import uuid
|
||||
from collections import defaultdict, deque
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
@ -60,7 +61,7 @@ class _FakeStreams:
|
||||
self._data: dict[str, list[tuple[str, dict]]] = defaultdict(list)
|
||||
self._seq: dict[str, int] = defaultdict(int)
|
||||
|
||||
def xadd(self, key: str, fields: dict, *, maxlen: int | None = None) -> str:
|
||||
def xadd(self, key: str, fields: dict[str, Any], *, maxlen: int | None = None) -> str:
|
||||
# maxlen is accepted for API compatibility with redis-py; ignored in this test double
|
||||
self._seq[key] += 1
|
||||
eid = f"{self._seq[key]}-0"
|
||||
@ -71,7 +72,7 @@ class _FakeStreams:
|
||||
# no-op for tests
|
||||
return None
|
||||
|
||||
def xread(self, streams: dict, block: int | None = None, count: int | None = None):
|
||||
def xread(self, streams: dict[str, Any], block: int | None = None, count: int | None = None):
|
||||
assert len(streams) == 1
|
||||
key, last_id = next(iter(streams.items()))
|
||||
entries = self._data.get(key, [])
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import datetime
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@ -18,7 +19,7 @@ def make_simple_message(msg_id: str, app_id: str) -> SimpleMessage:
|
||||
return SimpleMessage(id=msg_id, app_id=app_id, created_at=datetime.datetime(2024, 1, 1))
|
||||
|
||||
|
||||
def make_plan_provider(tenant_plans: dict) -> MagicMock:
|
||||
def make_plan_provider(tenant_plans: dict[str, Any]) -> MagicMock:
|
||||
"""Helper to create a mock plan_provider that returns the given tenant_plans."""
|
||||
provider = MagicMock()
|
||||
provider.return_value = tenant_plans
|
||||
|
||||
Loading…
Reference in New Issue
Block a user