mirror of
https://github.com/langgenius/dify.git
synced 2026-05-09 04:36:31 +08:00
fix: test
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
ba84ecc63f
commit
f0442456df
@ -12,14 +12,23 @@ handling is never silently skipped.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from datetime import UTC, datetime
|
||||
from typing import final, override
|
||||
|
||||
from core.app.llm import deduct_llm_quota_for_model, ensure_llm_quota_available_for_model
|
||||
from core.errors.error import QuotaExceededError
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
|
||||
from graphon.graph_engine.entities.commands import AbortCommand, CommandType
|
||||
from graphon.graph_engine.layers import GraphEngineLayer
|
||||
from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, NodeRunSucceededEvent
|
||||
from graphon.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphNodeEventBase,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from graphon.node_events import NodeRunResult
|
||||
from graphon.nodes.base.node import Node
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -66,10 +75,9 @@ class LLMQuotaLayer(GraphEngineLayer):
|
||||
|
||||
model_identity = self._extract_model_identity_from_node(node)
|
||||
if model_identity is None:
|
||||
self._abort_for_missing_model_identity(
|
||||
node=node,
|
||||
reason="LLM quota check requires public node model identity before execution.",
|
||||
)
|
||||
reason = "LLM quota check requires public node model identity before execution."
|
||||
self._abort_before_node_run(node=node, reason=reason, error_type="LLMQuotaIdentityError")
|
||||
logger.error("LLM quota handling aborted, node_id=%s, reason=%s", node.id, reason)
|
||||
return
|
||||
|
||||
provider, model_name = model_identity
|
||||
@ -80,8 +88,7 @@ class LLMQuotaLayer(GraphEngineLayer):
|
||||
model=model_name,
|
||||
)
|
||||
except QuotaExceededError as exc:
|
||||
self._set_stop_event(node)
|
||||
self._send_abort_command(reason=str(exc))
|
||||
self._abort_before_node_run(node=node, reason=str(exc), error_type=QuotaExceededError.__name__)
|
||||
logger.warning("LLM quota check failed, node_id=%s, error=%s", node.id, exc)
|
||||
|
||||
@override
|
||||
@ -121,11 +128,73 @@ class LLMQuotaLayer(GraphEngineLayer):
|
||||
if stop_event is not None:
|
||||
stop_event.set()
|
||||
|
||||
def _abort_before_node_run(self, *, node: Node, reason: str, error_type: str) -> None:
|
||||
self._set_stop_event(node)
|
||||
self._force_node_failure_to_abort(node)
|
||||
self._block_current_node_run(node=node, reason=reason, error_type=error_type)
|
||||
self._send_abort_command(reason=reason)
|
||||
|
||||
def _abort_for_missing_model_identity(self, *, node: Node, reason: str) -> None:
|
||||
self._set_stop_event(node)
|
||||
self._send_abort_command(reason=reason)
|
||||
logger.error("LLM quota handling aborted, node_id=%s, reason=%s", node.id, reason)
|
||||
|
||||
@staticmethod
|
||||
def _force_node_failure_to_abort(node: Node) -> None:
|
||||
node_data = getattr(node, "node_data", None)
|
||||
if node_data is None:
|
||||
return
|
||||
|
||||
# Quota aborts must not be converted into retry, default-value, or fail-branch execution.
|
||||
try:
|
||||
if hasattr(node_data, "error_strategy"):
|
||||
node_data.error_strategy = None
|
||||
|
||||
retry_config = getattr(node_data, "retry_config", None)
|
||||
if retry_config is not None and hasattr(retry_config, "retry_enabled"):
|
||||
retry_config.retry_enabled = False
|
||||
except Exception:
|
||||
logger.warning("Failed to force quota-aborted node into abort strategy, node_id=%s", node.id, exc_info=True)
|
||||
|
||||
@staticmethod
|
||||
def _block_current_node_run(*, node: Node, reason: str, error_type: str) -> None:
|
||||
def blocked_run() -> Generator[GraphNodeEventBase, None, None]:
|
||||
execution_id = node.ensure_execution_id()
|
||||
start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
start_event = NodeRunStartedEvent(
|
||||
id=execution_id,
|
||||
node_id=node.id,
|
||||
node_type=node.node_type,
|
||||
node_title=str(getattr(node, "title", node.id)),
|
||||
in_iteration_id=None,
|
||||
start_at=start_at,
|
||||
)
|
||||
populate_start_event = getattr(node, "populate_start_event", None)
|
||||
if callable(populate_start_event):
|
||||
try:
|
||||
populate_start_event(start_event)
|
||||
except Exception:
|
||||
logger.warning("Failed to populate quota-aborted start event, node_id=%s", node.id, exc_info=True)
|
||||
|
||||
yield start_event
|
||||
|
||||
finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
yield NodeRunFailedEvent(
|
||||
id=execution_id,
|
||||
node_id=node.id,
|
||||
node_type=node.node_type,
|
||||
start_at=start_at,
|
||||
finished_at=finished_at,
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=reason,
|
||||
error_type=error_type,
|
||||
),
|
||||
error=reason,
|
||||
)
|
||||
|
||||
object.__setattr__(node, "run", blocked_run)
|
||||
|
||||
def _send_abort_command(self, *, reason: str) -> None:
|
||||
if not self.command_channel or self._abort_sent:
|
||||
return
|
||||
|
||||
@ -132,7 +132,9 @@ class TestAdvancedChatGenerateTaskPipeline:
|
||||
pipeline._task_state.answer = "partial answer"
|
||||
pipeline._workflow_run_id = "run-id"
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
|
||||
variable_pool=build_test_variable_pool(
|
||||
variables=build_system_variables(workflow_execution_id="run-id"),
|
||||
),
|
||||
start_at=0.0,
|
||||
total_tokens=7,
|
||||
node_run_steps=3,
|
||||
|
||||
@ -95,7 +95,9 @@ class TestWorkflowGenerateTaskPipeline:
|
||||
def test_to_blocking_response_falls_back_to_human_input_required_when_pause_event_missing(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
|
||||
variable_pool=build_test_variable_pool(
|
||||
variables=build_system_variables(workflow_execution_id="run-id"),
|
||||
),
|
||||
start_at=0.0,
|
||||
total_tokens=5,
|
||||
node_run_steps=2,
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import threading
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
@ -7,7 +8,7 @@ from core.app.workflow.layers.llm_quota import LLMQuotaLayer
|
||||
from core.errors.error import QuotaExceededError
|
||||
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
|
||||
from graphon.graph_engine.entities.commands import CommandType
|
||||
from graphon.graph_events import NodeRunSucceededEvent
|
||||
from graphon.graph_events import NodeRunFailedEvent, NodeRunStartedEvent, NodeRunSucceededEvent
|
||||
from graphon.model_runtime.entities.llm_entities import LLMUsage
|
||||
from graphon.node_events import NodeRunResult
|
||||
|
||||
@ -44,6 +45,28 @@ def _build_node(*, node_type: BuiltinNodeTypes = BuiltinNodeTypes.LLM) -> MagicM
|
||||
return node
|
||||
|
||||
|
||||
class _RunnableQuotaNode:
|
||||
id = "node-id"
|
||||
execution_id = "execution-id"
|
||||
node_type = BuiltinNodeTypes.LLM
|
||||
title = "LLM node"
|
||||
|
||||
def __init__(self, *, stop_event: threading.Event, node_data: SimpleNamespace | None = None) -> None:
|
||||
self.node_data = node_data or SimpleNamespace(model=_build_public_model_identity())
|
||||
self.graph_runtime_state = SimpleNamespace(stop_event=stop_event)
|
||||
self.original_run_called = False
|
||||
|
||||
def ensure_execution_id(self) -> str:
|
||||
return self.execution_id
|
||||
|
||||
def populate_start_event(self, event: NodeRunStartedEvent) -> None:
|
||||
_ = event
|
||||
|
||||
def run(self) -> Generator[NodeRunSucceededEvent, None, None]:
|
||||
self.original_run_called = True
|
||||
yield _build_succeeded_event()
|
||||
|
||||
|
||||
def test_deduct_quota_called_for_successful_llm_node() -> None:
|
||||
layer = LLMQuotaLayer(tenant_id="tenant-id")
|
||||
node = _build_node(node_type=BuiltinNodeTypes.LLM)
|
||||
@ -147,6 +170,48 @@ def test_quota_precheck_failure_aborts_workflow_immediately() -> None:
|
||||
assert abort_command.reason == "Model provider openai quota exceeded."
|
||||
|
||||
|
||||
def test_quota_precheck_failure_blocks_current_node_run() -> None:
|
||||
layer = LLMQuotaLayer(tenant_id="tenant-id")
|
||||
stop_event = threading.Event()
|
||||
layer.command_channel = MagicMock()
|
||||
|
||||
node = _RunnableQuotaNode(stop_event=stop_event)
|
||||
|
||||
with patch(
|
||||
"core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model",
|
||||
autospec=True,
|
||||
side_effect=QuotaExceededError("Model provider openai quota exceeded."),
|
||||
):
|
||||
layer.on_node_run_start(node)
|
||||
|
||||
events = list(node.run())
|
||||
assert not node.original_run_called
|
||||
assert isinstance(events[0], NodeRunStartedEvent)
|
||||
assert isinstance(events[1], NodeRunFailedEvent)
|
||||
assert events[1].error == "Model provider openai quota exceeded."
|
||||
assert events[1].node_run_result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert events[1].node_run_result.error_type == QuotaExceededError.__name__
|
||||
|
||||
|
||||
def test_missing_model_identity_blocks_current_node_run() -> None:
|
||||
layer = LLMQuotaLayer(tenant_id="tenant-id")
|
||||
stop_event = threading.Event()
|
||||
layer.command_channel = MagicMock()
|
||||
|
||||
node = _RunnableQuotaNode(stop_event=stop_event, node_data=SimpleNamespace())
|
||||
|
||||
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check:
|
||||
layer.on_node_run_start(node)
|
||||
|
||||
events = list(node.run())
|
||||
assert not node.original_run_called
|
||||
assert isinstance(events[1], NodeRunFailedEvent)
|
||||
assert events[1].error == "LLM quota check requires public node model identity before execution."
|
||||
assert events[1].node_run_result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert events[1].node_run_result.error_type == "LLMQuotaIdentityError"
|
||||
mock_check.assert_not_called()
|
||||
|
||||
|
||||
def test_quota_precheck_passes_without_abort() -> None:
|
||||
layer = LLMQuotaLayer(tenant_id="tenant-id")
|
||||
stop_event = threading.Event()
|
||||
|
||||
@ -1,41 +1,36 @@
|
||||
import time
|
||||
import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.system_variables import build_system_variables
|
||||
from extensions.ext_database import db
|
||||
from graphon.enums import WorkflowNodeExecutionStatus
|
||||
from graphon.graph import Graph
|
||||
from graphon.nodes.answer.answer_node import AnswerNode
|
||||
from graphon.nodes.answer.entities import AnswerNodeData
|
||||
from graphon.runtime import GraphRuntimeState, VariablePool
|
||||
from tests.workflow_test_utils import build_test_graph_init_params
|
||||
|
||||
|
||||
def test_execute_answer():
|
||||
def _build_variable_pool() -> VariablePool:
|
||||
return VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
|
||||
def _build_answer_node(*, answer: str, variable_pool: VariablePool) -> AnswerNode:
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-answer-target",
|
||||
"source": "start",
|
||||
"target": "answer",
|
||||
},
|
||||
],
|
||||
"edges": [],
|
||||
"nodes": [
|
||||
{"data": {"type": "start", "title": "Start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"title": "123",
|
||||
"title": "Answer",
|
||||
"type": "answer",
|
||||
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
|
||||
"answer": answer,
|
||||
},
|
||||
"id": "answer",
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
init_params = build_test_graph_init_params(
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
@ -46,42 +41,31 @@ def test_execute_answer():
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=time.perf_counter(),
|
||||
)
|
||||
variable_pool.add(["start", "weather"], "sunny")
|
||||
variable_pool.add(["llm", "text"], "You are a helpful AI.")
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# create node factory
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
|
||||
|
||||
node = AnswerNode(
|
||||
return AnswerNode(
|
||||
node_id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
data=AnswerNodeData(
|
||||
title="123",
|
||||
title="Answer",
|
||||
type="answer",
|
||||
answer="Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
|
||||
answer=answer,
|
||||
),
|
||||
)
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
# execute node
|
||||
def test_execute_answer_renders_variable_selectors() -> None:
|
||||
variable_pool = _build_variable_pool()
|
||||
variable_pool.add(["start", "weather"], "sunny")
|
||||
variable_pool.add(["llm", "text"], "You are a helpful AI.")
|
||||
node = _build_answer_node(
|
||||
answer="Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
@ -89,36 +73,11 @@ def test_execute_answer():
|
||||
|
||||
|
||||
def test_execute_answer_renders_structured_output_object_as_json() -> None:
|
||||
init_params = build_test_graph_init_params(
|
||||
workflow_id="1",
|
||||
graph_config={"nodes": [], "edges": []},
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=build_system_variables(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
variable_pool = _build_variable_pool()
|
||||
variable_pool.add(["1777539038857", "structured_output"], {"type": "greeting"})
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
node = AnswerNode(
|
||||
node_id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
config=AnswerNodeData(
|
||||
title="123",
|
||||
type="answer",
|
||||
answer="{{#1777539038857.structured_output#}}",
|
||||
),
|
||||
node = _build_answer_node(
|
||||
answer="{{#1777539038857.structured_output#}}",
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
@ -128,35 +87,9 @@ def test_execute_answer_renders_structured_output_object_as_json() -> None:
|
||||
|
||||
|
||||
def test_execute_answer_falls_back_to_plain_selector_text_when_structured_output_missing() -> None:
|
||||
init_params = build_test_graph_init_params(
|
||||
workflow_id="1",
|
||||
graph_config={"nodes": [], "edges": []},
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=build_system_variables(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
node = AnswerNode(
|
||||
node_id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
config=AnswerNodeData(
|
||||
title="123",
|
||||
type="answer",
|
||||
answer="{{#1777539038857.structured_output#}}",
|
||||
),
|
||||
node = _build_answer_node(
|
||||
answer="{{#1777539038857.structured_output#}}",
|
||||
variable_pool=_build_variable_pool(),
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user