fix: test

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2026-05-06 14:40:00 +08:00
parent ba84ecc63f
commit f0442456df
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF
5 changed files with 183 additions and 112 deletions

View File

@ -12,14 +12,23 @@ handling is never silently skipped.
"""
import logging
from collections.abc import Generator
from datetime import UTC, datetime
from typing import final, override
from core.app.llm import deduct_llm_quota_for_model, ensure_llm_quota_available_for_model
from core.errors.error import QuotaExceededError
from graphon.enums import BuiltinNodeTypes
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
from graphon.graph_engine.entities.commands import AbortCommand, CommandType
from graphon.graph_engine.layers import GraphEngineLayer
from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, NodeRunSucceededEvent
from graphon.graph_events import (
GraphEngineEvent,
GraphNodeEventBase,
NodeRunFailedEvent,
NodeRunStartedEvent,
NodeRunSucceededEvent,
)
from graphon.node_events import NodeRunResult
from graphon.nodes.base.node import Node
logger = logging.getLogger(__name__)
@ -66,10 +75,9 @@ class LLMQuotaLayer(GraphEngineLayer):
model_identity = self._extract_model_identity_from_node(node)
if model_identity is None:
self._abort_for_missing_model_identity(
node=node,
reason="LLM quota check requires public node model identity before execution.",
)
reason = "LLM quota check requires public node model identity before execution."
self._abort_before_node_run(node=node, reason=reason, error_type="LLMQuotaIdentityError")
logger.error("LLM quota handling aborted, node_id=%s, reason=%s", node.id, reason)
return
provider, model_name = model_identity
@ -80,8 +88,7 @@ class LLMQuotaLayer(GraphEngineLayer):
model=model_name,
)
except QuotaExceededError as exc:
self._set_stop_event(node)
self._send_abort_command(reason=str(exc))
self._abort_before_node_run(node=node, reason=str(exc), error_type=QuotaExceededError.__name__)
logger.warning("LLM quota check failed, node_id=%s, error=%s", node.id, exc)
@override
@ -121,11 +128,73 @@ class LLMQuotaLayer(GraphEngineLayer):
if stop_event is not None:
stop_event.set()
def _abort_before_node_run(self, *, node: Node, reason: str, error_type: str) -> None:
self._set_stop_event(node)
self._force_node_failure_to_abort(node)
self._block_current_node_run(node=node, reason=reason, error_type=error_type)
self._send_abort_command(reason=reason)
def _abort_for_missing_model_identity(self, *, node: Node, reason: str) -> None:
self._set_stop_event(node)
self._send_abort_command(reason=reason)
logger.error("LLM quota handling aborted, node_id=%s, reason=%s", node.id, reason)
@staticmethod
def _force_node_failure_to_abort(node: Node) -> None:
node_data = getattr(node, "node_data", None)
if node_data is None:
return
# Quota aborts must not be converted into retry, default-value, or fail-branch execution.
try:
if hasattr(node_data, "error_strategy"):
node_data.error_strategy = None
retry_config = getattr(node_data, "retry_config", None)
if retry_config is not None and hasattr(retry_config, "retry_enabled"):
retry_config.retry_enabled = False
except Exception:
logger.warning("Failed to force quota-aborted node into abort strategy, node_id=%s", node.id, exc_info=True)
@staticmethod
def _block_current_node_run(*, node: Node, reason: str, error_type: str) -> None:
def blocked_run() -> Generator[GraphNodeEventBase, None, None]:
execution_id = node.ensure_execution_id()
start_at = datetime.now(UTC).replace(tzinfo=None)
start_event = NodeRunStartedEvent(
id=execution_id,
node_id=node.id,
node_type=node.node_type,
node_title=str(getattr(node, "title", node.id)),
in_iteration_id=None,
start_at=start_at,
)
populate_start_event = getattr(node, "populate_start_event", None)
if callable(populate_start_event):
try:
populate_start_event(start_event)
except Exception:
logger.warning("Failed to populate quota-aborted start event, node_id=%s", node.id, exc_info=True)
yield start_event
finished_at = datetime.now(UTC).replace(tzinfo=None)
yield NodeRunFailedEvent(
id=execution_id,
node_id=node.id,
node_type=node.node_type,
start_at=start_at,
finished_at=finished_at,
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=reason,
error_type=error_type,
),
error=reason,
)
object.__setattr__(node, "run", blocked_run)
def _send_abort_command(self, *, reason: str) -> None:
if not self.command_channel or self._abort_sent:
return

View File

@ -132,7 +132,9 @@ class TestAdvancedChatGenerateTaskPipeline:
pipeline._task_state.answer = "partial answer"
pipeline._workflow_run_id = "run-id"
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
variable_pool=build_test_variable_pool(
variables=build_system_variables(workflow_execution_id="run-id"),
),
start_at=0.0,
total_tokens=7,
node_run_steps=3,

View File

@ -95,7 +95,9 @@ class TestWorkflowGenerateTaskPipeline:
def test_to_blocking_response_falls_back_to_human_input_required_when_pause_event_missing(self):
pipeline = _make_pipeline()
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
variable_pool=build_test_variable_pool(
variables=build_system_variables(workflow_execution_id="run-id"),
),
start_at=0.0,
total_tokens=5,
node_run_steps=2,

View File

@ -1,4 +1,5 @@
import threading
from collections.abc import Generator
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
@ -7,7 +8,7 @@ from core.app.workflow.layers.llm_quota import LLMQuotaLayer
from core.errors.error import QuotaExceededError
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
from graphon.graph_engine.entities.commands import CommandType
from graphon.graph_events import NodeRunSucceededEvent
from graphon.graph_events import NodeRunFailedEvent, NodeRunStartedEvent, NodeRunSucceededEvent
from graphon.model_runtime.entities.llm_entities import LLMUsage
from graphon.node_events import NodeRunResult
@ -44,6 +45,28 @@ def _build_node(*, node_type: BuiltinNodeTypes = BuiltinNodeTypes.LLM) -> MagicM
return node
class _RunnableQuotaNode:
id = "node-id"
execution_id = "execution-id"
node_type = BuiltinNodeTypes.LLM
title = "LLM node"
def __init__(self, *, stop_event: threading.Event, node_data: SimpleNamespace | None = None) -> None:
self.node_data = node_data or SimpleNamespace(model=_build_public_model_identity())
self.graph_runtime_state = SimpleNamespace(stop_event=stop_event)
self.original_run_called = False
def ensure_execution_id(self) -> str:
return self.execution_id
def populate_start_event(self, event: NodeRunStartedEvent) -> None:
_ = event
def run(self) -> Generator[NodeRunSucceededEvent, None, None]:
self.original_run_called = True
yield _build_succeeded_event()
def test_deduct_quota_called_for_successful_llm_node() -> None:
layer = LLMQuotaLayer(tenant_id="tenant-id")
node = _build_node(node_type=BuiltinNodeTypes.LLM)
@ -147,6 +170,48 @@ def test_quota_precheck_failure_aborts_workflow_immediately() -> None:
assert abort_command.reason == "Model provider openai quota exceeded."
def test_quota_precheck_failure_blocks_current_node_run() -> None:
layer = LLMQuotaLayer(tenant_id="tenant-id")
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = _RunnableQuotaNode(stop_event=stop_event)
with patch(
"core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model",
autospec=True,
side_effect=QuotaExceededError("Model provider openai quota exceeded."),
):
layer.on_node_run_start(node)
events = list(node.run())
assert not node.original_run_called
assert isinstance(events[0], NodeRunStartedEvent)
assert isinstance(events[1], NodeRunFailedEvent)
assert events[1].error == "Model provider openai quota exceeded."
assert events[1].node_run_result.status == WorkflowNodeExecutionStatus.FAILED
assert events[1].node_run_result.error_type == QuotaExceededError.__name__
def test_missing_model_identity_blocks_current_node_run() -> None:
layer = LLMQuotaLayer(tenant_id="tenant-id")
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = _RunnableQuotaNode(stop_event=stop_event, node_data=SimpleNamespace())
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check:
layer.on_node_run_start(node)
events = list(node.run())
assert not node.original_run_called
assert isinstance(events[1], NodeRunFailedEvent)
assert events[1].error == "LLM quota check requires public node model identity before execution."
assert events[1].node_run_result.status == WorkflowNodeExecutionStatus.FAILED
assert events[1].node_run_result.error_type == "LLMQuotaIdentityError"
mock_check.assert_not_called()
def test_quota_precheck_passes_without_abort() -> None:
layer = LLMQuotaLayer(tenant_id="tenant-id")
stop_event = threading.Event()

View File

@ -1,41 +1,36 @@
import time
import uuid
from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
from core.workflow.node_factory import DifyNodeFactory
from core.workflow.system_variables import build_system_variables
from extensions.ext_database import db
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.graph import Graph
from graphon.nodes.answer.answer_node import AnswerNode
from graphon.nodes.answer.entities import AnswerNodeData
from graphon.runtime import GraphRuntimeState, VariablePool
from tests.workflow_test_utils import build_test_graph_init_params
def test_execute_answer():
def _build_variable_pool() -> VariablePool:
return VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="aaa", files=[]),
user_inputs={},
)
def _build_answer_node(*, answer: str, variable_pool: VariablePool) -> AnswerNode:
graph_config = {
"edges": [
{
"id": "start-source-answer-target",
"source": "start",
"target": "answer",
},
],
"edges": [],
"nodes": [
{"data": {"type": "start", "title": "Start"}, "id": "start"},
{
"data": {
"title": "123",
"title": "Answer",
"type": "answer",
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
"answer": answer,
},
"id": "answer",
},
}
],
}
init_params = build_test_graph_init_params(
workflow_id="1",
graph_config=graph_config,
@ -46,42 +41,31 @@ def test_execute_answer():
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
# construct variable pool
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
conversation_variables=[],
graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
start_at=time.perf_counter(),
)
variable_pool.add(["start", "weather"], "sunny")
variable_pool.add(["llm", "text"], "You are a helpful AI.")
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# create node factory
node_factory = DifyNodeFactory(
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
node = AnswerNode(
return AnswerNode(
node_id=str(uuid.uuid4()),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
data=AnswerNodeData(
title="123",
title="Answer",
type="answer",
answer="Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
answer=answer,
),
)
# Mock db.session.close()
db.session.close = MagicMock()
# execute node
def test_execute_answer_renders_variable_selectors() -> None:
variable_pool = _build_variable_pool()
variable_pool.add(["start", "weather"], "sunny")
variable_pool.add(["llm", "text"], "You are a helpful AI.")
node = _build_answer_node(
answer="Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
variable_pool=variable_pool,
)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
@ -89,36 +73,11 @@ def test_execute_answer():
def test_execute_answer_renders_structured_output_object_as_json() -> None:
init_params = build_test_graph_init_params(
workflow_id="1",
graph_config={"nodes": [], "edges": []},
tenant_id="1",
app_id="1",
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
variable_pool = VariablePool(
system_variables=build_system_variables(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
conversation_variables=[],
)
variable_pool = _build_variable_pool()
variable_pool.add(["1777539038857", "structured_output"], {"type": "greeting"})
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
node = AnswerNode(
node_id=str(uuid.uuid4()),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
config=AnswerNodeData(
title="123",
type="answer",
answer="{{#1777539038857.structured_output#}}",
),
node = _build_answer_node(
answer="{{#1777539038857.structured_output#}}",
variable_pool=variable_pool,
)
result = node._run()
@ -128,35 +87,9 @@ def test_execute_answer_renders_structured_output_object_as_json() -> None:
def test_execute_answer_falls_back_to_plain_selector_text_when_structured_output_missing() -> None:
init_params = build_test_graph_init_params(
workflow_id="1",
graph_config={"nodes": [], "edges": []},
tenant_id="1",
app_id="1",
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
variable_pool = VariablePool(
system_variables=build_system_variables(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
conversation_variables=[],
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
node = AnswerNode(
node_id=str(uuid.uuid4()),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
config=AnswerNodeData(
title="123",
type="answer",
answer="{{#1777539038857.structured_output#}}",
),
node = _build_answer_node(
answer="{{#1777539038857.structured_output#}}",
variable_pool=_build_variable_pool(),
)
result = node._run()