chore(api): upgrade graphon to v0.3.0 (#35469)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: WH-2099 <wh2099@pm.me>
This commit is contained in:
-LAN- 2026-05-09 15:30:03 +08:00 committed by GitHub
parent f3eb3ab4dd
commit 19476109da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
80 changed files with 2526 additions and 673 deletions

View File

@ -1,5 +1,15 @@
"""LLM-related application services."""
from .quota import deduct_llm_quota, ensure_llm_quota_available
from .quota import (
deduct_llm_quota,
deduct_llm_quota_for_model,
ensure_llm_quota_available,
ensure_llm_quota_available_for_model,
)
__all__ = ["deduct_llm_quota", "ensure_llm_quota_available"]
__all__ = [
"deduct_llm_quota",
"deduct_llm_quota_for_model",
"ensure_llm_quota_available",
"ensure_llm_quota_available_for_model",
]

View File

@ -1,4 +1,14 @@
from sqlalchemy import update
"""Tenant-scoped helpers for checking and deducting LLM provider quota.
System-hosted quota accounting is currently defined only for LLM models. Keep
the public helpers LLM-specific so callers do not carry unused model-type
plumbing, and fail loudly if the deprecated ``ModelInstance`` wrappers are used
with a non-LLM model.
"""
import warnings
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from configs import dify_config
@ -6,44 +16,47 @@ from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from core.errors.error import QuotaExceededError
from core.model_manager import ModelInstance
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from extensions.ext_database import db
from graphon.model_runtime.entities.llm_entities import LLMUsage
from graphon.model_runtime.entities.model_entities import ModelType
from libs.datetime_utils import naive_utc_now
from models.provider import Provider, ProviderType
from models.provider_ids import ModelProviderID
def ensure_llm_quota_available(*, model_instance: ModelInstance) -> None:
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
def _get_provider_configuration(*, tenant_id: str, provider: str):
"""Resolve the tenant-bound provider configuration for quota decisions."""
provider_manager = create_plugin_provider_manager(tenant_id=tenant_id)
provider_configuration = provider_manager.get_configurations(tenant_id).get(provider)
if provider_configuration is None:
raise ValueError(f"Provider {provider} does not exist.")
return provider_configuration
def ensure_llm_quota_available_for_model(*, tenant_id: str, provider: str, model: str) -> None:
"""Raise when a tenant-bound LLM model is already out of quota."""
provider_configuration = _get_provider_configuration(tenant_id=tenant_id, provider=provider)
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
provider_model = provider_configuration.get_provider_model(
model_type=model_instance.model_type_instance.model_type,
model=model_instance.model_name,
model_type=ModelType.LLM,
model=model,
)
if provider_model and provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {model_instance.provider} quota exceeded.")
raise QuotaExceededError(f"Model provider {provider} quota exceeded.")
def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
system_configuration = provider_configuration.system_configuration
def _resolve_llm_used_quota(*, system_configuration, model: str, usage: LLMUsage) -> int | None:
"""Compute the quota impact for an LLM invocation under the current quota mode."""
quota_unit = None
for quota_configuration in system_configuration.quota_configurations:
if quota_configuration.quota_type == system_configuration.current_quota_type:
quota_unit = quota_configuration.quota_unit
if quota_configuration.quota_limit == -1:
return
return None
break
@ -52,42 +65,136 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL
if quota_unit == QuotaUnit.TOKENS:
used_quota = usage.total_tokens
elif quota_unit == QuotaUnit.CREDITS:
used_quota = dify_config.get_model_credits(model_instance.model_name)
used_quota = dify_config.get_model_credits(model)
else:
used_quota = 1
return used_quota
def _deduct_free_llm_quota(
*,
tenant_id: str,
provider: str,
quota_type: ProviderQuotaType,
used_quota: int,
) -> None:
"""Deduct FREE provider quota, capping at the limit before reporting exhaustion."""
quota_exceeded = False
with sessionmaker(bind=db.engine).begin() as session:
provider_record = session.scalar(
select(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == quota_type,
)
.with_for_update()
)
if (
provider_record is None
or provider_record.quota_limit is None
or provider_record.quota_used is None
or provider_record.quota_limit <= provider_record.quota_used
):
quota_exceeded = True
else:
available_quota = provider_record.quota_limit - provider_record.quota_used
deducted_quota = min(used_quota, available_quota)
provider_record.quota_used += deducted_quota
provider_record.last_used = naive_utc_now()
quota_exceeded = deducted_quota < used_quota
if quota_exceeded:
raise QuotaExceededError(f"Model provider {provider} quota exceeded.")
def _deduct_used_llm_quota(*, tenant_id: str, provider: str, provider_configuration, used_quota: int | None) -> None:
"""Apply a resolved LLM quota charge against the current provider quota bucket."""
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
system_configuration = provider_configuration.system_configuration
if used_quota is not None and system_configuration.current_quota_type is not None:
match system_configuration.current_quota_type:
case ProviderQuotaType.TRIAL:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
CreditPoolService.deduct_credits_capped(
tenant_id=tenant_id,
credits_required=used_quota,
)
case ProviderQuotaType.PAID:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
CreditPoolService.deduct_credits_capped(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="paid",
)
case ProviderQuotaType.FREE:
with sessionmaker(bind=db.engine).begin() as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=naive_utc_now(),
)
)
session.execute(stmt)
_deduct_free_llm_quota(
tenant_id=tenant_id,
provider=provider,
quota_type=system_configuration.current_quota_type,
used_quota=used_quota,
)
case _:
return
def deduct_llm_quota_for_model(*, tenant_id: str, provider: str, model: str, usage: LLMUsage) -> None:
"""Deduct tenant-bound quota for the resolved LLM model identity."""
provider_configuration = _get_provider_configuration(tenant_id=tenant_id, provider=provider)
used_quota = _resolve_llm_used_quota(
system_configuration=provider_configuration.system_configuration,
model=model,
usage=usage,
)
_deduct_used_llm_quota(
tenant_id=tenant_id,
provider=provider,
provider_configuration=provider_configuration,
used_quota=used_quota,
)
def _require_llm_model_instance(model_instance: ModelInstance) -> None:
"""Reject deprecated wrapper calls that pass a non-LLM model instance."""
if model_instance.model_type_instance.model_type != ModelType.LLM:
raise ValueError("LLM quota helpers only support LLM model instances.")
def ensure_llm_quota_available(*, model_instance: ModelInstance) -> None:
"""Deprecated compatibility wrapper for callers that still pass ModelInstance."""
warnings.warn(
"ensure_llm_quota_available(model_instance=...) is deprecated; "
"use ensure_llm_quota_available_for_model(...) instead.",
DeprecationWarning,
stacklevel=2,
)
_require_llm_model_instance(model_instance)
ensure_llm_quota_available_for_model(
tenant_id=model_instance.provider_model_bundle.configuration.tenant_id,
provider=model_instance.provider,
model=model_instance.model_name,
)
def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
"""Deprecated compatibility wrapper for callers that still pass ModelInstance."""
warnings.warn(
"deduct_llm_quota(tenant_id=..., model_instance=..., usage=...) is deprecated; "
"use deduct_llm_quota_for_model(...) instead.",
DeprecationWarning,
stacklevel=2,
)
_require_llm_model_instance(model_instance)
deduct_llm_quota_for_model(
tenant_id=tenant_id,
provider=model_instance.provider,
model=model_instance.model_name,
usage=usage,
)

View File

@ -1,36 +1,48 @@
"""
LLM quota deduction layer for GraphEngine.
This layer centralizes model-quota deduction outside node implementations.
This layer centralizes model-quota handling outside node implementations.
Graphon LLM-backed nodes expose provider/model identity through public node
configuration and, after execution, through ``node_run_result.inputs``. Resolve
quota billing from that public identity instead of depending on
``ModelInstance`` reconstruction inside the workflow layer. Missing identity on
quota-tracked nodes is treated as a workflow bug and aborts execution so quota
handling is never silently skipped.
"""
import logging
from typing import TYPE_CHECKING, cast, final, override
from typing import final, override
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
from core.app.llm import deduct_llm_quota, ensure_llm_quota_available
from core.app.llm import deduct_llm_quota_for_model, ensure_llm_quota_available_for_model
from core.errors.error import QuotaExceededError
from core.model_manager import ModelInstance
from graphon.enums import BuiltinNodeTypes
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
from graphon.graph_engine.entities.commands import AbortCommand, CommandType
from graphon.graph_engine.layers import GraphEngineLayer
from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, NodeRunSucceededEvent
from graphon.node_events import NodeRunResult
from graphon.nodes.base.node import Node
if TYPE_CHECKING:
from graphon.nodes.llm.node import LLMNode
from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from graphon.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
logger = logging.getLogger(__name__)
_QUOTA_NODE_TYPES = frozenset(
[
BuiltinNodeTypes.LLM,
BuiltinNodeTypes.PARAMETER_EXTRACTOR,
BuiltinNodeTypes.QUESTION_CLASSIFIER,
]
)
@final
class LLMQuotaLayer(GraphEngineLayer):
"""Graph layer that applies LLM quota deduction after node execution."""
"""Graph layer that applies tenant-scoped quota checks to LLM-backed nodes."""
def __init__(self) -> None:
tenant_id: str
_abort_sent: bool
def __init__(self, tenant_id: str) -> None:
super().__init__()
self.tenant_id = tenant_id
self._abort_sent = False
@override
@ -50,33 +62,49 @@ class LLMQuotaLayer(GraphEngineLayer):
if self._abort_sent:
return
model_instance = self._extract_model_instance(node)
if model_instance is None:
if not self._supports_quota(node):
return
model_identity = self._extract_model_identity_from_node(node)
if model_identity is None:
reason = "LLM quota check requires public node model identity before execution."
self._abort_before_node_run(node=node, reason=reason, error_type="LLMQuotaIdentityError")
logger.error("LLM quota handling aborted, node_id=%s, reason=%s", node.id, reason)
return
provider, model_name = model_identity
try:
ensure_llm_quota_available(model_instance=model_instance)
ensure_llm_quota_available_for_model(
tenant_id=self.tenant_id,
provider=provider,
model=model_name,
)
except QuotaExceededError as exc:
self._set_stop_event(node)
self._send_abort_command(reason=str(exc))
self._abort_before_node_run(node=node, reason=str(exc), error_type=QuotaExceededError.__name__)
logger.warning("LLM quota check failed, node_id=%s, error=%s", node.id, exc)
@override
def on_node_run_end(
self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None
) -> None:
if error is not None or not isinstance(result_event, NodeRunSucceededEvent):
if error is not None or not isinstance(result_event, NodeRunSucceededEvent) or not self._supports_quota(node):
return
model_instance = self._extract_model_instance(node)
if model_instance is None:
model_identity = self._extract_model_identity_from_result_event(result_event)
if model_identity is None:
self._abort_for_missing_model_identity(
node=node,
reason="LLM quota deduction requires model identity in the node result event.",
)
return
provider, model_name = model_identity
try:
dify_ctx = DifyRunContext.model_validate(node.require_run_context_value(DIFY_RUN_CONTEXT_KEY))
deduct_llm_quota(
tenant_id=dify_ctx.tenant_id,
model_instance=model_instance,
deduct_llm_quota_for_model(
tenant_id=self.tenant_id,
provider=provider,
model=model_name,
usage=result_event.node_run_result.llm_usage,
)
except QuotaExceededError as exc:
@ -92,6 +120,27 @@ class LLMQuotaLayer(GraphEngineLayer):
if stop_event is not None:
stop_event.set()
def _abort_before_node_run(self, *, node: Node, reason: str, error_type: str) -> None:
self._set_stop_event(node)
node.node_data.error_strategy = None
node.node_data.retry_config.retry_enabled = False
def quota_aborted_run() -> NodeRunResult:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=reason,
error_type=error_type,
)
# TODO: Push Graphon to expose a public pre-run failure/skip hook, then replace this private _run override.
node._run = quota_aborted_run # type: ignore[method-assign]
self._send_abort_command(reason=reason)
def _abort_for_missing_model_identity(self, *, node: Node, reason: str) -> None:
self._set_stop_event(node)
self._send_abort_command(reason=reason)
logger.error("LLM quota handling aborted, node_id=%s, reason=%s", node.id, reason)
def _send_abort_command(self, *, reason: str) -> None:
if not self.command_channel or self._abort_sent:
return
@ -108,29 +157,38 @@ class LLMQuotaLayer(GraphEngineLayer):
logger.exception("Failed to send quota abort command")
@staticmethod
def _extract_model_instance(node: Node) -> ModelInstance | None:
try:
match node.node_type:
case BuiltinNodeTypes.LLM:
model_instance = cast("LLMNode", node).model_instance
case BuiltinNodeTypes.PARAMETER_EXTRACTOR:
model_instance = cast("ParameterExtractorNode", node).model_instance
case BuiltinNodeTypes.QUESTION_CLASSIFIER:
model_instance = cast("QuestionClassifierNode", node).model_instance
case _:
return None
except AttributeError:
def _supports_quota(node: Node) -> bool:
return node.node_type in _QUOTA_NODE_TYPES
@staticmethod
def _extract_model_identity_from_result_event(result_event: NodeRunSucceededEvent) -> tuple[str, str] | None:
provider = result_event.node_run_result.inputs.get("model_provider")
model_name = result_event.node_run_result.inputs.get("model_name")
if isinstance(provider, str) and provider and isinstance(model_name, str) and model_name:
return provider, model_name
return None
@staticmethod
def _extract_model_identity_from_node(node: Node) -> tuple[str, str] | None:
node_data = getattr(node, "node_data", None)
if node_data is None:
node_data = getattr(node, "data", None)
model_config = getattr(node_data, "model", None)
if model_config is None:
logger.warning(
"LLMQuotaLayer skipped quota deduction because node does not expose a model instance, node_id=%s",
"LLMQuotaLayer skipped quota handling because node model config is missing, node_id=%s",
node.id,
)
return None
if isinstance(model_instance, ModelInstance):
return model_instance
raw_model_instance = getattr(model_instance, "_model_instance", None)
if isinstance(raw_model_instance, ModelInstance):
return raw_model_instance
provider = getattr(model_config, "provider", None)
model_name = getattr(model_config, "name", None)
if isinstance(provider, str) and provider and isinstance(model_name, str) and model_name:
return provider, model_name
logger.warning(
"LLMQuotaLayer skipped quota handling because node model identity is invalid, node_id=%s",
node.id,
)
return None

View File

