refactor: rename model instance extraction method to improve clarity and update related logic in LLMQuotaLayer; enhance unit tests for model instance handling

This commit is contained in:
Novice 2026-03-24 15:33:48 +08:00
parent e6c4bf7320
commit 7e65659239
No known key found for this signature in database
GPG Key ID: A253106A7475AA3E
3 changed files with 86 additions and 68 deletions

View File

@ -21,6 +21,12 @@ from dify_graph.nodes.base.node import Node
logger = logging.getLogger(__name__)
_LLM_LIKE_NODE_TYPES = {
BuiltinNodeTypes.LLM,
BuiltinNodeTypes.PARAMETER_EXTRACTOR,
BuiltinNodeTypes.QUESTION_CLASSIFIER,
}
@final
class LLMQuotaLayer(GraphEngineLayer):
@ -47,7 +53,7 @@ class LLMQuotaLayer(GraphEngineLayer):
if self._abort_sent:
return
model_instance = self._extract_model_instance(node)
model_instance = self._build_model_instance(node)
if model_instance is None:
return
@ -65,7 +71,7 @@ class LLMQuotaLayer(GraphEngineLayer):
if error is not None or not isinstance(result_event, NodeRunSucceededEvent):
return
model_instance = self._extract_model_instance(node)
model_instance = self._build_model_instance(node)
if model_instance is None:
return
@ -105,16 +111,22 @@ class LLMQuotaLayer(GraphEngineLayer):
logger.exception("Failed to send quota abort command")
@staticmethod
def _extract_model_instance(node: Node) -> ModelInstance | None:
match node.node_type:
case BuiltinNodeTypes.LLM | BuiltinNodeTypes.PARAMETER_EXTRACTOR | BuiltinNodeTypes.QUESTION_CLASSIFIER:
instance: ModelInstance | None = getattr(node, "model_instance", None)
if instance is not None:
return instance
logger.warning(
"LLMQuotaLayer skipped quota deduction because node does not expose a model instance, node_id=%s",
node.id,
)
return None
case _:
return None
def _build_model_instance(node: Node) -> ModelInstance | None:
if node.node_type not in _LLM_LIKE_NODE_TYPES:
return None
model_config = getattr(node.node_data, "model", None)
if model_config is None:
return None
try:
from dify_graph.nodes.llm.llm_utils import fetch_model_config
model_instance, _ = fetch_model_config(
tenant_id=node.tenant_id,
node_data_model=model_config,
)
return model_instance
except Exception:
logger.warning("Failed to build ModelInstance for quota check, node_id=%s", node.id, exc_info=True)
return None

View File

