test(api): cover quota patch branches

This commit is contained in:
-LAN- 2026-05-08 19:51:01 +08:00
parent 810aad3ad0
commit 6d79dbba86
4 changed files with 366 additions and 1 deletions

View File

@ -3,7 +3,7 @@
import datetime
import uuid
from types import SimpleNamespace
from unittest.mock import Mock, sentinel
from unittest.mock import Mock, patch, sentinel
import pytest
@ -205,6 +205,38 @@ class TestPluginModelRuntime:
stream=False,
)
def test_invoke_llm_returns_plugin_stream_directly(self) -> None:
client = Mock(spec=PluginModelClient)
stream_result = iter([])
client.invoke_llm.return_value = stream_result
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client)
result = runtime.invoke_llm(
provider="langgenius/openai/openai",
model="gpt-4o-mini",
credentials={"api_key": "secret"},
model_parameters={"temperature": 0.3},
prompt_messages=[],
tools=None,
stop=("END",),
stream=True,
)
assert result is stream_result
client.invoke_llm.assert_called_once_with(
tenant_id="tenant",
user_id="user",
plugin_id="langgenius/openai",
provider="openai",
model="gpt-4o-mini",
credentials={"api_key": "secret"},
model_parameters={"temperature": 0.3},
prompt_messages=[],
tools=None,
stop=["END"],
stream=True,
)
def test_invoke_llm_rejects_per_call_user_override(self) -> None:
client = Mock(spec=PluginModelClient)
client.invoke_llm.return_value = sentinel.result
@ -297,6 +329,129 @@ def test_get_model_schema_uses_cached_schema_without_hitting_client(monkeypatch:
client.get_model_schema.assert_not_called()
def test_structured_output_adapter_invokes_bound_runtime_streaming() -> None:
runtime = Mock()
runtime.invoke_llm.return_value = sentinel.stream_result
adapter = model_runtime_module._PluginStructuredOutputModelInstance(
runtime=runtime,
provider="langgenius/openai/openai",
model="gpt-4o-mini",
credentials={"api_key": "secret"},
)
tool = Mock()
result = adapter.invoke_llm(
prompt_messages=[],
model_parameters=None,
tools=[tool],
stop=["END"],
stream=True,
callbacks=sentinel.callbacks,
)
assert result is sentinel.stream_result
runtime.invoke_llm.assert_called_once_with(
provider="langgenius/openai/openai",
model="gpt-4o-mini",
credentials={"api_key": "secret"},
model_parameters={},
prompt_messages=[],
tools=[tool],
stop=["END"],
stream=True,
)
def test_structured_output_adapter_invokes_bound_runtime_non_streaming() -> None:
runtime = Mock()
runtime.invoke_llm.return_value = sentinel.result
adapter = model_runtime_module._PluginStructuredOutputModelInstance(
runtime=runtime,
provider="langgenius/openai/openai",
model="gpt-4o-mini",
credentials={"api_key": "secret"},
)
result = adapter.invoke_llm(
prompt_messages=[],
model_parameters={"temperature": 0},
tools=None,
stop=None,
stream=False,
)
assert result is sentinel.result
runtime.invoke_llm.assert_called_once_with(
provider="langgenius/openai/openai",
model="gpt-4o-mini",
credentials={"api_key": "secret"},
model_parameters={"temperature": 0},
prompt_messages=[],
tools=None,
stop=None,
stream=False,
)
def test_invoke_llm_with_structured_output_delegates_with_bound_adapter() -> None:
client = Mock(spec=PluginModelClient)
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client)
schema = _build_model_schema()
runtime.get_model_schema = Mock(return_value=schema) # type: ignore[method-assign]
with patch.object(
model_runtime_module,
"invoke_llm_with_structured_output_helper",
return_value=sentinel.structured_result,
) as mock_helper:
result = runtime.invoke_llm_with_structured_output(
provider="langgenius/openai/openai",
model="gpt-4o-mini",
credentials={"api_key": "secret"},
json_schema={"type": "object"},
model_parameters={"temperature": 0},
prompt_messages=[],
stop=("END",),
stream=False,
)
assert result is sentinel.structured_result
runtime.get_model_schema.assert_called_once_with(
provider="langgenius/openai/openai",
model_type=ModelType.LLM,
model="gpt-4o-mini",
credentials={"api_key": "secret"},
)
helper_kwargs = mock_helper.call_args.kwargs
assert helper_kwargs["provider"] == "langgenius/openai/openai"
assert helper_kwargs["model_schema"] == schema
assert helper_kwargs["json_schema"] == {"type": "object"}
assert helper_kwargs["model_parameters"] == {"temperature": 0}
assert helper_kwargs["prompt_messages"] == []
assert helper_kwargs["tools"] is None
assert helper_kwargs["stop"] == ["END"]
assert helper_kwargs["stream"] is False
assert isinstance(helper_kwargs["model_instance"], model_runtime_module._PluginStructuredOutputModelInstance)
def test_invoke_llm_with_structured_output_raises_when_model_schema_is_missing() -> None:
client = Mock(spec=PluginModelClient)
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client)
runtime.get_model_schema = Mock(return_value=None) # type: ignore[method-assign]
with pytest.raises(ValueError, match="Model schema not found for gpt-4o-mini"):
runtime.invoke_llm_with_structured_output(
provider="langgenius/openai/openai",
model="gpt-4o-mini",
credentials={"api_key": "secret"},
json_schema={"type": "object"},
model_parameters={},
prompt_messages=[],
stop=None,
stream=False,
)
def test_get_model_schema_deletes_invalid_cache_and_refetches(monkeypatch: pytest.MonkeyPatch) -> None:
client = Mock(spec=PluginModelClient)
schema = _build_model_schema()

