mirror of
https://github.com/langgenius/dify.git
synced 2026-05-09 04:36:31 +08:00
test(api): cover quota patch branches
This commit is contained in:
parent
810aad3ad0
commit
6d79dbba86
@ -3,7 +3,7 @@
|
||||
import datetime
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, sentinel
|
||||
from unittest.mock import Mock, patch, sentinel
|
||||
|
||||
import pytest
|
||||
|
||||
@ -205,6 +205,38 @@ class TestPluginModelRuntime:
|
||||
stream=False,
|
||||
)
|
||||
|
||||
def test_invoke_llm_returns_plugin_stream_directly(self) -> None:
|
||||
client = Mock(spec=PluginModelClient)
|
||||
stream_result = iter([])
|
||||
client.invoke_llm.return_value = stream_result
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client)
|
||||
|
||||
result = runtime.invoke_llm(
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-4o-mini",
|
||||
credentials={"api_key": "secret"},
|
||||
model_parameters={"temperature": 0.3},
|
||||
prompt_messages=[],
|
||||
tools=None,
|
||||
stop=("END",),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
assert result is stream_result
|
||||
client.invoke_llm.assert_called_once_with(
|
||||
tenant_id="tenant",
|
||||
user_id="user",
|
||||
plugin_id="langgenius/openai",
|
||||
provider="openai",
|
||||
model="gpt-4o-mini",
|
||||
credentials={"api_key": "secret"},
|
||||
model_parameters={"temperature": 0.3},
|
||||
prompt_messages=[],
|
||||
tools=None,
|
||||
stop=["END"],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
def test_invoke_llm_rejects_per_call_user_override(self) -> None:
|
||||
client = Mock(spec=PluginModelClient)
|
||||
client.invoke_llm.return_value = sentinel.result
|
||||
@ -297,6 +329,129 @@ def test_get_model_schema_uses_cached_schema_without_hitting_client(monkeypatch:
|
||||
client.get_model_schema.assert_not_called()
|
||||
|
||||
|
||||
def test_structured_output_adapter_invokes_bound_runtime_streaming() -> None:
|
||||
runtime = Mock()
|
||||
runtime.invoke_llm.return_value = sentinel.stream_result
|
||||
adapter = model_runtime_module._PluginStructuredOutputModelInstance(
|
||||
runtime=runtime,
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-4o-mini",
|
||||
credentials={"api_key": "secret"},
|
||||
)
|
||||
tool = Mock()
|
||||
|
||||
result = adapter.invoke_llm(
|
||||
prompt_messages=[],
|
||||
model_parameters=None,
|
||||
tools=[tool],
|
||||
stop=["END"],
|
||||
stream=True,
|
||||
callbacks=sentinel.callbacks,
|
||||
)
|
||||
|
||||
assert result is sentinel.stream_result
|
||||
runtime.invoke_llm.assert_called_once_with(
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-4o-mini",
|
||||
credentials={"api_key": "secret"},
|
||||
model_parameters={},
|
||||
prompt_messages=[],
|
||||
tools=[tool],
|
||||
stop=["END"],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
|
||||
def test_structured_output_adapter_invokes_bound_runtime_non_streaming() -> None:
|
||||
runtime = Mock()
|
||||
runtime.invoke_llm.return_value = sentinel.result
|
||||
adapter = model_runtime_module._PluginStructuredOutputModelInstance(
|
||||
runtime=runtime,
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-4o-mini",
|
||||
credentials={"api_key": "secret"},
|
||||
)
|
||||
|
||||
result = adapter.invoke_llm(
|
||||
prompt_messages=[],
|
||||
model_parameters={"temperature": 0},
|
||||
tools=None,
|
||||
stop=None,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert result is sentinel.result
|
||||
runtime.invoke_llm.assert_called_once_with(
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-4o-mini",
|
||||
credentials={"api_key": "secret"},
|
||||
model_parameters={"temperature": 0},
|
||||
prompt_messages=[],
|
||||
tools=None,
|
||||
stop=None,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_llm_with_structured_output_delegates_with_bound_adapter() -> None:
|
||||
client = Mock(spec=PluginModelClient)
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client)
|
||||
schema = _build_model_schema()
|
||||
runtime.get_model_schema = Mock(return_value=schema) # type: ignore[method-assign]
|
||||
|
||||
with patch.object(
|
||||
model_runtime_module,
|
||||
"invoke_llm_with_structured_output_helper",
|
||||
return_value=sentinel.structured_result,
|
||||
) as mock_helper:
|
||||
result = runtime.invoke_llm_with_structured_output(
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-4o-mini",
|
||||
credentials={"api_key": "secret"},
|
||||
json_schema={"type": "object"},
|
||||
model_parameters={"temperature": 0},
|
||||
prompt_messages=[],
|
||||
stop=("END",),
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert result is sentinel.structured_result
|
||||
runtime.get_model_schema.assert_called_once_with(
|
||||
provider="langgenius/openai/openai",
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o-mini",
|
||||
credentials={"api_key": "secret"},
|
||||
)
|
||||
helper_kwargs = mock_helper.call_args.kwargs
|
||||
assert helper_kwargs["provider"] == "langgenius/openai/openai"
|
||||
assert helper_kwargs["model_schema"] == schema
|
||||
assert helper_kwargs["json_schema"] == {"type": "object"}
|
||||
assert helper_kwargs["model_parameters"] == {"temperature": 0}
|
||||
assert helper_kwargs["prompt_messages"] == []
|
||||
assert helper_kwargs["tools"] is None
|
||||
assert helper_kwargs["stop"] == ["END"]
|
||||
assert helper_kwargs["stream"] is False
|
||||
assert isinstance(helper_kwargs["model_instance"], model_runtime_module._PluginStructuredOutputModelInstance)
|
||||
|
||||
|
||||
def test_invoke_llm_with_structured_output_raises_when_model_schema_is_missing() -> None:
|
||||
client = Mock(spec=PluginModelClient)
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client)
|
||||
runtime.get_model_schema = Mock(return_value=None) # type: ignore[method-assign]
|
||||
|
||||
with pytest.raises(ValueError, match="Model schema not found for gpt-4o-mini"):
|
||||
runtime.invoke_llm_with_structured_output(
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-4o-mini",
|
||||
credentials={"api_key": "secret"},
|
||||
json_schema={"type": "object"},
|
||||
model_parameters={},
|
||||
prompt_messages=[],
|
||||
stop=None,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
|
||||
def test_get_model_schema_deletes_invalid_cache_and_refetches(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
client = Mock(spec=PluginModelClient)
|
||||
schema = _build_model_schema()
|
||||
|
||||
@ -112,6 +112,16 @@ def test_non_llm_node_is_ignored() -> None:
|
||||
mock_deduct.assert_not_called()
|
||||
|
||||
|
||||
def test_precheck_ignores_non_quota_node() -> None:
|
||||
layer = LLMQuotaLayer(tenant_id="tenant-id")
|
||||
node = _build_node(node_type=BuiltinNodeTypes.START)
|
||||
|
||||
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check:
|
||||
layer.on_node_run_start(node)
|
||||
|
||||
mock_check.assert_not_called()
|
||||
|
||||
|
||||
def test_quota_error_is_handled_in_layer(caplog) -> None:
|
||||
layer = LLMQuotaLayer(tenant_id="tenant-id")
|
||||
stop_event = threading.Event()
|
||||
@ -143,6 +153,18 @@ def test_quota_error_is_handled_in_layer(caplog) -> None:
|
||||
layer.command_channel.send_command.assert_not_called()
|
||||
|
||||
|
||||
def test_send_abort_command_is_noop_without_channel_or_after_abort() -> None:
|
||||
layer = LLMQuotaLayer(tenant_id="tenant-id")
|
||||
|
||||
layer._send_abort_command(reason="no channel")
|
||||
|
||||
layer.command_channel = MagicMock()
|
||||
layer._abort_sent = True
|
||||
layer._send_abort_command(reason="already aborted")
|
||||
|
||||
layer.command_channel.send_command.assert_not_called()
|
||||
|
||||
|
||||
def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None:
|
||||
layer = LLMQuotaLayer(tenant_id="tenant-id")
|
||||
stop_event = threading.Event()
|
||||
@ -250,6 +272,42 @@ def test_quota_precheck_passes_without_abort() -> None:
|
||||
layer.command_channel.send_command.assert_not_called()
|
||||
|
||||
|
||||
def test_precheck_reads_model_identity_from_data_when_node_data_is_absent() -> None:
|
||||
layer = LLMQuotaLayer(tenant_id="tenant-id")
|
||||
node = SimpleNamespace(
|
||||
id="node-id",
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
data=_build_node_data(model=_build_public_model_identity(provider="anthropic", model_name="claude")),
|
||||
)
|
||||
|
||||
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check:
|
||||
layer.on_node_run_start(node)
|
||||
|
||||
mock_check.assert_called_once_with(
|
||||
tenant_id="tenant-id",
|
||||
provider="anthropic",
|
||||
model="claude",
|
||||
)
|
||||
|
||||
|
||||
def test_precheck_rejects_invalid_public_model_identity() -> None:
|
||||
layer = LLMQuotaLayer(tenant_id="tenant-id")
|
||||
stop_event = threading.Event()
|
||||
layer.command_channel = MagicMock()
|
||||
|
||||
node = _build_node(node_type=BuiltinNodeTypes.LLM)
|
||||
node.node_data = _build_node_data(model=_build_public_model_identity(provider="", model_name="gpt-4o"))
|
||||
node.graph_runtime_state = MagicMock()
|
||||
node.graph_runtime_state.stop_event = stop_event
|
||||
|
||||
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check:
|
||||
layer.on_node_run_start(node)
|
||||
|
||||
assert stop_event.is_set()
|
||||
mock_check.assert_not_called()
|
||||
layer.command_channel.send_command.assert_called_once()
|
||||
|
||||
|
||||
def test_precheck_requires_public_node_model_config() -> None:
|
||||
layer = LLMQuotaLayer(tenant_id="tenant-id")
|
||||
stop_event = threading.Event()
|
||||
|
||||
@ -66,3 +66,65 @@ def test_message_created_trial_credit_accounting_does_not_raise_when_balance_is_
|
||||
quota_used = connection.scalar(select(TenantCreditPool.quota_used).where(TenantCreditPool.id == pool_id))
|
||||
|
||||
assert quota_used == 10
|
||||
|
||||
|
||||
def test_message_created_paid_credit_accounting_uses_paid_pool() -> None:
|
||||
tenant_id = str(uuid4())
|
||||
system_configuration = SimpleNamespace(
|
||||
current_quota_type=ProviderQuotaType.PAID,
|
||||
quota_configurations=[
|
||||
SimpleNamespace(
|
||||
quota_type=ProviderQuotaType.PAID,
|
||||
quota_unit=QuotaUnit.TOKENS,
|
||||
quota_limit=10,
|
||||
)
|
||||
],
|
||||
)
|
||||
application_generate_entity = ChatAppGenerateEntity.model_construct(
|
||||
app_config=SimpleNamespace(tenant_id=tenant_id),
|
||||
model_conf=SimpleNamespace(
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
provider_model_bundle=SimpleNamespace(
|
||||
configuration=SimpleNamespace(
|
||||
using_provider_type=ProviderType.SYSTEM,
|
||||
system_configuration=system_configuration,
|
||||
)
|
||||
),
|
||||
),
|
||||
)
|
||||
message = SimpleNamespace(message_tokens=2, answer_tokens=1)
|
||||
|
||||
with (
|
||||
patch.object(update_provider_when_message_created, "_deduct_credit_pool_quota_capped") as mock_deduct,
|
||||
patch.object(update_provider_when_message_created, "_execute_provider_updates"),
|
||||
):
|
||||
update_provider_when_message_created.handle(
|
||||
sender=message,
|
||||
application_generate_entity=application_generate_entity,
|
||||
)
|
||||
|
||||
mock_deduct.assert_called_once_with(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=3,
|
||||
pool_type="paid",
|
||||
)
|
||||
|
||||
|
||||
def test_capped_credit_pool_accounting_skips_exhaustion_warning_when_full_amount_is_deducted(caplog) -> None:
|
||||
with patch(
|
||||
"services.credit_pool_service.CreditPoolService.deduct_credits_capped",
|
||||
return_value=3,
|
||||
) as mock_deduct:
|
||||
update_provider_when_message_created._deduct_credit_pool_quota_capped(
|
||||
tenant_id="tenant-id",
|
||||
credits_required=3,
|
||||
pool_type="trial",
|
||||
)
|
||||
|
||||
mock_deduct.assert_called_once_with(
|
||||
tenant_id="tenant-id",
|
||||
credits_required=3,
|
||||
pool_type="trial",
|
||||
)
|
||||
assert "Credit pool exhausted during message-created accounting" not in caplog.text
|
||||
|
||||
@ -46,6 +46,33 @@ def test_check_and_deduct_credits_deducts_exact_amount_when_sufficient() -> None
|
||||
assert _get_quota_used(engine=engine, pool_id=pool_id) == 5
|
||||
|
||||
|
||||
def test_check_and_deduct_credits_returns_zero_for_non_positive_request() -> None:
|
||||
assert CreditPoolService.check_and_deduct_credits(tenant_id=str(uuid4()), credits_required=0) == 0
|
||||
|
||||
|
||||
def test_check_and_deduct_credits_raises_when_pool_is_missing() -> None:
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
TenantCreditPool.__table__.create(engine)
|
||||
|
||||
with (
|
||||
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
|
||||
pytest.raises(QuotaExceededError, match="Credit pool not found"),
|
||||
):
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id=str(uuid4()), credits_required=1)
|
||||
|
||||
|
||||
def test_check_and_deduct_credits_raises_when_pool_is_empty() -> None:
|
||||
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=10)
|
||||
|
||||
with (
|
||||
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
|
||||
pytest.raises(QuotaExceededError, match="No credits remaining"),
|
||||
):
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=1)
|
||||
|
||||
assert _get_quota_used(engine=engine, pool_id=pool_id) == 10
|
||||
|
||||
|
||||
def test_check_and_deduct_credits_raises_without_partial_deduction_when_insufficient() -> None:
|
||||
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=9)
|
||||
|
||||
@ -58,6 +85,43 @@ def test_check_and_deduct_credits_raises_without_partial_deduction_when_insuffic
|
||||
assert _get_quota_used(engine=engine, pool_id=pool_id) == 9
|
||||
|
||||
|
||||
def test_check_and_deduct_credits_wraps_unexpected_deduction_errors() -> None:
|
||||
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2)
|
||||
|
||||
with (
|
||||
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
|
||||
patch.object(CreditPoolService, "_get_locked_pool", side_effect=RuntimeError("database unavailable")),
|
||||
pytest.raises(QuotaExceededError, match="Failed to deduct credits"),
|
||||
):
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=1)
|
||||
|
||||
assert _get_quota_used(engine=engine, pool_id=pool_id) == 2
|
||||
|
||||
|
||||
def test_deduct_credits_capped_returns_zero_for_non_positive_request() -> None:
|
||||
assert CreditPoolService.deduct_credits_capped(tenant_id=str(uuid4()), credits_required=0) == 0
|
||||
|
||||
|
||||
def test_deduct_credits_capped_returns_zero_when_pool_is_missing() -> None:
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
TenantCreditPool.__table__.create(engine)
|
||||
|
||||
with patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)):
|
||||
deducted_credits = CreditPoolService.deduct_credits_capped(tenant_id=str(uuid4()), credits_required=1)
|
||||
|
||||
assert deducted_credits == 0
|
||||
|
||||
|
||||
def test_deduct_credits_capped_returns_zero_when_pool_is_empty() -> None:
|
||||
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=10)
|
||||
|
||||
with patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)):
|
||||
deducted_credits = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=1)
|
||||
|
||||
assert deducted_credits == 0
|
||||
assert _get_quota_used(engine=engine, pool_id=pool_id) == 10
|
||||
|
||||
|
||||
def test_deduct_credits_capped_deducts_only_remaining_balance_when_insufficient() -> None:
|
||||
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=9)
|
||||
|
||||
@ -66,3 +130,29 @@ def test_deduct_credits_capped_deducts_only_remaining_balance_when_insufficient(
|
||||
|
||||
assert deducted_credits == 1
|
||||
assert _get_quota_used(engine=engine, pool_id=pool_id) == 10
|
||||
|
||||
|
||||
def test_deduct_credits_capped_wraps_unexpected_deduction_errors() -> None:
|
||||
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2)
|
||||
|
||||
with (
|
||||
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
|
||||
patch.object(CreditPoolService, "_get_locked_pool", side_effect=RuntimeError("database unavailable")),
|
||||
pytest.raises(QuotaExceededError, match="Failed to deduct credits"),
|
||||
):
|
||||
CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=1)
|
||||
|
||||
assert _get_quota_used(engine=engine, pool_id=pool_id) == 2
|
||||
|
||||
|
||||
def test_deduct_credits_capped_reraises_quota_exceeded_errors() -> None:
|
||||
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2)
|
||||
|
||||
with (
|
||||
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
|
||||
patch.object(CreditPoolService, "_get_locked_pool", side_effect=QuotaExceededError("quota unavailable")),
|
||||
pytest.raises(QuotaExceededError, match="quota unavailable"),
|
||||
):
|
||||
CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=1)
|
||||
|
||||
assert _get_quota_used(engine=engine, pool_id=pool_id) == 2
|
||||
|
||||
Loading…
Reference in New Issue
Block a user