@ -10,6 +10,8 @@ from dify_graph.graph_events.node import NodeRunSucceededEvent
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.node_events import NodeRunResult
_FETCH_MODEL_CONFIG_PATH = "dify_graph.nodes.llm.llm_utils.fetch_model_config"
def _build_succeeded_event() -> NodeRunSucceededEvent:
return NodeRunSucceededEvent(
@ -25,44 +27,52 @@ def _build_succeeded_event() -> NodeRunSucceededEvent:
)
def test_deduct_quota_called_for_successful_llm_node() -> None:
layer = LLMQuotaLayer()
def _make_llm_node(*, node_type: BuiltinNodeTypes = BuiltinNodeTypes.LLM) -> MagicMock:
node = MagicMock()
node.id = "llm-node-id"
node.execution_id = "execution-id"
node.node_type = BuiltinNodeTypes.LLM
node.node_type = node_type
node.tenant_id = "tenant-id"
node.require_dify_context.return_value.tenant_id = "tenant-id"
node.model_instance = object()
node.node_data.model = MagicMock(name="model-config")
return node
def test_deduct_quota_called_for_successful_llm_node() -> None:
layer = LLMQuotaLayer()
node = _make_llm_node()
fake_instance = MagicMock(name="model-instance")
result_event = _build_succeeded_event()
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
with (
patch(_FETCH_MODEL_CONFIG_PATH, return_value=(fake_instance, MagicMock())),
patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct,
):
layer.on_node_run_end(node=node, error=None, result_event=result_event)
mock_deduct.assert_called_once_with(
tenant_id="tenant-id",
model_instance=node.model_instance,
model_instance=fake_instance,
usage=result_event.node_run_result.llm_usage,
)
def test_deduct_quota_called_for_question_classifier_node() -> None:
layer = LLMQuotaLayer()
node = MagicMock()
node = _make_llm_node(node_type=BuiltinNodeTypes.QUESTION_CLASSIFIER)
node.id = "question-classifier-node-id"
node.execution_id = "execution-id"
node.node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER
node.tenant_id = "tenant-id"
node.require_dify_context.return_value.tenant_id = "tenant-id"
node.model_instance = object()
fake_instance = MagicMock(name="model-instance")
result_event = _build_succeeded_event()
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
with (
patch(_FETCH_MODEL_CONFIG_PATH, return_value=(fake_instance, MagicMock())),
patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct,
):
layer.on_node_run_end(node=node, error=None, result_event=result_event)
mock_deduct.assert_called_once_with(
tenant_id="tenant-id",
model_instance=node.model_instance,
model_instance=fake_instance,
usage=result_event.node_run_result.llm_usage,
)
@ -74,8 +84,6 @@ def test_non_llm_node_is_ignored() -> None:
node.execution_id = "execution-id"
node.node_type = BuiltinNodeTypes.START
node.tenant_id = "tenant-id"
node.require_dify_context.return_value.tenant_id = "tenant-id"
node._model_instance = object()
result_event = _build_succeeded_event()
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
@ -86,19 +94,17 @@ def test_non_llm_node_is_ignored() -> None:
def test_quota_error_is_handled_in_layer() -> None:
layer = LLMQuotaLayer()
node = MagicMock()
node.id = "llm-node-id"
node.execution_id = "execution-id"
node.node_type = BuiltinNodeTypes.LLM
node.tenant_id = "tenant-id"
node.require_dify_context.return_value.tenant_id = "tenant-id"
node.model_instance = object()
node = _make_llm_node()
fake_instance = MagicMock(name="model-instance")
result_event = _build_succeeded_event()
with patch(
"core.app.workflow.layers.llm_quota.deduct_llm_quota",
autospec=True,
side_effect=ValueError("quota exceeded"),
with (
patch(_FETCH_MODEL_CONFIG_PATH, return_value=(fake_instance, MagicMock())),
patch(
"core.app.workflow.layers.llm_quota.deduct_llm_quota",
autospec=True,
side_effect=ValueError("quota exceeded"),
),
):
layer.on_node_run_end(node=node, error=None, result_event=result_event)
@ -108,21 +114,19 @@ def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None:
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = MagicMock()
node.id = "llm-node-id"
node.execution_id = "execution-id"
node.node_type = BuiltinNodeTypes.LLM
node.tenant_id = "tenant-id"
node.require_dify_context.return_value.tenant_id = "tenant-id"
node.model_instance = object()
node = _make_llm_node()
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.stop_event = stop_event
fake_instance = MagicMock(name="model-instance")
result_event = _build_succeeded_event()
with patch(
"core.app.workflow.layers.llm_quota.deduct_llm_quota",
autospec=True,
side_effect=QuotaExceededError("No credits remaining"),
with (
patch(_FETCH_MODEL_CONFIG_PATH, return_value=(fake_instance, MagicMock())),
patch(
"core.app.workflow.layers.llm_quota.deduct_llm_quota",
autospec=True,
side_effect=QuotaExceededError("No credits remaining"),
),
):
layer.on_node_run_end(node=node, error=None, result_event=result_event)
@ -138,17 +142,18 @@ def test_quota_precheck_failure_aborts_workflow_immediately() -> None:
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = MagicMock()
node.id = "llm-node-id"
node.node_type = BuiltinNodeTypes.LLM
node.model_instance = object()
node = _make_llm_node()
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.stop_event = stop_event
fake_instance = MagicMock(name="model-instance")
with patch(
"core.app.workflow.layers.llm_quota.ensure_llm_quota_available",
autospec=True,
side_effect=QuotaExceededError("Model provider openai quota exceeded."),
with (
patch(_FETCH_MODEL_CONFIG_PATH, return_value=(fake_instance, MagicMock())),
patch(
"core.app.workflow.layers.llm_quota.ensure_llm_quota_available",
autospec=True,
side_effect=QuotaExceededError("Model provider openai quota exceeded."),
),
):
layer.on_node_run_start(node)
@ -164,16 +169,17 @@ def test_quota_precheck_passes_without_abort() -> None:
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = MagicMock()
node.id = "llm-node-id"
node.node_type = BuiltinNodeTypes.LLM
node.model_instance = object()
node = _make_llm_node()
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.stop_event = stop_event
fake_instance = MagicMock(name="model-instance")
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available", autospec=True) as mock_check:
with (
patch(_FETCH_MODEL_CONFIG_PATH, return_value=(fake_instance, MagicMock())),
patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available", autospec=True) as mock_check,
):
layer.on_node_run_start(node)
assert not stop_event.is_set()
mock_check.assert_called_once_with(model_instance=node.model_instance)
mock_check.assert_called_once_with(model_instance=fake_instance)
layer.command_channel.send_command.assert_not_called()

View File

@ -140,7 +140,7 @@ const NodePanel: FC<Props> = ({
size={inMessage ? 'xs' : 'sm'}
className={cn('mr-2 shrink-0', inMessage && '!mr-1')}
type={nodeInfo.node_type}
toolIcon={nodeInfo.extras as string | { content: string, background: string } | undefined}
toolIcon={((nodeInfo.extras as { icon?: string } | undefined)?.icon || nodeInfo.extras) as string | { content: string, background: string } | undefined}
/>
<Tooltip
popupContent={