View File

@ -112,6 +112,16 @@ def test_non_llm_node_is_ignored() -> None:
mock_deduct.assert_not_called()
def test_precheck_ignores_non_quota_node() -> None:
layer = LLMQuotaLayer(tenant_id="tenant-id")
node = _build_node(node_type=BuiltinNodeTypes.START)
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check:
layer.on_node_run_start(node)
mock_check.assert_not_called()
def test_quota_error_is_handled_in_layer(caplog) -> None:
layer = LLMQuotaLayer(tenant_id="tenant-id")
stop_event = threading.Event()
@ -143,6 +153,18 @@ def test_quota_error_is_handled_in_layer(caplog) -> None:
layer.command_channel.send_command.assert_not_called()
def test_send_abort_command_is_noop_without_channel_or_after_abort() -> None:
layer = LLMQuotaLayer(tenant_id="tenant-id")
layer._send_abort_command(reason="no channel")
layer.command_channel = MagicMock()
layer._abort_sent = True
layer._send_abort_command(reason="already aborted")
layer.command_channel.send_command.assert_not_called()
def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None:
layer = LLMQuotaLayer(tenant_id="tenant-id")
stop_event = threading.Event()
@ -250,6 +272,42 @@ def test_quota_precheck_passes_without_abort() -> None:
layer.command_channel.send_command.assert_not_called()
def test_precheck_reads_model_identity_from_data_when_node_data_is_absent() -> None:
layer = LLMQuotaLayer(tenant_id="tenant-id")
node = SimpleNamespace(
id="node-id",
node_type=BuiltinNodeTypes.LLM,
data=_build_node_data(model=_build_public_model_identity(provider="anthropic", model_name="claude")),
)
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check:
layer.on_node_run_start(node)
mock_check.assert_called_once_with(
tenant_id="tenant-id",
provider="anthropic",
model="claude",
)
def test_precheck_rejects_invalid_public_model_identity() -> None:
layer = LLMQuotaLayer(tenant_id="tenant-id")
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = _build_node(node_type=BuiltinNodeTypes.LLM)
node.node_data = _build_node_data(model=_build_public_model_identity(provider="", model_name="gpt-4o"))
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.stop_event = stop_event
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check:
layer.on_node_run_start(node)
assert stop_event.is_set()
mock_check.assert_not_called()
layer.command_channel.send_command.assert_called_once()
def test_precheck_requires_public_node_model_config() -> None:
layer = LLMQuotaLayer(tenant_id="tenant-id")
stop_event = threading.Event()

View File

@ -66,3 +66,65 @@ def test_message_created_trial_credit_accounting_does_not_raise_when_balance_is_
quota_used = connection.scalar(select(TenantCreditPool.quota_used).where(TenantCreditPool.id == pool_id))
assert quota_used == 10
def test_message_created_paid_credit_accounting_uses_paid_pool() -> None:
tenant_id = str(uuid4())
system_configuration = SimpleNamespace(
current_quota_type=ProviderQuotaType.PAID,
quota_configurations=[
SimpleNamespace(
quota_type=ProviderQuotaType.PAID,
quota_unit=QuotaUnit.TOKENS,
quota_limit=10,
)
],
)
application_generate_entity = ChatAppGenerateEntity.model_construct(
app_config=SimpleNamespace(tenant_id=tenant_id),
model_conf=SimpleNamespace(
provider="openai",
model="gpt-4o",
provider_model_bundle=SimpleNamespace(
configuration=SimpleNamespace(
using_provider_type=ProviderType.SYSTEM,
system_configuration=system_configuration,
)
),
),
)
message = SimpleNamespace(message_tokens=2, answer_tokens=1)
with (
patch.object(update_provider_when_message_created, "_deduct_credit_pool_quota_capped") as mock_deduct,
patch.object(update_provider_when_message_created, "_execute_provider_updates"),
):
update_provider_when_message_created.handle(
sender=message,
application_generate_entity=application_generate_entity,
)
mock_deduct.assert_called_once_with(
tenant_id=tenant_id,
credits_required=3,
pool_type="paid",
)
def test_capped_credit_pool_accounting_skips_exhaustion_warning_when_full_amount_is_deducted(caplog) -> None:
with patch(
"services.credit_pool_service.CreditPoolService.deduct_credits_capped",
return_value=3,
) as mock_deduct:
update_provider_when_message_created._deduct_credit_pool_quota_capped(
tenant_id="tenant-id",
credits_required=3,
pool_type="trial",
)
mock_deduct.assert_called_once_with(
tenant_id="tenant-id",
credits_required=3,
pool_type="trial",
)
assert "Credit pool exhausted during message-created accounting" not in caplog.text

