refactor(api): decouple llm quota layer from model instances

Introduce tenant-scoped quota helpers that accept provider and model identity directly.

Refactor the workflow quota layer and engine wiring to use public node and event model identity instead of reconstructing ModelInstance, and keep the legacy ModelInstance helpers as deprecated wrappers with focused test coverage.
This commit is contained in:
-LAN- 2026-04-22 14:08:27 +08:00
parent 4839fcc4f8
commit 9b65e53c12
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF
7 changed files with 393 additions and 160 deletions

View File

@ -1,5 +1,15 @@
"""LLM-related application services."""
from .quota import deduct_llm_quota, ensure_llm_quota_available
from .quota import (
deduct_llm_quota,
deduct_llm_quota_for_model,
ensure_llm_quota_available,
ensure_llm_quota_available_for_model,
)
__all__ = ["deduct_llm_quota", "ensure_llm_quota_available"]
__all__ = [
"deduct_llm_quota",
"deduct_llm_quota_for_model",
"ensure_llm_quota_available",
"ensure_llm_quota_available_for_model",
]

View File

@ -1,3 +1,12 @@
"""Tenant-scoped helpers for checking and deducting LLM provider quota.
Workflow callers now bill quota from public model identity instead of passing a
fully prepared ``ModelInstance``. Keep the model-instance helpers as thin,
deprecated adapters so non-workflow code can move independently.
"""
import warnings
from sqlalchemy import update
from sqlalchemy.orm import sessionmaker
@ -6,32 +15,41 @@ from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from core.errors.error import QuotaExceededError
from core.model_manager import ModelInstance
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from extensions.ext_database import db
from graphon.model_runtime.entities.llm_entities import LLMUsage
from graphon.model_runtime.entities.model_entities import ModelType
from libs.datetime_utils import naive_utc_now
from models.provider import Provider, ProviderType
from models.provider_ids import ModelProviderID
def ensure_llm_quota_available(*, model_instance: ModelInstance) -> None:
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
def _get_provider_configuration(*, tenant_id: str, provider: str):
"""Resolve the tenant-bound provider configuration for quota decisions."""
provider_manager = create_plugin_provider_manager(tenant_id=tenant_id)
provider_configuration = provider_manager.get_configurations(tenant_id).get(provider)
if provider_configuration is None:
raise ValueError(f"Provider {provider} does not exist.")
return provider_configuration
def ensure_llm_quota_available_for_model(*, tenant_id: str, provider: str, model: str) -> None:
"""Raise when a tenant-bound system provider model is already out of quota."""
provider_configuration = _get_provider_configuration(tenant_id=tenant_id, provider=provider)
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
provider_model = provider_configuration.get_provider_model(
model_type=model_instance.model_type_instance.model_type,
model=model_instance.model_name,
model_type=ModelType.LLM,
model=model,
)
if provider_model and provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {model_instance.provider} quota exceeded.")
raise QuotaExceededError(f"Model provider {provider} quota exceeded.")
def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
def deduct_llm_quota_for_model(*, tenant_id: str, provider: str, model: str, usage: LLMUsage) -> None:
"""Deduct tenant-bound quota for the resolved LLM model identity."""
provider_configuration = _get_provider_configuration(tenant_id=tenant_id, provider=provider)
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
@ -52,7 +70,7 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL
if quota_unit == QuotaUnit.TOKENS:
used_quota = usage.total_tokens
elif quota_unit == QuotaUnit.CREDITS:
used_quota = dify_config.get_model_credits(model_instance.model_name)
used_quota = dify_config.get_model_credits(model)
else:
used_quota = 1
@ -80,7 +98,7 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_name == ModelProviderID(provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type,
Provider.quota_limit > Provider.quota_used,
@ -91,3 +109,34 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL
)
)
session.execute(stmt)
def ensure_llm_quota_available(*, model_instance: ModelInstance) -> None:
"""Deprecated compatibility wrapper for callers that still pass ModelInstance."""
warnings.warn(
"ensure_llm_quota_available(model_instance=...) is deprecated; "
"use ensure_llm_quota_available_for_model(...) instead.",
DeprecationWarning,
stacklevel=2,
)
ensure_llm_quota_available_for_model(
tenant_id=model_instance.provider_model_bundle.configuration.tenant_id,
provider=model_instance.provider,
model=model_instance.model_name,
)
def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
"""Deprecated compatibility wrapper for callers that still pass ModelInstance."""
warnings.warn(
"deduct_llm_quota(tenant_id=..., model_instance=..., usage=...) is deprecated; "
"use deduct_llm_quota_for_model(...) instead.",
DeprecationWarning,
stacklevel=2,
)
deduct_llm_quota_for_model(
tenant_id=tenant_id,
provider=model_instance.provider,
model=model_instance.model_name,
usage=usage,
)

