diff --git a/api/core/app/workflow/layers/llm_quota.py b/api/core/app/workflow/layers/llm_quota.py index 89fcd496e9..5a9bd1fbda 100644 --- a/api/core/app/workflow/layers/llm_quota.py +++ b/api/core/app/workflow/layers/llm_quota.py @@ -12,14 +12,23 @@ handling is never silently skipped. """ import logging +from collections.abc import Generator +from datetime import UTC, datetime from typing import final, override from core.app.llm import deduct_llm_quota_for_model, ensure_llm_quota_available_for_model from core.errors.error import QuotaExceededError -from graphon.enums import BuiltinNodeTypes +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from graphon.graph_engine.entities.commands import AbortCommand, CommandType from graphon.graph_engine.layers import GraphEngineLayer -from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, NodeRunSucceededEvent +from graphon.graph_events import ( + GraphEngineEvent, + GraphNodeEventBase, + NodeRunFailedEvent, + NodeRunStartedEvent, + NodeRunSucceededEvent, +) +from graphon.node_events import NodeRunResult from graphon.nodes.base.node import Node logger = logging.getLogger(__name__) @@ -66,10 +75,9 @@ class LLMQuotaLayer(GraphEngineLayer): model_identity = self._extract_model_identity_from_node(node) if model_identity is None: - self._abort_for_missing_model_identity( - node=node, - reason="LLM quota check requires public node model identity before execution.", - ) + reason = "LLM quota check requires public node model identity before execution." + self._abort_before_node_run(node=node, reason=reason, error_type="LLMQuotaIdentityError") + logger.error("LLM quota handling aborted, node_id=%s, reason=%s", node.id, reason) return provider, model_name = model_identity @@ -80,8 +88,7 @@ class LLMQuotaLayer(GraphEngineLayer): model=model_name, ) except QuotaExceededError as exc: - self._set_stop_event(node) - self._send_abort_command(reason=str(exc)) + self._abort_before_node_run(node=node, reason=str(exc), error_type=QuotaExceededError.__name__) logger.warning("LLM quota check failed, node_id=%s, error=%s", node.id, exc) @override @@ -121,11 +128,73 @@ class LLMQuotaLayer(GraphEngineLayer): if stop_event is not None: stop_event.set() + def _abort_before_node_run(self, *, node: Node, reason: str, error_type: str) -> None: + self._set_stop_event(node) + self._force_node_failure_to_abort(node) + self._block_current_node_run(node=node, reason=reason, error_type=error_type) + self._send_abort_command(reason=reason) + def _abort_for_missing_model_identity(self, *, node: Node, reason: str) -> None: self._set_stop_event(node) self._send_abort_command(reason=reason) logger.error("LLM quota handling aborted, node_id=%s, reason=%s", node.id, reason) + @staticmethod + def _force_node_failure_to_abort(node: Node) -> None: + node_data = getattr(node, "node_data", None) + if node_data is None: + return + + # Quota aborts must not be converted into retry, default-value, or fail-branch execution. + try: + if hasattr(node_data, "error_strategy"): + node_data.error_strategy = None + + retry_config = getattr(node_data, "retry_config", None) + if retry_config is not None and hasattr(retry_config, "retry_enabled"): + retry_config.retry_enabled = False + except Exception: + logger.warning("Failed to force quota-aborted node into abort strategy, node_id=%s", node.id, exc_info=True) + + @staticmethod + def _block_current_node_run(*, node: Node, reason: str, error_type: str) -> None: + def blocked_run() -> Generator[GraphNodeEventBase, None, None]: + execution_id = node.ensure_execution_id() + start_at = datetime.now(UTC).replace(tzinfo=None) + start_event = NodeRunStartedEvent( + id=execution_id, + node_id=node.id, + node_type=node.node_type, + node_title=str(getattr(node, "title", node.id)), + in_iteration_id=None, + start_at=start_at, + ) + populate_start_event = getattr(node, "populate_start_event", None) + if callable(populate_start_event): + try: + populate_start_event(start_event) + except Exception: + logger.warning("Failed to populate quota-aborted start event, node_id=%s", node.id, exc_info=True) + + yield start_event + + finished_at = datetime.now(UTC).replace(tzinfo=None) + yield NodeRunFailedEvent( + id=execution_id, + node_id=node.id, + node_type=node.node_type, + start_at=start_at, + finished_at=finished_at, + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=reason, + error_type=error_type, + ), + error=reason, + ) + + object.__setattr__(node, "run", blocked_run) + def _send_abort_command(self, *, reason: str) -> None: if not self.command_channel or self._abort_sent: return diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py index 9847f513e4..ecbd83cba8 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py @@ -132,7 +132,9 @@ class TestAdvancedChatGenerateTaskPipeline: pipeline._task_state.answer = "partial answer" pipeline._workflow_run_id = "run-id" pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), + variable_pool=build_test_variable_pool( + variables=build_system_variables(workflow_execution_id="run-id"), + ), start_at=0.0, total_tokens=7, node_run_steps=3, diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py index 52a3cdb159..1a22c3de7f 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py @@ -95,7 +95,9 @@ class TestWorkflowGenerateTaskPipeline: def test_to_blocking_response_falls_back_to_human_input_required_when_pause_event_missing(self): pipeline = _make_pipeline() pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), + variable_pool=build_test_variable_pool( + variables=build_system_variables(workflow_execution_id="run-id"), + ), start_at=0.0, total_tokens=5, node_run_steps=2, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py index d4d9722182..9b7641544a 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py @@ -1,4 +1,5 @@ import threading +from collections.abc import Generator from datetime import datetime from types import SimpleNamespace from unittest.mock import MagicMock, patch @@ -7,7 +8,7 @@ from core.app.workflow.layers.llm_quota import LLMQuotaLayer from core.errors.error import QuotaExceededError from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from graphon.graph_engine.entities.commands import CommandType -from graphon.graph_events import NodeRunSucceededEvent +from graphon.graph_events import NodeRunFailedEvent, NodeRunStartedEvent, NodeRunSucceededEvent from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.node_events import NodeRunResult @@ -44,6 +45,28 @@ def _build_node(*, node_type: BuiltinNodeTypes = BuiltinNodeTypes.LLM) -> MagicM return node +class _RunnableQuotaNode: + id = "node-id" + execution_id = "execution-id" + node_type = BuiltinNodeTypes.LLM + title = "LLM node" + + def __init__(self, *, stop_event: threading.Event, node_data: SimpleNamespace | None = None) -> None: + self.node_data = node_data or SimpleNamespace(model=_build_public_model_identity()) + self.graph_runtime_state = SimpleNamespace(stop_event=stop_event) + self.original_run_called = False + + def ensure_execution_id(self) -> str: + return self.execution_id + + def populate_start_event(self, event: NodeRunStartedEvent) -> None: + _ = event + + def run(self) -> Generator[NodeRunSucceededEvent, None, None]: + self.original_run_called = True + yield _build_succeeded_event() + + def test_deduct_quota_called_for_successful_llm_node() -> None: layer = LLMQuotaLayer(tenant_id="tenant-id") node = _build_node(node_type=BuiltinNodeTypes.LLM) @@ -147,6 +170,48 @@ def test_quota_precheck_failure_aborts_workflow_immediately() -> None: assert abort_command.reason == "Model provider openai quota exceeded." +def test_quota_precheck_failure_blocks_current_node_run() -> None: + layer = LLMQuotaLayer(tenant_id="tenant-id") + stop_event = threading.Event() + layer.command_channel = MagicMock() + + node = _RunnableQuotaNode(stop_event=stop_event) + + with patch( + "core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", + autospec=True, + side_effect=QuotaExceededError("Model provider openai quota exceeded."), + ): + layer.on_node_run_start(node) + + events = list(node.run()) + assert not node.original_run_called + assert isinstance(events[0], NodeRunStartedEvent) + assert isinstance(events[1], NodeRunFailedEvent) + assert events[1].error == "Model provider openai quota exceeded." + assert events[1].node_run_result.status == WorkflowNodeExecutionStatus.FAILED + assert events[1].node_run_result.error_type == QuotaExceededError.__name__ + + +def test_missing_model_identity_blocks_current_node_run() -> None: + layer = LLMQuotaLayer(tenant_id="tenant-id") + stop_event = threading.Event() + layer.command_channel = MagicMock() + + node = _RunnableQuotaNode(stop_event=stop_event, node_data=SimpleNamespace()) + + with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check: + layer.on_node_run_start(node) + + events = list(node.run()) + assert not node.original_run_called + assert isinstance(events[1], NodeRunFailedEvent) + assert events[1].error == "LLM quota check requires public node model identity before execution." + assert events[1].node_run_result.status == WorkflowNodeExecutionStatus.FAILED + assert events[1].node_run_result.error_type == "LLMQuotaIdentityError" + mock_check.assert_not_called() + + def test_quota_precheck_passes_without_abort() -> None: layer = LLMQuotaLayer(tenant_id="tenant-id") stop_event = threading.Event() diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py index 387f508154..2603e29be6 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -1,41 +1,36 @@ import time import uuid -from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.workflow.node_factory import DifyNodeFactory from core.workflow.system_variables import build_system_variables -from extensions.ext_database import db from graphon.enums import WorkflowNodeExecutionStatus -from graphon.graph import Graph from graphon.nodes.answer.answer_node import AnswerNode from graphon.nodes.answer.entities import AnswerNodeData from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params -def test_execute_answer(): +def _build_variable_pool() -> VariablePool: + return VariablePool.from_bootstrap( + system_variables=build_system_variables(user_id="aaa", files=[]), + user_inputs={}, + ) + + +def _build_answer_node(*, answer: str, variable_pool: VariablePool) -> AnswerNode: graph_config = { - "edges": [ - { - "id": "start-source-answer-target", - "source": "start", - "target": "answer", - }, - ], + "edges": [], "nodes": [ - {"data": {"type": "start", "title": "Start"}, "id": "start"}, { "data": { - "title": "123", + "title": "Answer", "type": "answer", - "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", + "answer": answer, }, "id": "answer", - }, + } ], } - init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, @@ -46,42 +41,31 @@ def test_execute_answer(): invoke_from=InvokeFrom.DEBUGGER, call_depth=0, ) - - # construct variable pool - variable_pool = VariablePool.from_bootstrap( - system_variables=build_system_variables(user_id="aaa", files=[]), - user_inputs={}, - environment_variables=[], - conversation_variables=[], + graph_runtime_state = GraphRuntimeState( + variable_pool=variable_pool, + start_at=time.perf_counter(), ) - variable_pool.add(["start", "weather"], "sunny") - variable_pool.add(["llm", "text"], "You are a helpful AI.") - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - # create node factory - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - node = AnswerNode( + return AnswerNode( node_id=str(uuid.uuid4()), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, data=AnswerNodeData( - title="123", + title="Answer", type="answer", - answer="Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", + answer=answer, ), ) - # Mock db.session.close() - db.session.close = MagicMock() - # execute node +def test_execute_answer_renders_variable_selectors() -> None: + variable_pool = _build_variable_pool() + variable_pool.add(["start", "weather"], "sunny") + variable_pool.add(["llm", "text"], "You are a helpful AI.") + node = _build_answer_node( + answer="Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", + variable_pool=variable_pool, + ) + result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED @@ -89,36 +73,11 @@ def test_execute_answer(): def test_execute_answer_renders_structured_output_object_as_json() -> None: - init_params = build_test_graph_init_params( - workflow_id="1", - graph_config={"nodes": [], "edges": []}, - tenant_id="1", - app_id="1", - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables=build_system_variables(user_id="aaa", files=[]), - user_inputs={}, - environment_variables=[], - conversation_variables=[], - ) + variable_pool = _build_variable_pool() variable_pool.add(["1777539038857", "structured_output"], {"type": "greeting"}) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - node = AnswerNode( - node_id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=AnswerNodeData( - title="123", - type="answer", - answer="{{#1777539038857.structured_output#}}", - ), + node = _build_answer_node( + answer="{{#1777539038857.structured_output#}}", + variable_pool=variable_pool, ) result = node._run() @@ -128,35 +87,9 @@ def test_execute_answer_renders_structured_output_object_as_json() -> None: def test_execute_answer_falls_back_to_plain_selector_text_when_structured_output_missing() -> None: - init_params = build_test_graph_init_params( - workflow_id="1", - graph_config={"nodes": [], "edges": []}, - tenant_id="1", - app_id="1", - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables=build_system_variables(user_id="aaa", files=[]), - user_inputs={}, - environment_variables=[], - conversation_variables=[], - ) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - node = AnswerNode( - node_id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=AnswerNodeData( - title="123", - type="answer", - answer="{{#1777539038857.structured_output#}}", - ), + node = _build_answer_node( + answer="{{#1777539038857.structured_output#}}", + variable_pool=_build_variable_pool(), ) result = node._run()