diff --git a/api/core/plugin/impl/model_runtime_factory.py b/api/core/plugin/impl/model_runtime_factory.py index 98a5660fdf..fbe307ea60 100644 --- a/api/core/plugin/impl/model_runtime_factory.py +++ b/api/core/plugin/impl/model_runtime_factory.py @@ -125,20 +125,6 @@ def create_plugin_model_provider_factory(*, tenant_id: str, user_id: str | None return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).model_provider_factory -def create_plugin_model_type_instance( - *, - tenant_id: str, - provider: str, - model_type: ModelType, - user_id: str | None = None, -) -> AIModel: - """Create a tenant-bound model wrapper for the requested provider and model type.""" - return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).create_model_type_instance( - provider=provider, - model_type=model_type, - ) - - def create_plugin_provider_manager(*, tenant_id: str, user_id: str | None = None) -> ProviderManager: """Create a tenant-bound provider manager for service flows.""" return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).provider_manager diff --git a/api/tests/unit_tests/core/app/test_llm_quota.py b/api/tests/unit_tests/core/app/test_llm_quota.py index 4d195eff46..de6aa1ec9f 100644 --- a/api/tests/unit_tests/core/app/test_llm_quota.py +++ b/api/tests/unit_tests/core/app/test_llm_quota.py @@ -1,7 +1,8 @@ from types import SimpleNamespace -from unittest.mock import MagicMock, patch, sentinel +from unittest.mock import MagicMock, patch import pytest +from sqlalchemy import create_engine, select from configs import dify_config from core.app.llm.quota import ( @@ -15,7 +16,7 @@ from core.entities.provider_entities import ProviderQuotaType, QuotaUnit from core.errors.error import QuotaExceededError from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.model_runtime.entities.model_entities import ModelType -from models.provider import ProviderType +from models.provider import Provider, ProviderType def test_ensure_llm_quota_available_for_model_raises_when_system_model_is_exhausted() -> None: @@ -271,17 +272,58 @@ def test_deduct_llm_quota_for_model_updates_free_quota_usage() -> None: ) provider_manager = MagicMock() provider_manager.get_configurations.return_value.get.return_value = provider_configuration - session = MagicMock() - session_context = MagicMock() - session_context.__enter__.return_value = session - session_context.__exit__.return_value = False - session_factory = MagicMock() - session_factory.begin.return_value = session_context + engine = create_engine("sqlite:///:memory:") + Provider.__table__.create(engine) + with engine.begin() as connection: + connection.execute( + Provider.__table__.insert(), + [ + { + "id": "matching-provider", + "tenant_id": "tenant-id", + "provider_name": "openai", + "provider_type": ProviderType.SYSTEM, + "quota_type": ProviderQuotaType.FREE, + "quota_limit": 100, + "quota_used": 10, + "is_valid": True, + }, + { + "id": "other-tenant", + "tenant_id": "other-tenant-id", + "provider_name": "openai", + "provider_type": ProviderType.SYSTEM, + "quota_type": ProviderQuotaType.FREE, + "quota_limit": 100, + "quota_used": 20, + "is_valid": True, + }, + { + "id": "other-provider", + "tenant_id": "tenant-id", + "provider_name": "anthropic", + "provider_type": ProviderType.SYSTEM, + "quota_type": ProviderQuotaType.FREE, + "quota_limit": 100, + "quota_used": 30, + "is_valid": True, + }, + { + "id": "custom-provider", + "tenant_id": "tenant-id", + "provider_name": "openai", + "provider_type": ProviderType.CUSTOM, + "quota_type": ProviderQuotaType.FREE, + "quota_limit": 100, + "quota_used": 40, + "is_valid": True, + }, + ], + ) with ( patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager), - patch("core.app.llm.quota.db", SimpleNamespace(engine=sentinel.engine)), - patch("core.app.llm.quota.sessionmaker", return_value=session_factory), + patch("core.app.llm.quota.db", SimpleNamespace(engine=engine)), ): deduct_llm_quota_for_model( tenant_id="tenant-id", @@ -290,7 +332,36 @@ def test_deduct_llm_quota_for_model_updates_free_quota_usage() -> None: usage=usage, ) - session.execute.assert_called_once() + with engine.connect() as connection: + quota_used_by_id = dict(connection.execute(select(Provider.id, Provider.quota_used)).all()) + + assert quota_used_by_id == { + "matching-provider": 13, + "other-tenant": 20, + "other-provider": 30, + "custom-provider": 40, + } + + with engine.begin() as connection: + connection.execute( + Provider.__table__.update().where(Provider.id == "matching-provider").values(quota_limit=13, quota_used=13) + ) + + with ( + patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager), + patch("core.app.llm.quota.db", SimpleNamespace(engine=engine)), + ): + deduct_llm_quota_for_model( + tenant_id="tenant-id", + provider="openai", + model="gpt-4o", + usage=usage, + ) + + with engine.connect() as connection: + exhausted_quota_used = connection.scalar(select(Provider.quota_used).where(Provider.id == "matching-provider")) + + assert exhausted_quota_used == 13 def test_deduct_llm_quota_for_model_ignores_unknown_quota_type() -> None: @@ -357,42 +428,6 @@ def test_deduct_llm_quota_for_model_ignores_custom_provider_configuration() -> N mock_sessionmaker.assert_not_called() -def test_deduct_llm_quota_for_model_reuses_resolved_provider_configuration_for_deduction() -> None: - usage = LLMUsage.empty_usage() - usage.total_tokens = 42 - provider_configuration = SimpleNamespace( - using_provider_type=ProviderType.SYSTEM, - system_configuration=SimpleNamespace( - current_quota_type=ProviderQuotaType.TRIAL, - quota_configurations=[ - SimpleNamespace( - quota_type=ProviderQuotaType.TRIAL, - quota_unit=QuotaUnit.TOKENS, - quota_limit=100, - ) - ], - ), - ) - - with ( - patch("core.app.llm.quota._get_provider_configuration", return_value=provider_configuration), - patch("core.app.llm.quota._deduct_used_llm_quota") as mock_deduct, - ): - deduct_llm_quota_for_model( - tenant_id="tenant-id", - provider="openai", - model="gpt-4o", - usage=usage, - ) - - mock_deduct.assert_called_once_with( - tenant_id="tenant-id", - provider="openai", - provider_configuration=provider_configuration, - used_quota=42, - ) - - def test_ensure_llm_quota_available_wrapper_warns_and_delegates() -> None: model_instance = SimpleNamespace( provider="openai", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py index b51fa454cc..6eb7a602a7 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py @@ -1,3 +1,4 @@ +import logging import threading from datetime import datetime from types import SimpleNamespace @@ -111,18 +112,36 @@ def test_non_llm_node_is_ignored() -> None: mock_deduct.assert_not_called() -def test_quota_error_is_handled_in_layer() -> None: +def test_quota_error_is_handled_in_layer(caplog) -> None: layer = LLMQuotaLayer(tenant_id="tenant-id") + stop_event = threading.Event() + layer.command_channel = MagicMock() + node = _build_node(node_type=BuiltinNodeTypes.LLM) + node.graph_runtime_state = MagicMock() + node.graph_runtime_state.stop_event = stop_event result_event = _build_succeeded_event() - with patch( - "core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model", - autospec=True, - side_effect=ValueError("quota exceeded"), + with ( + caplog.at_level(logging.ERROR, logger="core.app.workflow.layers.llm_quota"), + patch( + "core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model", + autospec=True, + side_effect=ValueError("quota exceeded"), + ) as mock_deduct, ): layer.on_node_run_end(node=node, error=None, result_event=result_event) + mock_deduct.assert_called_once_with( + tenant_id="tenant-id", + provider="openai", + model="gpt-4o", + usage=result_event.node_run_result.llm_usage, + ) + assert "LLM quota deduction failed, node_id=node-id" in caplog.text + assert not stop_event.is_set() + layer.command_channel.send_command.assert_not_called() + def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None: layer = LLMQuotaLayer(tenant_id="tenant-id")