View File

@ -1,16 +1,19 @@
"""
LLM quota deduction layer for GraphEngine.
This layer centralizes model-quota deduction outside node implementations.
This layer centralizes model-quota handling outside node implementations.
Graphon LLM-backed nodes expose provider/model identity through public node
configuration and, after execution, through ``node_run_result.inputs``. Resolve
quota billing from that public identity instead of depending on
``ModelInstance`` reconstruction inside the workflow layer.
"""
import logging
from typing import final, override
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
from core.app.llm import deduct_llm_quota, ensure_llm_quota_available
from core.app.llm import deduct_llm_quota_for_model, ensure_llm_quota_available_for_model
from core.errors.error import QuotaExceededError
from core.model_manager import ModelInstance
from graphon.enums import BuiltinNodeTypes
from graphon.graph_engine.entities.commands import AbortCommand, CommandType
from graphon.graph_engine.layers import GraphEngineLayer
@ -18,14 +21,25 @@ from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, NodeRunSu
from graphon.nodes.base.node import Node
logger = logging.getLogger(__name__)
_QUOTA_NODE_TYPES = frozenset(
[
BuiltinNodeTypes.LLM,
BuiltinNodeTypes.PARAMETER_EXTRACTOR,
BuiltinNodeTypes.QUESTION_CLASSIFIER,
]
)
@final
class LLMQuotaLayer(GraphEngineLayer):
"""Graph layer that applies LLM quota deduction after node execution."""
"""Graph layer that applies tenant-scoped quota checks to LLM-backed nodes."""
def __init__(self) -> None:
tenant_id: str
_abort_sent: bool
def __init__(self, tenant_id: str) -> None:
super().__init__()
self.tenant_id = tenant_id
self._abort_sent = False
@override
@ -45,12 +59,20 @@ class LLMQuotaLayer(GraphEngineLayer):
if self._abort_sent:
return
model_instance = self._extract_model_instance(node)
if model_instance is None:
if not self._supports_quota(node):
return
model_identity = self._extract_model_identity_from_node(node)
if model_identity is None:
return
provider, model_name = model_identity
try:
ensure_llm_quota_available(model_instance=model_instance)
ensure_llm_quota_available_for_model(
tenant_id=self.tenant_id,
provider=provider,
model=model_name,
)
except QuotaExceededError as exc:
self._set_stop_event(node)
self._send_abort_command(reason=str(exc))
@ -60,18 +82,20 @@ class LLMQuotaLayer(GraphEngineLayer):
def on_node_run_end(
self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None
) -> None:
if error is not None or not isinstance(result_event, NodeRunSucceededEvent):
if error is not None or not isinstance(result_event, NodeRunSucceededEvent) or not self._supports_quota(node):
return
model_instance = self._extract_model_instance(node)
if model_instance is None:
model_identity = self._extract_model_identity_from_result_event(result_event)
if model_identity is None:
return
provider, model_name = model_identity
try:
dify_ctx = DifyRunContext.model_validate(node.require_run_context_value(DIFY_RUN_CONTEXT_KEY))
deduct_llm_quota(
tenant_id=dify_ctx.tenant_id,
model_instance=model_instance,
deduct_llm_quota_for_model(
tenant_id=self.tenant_id,
provider=provider,
model=model_name,
usage=result_event.node_run_result.llm_usage,
)
except QuotaExceededError as exc:
@ -103,35 +127,38 @@ class LLMQuotaLayer(GraphEngineLayer):
logger.exception("Failed to send quota abort command")
@staticmethod
def _extract_model_instance(node: Node) -> ModelInstance | None:
match node.node_type:
case BuiltinNodeTypes.LLM | BuiltinNodeTypes.PARAMETER_EXTRACTOR | BuiltinNodeTypes.QUESTION_CLASSIFIER:
pass
case _:
return None
def _supports_quota(node: Node) -> bool:
return node.node_type in _QUOTA_NODE_TYPES
try:
model_instance = getattr(node, "model_instance", None)
except AttributeError:
@staticmethod
def _extract_model_identity_from_result_event(result_event: NodeRunSucceededEvent) -> tuple[str, str] | None:
provider = result_event.node_run_result.inputs.get("model_provider")
model_name = result_event.node_run_result.inputs.get("model_name")
if isinstance(provider, str) and provider and isinstance(model_name, str) and model_name:
return provider, model_name
return None
@staticmethod
def _extract_model_identity_from_node(node: Node) -> tuple[str, str] | None:
node_data = getattr(node, "node_data", None)
if node_data is None:
node_data = getattr(node, "data", None)
model_config = getattr(node_data, "model", None)
if model_config is None:
logger.warning(
"LLMQuotaLayer skipped quota deduction because node does not expose a model instance, node_id=%s",
"LLMQuotaLayer skipped quota handling because node model config is missing, node_id=%s",
node.id,
)
return None
if isinstance(model_instance, ModelInstance):
return model_instance
raw_model_instance = getattr(model_instance, "_model_instance", None)
if isinstance(raw_model_instance, ModelInstance):
return raw_model_instance
private_model_instance = getattr(node, "_model_instance", None)
if isinstance(private_model_instance, ModelInstance):
return private_model_instance
wrapped_private_model_instance = getattr(private_model_instance, "_model_instance", None)
if isinstance(wrapped_private_model_instance, ModelInstance):
return wrapped_private_model_instance
provider = getattr(model_config, "provider", None)
model_name = getattr(model_config, "name", None)
if isinstance(provider, str) and provider and isinstance(model_name, str) and model_name:
return provider, model_name
logger.warning(
"LLMQuotaLayer skipped quota handling because node model identity is invalid, node_id=%s",
node.id,
)
return None

View File

@ -46,6 +46,11 @@ _file_access_controller = DatabaseFileAccessController()
class _WorkflowChildEngineBuilder:
tenant_id: str
def __init__(self, *, tenant_id: str) -> None:
self.tenant_id = tenant_id
@staticmethod
def _has_node_id(graph_config: Mapping[str, Any], node_id: str) -> bool | None:
"""
@ -107,7 +112,7 @@ class _WorkflowChildEngineBuilder:
config=config,
child_engine_builder=self,
)
child_engine.layer(LLMQuotaLayer())
child_engine.layer(LLMQuotaLayer(tenant_id=self.tenant_id))
return child_engine
@ -176,7 +181,7 @@ class WorkflowEntry:
self.command_channel = command_channel
execution_context = capture_current_context()
graph_runtime_state.execution_context = execution_context
self._child_engine_builder = _WorkflowChildEngineBuilder()
self._child_engine_builder = _WorkflowChildEngineBuilder(tenant_id=tenant_id)
self.graph_engine = GraphEngine(
workflow_id=workflow_id,
graph=graph,
@ -208,7 +213,7 @@ class WorkflowEntry:
max_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
)
self.graph_engine.layer(limits_layer)
self.graph_engine.layer(LLMQuotaLayer())
self.graph_engine.layer(LLMQuotaLayer(tenant_id=tenant_id))
# Add observability layer when OTel is enabled
if dify_config.ENABLE_OTEL or is_instrument_flag_enabled():

View File

@ -0,0 +1,117 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from core.app.llm.quota import (
deduct_llm_quota,
deduct_llm_quota_for_model,
ensure_llm_quota_available,
ensure_llm_quota_available_for_model,
)
from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from core.errors.error import QuotaExceededError
from graphon.model_runtime.entities.llm_entities import LLMUsage
from models.provider import ProviderType
def test_ensure_llm_quota_available_for_model_raises_when_system_model_is_exhausted() -> None:
provider_configuration = SimpleNamespace(
using_provider_type=ProviderType.SYSTEM,
get_provider_model=MagicMock(return_value=SimpleNamespace(status=ModelStatus.QUOTA_EXCEEDED)),
)
provider_manager = MagicMock()
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
with (
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
pytest.raises(QuotaExceededError, match="Model provider openai quota exceeded."),
):
ensure_llm_quota_available_for_model(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
)
def test_deduct_llm_quota_for_model_uses_identity_based_trial_billing() -> None:
usage = LLMUsage.empty_usage()
usage.total_tokens = 42
provider_configuration = SimpleNamespace(
using_provider_type=ProviderType.SYSTEM,
system_configuration=SimpleNamespace(
current_quota_type=ProviderQuotaType.TRIAL,
quota_configurations=[
SimpleNamespace(
quota_type=ProviderQuotaType.TRIAL,
quota_unit=QuotaUnit.TOKENS,
quota_limit=100,
)
],
),
)
provider_manager = MagicMock()
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
with (
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
patch("services.credit_pool_service.CreditPoolService.check_and_deduct_credits") as mock_deduct_credits,
):
deduct_llm_quota_for_model(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
usage=usage,
)
mock_deduct_credits.assert_called_once_with(
tenant_id="tenant-id",
credits_required=42,
)
def test_ensure_llm_quota_available_wrapper_warns_and_delegates() -> None:
model_instance = SimpleNamespace(
provider="openai",
model_name="gpt-4o",
provider_model_bundle=SimpleNamespace(configuration=SimpleNamespace(tenant_id="tenant-id")),
)
with (
pytest.deprecated_call(match="ensure_llm_quota_available\\(model_instance=.*deprecated"),
patch("core.app.llm.quota.ensure_llm_quota_available_for_model") as mock_ensure,
):
ensure_llm_quota_available(model_instance=model_instance)
mock_ensure.assert_called_once_with(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
)
def test_deduct_llm_quota_wrapper_warns_and_delegates() -> None:
usage = LLMUsage.empty_usage()
model_instance = SimpleNamespace(
provider="openai",
model_name="gpt-4o",
)
with (
pytest.deprecated_call(match="deduct_llm_quota\\(tenant_id=.*deprecated"),
patch("core.app.llm.quota.deduct_llm_quota_for_model") as mock_deduct,
):
deduct_llm_quota(
tenant_id="tenant-id",
model_instance=model_instance,
usage=usage,
)
mock_deduct.assert_called_once_with(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
usage=usage,
)

View File

@ -3,10 +3,8 @@ from datetime import datetime
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, UserFrom
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
from core.errors.error import QuotaExceededError
from core.model_manager import ModelInstance
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
from graphon.graph_engine.entities.commands import CommandType
from graphon.graph_events import NodeRunSucceededEvent
@ -14,17 +12,7 @@ from graphon.model_runtime.entities.llm_entities import LLMUsage
from graphon.node_events import NodeRunResult
def _build_dify_context() -> DifyRunContext:
return DifyRunContext(
tenant_id="tenant-id",
app_id="app-id",
user_id="user-id",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
)
def _build_succeeded_event() -> NodeRunSucceededEvent:
def _build_succeeded_event(*, provider: str = "openai", model_name: str = "gpt-4o") -> NodeRunSucceededEvent:
return NodeRunSucceededEvent(
id="execution-id",
node_id="llm-node-id",
@ -32,89 +20,80 @@ def _build_succeeded_event() -> NodeRunSucceededEvent:
start_at=datetime.now(),
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={"question": "hello"},
inputs={
"question": "hello",
"model_provider": provider,
"model_name": model_name,
},
llm_usage=LLMUsage.empty_usage(),
),
)
def _build_wrapped_model_instance() -> tuple[SimpleNamespace, ModelInstance]:
raw_model_instance = ModelInstance.__new__(ModelInstance)
return SimpleNamespace(_model_instance=raw_model_instance), raw_model_instance
def _build_public_model_identity(*, provider: str = "openai", model_name: str = "gpt-4o") -> SimpleNamespace:
return SimpleNamespace(provider=provider, name=model_name)
def _build_node(*, node_type: BuiltinNodeTypes = BuiltinNodeTypes.LLM) -> MagicMock:
node = MagicMock()
node.id = "node-id"
node.execution_id = "execution-id"
node.node_type = node_type
node.node_data = SimpleNamespace(model=_build_public_model_identity())
node.model_instance = SimpleNamespace(provider="stale-provider", model_name="stale-model")
return node
def test_deduct_quota_called_for_successful_llm_node() -> None:
layer = LLMQuotaLayer()
node = MagicMock()
node.id = "llm-node-id"
node.execution_id = "execution-id"
node.node_type = BuiltinNodeTypes.LLM
node.tenant_id = "tenant-id"
node.require_run_context_value.return_value = _build_dify_context()
node.model_instance, raw_model_instance = _build_wrapped_model_instance()
layer = LLMQuotaLayer(tenant_id="tenant-id")
node = _build_node(node_type=BuiltinNodeTypes.LLM)
result_event = _build_succeeded_event()
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model", autospec=True) as mock_deduct:
layer.on_node_run_end(node=node, error=None, result_event=result_event)
mock_deduct.assert_called_once_with(
tenant_id="tenant-id",
model_instance=raw_model_instance,
provider="openai",
model="gpt-4o",
usage=result_event.node_run_result.llm_usage,
)
def test_deduct_quota_called_for_question_classifier_node() -> None:
layer = LLMQuotaLayer()
node = MagicMock()
node.id = "question-classifier-node-id"
node.execution_id = "execution-id"
node.node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER
node.tenant_id = "tenant-id"
node.require_run_context_value.return_value = _build_dify_context()
node.model_instance, raw_model_instance = _build_wrapped_model_instance()
layer = LLMQuotaLayer(tenant_id="tenant-id")
node = _build_node(node_type=BuiltinNodeTypes.QUESTION_CLASSIFIER)
result_event = _build_succeeded_event(provider="anthropic", model_name="claude-3-7-sonnet")
result_event = _build_succeeded_event()
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model", autospec=True) as mock_deduct:
layer.on_node_run_end(node=node, error=None, result_event=result_event)
mock_deduct.assert_called_once_with(
tenant_id="tenant-id",
model_instance=raw_model_instance,
provider="anthropic",
model="claude-3-7-sonnet",
usage=result_event.node_run_result.llm_usage,
)
def test_non_llm_node_is_ignored() -> None:
layer = LLMQuotaLayer()
node = MagicMock()
node.id = "start-node-id"
node.execution_id = "execution-id"
node.node_type = BuiltinNodeTypes.START
node.tenant_id = "tenant-id"
node.require_run_context_value.return_value = _build_dify_context()
node._model_instance = object()
layer = LLMQuotaLayer(tenant_id="tenant-id")
node = _build_node(node_type=BuiltinNodeTypes.START)
result_event = _build_succeeded_event()
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model", autospec=True) as mock_deduct:
layer.on_node_run_end(node=node, error=None, result_event=result_event)
mock_deduct.assert_not_called()
def test_quota_error_is_handled_in_layer() -> None:
layer = LLMQuotaLayer()
node = MagicMock()
node.id = "llm-node-id"
node.execution_id = "execution-id"
node.node_type = BuiltinNodeTypes.LLM
node.tenant_id = "tenant-id"
node.require_run_context_value.return_value = _build_dify_context()
node.model_instance = object()
layer = LLMQuotaLayer(tenant_id="tenant-id")
node = _build_node(node_type=BuiltinNodeTypes.LLM)
result_event = _build_succeeded_event()
with patch(
"core.app.workflow.layers.llm_quota.deduct_llm_quota",
"core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model",
autospec=True,
side_effect=ValueError("quota exceeded"),
):
@ -122,23 +101,17 @@ def test_quota_error_is_handled_in_layer() -> None:
def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None:
layer = LLMQuotaLayer()
layer = LLMQuotaLayer(tenant_id="tenant-id")
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = MagicMock()
node.id = "llm-node-id"
node.execution_id = "execution-id"
node.node_type = BuiltinNodeTypes.LLM
node.tenant_id = "tenant-id"
node.require_run_context_value.return_value = _build_dify_context()
node.model_instance, _ = _build_wrapped_model_instance()
node = _build_node(node_type=BuiltinNodeTypes.LLM)
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.stop_event = stop_event
result_event = _build_succeeded_event()
with patch(
"core.app.workflow.layers.llm_quota.deduct_llm_quota",
"core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model",
autospec=True,
side_effect=QuotaExceededError("No credits remaining"),
):
@ -152,19 +125,16 @@ def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None:
def test_quota_precheck_failure_aborts_workflow_immediately() -> None:
layer = LLMQuotaLayer()
layer = LLMQuotaLayer(tenant_id="tenant-id")
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = MagicMock()
node.id = "llm-node-id"
node.node_type = BuiltinNodeTypes.LLM
node.model_instance, _ = _build_wrapped_model_instance()
node = _build_node(node_type=BuiltinNodeTypes.LLM)
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.stop_event = stop_event
with patch(
"core.app.workflow.layers.llm_quota.ensure_llm_quota_available",
"core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model",
autospec=True,
side_effect=QuotaExceededError("Model provider openai quota exceeded."),
):
@ -178,20 +148,44 @@ def test_quota_precheck_failure_aborts_workflow_immediately() -> None:
def test_quota_precheck_passes_without_abort() -> None:
layer = LLMQuotaLayer()
layer = LLMQuotaLayer(tenant_id="tenant-id")
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = MagicMock()
node.id = "llm-node-id"
node.node_type = BuiltinNodeTypes.LLM
node.model_instance, raw_model_instance = _build_wrapped_model_instance()
node = _build_node(node_type=BuiltinNodeTypes.LLM)
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.stop_event = stop_event
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available", autospec=True) as mock_check:
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check:
layer.on_node_run_start(node)
assert not stop_event.is_set()
mock_check.assert_called_once_with(model_instance=raw_model_instance)
mock_check.assert_called_once_with(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
)
layer.command_channel.send_command.assert_not_called()
def test_precheck_requires_public_node_model_config() -> None:
layer = LLMQuotaLayer(tenant_id="tenant-id")
node = _build_node(node_type=BuiltinNodeTypes.LLM)
node.node_data = SimpleNamespace()
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check:
layer.on_node_run_start(node)
mock_check.assert_not_called()
def test_deduction_requires_public_event_model_identity() -> None:
layer = LLMQuotaLayer(tenant_id="tenant-id")
node = _build_node(node_type=BuiltinNodeTypes.LLM)
result_event = _build_succeeded_event()
result_event.node_run_result.inputs = {"question": "hello"}
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model", autospec=True) as mock_deduct:
layer.on_node_run_end(node=node, error=None, result_event=result_event)
mock_deduct.assert_not_called()

View File

@ -7,7 +7,6 @@ import pytest
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
from core.model_manager import ModelInstance
from core.workflow import workflow_entry
from core.workflow.system_variables import default_system_variables
from graphon.entities.base_node_data import BaseNodeData
@ -16,10 +15,12 @@ from graphon.errors import WorkflowNodeRunFailedError
from graphon.file import File, FileTransferMethod, FileType
from graphon.graph import Graph
from graphon.graph_events import GraphRunFailedEvent
from graphon.model_runtime.entities.llm_entities import LLMUsage
from graphon.model_runtime.entities.llm_entities import LLMMode, LLMUsage
from graphon.node_events import NodeRunResult
from graphon.nodes import BuiltinNodeTypes
from graphon.nodes.base.node import Node
from graphon.nodes.llm.entities import ContextConfig, LLMNodeData, ModelConfig
from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData
from graphon.runtime import ChildGraphNotFoundError, VariablePool
from graphon.variables.variables import StringVariable
from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool
@ -29,9 +30,30 @@ def _build_typed_node_config(node_type: NodeType):
return {"id": "node-id", "data": BaseNodeData(type=node_type)}
def _build_wrapped_model_instance() -> tuple[SimpleNamespace, ModelInstance]:
raw_model_instance = ModelInstance.__new__(ModelInstance)
return SimpleNamespace(_model_instance=raw_model_instance), raw_model_instance
def _build_model_config(*, provider: str = "openai", model_name: str = "gpt-4o") -> ModelConfig:
return ModelConfig(provider=provider, name=model_name, mode=LLMMode.CHAT)
def _build_llm_node_data(*, provider: str = "openai", model_name: str = "gpt-4o") -> LLMNodeData:
return LLMNodeData(
type=BuiltinNodeTypes.LLM,
title="Child Model",
model=_build_model_config(provider=provider, model_name=model_name),
prompt_template=[],
context=ContextConfig(enabled=False),
)
def _build_question_classifier_node_data(
*, provider: str = "openai", model_name: str = "gpt-4o"
) -> QuestionClassifierNodeData:
return QuestionClassifierNodeData(
type=BuiltinNodeTypes.QUESTION_CLASSIFIER,
title="Child Model",
query_variable_selector=["sys", "query"],
model=_build_model_config(provider=provider, model_name=model_name),
classes=[],
)
class _FakeModelNodeMixin:
@ -40,22 +62,26 @@ class _FakeModelNodeMixin:
return "1"
def post_init(self) -> None:
self.model_instance, self.raw_model_instance = _build_wrapped_model_instance()
self.model_instance = SimpleNamespace(provider="stale-provider", model_name="stale-model")
self.usage_snapshot = LLMUsage.empty_usage()
self.usage_snapshot.total_tokens = 1
def _run(self) -> NodeRunResult:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={
"model_provider": self.node_data.model.provider,
"model_name": self.node_data.model.name,
},
llm_usage=self.usage_snapshot,
)
class _FakeLLMNode(_FakeModelNodeMixin, Node[BaseNodeData]):
class _FakeLLMNode(_FakeModelNodeMixin, Node[LLMNodeData]):
node_type = BuiltinNodeTypes.LLM
class _FakeQuestionClassifierNode(_FakeModelNodeMixin, Node[BaseNodeData]):
class _FakeQuestionClassifierNode(_FakeModelNodeMixin, Node[QuestionClassifierNodeData]):
node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER
@ -75,7 +101,7 @@ class TestWorkflowChildEngineBuilder:
assert result is expected
def test_build_child_engine_raises_when_root_node_is_missing(self):
builder = workflow_entry._WorkflowChildEngineBuilder()
builder = workflow_entry._WorkflowChildEngineBuilder(tenant_id="tenant-id")
graph_init_params = SimpleNamespace(graph_config={"nodes": []})
parent_graph_runtime_state = SimpleNamespace(
execution_context=sentinel.execution_context,
@ -92,7 +118,7 @@ class TestWorkflowChildEngineBuilder:
)
def test_build_child_engine_constructs_graph_engine_with_quota_layer_only(self):
builder = workflow_entry._WorkflowChildEngineBuilder()
builder = workflow_entry._WorkflowChildEngineBuilder(tenant_id="tenant-id")
graph_init_params = SimpleNamespace(graph_config={"nodes": [{"id": "root"}]})
parent_graph_runtime_state = SimpleNamespace(
execution_context=sentinel.execution_context,
@ -114,7 +140,7 @@ class TestWorkflowChildEngineBuilder:
patch.object(workflow_entry, "GraphEngine", return_value=child_engine) as graph_engine_cls,
patch.object(workflow_entry, "GraphEngineConfig", return_value=sentinel.graph_engine_config),
patch.object(workflow_entry, "InMemoryChannel", return_value=sentinel.command_channel),
patch.object(workflow_entry, "LLMQuotaLayer", return_value=sentinel.llm_quota_layer),
patch.object(workflow_entry, "LLMQuotaLayer", return_value=sentinel.llm_quota_layer) as llm_quota_layer_cls,
):
result = builder.build_child_engine(
workflow_id="workflow-id",
@ -147,11 +173,12 @@ class TestWorkflowChildEngineBuilder:
config=sentinel.graph_engine_config,
child_engine_builder=builder,
)
llm_quota_layer_cls.assert_called_once_with(tenant_id="tenant-id")
assert child_engine.layer.call_args_list == [((sentinel.llm_quota_layer,), {})]
@pytest.mark.parametrize("node_cls", [_FakeLLMNode, _FakeQuestionClassifierNode])
def test_build_child_engine_runs_llm_quota_layer_for_child_model_nodes(self, node_cls):
builder = workflow_entry._WorkflowChildEngineBuilder()
builder = workflow_entry._WorkflowChildEngineBuilder(tenant_id="tenant-id")
graph_init_params = build_test_graph_init_params(
graph_config={"nodes": [{"id": "root"}], "edges": []},
)
@ -163,12 +190,10 @@ class TestWorkflowChildEngineBuilder:
def build_graph(*, graph_config, node_factory, root_node_id):
_ = graph_config
node_data = _build_llm_node_data() if node_cls is _FakeLLMNode else _build_question_classifier_node_data()
node = node_cls(
node_id=root_node_id,
data=BaseNodeData(
type=node_cls.node_type,
title="Child Model",
),
data=node_data,
graph_init_params=node_factory.graph_init_params,
graph_runtime_state=node_factory.graph_runtime_state,
)
@ -191,8 +216,8 @@ class TestWorkflowChildEngineBuilder:
),
),
patch.object(workflow_entry.Graph, "init", side_effect=build_graph),
patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available") as ensure_quota,
patch("core.app.workflow.layers.llm_quota.deduct_llm_quota") as deduct_quota,
patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model") as ensure_quota,
patch("core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model") as deduct_quota,
):
child_engine = builder.build_child_engine(
workflow_id="workflow-id",
@ -203,10 +228,15 @@ class TestWorkflowChildEngineBuilder:
list(child_engine.run())
node = created_node["node"]
ensure_quota.assert_called_once_with(model_instance=node.raw_model_instance)
ensure_quota.assert_called_once_with(
tenant_id="tenant-id",
provider=node.node_data.model.provider,
model=node.node_data.model.name,
)
deduct_quota.assert_called_once_with(
tenant_id="tenant",
model_instance=node.raw_model_instance,
tenant_id="tenant-id",
provider=node.node_data.model.provider,
model=node.node_data.model.name,
usage=node.usage_snapshot,
)
@ -252,7 +282,7 @@ class TestWorkflowEntryInit:
"ExecutionLimitsLayer",
return_value=execution_limits_layer,
) as execution_limits_layer_cls,
patch.object(workflow_entry, "LLMQuotaLayer", return_value=llm_quota_layer),
patch.object(workflow_entry, "LLMQuotaLayer", return_value=llm_quota_layer) as llm_quota_layer_cls,
patch.object(workflow_entry, "ObservabilityLayer", return_value=observability_layer),
):
entry = workflow_entry.WorkflowEntry(
@ -291,6 +321,7 @@ class TestWorkflowEntryInit:
max_steps=workflow_entry.dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_time=workflow_entry.dify_config.WORKFLOW_MAX_EXECUTION_TIME,
)
llm_quota_layer_cls.assert_called_once_with(tenant_id="tenant-id")
assert graph_engine.layer.call_args_list == [
((debug_layer,), {}),
((execution_limits_layer,), {}),