View File

@ -46,6 +46,33 @@ def test_check_and_deduct_credits_deducts_exact_amount_when_sufficient() -> None
assert _get_quota_used(engine=engine, pool_id=pool_id) == 5
def test_check_and_deduct_credits_returns_zero_for_non_positive_request() -> None:
assert CreditPoolService.check_and_deduct_credits(tenant_id=str(uuid4()), credits_required=0) == 0
def test_check_and_deduct_credits_raises_when_pool_is_missing() -> None:
engine = create_engine("sqlite:///:memory:")
TenantCreditPool.__table__.create(engine)
with (
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
pytest.raises(QuotaExceededError, match="Credit pool not found"),
):
CreditPoolService.check_and_deduct_credits(tenant_id=str(uuid4()), credits_required=1)
def test_check_and_deduct_credits_raises_when_pool_is_empty() -> None:
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=10)
with (
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
pytest.raises(QuotaExceededError, match="No credits remaining"),
):
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=1)
assert _get_quota_used(engine=engine, pool_id=pool_id) == 10
def test_check_and_deduct_credits_raises_without_partial_deduction_when_insufficient() -> None:
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=9)
@ -58,6 +85,43 @@ def test_check_and_deduct_credits_raises_without_partial_deduction_when_insuffic
assert _get_quota_used(engine=engine, pool_id=pool_id) == 9
def test_check_and_deduct_credits_wraps_unexpected_deduction_errors() -> None:
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2)
with (
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
patch.object(CreditPoolService, "_get_locked_pool", side_effect=RuntimeError("database unavailable")),
pytest.raises(QuotaExceededError, match="Failed to deduct credits"),
):
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=1)
assert _get_quota_used(engine=engine, pool_id=pool_id) == 2
def test_deduct_credits_capped_returns_zero_for_non_positive_request() -> None:
assert CreditPoolService.deduct_credits_capped(tenant_id=str(uuid4()), credits_required=0) == 0
def test_deduct_credits_capped_returns_zero_when_pool_is_missing() -> None:
engine = create_engine("sqlite:///:memory:")
TenantCreditPool.__table__.create(engine)
with patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)):
deducted_credits = CreditPoolService.deduct_credits_capped(tenant_id=str(uuid4()), credits_required=1)
assert deducted_credits == 0
def test_deduct_credits_capped_returns_zero_when_pool_is_empty() -> None:
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=10)
with patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)):
deducted_credits = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=1)
assert deducted_credits == 0
assert _get_quota_used(engine=engine, pool_id=pool_id) == 10
def test_deduct_credits_capped_deducts_only_remaining_balance_when_insufficient() -> None:
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=9)
@ -66,3 +130,29 @@ def test_deduct_credits_capped_deducts_only_remaining_balance_when_insufficient(
assert deducted_credits == 1
assert _get_quota_used(engine=engine, pool_id=pool_id) == 10
def test_deduct_credits_capped_wraps_unexpected_deduction_errors() -> None:
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2)
with (
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
patch.object(CreditPoolService, "_get_locked_pool", side_effect=RuntimeError("database unavailable")),
pytest.raises(QuotaExceededError, match="Failed to deduct credits"),
):
CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=1)
assert _get_quota_used(engine=engine, pool_id=pool_id) == 2
def test_deduct_credits_capped_reraises_quota_exceeded_errors() -> None:
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2)
with (
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
patch.object(CreditPoolService, "_get_locked_pool", side_effect=QuotaExceededError("quota unavailable")),
pytest.raises(QuotaExceededError, match="quota unavailable"),
):
CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=1)
assert _get_quota_used(engine=engine, pool_id=pool_id) == 2