diff --git a/api/core/app/llm/__init__.py b/api/core/app/llm/__init__.py index f069bede74..d20a5b2344 100644 --- a/api/core/app/llm/__init__.py +++ b/api/core/app/llm/__init__.py @@ -1,5 +1,15 @@ """LLM-related application services.""" -from .quota import deduct_llm_quota, ensure_llm_quota_available +from .quota import ( + deduct_llm_quota, + deduct_llm_quota_for_model, + ensure_llm_quota_available, + ensure_llm_quota_available_for_model, +) -__all__ = ["deduct_llm_quota", "ensure_llm_quota_available"] +__all__ = [ + "deduct_llm_quota", + "deduct_llm_quota_for_model", + "ensure_llm_quota_available", + "ensure_llm_quota_available_for_model", +] diff --git a/api/core/app/llm/quota.py b/api/core/app/llm/quota.py index b6039e1e4e..723b090de9 100644 --- a/api/core/app/llm/quota.py +++ b/api/core/app/llm/quota.py @@ -1,3 +1,12 @@ +"""Tenant-scoped helpers for checking and deducting LLM provider quota. + +Workflow callers now bill quota from public model identity instead of passing a +fully prepared ``ModelInstance``. Keep the model-instance helpers as thin, +deprecated adapters so non-workflow code can move independently. +""" + +import warnings + from sqlalchemy import update from sqlalchemy.orm import sessionmaker @@ -6,32 +15,41 @@ from core.entities.model_entities import ModelStatus from core.entities.provider_entities import ProviderQuotaType, QuotaUnit from core.errors.error import QuotaExceededError from core.model_manager import ModelInstance +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from extensions.ext_database import db from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.entities.model_entities import ModelType from libs.datetime_utils import naive_utc_now from models.provider import Provider, ProviderType from models.provider_ids import ModelProviderID -def ensure_llm_quota_available(*, model_instance: ModelInstance) -> None: - provider_model_bundle = model_instance.provider_model_bundle - provider_configuration = provider_model_bundle.configuration +def _get_provider_configuration(*, tenant_id: str, provider: str): + """Resolve the tenant-bound provider configuration for quota decisions.""" + provider_manager = create_plugin_provider_manager(tenant_id=tenant_id) + provider_configuration = provider_manager.get_configurations(tenant_id).get(provider) + if provider_configuration is None: + raise ValueError(f"Provider {provider} does not exist.") + return provider_configuration + +def ensure_llm_quota_available_for_model(*, tenant_id: str, provider: str, model: str) -> None: + """Raise when a tenant-bound system provider model is already out of quota.""" + provider_configuration = _get_provider_configuration(tenant_id=tenant_id, provider=provider) if provider_configuration.using_provider_type != ProviderType.SYSTEM: return provider_model = provider_configuration.get_provider_model( - model_type=model_instance.model_type_instance.model_type, - model=model_instance.model_name, + model_type=ModelType.LLM, + model=model, ) if provider_model and provider_model.status == ModelStatus.QUOTA_EXCEEDED: - raise QuotaExceededError(f"Model provider {model_instance.provider} quota exceeded.") + raise QuotaExceededError(f"Model provider {provider} quota exceeded.") -def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None: - provider_model_bundle = model_instance.provider_model_bundle - provider_configuration = provider_model_bundle.configuration - +def deduct_llm_quota_for_model(*, tenant_id: str, provider: str, model: str, usage: LLMUsage) -> None: + """Deduct tenant-bound quota for the resolved LLM model identity.""" + provider_configuration = _get_provider_configuration(tenant_id=tenant_id, provider=provider) if provider_configuration.using_provider_type != ProviderType.SYSTEM: return @@ -52,7 +70,7 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL if quota_unit == QuotaUnit.TOKENS: used_quota = usage.total_tokens elif quota_unit == QuotaUnit.CREDITS: - used_quota = dify_config.get_model_credits(model_instance.model_name) + used_quota = dify_config.get_model_credits(model) else: used_quota = 1 @@ -80,7 +98,7 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL .where( Provider.tenant_id == tenant_id, # TODO: Use provider name with prefix after the data migration. - Provider.provider_name == ModelProviderID(model_instance.provider).provider_name, + Provider.provider_name == ModelProviderID(provider).provider_name, Provider.provider_type == ProviderType.SYSTEM.value, Provider.quota_type == system_configuration.current_quota_type, Provider.quota_limit > Provider.quota_used, @@ -91,3 +109,34 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL ) ) session.execute(stmt) + + +def ensure_llm_quota_available(*, model_instance: ModelInstance) -> None: + """Deprecated compatibility wrapper for callers that still pass ModelInstance.""" + warnings.warn( + "ensure_llm_quota_available(model_instance=...) is deprecated; " + "use ensure_llm_quota_available_for_model(...) instead.", + DeprecationWarning, + stacklevel=2, + ) + ensure_llm_quota_available_for_model( + tenant_id=model_instance.provider_model_bundle.configuration.tenant_id, + provider=model_instance.provider, + model=model_instance.model_name, + ) + + +def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None: + """Deprecated compatibility wrapper for callers that still pass ModelInstance.""" + warnings.warn( + "deduct_llm_quota(tenant_id=..., model_instance=..., usage=...) is deprecated; " + "use deduct_llm_quota_for_model(...) instead.", + DeprecationWarning, + stacklevel=2, + ) + deduct_llm_quota_for_model( + tenant_id=tenant_id, + provider=model_instance.provider, + model=model_instance.model_name, + usage=usage, + ) diff --git a/api/core/app/workflow/layers/llm_quota.py b/api/core/app/workflow/layers/llm_quota.py index 69be25bfb5..c72dde8481 100644 --- a/api/core/app/workflow/layers/llm_quota.py +++ b/api/core/app/workflow/layers/llm_quota.py @@ -1,16 +1,19 @@ """ LLM quota deduction layer for GraphEngine. -This layer centralizes model-quota deduction outside node implementations. +This layer centralizes model-quota handling outside node implementations. + +Graphon LLM-backed nodes expose provider/model identity through public node +configuration and, after execution, through ``node_run_result.inputs``. Resolve +quota billing from that public identity instead of depending on +``ModelInstance`` reconstruction inside the workflow layer. """ import logging from typing import final, override -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext -from core.app.llm import deduct_llm_quota, ensure_llm_quota_available +from core.app.llm import deduct_llm_quota_for_model, ensure_llm_quota_available_for_model from core.errors.error import QuotaExceededError -from core.model_manager import ModelInstance from graphon.enums import BuiltinNodeTypes from graphon.graph_engine.entities.commands import AbortCommand, CommandType from graphon.graph_engine.layers import GraphEngineLayer @@ -18,14 +21,25 @@ from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, NodeRunSu from graphon.nodes.base.node import Node logger = logging.getLogger(__name__) +_QUOTA_NODE_TYPES = frozenset( + [ + BuiltinNodeTypes.LLM, + BuiltinNodeTypes.PARAMETER_EXTRACTOR, + BuiltinNodeTypes.QUESTION_CLASSIFIER, + ] +) @final class LLMQuotaLayer(GraphEngineLayer): - """Graph layer that applies LLM quota deduction after node execution.""" + """Graph layer that applies tenant-scoped quota checks to LLM-backed nodes.""" - def __init__(self) -> None: + tenant_id: str + _abort_sent: bool + + def __init__(self, tenant_id: str) -> None: super().__init__() + self.tenant_id = tenant_id self._abort_sent = False @override @@ -45,12 +59,20 @@ class LLMQuotaLayer(GraphEngineLayer): if self._abort_sent: return - model_instance = self._extract_model_instance(node) - if model_instance is None: + if not self._supports_quota(node): return + model_identity = self._extract_model_identity_from_node(node) + if model_identity is None: + return + + provider, model_name = model_identity try: - ensure_llm_quota_available(model_instance=model_instance) + ensure_llm_quota_available_for_model( + tenant_id=self.tenant_id, + provider=provider, + model=model_name, + ) except QuotaExceededError as exc: self._set_stop_event(node) self._send_abort_command(reason=str(exc)) @@ -60,18 +82,20 @@ class LLMQuotaLayer(GraphEngineLayer): def on_node_run_end( self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None ) -> None: - if error is not None or not isinstance(result_event, NodeRunSucceededEvent): + if error is not None or not isinstance(result_event, NodeRunSucceededEvent) or not self._supports_quota(node): return - model_instance = self._extract_model_instance(node) - if model_instance is None: + model_identity = self._extract_model_identity_from_result_event(result_event) + if model_identity is None: return + provider, model_name = model_identity + try: - dify_ctx = DifyRunContext.model_validate(node.require_run_context_value(DIFY_RUN_CONTEXT_KEY)) - deduct_llm_quota( - tenant_id=dify_ctx.tenant_id, - model_instance=model_instance, + deduct_llm_quota_for_model( + tenant_id=self.tenant_id, + provider=provider, + model=model_name, usage=result_event.node_run_result.llm_usage, ) except QuotaExceededError as exc: @@ -103,35 +127,38 @@ class LLMQuotaLayer(GraphEngineLayer): logger.exception("Failed to send quota abort command") @staticmethod - def _extract_model_instance(node: Node) -> ModelInstance | None: - match node.node_type: - case BuiltinNodeTypes.LLM | BuiltinNodeTypes.PARAMETER_EXTRACTOR | BuiltinNodeTypes.QUESTION_CLASSIFIER: - pass - case _: - return None + def _supports_quota(node: Node) -> bool: + return node.node_type in _QUOTA_NODE_TYPES - try: - model_instance = getattr(node, "model_instance", None) - except AttributeError: + @staticmethod + def _extract_model_identity_from_result_event(result_event: NodeRunSucceededEvent) -> tuple[str, str] | None: + provider = result_event.node_run_result.inputs.get("model_provider") + model_name = result_event.node_run_result.inputs.get("model_name") + if isinstance(provider, str) and provider and isinstance(model_name, str) and model_name: + return provider, model_name + return None + + @staticmethod + def _extract_model_identity_from_node(node: Node) -> tuple[str, str] | None: + node_data = getattr(node, "node_data", None) + if node_data is None: + node_data = getattr(node, "data", None) + + model_config = getattr(node_data, "model", None) + if model_config is None: logger.warning( - "LLMQuotaLayer skipped quota deduction because node does not expose a model instance, node_id=%s", + "LLMQuotaLayer skipped quota handling because node model config is missing, node_id=%s", node.id, ) return None - if isinstance(model_instance, ModelInstance): - return model_instance - - raw_model_instance = getattr(model_instance, "_model_instance", None) - if isinstance(raw_model_instance, ModelInstance): - return raw_model_instance - - private_model_instance = getattr(node, "_model_instance", None) - if isinstance(private_model_instance, ModelInstance): - return private_model_instance - - wrapped_private_model_instance = getattr(private_model_instance, "_model_instance", None) - if isinstance(wrapped_private_model_instance, ModelInstance): - return wrapped_private_model_instance + provider = getattr(model_config, "provider", None) + model_name = getattr(model_config, "name", None) + if isinstance(provider, str) and provider and isinstance(model_name, str) and model_name: + return provider, model_name + logger.warning( + "LLMQuotaLayer skipped quota handling because node model identity is invalid, node_id=%s", + node.id, + ) return None diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 4e2f603e5b..3019704dac 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -46,6 +46,11 @@ _file_access_controller = DatabaseFileAccessController() class _WorkflowChildEngineBuilder: + tenant_id: str + + def __init__(self, *, tenant_id: str) -> None: + self.tenant_id = tenant_id + @staticmethod def _has_node_id(graph_config: Mapping[str, Any], node_id: str) -> bool | None: """ @@ -107,7 +112,7 @@ class _WorkflowChildEngineBuilder: config=config, child_engine_builder=self, ) - child_engine.layer(LLMQuotaLayer()) + child_engine.layer(LLMQuotaLayer(tenant_id=self.tenant_id)) return child_engine @@ -176,7 +181,7 @@ class WorkflowEntry: self.command_channel = command_channel execution_context = capture_current_context() graph_runtime_state.execution_context = execution_context - self._child_engine_builder = _WorkflowChildEngineBuilder() + self._child_engine_builder = _WorkflowChildEngineBuilder(tenant_id=tenant_id) self.graph_engine = GraphEngine( workflow_id=workflow_id, graph=graph, @@ -208,7 +213,7 @@ class WorkflowEntry: max_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME ) self.graph_engine.layer(limits_layer) - self.graph_engine.layer(LLMQuotaLayer()) + self.graph_engine.layer(LLMQuotaLayer(tenant_id=tenant_id)) # Add observability layer when OTel is enabled if dify_config.ENABLE_OTEL or is_instrument_flag_enabled(): diff --git a/api/tests/unit_tests/core/app/test_llm_quota.py b/api/tests/unit_tests/core/app/test_llm_quota.py new file mode 100644 index 0000000000..b94e3aa758 --- /dev/null +++ b/api/tests/unit_tests/core/app/test_llm_quota.py @@ -0,0 +1,117 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.app.llm.quota import ( + deduct_llm_quota, + deduct_llm_quota_for_model, + ensure_llm_quota_available, + ensure_llm_quota_available_for_model, +) +from core.entities.model_entities import ModelStatus +from core.entities.provider_entities import ProviderQuotaType, QuotaUnit +from core.errors.error import QuotaExceededError +from graphon.model_runtime.entities.llm_entities import LLMUsage +from models.provider import ProviderType + + +def test_ensure_llm_quota_available_for_model_raises_when_system_model_is_exhausted() -> None: + provider_configuration = SimpleNamespace( + using_provider_type=ProviderType.SYSTEM, + get_provider_model=MagicMock(return_value=SimpleNamespace(status=ModelStatus.QUOTA_EXCEEDED)), + ) + provider_manager = MagicMock() + provider_manager.get_configurations.return_value.get.return_value = provider_configuration + + with ( + patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager), + pytest.raises(QuotaExceededError, match="Model provider openai quota exceeded."), + ): + ensure_llm_quota_available_for_model( + tenant_id="tenant-id", + provider="openai", + model="gpt-4o", + ) + + +def test_deduct_llm_quota_for_model_uses_identity_based_trial_billing() -> None: + usage = LLMUsage.empty_usage() + usage.total_tokens = 42 + + provider_configuration = SimpleNamespace( + using_provider_type=ProviderType.SYSTEM, + system_configuration=SimpleNamespace( + current_quota_type=ProviderQuotaType.TRIAL, + quota_configurations=[ + SimpleNamespace( + quota_type=ProviderQuotaType.TRIAL, + quota_unit=QuotaUnit.TOKENS, + quota_limit=100, + ) + ], + ), + ) + provider_manager = MagicMock() + provider_manager.get_configurations.return_value.get.return_value = provider_configuration + + with ( + patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager), + patch("services.credit_pool_service.CreditPoolService.check_and_deduct_credits") as mock_deduct_credits, + ): + deduct_llm_quota_for_model( + tenant_id="tenant-id", + provider="openai", + model="gpt-4o", + usage=usage, + ) + + mock_deduct_credits.assert_called_once_with( + tenant_id="tenant-id", + credits_required=42, + ) + + +def test_ensure_llm_quota_available_wrapper_warns_and_delegates() -> None: + model_instance = SimpleNamespace( + provider="openai", + model_name="gpt-4o", + provider_model_bundle=SimpleNamespace(configuration=SimpleNamespace(tenant_id="tenant-id")), + ) + + with ( + pytest.deprecated_call(match="ensure_llm_quota_available\\(model_instance=.*deprecated"), + patch("core.app.llm.quota.ensure_llm_quota_available_for_model") as mock_ensure, + ): + ensure_llm_quota_available(model_instance=model_instance) + + mock_ensure.assert_called_once_with( + tenant_id="tenant-id", + provider="openai", + model="gpt-4o", + ) + + +def test_deduct_llm_quota_wrapper_warns_and_delegates() -> None: + usage = LLMUsage.empty_usage() + model_instance = SimpleNamespace( + provider="openai", + model_name="gpt-4o", + ) + + with ( + pytest.deprecated_call(match="deduct_llm_quota\\(tenant_id=.*deprecated"), + patch("core.app.llm.quota.deduct_llm_quota_for_model") as mock_deduct, + ): + deduct_llm_quota( + tenant_id="tenant-id", + model_instance=model_instance, + usage=usage, + ) + + mock_deduct.assert_called_once_with( + tenant_id="tenant-id", + provider="openai", + model="gpt-4o", + usage=usage, + ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py index 5d6667257f..91ef57f27a 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py @@ -3,10 +3,8 @@ from datetime import datetime from types import SimpleNamespace from unittest.mock import MagicMock, patch -from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, UserFrom from core.app.workflow.layers.llm_quota import LLMQuotaLayer from core.errors.error import QuotaExceededError -from core.model_manager import ModelInstance from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from graphon.graph_engine.entities.commands import CommandType from graphon.graph_events import NodeRunSucceededEvent @@ -14,17 +12,7 @@ from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.node_events import NodeRunResult -def _build_dify_context() -> DifyRunContext: - return DifyRunContext( - tenant_id="tenant-id", - app_id="app-id", - user_id="user-id", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - ) - - -def _build_succeeded_event() -> NodeRunSucceededEvent: +def _build_succeeded_event(*, provider: str = "openai", model_name: str = "gpt-4o") -> NodeRunSucceededEvent: return NodeRunSucceededEvent( id="execution-id", node_id="llm-node-id", @@ -32,89 +20,80 @@ def _build_succeeded_event() -> NodeRunSucceededEvent: start_at=datetime.now(), node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs={"question": "hello"}, + inputs={ + "question": "hello", + "model_provider": provider, + "model_name": model_name, + }, llm_usage=LLMUsage.empty_usage(), ), ) -def _build_wrapped_model_instance() -> tuple[SimpleNamespace, ModelInstance]: - raw_model_instance = ModelInstance.__new__(ModelInstance) - return SimpleNamespace(_model_instance=raw_model_instance), raw_model_instance +def _build_public_model_identity(*, provider: str = "openai", model_name: str = "gpt-4o") -> SimpleNamespace: + return SimpleNamespace(provider=provider, name=model_name) + + +def _build_node(*, node_type: BuiltinNodeTypes = BuiltinNodeTypes.LLM) -> MagicMock: + node = MagicMock() + node.id = "node-id" + node.execution_id = "execution-id" + node.node_type = node_type + node.node_data = SimpleNamespace(model=_build_public_model_identity()) + node.model_instance = SimpleNamespace(provider="stale-provider", model_name="stale-model") + return node def test_deduct_quota_called_for_successful_llm_node() -> None: - layer = LLMQuotaLayer() - node = MagicMock() - node.id = "llm-node-id" - node.execution_id = "execution-id" - node.node_type = BuiltinNodeTypes.LLM - node.tenant_id = "tenant-id" - node.require_run_context_value.return_value = _build_dify_context() - node.model_instance, raw_model_instance = _build_wrapped_model_instance() - + layer = LLMQuotaLayer(tenant_id="tenant-id") + node = _build_node(node_type=BuiltinNodeTypes.LLM) result_event = _build_succeeded_event() - with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct: + + with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model", autospec=True) as mock_deduct: layer.on_node_run_end(node=node, error=None, result_event=result_event) mock_deduct.assert_called_once_with( tenant_id="tenant-id", - model_instance=raw_model_instance, + provider="openai", + model="gpt-4o", usage=result_event.node_run_result.llm_usage, ) def test_deduct_quota_called_for_question_classifier_node() -> None: - layer = LLMQuotaLayer() - node = MagicMock() - node.id = "question-classifier-node-id" - node.execution_id = "execution-id" - node.node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER - node.tenant_id = "tenant-id" - node.require_run_context_value.return_value = _build_dify_context() - node.model_instance, raw_model_instance = _build_wrapped_model_instance() + layer = LLMQuotaLayer(tenant_id="tenant-id") + node = _build_node(node_type=BuiltinNodeTypes.QUESTION_CLASSIFIER) + result_event = _build_succeeded_event(provider="anthropic", model_name="claude-3-7-sonnet") - result_event = _build_succeeded_event() - with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct: + with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model", autospec=True) as mock_deduct: layer.on_node_run_end(node=node, error=None, result_event=result_event) mock_deduct.assert_called_once_with( tenant_id="tenant-id", - model_instance=raw_model_instance, + provider="anthropic", + model="claude-3-7-sonnet", usage=result_event.node_run_result.llm_usage, ) def test_non_llm_node_is_ignored() -> None: - layer = LLMQuotaLayer() - node = MagicMock() - node.id = "start-node-id" - node.execution_id = "execution-id" - node.node_type = BuiltinNodeTypes.START - node.tenant_id = "tenant-id" - node.require_run_context_value.return_value = _build_dify_context() - node._model_instance = object() - + layer = LLMQuotaLayer(tenant_id="tenant-id") + node = _build_node(node_type=BuiltinNodeTypes.START) result_event = _build_succeeded_event() - with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct: + + with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model", autospec=True) as mock_deduct: layer.on_node_run_end(node=node, error=None, result_event=result_event) mock_deduct.assert_not_called() def test_quota_error_is_handled_in_layer() -> None: - layer = LLMQuotaLayer() - node = MagicMock() - node.id = "llm-node-id" - node.execution_id = "execution-id" - node.node_type = BuiltinNodeTypes.LLM - node.tenant_id = "tenant-id" - node.require_run_context_value.return_value = _build_dify_context() - node.model_instance = object() - + layer = LLMQuotaLayer(tenant_id="tenant-id") + node = _build_node(node_type=BuiltinNodeTypes.LLM) result_event = _build_succeeded_event() + with patch( - "core.app.workflow.layers.llm_quota.deduct_llm_quota", + "core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model", autospec=True, side_effect=ValueError("quota exceeded"), ): @@ -122,23 +101,17 @@ def test_quota_error_is_handled_in_layer() -> None: def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None: - layer = LLMQuotaLayer() + layer = LLMQuotaLayer(tenant_id="tenant-id") stop_event = threading.Event() layer.command_channel = MagicMock() - node = MagicMock() - node.id = "llm-node-id" - node.execution_id = "execution-id" - node.node_type = BuiltinNodeTypes.LLM - node.tenant_id = "tenant-id" - node.require_run_context_value.return_value = _build_dify_context() - node.model_instance, _ = _build_wrapped_model_instance() + node = _build_node(node_type=BuiltinNodeTypes.LLM) node.graph_runtime_state = MagicMock() node.graph_runtime_state.stop_event = stop_event result_event = _build_succeeded_event() with patch( - "core.app.workflow.layers.llm_quota.deduct_llm_quota", + "core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model", autospec=True, side_effect=QuotaExceededError("No credits remaining"), ): @@ -152,19 +125,16 @@ def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None: def test_quota_precheck_failure_aborts_workflow_immediately() -> None: - layer = LLMQuotaLayer() + layer = LLMQuotaLayer(tenant_id="tenant-id") stop_event = threading.Event() layer.command_channel = MagicMock() - node = MagicMock() - node.id = "llm-node-id" - node.node_type = BuiltinNodeTypes.LLM - node.model_instance, _ = _build_wrapped_model_instance() + node = _build_node(node_type=BuiltinNodeTypes.LLM) node.graph_runtime_state = MagicMock() node.graph_runtime_state.stop_event = stop_event with patch( - "core.app.workflow.layers.llm_quota.ensure_llm_quota_available", + "core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True, side_effect=QuotaExceededError("Model provider openai quota exceeded."), ): @@ -178,20 +148,44 @@ def test_quota_precheck_failure_aborts_workflow_immediately() -> None: def test_quota_precheck_passes_without_abort() -> None: - layer = LLMQuotaLayer() + layer = LLMQuotaLayer(tenant_id="tenant-id") stop_event = threading.Event() layer.command_channel = MagicMock() - node = MagicMock() - node.id = "llm-node-id" - node.node_type = BuiltinNodeTypes.LLM - node.model_instance, raw_model_instance = _build_wrapped_model_instance() + node = _build_node(node_type=BuiltinNodeTypes.LLM) node.graph_runtime_state = MagicMock() node.graph_runtime_state.stop_event = stop_event - with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available", autospec=True) as mock_check: + with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check: layer.on_node_run_start(node) assert not stop_event.is_set() - mock_check.assert_called_once_with(model_instance=raw_model_instance) + mock_check.assert_called_once_with( + tenant_id="tenant-id", + provider="openai", + model="gpt-4o", + ) layer.command_channel.send_command.assert_not_called() + + +def test_precheck_requires_public_node_model_config() -> None: + layer = LLMQuotaLayer(tenant_id="tenant-id") + node = _build_node(node_type=BuiltinNodeTypes.LLM) + node.node_data = SimpleNamespace() + + with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check: + layer.on_node_run_start(node) + + mock_check.assert_not_called() + + +def test_deduction_requires_public_event_model_identity() -> None: + layer = LLMQuotaLayer(tenant_id="tenant-id") + node = _build_node(node_type=BuiltinNodeTypes.LLM) + result_event = _build_succeeded_event() + result_event.node_run_result.inputs = {"question": "hello"} + + with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model", autospec=True) as mock_deduct: + layer.on_node_run_end(node=node, error=None, result_event=result_event) + + mock_deduct.assert_not_called() diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py b/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py index c830255926..cf60d99e82 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py @@ -7,7 +7,6 @@ import pytest from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.model_manager import ModelInstance from core.workflow import workflow_entry from core.workflow.system_variables import default_system_variables from graphon.entities.base_node_data import BaseNodeData @@ -16,10 +15,12 @@ from graphon.errors import WorkflowNodeRunFailedError from graphon.file import File, FileTransferMethod, FileType from graphon.graph import Graph from graphon.graph_events import GraphRunFailedEvent -from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.entities.llm_entities import LLMMode, LLMUsage from graphon.node_events import NodeRunResult from graphon.nodes import BuiltinNodeTypes from graphon.nodes.base.node import Node +from graphon.nodes.llm.entities import ContextConfig, LLMNodeData, ModelConfig +from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData from graphon.runtime import ChildGraphNotFoundError, VariablePool from graphon.variables.variables import StringVariable from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool @@ -29,9 +30,30 @@ def _build_typed_node_config(node_type: NodeType): return {"id": "node-id", "data": BaseNodeData(type=node_type)} -def _build_wrapped_model_instance() -> tuple[SimpleNamespace, ModelInstance]: - raw_model_instance = ModelInstance.__new__(ModelInstance) - return SimpleNamespace(_model_instance=raw_model_instance), raw_model_instance +def _build_model_config(*, provider: str = "openai", model_name: str = "gpt-4o") -> ModelConfig: + return ModelConfig(provider=provider, name=model_name, mode=LLMMode.CHAT) + + +def _build_llm_node_data(*, provider: str = "openai", model_name: str = "gpt-4o") -> LLMNodeData: + return LLMNodeData( + type=BuiltinNodeTypes.LLM, + title="Child Model", + model=_build_model_config(provider=provider, model_name=model_name), + prompt_template=[], + context=ContextConfig(enabled=False), + ) + + +def _build_question_classifier_node_data( + *, provider: str = "openai", model_name: str = "gpt-4o" +) -> QuestionClassifierNodeData: + return QuestionClassifierNodeData( + type=BuiltinNodeTypes.QUESTION_CLASSIFIER, + title="Child Model", + query_variable_selector=["sys", "query"], + model=_build_model_config(provider=provider, model_name=model_name), + classes=[], + ) class _FakeModelNodeMixin: @@ -40,22 +62,26 @@ class _FakeModelNodeMixin: return "1" def post_init(self) -> None: - self.model_instance, self.raw_model_instance = _build_wrapped_model_instance() + self.model_instance = SimpleNamespace(provider="stale-provider", model_name="stale-model") self.usage_snapshot = LLMUsage.empty_usage() self.usage_snapshot.total_tokens = 1 def _run(self) -> NodeRunResult: return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={ + "model_provider": self.node_data.model.provider, + "model_name": self.node_data.model.name, + }, llm_usage=self.usage_snapshot, ) -class _FakeLLMNode(_FakeModelNodeMixin, Node[BaseNodeData]): +class _FakeLLMNode(_FakeModelNodeMixin, Node[LLMNodeData]): node_type = BuiltinNodeTypes.LLM -class _FakeQuestionClassifierNode(_FakeModelNodeMixin, Node[BaseNodeData]): +class _FakeQuestionClassifierNode(_FakeModelNodeMixin, Node[QuestionClassifierNodeData]): node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER @@ -75,7 +101,7 @@ class TestWorkflowChildEngineBuilder: assert result is expected def test_build_child_engine_raises_when_root_node_is_missing(self): - builder = workflow_entry._WorkflowChildEngineBuilder() + builder = workflow_entry._WorkflowChildEngineBuilder(tenant_id="tenant-id") graph_init_params = SimpleNamespace(graph_config={"nodes": []}) parent_graph_runtime_state = SimpleNamespace( execution_context=sentinel.execution_context, @@ -92,7 +118,7 @@ class TestWorkflowChildEngineBuilder: ) def test_build_child_engine_constructs_graph_engine_with_quota_layer_only(self): - builder = workflow_entry._WorkflowChildEngineBuilder() + builder = workflow_entry._WorkflowChildEngineBuilder(tenant_id="tenant-id") graph_init_params = SimpleNamespace(graph_config={"nodes": [{"id": "root"}]}) parent_graph_runtime_state = SimpleNamespace( execution_context=sentinel.execution_context, @@ -114,7 +140,7 @@ class TestWorkflowChildEngineBuilder: patch.object(workflow_entry, "GraphEngine", return_value=child_engine) as graph_engine_cls, patch.object(workflow_entry, "GraphEngineConfig", return_value=sentinel.graph_engine_config), patch.object(workflow_entry, "InMemoryChannel", return_value=sentinel.command_channel), - patch.object(workflow_entry, "LLMQuotaLayer", return_value=sentinel.llm_quota_layer), + patch.object(workflow_entry, "LLMQuotaLayer", return_value=sentinel.llm_quota_layer) as llm_quota_layer_cls, ): result = builder.build_child_engine( workflow_id="workflow-id", @@ -147,11 +173,12 @@ class TestWorkflowChildEngineBuilder: config=sentinel.graph_engine_config, child_engine_builder=builder, ) + llm_quota_layer_cls.assert_called_once_with(tenant_id="tenant-id") assert child_engine.layer.call_args_list == [((sentinel.llm_quota_layer,), {})] @pytest.mark.parametrize("node_cls", [_FakeLLMNode, _FakeQuestionClassifierNode]) def test_build_child_engine_runs_llm_quota_layer_for_child_model_nodes(self, node_cls): - builder = workflow_entry._WorkflowChildEngineBuilder() + builder = workflow_entry._WorkflowChildEngineBuilder(tenant_id="tenant-id") graph_init_params = build_test_graph_init_params( graph_config={"nodes": [{"id": "root"}], "edges": []}, ) @@ -163,12 +190,10 @@ class TestWorkflowChildEngineBuilder: def build_graph(*, graph_config, node_factory, root_node_id): _ = graph_config + node_data = _build_llm_node_data() if node_cls is _FakeLLMNode else _build_question_classifier_node_data() node = node_cls( node_id=root_node_id, - data=BaseNodeData( - type=node_cls.node_type, - title="Child Model", - ), + data=node_data, graph_init_params=node_factory.graph_init_params, graph_runtime_state=node_factory.graph_runtime_state, ) @@ -191,8 +216,8 @@ class TestWorkflowChildEngineBuilder: ), ), patch.object(workflow_entry.Graph, "init", side_effect=build_graph), - patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available") as ensure_quota, - patch("core.app.workflow.layers.llm_quota.deduct_llm_quota") as deduct_quota, + patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model") as ensure_quota, + patch("core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model") as deduct_quota, ): child_engine = builder.build_child_engine( workflow_id="workflow-id", @@ -203,10 +228,15 @@ class TestWorkflowChildEngineBuilder: list(child_engine.run()) node = created_node["node"] - ensure_quota.assert_called_once_with(model_instance=node.raw_model_instance) + ensure_quota.assert_called_once_with( + tenant_id="tenant-id", + provider=node.node_data.model.provider, + model=node.node_data.model.name, + ) deduct_quota.assert_called_once_with( - tenant_id="tenant", - model_instance=node.raw_model_instance, + tenant_id="tenant-id", + provider=node.node_data.model.provider, + model=node.node_data.model.name, usage=node.usage_snapshot, ) @@ -252,7 +282,7 @@ class TestWorkflowEntryInit: "ExecutionLimitsLayer", return_value=execution_limits_layer, ) as execution_limits_layer_cls, - patch.object(workflow_entry, "LLMQuotaLayer", return_value=llm_quota_layer), + patch.object(workflow_entry, "LLMQuotaLayer", return_value=llm_quota_layer) as llm_quota_layer_cls, patch.object(workflow_entry, "ObservabilityLayer", return_value=observability_layer), ): entry = workflow_entry.WorkflowEntry( @@ -291,6 +321,7 @@ class TestWorkflowEntryInit: max_steps=workflow_entry.dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_time=workflow_entry.dify_config.WORKFLOW_MAX_EXECUTION_TIME, ) + llm_quota_layer_cls.assert_called_once_with(tenant_id="tenant-id") assert graph_engine.layer.call_args_list == [ ((debug_layer,), {}), ((execution_limits_layer,), {}),