@ -23,7 +23,7 @@ from core.entities.provider_entities import (
)
from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory
from core.plugin.impl.model_runtime_factory import create_model_type_instance, create_plugin_model_assembly
from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from graphon.model_runtime.entities.provider_entities import (
ConfigurateMethod,
@ -33,7 +33,7 @@ from graphon.model_runtime.entities.provider_entities import (
)
from graphon.model_runtime.model_providers.base.ai_model import AIModel
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from graphon.model_runtime.runtime import ModelRuntime
from graphon.model_runtime.protocols.runtime import ModelRuntime
from libs.datetime_utils import naive_utc_now
from models.engine import db
from models.enums import CredentialSourceType
@ -106,11 +106,18 @@ class ProviderConfiguration(BaseModel):
"""Attach the already-composed runtime for request-bound call chains."""
self._bound_model_runtime = model_runtime
def _get_runtime_and_provider_factory(self) -> tuple[ModelRuntime, ModelProviderFactory]:
"""Resolve a provider factory that stays aligned with the runtime used by the caller."""
if self._bound_model_runtime is not None:
return self._bound_model_runtime, ModelProviderFactory(runtime=self._bound_model_runtime)
model_assembly = create_plugin_model_assembly(tenant_id=self.tenant_id)
return model_assembly.model_runtime, model_assembly.model_provider_factory
def get_model_provider_factory(self) -> ModelProviderFactory:
"""Return a provider factory that preserves any request-bound runtime."""
if self._bound_model_runtime is not None:
return ModelProviderFactory(model_runtime=self._bound_model_runtime)
return create_plugin_model_provider_factory(tenant_id=self.tenant_id)
_, model_provider_factory = self._get_runtime_and_provider_factory()
return model_provider_factory
def get_current_credentials(self, model_type: ModelType, model: str) -> dict[str, Any] | None:
"""
@ -1392,10 +1399,13 @@ class ProviderConfiguration(BaseModel):
:param model_type: model type
:return:
"""
model_provider_factory = self.get_model_provider_factory()
# Get model instance of LLM
return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
model_runtime, model_provider_factory = self._get_runtime_and_provider_factory()
provider_schema = model_provider_factory.get_provider_schema(provider=self.provider.provider)
return create_model_type_instance(
runtime=model_runtime,
provider_schema=provider_schema,
model_type=model_type,
)
def get_model_schema(
self, model_type: ModelType, model: str, credentials: dict[str, Any] | None

View File

@ -4,7 +4,7 @@ from typing import cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities import DEFAULT_PLUGIN_ID
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory
from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly
from extensions.ext_hosting_provider import hosting_configuration
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.invoke import InvokeBadRequestError
@ -41,10 +41,8 @@ def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEnt
text_chunk = secrets.choice(text_chunks)
try:
model_provider_factory = create_plugin_model_provider_factory(tenant_id=tenant_id)
# Get model instance of LLM
model_type_instance = model_provider_factory.get_model_type_instance(
model_assembly = create_plugin_model_assembly(tenant_id=tenant_id)
model_type_instance = model_assembly.create_model_type_instance(
provider=openai_provider_name, model_type=ModelType.MODERATION
)
model_type_instance = cast(ModerationModel, model_type_instance)

View File

@ -4,23 +4,32 @@ import hashlib
import logging
from collections.abc import Generator, Iterable, Sequence
from threading import Lock
from typing import IO, Any, Union
from typing import IO, Any, Literal, cast, overload
from pydantic import ValidationError
from redis import RedisError
from configs import dify_config
from core.llm_generator.output_parser.structured_output import (
invoke_llm_with_structured_output as invoke_llm_with_structured_output_helper,
)
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.plugin.impl.asset import PluginAssetManager
from core.plugin.impl.model import PluginModelClient
from extensions.ext_redis import redis_client
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from graphon.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
LLMResultChunkWithStructuredOutput,
LLMResultWithStructuredOutput,
)
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType
from graphon.model_runtime.entities.provider_entities import ProviderEntity
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult
from graphon.model_runtime.runtime import ModelRuntime
from graphon.model_runtime.model_providers.base.large_language_model import normalize_non_stream_runtime_result
from graphon.model_runtime.protocols.runtime import ModelRuntime
from models.provider_ids import ModelProviderID
logger = logging.getLogger(__name__)
@ -29,6 +38,68 @@ logger = logging.getLogger(__name__)
TENANT_SCOPE_SCHEMA_CACHE_USER_ID = "__DIFY_TS__"
# TODO(-LAN-): Move native structured-output invocation into Graphon's LLM node.
# TODO(-LAN-): Remove this Dify-side adapter once Graphon owns structured output end-to-end.
class _PluginStructuredOutputModelInstance:
"""Bind plugin model identity to the shared structured-output helper.
The structured-output parser is shared with legacy ``ModelInstance`` flows
and only needs an object exposing ``invoke_llm(...)``. ``PluginModelRuntime``
intentionally exposes a lower-level API where provider, model, and
credentials are passed per call. This adapter supplies the small bound
``invoke_llm`` surface the helper needs without constructing a full
``ModelInstance`` or reintroducing model-manager dependencies into the
plugin runtime path.
"""
def __init__(
self,
*,
runtime: PluginModelRuntime,
provider: str,
model: str,
credentials: dict[str, Any],
) -> None:
self._runtime = runtime
self._provider = provider
self._model = model
self._credentials = credentials
def invoke_llm(
self,
*,
prompt_messages: Sequence[PromptMessage],
model_parameters: dict[str, Any] | None = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
callbacks: object | None = None,
) -> LLMResult | Generator[LLMResultChunk, None, None]:
del callbacks
if stream:
return self._runtime.invoke_llm(
provider=self._provider,
model=self._model,
credentials=self._credentials,
model_parameters=model_parameters or {},
prompt_messages=prompt_messages,
tools=list(tools) if tools else None,
stop=stop,
stream=True,
)
return self._runtime.invoke_llm(
provider=self._provider,
model=self._model,
credentials=self._credentials,
model_parameters=model_parameters or {},
prompt_messages=prompt_messages,
tools=list(tools) if tools else None,
stop=stop,
stream=False,
)
class PluginModelRuntime(ModelRuntime):
"""Plugin-backed runtime adapter bound to tenant context and optional caller scope."""
@ -195,6 +266,34 @@ class PluginModelRuntime(ModelRuntime):
return schema
@overload
def invoke_llm(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
tools: list[PromptMessageTool] | None,
stop: Sequence[str] | None,
stream: Literal[False],
) -> LLMResult: ...
@overload
def invoke_llm(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
tools: list[PromptMessageTool] | None,
stop: Sequence[str] | None,
stream: Literal[True],
) -> Generator[LLMResultChunk, None, None]: ...
def invoke_llm(
self,
*,
@ -206,9 +305,9 @@ class PluginModelRuntime(ModelRuntime):
tools: list[PromptMessageTool] | None,
stop: Sequence[str] | None,
stream: bool,
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
) -> LLMResult | Generator[LLMResultChunk, None, None]:
plugin_id, provider_name = self._split_provider(provider)
return self.client.invoke_llm(
result = self.client.invoke_llm(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
@ -221,6 +320,81 @@ class PluginModelRuntime(ModelRuntime):
stop=list(stop) if stop else None,
stream=stream,
)
if stream:
return result
return normalize_non_stream_runtime_result(
model=model,
prompt_messages=prompt_messages,
result=result,
)
@overload
def invoke_llm_with_structured_output(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
json_schema: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
stop: Sequence[str] | None,
stream: Literal[False],
) -> LLMResultWithStructuredOutput: ...
@overload
def invoke_llm_with_structured_output(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
json_schema: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
stop: Sequence[str] | None,
stream: Literal[True],
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
def invoke_llm_with_structured_output(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
json_schema: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
stop: Sequence[str] | None,
stream: bool,
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]:
model_schema = self.get_model_schema(
provider=provider,
model_type=ModelType.LLM,
model=model,
credentials=credentials,
)
if model_schema is None:
raise ValueError(f"Model schema not found for {model}")
adapter = _PluginStructuredOutputModelInstance(
runtime=self,
provider=provider,
model=model,
credentials=credentials,
)
return invoke_llm_with_structured_output_helper(
provider=provider,
model_schema=model_schema,
model_instance=cast(Any, adapter),
prompt_messages=prompt_messages,
json_schema=json_schema,
model_parameters=model_parameters,
tools=None,
stop=list(stop) if stop else None,
stream=stream,
)
def get_llm_num_tokens(
self,

View File

@ -3,13 +3,46 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from core.plugin.impl.model import PluginModelClient
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.entities.provider_entities import ProviderEntity
from graphon.model_runtime.model_providers.base.ai_model import AIModel
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
from graphon.model_runtime.model_providers.base.rerank_model import RerankModel
from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel
from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
from graphon.model_runtime.model_providers.base.tts_model import TTSModel
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from graphon.model_runtime.protocols.runtime import ModelRuntime
if TYPE_CHECKING:
from core.model_manager import ModelManager
from core.plugin.impl.model_runtime import PluginModelRuntime
from core.provider_manager import ProviderManager
_MODEL_CLASS_BY_TYPE: dict[ModelType, type[AIModel]] = {
ModelType.LLM: LargeLanguageModel,
ModelType.TEXT_EMBEDDING: TextEmbeddingModel,
ModelType.RERANK: RerankModel,
ModelType.SPEECH2TEXT: Speech2TextModel,
ModelType.MODERATION: ModerationModel,
ModelType.TTS: TTSModel,
}
def create_model_type_instance(
*,
runtime: ModelRuntime,
provider_schema: ProviderEntity,
model_type: ModelType,
) -> AIModel:
"""Build the graphon model wrapper explicitly against the request runtime."""
model_class = _MODEL_CLASS_BY_TYPE.get(model_type)
if model_class is None:
raise ValueError(f"Unsupported model type: {model_type}")
return model_class(provider_schema=provider_schema, model_runtime=runtime)
class PluginModelAssembly:
"""Compose request-scoped model views on top of a single plugin runtime."""
@ -38,9 +71,22 @@ class PluginModelAssembly:
@property
def model_provider_factory(self) -> ModelProviderFactory:
if self._model_provider_factory is None:
self._model_provider_factory = ModelProviderFactory(model_runtime=self.model_runtime)
self._model_provider_factory = ModelProviderFactory(runtime=self.model_runtime)
return self._model_provider_factory
def create_model_type_instance(
self,
*,
provider: str,
model_type: ModelType,
) -> AIModel:
provider_schema = self.model_provider_factory.get_provider_schema(provider=provider)
return create_model_type_instance(
runtime=self.model_runtime,
provider_schema=provider_schema,
model_type=model_type,
)
@property
def provider_manager(self) -> ProviderManager:
if self._provider_manager is None:

View File

@ -56,7 +56,7 @@ from models.provider_ids import ModelProviderID
from services.feature_service import FeatureService
if TYPE_CHECKING:
from graphon.model_runtime.runtime import ModelRuntime
from graphon.model_runtime.protocols.runtime import ModelRuntime
_credentials_adapter: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any])
@ -165,7 +165,7 @@ class ProviderManager:
)
# Get all provider entities
model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime)
model_provider_factory = ModelProviderFactory(runtime=self._model_runtime)
provider_entities = model_provider_factory.get_providers()
# Get All preferred provider types of the workspace
@ -362,7 +362,7 @@ class ProviderManager:
if not default_model:
return None
model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime)
model_provider_factory = ModelProviderFactory(runtime=self._model_runtime)
provider_schema = model_provider_factory.get_provider_schema(provider=default_model.provider_name)
return DefaultModelEntity(

View File

@ -374,11 +374,6 @@ class DifyNodeFactory(NodeFactory):
# Re-validate using the resolved node class so workflow-local node schemas
# stay explicit and constructors receive the concrete typed payload.
resolved_node_data = self._validate_resolved_node_data(node_class, node_data)
config_for_node_init: BaseNodeData | dict[str, Any]
if isinstance(resolved_node_data, BaseNodeData):
config_for_node_init = resolved_node_data.model_dump(mode="python", by_alias=True)
else:
config_for_node_init = resolved_node_data
node_type = node_data.type
node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = {
BuiltinNodeTypes.CODE: lambda: {
@ -446,9 +441,10 @@ class DifyNodeFactory(NodeFactory):
},
}
node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})()
constructor_node_data = resolved_node_data.model_dump(mode="python", by_alias=True)
return node_class(
node_id=node_id,
config=config_for_node_init,
data=constructor_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
**node_init_kwargs,

View File

@ -35,7 +35,7 @@ class AgentNode(Node[AgentNodeData]):
def __init__(
self,
node_id: str,
config: AgentNodeData,
data: AgentNodeData,
*,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
@ -46,7 +46,7 @@ class AgentNode(Node[AgentNodeData]):
) -> None:
super().__init__(
node_id=node_id,
config=config,
data=data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)

View File

@ -36,14 +36,14 @@ class DatasourceNode(Node[DatasourceNodeData]):
def __init__(
self,
node_id: str,
config: DatasourceNodeData,
data: DatasourceNodeData,
*,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
) -> None:
super().__init__(
node_id=node_id,
config=config,
data=data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)

View File

@ -32,14 +32,14 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
def __init__(
self,
node_id: str,
config: KnowledgeIndexNodeData,
data: KnowledgeIndexNodeData,
*,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
) -> None:
super().__init__(
node_id=node_id,
config=config,
data=data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)

View File

@ -71,14 +71,14 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
def __init__(
self,
node_id: str,
config: KnowledgeRetrievalNodeData,
data: KnowledgeRetrievalNodeData,
*,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
) -> None:
super().__init__(
node_id=node_id,
config=config,
data=data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)

View File

@ -3,7 +3,7 @@ from __future__ import annotations
from collections import defaultdict
from collections.abc import Mapping, Sequence
from enum import StrEnum
from typing import Any, Protocol, cast
from typing import Any, Protocol
from uuid import uuid4
from graphon.enums import BuiltinNodeTypes
@ -82,13 +82,10 @@ def build_system_variables(values: Mapping[str, Any] | None = None, /, **kwargs:
normalized = _normalize_system_variable_values(values, **kwargs)
return [
cast(
Variable,
segment_to_variable(
segment=build_segment(value),
selector=system_variable_selector(key),
name=key,
),
segment_to_variable(
segment=build_segment(value),
selector=system_variable_selector(key),
name=key,
)
for key, value in normalized.items()
]
@ -130,13 +127,10 @@ def build_bootstrap_variables(
for node_id, value in rag_pipeline_variables_map.items():
variables.append(
cast(
Variable,
segment_to_variable(
segment=build_segment(value),
selector=(RAG_PIPELINE_VARIABLE_NODE_ID, node_id),
name=node_id,
),
segment_to_variable(
segment=build_segment(value),
selector=(RAG_PIPELINE_VARIABLE_NODE_ID, node_id),
name=node_id,
)
)

View File

@ -46,6 +46,11 @@ _file_access_controller = DatabaseFileAccessController()
class _WorkflowChildEngineBuilder:
tenant_id: str
def __init__(self, *, tenant_id: str) -> None:
self.tenant_id = tenant_id
@staticmethod
def _has_node_id(graph_config: Mapping[str, Any], node_id: str) -> bool | None:
"""
@ -107,7 +112,7 @@ class _WorkflowChildEngineBuilder:
config=config,
child_engine_builder=self,
)
child_engine.layer(LLMQuotaLayer())
child_engine.layer(LLMQuotaLayer(tenant_id=self.tenant_id))
return child_engine
@ -176,7 +181,7 @@ class WorkflowEntry:
self.command_channel = command_channel
execution_context = capture_current_context()
graph_runtime_state.execution_context = execution_context
self._child_engine_builder = _WorkflowChildEngineBuilder()
self._child_engine_builder = _WorkflowChildEngineBuilder(tenant_id=tenant_id)
self.graph_engine = GraphEngine(
workflow_id=workflow_id,
graph=graph,
@ -208,7 +213,7 @@ class WorkflowEntry:
max_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
)
self.graph_engine.layer(limits_layer)
self.graph_engine.layer(LLMQuotaLayer())
self.graph_engine.layer(LLMQuotaLayer(tenant_id=tenant_id))
# Add observability layer when OTel is enabled
if dify_config.ENABLE_OTEL or is_instrument_flag_enabled():

View File

@ -137,17 +137,13 @@ def handle(sender: Message, **kwargs):
if used_quota is not None:
match provider_configuration.system_configuration.current_quota_type:
case ProviderQuotaType.TRIAL:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
_deduct_credit_pool_quota_capped(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="trial",
)
case ProviderQuotaType.PAID:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
_deduct_credit_pool_quota_capped(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="paid",
@ -200,6 +196,26 @@ def handle(sender: Message, **kwargs):
raise
def _deduct_credit_pool_quota_capped(*, tenant_id: str, credits_required: int, pool_type: str) -> None:
"""Apply post-generation credit accounting without failing message persistence on quota exhaustion."""
from services.credit_pool_service import CreditPoolService
deducted_credits = CreditPoolService.deduct_credits_capped(
tenant_id=tenant_id,
credits_required=credits_required,
pool_type=pool_type,
)
if deducted_credits < credits_required:
logger.warning(
"Credit pool exhausted during message-created accounting, "
"tenant_id=%s, pool_type=%s, credits_required=%s, credits_deducted=%s",
tenant_id,
pool_type,
credits_required,
deducted_credits,
)
def _calculate_quota_usage(
*, message: Message, system_configuration: SystemConfiguration, model_name: str
) -> int | None:

View File

@ -45,7 +45,7 @@ dependencies = [
# Emerging: newer and fast-moving, use compatible pins
"fastopenapi[flask]~=0.7.0",
"graphon~=0.2.2",
"graphon~=0.3.0",
"httpx-sse~=0.4.0",
"json-repair~=0.59.4",
]

View File

@ -1,7 +1,7 @@
import logging
from sqlalchemy import select, update
from sqlalchemy.orm import sessionmaker
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from core.errors.error import QuotaExceededError
@ -13,6 +13,18 @@ logger = logging.getLogger(__name__)
class CreditPoolService:
@staticmethod
def _get_locked_pool(session: Session, tenant_id: str, pool_type: str) -> TenantCreditPool | None:
return session.scalar(
select(TenantCreditPool)
.where(
TenantCreditPool.tenant_id == tenant_id,
TenantCreditPool.pool_type == pool_type,
)
.limit(1)
.with_for_update()
)
@classmethod
def create_default_pool(cls, tenant_id: str) -> TenantCreditPool:
"""create default credit pool for new tenant"""
@ -59,31 +71,57 @@ class CreditPoolService:
credits_required: int,
pool_type: str = "trial",
) -> int:
"""check and deduct credits, returns actual credits deducted"""
pool = cls.get_pool(tenant_id, pool_type)
if not pool:
raise QuotaExceededError("Credit pool not found")
if pool.remaining_credits <= 0:
raise QuotaExceededError("No credits remaining")
# deduct all remaining credits if less than required
actual_credits = min(credits_required, pool.remaining_credits)
"""Deduct exactly the requested credits or raise without mutating the pool."""
if credits_required <= 0:
return 0
try:
with sessionmaker(db.engine).begin() as session:
stmt = (
update(TenantCreditPool)
.where(
TenantCreditPool.tenant_id == tenant_id,
TenantCreditPool.pool_type == pool_type,
)
.values(quota_used=TenantCreditPool.quota_used + actual_credits)
)
session.execute(stmt)
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
pool = cls._get_locked_pool(session=session, tenant_id=tenant_id, pool_type=pool_type)
if not pool:
raise QuotaExceededError("Credit pool not found")
remaining_credits = pool.remaining_credits
if remaining_credits <= 0:
raise QuotaExceededError("No credits remaining")
if remaining_credits < credits_required:
raise QuotaExceededError("Insufficient credits remaining")
pool.quota_used += credits_required
except QuotaExceededError:
raise
except Exception:
logger.exception("Failed to deduct credits for tenant %s", tenant_id)
raise QuotaExceededError("Failed to deduct credits")
return actual_credits
return credits_required
@classmethod
def deduct_credits_capped(
cls,
tenant_id: str,
credits_required: int,
pool_type: str = "trial",
) -> int:
"""Deduct up to the available balance and return the actual deducted credits."""
if credits_required <= 0:
return 0
try:
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
pool = cls._get_locked_pool(session=session, tenant_id=tenant_id, pool_type=pool_type)
if not pool:
logger.warning("Credit pool not found, tenant_id=%s, pool_type=%s", tenant_id, pool_type)
return 0
deducted_credits = min(credits_required, pool.remaining_credits)
if deducted_credits <= 0:
return 0
pool.quota_used += deducted_credits
return deducted_credits
except QuotaExceededError:
raise
except Exception:
logger.exception("Failed to deduct capped credits for tenant %s", tenant_id)
raise QuotaExceededError("Failed to deduct credits")

View File

@ -157,8 +157,8 @@ class DraftVarLoader(VariableLoader):
# This approach reduces loading time by querying external systems concurrently.
with ThreadPoolExecutor(max_workers=10) as executor:
offloaded_variables = executor.map(self._load_offloaded_variable, offloaded_draft_vars)
for selector, variable in offloaded_variables:
variable_by_selector[selector] = variable
for selector, offloaded_variable in offloaded_variables:
variable_by_selector[selector] = offloaded_variable
return list(variable_by_selector.values())

View File

@ -1251,7 +1251,7 @@ class WorkflowService:
node_data = HumanInputNode.validate_node_data(adapt_human_input_node_data_for_graph(node_config["data"]))
node = HumanInputNode(
node_id=node_config["id"],
config=node_data,
data=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
runtime=DifyHumanInputNodeRuntime(run_context),

View File

@ -73,7 +73,7 @@ def test_node_integration_minimal_stream(mocker: MockerFixture):
node = DatasourceNode(
node_id="n",
config=DatasourceNodeData(
data=DatasourceNodeData(
type="datasource",
version="1",
title="Datasource",

View File

@ -4,7 +4,7 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration
from core.model_manager import ModelInstance
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory
from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly
from graphon.model_runtime.entities.model_entities import ModelType
from models.provider import ProviderType
@ -15,8 +15,9 @@ def get_mocked_fetch_model_config(
mode: str,
credentials: dict,
):
model_provider_factory = create_plugin_model_provider_factory(tenant_id="9d2074fc-6f86-45a9-b09d-6ecc63b9056b")
model_type_instance = model_provider_factory.get_model_type_instance(provider, ModelType.LLM)
model_assembly = create_plugin_model_assembly(tenant_id="9d2074fc-6f86-45a9-b09d-6ecc63b9056b")
model_provider_factory = model_assembly.model_provider_factory
model_type_instance = model_assembly.create_model_type_instance(provider=provider, model_type=ModelType.LLM)
provider_model_bundle = ProviderModelBundle(
configuration=ProviderConfiguration(
tenant_id="1",

View File

@ -45,7 +45,7 @@ def init_code_node(code_config: dict):
)
# construct variable pool
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
@ -66,7 +66,7 @@ def init_code_node(code_config: dict):
node = CodeNode(
node_id=str(uuid.uuid4()),
config=CodeNodeData.model_validate(code_config["data"]),
data=CodeNodeData.model_validate(code_config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
code_executor=node_factory._code_executor,

View File

@ -55,7 +55,7 @@ def init_http_node(config: dict):
)
# construct variable pool
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
@ -76,7 +76,7 @@ def init_http_node(config: dict):
node = HttpRequestNode(
node_id=str(uuid.uuid4()),
config=HttpRequestNodeData.model_validate(config["data"]),
data=HttpRequestNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
http_request_config=HTTP_REQUEST_CONFIG,
@ -204,7 +204,7 @@ def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock):
from graphon.runtime import VariablePool
# Create variable pool
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="test", files=[]),
user_inputs={},
environment_variables=[],
@ -702,7 +702,7 @@ def test_nested_object_variable_selector(setup_http_mock):
)
# Create independent variable pool for this test only
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
@ -724,7 +724,7 @@ def test_nested_object_variable_selector(setup_http_mock):
node = HttpRequestNode(
node_id=str(uuid.uuid4()),
config=HttpRequestNodeData.model_validate(graph_config["nodes"][1]["data"]),
data=HttpRequestNodeData.model_validate(graph_config["nodes"][1]["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
http_request_config=HTTP_REQUEST_CONFIG,

View File

@ -53,7 +53,7 @@ def init_llm_node(config: dict) -> LLMNode:
)
# construct variable pool
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(
user_id="aaa",
app_id=app_id,
@ -77,7 +77,7 @@ def init_llm_node(config: dict) -> LLMNode:
node = LLMNode(
node_id=str(uuid.uuid4()),
config=LLMNodeData.model_validate(config["data"]),
data=LLMNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
credentials_provider=MagicMock(spec=CredentialsProvider),

View File

@ -56,7 +56,7 @@ def init_parameter_extractor_node(config: dict, memory=None):
)
# construct variable pool
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(
user_id="aaa", files=[], query="what's the weather in SF", conversation_id="abababa"
),
@ -71,7 +71,7 @@ def init_parameter_extractor_node(config: dict, memory=None):
node = ParameterExtractorNode(
node_id=str(uuid.uuid4()),
config=ParameterExtractorNodeData.model_validate(config["data"]),
data=ParameterExtractorNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
credentials_provider=MagicMock(spec=CredentialsProvider),

View File

@ -66,7 +66,7 @@ def test_execute_template_transform():
)
# construct variable pool
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
@ -88,7 +88,7 @@ def test_execute_template_transform():
node = TemplateTransformNode(
node_id=str(uuid.uuid4()),
config=TemplateTransformNodeData.model_validate(config["data"]),
data=TemplateTransformNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
jinja2_template_renderer=_SimpleJinja2Renderer(),

View File

@ -43,7 +43,7 @@ def init_tool_node(config: dict):
)
# construct variable pool
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
@ -64,7 +64,7 @@ def init_tool_node(config: dict):
node = ToolNode(
node_id=str(uuid.uuid4()),
config=ToolNodeData.model_validate(config["data"]),
data=ToolNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
tool_file_manager_factory=tool_file_manager_factory,

View File

@ -210,7 +210,9 @@ class TestPauseStatePersistenceLayerTestContainers:
execution_id = workflow_run_id or getattr(self, "test_workflow_run_id", None) or str(uuid.uuid4())
# Create variable pool
variable_pool = VariablePool(system_variables=build_system_variables(workflow_execution_id=execution_id))
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(workflow_execution_id=execution_id)
)
if variables:
for (node_id, var_key), value in variables.items():
variable_pool.add([node_id, var_key], value)

View File

@ -66,7 +66,7 @@ def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepos
def _build_runtime_state(workflow_execution_id: str, app_id: str, workflow_id: str, user_id: str) -> GraphRuntimeState:
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(
workflow_execution_id=workflow_execution_id,
app_id=app_id,
@ -102,7 +102,7 @@ def _build_graph(
start_data = StartNodeData(title="start", variables=[])
start_node = StartNode(
node_id="start",
config=start_data,
data=start_data,
graph_init_params=params,
graph_runtime_state=runtime_state,
)
@ -117,7 +117,7 @@ def _build_graph(
)
human_node = HumanInputNode(
node_id="human",
config=human_data,
data=human_data,
graph_init_params=params,
graph_runtime_state=runtime_state,
form_repository=form_repository,
@ -131,7 +131,7 @@ def _build_graph(
)
end_node = EndNode(
node_id="end",
config=end_data,
data=end_data,
graph_init_params=params,
graph_runtime_state=runtime_state,
)

View File

@ -90,16 +90,34 @@ class TestCreditPoolService:
pool = CreditPoolService.get_pool(tenant_id=tenant_id)
assert pool.quota_used == credits_required
def test_check_and_deduct_credits_caps_at_remaining(self, db_session_with_containers: Session):
def test_check_and_deduct_credits_raises_without_deducting_when_insufficient(
self, db_session_with_containers: Session
):
tenant_id = self._create_tenant_id()
pool = CreditPoolService.create_default_pool(tenant_id)
remaining = 5
pool.quota_used = pool.quota_limit - remaining
quota_used = pool.quota_used
db_session_with_containers.commit()
result = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=200)
with pytest.raises(QuotaExceededError, match="Insufficient credits remaining"):
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=200)
db_session_with_containers.expire_all()
updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id)
assert updated_pool.quota_used == quota_used
def test_deduct_credits_capped_depletes_available_balance(self, db_session_with_containers: Session):
tenant_id = self._create_tenant_id()
pool = CreditPoolService.create_default_pool(tenant_id)
remaining = 5
pool.quota_used = pool.quota_limit - remaining
quota_limit = pool.quota_limit
db_session_with_containers.commit()
result = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=200)
assert result == remaining
db_session_with_containers.expire_all()
updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id)
assert updated_pool.quota_used == pool.quota_limit
assert updated_pool.quota_used == quota_limit

View File

@ -132,7 +132,9 @@ class TestAdvancedChatGenerateTaskPipeline:
pipeline._task_state.answer = "partial answer"
pipeline._workflow_run_id = "run-id"
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
variable_pool=build_test_variable_pool(
variables=build_system_variables(workflow_execution_id="run-id"),
),
start_at=0.0,
total_tokens=7,
node_run_steps=3,
@ -372,7 +374,9 @@ class TestAdvancedChatGenerateTaskPipeline:
pipeline = _make_pipeline()
pipeline._workflow_run_id = "run-id"
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
variable_pool=VariablePool.from_bootstrap(
system_variables=build_system_variables(workflow_execution_id="run-id")
),
start_at=0.0,
)
pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish"
@ -583,7 +587,9 @@ class TestAdvancedChatGenerateTaskPipeline:
self.items = items
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
variable_pool=VariablePool.from_bootstrap(
system_variables=build_system_variables(workflow_execution_id="run-id")
),
start_at=0.0,
)
@ -617,7 +623,9 @@ class TestAdvancedChatGenerateTaskPipeline:
def test_handle_message_end_event_applies_output_moderation(self, monkeypatch: pytest.MonkeyPatch):
pipeline = _make_pipeline()
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
variable_pool=VariablePool.from_bootstrap(
system_variables=build_system_variables(workflow_execution_id="run-id")
),
start_at=0.0,
)
pipeline._base_task_pipeline.handle_output_moderation_when_task_finished = lambda answer: "safe"

View File

@ -60,7 +60,7 @@ class _StubToolNode(Node[_StubToolNodeData]):
def __init__(
self,
node_id: str,
config: _StubToolNodeData,
data: _StubToolNodeData,
*,
graph_init_params,
graph_runtime_state,
@ -68,7 +68,7 @@ class _StubToolNode(Node[_StubToolNodeData]):
) -> None:
super().__init__(
node_id=node_id,
config=config,
data=data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
@ -169,7 +169,7 @@ def _build_graph(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> G
def _build_runtime_state(run_id: str) -> GraphRuntimeState:
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="user", app_id="app", workflow_id="workflow"),
user_inputs={},
conversation_variables=[],

View File

@ -54,7 +54,7 @@ class TestWorkflowBasedAppRunner:
runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app")
runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=default_system_variables()),
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()),
start_at=0.0,
)
@ -93,7 +93,7 @@ class TestWorkflowBasedAppRunner:
def test_get_graph_and_variable_pool_for_single_node_run(self, monkeypatch: pytest.MonkeyPatch):
runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app")
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=default_system_variables()),
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()),
start_at=0.0,
)
@ -164,7 +164,7 @@ class TestWorkflowBasedAppRunner:
app_id="app",
)
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=default_system_variables()),
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()),
start_at=0.0,
)
@ -243,7 +243,7 @@ class TestWorkflowBasedAppRunner:
runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app")
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=default_system_variables()),
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()),
start_at=0.0,
)
graph_runtime_state.register_paused_node("node-1")
@ -286,7 +286,7 @@ class TestWorkflowBasedAppRunner:
runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app")
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=default_system_variables()),
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()),
start_at=0.0,
)
workflow_entry = SimpleNamespace(graph_engine=SimpleNamespace(graph_runtime_state=graph_runtime_state))
@ -425,7 +425,7 @@ class TestWorkflowBasedAppRunner:
runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app")
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=default_system_variables()),
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()),
start_at=0.0,
)
workflow_entry = SimpleNamespace(graph_engine=SimpleNamespace(graph_runtime_state=graph_runtime_state))

View File

@ -16,7 +16,7 @@ from models.workflow import Workflow
def _make_graph_state():
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
environment_variables=[],

View File

@ -95,7 +95,9 @@ class TestWorkflowGenerateTaskPipeline:
def test_to_blocking_response_falls_back_to_human_input_required_when_pause_event_missing(self):
pipeline = _make_pipeline()
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
variable_pool=build_test_variable_pool(
variables=build_system_variables(workflow_execution_id="run-id"),
),
start_at=0.0,
total_tokens=5,
node_run_steps=2,
@ -283,7 +285,9 @@ class TestWorkflowGenerateTaskPipeline:
pipeline = _make_pipeline()
pipeline._workflow_execution_id = "run-id"
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
variable_pool=VariablePool.from_bootstrap(
system_variables=build_system_variables(workflow_execution_id="run-id")
),
start_at=0.0,
)
pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish"
@ -725,7 +729,9 @@ class TestWorkflowGenerateTaskPipeline:
pipeline = _make_pipeline()
pipeline._workflow_execution_id = "run-id"
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
variable_pool=VariablePool.from_bootstrap(
system_variables=build_system_variables(workflow_execution_id="run-id")
),
start_at=0.0,
)
@ -753,7 +759,9 @@ class TestWorkflowGenerateTaskPipeline:
pipeline = _make_pipeline()
pipeline._workflow_execution_id = "run-id"
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
variable_pool=VariablePool.from_bootstrap(
system_variables=build_system_variables(workflow_execution_id="run-id")
),
start_at=0.0,
)
pipeline._handle_ping_event = lambda event, **kwargs: iter(["ping"])
@ -769,7 +777,9 @@ class TestWorkflowGenerateTaskPipeline:
def test_process_stream_response_main_match_paths_and_cleanup(self):
pipeline = _make_pipeline()
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
variable_pool=VariablePool.from_bootstrap(
system_variables=build_system_variables(workflow_execution_id="run-id")
),
start_at=0.0,
)
pipeline._base_task_pipeline.queue_manager.listen = lambda: iter(

View File

@ -21,7 +21,9 @@ class TestTriggerPostLayer:
)
runtime_state = SimpleNamespace(
outputs={"answer": "ok"},
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-1")),
variable_pool=VariablePool.from_bootstrap(
system_variables=build_system_variables(workflow_execution_id="run-1")
),
total_tokens=12,
)
@ -60,7 +62,9 @@ class TestTriggerPostLayer:
def test_on_event_handles_missing_trigger_log(self):
runtime_state = SimpleNamespace(
outputs={},
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-1")),
variable_pool=VariablePool.from_bootstrap(
system_variables=build_system_variables(workflow_execution_id="run-1")
),
total_tokens=0,
)
@ -91,7 +95,9 @@ class TestTriggerPostLayer:
def test_on_event_ignores_non_status_events(self):
runtime_state = SimpleNamespace(
outputs={},
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-1")),
variable_pool=VariablePool.from_bootstrap(
system_variables=build_system_variables(workflow_execution_id="run-1")
),
total_tokens=0,
)

View File

@ -0,0 +1,617 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from sqlalchemy import create_engine, select
from configs import dify_config
from core.app.llm.quota import (
deduct_llm_quota,
deduct_llm_quota_for_model,
ensure_llm_quota_available,
ensure_llm_quota_available_for_model,
)
from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from core.errors.error import QuotaExceededError
from graphon.model_runtime.entities.llm_entities import LLMUsage
from graphon.model_runtime.entities.model_entities import ModelType
from models import TenantCreditPool
from models.enums import ProviderQuotaType as ModelProviderQuotaType
from models.provider import Provider, ProviderType
def test_ensure_llm_quota_available_for_model_raises_when_system_model_is_exhausted() -> None:
provider_configuration = SimpleNamespace(
using_provider_type=ProviderType.SYSTEM,
get_provider_model=MagicMock(return_value=SimpleNamespace(status=ModelStatus.QUOTA_EXCEEDED)),
)
provider_manager = MagicMock()
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
with (
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
pytest.raises(QuotaExceededError, match="Model provider openai quota exceeded."),
):
ensure_llm_quota_available_for_model(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
)
provider_configuration.get_provider_model.assert_called_once_with(
model_type=ModelType.LLM,
model="gpt-4o",
)
def test_ensure_llm_quota_available_for_model_raises_when_provider_is_missing() -> None:
provider_manager = MagicMock()
provider_manager.get_configurations.return_value.get.return_value = None
with (
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
pytest.raises(ValueError, match="Provider openai does not exist."),
):
ensure_llm_quota_available_for_model(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
)
def test_ensure_llm_quota_available_for_model_ignores_custom_provider_configuration() -> None:
provider_configuration = SimpleNamespace(
using_provider_type=ProviderType.CUSTOM,
get_provider_model=MagicMock(),
)
provider_manager = MagicMock()
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
with patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager):
ensure_llm_quota_available_for_model(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
)
provider_configuration.get_provider_model.assert_not_called()
def test_deduct_llm_quota_for_model_uses_identity_based_trial_billing() -> None:
usage = LLMUsage.empty_usage()
usage.total_tokens = 42
provider_configuration = SimpleNamespace(
using_provider_type=ProviderType.SYSTEM,
system_configuration=SimpleNamespace(
current_quota_type=ProviderQuotaType.TRIAL,
quota_configurations=[
SimpleNamespace(
quota_type=ProviderQuotaType.TRIAL,
quota_unit=QuotaUnit.TOKENS,
quota_limit=100,
)
],
),
)
provider_manager = MagicMock()
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
with (
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits,
):
deduct_llm_quota_for_model(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
usage=usage,
)
mock_deduct_credits.assert_called_once_with(
tenant_id="tenant-id",
credits_required=42,
)
def test_deduct_llm_quota_for_model_caps_trial_pool_when_usage_exceeds_remaining() -> None:
usage = LLMUsage.empty_usage()
usage.total_tokens = 3
provider_configuration = SimpleNamespace(
using_provider_type=ProviderType.SYSTEM,
system_configuration=SimpleNamespace(
current_quota_type=ProviderQuotaType.TRIAL,
quota_configurations=[
SimpleNamespace(
quota_type=ProviderQuotaType.TRIAL,
quota_unit=QuotaUnit.TOKENS,
quota_limit=100,
)
],
),
)
provider_manager = MagicMock()
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
engine = create_engine("sqlite:///:memory:")
TenantCreditPool.__table__.create(engine)
with engine.begin() as connection:
connection.execute(
TenantCreditPool.__table__.insert(),
{
"id": "trial-pool",
"tenant_id": "tenant-id",
"pool_type": ModelProviderQuotaType.TRIAL,
"quota_limit": 10,
"quota_used": 9,
},
)
with (
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
):
deduct_llm_quota_for_model(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
usage=usage,
)
with engine.connect() as connection:
quota_used = connection.scalar(select(TenantCreditPool.quota_used).where(TenantCreditPool.id == "trial-pool"))
assert quota_used == 10
def test_deduct_llm_quota_for_model_returns_for_unbounded_quota() -> None:
usage = LLMUsage.empty_usage()
usage.total_tokens = 42
provider_configuration = SimpleNamespace(
using_provider_type=ProviderType.SYSTEM,
system_configuration=SimpleNamespace(
current_quota_type=ProviderQuotaType.TRIAL,
quota_configurations=[
SimpleNamespace(
quota_type=ProviderQuotaType.TRIAL,
quota_unit=QuotaUnit.TOKENS,
quota_limit=-1,
)
],
),
)
provider_manager = MagicMock()
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
with (
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits,
):
deduct_llm_quota_for_model(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
usage=usage,
)
mock_deduct_credits.assert_not_called()
def test_deduct_llm_quota_for_model_uses_credit_configuration() -> None:
usage = LLMUsage.empty_usage()
provider_configuration = SimpleNamespace(
using_provider_type=ProviderType.SYSTEM,
system_configuration=SimpleNamespace(
current_quota_type=ProviderQuotaType.TRIAL,
quota_configurations=[
SimpleNamespace(
quota_type=ProviderQuotaType.TRIAL,
quota_unit=QuotaUnit.CREDITS,
quota_limit=100,
)
],
),
)
provider_manager = MagicMock()
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
with (
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
patch.object(type(dify_config), "get_model_credits", return_value=9) as mock_get_model_credits,
patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits,
):
deduct_llm_quota_for_model(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
usage=usage,
)
mock_get_model_credits.assert_called_once_with("gpt-4o")
mock_deduct_credits.assert_called_once_with(
tenant_id="tenant-id",
credits_required=9,
)
def test_deduct_llm_quota_for_model_uses_single_charge_for_times_quota() -> None:
usage = LLMUsage.empty_usage()
provider_configuration = SimpleNamespace(
using_provider_type=ProviderType.SYSTEM,
system_configuration=SimpleNamespace(
current_quota_type=ProviderQuotaType.TRIAL,
quota_configurations=[
SimpleNamespace(
quota_type=ProviderQuotaType.TRIAL,
quota_unit=QuotaUnit.TIMES,
quota_limit=100,
)
],
),
)
provider_manager = MagicMock()
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
with (
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits,
):
deduct_llm_quota_for_model(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
usage=usage,
)
mock_deduct_credits.assert_called_once_with(
tenant_id="tenant-id",
credits_required=1,
)
def test_deduct_llm_quota_for_model_uses_paid_billing_pool() -> None:
usage = LLMUsage.empty_usage()
usage.total_tokens = 5
provider_configuration = SimpleNamespace(
using_provider_type=ProviderType.SYSTEM,
system_configuration=SimpleNamespace(
current_quota_type=ProviderQuotaType.PAID,
quota_configurations=[
SimpleNamespace(
quota_type=ProviderQuotaType.PAID,
quota_unit=QuotaUnit.TOKENS,
quota_limit=100,
)
],
),
)
provider_manager = MagicMock()
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
with (
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits,
):
deduct_llm_quota_for_model(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
usage=usage,
)
mock_deduct_credits.assert_called_once_with(
tenant_id="tenant-id",
credits_required=5,
pool_type="paid",
)
def test_deduct_llm_quota_for_model_updates_free_quota_usage() -> None:
usage = LLMUsage.empty_usage()
usage.total_tokens = 3
provider_configuration = SimpleNamespace(
using_provider_type=ProviderType.SYSTEM,
system_configuration=SimpleNamespace(
current_quota_type=ProviderQuotaType.FREE,
quota_configurations=[
SimpleNamespace(
quota_type=ProviderQuotaType.FREE,
quota_unit=QuotaUnit.TOKENS,
quota_limit=100,
)
],
),
)
provider_manager = MagicMock()
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
engine = create_engine("sqlite:///:memory:")
Provider.__table__.create(engine)
with engine.begin() as connection:
connection.execute(
Provider.__table__.insert(),
[
{
"id": "matching-provider",
"tenant_id": "tenant-id",
"provider_name": "openai",
"provider_type": ProviderType.SYSTEM,
"quota_type": ProviderQuotaType.FREE,
"quota_limit": 100,
"quota_used": 10,
"is_valid": True,
},
{
"id": "other-tenant",
"tenant_id": "other-tenant-id",
"provider_name": "openai",
"provider_type": ProviderType.SYSTEM,
"quota_type": ProviderQuotaType.FREE,
"quota_limit": 100,
"quota_used": 20,
"is_valid": True,
},
{
"id": "other-provider",
"tenant_id": "tenant-id",
"provider_name": "anthropic",
"provider_type": ProviderType.SYSTEM,
"quota_type": ProviderQuotaType.FREE,
"quota_limit": 100,
"quota_used": 30,
"is_valid": True,
},
{
"id": "custom-provider",
"tenant_id": "tenant-id",
"provider_name": "openai",
"provider_type": ProviderType.CUSTOM,
"quota_type": ProviderQuotaType.FREE,
"quota_limit": 100,
"quota_used": 40,
"is_valid": True,
},
],
)
with (
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
patch("core.app.llm.quota.db", SimpleNamespace(engine=engine)),
):
deduct_llm_quota_for_model(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
usage=usage,
)
with engine.connect() as connection:
quota_used_by_id = dict(connection.execute(select(Provider.id, Provider.quota_used)).all())
assert quota_used_by_id == {
"matching-provider": 13,
"other-tenant": 20,
"other-provider": 30,
"custom-provider": 40,
}
with engine.begin() as connection:
connection.execute(
Provider.__table__.update().where(Provider.id == "matching-provider").values(quota_limit=13, quota_used=13)
)
with (
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
patch("core.app.llm.quota.db", SimpleNamespace(engine=engine)),
pytest.raises(QuotaExceededError, match="Model provider openai quota exceeded."),
):
deduct_llm_quota_for_model(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
usage=usage,
)
with engine.connect() as connection:
exhausted_quota_used = connection.scalar(select(Provider.quota_used).where(Provider.id == "matching-provider"))
assert exhausted_quota_used == 13
def test_deduct_llm_quota_for_model_caps_free_quota_and_raises_when_usage_exceeds_remaining() -> None:
usage = LLMUsage.empty_usage()
usage.total_tokens = 3
provider_configuration = SimpleNamespace(
using_provider_type=ProviderType.SYSTEM,
system_configuration=SimpleNamespace(
current_quota_type=ProviderQuotaType.FREE,
quota_configurations=[
SimpleNamespace(
quota_type=ProviderQuotaType.FREE,
quota_unit=QuotaUnit.TOKENS,
quota_limit=100,
)
],
),
)
provider_manager = MagicMock()
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
engine = create_engine("sqlite:///:memory:")
Provider.__table__.create(engine)
with engine.begin() as connection:
connection.execute(
Provider.__table__.insert(),
{
"id": "matching-provider",
"tenant_id": "tenant-id",
"provider_name": "openai",
"provider_type": ProviderType.SYSTEM,
"quota_type": ProviderQuotaType.FREE,
"quota_limit": 15,
"quota_used": 13,
"is_valid": True,
},
)
with (
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
patch("core.app.llm.quota.db", SimpleNamespace(engine=engine)),
pytest.raises(QuotaExceededError, match="Model provider openai quota exceeded."),
):
deduct_llm_quota_for_model(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
usage=usage,
)
with engine.connect() as connection:
quota_used = connection.scalar(select(Provider.quota_used).where(Provider.id == "matching-provider"))
assert quota_used == 15
def test_deduct_llm_quota_for_model_ignores_unknown_quota_type() -> None:
usage = LLMUsage.empty_usage()
usage.total_tokens = 2
provider_configuration = SimpleNamespace(
using_provider_type=ProviderType.SYSTEM,
system_configuration=SimpleNamespace(
current_quota_type="unexpected",
quota_configurations=[
SimpleNamespace(
quota_type="unexpected",
quota_unit=QuotaUnit.TOKENS,
quota_limit=100,
)
],
),
)
provider_manager = MagicMock()
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
with (
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits,
patch("core.app.llm.quota.sessionmaker") as mock_sessionmaker,
):
deduct_llm_quota_for_model(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
usage=usage,
)
mock_deduct_credits.assert_not_called()
mock_sessionmaker.assert_not_called()
def test_deduct_llm_quota_for_model_ignores_custom_provider_configuration() -> None:
usage = LLMUsage.empty_usage()
usage.total_tokens = 2
provider_configuration = SimpleNamespace(
using_provider_type=ProviderType.CUSTOM,
system_configuration=SimpleNamespace(
current_quota_type=ProviderQuotaType.TRIAL,
quota_configurations=[],
),
)
provider_manager = MagicMock()
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
with (
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits,
patch("core.app.llm.quota.sessionmaker") as mock_sessionmaker,
):
deduct_llm_quota_for_model(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
usage=usage,
)
mock_deduct_credits.assert_not_called()
mock_sessionmaker.assert_not_called()
def test_ensure_llm_quota_available_wrapper_warns_and_delegates() -> None:
model_instance = SimpleNamespace(
provider="openai",
model_name="gpt-4o",
provider_model_bundle=SimpleNamespace(configuration=SimpleNamespace(tenant_id="tenant-id")),
model_type_instance=SimpleNamespace(model_type=ModelType.LLM),
)
with (
pytest.deprecated_call(match="ensure_llm_quota_available\\(model_instance=.*deprecated"),
patch("core.app.llm.quota.ensure_llm_quota_available_for_model") as mock_ensure,
):
ensure_llm_quota_available(model_instance=model_instance)
mock_ensure.assert_called_once_with(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
)
def test_ensure_llm_quota_available_wrapper_rejects_non_llm_model_instances() -> None:
model_instance = SimpleNamespace(
provider="openai",
model_name="gpt-4o",
provider_model_bundle=SimpleNamespace(configuration=SimpleNamespace(tenant_id="tenant-id")),
model_type_instance=SimpleNamespace(model_type=ModelType.TEXT_EMBEDDING),
)
with (
pytest.deprecated_call(match="ensure_llm_quota_available\\(model_instance=.*deprecated"),
pytest.raises(ValueError, match="only support LLM model instances"),
):
ensure_llm_quota_available(model_instance=model_instance)
def test_deduct_llm_quota_wrapper_warns_and_delegates() -> None:
usage = LLMUsage.empty_usage()
usage.total_tokens = 7
model_instance = SimpleNamespace(
provider="openai",
model_name="gpt-4o",
model_type_instance=SimpleNamespace(model_type=ModelType.LLM),
provider_model_bundle=SimpleNamespace(configuration=SimpleNamespace()),
)
with (
pytest.deprecated_call(match="deduct_llm_quota\\(tenant_id=.*deprecated"),
patch("core.app.llm.quota.deduct_llm_quota_for_model") as mock_deduct,
):
deduct_llm_quota(
tenant_id="tenant-id",
model_instance=model_instance,
usage=usage,
)
mock_deduct.assert_called_once_with(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
usage=usage,
)
def test_deduct_llm_quota_wrapper_rejects_non_llm_model_instances() -> None:
usage = LLMUsage.empty_usage()
model_instance = SimpleNamespace(
provider="openai",
model_name="gpt-4o",
model_type_instance=SimpleNamespace(model_type=ModelType.TEXT_EMBEDDING),
provider_model_bundle=SimpleNamespace(configuration=SimpleNamespace()),
)
with (
pytest.deprecated_call(match="deduct_llm_quota\\(tenant_id=.*deprecated"),
pytest.raises(ValueError, match="only support LLM model instances"),
):
deduct_llm_quota(
tenant_id="tenant-id",
model_instance=model_instance,
usage=usage,
)

View File

@ -8,9 +8,9 @@ from graphon.enums import BuiltinNodeTypes
class DummyNode:
def __init__(self, *, node_id, config, graph_init_params, graph_runtime_state, **kwargs):
def __init__(self, *, node_id, data, graph_init_params, graph_runtime_state, **kwargs):
self.id = node_id
self.config = config
self.data = data
self.graph_init_params = graph_init_params
self.graph_runtime_state = graph_runtime_state
self.kwargs = kwargs

View File

@ -60,7 +60,10 @@ def _make_layer(
workflow_execution_id="run-id",
conversation_id="conv-id",
)
runtime_state = GraphRuntimeState(variable_pool=VariablePool(system_variables=system_variables), start_at=0.0)
runtime_state = GraphRuntimeState(
variable_pool=VariablePool.from_bootstrap(system_variables=system_variables),
start_at=0.0,
)
read_only_state = ReadOnlyGraphRuntimeStateWrapper(runtime_state)
application_generate_entity = WorkflowAppGenerateEntity.model_construct(

View File

@ -354,7 +354,8 @@ def test_validate_provider_credentials_handles_hidden_secret_value() -> None:
with _patched_session(mock_session):
with patch(
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
"core.entities.provider_configuration.create_plugin_model_assembly",
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory),
):
with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="restored-key"):
with patch(
@ -379,7 +380,10 @@ def test_validate_provider_credentials_without_credential_id() -> None:
mock_factory = Mock()
mock_factory.provider_credentials_validate.return_value = {"region": "us"}
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
with patch(
"core.entities.provider_configuration.create_plugin_model_assembly",
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory),
):
validated = configuration.validate_provider_credentials(credentials={"region": "us"})
assert validated == {"region": "us"}
@ -426,23 +430,37 @@ def test_switch_preferred_provider_type_creates_record_when_missing() -> None:
def test_get_model_type_instance_and_schema_delegate_to_factory() -> None:
configuration = _build_provider_configuration()
mock_factory = Mock()
mock_model_type_instance = Mock()
mock_schema = _build_ai_model("gpt-4o")
mock_factory.get_model_type_instance.return_value = mock_model_type_instance
mock_factory = Mock()
mock_factory.get_provider_schema.return_value = configuration.provider
mock_factory.get_model_schema.return_value = mock_schema
mock_assembly = Mock()
mock_assembly.model_runtime = Mock()
mock_assembly.model_provider_factory = mock_factory
with patch(
"core.entities.provider_configuration.create_plugin_model_provider_factory",
return_value=mock_factory,
) as mock_factory_builder:
with (
patch(
"core.entities.provider_configuration.create_plugin_model_assembly",
return_value=mock_assembly,
) as mock_assembly_builder,
patch(
"core.entities.provider_configuration.create_model_type_instance",
return_value=mock_model_type_instance,
) as mock_model_builder,
):
model_type_instance = configuration.get_model_type_instance(ModelType.LLM)
model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"})
assert model_type_instance is mock_model_type_instance
assert model_schema is mock_schema
assert mock_factory_builder.call_count == 2
mock_factory.get_model_type_instance.assert_called_once_with(provider="openai", model_type=ModelType.LLM)
assert mock_assembly_builder.call_count == 2
mock_factory.get_provider_schema.assert_called_once_with(provider="openai")
mock_model_builder.assert_called_once_with(
runtime=mock_assembly.model_runtime,
provider_schema=configuration.provider,
model_type=ModelType.LLM,
)
mock_factory.get_model_schema.assert_called_once_with(
provider="openai",
model_type=ModelType.LLM,
@ -456,17 +474,21 @@ def test_get_model_type_instance_and_schema_reuse_bound_runtime_factory() -> Non
bound_runtime = Mock()
configuration.bind_model_runtime(bound_runtime)
mock_factory = Mock()
mock_model_type_instance = Mock()
mock_schema = _build_ai_model("gpt-4o")
mock_factory.get_model_type_instance.return_value = mock_model_type_instance
mock_factory = Mock()
mock_factory.get_provider_schema.return_value = configuration.provider
mock_factory.get_model_schema.return_value = mock_schema
with (
patch(
"core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory
) as mock_factory_cls,
patch("core.entities.provider_configuration.create_plugin_model_provider_factory") as mock_factory_builder,
patch("core.entities.provider_configuration.create_plugin_model_assembly") as mock_assembly_builder,
patch(
"core.entities.provider_configuration.create_model_type_instance",
return_value=mock_model_type_instance,
) as mock_model_builder,
):
model_type_instance = configuration.get_model_type_instance(ModelType.LLM)
model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"})
@ -474,8 +496,14 @@ def test_get_model_type_instance_and_schema_reuse_bound_runtime_factory() -> Non
assert model_type_instance is mock_model_type_instance
assert model_schema is mock_schema
assert mock_factory_cls.call_count == 2
mock_factory_cls.assert_called_with(model_runtime=bound_runtime)
mock_factory_builder.assert_not_called()
mock_factory_cls.assert_called_with(runtime=bound_runtime)
mock_assembly_builder.assert_not_called()
mock_factory.get_provider_schema.assert_called_once_with(provider="openai")
mock_model_builder.assert_called_once_with(
runtime=bound_runtime,
provider_schema=configuration.provider,
model_type=ModelType.LLM,
)
def test_get_provider_model_returns_none_when_model_not_found() -> None:
@ -504,7 +532,10 @@ def test_get_provider_models_system_deduplicates_sorts_and_filters_active() -> N
mock_factory = Mock()
mock_factory.get_provider_schema.return_value = provider_schema
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
with patch(
"core.entities.provider_configuration.create_plugin_model_assembly",
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory),
):
all_models = configuration.get_provider_models(model_type=ModelType.LLM, only_active=False)
active_models = configuration.get_provider_models(model_type=ModelType.LLM, only_active=True)
@ -722,7 +753,8 @@ def test_validate_provider_credentials_handles_invalid_original_json() -> None:
with _patched_session(mock_session):
with patch(
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
"core.entities.provider_configuration.create_plugin_model_assembly",
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory),
):
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-key"):
validated = configuration.validate_provider_credentials(
@ -1069,7 +1101,8 @@ def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless
with _patched_session(mock_session):
with patch(
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
"core.entities.provider_configuration.create_plugin_model_assembly",
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory),
):
with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"):
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
@ -1083,7 +1116,10 @@ def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless
mock_factory2 = Mock()
mock_factory2.model_credentials_validate.return_value = {"region": "us"}
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory2):
with patch(
"core.entities.provider_configuration.create_plugin_model_assembly",
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory2),
):
validated = configuration.validate_custom_model_credentials(
model_type=ModelType.LLM,
model="gpt-4o",
@ -1575,7 +1611,8 @@ def test_validate_provider_credentials_uses_empty_original_when_record_missing()
with _patched_session(mock_session):
with patch(
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
"core.entities.provider_configuration.create_plugin_model_assembly",
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory),
):
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
validated = configuration.validate_provider_credentials(
@ -1701,7 +1738,8 @@ def test_validate_custom_model_credentials_handles_invalid_original_json() -> No
with _patched_session(mock_session):
with patch(
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
"core.entities.provider_configuration.create_plugin_model_assembly",
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory),
):
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
validated = configuration.validate_custom_model_credentials(

View File

@ -68,8 +68,8 @@ def test_check_moderation_returns_true_when_model_accepts_text(mocker: MockerFix
mocker.patch("core.helper.moderation.secrets.choice", return_value="chunk")
moderation_model = SimpleNamespace(invoke=lambda **invoke_kwargs: invoke_kwargs["text"] == "chunk")
factory = SimpleNamespace(get_model_type_instance=lambda **_factory_kwargs: moderation_model)
mocker.patch("core.helper.moderation.create_plugin_model_provider_factory", return_value=factory)
assembly = SimpleNamespace(create_model_type_instance=lambda **_factory_kwargs: moderation_model)
mocker.patch("core.helper.moderation.create_plugin_model_assembly", return_value=assembly)
assert (
check_moderation(
@ -91,7 +91,7 @@ def test_check_moderation_returns_true_when_text_is_empty(mocker: MockerFixture)
provider_map={openai_provider: hosting_openai},
),
)
factory_mock = mocker.patch("core.helper.moderation.create_plugin_model_provider_factory")
factory_mock = mocker.patch("core.helper.moderation.create_plugin_model_assembly")
choice_mock = mocker.patch("core.helper.moderation.secrets.choice")
assert (
@ -119,8 +119,8 @@ def test_check_moderation_returns_false_when_model_rejects_text(mocker: MockerFi
mocker.patch("core.helper.moderation.secrets.choice", return_value="chunk")
moderation_model = SimpleNamespace(invoke=lambda **_invoke_kwargs: False)
factory = SimpleNamespace(get_model_type_instance=lambda **_factory_kwargs: moderation_model)
mocker.patch("core.helper.moderation.create_plugin_model_provider_factory", return_value=factory)
assembly = SimpleNamespace(create_model_type_instance=lambda **_factory_kwargs: moderation_model)
mocker.patch("core.helper.moderation.create_plugin_model_assembly", return_value=assembly)
assert (
check_moderation(
@ -147,8 +147,8 @@ def test_check_moderation_raises_bad_request_when_provider_call_fails(mocker: Mo
failing_model = SimpleNamespace(
invoke=lambda **_invoke_kwargs: (_ for _ in ()).throw(RuntimeError("boom")),
)
factory = SimpleNamespace(get_model_type_instance=lambda **_factory_kwargs: failing_model)
mocker.patch("core.helper.moderation.create_plugin_model_provider_factory", return_value=factory)
assembly = SimpleNamespace(create_model_type_instance=lambda **_factory_kwargs: failing_model)
mocker.patch("core.helper.moderation.create_plugin_model_assembly", return_value=assembly)
with pytest.raises(InvokeBadRequestError, match="Rate limit exceeded, please try again later."):
check_moderation(

View File

@ -2,6 +2,7 @@ from unittest.mock import Mock
import pytest
from core.plugin.impl.model_runtime_factory import create_model_type_instance
from graphon.model_runtime.entities.common_entities import I18nObject
from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from graphon.model_runtime.entities.provider_entities import (
@ -73,7 +74,7 @@ def test_model_provider_factory_resolves_runtime_provider_name() -> None:
supported_model_types=[ModelType.LLM],
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
)
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime([provider]))
factory = ModelProviderFactory(runtime=_FakeModelRuntime([provider]))
provider_schema = factory.get_model_provider("openai")
@ -98,7 +99,7 @@ def test_model_provider_factory_resolves_canonical_short_name_independent_of_pro
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
),
]
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers))
factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers))
provider_schema = factory.get_model_provider("openai")
@ -107,8 +108,8 @@ def test_model_provider_factory_resolves_canonical_short_name_independent_of_pro
def test_model_provider_factory_requires_runtime() -> None:
with pytest.raises(ValueError, match="model_runtime is required"):
ModelProviderFactory(model_runtime=None) # type: ignore[arg-type]
with pytest.raises(ValueError, match="runtime is required"):
ModelProviderFactory(runtime=None) # type: ignore[arg-type]
def test_model_provider_factory_get_providers_returns_runtime_providers() -> None:
@ -119,7 +120,7 @@ def test_model_provider_factory_get_providers_returns_runtime_providers() -> Non
supported_model_types=[ModelType.LLM],
)
]
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers))
factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers))
result = factory.get_providers()
@ -133,7 +134,7 @@ def test_model_provider_factory_get_provider_schema_delegates_to_provider_lookup
provider_name="openai",
supported_model_types=[ModelType.LLM],
)
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime([provider]))
factory = ModelProviderFactory(runtime=_FakeModelRuntime([provider]))
result = factory.get_provider_schema("openai")
@ -142,7 +143,7 @@ def test_model_provider_factory_get_provider_schema_delegates_to_provider_lookup
def test_model_provider_factory_raises_for_unknown_provider() -> None:
factory = ModelProviderFactory(
model_runtime=_FakeModelRuntime(
runtime=_FakeModelRuntime(
[
_build_provider(
provider="langgenius/openai/openai",
@ -172,7 +173,7 @@ def test_model_provider_factory_get_models_filters_provider_and_model_type() ->
models=[_build_model("rerank-v3", ModelType.RERANK)],
),
]
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers))
factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers))
results = factory.get_models(provider="openai", model_type=ModelType.LLM)
@ -196,7 +197,7 @@ def test_model_provider_factory_get_models_skips_providers_without_requested_mod
models=[_build_model("eleven_multilingual_v2", ModelType.TTS)],
),
]
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers))
factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers))
results = factory.get_models(model_type=ModelType.TTS)
@ -214,7 +215,7 @@ def test_model_provider_factory_get_models_without_model_type_keeps_all_provider
models=[_build_model("gpt-4o-mini", ModelType.LLM), _build_model("tts-1", ModelType.TTS)],
)
]
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers))
factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers))
results = factory.get_models(provider="openai")
@ -242,7 +243,7 @@ def test_model_provider_factory_validates_provider_credentials() -> None:
)
]
)
factory = ModelProviderFactory(model_runtime=runtime)
factory = ModelProviderFactory(runtime=runtime)
filtered = factory.provider_credentials_validate(
provider="openai",
@ -258,7 +259,7 @@ def test_model_provider_factory_validates_provider_credentials() -> None:
def test_model_provider_factory_provider_credentials_validate_requires_schema() -> None:
factory = ModelProviderFactory(
model_runtime=_FakeModelRuntime(
runtime=_FakeModelRuntime(
[
_build_provider(
provider="langgenius/openai/openai",
@ -294,7 +295,7 @@ def test_model_provider_factory_validates_model_credentials() -> None:
)
]
)
factory = ModelProviderFactory(model_runtime=runtime)
factory = ModelProviderFactory(runtime=runtime)
filtered = factory.model_credentials_validate(
provider="openai",
@ -314,7 +315,7 @@ def test_model_provider_factory_validates_model_credentials() -> None:
def test_model_provider_factory_model_credentials_validate_requires_schema() -> None:
factory = ModelProviderFactory(
model_runtime=_FakeModelRuntime(
runtime=_FakeModelRuntime(
[
_build_provider(
provider="langgenius/openai/openai",
@ -346,7 +347,7 @@ def test_model_provider_factory_get_model_schema_and_icon_use_canonical_provider
)
runtime.get_model_schema.return_value = "schema"
runtime.get_provider_icon.return_value = (b"icon", "image/png")
factory = ModelProviderFactory(model_runtime=runtime)
factory = ModelProviderFactory(runtime=runtime)
assert (
factory.get_model_schema(
@ -382,39 +383,43 @@ def test_model_provider_factory_get_model_schema_and_icon_use_canonical_provider
(ModelType.TTS, TTSModel),
],
)
def test_model_provider_factory_builds_model_type_instances(
def test_create_model_type_instance_builds_model_wrappers(
model_type: ModelType,
expected_type: type[object],
) -> None:
factory = ModelProviderFactory(
model_runtime=_FakeModelRuntime(
[
_build_provider(
provider="langgenius/openai/openai",
provider_name="openai",
supported_model_types=[model_type],
)
]
)
runtime = _FakeModelRuntime(
[
_build_provider(
provider="langgenius/openai/openai",
provider_name="openai",
supported_model_types=[model_type],
)
]
)
instance = factory.get_model_type_instance("openai", model_type)
instance = create_model_type_instance(
runtime=runtime,
provider_schema=runtime.fetch_model_providers()[0],
model_type=model_type,
)
assert isinstance(instance, expected_type)
def test_model_provider_factory_rejects_unsupported_model_type() -> None:
factory = ModelProviderFactory(
model_runtime=_FakeModelRuntime(
[
_build_provider(
provider="langgenius/openai/openai",
provider_name="openai",
supported_model_types=[ModelType.LLM],
)
]
)
def test_create_model_type_instance_rejects_unsupported_model_type() -> None:
runtime = _FakeModelRuntime(
[
_build_provider(
provider="langgenius/openai/openai",
provider_name="openai",
supported_model_types=[ModelType.LLM],
)
]
)
with pytest.raises(ValueError, match="Unsupported model type: unsupported"):
factory.get_model_type_instance("openai", "unsupported") # type: ignore[arg-type]
create_model_type_instance(
runtime=runtime,
provider_schema=runtime.fetch_model_providers()[0],
model_type="unsupported", # type: ignore[arg-type]
)

View File

@ -31,6 +31,6 @@ def test_plugin_model_assembly_reuses_single_runtime_across_views():
assert assembly.model_manager is model_manager
mock_runtime_factory.assert_called_once_with(tenant_id="tenant-1", user_id="user-1")
mock_provider_factory_cls.assert_called_once_with(model_runtime=runtime)
mock_provider_factory_cls.assert_called_once_with(runtime=runtime)
mock_provider_manager_cls.assert_called_once_with(model_runtime=runtime)
mock_model_manager_cls.assert_called_once_with(provider_manager=provider_manager)

View File

@ -3,7 +3,7 @@
import datetime
import uuid
from types import SimpleNamespace
from unittest.mock import Mock, sentinel
from unittest.mock import Mock, patch, sentinel
import pytest
@ -13,6 +13,8 @@ from core.plugin.impl.model import PluginModelClient
from core.plugin.impl.model_runtime import TENANT_SCOPE_SCHEMA_CACHE_USER_ID, PluginModelRuntime
from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime
from graphon.model_runtime.entities.common_entities import I18nObject
from graphon.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta, LLMUsage
from graphon.model_runtime.entities.message_entities import AssistantPromptMessage
from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
@ -146,7 +148,31 @@ class TestPluginModelRuntime:
def test_invoke_llm_resolves_plugin_fields(self) -> None:
client = Mock(spec=PluginModelClient)
client.invoke_llm.return_value = sentinel.result
usage = LLMUsage.empty_usage()
client.invoke_llm.return_value = iter(
[
LLMResultChunk(
model="gpt-4o-mini",
prompt_messages=[],
system_fingerprint="fp-plugin",
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content="plugin "),
),
),
LLMResultChunk(
model="gpt-4o-mini",
prompt_messages=[],
system_fingerprint="fp-plugin",
delta=LLMResultChunkDelta(
index=1,
message=AssistantPromptMessage(content="response"),
usage=usage,
finish_reason="stop",
),
),
]
)
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client)
result = runtime.invoke_llm(
@ -160,7 +186,11 @@ class TestPluginModelRuntime:
stream=False,
)
assert result is sentinel.result
assert result.model == "gpt-4o-mini"
assert result.prompt_messages == []
assert result.message.content == "plugin response"
assert result.usage == usage
assert result.system_fingerprint == "fp-plugin"
client.invoke_llm.assert_called_once_with(
tenant_id="tenant",
user_id="user",
@ -175,6 +205,38 @@ class TestPluginModelRuntime:
stream=False,
)
def test_invoke_llm_returns_plugin_stream_directly(self) -> None:
client = Mock(spec=PluginModelClient)
stream_result = iter([])
client.invoke_llm.return_value = stream_result
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client)
result = runtime.invoke_llm(
provider="langgenius/openai/openai",
model="gpt-4o-mini",
credentials={"api_key": "secret"},
model_parameters={"temperature": 0.3},
prompt_messages=[],
tools=None,
stop=("END",),
stream=True,
)
assert result is stream_result
client.invoke_llm.assert_called_once_with(
tenant_id="tenant",
user_id="user",
plugin_id="langgenius/openai",
provider="openai",
model="gpt-4o-mini",
credentials={"api_key": "secret"},
model_parameters={"temperature": 0.3},
prompt_messages=[],
tools=None,
stop=["END"],
stream=True,
)
def test_invoke_llm_rejects_per_call_user_override(self) -> None:
client = Mock(spec=PluginModelClient)
client.invoke_llm.return_value = sentinel.result
@ -267,6 +329,129 @@ def test_get_model_schema_uses_cached_schema_without_hitting_client(monkeypatch:
client.get_model_schema.assert_not_called()
def test_structured_output_adapter_invokes_bound_runtime_streaming() -> None:
runtime = Mock()
runtime.invoke_llm.return_value = sentinel.stream_result
adapter = model_runtime_module._PluginStructuredOutputModelInstance(
runtime=runtime,
provider="langgenius/openai/openai",
model="gpt-4o-mini",
credentials={"api_key": "secret"},
)
tool = Mock()
result = adapter.invoke_llm(
prompt_messages=[],
model_parameters=None,
tools=[tool],
stop=["END"],
stream=True,
callbacks=sentinel.callbacks,
)
assert result is sentinel.stream_result
runtime.invoke_llm.assert_called_once_with(
provider="langgenius/openai/openai",
model="gpt-4o-mini",
credentials={"api_key": "secret"},
model_parameters={},
prompt_messages=[],
tools=[tool],
stop=["END"],
stream=True,
)
def test_structured_output_adapter_invokes_bound_runtime_non_streaming() -> None:
runtime = Mock()
runtime.invoke_llm.return_value = sentinel.result
adapter = model_runtime_module._PluginStructuredOutputModelInstance(
runtime=runtime,
provider="langgenius/openai/openai",
model="gpt-4o-mini",
credentials={"api_key": "secret"},
)
result = adapter.invoke_llm(
prompt_messages=[],
model_parameters={"temperature": 0},
tools=None,
stop=None,
stream=False,
)
assert result is sentinel.result
runtime.invoke_llm.assert_called_once_with(
provider="langgenius/openai/openai",
model="gpt-4o-mini",
credentials={"api_key": "secret"},
model_parameters={"temperature": 0},
prompt_messages=[],
tools=None,
stop=None,
stream=False,
)
def test_invoke_llm_with_structured_output_delegates_with_bound_adapter() -> None:
client = Mock(spec=PluginModelClient)
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client)
schema = _build_model_schema()
runtime.get_model_schema = Mock(return_value=schema) # type: ignore[method-assign]
with patch.object(
model_runtime_module,
"invoke_llm_with_structured_output_helper",
return_value=sentinel.structured_result,
) as mock_helper:
result = runtime.invoke_llm_with_structured_output(
provider="langgenius/openai/openai",
model="gpt-4o-mini",
credentials={"api_key": "secret"},
json_schema={"type": "object"},
model_parameters={"temperature": 0},
prompt_messages=[],
stop=("END",),
stream=False,
)
assert result is sentinel.structured_result
runtime.get_model_schema.assert_called_once_with(
provider="langgenius/openai/openai",
model_type=ModelType.LLM,
model="gpt-4o-mini",
credentials={"api_key": "secret"},
)
helper_kwargs = mock_helper.call_args.kwargs
assert helper_kwargs["provider"] == "langgenius/openai/openai"
assert helper_kwargs["model_schema"] == schema
assert helper_kwargs["json_schema"] == {"type": "object"}
assert helper_kwargs["model_parameters"] == {"temperature": 0}
assert helper_kwargs["prompt_messages"] == []
assert helper_kwargs["tools"] is None
assert helper_kwargs["stop"] == ["END"]
assert helper_kwargs["stream"] is False
assert isinstance(helper_kwargs["model_instance"], model_runtime_module._PluginStructuredOutputModelInstance)
def test_invoke_llm_with_structured_output_raises_when_model_schema_is_missing() -> None:
client = Mock(spec=PluginModelClient)
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client)
runtime.get_model_schema = Mock(return_value=None) # type: ignore[method-assign]
with pytest.raises(ValueError, match="Model schema not found for gpt-4o-mini"):
runtime.invoke_llm_with_structured_output(
provider="langgenius/openai/openai",
model="gpt-4o-mini",
credentials={"api_key": "secret"},
json_schema={"type": "object"},
model_parameters={},
prompt_messages=[],
stop=None,
stream=False,
)
def test_get_model_schema_deletes_invalid_cache_and_refetches(monkeypatch: pytest.MonkeyPatch) -> None:
client = Mock(spec=PluginModelClient)
schema = _build_model_schema()

View File

@ -289,7 +289,7 @@ def test_get_default_model_uses_injected_runtime_for_existing_default_record(moc
result = manager.get_default_model("tenant-id", ModelType.LLM)
mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime)
mock_factory_cls.assert_called_once_with(runtime=manager._model_runtime)
assert result is not None
assert result.model == "gpt-4"
assert result.provider.provider == "openai"
@ -316,7 +316,7 @@ def test_get_configurations_uses_injected_runtime_and_adds_provider_aliases(mock
result = manager.get_configurations("tenant-id")
expected_alias = str(ModelProviderID("openai"))
mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime)
mock_factory_cls.assert_called_once_with(runtime=manager._model_runtime)
assert result.tenant_id == "tenant-id"
assert expected_alias in provider_records
assert expected_alias in provider_model_records
@ -402,7 +402,7 @@ def test_get_configurations_reuses_cached_result_for_same_tenant(mocker: MockerF
assert first is second
mock_get_all_providers.assert_called_once_with("tenant-id")
mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime)
mock_factory_cls.assert_called_once_with(runtime=manager._model_runtime)
mock_provider_configuration.assert_called_once()
provider_configuration.bind_model_runtime.assert_called_once_with(manager._model_runtime)

View File

@ -1,12 +1,11 @@
import logging
import threading
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, UserFrom
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
from core.errors.error import QuotaExceededError
from core.model_manager import ModelInstance
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
from graphon.graph_engine.entities.commands import CommandType
from graphon.graph_events import NodeRunSucceededEvent
@ -14,17 +13,7 @@ from graphon.model_runtime.entities.llm_entities import LLMUsage
from graphon.node_events import NodeRunResult
def _build_dify_context() -> DifyRunContext:
return DifyRunContext(
tenant_id="tenant-id",
app_id="app-id",
user_id="user-id",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
)
def _build_succeeded_event() -> NodeRunSucceededEvent:
def _build_succeeded_event(*, provider: str = "openai", model_name: str = "gpt-4o") -> NodeRunSucceededEvent:
return NodeRunSucceededEvent(
id="execution-id",
node_id="llm-node-id",
@ -32,113 +21,162 @@ def _build_succeeded_event() -> NodeRunSucceededEvent:
start_at=datetime.now(),
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={"question": "hello"},
inputs={
"question": "hello",
"model_provider": provider,
"model_name": model_name,
},
llm_usage=LLMUsage.empty_usage(),
),
)
def _build_wrapped_model_instance() -> tuple[SimpleNamespace, ModelInstance]:
raw_model_instance = ModelInstance.__new__(ModelInstance)
return SimpleNamespace(_model_instance=raw_model_instance), raw_model_instance
def _build_public_model_identity(*, provider: str = "openai", model_name: str = "gpt-4o") -> SimpleNamespace:
return SimpleNamespace(provider=provider, name=model_name)
def _build_node_data(*, model: SimpleNamespace | None = None) -> SimpleNamespace:
return SimpleNamespace(
error_strategy=None,
retry_config=SimpleNamespace(retry_enabled=False),
model=model,
)
def _build_node(*, node_type: BuiltinNodeTypes = BuiltinNodeTypes.LLM) -> MagicMock:
node = MagicMock()
node.id = "node-id"
node.execution_id = "execution-id"
node.node_type = node_type
node.node_data = _build_node_data(model=_build_public_model_identity())
node.model_instance = SimpleNamespace(provider="stale-provider", model_name="stale-model")
return node
class _RunnableQuotaNode:
id = "node-id"
execution_id = "execution-id"
node_type = BuiltinNodeTypes.LLM
title = "LLM node"
def __init__(self, *, stop_event: threading.Event, node_data: SimpleNamespace | None = None) -> None:
self.node_data = node_data or _build_node_data(model=_build_public_model_identity())
self.graph_runtime_state = SimpleNamespace(stop_event=stop_event)
self.original_run_called = False
def _run(self) -> NodeRunResult:
self.original_run_called = True
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED)
def test_deduct_quota_called_for_successful_llm_node() -> None:
layer = LLMQuotaLayer()
node = MagicMock()
node.id = "llm-node-id"
node.execution_id = "execution-id"
node.node_type = BuiltinNodeTypes.LLM
node.tenant_id = "tenant-id"
node.require_run_context_value.return_value = _build_dify_context()
node.model_instance, raw_model_instance = _build_wrapped_model_instance()
layer = LLMQuotaLayer(tenant_id="tenant-id")
node = _build_node(node_type=BuiltinNodeTypes.LLM)
result_event = _build_succeeded_event()
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model", autospec=True) as mock_deduct:
layer.on_node_run_end(node=node, error=None, result_event=result_event)
mock_deduct.assert_called_once_with(
tenant_id="tenant-id",
model_instance=raw_model_instance,
provider="openai",
model="gpt-4o",
usage=result_event.node_run_result.llm_usage,
)
def test_deduct_quota_called_for_question_classifier_node() -> None:
layer = LLMQuotaLayer()
node = MagicMock()
node.id = "question-classifier-node-id"
node.execution_id = "execution-id"
node.node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER
node.tenant_id = "tenant-id"
node.require_run_context_value.return_value = _build_dify_context()
node.model_instance, raw_model_instance = _build_wrapped_model_instance()
layer = LLMQuotaLayer(tenant_id="tenant-id")
node = _build_node(node_type=BuiltinNodeTypes.QUESTION_CLASSIFIER)
result_event = _build_succeeded_event(provider="anthropic", model_name="claude-3-7-sonnet")
result_event = _build_succeeded_event()
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model", autospec=True) as mock_deduct:
layer.on_node_run_end(node=node, error=None, result_event=result_event)
mock_deduct.assert_called_once_with(
tenant_id="tenant-id",
model_instance=raw_model_instance,
provider="anthropic",
model="claude-3-7-sonnet",
usage=result_event.node_run_result.llm_usage,
)
def test_non_llm_node_is_ignored() -> None:
layer = LLMQuotaLayer()
node = MagicMock()
node.id = "start-node-id"
node.execution_id = "execution-id"
node.node_type = BuiltinNodeTypes.START
node.tenant_id = "tenant-id"
node.require_run_context_value.return_value = _build_dify_context()
node._model_instance = object()
layer = LLMQuotaLayer(tenant_id="tenant-id")
node = _build_node(node_type=BuiltinNodeTypes.START)
result_event = _build_succeeded_event()
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model", autospec=True) as mock_deduct:
layer.on_node_run_end(node=node, error=None, result_event=result_event)
mock_deduct.assert_not_called()
def test_quota_error_is_handled_in_layer() -> None:
layer = LLMQuotaLayer()
node = MagicMock()
node.id = "llm-node-id"
node.execution_id = "execution-id"
node.node_type = BuiltinNodeTypes.LLM
node.tenant_id = "tenant-id"
node.require_run_context_value.return_value = _build_dify_context()
node.model_instance = object()
def test_precheck_ignores_non_quota_node() -> None:
layer = LLMQuotaLayer(tenant_id="tenant-id")
node = _build_node(node_type=BuiltinNodeTypes.START)
result_event = _build_succeeded_event()
with patch(
"core.app.workflow.layers.llm_quota.deduct_llm_quota",
autospec=True,
side_effect=ValueError("quota exceeded"),
):
layer.on_node_run_end(node=node, error=None, result_event=result_event)
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check:
layer.on_node_run_start(node)
mock_check.assert_not_called()
def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None:
layer = LLMQuotaLayer()
def test_quota_error_is_handled_in_layer(caplog) -> None:
layer = LLMQuotaLayer(tenant_id="tenant-id")
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = MagicMock()
node.id = "llm-node-id"
node.execution_id = "execution-id"
node.node_type = BuiltinNodeTypes.LLM
node.tenant_id = "tenant-id"
node.require_run_context_value.return_value = _build_dify_context()
node.model_instance, _ = _build_wrapped_model_instance()
node = _build_node(node_type=BuiltinNodeTypes.LLM)
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.stop_event = stop_event
result_event = _build_succeeded_event()
with (
caplog.at_level(logging.ERROR, logger="core.app.workflow.layers.llm_quota"),
patch(
"core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model",
autospec=True,
side_effect=ValueError("quota exceeded"),
) as mock_deduct,
):
layer.on_node_run_end(node=node, error=None, result_event=result_event)
mock_deduct.assert_called_once_with(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
usage=result_event.node_run_result.llm_usage,
)
assert "LLM quota deduction failed, node_id=node-id" in caplog.text
assert not stop_event.is_set()
layer.command_channel.send_command.assert_not_called()
def test_send_abort_command_is_noop_without_channel_or_after_abort() -> None:
layer = LLMQuotaLayer(tenant_id="tenant-id")
layer._send_abort_command(reason="no channel")
layer.command_channel = MagicMock()
layer._abort_sent = True
layer._send_abort_command(reason="already aborted")
layer.command_channel.send_command.assert_not_called()
def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None:
layer = LLMQuotaLayer(tenant_id="tenant-id")
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = _build_node(node_type=BuiltinNodeTypes.LLM)
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.stop_event = stop_event
result_event = _build_succeeded_event()
with patch(
"core.app.workflow.layers.llm_quota.deduct_llm_quota",
"core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model",
autospec=True,
side_effect=QuotaExceededError("No credits remaining"),
):
@ -152,19 +190,16 @@ def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None:
def test_quota_precheck_failure_aborts_workflow_immediately() -> None:
layer = LLMQuotaLayer()
layer = LLMQuotaLayer(tenant_id="tenant-id")
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = MagicMock()
node.id = "llm-node-id"
node.node_type = BuiltinNodeTypes.LLM
node.model_instance, _ = _build_wrapped_model_instance()
node = _build_node(node_type=BuiltinNodeTypes.LLM)
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.stop_event = stop_event
with patch(
"core.app.workflow.layers.llm_quota.ensure_llm_quota_available",
"core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model",
autospec=True,
side_effect=QuotaExceededError("Model provider openai quota exceeded."),
):
@ -177,21 +212,140 @@ def test_quota_precheck_failure_aborts_workflow_immediately() -> None:
assert abort_command.reason == "Model provider openai quota exceeded."
def test_quota_precheck_passes_without_abort() -> None:
layer = LLMQuotaLayer()
def test_quota_precheck_failure_blocks_current_node_run() -> None:
layer = LLMQuotaLayer(tenant_id="tenant-id")
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = MagicMock()
node.id = "llm-node-id"
node.node_type = BuiltinNodeTypes.LLM
node.model_instance, raw_model_instance = _build_wrapped_model_instance()
node = _RunnableQuotaNode(stop_event=stop_event)
with patch(
"core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model",
autospec=True,
side_effect=QuotaExceededError("Model provider openai quota exceeded."),
):
layer.on_node_run_start(node)
result = node._run()
assert not node.original_run_called
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert result.error == "Model provider openai quota exceeded."
assert result.error_type == QuotaExceededError.__name__
def test_missing_model_identity_blocks_current_node_run() -> None:
layer = LLMQuotaLayer(tenant_id="tenant-id")
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = _RunnableQuotaNode(stop_event=stop_event, node_data=_build_node_data())
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check:
layer.on_node_run_start(node)
result = node._run()
assert not node.original_run_called
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert result.error == "LLM quota check requires public node model identity before execution."
assert result.error_type == "LLMQuotaIdentityError"
mock_check.assert_not_called()
def test_quota_precheck_passes_without_abort() -> None:
layer = LLMQuotaLayer(tenant_id="tenant-id")
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = _build_node(node_type=BuiltinNodeTypes.LLM)
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.stop_event = stop_event
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available", autospec=True) as mock_check:
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check:
layer.on_node_run_start(node)
assert not stop_event.is_set()
mock_check.assert_called_once_with(model_instance=raw_model_instance)
mock_check.assert_called_once_with(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
)
layer.command_channel.send_command.assert_not_called()
def test_precheck_reads_model_identity_from_data_when_node_data_is_absent() -> None:
layer = LLMQuotaLayer(tenant_id="tenant-id")
node = SimpleNamespace(
id="node-id",
node_type=BuiltinNodeTypes.LLM,
data=_build_node_data(model=_build_public_model_identity(provider="anthropic", model_name="claude")),
)
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check:
layer.on_node_run_start(node)
mock_check.assert_called_once_with(
tenant_id="tenant-id",
provider="anthropic",
model="claude",
)
def test_precheck_rejects_invalid_public_model_identity() -> None:
layer = LLMQuotaLayer(tenant_id="tenant-id")
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = _build_node(node_type=BuiltinNodeTypes.LLM)
node.node_data = _build_node_data(model=_build_public_model_identity(provider="", model_name="gpt-4o"))
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.stop_event = stop_event
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check:
layer.on_node_run_start(node)
assert stop_event.is_set()
mock_check.assert_not_called()
layer.command_channel.send_command.assert_called_once()
def test_precheck_requires_public_node_model_config() -> None:
layer = LLMQuotaLayer(tenant_id="tenant-id")
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = _build_node(node_type=BuiltinNodeTypes.LLM)
node.node_data = _build_node_data()
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.stop_event = stop_event
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check:
layer.on_node_run_start(node)
assert stop_event.is_set()
mock_check.assert_not_called()
layer.command_channel.send_command.assert_called_once()
abort_command = layer.command_channel.send_command.call_args.args[0]
assert abort_command.command_type == CommandType.ABORT
assert abort_command.reason == "LLM quota check requires public node model identity before execution."
def test_deduction_requires_public_event_model_identity() -> None:
layer = LLMQuotaLayer(tenant_id="tenant-id")
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = _build_node(node_type=BuiltinNodeTypes.LLM)
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.stop_event = stop_event
result_event = _build_succeeded_event()
result_event.node_run_result.inputs = {"question": "hello"}
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model", autospec=True) as mock_deduct:
layer.on_node_run_end(node=node, error=None, result_event=result_event)
assert stop_event.is_set()
mock_deduct.assert_not_called()
layer.command_channel.send_command.assert_called_once()
abort_command = layer.command_channel.send_command.call_args.args[0]
assert abort_command.command_type == CommandType.ABORT
assert abort_command.reason == "LLM quota deduction requires model identity in the node result event."

View File

@ -96,7 +96,7 @@ class MockNodeFactory(DifyNodeFactory):
if node_type == BuiltinNodeTypes.CODE:
mock_instance = mock_class(
node_id=node_id,
config=resolved_node_data,
data=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,
@ -106,7 +106,7 @@ class MockNodeFactory(DifyNodeFactory):
elif node_type == BuiltinNodeTypes.HTTP_REQUEST:
mock_instance = mock_class(
node_id=node_id,
config=resolved_node_data,
data=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,
@ -122,7 +122,7 @@ class MockNodeFactory(DifyNodeFactory):
}:
mock_instance = mock_class(
node_id=node_id,
config=resolved_node_data,
data=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,
@ -132,7 +132,7 @@ class MockNodeFactory(DifyNodeFactory):
else:
mock_instance = mock_class(
node_id=node_id,
config=resolved_node_data,
data=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,

View File

@ -56,7 +56,7 @@ class MockNodeMixin:
def __init__(
self,
node_id: str,
config: Any,
data: Any,
*,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
@ -98,7 +98,7 @@ class MockNodeMixin:
super().__init__(
node_id=node_id,
config=config,
data=data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
**kwargs,

View File

@ -111,7 +111,7 @@ class StaticRepo(HumanInputFormRepository):
def _build_runtime_state() -> GraphRuntimeState:
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(
user_id="user",
app_id="app",
@ -140,7 +140,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
start_node = StartNode(
node_id=start_config["id"],
config=StartNodeData(title="Start", variables=[]),
data=StartNodeData(title="Start", variables=[]),
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)
@ -155,7 +155,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
human_a_config = {"id": "human_a", "data": human_data.model_dump()}
human_a = HumanInputNode(
node_id=human_a_config["id"],
config=human_data,
data=human_data,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
form_repository=repo,
@ -165,7 +165,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
human_b_config = {"id": "human_b", "data": human_data.model_dump()}
human_b = HumanInputNode(
node_id=human_b_config["id"],
config=human_data,
data=human_data,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
form_repository=repo,
@ -183,7 +183,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
end_config = {"id": "end", "data": end_data.model_dump()}
end_node = EndNode(
node_id=end_config["id"],
config=end_data,
data=end_data,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)

View File

@ -1,41 +1,36 @@
import time
import uuid
from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
from core.workflow.node_factory import DifyNodeFactory
from core.workflow.system_variables import build_system_variables
from extensions.ext_database import db
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.graph import Graph
from graphon.nodes.answer.answer_node import AnswerNode
from graphon.nodes.answer.entities import AnswerNodeData
from graphon.runtime import GraphRuntimeState, VariablePool
from tests.workflow_test_utils import build_test_graph_init_params
def test_execute_answer():
def _build_variable_pool() -> VariablePool:
return VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="aaa", files=[]),
user_inputs={},
)
def _build_answer_node(*, answer: str, variable_pool: VariablePool) -> AnswerNode:
graph_config = {
"edges": [
{
"id": "start-source-answer-target",
"source": "start",
"target": "answer",
},
],
"edges": [],
"nodes": [
{"data": {"type": "start", "title": "Start"}, "id": "start"},
{
"data": {
"title": "123",
"title": "Answer",
"type": "answer",
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
"answer": answer,
},
"id": "answer",
},
}
],
}
init_params = build_test_graph_init_params(
workflow_id="1",
graph_config=graph_config,
@ -46,42 +41,31 @@ def test_execute_answer():
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
# construct variable pool
variable_pool = VariablePool(
system_variables=build_system_variables(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
conversation_variables=[],
graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
start_at=time.perf_counter(),
)
variable_pool.add(["start", "weather"], "sunny")
variable_pool.add(["llm", "text"], "You are a helpful AI.")
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# create node factory
node_factory = DifyNodeFactory(
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
node = AnswerNode(
return AnswerNode(
node_id=str(uuid.uuid4()),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
config=AnswerNodeData(
title="123",
data=AnswerNodeData(
title="Answer",
type="answer",
answer="Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
answer=answer,
),
)
# Mock db.session.close()
db.session.close = MagicMock()
# execute node
def test_execute_answer_renders_variable_selectors() -> None:
variable_pool = _build_variable_pool()
variable_pool.add(["start", "weather"], "sunny")
variable_pool.add(["llm", "text"], "You are a helpful AI.")
node = _build_answer_node(
answer="Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
variable_pool=variable_pool,
)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
@ -89,36 +73,11 @@ def test_execute_answer():
def test_execute_answer_renders_structured_output_object_as_json() -> None:
init_params = build_test_graph_init_params(
workflow_id="1",
graph_config={"nodes": [], "edges": []},
tenant_id="1",
app_id="1",
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
variable_pool = VariablePool(
system_variables=build_system_variables(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
conversation_variables=[],
)
variable_pool = _build_variable_pool()
variable_pool.add(["1777539038857", "structured_output"], {"type": "greeting"})
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
node = AnswerNode(
node_id=str(uuid.uuid4()),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
config=AnswerNodeData(
title="123",
type="answer",
answer="{{#1777539038857.structured_output#}}",
),
node = _build_answer_node(
answer="{{#1777539038857.structured_output#}}",
variable_pool=variable_pool,
)
result = node._run()
@ -128,35 +87,9 @@ def test_execute_answer_renders_structured_output_object_as_json() -> None:
def test_execute_answer_falls_back_to_plain_selector_text_when_structured_output_missing() -> None:
init_params = build_test_graph_init_params(
workflow_id="1",
graph_config={"nodes": [], "edges": []},
tenant_id="1",
app_id="1",
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
variable_pool = VariablePool(
system_variables=build_system_variables(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
conversation_variables=[],
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
node = AnswerNode(
node_id=str(uuid.uuid4()),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
config=AnswerNodeData(
title="123",
type="answer",
answer="{{#1777539038857.structured_output#}}",
),
node = _build_answer_node(
answer="{{#1777539038857.structured_output#}}",
variable_pool=_build_variable_pool(),
)
result = node._run()

View File

@ -81,7 +81,7 @@ def test_datasource_node_delegates_to_manager_stream(mocker: MockerFixture):
node = DatasourceNode(
node_id="n",
config=DatasourceNodeData(
data=DatasourceNodeData(
type="datasource",
version="1",
title="Datasource",

View File

@ -29,7 +29,7 @@ HTTP_REQUEST_CONFIG = HttpRequestNodeConfig(
def test_executor_with_json_body_and_number_variable():
# Prepare the variable pool
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
)
@ -85,7 +85,7 @@ def test_executor_with_json_body_and_number_variable():
def test_executor_with_json_body_and_object_variable():
# Prepare the variable pool
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
)
@ -143,7 +143,7 @@ def test_executor_with_json_body_and_object_variable():
def test_executor_with_json_body_and_nested_object_variable():
# Prepare the variable pool
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
)
@ -201,7 +201,7 @@ def test_executor_with_json_body_and_nested_object_variable():
def test_extract_selectors_from_template_with_newline():
variable_pool = VariablePool(system_variables=default_system_variables())
variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables())
variable_pool.add(("node_id", "custom_query"), "line1\nline2")
node_data = HttpRequestNodeData(
title="Test JSON Body with Nested Object Variable",
@ -230,7 +230,7 @@ def test_extract_selectors_from_template_with_newline():
def test_executor_with_form_data():
# Prepare the variable pool
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
)
@ -320,7 +320,7 @@ def test_init_headers():
node_data=node_data,
timeout=timeout,
http_request_config=HTTP_REQUEST_CONFIG,
variable_pool=VariablePool(system_variables=default_system_variables()),
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()),
http_client=ssrf_proxy,
file_manager=file_manager,
)
@ -357,7 +357,7 @@ def test_init_params():
node_data=node_data,
timeout=timeout,
http_request_config=HTTP_REQUEST_CONFIG,
variable_pool=VariablePool(system_variables=default_system_variables()),
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()),
http_client=ssrf_proxy,
file_manager=file_manager,
)
@ -390,7 +390,7 @@ def test_init_params():
def test_empty_api_key_raises_error_bearer():
"""Test that empty API key raises AuthorizationConfigError for bearer auth."""
variable_pool = VariablePool(system_variables=default_system_variables())
variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables())
node_data = HttpRequestNodeData(
title="test",
method="get",
@ -417,7 +417,7 @@ def test_empty_api_key_raises_error_bearer():
def test_empty_api_key_raises_error_basic():
"""Test that empty API key raises AuthorizationConfigError for basic auth."""
variable_pool = VariablePool(system_variables=default_system_variables())
variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables())
node_data = HttpRequestNodeData(
title="test",
method="get",
@ -444,7 +444,7 @@ def test_empty_api_key_raises_error_basic():
def test_empty_api_key_raises_error_custom():
"""Test that empty API key raises AuthorizationConfigError for custom auth."""
variable_pool = VariablePool(system_variables=default_system_variables())
variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables())
node_data = HttpRequestNodeData(
title="test",
method="get",
@ -471,7 +471,7 @@ def test_empty_api_key_raises_error_custom():
def test_whitespace_only_api_key_raises_error():
"""Test that whitespace-only API key raises AuthorizationConfigError."""
variable_pool = VariablePool(system_variables=default_system_variables())
variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables())
node_data = HttpRequestNodeData(
title="test",
method="get",
@ -498,7 +498,7 @@ def test_whitespace_only_api_key_raises_error():
def test_valid_api_key_works():
"""Test that valid API key works correctly for bearer auth."""
variable_pool = VariablePool(system_variables=default_system_variables())
variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables())
node_data = HttpRequestNodeData(
title="test",
method="get",
@ -536,7 +536,7 @@ def test_executor_with_json_body_and_unquoted_uuid_variable():
# UUID that triggers the json_repair truncation bug
test_uuid = "57eeeeb1-450b-482c-81b9-4be77e95dee2"
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
)
@ -583,7 +583,7 @@ def test_executor_with_json_body_and_unquoted_uuid_with_newlines():
"""
test_uuid = "57eeeeb1-450b-482c-81b9-4be77e95dee2"
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
)
@ -624,7 +624,7 @@ def test_executor_with_json_body_and_unquoted_uuid_with_newlines():
def test_executor_with_json_body_preserves_numbers_and_strings():
"""Test that numbers are preserved and string values are properly quoted."""
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
)

View File

@ -110,12 +110,15 @@ def _build_http_node(
call_depth=0,
)
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}),
variable_pool=VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="user", files=[]),
user_inputs={},
),
start_at=time.perf_counter(),
)
return HttpRequestNode(
node_id="http-node",
config=HttpRequestNodeData.model_validate(node_data),
data=HttpRequestNodeData.model_validate(node_data),
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
http_request_config=HTTP_REQUEST_CONFIG,

View File

@ -149,7 +149,7 @@ def _build_human_input_node(
)
return HumanInputNode(
node_id=node_id,
config=typed_node_data,
data=typed_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
runtime=runtime,
@ -241,16 +241,16 @@ class TestUserAction:
def test_user_action_length_boundaries(self):
"""Test user action id and title length boundaries."""
action = UserAction(id="a" * 20, title="b" * 20)
action = UserAction(id="a" * 20, title="b" * 100)
assert action.id == "a" * 20
assert action.title == "b" * 20
assert action.title == "b" * 100
@pytest.mark.parametrize(
("field_name", "value"),
[
("id", "a" * 21),
("title", "b" * 21),
("title", "b" * 101),
],
)
def test_user_action_length_limits(self, field_name: str, value: str):
@ -427,7 +427,7 @@ class TestHumanInputNodeVariableResolution:
"""Tests for resolving variable-based defaults in HumanInputNode."""
def test_resolves_variable_defaults(self):
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(
user_id="user",
app_id="app",
@ -504,7 +504,7 @@ class TestHumanInputNodeVariableResolution:
assert params.resolved_default_values == expected_values
def test_debugger_falls_back_to_recipient_token_when_webapp_disabled(self):
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(
user_id="user",
app_id="app",
@ -565,7 +565,7 @@ class TestHumanInputNodeVariableResolution:
assert not hasattr(pause_event.reason, "form_token")
def test_webapp_runtime_keeps_form_visible_in_ui_when_webapp_delivery_is_enabled(self):
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(
user_id="user",
app_id="app",
@ -631,7 +631,7 @@ class TestHumanInputNodeVariableResolution:
assert params.display_in_ui is True
def test_debugger_debug_mode_overrides_email_recipients(self):
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(
user_id="user-123",
app_id="app",
@ -748,7 +748,7 @@ class TestHumanInputNodeRenderedContent:
"""Tests for rendering submitted content."""
def test_replaces_outputs_placeholders_after_submission(self):
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(
user_id="user",
app_id="app",

View File

@ -40,7 +40,7 @@ def _create_human_input_node(
)
return HumanInputNode(
node_id=config["id"],
config=node_data,
data=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
form_repository=repo,
@ -51,7 +51,11 @@ def _create_human_input_node(
def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#}}") -> HumanInputNode:
system_variables = default_system_variables()
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=system_variables, user_inputs={}, environment_variables=[]),
variable_pool=VariablePool.from_bootstrap(
system_variables=system_variables,
user_inputs={},
environment_variables=[],
),
start_at=0.0,
)
graph_init_params = GraphInitParams(
@ -114,7 +118,11 @@ def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#
def _build_timeout_node() -> HumanInputNode:
system_variables = default_system_variables()
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=system_variables, user_inputs={}, environment_variables=[]),
variable_pool=VariablePool.from_bootstrap(
system_variables=system_variables,
user_inputs={},
environment_variables=[],
),
start_at=0.0,
)
graph_init_params = GraphInitParams(

View File

@ -32,7 +32,7 @@ class _MissingGraphBuilder:
def _build_runtime_state() -> GraphRuntimeState:
return GraphRuntimeState(
variable_pool=VariablePool(system_variables=default_system_variables(), user_inputs={}),
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables(), user_inputs={}),
start_at=0.0,
)
@ -46,7 +46,7 @@ def _build_iteration_node(
init_params = build_test_graph_init_params(graph_config=graph_config)
return IterationNode(
node_id="iteration-node",
config=IterationNodeData(
data=IterationNodeData(
type="iteration",
title="Iteration",
iterator_selector=["start", "items"],

View File

@ -41,7 +41,7 @@ def mock_graph_init_params():
@pytest.fixture
def mock_graph_runtime_state():
"""Create mock GraphRuntimeState."""
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id=str(uuid.uuid4()), files=[]),
user_inputs={},
environment_variables=[],
@ -103,7 +103,7 @@ def _build_node(
) -> KnowledgeIndexNode:
return KnowledgeIndexNode(
node_id=node_id,
config=(
data=(
node_data
if isinstance(node_data, KnowledgeIndexNodeData)
else KnowledgeIndexNodeData.model_validate(node_data)

View File

@ -47,7 +47,7 @@ def mock_graph_init_params():
@pytest.fixture
def mock_graph_runtime_state():
"""Create mock GraphRuntimeState."""
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id=str(uuid.uuid4()), files=[]),
user_inputs={},
environment_variables=[],
@ -118,7 +118,7 @@ class TestKnowledgeRetrievalNode:
# Act
node = KnowledgeRetrievalNode(
node_id=node_id,
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -147,7 +147,7 @@ class TestKnowledgeRetrievalNode:
node = KnowledgeRetrievalNode(
node_id=node_id,
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -206,7 +206,7 @@ class TestKnowledgeRetrievalNode:
node = KnowledgeRetrievalNode(
node_id=node_id,
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -250,7 +250,7 @@ class TestKnowledgeRetrievalNode:
node = KnowledgeRetrievalNode(
node_id=node_id,
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -286,7 +286,7 @@ class TestKnowledgeRetrievalNode:
node = KnowledgeRetrievalNode(
node_id=node_id,
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -321,7 +321,7 @@ class TestKnowledgeRetrievalNode:
node = KnowledgeRetrievalNode(
node_id=node_id,
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -362,7 +362,7 @@ class TestKnowledgeRetrievalNode:
node = KnowledgeRetrievalNode(
node_id=node_id,
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -401,7 +401,7 @@ class TestKnowledgeRetrievalNode:
node = KnowledgeRetrievalNode(
node_id=node_id,
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -482,7 +482,7 @@ class TestFetchDatasetRetriever:
node = KnowledgeRetrievalNode(
node_id=node_id,
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -519,7 +519,7 @@ class TestFetchDatasetRetriever:
node = KnowledgeRetrievalNode(
node_id=node_id,
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -574,7 +574,7 @@ class TestFetchDatasetRetriever:
node = KnowledgeRetrievalNode(
node_id=node_id,
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -622,7 +622,7 @@ class TestFetchDatasetRetriever:
node = KnowledgeRetrievalNode(
node_id=node_id,
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -683,7 +683,7 @@ class TestFetchDatasetRetriever:
config = {"id": node_id, "data": node_data.model_dump()}
node = KnowledgeRetrievalNode(
node_id=node_id,
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)

View File

@ -16,10 +16,10 @@ class TestListOperatorNode:
"""Comprehensive tests for ListOperatorNode."""
@staticmethod
def _build_node(*, config, graph_init_params, graph_runtime_state):
def _build_node(*, data, graph_init_params, graph_runtime_state):
return ListOperatorNode(
node_id="test",
config=config if isinstance(config, ListOperatorNodeData) else ListOperatorNodeData.model_validate(config),
data=data if isinstance(data, ListOperatorNodeData) else ListOperatorNodeData.model_validate(data),
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
@ -65,7 +65,7 @@ class TestListOperatorNode:
def _create_node(config, mock_variable):
mock_graph_runtime_state.variable_pool.get.return_value = mock_variable
return self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -83,7 +83,7 @@ class TestListOperatorNode:
}
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -127,7 +127,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -153,7 +153,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -177,7 +177,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -201,7 +201,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -228,7 +228,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -255,7 +255,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -282,7 +282,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -312,7 +312,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -335,7 +335,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = None
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -359,7 +359,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -384,7 +384,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -408,7 +408,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -432,7 +432,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -456,7 +456,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -483,7 +483,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)

View File

@ -15,7 +15,7 @@ from core.app.llm.model_access import (
)
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
from core.entities.provider_entities import CustomConfiguration, SystemConfiguration
from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime
from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.workflow.system_variables import default_system_variables
from graphon.entities import GraphInitParams
@ -187,7 +187,7 @@ def graph_init_params() -> GraphInitParams:
@pytest.fixture
def graph_runtime_state() -> GraphRuntimeState:
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
)
@ -208,7 +208,7 @@ def llm_node(
http_client = mock.MagicMock()
node = LLMNode(
node_id="1",
config=llm_node_data,
data=llm_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
credentials_provider=mock_credentials_provider,
@ -241,9 +241,10 @@ def model_config(monkeypatch: pytest.MonkeyPatch):
)
# Create actual provider and model type instances
model_provider_factory = ModelProviderFactory(model_runtime=create_plugin_model_runtime(tenant_id="test"))
model_assembly = create_plugin_model_assembly(tenant_id="test")
model_provider_factory = model_assembly.model_provider_factory
provider_instance = model_provider_factory.get_model_provider("openai")
model_type_instance = model_provider_factory.get_model_type_instance("openai", ModelType.LLM)
model_type_instance = model_assembly.create_model_type_instance(provider="openai", model_type=ModelType.LLM)
# Create a ProviderModelBundle
provider_model_bundle = ProviderModelBundle(
@ -1173,7 +1174,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
http_client = mock.MagicMock()
node = LLMNode(
node_id="1",
config=llm_node_data,
data=llm_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
credentials_provider=mock_credentials_provider,

View File

@ -28,7 +28,7 @@ def _build_template_transform_node(
)
return TemplateTransformNode(
node_id=node_id,
config=typed_node_data,
data=typed_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
**kwargs,

View File

@ -39,7 +39,7 @@ def mock_graph_runtime_state():
def test_node_uses_default_max_output_length_when_not_overridden(graph_init_params, mock_graph_runtime_state):
node = TemplateTransformNode(
node_id="test_node",
config=TemplateTransformNodeData(
data=TemplateTransformNodeData(
title="Template Transform",
type="template-transform",
variables=[],

View File

@ -35,7 +35,10 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams,
invoke_from="debugger",
)
runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}),
variable_pool=VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="user", files=[]),
user_inputs={},
),
start_at=0.0,
)
return init_params, runtime_state
@ -62,7 +65,7 @@ def test_node_hydrates_data_during_initialization():
node = _SampleNode(
node_id="node-1",
config=_build_node_data(),
data=_build_node_data(),
graph_init_params=init_params,
graph_runtime_state=runtime_state,
)
@ -82,13 +85,16 @@ def test_node_accepts_invoke_from_enum():
invoke_from=InvokeFrom.DEBUGGER,
)
runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}),
variable_pool=VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="user", files=[]),
user_inputs={},
),
start_at=0.0,
)
node = _SampleNode(
node_id="node-1",
config=_build_node_data(),
data=_build_node_data(),
graph_init_params=init_params,
graph_runtime_state=runtime_state,
)
@ -140,7 +146,7 @@ def test_node_hydration_preserves_compatibility_extra_fields():
node = _SampleNode(
node_id="node-1",
config=node_config["data"],
data=node_config["data"],
graph_init_params=init_params,
graph_runtime_state=runtime_state,
)

View File

@ -49,7 +49,7 @@ def document_extractor_node(graph_init_params):
http_client = Mock()
node = DocumentExtractorNode(
node_id="test_node_id",
config=node_data,
data=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=Mock(),
http_client=http_client,
@ -186,12 +186,13 @@ def test_run_extract_text(
monkeypatch.setattr("graphon.file.file_manager.download", mock_download)
dispatch_mock = None
if mime_type == "application/pdf":
mock_pdf_extract = Mock(return_value=expected_text[0])
monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_from_pdf", mock_pdf_extract)
dispatch_mock = Mock(return_value=expected_text[0])
monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_by_file_extension", dispatch_mock)
elif mime_type.startswith("application/vnd.openxmlformats"):
mock_docx_extract = Mock(return_value=expected_text[0])
monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_from_docx", mock_docx_extract)
dispatch_mock = Mock(return_value=expected_text[0])
monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_by_mime_type", dispatch_mock)
result = document_extractor_node._run()
@ -200,6 +201,19 @@ def test_run_extract_text(
assert result.outputs is not None
assert result.outputs["text"] == ArrayStringSegment(value=expected_text)
if mime_type == "application/pdf":
dispatch_mock.assert_called_once_with(
file_content=file_content,
file_extension=extension,
unstructured_api_config=document_extractor_node._unstructured_api_config,
)
elif mime_type.startswith("application/vnd.openxmlformats"):
dispatch_mock.assert_called_once_with(
file_content=file_content,
mime_type=mime_type,
unstructured_api_config=document_extractor_node._unstructured_api_config,
)
if transfer_method == FileTransferMethod.REMOTE_URL:
document_extractor_node._http_client.get.assert_called_once_with("https://example.com/file.txt")
elif transfer_method == FileTransferMethod.LOCAL_FILE:
@ -439,24 +453,42 @@ def test_extract_text_from_file_routes_excel_inputs(document_extractor_node, ext
file.extension = extension
file.mime_type = mime_type
with (
patch(
"graphon.nodes.document_extractor.node._download_file_content",
return_value=b"excel",
),
patch(
"graphon.nodes.document_extractor.node._extract_text_from_excel",
return_value="excel text",
) as mock_extract,
with patch(
"graphon.nodes.document_extractor.node._download_file_content",
return_value=b"excel",
):
result = _extract_text_from_file(
document_extractor_node.http_client,
file,
unstructured_api_config=document_extractor_node._unstructured_api_config,
)
if extension:
with patch(
"graphon.nodes.document_extractor.node._extract_text_by_file_extension",
return_value="excel text",
) as mock_extract:
result = _extract_text_from_file(
document_extractor_node.http_client,
file,
unstructured_api_config=document_extractor_node._unstructured_api_config,
)
mock_extract.assert_called_once_with(
file_content=b"excel",
file_extension=extension,
unstructured_api_config=document_extractor_node._unstructured_api_config,
)
else:
with patch(
"graphon.nodes.document_extractor.node._extract_text_by_mime_type",
return_value="excel text",
) as mock_extract:
result = _extract_text_from_file(
document_extractor_node.http_client,
file,
unstructured_api_config=document_extractor_node._unstructured_api_config,
)
mock_extract.assert_called_once_with(
file_content=b"excel",
mime_type=mime_type,
unstructured_api_config=document_extractor_node._unstructured_api_config,
)
assert result == "excel text"
mock_extract.assert_called_once_with(b"excel")
def test_extract_text_from_file_rejects_missing_extension_and_mime_type(document_extractor_node):

View File

@ -29,7 +29,7 @@ def _build_if_else_node(
node_id=str(uuid.uuid4()),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
config=node_data if isinstance(node_data, IfElseNodeData) else IfElseNodeData.model_validate(node_data),
data=node_data if isinstance(node_data, IfElseNodeData) else IfElseNodeData.model_validate(node_data),
)
@ -48,7 +48,10 @@ def test_execute_if_else_result_true():
)
# construct variable pool
pool = VariablePool(system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={})
pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="aaa", files=[]),
user_inputs={},
)
pool.add(["start", "array_contains"], ["ab", "def"])
pool.add(["start", "array_not_contains"], ["ac", "def"])
pool.add(["start", "contains"], "cabcde")
@ -148,7 +151,7 @@ def test_execute_if_else_result_false():
)
# construct variable pool
pool = VariablePool(
pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
@ -305,7 +308,7 @@ def test_execute_if_else_boolean_conditions(condition: Condition):
)
# construct variable pool with boolean values
pool = VariablePool(
pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(files=[], user_id="aaa"),
)
pool.add(["start", "bool_true"], True)
@ -359,7 +362,7 @@ def test_execute_if_else_boolean_false_conditions():
)
# construct variable pool with boolean values
pool = VariablePool(
pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(files=[], user_id="aaa"),
)
pool.add(["start", "bool_true"], True)
@ -424,7 +427,7 @@ def test_execute_if_else_boolean_cases_structure():
)
# construct variable pool with boolean values
pool = VariablePool(
pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(files=[], user_id="aaa"),
)
pool.add(["start", "bool_true"], True)

View File

@ -22,7 +22,7 @@ from graphon.variables import ArrayFileSegment
def _build_list_operator_node(node_data: ListOperatorNodeData, graph_init_params) -> ListOperatorNode:
return ListOperatorNode(
node_id="test_node_id",
config=node_data,
data=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=MagicMock(),
)

View File

@ -31,7 +31,7 @@ def make_start_node(user_inputs, variables):
return StartNode(
node_id="start",
config=node_data,
data=node_data,
graph_init_params=build_test_graph_init_params(
workflow_id="wf",
graph_config={},
@ -260,7 +260,7 @@ def test_start_node_outputs_full_variable_pool_snapshot():
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
node = StartNode(
node_id="start",
config=node_data,
data=node_data,
graph_init_params=build_test_graph_init_params(
workflow_id="wf",
graph_config={},

View File

@ -99,7 +99,7 @@ def tool_node(monkeypatch) -> ToolNode:
call_depth=0,
)
variable_pool = VariablePool(system_variables=build_system_variables(user_id="user-id"))
variable_pool = VariablePool.from_bootstrap(system_variables=build_system_variables(user_id="user-id"))
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
config = graph_config["nodes"][0]
@ -110,7 +110,7 @@ def tool_node(monkeypatch) -> ToolNode:
node = ToolNode(
node_id="node-instance",
config=ToolNodeData.model_validate(config["data"]),
data=ToolNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
tool_file_manager_factory=tool_file_manager_factory,

View File

@ -44,7 +44,7 @@ def test_trigger_event_node_run_populates_trigger_info_metadata() -> None:
init_params, runtime_state = _build_context(graph_config={})
node = TriggerEventNode(
node_id="node-1",
config=_build_node_data(),
data=_build_node_data(),
graph_init_params=init_params,
graph_runtime_state=runtime_state,
)

View File

@ -52,7 +52,7 @@ def create_webhook_node(
node = TriggerWebhookNode(
node_id="webhook-node-1",
config=webhook_data,
data=webhook_data,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)

View File

@ -44,7 +44,7 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool)
)
node = TriggerWebhookNode(
node_id="1",
config=webhook_data,
data=webhook_data,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)

View File

@ -1,3 +1,4 @@
from collections.abc import Mapping
from types import SimpleNamespace
from unittest.mock import MagicMock, patch, sentinel
@ -11,19 +12,20 @@ from graphon.entities.base_node_data import BaseNodeData
from graphon.enums import BuiltinNodeTypes, NodeType
from graphon.nodes.code.entities import CodeLanguage
from graphon.nodes.llm.entities import LLMNodeData
from graphon.nodes.llm.node import LLMNode
from graphon.variables.segments import StringSegment
def _assert_typed_node_config(config, *, node_id: str, node_type: NodeType, version: str = "1") -> None:
def _assert_constructor_node_data(data, *, node_id: str, node_type: NodeType, version: str = "1") -> None:
_ = node_id
if isinstance(config, BaseNodeData):
assert config.type == node_type
assert config.version == version
if isinstance(data, BaseNodeData):
assert data.type == node_type
assert data.version == version
return
assert isinstance(config, dict)
assert config["type"] == node_type
assert config["version"] == version
assert isinstance(data, Mapping)
assert data["type"] == node_type
assert data.get("version", "1") == version
def _node_constructor(*, return_value):
@ -470,7 +472,7 @@ class TestDifyNodeFactoryCreateNode:
matched_node_class.assert_called_once()
kwargs = matched_node_class.call_args.kwargs
assert kwargs["node_id"] == "node-id"
_assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=BuiltinNodeTypes.START, version="9")
_assert_constructor_node_data(kwargs["data"], node_id="node-id", node_type=BuiltinNodeTypes.START, version="9")
assert kwargs["graph_init_params"] is sentinel.graph_init_params
assert kwargs["graph_runtime_state"] is factory.graph_runtime_state
latest_node_class.assert_not_called()
@ -492,7 +494,7 @@ class TestDifyNodeFactoryCreateNode:
latest_node_class.assert_called_once()
kwargs = latest_node_class.call_args.kwargs
assert kwargs["node_id"] == "node-id"
_assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=BuiltinNodeTypes.START, version="9")
_assert_constructor_node_data(kwargs["data"], node_id="node-id", node_type=BuiltinNodeTypes.START, version="9")
assert kwargs["graph_init_params"] is sentinel.graph_init_params
assert kwargs["graph_runtime_state"] is factory.graph_runtime_state
@ -530,7 +532,7 @@ class TestDifyNodeFactoryCreateNode:
assert result is created_node
kwargs = constructor.call_args.kwargs
assert kwargs["node_id"] == "node-id"
_assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=node_type)
_assert_constructor_node_data(kwargs["data"], node_id="node-id", node_type=node_type)
assert kwargs["graph_init_params"] is sentinel.graph_init_params
assert kwargs["graph_runtime_state"] is factory.graph_runtime_state
@ -599,11 +601,12 @@ class TestDifyNodeFactoryCreateNode:
prepared_llm.assert_called_once_with(sentinel.model_instance)
assert kwargs["model_instance"] is wrapped_model_instance
def test_create_node_passes_alias_preserving_llm_config_to_constructor(
self, monkeypatch: pytest.MonkeyPatch, factory
):
def test_create_node_passes_alias_preserving_llm_data_to_constructor(self, monkeypatch, factory):
created_node = object()
constructor = _node_constructor(return_value=created_node)
constructor.validate_node_data.side_effect = lambda node_data: LLMNodeData.model_validate(
node_data.model_dump(mode="python") if isinstance(node_data, BaseNodeData) else node_data
)
monkeypatch.setattr(factory, "_resolve_node_class", MagicMock(return_value=constructor))
monkeypatch.setattr(factory, "_build_llm_compatible_node_init_kwargs", MagicMock(return_value={}))
@ -629,10 +632,56 @@ class TestDifyNodeFactoryCreateNode:
factory.create_node(node_config)
config = constructor.call_args.kwargs["config"]
assert isinstance(config, dict)
assert config["structured_output_enabled"] is True
assert "structured_output_switch_on" not in config
data = constructor.call_args.kwargs["data"]
assert isinstance(data, Mapping)
assert data["structured_output_enabled"] is True
assert "structured_output_switch_on" not in data
assert LLMNodeData.model_validate(data).structured_output_enabled is True
def test_create_node_preserves_structured_output_switch_after_graphon_constructor(self, monkeypatch, factory):
factory.graph_init_params = SimpleNamespace(
workflow_id="workflow-id",
graph_config={},
run_context={},
call_depth=0,
)
monkeypatch.setattr(factory, "_resolve_node_class", MagicMock(return_value=LLMNode))
monkeypatch.setattr(
factory,
"_build_llm_compatible_node_init_kwargs",
MagicMock(
return_value={
"model_instance": sentinel.model_instance,
"llm_file_saver": sentinel.llm_file_saver,
"prompt_message_serializer": sentinel.prompt_message_serializer,
}
),
)
node_config = {
"id": "llm-node-id",
"data": {
"type": BuiltinNodeTypes.LLM,
"title": "LLM",
"model": {"provider": "provider", "name": "model", "mode": "chat", "completion_params": {}},
"prompt_template": [{"role": "system", "text": "x"}],
"context": {"enabled": False, "variable_selector": []},
"vision": {"enabled": False},
"structured_output_enabled": True,
"structured_output": {
"schema": {
"type": "object",
"properties": {"type": {"type": "string"}},
"required": ["type"],
}
},
},
}
node = factory.create_node(node_config)
assert node.node_data.structured_output_switch_on is True
assert node.node_data.structured_output_enabled is True
@pytest.mark.parametrize(
("node_type", "constructor_name", "expected_extra_kwargs"),
@ -711,7 +760,7 @@ class TestDifyNodeFactoryCreateNode:
constructor_kwargs = constructor.call_args.kwargs
assert constructor_kwargs["node_id"] == "node-id"
_assert_typed_node_config(constructor_kwargs["config"], node_id="node-id", node_type=node_type)
_assert_constructor_node_data(constructor_kwargs["data"], node_id="node-id", node_type=node_type)
assert constructor_kwargs["graph_init_params"] is sentinel.graph_init_params
assert constructor_kwargs["graph_runtime_state"] is factory.graph_runtime_state
assert constructor_kwargs["credentials_provider"] is sentinel.credentials_provider

View File

@ -109,8 +109,8 @@ class TestVariablePool:
assert pool.get([ENVIRONMENT_VARIABLE_NODE_ID, "env_var_1"]) is not None
assert pool.get([CONVERSATION_VARIABLE_NODE_ID, "conv_var_1"]) is not None
def test_constructor_loads_legacy_bootstrap_kwargs(self):
pool = VariablePool(
def test_from_bootstrap_loads_legacy_bootstrap_kwargs(self):
pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="test_user_id"),
environment_variables=[StringVariable(name="env_var", value="env-value")],
conversation_variables=[StringVariable(name="conv_var", value="conv-value")],

View File

@ -55,7 +55,7 @@ class TestWorkflowEntry:
def test_mapping_user_inputs_to_variable_pool_with_system_variables(self):
"""Test mapping system variables from user inputs to variable pool."""
# Initialize variable pool with system variables
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(
user_id="test_user_id",
app_id="test_app_id",
@ -128,7 +128,7 @@ class TestWorkflowEntry:
return NodeConfigDictAdapter.validate_python(node_config)
workflow = StubWorkflow()
variable_pool = VariablePool(system_variables=default_system_variables(), user_inputs={})
variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables(), user_inputs={})
expected_limits = CodeNodeLimits(
max_string_length=dify_config.CODE_MAX_STRING_LENGTH,
max_number=dify_config.CODE_MAX_NUMBER,
@ -157,7 +157,7 @@ class TestWorkflowEntry:
"""Test mapping environment variables from user inputs to variable pool."""
# Initialize variable pool with environment variables
env_var = StringVariable(name="API_KEY", value="existing_key")
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
environment_variables=[env_var],
user_inputs={},
@ -198,7 +198,7 @@ class TestWorkflowEntry:
"""Test mapping conversation variables from user inputs to variable pool."""
# Initialize variable pool with conversation variables
conv_var = StringVariable(name="last_message", value="Hello")
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
conversation_variables=[conv_var],
user_inputs={},
@ -239,7 +239,7 @@ class TestWorkflowEntry:
def test_mapping_user_inputs_to_variable_pool_with_regular_variables(self):
"""Test mapping regular node variables from user inputs to variable pool."""
# Initialize empty variable pool
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
)
@ -281,7 +281,7 @@ class TestWorkflowEntry:
def test_mapping_user_inputs_with_file_handling(self):
"""Test mapping file inputs from user inputs to variable pool."""
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
)
@ -340,7 +340,7 @@ class TestWorkflowEntry:
def test_mapping_user_inputs_missing_variable_error(self):
"""Test that mapping raises error when required variable is missing."""
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
)
@ -366,7 +366,7 @@ class TestWorkflowEntry:
def test_mapping_user_inputs_with_alternative_key_format(self):
"""Test mapping with alternative key format (without node prefix)."""
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
)
@ -396,7 +396,7 @@ class TestWorkflowEntry:
def test_mapping_user_inputs_with_complex_selectors(self):
"""Test mapping with complex node variable keys."""
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
)
@ -432,7 +432,7 @@ class TestWorkflowEntry:
def test_mapping_user_inputs_invalid_node_variable(self):
"""Test that mapping handles invalid node variable format."""
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
)
@ -463,7 +463,7 @@ class TestWorkflowEntry:
env_var = StringVariable(name="API_KEY", value="existing_key")
conv_var = StringVariable(name="session_id", value="session123")
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(
user_id="test_user",
app_id="test_app",

View File

@ -7,7 +7,6 @@ import pytest
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
from core.model_manager import ModelInstance
from core.workflow import workflow_entry
from core.workflow.system_variables import default_system_variables
from graphon.entities.base_node_data import BaseNodeData
@ -16,10 +15,12 @@ from graphon.errors import WorkflowNodeRunFailedError
from graphon.file import File, FileTransferMethod, FileType
from graphon.graph import Graph
from graphon.graph_events import GraphRunFailedEvent
from graphon.model_runtime.entities.llm_entities import LLMUsage
from graphon.model_runtime.entities.llm_entities import LLMMode, LLMUsage
from graphon.node_events import NodeRunResult
from graphon.nodes import BuiltinNodeTypes
from graphon.nodes.base.node import Node
from graphon.nodes.llm.entities import ContextConfig, LLMNodeData, ModelConfig
from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData
from graphon.runtime import ChildGraphNotFoundError, VariablePool
from graphon.variables.variables import StringVariable
from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool
@ -29,9 +30,30 @@ def _build_typed_node_config(node_type: NodeType):
return {"id": "node-id", "data": BaseNodeData(type=node_type)}
def _build_wrapped_model_instance() -> tuple[SimpleNamespace, ModelInstance]:
raw_model_instance = ModelInstance.__new__(ModelInstance)
return SimpleNamespace(_model_instance=raw_model_instance), raw_model_instance
def _build_model_config(*, provider: str = "openai", model_name: str = "gpt-4o") -> ModelConfig:
return ModelConfig(provider=provider, name=model_name, mode=LLMMode.CHAT)
def _build_llm_node_data(*, provider: str = "openai", model_name: str = "gpt-4o") -> LLMNodeData:
return LLMNodeData(
type=BuiltinNodeTypes.LLM,
title="Child Model",
model=_build_model_config(provider=provider, model_name=model_name),
prompt_template=[],
context=ContextConfig(enabled=False),
)
def _build_question_classifier_node_data(
*, provider: str = "openai", model_name: str = "gpt-4o"
) -> QuestionClassifierNodeData:
return QuestionClassifierNodeData(
type=BuiltinNodeTypes.QUESTION_CLASSIFIER,
title="Child Model",
query_variable_selector=["sys", "query"],
model=_build_model_config(provider=provider, model_name=model_name),
classes=[],
)
class _FakeModelNodeMixin:
@ -40,22 +62,26 @@ class _FakeModelNodeMixin:
return "1"
def post_init(self) -> None:
self.model_instance, self.raw_model_instance = _build_wrapped_model_instance()
self.model_instance = SimpleNamespace(provider="stale-provider", model_name="stale-model")
self.usage_snapshot = LLMUsage.empty_usage()
self.usage_snapshot.total_tokens = 1
def _run(self) -> NodeRunResult:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={
"model_provider": self.node_data.model.provider,
"model_name": self.node_data.model.name,
},
llm_usage=self.usage_snapshot,
)
class _FakeLLMNode(_FakeModelNodeMixin, Node[BaseNodeData]):
class _FakeLLMNode(_FakeModelNodeMixin, Node[LLMNodeData]):
node_type = BuiltinNodeTypes.LLM
class _FakeQuestionClassifierNode(_FakeModelNodeMixin, Node[BaseNodeData]):
class _FakeQuestionClassifierNode(_FakeModelNodeMixin, Node[QuestionClassifierNodeData]):
node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER
@ -75,7 +101,7 @@ class TestWorkflowChildEngineBuilder:
assert result is expected
def test_build_child_engine_raises_when_root_node_is_missing(self):
builder = workflow_entry._WorkflowChildEngineBuilder()
builder = workflow_entry._WorkflowChildEngineBuilder(tenant_id="tenant-id")
graph_init_params = SimpleNamespace(graph_config={"nodes": []})
parent_graph_runtime_state = SimpleNamespace(
execution_context=sentinel.execution_context,
@ -92,7 +118,7 @@ class TestWorkflowChildEngineBuilder:
)
def test_build_child_engine_constructs_graph_engine_with_quota_layer_only(self):
builder = workflow_entry._WorkflowChildEngineBuilder()
builder = workflow_entry._WorkflowChildEngineBuilder(tenant_id="tenant-id")
graph_init_params = SimpleNamespace(graph_config={"nodes": [{"id": "root"}]})
parent_graph_runtime_state = SimpleNamespace(
execution_context=sentinel.execution_context,
@ -114,7 +140,7 @@ class TestWorkflowChildEngineBuilder:
patch.object(workflow_entry, "GraphEngine", return_value=child_engine) as graph_engine_cls,
patch.object(workflow_entry, "GraphEngineConfig", return_value=sentinel.graph_engine_config),
patch.object(workflow_entry, "InMemoryChannel", return_value=sentinel.command_channel),
patch.object(workflow_entry, "LLMQuotaLayer", return_value=sentinel.llm_quota_layer),
patch.object(workflow_entry, "LLMQuotaLayer", return_value=sentinel.llm_quota_layer) as llm_quota_layer_cls,
):
result = builder.build_child_engine(
workflow_id="workflow-id",
@ -147,11 +173,12 @@ class TestWorkflowChildEngineBuilder:
config=sentinel.graph_engine_config,
child_engine_builder=builder,
)
llm_quota_layer_cls.assert_called_once_with(tenant_id="tenant-id")
assert child_engine.layer.call_args_list == [((sentinel.llm_quota_layer,), {})]
@pytest.mark.parametrize("node_cls", [_FakeLLMNode, _FakeQuestionClassifierNode])
def test_build_child_engine_runs_llm_quota_layer_for_child_model_nodes(self, node_cls):
builder = workflow_entry._WorkflowChildEngineBuilder()
builder = workflow_entry._WorkflowChildEngineBuilder(tenant_id="tenant-id")
graph_init_params = build_test_graph_init_params(
graph_config={"nodes": [{"id": "root"}], "edges": []},
)
@ -163,12 +190,10 @@ class TestWorkflowChildEngineBuilder:
def build_graph(*, graph_config, node_factory, root_node_id):
_ = graph_config
node_data = _build_llm_node_data() if node_cls is _FakeLLMNode else _build_question_classifier_node_data()
node = node_cls(
node_id=root_node_id,
config=BaseNodeData(
type=node_cls.node_type,
title="Child Model",
),
data=node_data,
graph_init_params=node_factory.graph_init_params,
graph_runtime_state=node_factory.graph_runtime_state,
)
@ -191,8 +216,8 @@ class TestWorkflowChildEngineBuilder:
),
),
patch.object(workflow_entry.Graph, "init", side_effect=build_graph),
patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available") as ensure_quota,
patch("core.app.workflow.layers.llm_quota.deduct_llm_quota") as deduct_quota,
patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model") as ensure_quota,
patch("core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model") as deduct_quota,
):
child_engine = builder.build_child_engine(
workflow_id="workflow-id",
@ -203,10 +228,15 @@ class TestWorkflowChildEngineBuilder:
list(child_engine.run())
node = created_node["node"]
ensure_quota.assert_called_once_with(model_instance=node.raw_model_instance)
ensure_quota.assert_called_once_with(
tenant_id="tenant-id",
provider=node.node_data.model.provider,
model=node.node_data.model.name,
)
deduct_quota.assert_called_once_with(
tenant_id="tenant",
model_instance=node.raw_model_instance,
tenant_id="tenant-id",
provider=node.node_data.model.provider,
model=node.node_data.model.name,
usage=node.usage_snapshot,
)
@ -252,7 +282,7 @@ class TestWorkflowEntryInit:
"ExecutionLimitsLayer",
return_value=execution_limits_layer,
) as execution_limits_layer_cls,
patch.object(workflow_entry, "LLMQuotaLayer", return_value=llm_quota_layer),
patch.object(workflow_entry, "LLMQuotaLayer", return_value=llm_quota_layer) as llm_quota_layer_cls,
patch.object(workflow_entry, "ObservabilityLayer", return_value=observability_layer),
):
entry = workflow_entry.WorkflowEntry(
@ -291,6 +321,7 @@ class TestWorkflowEntryInit:
max_steps=workflow_entry.dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_time=workflow_entry.dify_config.WORKFLOW_MAX_EXECUTION_TIME,
)
llm_quota_layer_cls.assert_called_once_with(tenant_id="tenant-id")
assert graph_engine.layer.call_args_list == [
((debug_layer,), {}),
((execution_limits_layer,), {}),
@ -334,7 +365,7 @@ class TestWorkflowEntrySingleStepRun:
def extract_variable_selector_to_variable_mapping(**_kwargs):
return {}
variable_pool = VariablePool(system_variables=default_system_variables(), user_inputs={})
variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables(), user_inputs={})
variable_loader = MagicMock()
variable_loader.load_variables.return_value = [
StringVariable(

View File

@ -0,0 +1,130 @@
from types import SimpleNamespace
from unittest.mock import patch
from uuid import uuid4
from sqlalchemy import create_engine, select
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from events.event_handlers import update_provider_when_message_created
from models import TenantCreditPool
from models.provider import ProviderType
def test_message_created_trial_credit_accounting_does_not_raise_when_balance_is_insufficient() -> None:
engine = create_engine("sqlite:///:memory:")
TenantCreditPool.__table__.create(engine)
tenant_id = str(uuid4())
pool_id = str(uuid4())
with engine.begin() as connection:
connection.execute(
TenantCreditPool.__table__.insert(),
{
"id": pool_id,
"tenant_id": tenant_id,
"pool_type": ProviderQuotaType.TRIAL,
"quota_limit": 10,
"quota_used": 9,
},
)
system_configuration = SimpleNamespace(
current_quota_type=ProviderQuotaType.TRIAL,
quota_configurations=[
SimpleNamespace(
quota_type=ProviderQuotaType.TRIAL,
quota_unit=QuotaUnit.TOKENS,
quota_limit=10,
)
],
)
application_generate_entity = ChatAppGenerateEntity.model_construct(
app_config=SimpleNamespace(tenant_id=tenant_id),
model_conf=SimpleNamespace(
provider="openai",
model="gpt-4o",
provider_model_bundle=SimpleNamespace(
configuration=SimpleNamespace(
using_provider_type=ProviderType.SYSTEM,
system_configuration=system_configuration,
)
),
),
)
message = SimpleNamespace(message_tokens=2, answer_tokens=1)
with (
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
patch.object(update_provider_when_message_created, "_execute_provider_updates"),
):
update_provider_when_message_created.handle(
sender=message,
application_generate_entity=application_generate_entity,
)
with engine.connect() as connection:
quota_used = connection.scalar(select(TenantCreditPool.quota_used).where(TenantCreditPool.id == pool_id))
assert quota_used == 10
def test_message_created_paid_credit_accounting_uses_paid_pool() -> None:
tenant_id = str(uuid4())
system_configuration = SimpleNamespace(
current_quota_type=ProviderQuotaType.PAID,
quota_configurations=[
SimpleNamespace(
quota_type=ProviderQuotaType.PAID,
quota_unit=QuotaUnit.TOKENS,
quota_limit=10,
)
],
)
application_generate_entity = ChatAppGenerateEntity.model_construct(
app_config=SimpleNamespace(tenant_id=tenant_id),
model_conf=SimpleNamespace(
provider="openai",
model="gpt-4o",
provider_model_bundle=SimpleNamespace(
configuration=SimpleNamespace(
using_provider_type=ProviderType.SYSTEM,
system_configuration=system_configuration,
)
),
),
)
message = SimpleNamespace(message_tokens=2, answer_tokens=1)
with (
patch.object(update_provider_when_message_created, "_deduct_credit_pool_quota_capped") as mock_deduct,
patch.object(update_provider_when_message_created, "_execute_provider_updates"),
):
update_provider_when_message_created.handle(
sender=message,
application_generate_entity=application_generate_entity,
)
mock_deduct.assert_called_once_with(
tenant_id=tenant_id,
credits_required=3,
pool_type="paid",
)
def test_capped_credit_pool_accounting_skips_exhaustion_warning_when_full_amount_is_deducted(caplog) -> None:
with patch(
"services.credit_pool_service.CreditPoolService.deduct_credits_capped",
return_value=3,
) as mock_deduct:
update_provider_when_message_created._deduct_credit_pool_quota_capped(
tenant_id="tenant-id",
credits_required=3,
pool_type="trial",
)
mock_deduct.assert_called_once_with(
tenant_id="tenant-id",
credits_required=3,
pool_type="trial",
)
assert "Credit pool exhausted during message-created accounting" not in caplog.text

View File

@ -0,0 +1,158 @@
from types import SimpleNamespace
from unittest.mock import patch
from uuid import uuid4
import pytest
from sqlalchemy import create_engine, select
from sqlalchemy.engine import Engine
from core.errors.error import QuotaExceededError
from models import TenantCreditPool
from models.enums import ProviderQuotaType
from services.credit_pool_service import CreditPoolService
def _create_engine_with_pool(*, quota_limit: int, quota_used: int) -> tuple[Engine, str, str]:
engine = create_engine("sqlite:///:memory:")
TenantCreditPool.__table__.create(engine)
tenant_id = str(uuid4())
pool_id = str(uuid4())
with engine.begin() as connection:
connection.execute(
TenantCreditPool.__table__.insert(),
{
"id": pool_id,
"tenant_id": tenant_id,
"pool_type": ProviderQuotaType.TRIAL,
"quota_limit": quota_limit,
"quota_used": quota_used,
},
)
return engine, tenant_id, pool_id
def _get_quota_used(*, engine: Engine, pool_id: str) -> int | None:
with engine.connect() as connection:
return connection.scalar(select(TenantCreditPool.quota_used).where(TenantCreditPool.id == pool_id))
def test_check_and_deduct_credits_deducts_exact_amount_when_sufficient() -> None:
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2)
with patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)):
deducted_credits = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=3)
assert deducted_credits == 3
assert _get_quota_used(engine=engine, pool_id=pool_id) == 5
def test_check_and_deduct_credits_returns_zero_for_non_positive_request() -> None:
assert CreditPoolService.check_and_deduct_credits(tenant_id=str(uuid4()), credits_required=0) == 0
def test_check_and_deduct_credits_raises_when_pool_is_missing() -> None:
engine = create_engine("sqlite:///:memory:")
TenantCreditPool.__table__.create(engine)
with (
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
pytest.raises(QuotaExceededError, match="Credit pool not found"),
):
CreditPoolService.check_and_deduct_credits(tenant_id=str(uuid4()), credits_required=1)
def test_check_and_deduct_credits_raises_when_pool_is_empty() -> None:
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=10)
with (
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
pytest.raises(QuotaExceededError, match="No credits remaining"),
):
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=1)
assert _get_quota_used(engine=engine, pool_id=pool_id) == 10
def test_check_and_deduct_credits_raises_without_partial_deduction_when_insufficient() -> None:
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=9)
with (
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
pytest.raises(QuotaExceededError, match="Insufficient credits remaining"),
):
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=3)
assert _get_quota_used(engine=engine, pool_id=pool_id) == 9
def test_check_and_deduct_credits_wraps_unexpected_deduction_errors() -> None:
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2)
with (
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
patch.object(CreditPoolService, "_get_locked_pool", side_effect=RuntimeError("database unavailable")),
pytest.raises(QuotaExceededError, match="Failed to deduct credits"),
):
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=1)
assert _get_quota_used(engine=engine, pool_id=pool_id) == 2
def test_deduct_credits_capped_returns_zero_for_non_positive_request() -> None:
assert CreditPoolService.deduct_credits_capped(tenant_id=str(uuid4()), credits_required=0) == 0
def test_deduct_credits_capped_returns_zero_when_pool_is_missing() -> None:
engine = create_engine("sqlite:///:memory:")
TenantCreditPool.__table__.create(engine)
with patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)):
deducted_credits = CreditPoolService.deduct_credits_capped(tenant_id=str(uuid4()), credits_required=1)
assert deducted_credits == 0
def test_deduct_credits_capped_returns_zero_when_pool_is_empty() -> None:
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=10)
with patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)):
deducted_credits = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=1)
assert deducted_credits == 0
assert _get_quota_used(engine=engine, pool_id=pool_id) == 10
def test_deduct_credits_capped_deducts_only_remaining_balance_when_insufficient() -> None:
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=9)
with patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)):
deducted_credits = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=3)
assert deducted_credits == 1
assert _get_quota_used(engine=engine, pool_id=pool_id) == 10
def test_deduct_credits_capped_wraps_unexpected_deduction_errors() -> None:
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2)
with (
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
patch.object(CreditPoolService, "_get_locked_pool", side_effect=RuntimeError("database unavailable")),
pytest.raises(QuotaExceededError, match="Failed to deduct credits"),
):
CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=1)
assert _get_quota_used(engine=engine, pool_id=pool_id) == 2
def test_deduct_credits_capped_reraises_quota_exceeded_errors() -> None:
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2)
with (
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
patch.object(CreditPoolService, "_get_locked_pool", side_effect=QuotaExceededError("quota unavailable")),
pytest.raises(QuotaExceededError, match="quota unavailable"),
):
CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=1)
assert _get_quota_used(engine=engine, pool_id=pool_id) == 2

