mirror of
https://github.com/langgenius/dify.git
synced 2026-05-09 12:59:18 +08:00
chore(quota_layer): remove useless method
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
f0442456df
commit
588ecc487f
@ -12,8 +12,6 @@ handling is never silently skipped.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from datetime import UTC, datetime
|
||||
from typing import final, override
|
||||
|
||||
from core.app.llm import deduct_llm_quota_for_model, ensure_llm_quota_available_for_model
|
||||
@ -21,13 +19,7 @@ from core.errors.error import QuotaExceededError
|
||||
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
|
||||
from graphon.graph_engine.entities.commands import AbortCommand, CommandType
|
||||
from graphon.graph_engine.layers import GraphEngineLayer
|
||||
from graphon.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphNodeEventBase,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, NodeRunSucceededEvent
|
||||
from graphon.node_events import NodeRunResult
|
||||
from graphon.nodes.base.node import Node
|
||||
|
||||
@ -130,8 +122,17 @@ class LLMQuotaLayer(GraphEngineLayer):
|
||||
|
||||
def _abort_before_node_run(self, *, node: Node, reason: str, error_type: str) -> None:
|
||||
self._set_stop_event(node)
|
||||
self._force_node_failure_to_abort(node)
|
||||
self._block_current_node_run(node=node, reason=reason, error_type=error_type)
|
||||
node.node_data.error_strategy = None
|
||||
node.node_data.retry_config.retry_enabled = False
|
||||
|
||||
def quota_aborted_run() -> NodeRunResult:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=reason,
|
||||
error_type=error_type,
|
||||
)
|
||||
|
||||
node._run = quota_aborted_run # type: ignore[method-assign]
|
||||
self._send_abort_command(reason=reason)
|
||||
|
||||
def _abort_for_missing_model_identity(self, *, node: Node, reason: str) -> None:
|
||||
@ -139,62 +140,6 @@ class LLMQuotaLayer(GraphEngineLayer):
|
||||
self._send_abort_command(reason=reason)
|
||||
logger.error("LLM quota handling aborted, node_id=%s, reason=%s", node.id, reason)
|
||||
|
||||
@staticmethod
|
||||
def _force_node_failure_to_abort(node: Node) -> None:
|
||||
node_data = getattr(node, "node_data", None)
|
||||
if node_data is None:
|
||||
return
|
||||
|
||||
# Quota aborts must not be converted into retry, default-value, or fail-branch execution.
|
||||
try:
|
||||
if hasattr(node_data, "error_strategy"):
|
||||
node_data.error_strategy = None
|
||||
|
||||
retry_config = getattr(node_data, "retry_config", None)
|
||||
if retry_config is not None and hasattr(retry_config, "retry_enabled"):
|
||||
retry_config.retry_enabled = False
|
||||
except Exception:
|
||||
logger.warning("Failed to force quota-aborted node into abort strategy, node_id=%s", node.id, exc_info=True)
|
||||
|
||||
@staticmethod
|
||||
def _block_current_node_run(*, node: Node, reason: str, error_type: str) -> None:
|
||||
def blocked_run() -> Generator[GraphNodeEventBase, None, None]:
|
||||
execution_id = node.ensure_execution_id()
|
||||
start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
start_event = NodeRunStartedEvent(
|
||||
id=execution_id,
|
||||
node_id=node.id,
|
||||
node_type=node.node_type,
|
||||
node_title=str(getattr(node, "title", node.id)),
|
||||
in_iteration_id=None,
|
||||
start_at=start_at,
|
||||
)
|
||||
populate_start_event = getattr(node, "populate_start_event", None)
|
||||
if callable(populate_start_event):
|
||||
try:
|
||||
populate_start_event(start_event)
|
||||
except Exception:
|
||||
logger.warning("Failed to populate quota-aborted start event, node_id=%s", node.id, exc_info=True)
|
||||
|
||||
yield start_event
|
||||
|
||||
finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
yield NodeRunFailedEvent(
|
||||
id=execution_id,
|
||||
node_id=node.id,
|
||||
node_type=node.node_type,
|
||||
start_at=start_at,
|
||||
finished_at=finished_at,
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=reason,
|
||||
error_type=error_type,
|
||||
),
|
||||
error=reason,
|
||||
)
|
||||
|
||||
object.__setattr__(node, "run", blocked_run)
|
||||
|
||||
def _send_abort_command(self, *, reason: str) -> None:
|
||||
if not self.command_channel or self._abort_sent:
|
||||
return
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import threading
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
@ -8,7 +7,7 @@ from core.app.workflow.layers.llm_quota import LLMQuotaLayer
|
||||
from core.errors.error import QuotaExceededError
|
||||
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
|
||||
from graphon.graph_engine.entities.commands import CommandType
|
||||
from graphon.graph_events import NodeRunFailedEvent, NodeRunStartedEvent, NodeRunSucceededEvent
|
||||
from graphon.graph_events import NodeRunSucceededEvent
|
||||
from graphon.model_runtime.entities.llm_entities import LLMUsage
|
||||
from graphon.node_events import NodeRunResult
|
||||
|
||||
@ -35,12 +34,20 @@ def _build_public_model_identity(*, provider: str = "openai", model_name: str =
|
||||
return SimpleNamespace(provider=provider, name=model_name)
|
||||
|
||||
|
||||
def _build_node_data(*, model: SimpleNamespace | None = None) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
error_strategy=None,
|
||||
retry_config=SimpleNamespace(retry_enabled=False),
|
||||
model=model,
|
||||
)
|
||||
|
||||
|
||||
def _build_node(*, node_type: BuiltinNodeTypes = BuiltinNodeTypes.LLM) -> MagicMock:
|
||||
node = MagicMock()
|
||||
node.id = "node-id"
|
||||
node.execution_id = "execution-id"
|
||||
node.node_type = node_type
|
||||
node.node_data = SimpleNamespace(model=_build_public_model_identity())
|
||||
node.node_data = _build_node_data(model=_build_public_model_identity())
|
||||
node.model_instance = SimpleNamespace(provider="stale-provider", model_name="stale-model")
|
||||
return node
|
||||
|
||||
@ -52,19 +59,13 @@ class _RunnableQuotaNode:
|
||||
title = "LLM node"
|
||||
|
||||
def __init__(self, *, stop_event: threading.Event, node_data: SimpleNamespace | None = None) -> None:
|
||||
self.node_data = node_data or SimpleNamespace(model=_build_public_model_identity())
|
||||
self.node_data = node_data or _build_node_data(model=_build_public_model_identity())
|
||||
self.graph_runtime_state = SimpleNamespace(stop_event=stop_event)
|
||||
self.original_run_called = False
|
||||
|
||||
def ensure_execution_id(self) -> str:
|
||||
return self.execution_id
|
||||
|
||||
def populate_start_event(self, event: NodeRunStartedEvent) -> None:
|
||||
_ = event
|
||||
|
||||
def run(self) -> Generator[NodeRunSucceededEvent, None, None]:
|
||||
def _run(self) -> NodeRunResult:
|
||||
self.original_run_called = True
|
||||
yield _build_succeeded_event()
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED)
|
||||
|
||||
|
||||
def test_deduct_quota_called_for_successful_llm_node() -> None:
|
||||
@ -184,13 +185,11 @@ def test_quota_precheck_failure_blocks_current_node_run() -> None:
|
||||
):
|
||||
layer.on_node_run_start(node)
|
||||
|
||||
events = list(node.run())
|
||||
result = node._run()
|
||||
assert not node.original_run_called
|
||||
assert isinstance(events[0], NodeRunStartedEvent)
|
||||
assert isinstance(events[1], NodeRunFailedEvent)
|
||||
assert events[1].error == "Model provider openai quota exceeded."
|
||||
assert events[1].node_run_result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert events[1].node_run_result.error_type == QuotaExceededError.__name__
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert result.error == "Model provider openai quota exceeded."
|
||||
assert result.error_type == QuotaExceededError.__name__
|
||||
|
||||
|
||||
def test_missing_model_identity_blocks_current_node_run() -> None:
|
||||
@ -198,17 +197,16 @@ def test_missing_model_identity_blocks_current_node_run() -> None:
|
||||
stop_event = threading.Event()
|
||||
layer.command_channel = MagicMock()
|
||||
|
||||
node = _RunnableQuotaNode(stop_event=stop_event, node_data=SimpleNamespace())
|
||||
node = _RunnableQuotaNode(stop_event=stop_event, node_data=_build_node_data())
|
||||
|
||||
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check:
|
||||
layer.on_node_run_start(node)
|
||||
|
||||
events = list(node.run())
|
||||
result = node._run()
|
||||
assert not node.original_run_called
|
||||
assert isinstance(events[1], NodeRunFailedEvent)
|
||||
assert events[1].error == "LLM quota check requires public node model identity before execution."
|
||||
assert events[1].node_run_result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert events[1].node_run_result.error_type == "LLMQuotaIdentityError"
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert result.error == "LLM quota check requires public node model identity before execution."
|
||||
assert result.error_type == "LLMQuotaIdentityError"
|
||||
mock_check.assert_not_called()
|
||||
|
||||
|
||||
@ -239,7 +237,7 @@ def test_precheck_requires_public_node_model_config() -> None:
|
||||
layer.command_channel = MagicMock()
|
||||
|
||||
node = _build_node(node_type=BuiltinNodeTypes.LLM)
|
||||
node.node_data = SimpleNamespace()
|
||||
node.node_data = _build_node_data()
|
||||
node.graph_runtime_state = MagicMock()
|
||||
node.graph_runtime_state.stop_event = stop_event
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user