diff --git a/api/core/app/workflow/layers/llm_quota.py b/api/core/app/workflow/layers/llm_quota.py index 617e58341f..a39a1c66a8 100644 --- a/api/core/app/workflow/layers/llm_quota.py +++ b/api/core/app/workflow/layers/llm_quota.py @@ -21,6 +21,12 @@ from dify_graph.nodes.base.node import Node logger = logging.getLogger(__name__) +_LLM_LIKE_NODE_TYPES = { + BuiltinNodeTypes.LLM, + BuiltinNodeTypes.PARAMETER_EXTRACTOR, + BuiltinNodeTypes.QUESTION_CLASSIFIER, +} + @final class LLMQuotaLayer(GraphEngineLayer): @@ -47,7 +53,7 @@ class LLMQuotaLayer(GraphEngineLayer): if self._abort_sent: return - model_instance = self._extract_model_instance(node) + model_instance = self._build_model_instance(node) if model_instance is None: return @@ -65,7 +71,7 @@ class LLMQuotaLayer(GraphEngineLayer): if error is not None or not isinstance(result_event, NodeRunSucceededEvent): return - model_instance = self._extract_model_instance(node) + model_instance = self._build_model_instance(node) if model_instance is None: return @@ -105,16 +111,22 @@ class LLMQuotaLayer(GraphEngineLayer): logger.exception("Failed to send quota abort command") @staticmethod - def _extract_model_instance(node: Node) -> ModelInstance | None: - match node.node_type: - case BuiltinNodeTypes.LLM | BuiltinNodeTypes.PARAMETER_EXTRACTOR | BuiltinNodeTypes.QUESTION_CLASSIFIER: - instance: ModelInstance | None = getattr(node, "model_instance", None) - if instance is not None: - return instance - logger.warning( - "LLMQuotaLayer skipped quota deduction because node does not expose a model instance, node_id=%s", - node.id, - ) - return None - case _: - return None + def _build_model_instance(node: Node) -> ModelInstance | None: + if node.node_type not in _LLM_LIKE_NODE_TYPES: + return None + + model_config = getattr(node.node_data, "model", None) + if model_config is None: + return None + + try: + from dify_graph.nodes.llm.llm_utils import fetch_model_config + + model_instance, _ = fetch_model_config( + tenant_id=node.tenant_id, + node_data_model=model_config, + ) + return model_instance + except Exception: + logger.warning("Failed to build ModelInstance for quota check, node_id=%s", node.id, exc_info=True) + return None diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py index 2a36f712fd..6fc9c905e6 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py @@ -10,6 +10,8 @@ from dify_graph.graph_events.node import NodeRunSucceededEvent from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.node_events import NodeRunResult +_FETCH_MODEL_CONFIG_PATH = "dify_graph.nodes.llm.llm_utils.fetch_model_config" + def _build_succeeded_event() -> NodeRunSucceededEvent: return NodeRunSucceededEvent( @@ -25,44 +27,52 @@ def _build_succeeded_event() -> NodeRunSucceededEvent: ) -def test_deduct_quota_called_for_successful_llm_node() -> None: - layer = LLMQuotaLayer() +def _make_llm_node(*, node_type: BuiltinNodeTypes = BuiltinNodeTypes.LLM) -> MagicMock: node = MagicMock() node.id = "llm-node-id" node.execution_id = "execution-id" - node.node_type = BuiltinNodeTypes.LLM + node.node_type = node_type node.tenant_id = "tenant-id" node.require_dify_context.return_value.tenant_id = "tenant-id" - node.model_instance = object() + node.node_data.model = MagicMock(name="model-config") + return node + + +def test_deduct_quota_called_for_successful_llm_node() -> None: + layer = LLMQuotaLayer() + node = _make_llm_node() + fake_instance = MagicMock(name="model-instance") result_event = _build_succeeded_event() - with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct: + with ( + patch(_FETCH_MODEL_CONFIG_PATH, return_value=(fake_instance, MagicMock())), + patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct, + ): layer.on_node_run_end(node=node, error=None, result_event=result_event) mock_deduct.assert_called_once_with( tenant_id="tenant-id", - model_instance=node.model_instance, + model_instance=fake_instance, usage=result_event.node_run_result.llm_usage, ) def test_deduct_quota_called_for_question_classifier_node() -> None: layer = LLMQuotaLayer() - node = MagicMock() + node = _make_llm_node(node_type=BuiltinNodeTypes.QUESTION_CLASSIFIER) node.id = "question-classifier-node-id" - node.execution_id = "execution-id" - node.node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER - node.tenant_id = "tenant-id" - node.require_dify_context.return_value.tenant_id = "tenant-id" - node.model_instance = object() + fake_instance = MagicMock(name="model-instance") result_event = _build_succeeded_event() - with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct: + with ( + patch(_FETCH_MODEL_CONFIG_PATH, return_value=(fake_instance, MagicMock())), + patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct, + ): layer.on_node_run_end(node=node, error=None, result_event=result_event) mock_deduct.assert_called_once_with( tenant_id="tenant-id", - model_instance=node.model_instance, + model_instance=fake_instance, usage=result_event.node_run_result.llm_usage, ) @@ -74,8 +84,6 @@ def test_non_llm_node_is_ignored() -> None: node.execution_id = "execution-id" node.node_type = BuiltinNodeTypes.START node.tenant_id = "tenant-id" - node.require_dify_context.return_value.tenant_id = "tenant-id" - node._model_instance = object() result_event = _build_succeeded_event() with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct: @@ -86,19 +94,17 @@ def test_non_llm_node_is_ignored() -> None: def test_quota_error_is_handled_in_layer() -> None: layer = LLMQuotaLayer() - node = MagicMock() - node.id = "llm-node-id" - node.execution_id = "execution-id" - node.node_type = BuiltinNodeTypes.LLM - node.tenant_id = "tenant-id" - node.require_dify_context.return_value.tenant_id = "tenant-id" - node.model_instance = object() + node = _make_llm_node() + fake_instance = MagicMock(name="model-instance") result_event = _build_succeeded_event() - with patch( - "core.app.workflow.layers.llm_quota.deduct_llm_quota", - autospec=True, - side_effect=ValueError("quota exceeded"), + with ( + patch(_FETCH_MODEL_CONFIG_PATH, return_value=(fake_instance, MagicMock())), + patch( + "core.app.workflow.layers.llm_quota.deduct_llm_quota", + autospec=True, + side_effect=ValueError("quota exceeded"), + ), ): layer.on_node_run_end(node=node, error=None, result_event=result_event) @@ -108,21 +114,19 @@ def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None: stop_event = threading.Event() layer.command_channel = MagicMock() - node = MagicMock() - node.id = "llm-node-id" - node.execution_id = "execution-id" - node.node_type = BuiltinNodeTypes.LLM - node.tenant_id = "tenant-id" - node.require_dify_context.return_value.tenant_id = "tenant-id" - node.model_instance = object() + node = _make_llm_node() node.graph_runtime_state = MagicMock() node.graph_runtime_state.stop_event = stop_event + fake_instance = MagicMock(name="model-instance") result_event = _build_succeeded_event() - with patch( - "core.app.workflow.layers.llm_quota.deduct_llm_quota", - autospec=True, - side_effect=QuotaExceededError("No credits remaining"), + with ( + patch(_FETCH_MODEL_CONFIG_PATH, return_value=(fake_instance, MagicMock())), + patch( + "core.app.workflow.layers.llm_quota.deduct_llm_quota", + autospec=True, + side_effect=QuotaExceededError("No credits remaining"), + ), ): layer.on_node_run_end(node=node, error=None, result_event=result_event) @@ -138,17 +142,18 @@ def test_quota_precheck_failure_aborts_workflow_immediately() -> None: stop_event = threading.Event() layer.command_channel = MagicMock() - node = MagicMock() - node.id = "llm-node-id" - node.node_type = BuiltinNodeTypes.LLM - node.model_instance = object() + node = _make_llm_node() node.graph_runtime_state = MagicMock() node.graph_runtime_state.stop_event = stop_event + fake_instance = MagicMock(name="model-instance") - with patch( - "core.app.workflow.layers.llm_quota.ensure_llm_quota_available", - autospec=True, - side_effect=QuotaExceededError("Model provider openai quota exceeded."), + with ( + patch(_FETCH_MODEL_CONFIG_PATH, return_value=(fake_instance, MagicMock())), + patch( + "core.app.workflow.layers.llm_quota.ensure_llm_quota_available", + autospec=True, + side_effect=QuotaExceededError("Model provider openai quota exceeded."), + ), ): layer.on_node_run_start(node) @@ -164,16 +169,17 @@ def test_quota_precheck_passes_without_abort() -> None: stop_event = threading.Event() layer.command_channel = MagicMock() - node = MagicMock() - node.id = "llm-node-id" - node.node_type = BuiltinNodeTypes.LLM - node.model_instance = object() + node = _make_llm_node() node.graph_runtime_state = MagicMock() node.graph_runtime_state.stop_event = stop_event + fake_instance = MagicMock(name="model-instance") - with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available", autospec=True) as mock_check: + with ( + patch(_FETCH_MODEL_CONFIG_PATH, return_value=(fake_instance, MagicMock())), + patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available", autospec=True) as mock_check, + ): layer.on_node_run_start(node) assert not stop_event.is_set() - mock_check.assert_called_once_with(model_instance=node.model_instance) + mock_check.assert_called_once_with(model_instance=fake_instance) layer.command_channel.send_command.assert_not_called() diff --git a/web/app/components/workflow/run/node.tsx b/web/app/components/workflow/run/node.tsx index dfda405345..076a7f3a2b 100644 --- a/web/app/components/workflow/run/node.tsx +++ b/web/app/components/workflow/run/node.tsx @@ -140,7 +140,7 @@ const NodePanel: FC = ({ size={inMessage ? 'xs' : 'sm'} className={cn('mr-2 shrink-0', inMessage && '!mr-1')} type={nodeInfo.node_type} - toolIcon={nodeInfo.extras as string | { content: string, background: string } | undefined} + toolIcon={((nodeInfo.extras as { icon?: string } | undefined)?.icon || nodeInfo.extras) as string | { content: string, background: string } | undefined} />