View File

@ -2845,7 +2845,7 @@ class TestWorkflowServiceFreeNodeExecution:
mock_node_cls.validate_node_data.assert_called_once_with(sentinel.adapted_node_data)
mock_node_cls.assert_called_once_with(
node_id="n-1",
config=sentinel.node_data,
data=sentinel.node_data,
graph_init_params=mock_graph_init_context_cls.return_value.to_graph_init_params.return_value,
graph_runtime_state=ANY,
runtime=mock_runtime_cls.return_value,

8
api/uv.lock generated
View File

@ -1597,7 +1597,7 @@ requires-dist = [
{ name = "gmpy2", specifier = ">=2.3.0" },
{ name = "google-api-python-client", specifier = ">=2.195.0" },
{ name = "google-cloud-aiplatform", specifier = ">=1.149.0,<2.0.0" },
{ name = "graphon", specifier = "~=0.2.2" },
{ name = "graphon", specifier = "~=0.3.0" },
{ name = "gunicorn", specifier = ">=25.3.0" },
{ name = "httpx", extras = ["socks"], specifier = ">=0.28.1,<1.0.0" },
{ name = "httpx-sse", specifier = "~=0.4.0" },
@ -2940,7 +2940,7 @@ httpx = [
[[package]]
name = "graphon"
version = "0.2.2"
version = "0.3.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "charset-normalizer" },
@ -2961,9 +2961,9 @@ dependencies = [
{ name = "unstructured", extra = ["docx", "epub", "md", "ppt", "pptx"] },
{ name = "webvtt-py" },
]
sdist = { url = "https://files.pythonhosted.org/packages/08/50/e745a79c5f742f88f6011a1f7c9ba2c2f9cc1beedd982f0b192f1ab8c748/graphon-0.2.2.tar.gz", hash = "sha256:141f0de536171850f1af6f738dc66f0285aadd3c097f1dad2a038636789e0aa5", size = 236360, upload-time = "2026-04-17T08:52:28.047Z" }
sdist = { url = "https://files.pythonhosted.org/packages/bf/62/83593d6e7a139ff124711ea05882cadca7065c11a38763aa9360d7e76804/graphon-0.3.0.tar.gz", hash = "sha256:cd38f842ae3dcfa956428b952efbe2a3ea9c1581446647142accbbdeb638b876", size = 241176, upload-time = "2026-04-21T15:18:48.291Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/de/89/a6340afdaf5169d17a318e00fc685fb67ed99baa602c2cbbbf6af6a76096/graphon-0.2.2-py3-none-any.whl", hash = "sha256:754e544d08779138f99eac6547ab08559463680e2c76488b05e1c978210392b4", size = 340808, upload-time = "2026-04-17T08:52:26.5Z" },
{ url = "https://files.pythonhosted.org/packages/b3/f7/81ee8f0368aa6a2d47f97fecc5d4a12865c987906798cbddd0e3b8387f33/graphon-0.3.0-py3-none-any.whl", hash = "sha256:9cca45ebab2a79fd4d04432f55b5b962e9e4f34fa037cc20fee7f18ec80eaa5d", size = 348486, upload-time = "2026-04-21T15:18:46.737Z" },
]
[[package]]