mirror of
https://github.com/langgenius/dify.git
synced 2026-05-10 05:56:31 +08:00
chore(api): upgrade graphon to v0.3.0 (#35469)
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: WH-2099 <wh2099@pm.me>
This commit is contained in:
parent
f3eb3ab4dd
commit
19476109da
@ -1,5 +1,15 @@
|
||||
"""LLM-related application services."""
|
||||
|
||||
from .quota import deduct_llm_quota, ensure_llm_quota_available
|
||||
from .quota import (
|
||||
deduct_llm_quota,
|
||||
deduct_llm_quota_for_model,
|
||||
ensure_llm_quota_available,
|
||||
ensure_llm_quota_available_for_model,
|
||||
)
|
||||
|
||||
__all__ = ["deduct_llm_quota", "ensure_llm_quota_available"]
|
||||
__all__ = [
|
||||
"deduct_llm_quota",
|
||||
"deduct_llm_quota_for_model",
|
||||
"ensure_llm_quota_available",
|
||||
"ensure_llm_quota_available_for_model",
|
||||
]
|
||||
|
||||
@ -1,4 +1,14 @@
|
||||
from sqlalchemy import update
|
||||
"""Tenant-scoped helpers for checking and deducting LLM provider quota.
|
||||
|
||||
System-hosted quota accounting is currently defined only for LLM models. Keep
|
||||
the public helpers LLM-specific so callers do not carry unused model-type
|
||||
plumbing, and fail loudly if the deprecated ``ModelInstance`` wrappers are used
|
||||
with a non-LLM model.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
@ -6,44 +16,47 @@ from core.entities.model_entities import ModelStatus
|
||||
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
|
||||
from core.errors.error import QuotaExceededError
|
||||
from core.model_manager import ModelInstance
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.entities.llm_entities import LLMUsage
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.provider import Provider, ProviderType
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
|
||||
def ensure_llm_quota_available(*, model_instance: ModelInstance) -> None:
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
provider_configuration = provider_model_bundle.configuration
|
||||
def _get_provider_configuration(*, tenant_id: str, provider: str):
|
||||
"""Resolve the tenant-bound provider configuration for quota decisions."""
|
||||
provider_manager = create_plugin_provider_manager(tenant_id=tenant_id)
|
||||
provider_configuration = provider_manager.get_configurations(tenant_id).get(provider)
|
||||
if provider_configuration is None:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
return provider_configuration
|
||||
|
||||
|
||||
def ensure_llm_quota_available_for_model(*, tenant_id: str, provider: str, model: str) -> None:
|
||||
"""Raise when a tenant-bound LLM model is already out of quota."""
|
||||
provider_configuration = _get_provider_configuration(tenant_id=tenant_id, provider=provider)
|
||||
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
|
||||
return
|
||||
|
||||
provider_model = provider_configuration.get_provider_model(
|
||||
model_type=model_instance.model_type_instance.model_type,
|
||||
model=model_instance.model_name,
|
||||
model_type=ModelType.LLM,
|
||||
model=model,
|
||||
)
|
||||
if provider_model and provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
||||
raise QuotaExceededError(f"Model provider {model_instance.provider} quota exceeded.")
|
||||
raise QuotaExceededError(f"Model provider {provider} quota exceeded.")
|
||||
|
||||
|
||||
def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
provider_configuration = provider_model_bundle.configuration
|
||||
|
||||
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
|
||||
return
|
||||
|
||||
system_configuration = provider_configuration.system_configuration
|
||||
|
||||
def _resolve_llm_used_quota(*, system_configuration, model: str, usage: LLMUsage) -> int | None:
|
||||
"""Compute the quota impact for an LLM invocation under the current quota mode."""
|
||||
quota_unit = None
|
||||
for quota_configuration in system_configuration.quota_configurations:
|
||||
if quota_configuration.quota_type == system_configuration.current_quota_type:
|
||||
quota_unit = quota_configuration.quota_unit
|
||||
|
||||
if quota_configuration.quota_limit == -1:
|
||||
return
|
||||
return None
|
||||
|
||||
break
|
||||
|
||||
@ -52,42 +65,136 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL
|
||||
if quota_unit == QuotaUnit.TOKENS:
|
||||
used_quota = usage.total_tokens
|
||||
elif quota_unit == QuotaUnit.CREDITS:
|
||||
used_quota = dify_config.get_model_credits(model_instance.model_name)
|
||||
used_quota = dify_config.get_model_credits(model)
|
||||
else:
|
||||
used_quota = 1
|
||||
|
||||
return used_quota
|
||||
|
||||
|
||||
def _deduct_free_llm_quota(
|
||||
*,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
quota_type: ProviderQuotaType,
|
||||
used_quota: int,
|
||||
) -> None:
|
||||
"""Deduct FREE provider quota, capping at the limit before reporting exhaustion."""
|
||||
quota_exceeded = False
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
provider_record = session.scalar(
|
||||
select(Provider)
|
||||
.where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == quota_type,
|
||||
)
|
||||
.with_for_update()
|
||||
)
|
||||
if (
|
||||
provider_record is None
|
||||
or provider_record.quota_limit is None
|
||||
or provider_record.quota_used is None
|
||||
or provider_record.quota_limit <= provider_record.quota_used
|
||||
):
|
||||
quota_exceeded = True
|
||||
else:
|
||||
available_quota = provider_record.quota_limit - provider_record.quota_used
|
||||
deducted_quota = min(used_quota, available_quota)
|
||||
provider_record.quota_used += deducted_quota
|
||||
provider_record.last_used = naive_utc_now()
|
||||
quota_exceeded = deducted_quota < used_quota
|
||||
|
||||
if quota_exceeded:
|
||||
raise QuotaExceededError(f"Model provider {provider} quota exceeded.")
|
||||
|
||||
|
||||
def _deduct_used_llm_quota(*, tenant_id: str, provider: str, provider_configuration, used_quota: int | None) -> None:
|
||||
"""Apply a resolved LLM quota charge against the current provider quota bucket."""
|
||||
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
|
||||
return
|
||||
|
||||
system_configuration = provider_configuration.system_configuration
|
||||
if used_quota is not None and system_configuration.current_quota_type is not None:
|
||||
match system_configuration.current_quota_type:
|
||||
case ProviderQuotaType.TRIAL:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
CreditPoolService.deduct_credits_capped(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
)
|
||||
case ProviderQuotaType.PAID:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
CreditPoolService.deduct_credits_capped(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
pool_type="paid",
|
||||
)
|
||||
case ProviderQuotaType.FREE:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
stmt = (
|
||||
update(Provider)
|
||||
.where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type,
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
)
|
||||
.values(
|
||||
quota_used=Provider.quota_used + used_quota,
|
||||
last_used=naive_utc_now(),
|
||||
)
|
||||
)
|
||||
session.execute(stmt)
|
||||
_deduct_free_llm_quota(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
quota_type=system_configuration.current_quota_type,
|
||||
used_quota=used_quota,
|
||||
)
|
||||
case _:
|
||||
return
|
||||
|
||||
|
||||
def deduct_llm_quota_for_model(*, tenant_id: str, provider: str, model: str, usage: LLMUsage) -> None:
|
||||
"""Deduct tenant-bound quota for the resolved LLM model identity."""
|
||||
provider_configuration = _get_provider_configuration(tenant_id=tenant_id, provider=provider)
|
||||
used_quota = _resolve_llm_used_quota(
|
||||
system_configuration=provider_configuration.system_configuration,
|
||||
model=model,
|
||||
usage=usage,
|
||||
)
|
||||
_deduct_used_llm_quota(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
provider_configuration=provider_configuration,
|
||||
used_quota=used_quota,
|
||||
)
|
||||
|
||||
|
||||
def _require_llm_model_instance(model_instance: ModelInstance) -> None:
|
||||
"""Reject deprecated wrapper calls that pass a non-LLM model instance."""
|
||||
if model_instance.model_type_instance.model_type != ModelType.LLM:
|
||||
raise ValueError("LLM quota helpers only support LLM model instances.")
|
||||
|
||||
|
||||
def ensure_llm_quota_available(*, model_instance: ModelInstance) -> None:
|
||||
"""Deprecated compatibility wrapper for callers that still pass ModelInstance."""
|
||||
warnings.warn(
|
||||
"ensure_llm_quota_available(model_instance=...) is deprecated; "
|
||||
"use ensure_llm_quota_available_for_model(...) instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
_require_llm_model_instance(model_instance)
|
||||
ensure_llm_quota_available_for_model(
|
||||
tenant_id=model_instance.provider_model_bundle.configuration.tenant_id,
|
||||
provider=model_instance.provider,
|
||||
model=model_instance.model_name,
|
||||
)
|
||||
|
||||
|
||||
def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
|
||||
"""Deprecated compatibility wrapper for callers that still pass ModelInstance."""
|
||||
warnings.warn(
|
||||
"deduct_llm_quota(tenant_id=..., model_instance=..., usage=...) is deprecated; "
|
||||
"use deduct_llm_quota_for_model(...) instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
_require_llm_model_instance(model_instance)
|
||||
deduct_llm_quota_for_model(
|
||||
tenant_id=tenant_id,
|
||||
provider=model_instance.provider,
|
||||
model=model_instance.model_name,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
@ -1,36 +1,48 @@
|
||||
"""
|
||||
LLM quota deduction layer for GraphEngine.
|
||||
|
||||
This layer centralizes model-quota deduction outside node implementations.
|
||||
This layer centralizes model-quota handling outside node implementations.
|
||||
|
||||
Graphon LLM-backed nodes expose provider/model identity through public node
|
||||
configuration and, after execution, through ``node_run_result.inputs``. Resolve
|
||||
quota billing from that public identity instead of depending on
|
||||
``ModelInstance`` reconstruction inside the workflow layer. Missing identity on
|
||||
quota-tracked nodes is treated as a workflow bug and aborts execution so quota
|
||||
handling is never silently skipped.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, cast, final, override
|
||||
from typing import final, override
|
||||
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
|
||||
from core.app.llm import deduct_llm_quota, ensure_llm_quota_available
|
||||
from core.app.llm import deduct_llm_quota_for_model, ensure_llm_quota_available_for_model
|
||||
from core.errors.error import QuotaExceededError
|
||||
from core.model_manager import ModelInstance
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
|
||||
from graphon.graph_engine.entities.commands import AbortCommand, CommandType
|
||||
from graphon.graph_engine.layers import GraphEngineLayer
|
||||
from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, NodeRunSucceededEvent
|
||||
from graphon.node_events import NodeRunResult
|
||||
from graphon.nodes.base.node import Node
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphon.nodes.llm.node import LLMNode
|
||||
from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from graphon.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_QUOTA_NODE_TYPES = frozenset(
|
||||
[
|
||||
BuiltinNodeTypes.LLM,
|
||||
BuiltinNodeTypes.PARAMETER_EXTRACTOR,
|
||||
BuiltinNodeTypes.QUESTION_CLASSIFIER,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
class LLMQuotaLayer(GraphEngineLayer):
|
||||
"""Graph layer that applies LLM quota deduction after node execution."""
|
||||
"""Graph layer that applies tenant-scoped quota checks to LLM-backed nodes."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
tenant_id: str
|
||||
_abort_sent: bool
|
||||
|
||||
def __init__(self, tenant_id: str) -> None:
|
||||
super().__init__()
|
||||
self.tenant_id = tenant_id
|
||||
self._abort_sent = False
|
||||
|
||||
@override
|
||||
@ -50,33 +62,49 @@ class LLMQuotaLayer(GraphEngineLayer):
|
||||
if self._abort_sent:
|
||||
return
|
||||
|
||||
model_instance = self._extract_model_instance(node)
|
||||
if model_instance is None:
|
||||
if not self._supports_quota(node):
|
||||
return
|
||||
|
||||
model_identity = self._extract_model_identity_from_node(node)
|
||||
if model_identity is None:
|
||||
reason = "LLM quota check requires public node model identity before execution."
|
||||
self._abort_before_node_run(node=node, reason=reason, error_type="LLMQuotaIdentityError")
|
||||
logger.error("LLM quota handling aborted, node_id=%s, reason=%s", node.id, reason)
|
||||
return
|
||||
|
||||
provider, model_name = model_identity
|
||||
try:
|
||||
ensure_llm_quota_available(model_instance=model_instance)
|
||||
ensure_llm_quota_available_for_model(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=provider,
|
||||
model=model_name,
|
||||
)
|
||||
except QuotaExceededError as exc:
|
||||
self._set_stop_event(node)
|
||||
self._send_abort_command(reason=str(exc))
|
||||
self._abort_before_node_run(node=node, reason=str(exc), error_type=QuotaExceededError.__name__)
|
||||
logger.warning("LLM quota check failed, node_id=%s, error=%s", node.id, exc)
|
||||
|
||||
@override
|
||||
def on_node_run_end(
|
||||
self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None
|
||||
) -> None:
|
||||
if error is not None or not isinstance(result_event, NodeRunSucceededEvent):
|
||||
if error is not None or not isinstance(result_event, NodeRunSucceededEvent) or not self._supports_quota(node):
|
||||
return
|
||||
|
||||
model_instance = self._extract_model_instance(node)
|
||||
if model_instance is None:
|
||||
model_identity = self._extract_model_identity_from_result_event(result_event)
|
||||
if model_identity is None:
|
||||
self._abort_for_missing_model_identity(
|
||||
node=node,
|
||||
reason="LLM quota deduction requires model identity in the node result event.",
|
||||
)
|
||||
return
|
||||
|
||||
provider, model_name = model_identity
|
||||
|
||||
try:
|
||||
dify_ctx = DifyRunContext.model_validate(node.require_run_context_value(DIFY_RUN_CONTEXT_KEY))
|
||||
deduct_llm_quota(
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
model_instance=model_instance,
|
||||
deduct_llm_quota_for_model(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=provider,
|
||||
model=model_name,
|
||||
usage=result_event.node_run_result.llm_usage,
|
||||
)
|
||||
except QuotaExceededError as exc:
|
||||
@ -92,6 +120,27 @@ class LLMQuotaLayer(GraphEngineLayer):
|
||||
if stop_event is not None:
|
||||
stop_event.set()
|
||||
|
||||
def _abort_before_node_run(self, *, node: Node, reason: str, error_type: str) -> None:
|
||||
self._set_stop_event(node)
|
||||
node.node_data.error_strategy = None
|
||||
node.node_data.retry_config.retry_enabled = False
|
||||
|
||||
def quota_aborted_run() -> NodeRunResult:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=reason,
|
||||
error_type=error_type,
|
||||
)
|
||||
|
||||
# TODO: Push Graphon to expose a public pre-run failure/skip hook, then replace this private _run override.
|
||||
node._run = quota_aborted_run # type: ignore[method-assign]
|
||||
self._send_abort_command(reason=reason)
|
||||
|
||||
def _abort_for_missing_model_identity(self, *, node: Node, reason: str) -> None:
|
||||
self._set_stop_event(node)
|
||||
self._send_abort_command(reason=reason)
|
||||
logger.error("LLM quota handling aborted, node_id=%s, reason=%s", node.id, reason)
|
||||
|
||||
def _send_abort_command(self, *, reason: str) -> None:
|
||||
if not self.command_channel or self._abort_sent:
|
||||
return
|
||||
@ -108,29 +157,38 @@ class LLMQuotaLayer(GraphEngineLayer):
|
||||
logger.exception("Failed to send quota abort command")
|
||||
|
||||
@staticmethod
|
||||
def _extract_model_instance(node: Node) -> ModelInstance | None:
|
||||
try:
|
||||
match node.node_type:
|
||||
case BuiltinNodeTypes.LLM:
|
||||
model_instance = cast("LLMNode", node).model_instance
|
||||
case BuiltinNodeTypes.PARAMETER_EXTRACTOR:
|
||||
model_instance = cast("ParameterExtractorNode", node).model_instance
|
||||
case BuiltinNodeTypes.QUESTION_CLASSIFIER:
|
||||
model_instance = cast("QuestionClassifierNode", node).model_instance
|
||||
case _:
|
||||
return None
|
||||
except AttributeError:
|
||||
def _supports_quota(node: Node) -> bool:
|
||||
return node.node_type in _QUOTA_NODE_TYPES
|
||||
|
||||
@staticmethod
|
||||
def _extract_model_identity_from_result_event(result_event: NodeRunSucceededEvent) -> tuple[str, str] | None:
|
||||
provider = result_event.node_run_result.inputs.get("model_provider")
|
||||
model_name = result_event.node_run_result.inputs.get("model_name")
|
||||
if isinstance(provider, str) and provider and isinstance(model_name, str) and model_name:
|
||||
return provider, model_name
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_model_identity_from_node(node: Node) -> tuple[str, str] | None:
|
||||
node_data = getattr(node, "node_data", None)
|
||||
if node_data is None:
|
||||
node_data = getattr(node, "data", None)
|
||||
|
||||
model_config = getattr(node_data, "model", None)
|
||||
if model_config is None:
|
||||
logger.warning(
|
||||
"LLMQuotaLayer skipped quota deduction because node does not expose a model instance, node_id=%s",
|
||||
"LLMQuotaLayer skipped quota handling because node model config is missing, node_id=%s",
|
||||
node.id,
|
||||
)
|
||||
return None
|
||||
|
||||
if isinstance(model_instance, ModelInstance):
|
||||
return model_instance
|
||||
|
||||
raw_model_instance = getattr(model_instance, "_model_instance", None)
|
||||
if isinstance(raw_model_instance, ModelInstance):
|
||||
return raw_model_instance
|
||||
provider = getattr(model_config, "provider", None)
|
||||
model_name = getattr(model_config, "name", None)
|
||||
if isinstance(provider, str) and provider and isinstance(model_name, str) and model_name:
|
||||
return provider, model_name
|
||||
|
||||
logger.warning(
|
||||
"LLMQuotaLayer skipped quota handling because node model identity is invalid, node_id=%s",
|
||||
node.id,
|
||||
)
|
||||
return None
|
||||
|
||||
@ -23,7 +23,7 @@ from core.entities.provider_entities import (
|
||||
)
|
||||
from core.helper import encrypter
|
||||
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory
|
||||
from core.plugin.impl.model_runtime_factory import create_model_type_instance, create_plugin_model_assembly
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
from graphon.model_runtime.entities.provider_entities import (
|
||||
ConfigurateMethod,
|
||||
@ -33,7 +33,7 @@ from graphon.model_runtime.entities.provider_entities import (
|
||||
)
|
||||
from graphon.model_runtime.model_providers.base.ai_model import AIModel
|
||||
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from graphon.model_runtime.runtime import ModelRuntime
|
||||
from graphon.model_runtime.protocols.runtime import ModelRuntime
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.engine import db
|
||||
from models.enums import CredentialSourceType
|
||||
@ -106,11 +106,18 @@ class ProviderConfiguration(BaseModel):
|
||||
"""Attach the already-composed runtime for request-bound call chains."""
|
||||
self._bound_model_runtime = model_runtime
|
||||
|
||||
def _get_runtime_and_provider_factory(self) -> tuple[ModelRuntime, ModelProviderFactory]:
|
||||
"""Resolve a provider factory that stays aligned with the runtime used by the caller."""
|
||||
if self._bound_model_runtime is not None:
|
||||
return self._bound_model_runtime, ModelProviderFactory(runtime=self._bound_model_runtime)
|
||||
|
||||
model_assembly = create_plugin_model_assembly(tenant_id=self.tenant_id)
|
||||
return model_assembly.model_runtime, model_assembly.model_provider_factory
|
||||
|
||||
def get_model_provider_factory(self) -> ModelProviderFactory:
|
||||
"""Return a provider factory that preserves any request-bound runtime."""
|
||||
if self._bound_model_runtime is not None:
|
||||
return ModelProviderFactory(model_runtime=self._bound_model_runtime)
|
||||
return create_plugin_model_provider_factory(tenant_id=self.tenant_id)
|
||||
_, model_provider_factory = self._get_runtime_and_provider_factory()
|
||||
return model_provider_factory
|
||||
|
||||
def get_current_credentials(self, model_type: ModelType, model: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
@ -1392,10 +1399,13 @@ class ProviderConfiguration(BaseModel):
|
||||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
model_provider_factory = self.get_model_provider_factory()
|
||||
|
||||
# Get model instance of LLM
|
||||
return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
|
||||
model_runtime, model_provider_factory = self._get_runtime_and_provider_factory()
|
||||
provider_schema = model_provider_factory.get_provider_schema(provider=self.provider.provider)
|
||||
return create_model_type_instance(
|
||||
runtime=model_runtime,
|
||||
provider_schema=provider_schema,
|
||||
model_type=model_type,
|
||||
)
|
||||
|
||||
def get_model_schema(
|
||||
self, model_type: ModelType, model: str, credentials: dict[str, Any] | None
|
||||
|
||||
@ -4,7 +4,7 @@ from typing import cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities import DEFAULT_PLUGIN_ID
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly
|
||||
from extensions.ext_hosting_provider import hosting_configuration
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
@ -41,10 +41,8 @@ def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEnt
|
||||
text_chunk = secrets.choice(text_chunks)
|
||||
|
||||
try:
|
||||
model_provider_factory = create_plugin_model_provider_factory(tenant_id=tenant_id)
|
||||
|
||||
# Get model instance of LLM
|
||||
model_type_instance = model_provider_factory.get_model_type_instance(
|
||||
model_assembly = create_plugin_model_assembly(tenant_id=tenant_id)
|
||||
model_type_instance = model_assembly.create_model_type_instance(
|
||||
provider=openai_provider_name, model_type=ModelType.MODERATION
|
||||
)
|
||||
model_type_instance = cast(ModerationModel, model_type_instance)
|
||||
|
||||
@ -4,23 +4,32 @@ import hashlib
|
||||
import logging
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
from threading import Lock
|
||||
from typing import IO, Any, Union
|
||||
from typing import IO, Any, Literal, cast, overload
|
||||
|
||||
from pydantic import ValidationError
|
||||
from redis import RedisError
|
||||
|
||||
from configs import dify_config
|
||||
from core.llm_generator.output_parser.structured_output import (
|
||||
invoke_llm_with_structured_output as invoke_llm_with_structured_output_helper,
|
||||
)
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from core.plugin.impl.asset import PluginAssetManager
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
from extensions.ext_redis import redis_client
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from graphon.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkWithStructuredOutput,
|
||||
LLMResultWithStructuredOutput,
|
||||
)
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from graphon.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
|
||||
from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult
|
||||
from graphon.model_runtime.runtime import ModelRuntime
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import normalize_non_stream_runtime_result
|
||||
from graphon.model_runtime.protocols.runtime import ModelRuntime
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -29,6 +38,68 @@ logger = logging.getLogger(__name__)
|
||||
TENANT_SCOPE_SCHEMA_CACHE_USER_ID = "__DIFY_TS__"
|
||||
|
||||
|
||||
# TODO(-LAN-): Move native structured-output invocation into Graphon's LLM node.
|
||||
# TODO(-LAN-): Remove this Dify-side adapter once Graphon owns structured output end-to-end.
|
||||
class _PluginStructuredOutputModelInstance:
|
||||
"""Bind plugin model identity to the shared structured-output helper.
|
||||
|
||||
The structured-output parser is shared with legacy ``ModelInstance`` flows
|
||||
and only needs an object exposing ``invoke_llm(...)``. ``PluginModelRuntime``
|
||||
intentionally exposes a lower-level API where provider, model, and
|
||||
credentials are passed per call. This adapter supplies the small bound
|
||||
``invoke_llm`` surface the helper needs without constructing a full
|
||||
``ModelInstance`` or reintroducing model-manager dependencies into the
|
||||
plugin runtime path.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
runtime: PluginModelRuntime,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
) -> None:
|
||||
self._runtime = runtime
|
||||
self._provider = provider
|
||||
self._model = model
|
||||
self._credentials = credentials
|
||||
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict[str, Any] | None = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
callbacks: object | None = None,
|
||||
) -> LLMResult | Generator[LLMResultChunk, None, None]:
|
||||
del callbacks
|
||||
if stream:
|
||||
return self._runtime.invoke_llm(
|
||||
provider=self._provider,
|
||||
model=self._model,
|
||||
credentials=self._credentials,
|
||||
model_parameters=model_parameters or {},
|
||||
prompt_messages=prompt_messages,
|
||||
tools=list(tools) if tools else None,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
return self._runtime.invoke_llm(
|
||||
provider=self._provider,
|
||||
model=self._model,
|
||||
credentials=self._credentials,
|
||||
model_parameters=model_parameters or {},
|
||||
prompt_messages=prompt_messages,
|
||||
tools=list(tools) if tools else None,
|
||||
stop=stop,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
|
||||
class PluginModelRuntime(ModelRuntime):
|
||||
"""Plugin-backed runtime adapter bound to tenant context and optional caller scope."""
|
||||
|
||||
@ -195,6 +266,34 @@ class PluginModelRuntime(ModelRuntime):
|
||||
|
||||
return schema
|
||||
|
||||
@overload
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
model_parameters: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None,
|
||||
stop: Sequence[str] | None,
|
||||
stream: Literal[False],
|
||||
) -> LLMResult: ...
|
||||
|
||||
@overload
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
model_parameters: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None,
|
||||
stop: Sequence[str] | None,
|
||||
stream: Literal[True],
|
||||
) -> Generator[LLMResultChunk, None, None]: ...
|
||||
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
@ -206,9 +305,9 @@ class PluginModelRuntime(ModelRuntime):
|
||||
tools: list[PromptMessageTool] | None,
|
||||
stop: Sequence[str] | None,
|
||||
stream: bool,
|
||||
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
|
||||
) -> LLMResult | Generator[LLMResultChunk, None, None]:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.invoke_llm(
|
||||
result = self.client.invoke_llm(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
@ -221,6 +320,81 @@ class PluginModelRuntime(ModelRuntime):
|
||||
stop=list(stop) if stop else None,
|
||||
stream=stream,
|
||||
)
|
||||
if stream:
|
||||
return result
|
||||
|
||||
return normalize_non_stream_runtime_result(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
result=result,
|
||||
)
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
json_schema: dict[str, Any],
|
||||
model_parameters: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
stop: Sequence[str] | None,
|
||||
stream: Literal[False],
|
||||
) -> LLMResultWithStructuredOutput: ...
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
json_schema: dict[str, Any],
|
||||
model_parameters: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
stop: Sequence[str] | None,
|
||||
stream: Literal[True],
|
||||
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
|
||||
def invoke_llm_with_structured_output(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
json_schema: dict[str, Any],
|
||||
model_parameters: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
stop: Sequence[str] | None,
|
||||
stream: bool,
|
||||
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
model_schema = self.get_model_schema(
|
||||
provider=provider,
|
||||
model_type=ModelType.LLM,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
)
|
||||
if model_schema is None:
|
||||
raise ValueError(f"Model schema not found for {model}")
|
||||
|
||||
adapter = _PluginStructuredOutputModelInstance(
|
||||
runtime=self,
|
||||
provider=provider,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
)
|
||||
return invoke_llm_with_structured_output_helper(
|
||||
provider=provider,
|
||||
model_schema=model_schema,
|
||||
model_instance=cast(Any, adapter),
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=json_schema,
|
||||
model_parameters=model_parameters,
|
||||
tools=None,
|
||||
stop=list(stop) if stop else None,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
def get_llm_num_tokens(
|
||||
self,
|
||||
|
||||
@ -3,13 +3,46 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from graphon.model_runtime.model_providers.base.ai_model import AIModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
|
||||
from graphon.model_runtime.model_providers.base.rerank_model import RerankModel
|
||||
from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel
|
||||
from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
|
||||
from graphon.model_runtime.model_providers.base.tts_model import TTSModel
|
||||
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from graphon.model_runtime.protocols.runtime import ModelRuntime
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.model_manager import ModelManager
|
||||
from core.plugin.impl.model_runtime import PluginModelRuntime
|
||||
from core.provider_manager import ProviderManager
|
||||
|
||||
_MODEL_CLASS_BY_TYPE: dict[ModelType, type[AIModel]] = {
|
||||
ModelType.LLM: LargeLanguageModel,
|
||||
ModelType.TEXT_EMBEDDING: TextEmbeddingModel,
|
||||
ModelType.RERANK: RerankModel,
|
||||
ModelType.SPEECH2TEXT: Speech2TextModel,
|
||||
ModelType.MODERATION: ModerationModel,
|
||||
ModelType.TTS: TTSModel,
|
||||
}
|
||||
|
||||
|
||||
def create_model_type_instance(
|
||||
*,
|
||||
runtime: ModelRuntime,
|
||||
provider_schema: ProviderEntity,
|
||||
model_type: ModelType,
|
||||
) -> AIModel:
|
||||
"""Build the graphon model wrapper explicitly against the request runtime."""
|
||||
model_class = _MODEL_CLASS_BY_TYPE.get(model_type)
|
||||
if model_class is None:
|
||||
raise ValueError(f"Unsupported model type: {model_type}")
|
||||
|
||||
return model_class(provider_schema=provider_schema, model_runtime=runtime)
|
||||
|
||||
|
||||
class PluginModelAssembly:
|
||||
"""Compose request-scoped model views on top of a single plugin runtime."""
|
||||
@ -38,9 +71,22 @@ class PluginModelAssembly:
|
||||
@property
|
||||
def model_provider_factory(self) -> ModelProviderFactory:
|
||||
if self._model_provider_factory is None:
|
||||
self._model_provider_factory = ModelProviderFactory(model_runtime=self.model_runtime)
|
||||
self._model_provider_factory = ModelProviderFactory(runtime=self.model_runtime)
|
||||
return self._model_provider_factory
|
||||
|
||||
def create_model_type_instance(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
) -> AIModel:
|
||||
provider_schema = self.model_provider_factory.get_provider_schema(provider=provider)
|
||||
return create_model_type_instance(
|
||||
runtime=self.model_runtime,
|
||||
provider_schema=provider_schema,
|
||||
model_type=model_type,
|
||||
)
|
||||
|
||||
@property
|
||||
def provider_manager(self) -> ProviderManager:
|
||||
if self._provider_manager is None:
|
||||
|
||||
@ -56,7 +56,7 @@ from models.provider_ids import ModelProviderID
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphon.model_runtime.runtime import ModelRuntime
|
||||
from graphon.model_runtime.protocols.runtime import ModelRuntime
|
||||
|
||||
_credentials_adapter: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any])
|
||||
|
||||
@ -165,7 +165,7 @@ class ProviderManager:
|
||||
)
|
||||
|
||||
# Get all provider entities
|
||||
model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime)
|
||||
model_provider_factory = ModelProviderFactory(runtime=self._model_runtime)
|
||||
provider_entities = model_provider_factory.get_providers()
|
||||
|
||||
# Get All preferred provider types of the workspace
|
||||
@ -362,7 +362,7 @@ class ProviderManager:
|
||||
if not default_model:
|
||||
return None
|
||||
|
||||
model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime)
|
||||
model_provider_factory = ModelProviderFactory(runtime=self._model_runtime)
|
||||
provider_schema = model_provider_factory.get_provider_schema(provider=default_model.provider_name)
|
||||
|
||||
return DefaultModelEntity(
|
||||
|
||||
@ -374,11 +374,6 @@ class DifyNodeFactory(NodeFactory):
|
||||
# Re-validate using the resolved node class so workflow-local node schemas
|
||||
# stay explicit and constructors receive the concrete typed payload.
|
||||
resolved_node_data = self._validate_resolved_node_data(node_class, node_data)
|
||||
config_for_node_init: BaseNodeData | dict[str, Any]
|
||||
if isinstance(resolved_node_data, BaseNodeData):
|
||||
config_for_node_init = resolved_node_data.model_dump(mode="python", by_alias=True)
|
||||
else:
|
||||
config_for_node_init = resolved_node_data
|
||||
node_type = node_data.type
|
||||
node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = {
|
||||
BuiltinNodeTypes.CODE: lambda: {
|
||||
@ -446,9 +441,10 @@ class DifyNodeFactory(NodeFactory):
|
||||
},
|
||||
}
|
||||
node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})()
|
||||
constructor_node_data = resolved_node_data.model_dump(mode="python", by_alias=True)
|
||||
return node_class(
|
||||
node_id=node_id,
|
||||
config=config_for_node_init,
|
||||
data=constructor_node_data,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
**node_init_kwargs,
|
||||
|
||||
@ -35,7 +35,7 @@ class AgentNode(Node[AgentNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
config: AgentNodeData,
|
||||
data: AgentNodeData,
|
||||
*,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
@ -46,7 +46,7 @@ class AgentNode(Node[AgentNodeData]):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
node_id=node_id,
|
||||
config=config,
|
||||
data=data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
@ -36,14 +36,14 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
config: DatasourceNodeData,
|
||||
data: DatasourceNodeData,
|
||||
*,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
) -> None:
|
||||
super().__init__(
|
||||
node_id=node_id,
|
||||
config=config,
|
||||
data=data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
@ -32,14 +32,14 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
config: KnowledgeIndexNodeData,
|
||||
data: KnowledgeIndexNodeData,
|
||||
*,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
) -> None:
|
||||
super().__init__(
|
||||
node_id=node_id,
|
||||
config=config,
|
||||
data=data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
@ -71,14 +71,14 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
config: KnowledgeRetrievalNodeData,
|
||||
data: KnowledgeRetrievalNodeData,
|
||||
*,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
) -> None:
|
||||
super().__init__(
|
||||
node_id=node_id,
|
||||
config=config,
|
||||
data=data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Any, Protocol, cast
|
||||
from typing import Any, Protocol
|
||||
from uuid import uuid4
|
||||
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
@ -82,13 +82,10 @@ def build_system_variables(values: Mapping[str, Any] | None = None, /, **kwargs:
|
||||
normalized = _normalize_system_variable_values(values, **kwargs)
|
||||
|
||||
return [
|
||||
cast(
|
||||
Variable,
|
||||
segment_to_variable(
|
||||
segment=build_segment(value),
|
||||
selector=system_variable_selector(key),
|
||||
name=key,
|
||||
),
|
||||
segment_to_variable(
|
||||
segment=build_segment(value),
|
||||
selector=system_variable_selector(key),
|
||||
name=key,
|
||||
)
|
||||
for key, value in normalized.items()
|
||||
]
|
||||
@ -130,13 +127,10 @@ def build_bootstrap_variables(
|
||||
|
||||
for node_id, value in rag_pipeline_variables_map.items():
|
||||
variables.append(
|
||||
cast(
|
||||
Variable,
|
||||
segment_to_variable(
|
||||
segment=build_segment(value),
|
||||
selector=(RAG_PIPELINE_VARIABLE_NODE_ID, node_id),
|
||||
name=node_id,
|
||||
),
|
||||
segment_to_variable(
|
||||
segment=build_segment(value),
|
||||
selector=(RAG_PIPELINE_VARIABLE_NODE_ID, node_id),
|
||||
name=node_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -46,6 +46,11 @@ _file_access_controller = DatabaseFileAccessController()
|
||||
|
||||
|
||||
class _WorkflowChildEngineBuilder:
|
||||
tenant_id: str
|
||||
|
||||
def __init__(self, *, tenant_id: str) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
@staticmethod
|
||||
def _has_node_id(graph_config: Mapping[str, Any], node_id: str) -> bool | None:
|
||||
"""
|
||||
@ -107,7 +112,7 @@ class _WorkflowChildEngineBuilder:
|
||||
config=config,
|
||||
child_engine_builder=self,
|
||||
)
|
||||
child_engine.layer(LLMQuotaLayer())
|
||||
child_engine.layer(LLMQuotaLayer(tenant_id=self.tenant_id))
|
||||
return child_engine
|
||||
|
||||
|
||||
@ -176,7 +181,7 @@ class WorkflowEntry:
|
||||
self.command_channel = command_channel
|
||||
execution_context = capture_current_context()
|
||||
graph_runtime_state.execution_context = execution_context
|
||||
self._child_engine_builder = _WorkflowChildEngineBuilder()
|
||||
self._child_engine_builder = _WorkflowChildEngineBuilder(tenant_id=tenant_id)
|
||||
self.graph_engine = GraphEngine(
|
||||
workflow_id=workflow_id,
|
||||
graph=graph,
|
||||
@ -208,7 +213,7 @@ class WorkflowEntry:
|
||||
max_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
|
||||
)
|
||||
self.graph_engine.layer(limits_layer)
|
||||
self.graph_engine.layer(LLMQuotaLayer())
|
||||
self.graph_engine.layer(LLMQuotaLayer(tenant_id=tenant_id))
|
||||
|
||||
# Add observability layer when OTel is enabled
|
||||
if dify_config.ENABLE_OTEL or is_instrument_flag_enabled():
|
||||
|
||||
@ -137,17 +137,13 @@ def handle(sender: Message, **kwargs):
|
||||
if used_quota is not None:
|
||||
match provider_configuration.system_configuration.current_quota_type:
|
||||
case ProviderQuotaType.TRIAL:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
_deduct_credit_pool_quota_capped(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
pool_type="trial",
|
||||
)
|
||||
case ProviderQuotaType.PAID:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
_deduct_credit_pool_quota_capped(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
pool_type="paid",
|
||||
@ -200,6 +196,26 @@ def handle(sender: Message, **kwargs):
|
||||
raise
|
||||
|
||||
|
||||
def _deduct_credit_pool_quota_capped(*, tenant_id: str, credits_required: int, pool_type: str) -> None:
|
||||
"""Apply post-generation credit accounting without failing message persistence on quota exhaustion."""
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
deducted_credits = CreditPoolService.deduct_credits_capped(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=credits_required,
|
||||
pool_type=pool_type,
|
||||
)
|
||||
if deducted_credits < credits_required:
|
||||
logger.warning(
|
||||
"Credit pool exhausted during message-created accounting, "
|
||||
"tenant_id=%s, pool_type=%s, credits_required=%s, credits_deducted=%s",
|
||||
tenant_id,
|
||||
pool_type,
|
||||
credits_required,
|
||||
deducted_credits,
|
||||
)
|
||||
|
||||
|
||||
def _calculate_quota_usage(
|
||||
*, message: Message, system_configuration: SystemConfiguration, model_name: str
|
||||
) -> int | None:
|
||||
|
||||
@ -45,7 +45,7 @@ dependencies = [
|
||||
|
||||
# Emerging: newer and fast-moving, use compatible pins
|
||||
"fastopenapi[flask]~=0.7.0",
|
||||
"graphon~=0.2.2",
|
||||
"graphon~=0.3.0",
|
||||
"httpx-sse~=0.4.0",
|
||||
"json-repair~=0.59.4",
|
||||
]
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.errors.error import QuotaExceededError
|
||||
@ -13,6 +13,18 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CreditPoolService:
|
||||
@staticmethod
|
||||
def _get_locked_pool(session: Session, tenant_id: str, pool_type: str) -> TenantCreditPool | None:
|
||||
return session.scalar(
|
||||
select(TenantCreditPool)
|
||||
.where(
|
||||
TenantCreditPool.tenant_id == tenant_id,
|
||||
TenantCreditPool.pool_type == pool_type,
|
||||
)
|
||||
.limit(1)
|
||||
.with_for_update()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_default_pool(cls, tenant_id: str) -> TenantCreditPool:
|
||||
"""create default credit pool for new tenant"""
|
||||
@ -59,31 +71,57 @@ class CreditPoolService:
|
||||
credits_required: int,
|
||||
pool_type: str = "trial",
|
||||
) -> int:
|
||||
"""check and deduct credits, returns actual credits deducted"""
|
||||
|
||||
pool = cls.get_pool(tenant_id, pool_type)
|
||||
if not pool:
|
||||
raise QuotaExceededError("Credit pool not found")
|
||||
|
||||
if pool.remaining_credits <= 0:
|
||||
raise QuotaExceededError("No credits remaining")
|
||||
|
||||
# deduct all remaining credits if less than required
|
||||
actual_credits = min(credits_required, pool.remaining_credits)
|
||||
"""Deduct exactly the requested credits or raise without mutating the pool."""
|
||||
if credits_required <= 0:
|
||||
return 0
|
||||
|
||||
try:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
stmt = (
|
||||
update(TenantCreditPool)
|
||||
.where(
|
||||
TenantCreditPool.tenant_id == tenant_id,
|
||||
TenantCreditPool.pool_type == pool_type,
|
||||
)
|
||||
.values(quota_used=TenantCreditPool.quota_used + actual_credits)
|
||||
)
|
||||
session.execute(stmt)
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
pool = cls._get_locked_pool(session=session, tenant_id=tenant_id, pool_type=pool_type)
|
||||
if not pool:
|
||||
raise QuotaExceededError("Credit pool not found")
|
||||
|
||||
remaining_credits = pool.remaining_credits
|
||||
if remaining_credits <= 0:
|
||||
raise QuotaExceededError("No credits remaining")
|
||||
if remaining_credits < credits_required:
|
||||
raise QuotaExceededError("Insufficient credits remaining")
|
||||
|
||||
pool.quota_used += credits_required
|
||||
except QuotaExceededError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to deduct credits for tenant %s", tenant_id)
|
||||
raise QuotaExceededError("Failed to deduct credits")
|
||||
|
||||
return actual_credits
|
||||
return credits_required
|
||||
|
||||
@classmethod
|
||||
def deduct_credits_capped(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
credits_required: int,
|
||||
pool_type: str = "trial",
|
||||
) -> int:
|
||||
"""Deduct up to the available balance and return the actual deducted credits."""
|
||||
if credits_required <= 0:
|
||||
return 0
|
||||
|
||||
try:
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
pool = cls._get_locked_pool(session=session, tenant_id=tenant_id, pool_type=pool_type)
|
||||
if not pool:
|
||||
logger.warning("Credit pool not found, tenant_id=%s, pool_type=%s", tenant_id, pool_type)
|
||||
return 0
|
||||
|
||||
deducted_credits = min(credits_required, pool.remaining_credits)
|
||||
if deducted_credits <= 0:
|
||||
return 0
|
||||
|
||||
pool.quota_used += deducted_credits
|
||||
return deducted_credits
|
||||
except QuotaExceededError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to deduct capped credits for tenant %s", tenant_id)
|
||||
raise QuotaExceededError("Failed to deduct credits")
|
||||
|
||||
@ -157,8 +157,8 @@ class DraftVarLoader(VariableLoader):
|
||||
# This approach reduces loading time by querying external systems concurrently.
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
offloaded_variables = executor.map(self._load_offloaded_variable, offloaded_draft_vars)
|
||||
for selector, variable in offloaded_variables:
|
||||
variable_by_selector[selector] = variable
|
||||
for selector, offloaded_variable in offloaded_variables:
|
||||
variable_by_selector[selector] = offloaded_variable
|
||||
|
||||
return list(variable_by_selector.values())
|
||||
|
||||
|
||||
@ -1251,7 +1251,7 @@ class WorkflowService:
|
||||
node_data = HumanInputNode.validate_node_data(adapt_human_input_node_data_for_graph(node_config["data"]))
|
||||
node = HumanInputNode(
|
||||
node_id=node_config["id"],
|
||||
config=node_data,
|
||||
data=node_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
runtime=DifyHumanInputNodeRuntime(run_context),
|
||||
|
||||
@ -73,7 +73,7 @@ def test_node_integration_minimal_stream(mocker: MockerFixture):
|
||||
|
||||
node = DatasourceNode(
|
||||
node_id="n",
|
||||
config=DatasourceNodeData(
|
||||
data=DatasourceNodeData(
|
||||
type="datasource",
|
||||
version="1",
|
||||
title="Datasource",
|
||||
|
||||
@ -4,7 +4,7 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||
from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration
|
||||
from core.model_manager import ModelInstance
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from models.provider import ProviderType
|
||||
|
||||
@ -15,8 +15,9 @@ def get_mocked_fetch_model_config(
|
||||
mode: str,
|
||||
credentials: dict,
|
||||
):
|
||||
model_provider_factory = create_plugin_model_provider_factory(tenant_id="9d2074fc-6f86-45a9-b09d-6ecc63b9056b")
|
||||
model_type_instance = model_provider_factory.get_model_type_instance(provider, ModelType.LLM)
|
||||
model_assembly = create_plugin_model_assembly(tenant_id="9d2074fc-6f86-45a9-b09d-6ecc63b9056b")
|
||||
model_provider_factory = model_assembly.model_provider_factory
|
||||
model_type_instance = model_assembly.create_model_type_instance(provider=provider, model_type=ModelType.LLM)
|
||||
provider_model_bundle = ProviderModelBundle(
|
||||
configuration=ProviderConfiguration(
|
||||
tenant_id="1",
|
||||
|
||||
@ -45,7 +45,7 @@ def init_code_node(code_config: dict):
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
@ -66,7 +66,7 @@ def init_code_node(code_config: dict):
|
||||
|
||||
node = CodeNode(
|
||||
node_id=str(uuid.uuid4()),
|
||||
config=CodeNodeData.model_validate(code_config["data"]),
|
||||
data=CodeNodeData.model_validate(code_config["data"]),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
code_executor=node_factory._code_executor,
|
||||
|
||||
@ -55,7 +55,7 @@ def init_http_node(config: dict):
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
@ -76,7 +76,7 @@ def init_http_node(config: dict):
|
||||
|
||||
node = HttpRequestNode(
|
||||
node_id=str(uuid.uuid4()),
|
||||
config=HttpRequestNodeData.model_validate(config["data"]),
|
||||
data=HttpRequestNodeData.model_validate(config["data"]),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
http_request_config=HTTP_REQUEST_CONFIG,
|
||||
@ -204,7 +204,7 @@ def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock):
|
||||
from graphon.runtime import VariablePool
|
||||
|
||||
# Create variable pool
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(user_id="test", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
@ -702,7 +702,7 @@ def test_nested_object_variable_selector(setup_http_mock):
|
||||
)
|
||||
|
||||
# Create independent variable pool for this test only
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
@ -724,7 +724,7 @@ def test_nested_object_variable_selector(setup_http_mock):
|
||||
|
||||
node = HttpRequestNode(
|
||||
node_id=str(uuid.uuid4()),
|
||||
config=HttpRequestNodeData.model_validate(graph_config["nodes"][1]["data"]),
|
||||
data=HttpRequestNodeData.model_validate(graph_config["nodes"][1]["data"]),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
http_request_config=HTTP_REQUEST_CONFIG,
|
||||
|
||||
@ -53,7 +53,7 @@ def init_llm_node(config: dict) -> LLMNode:
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(
|
||||
user_id="aaa",
|
||||
app_id=app_id,
|
||||
@ -77,7 +77,7 @@ def init_llm_node(config: dict) -> LLMNode:
|
||||
|
||||
node = LLMNode(
|
||||
node_id=str(uuid.uuid4()),
|
||||
config=LLMNodeData.model_validate(config["data"]),
|
||||
data=LLMNodeData.model_validate(config["data"]),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
credentials_provider=MagicMock(spec=CredentialsProvider),
|
||||
|
||||
@ -56,7 +56,7 @@ def init_parameter_extractor_node(config: dict, memory=None):
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(
|
||||
user_id="aaa", files=[], query="what's the weather in SF", conversation_id="abababa"
|
||||
),
|
||||
@ -71,7 +71,7 @@ def init_parameter_extractor_node(config: dict, memory=None):
|
||||
|
||||
node = ParameterExtractorNode(
|
||||
node_id=str(uuid.uuid4()),
|
||||
config=ParameterExtractorNodeData.model_validate(config["data"]),
|
||||
data=ParameterExtractorNodeData.model_validate(config["data"]),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
credentials_provider=MagicMock(spec=CredentialsProvider),
|
||||
|
||||
@ -66,7 +66,7 @@ def test_execute_template_transform():
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
@ -88,7 +88,7 @@ def test_execute_template_transform():
|
||||
|
||||
node = TemplateTransformNode(
|
||||
node_id=str(uuid.uuid4()),
|
||||
config=TemplateTransformNodeData.model_validate(config["data"]),
|
||||
data=TemplateTransformNodeData.model_validate(config["data"]),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
jinja2_template_renderer=_SimpleJinja2Renderer(),
|
||||
|
||||
@ -43,7 +43,7 @@ def init_tool_node(config: dict):
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
@ -64,7 +64,7 @@ def init_tool_node(config: dict):
|
||||
|
||||
node = ToolNode(
|
||||
node_id=str(uuid.uuid4()),
|
||||
config=ToolNodeData.model_validate(config["data"]),
|
||||
data=ToolNodeData.model_validate(config["data"]),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
tool_file_manager_factory=tool_file_manager_factory,
|
||||
|
||||
@ -210,7 +210,9 @@ class TestPauseStatePersistenceLayerTestContainers:
|
||||
execution_id = workflow_run_id or getattr(self, "test_workflow_run_id", None) or str(uuid.uuid4())
|
||||
|
||||
# Create variable pool
|
||||
variable_pool = VariablePool(system_variables=build_system_variables(workflow_execution_id=execution_id))
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(workflow_execution_id=execution_id)
|
||||
)
|
||||
if variables:
|
||||
for (node_id, var_key), value in variables.items():
|
||||
variable_pool.add([node_id, var_key], value)
|
||||
|
||||
@ -66,7 +66,7 @@ def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepos
|
||||
|
||||
|
||||
def _build_runtime_state(workflow_execution_id: str, app_id: str, workflow_id: str, user_id: str) -> GraphRuntimeState:
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(
|
||||
workflow_execution_id=workflow_execution_id,
|
||||
app_id=app_id,
|
||||
@ -102,7 +102,7 @@ def _build_graph(
|
||||
start_data = StartNodeData(title="start", variables=[])
|
||||
start_node = StartNode(
|
||||
node_id="start",
|
||||
config=start_data,
|
||||
data=start_data,
|
||||
graph_init_params=params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
@ -117,7 +117,7 @@ def _build_graph(
|
||||
)
|
||||
human_node = HumanInputNode(
|
||||
node_id="human",
|
||||
config=human_data,
|
||||
data=human_data,
|
||||
graph_init_params=params,
|
||||
graph_runtime_state=runtime_state,
|
||||
form_repository=form_repository,
|
||||
@ -131,7 +131,7 @@ def _build_graph(
|
||||
)
|
||||
end_node = EndNode(
|
||||
node_id="end",
|
||||
config=end_data,
|
||||
data=end_data,
|
||||
graph_init_params=params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
@ -90,16 +90,34 @@ class TestCreditPoolService:
|
||||
pool = CreditPoolService.get_pool(tenant_id=tenant_id)
|
||||
assert pool.quota_used == credits_required
|
||||
|
||||
def test_check_and_deduct_credits_caps_at_remaining(self, db_session_with_containers: Session):
|
||||
def test_check_and_deduct_credits_raises_without_deducting_when_insufficient(
|
||||
self, db_session_with_containers: Session
|
||||
):
|
||||
tenant_id = self._create_tenant_id()
|
||||
pool = CreditPoolService.create_default_pool(tenant_id)
|
||||
remaining = 5
|
||||
pool.quota_used = pool.quota_limit - remaining
|
||||
quota_used = pool.quota_used
|
||||
db_session_with_containers.commit()
|
||||
|
||||
result = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=200)
|
||||
with pytest.raises(QuotaExceededError, match="Insufficient credits remaining"):
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=200)
|
||||
|
||||
db_session_with_containers.expire_all()
|
||||
updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id)
|
||||
assert updated_pool.quota_used == quota_used
|
||||
|
||||
def test_deduct_credits_capped_depletes_available_balance(self, db_session_with_containers: Session):
|
||||
tenant_id = self._create_tenant_id()
|
||||
pool = CreditPoolService.create_default_pool(tenant_id)
|
||||
remaining = 5
|
||||
pool.quota_used = pool.quota_limit - remaining
|
||||
quota_limit = pool.quota_limit
|
||||
db_session_with_containers.commit()
|
||||
|
||||
result = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=200)
|
||||
|
||||
assert result == remaining
|
||||
db_session_with_containers.expire_all()
|
||||
updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id)
|
||||
assert updated_pool.quota_used == pool.quota_limit
|
||||
assert updated_pool.quota_used == quota_limit
|
||||
|
||||
@ -132,7 +132,9 @@ class TestAdvancedChatGenerateTaskPipeline:
|
||||
pipeline._task_state.answer = "partial answer"
|
||||
pipeline._workflow_run_id = "run-id"
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
|
||||
variable_pool=build_test_variable_pool(
|
||||
variables=build_system_variables(workflow_execution_id="run-id"),
|
||||
),
|
||||
start_at=0.0,
|
||||
total_tokens=7,
|
||||
node_run_steps=3,
|
||||
@ -372,7 +374,9 @@ class TestAdvancedChatGenerateTaskPipeline:
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_run_id = "run-id"
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
|
||||
variable_pool=VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(workflow_execution_id="run-id")
|
||||
),
|
||||
start_at=0.0,
|
||||
)
|
||||
pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish"
|
||||
@ -583,7 +587,9 @@ class TestAdvancedChatGenerateTaskPipeline:
|
||||
self.items = items
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
|
||||
variable_pool=VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(workflow_execution_id="run-id")
|
||||
),
|
||||
start_at=0.0,
|
||||
)
|
||||
|
||||
@ -617,7 +623,9 @@ class TestAdvancedChatGenerateTaskPipeline:
|
||||
def test_handle_message_end_event_applies_output_moderation(self, monkeypatch: pytest.MonkeyPatch):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
|
||||
variable_pool=VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(workflow_execution_id="run-id")
|
||||
),
|
||||
start_at=0.0,
|
||||
)
|
||||
pipeline._base_task_pipeline.handle_output_moderation_when_task_finished = lambda answer: "safe"
|
||||
|
||||
@ -60,7 +60,7 @@ class _StubToolNode(Node[_StubToolNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
config: _StubToolNodeData,
|
||||
data: _StubToolNodeData,
|
||||
*,
|
||||
graph_init_params,
|
||||
graph_runtime_state,
|
||||
@ -68,7 +68,7 @@ class _StubToolNode(Node[_StubToolNodeData]):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
node_id=node_id,
|
||||
config=config,
|
||||
data=data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
@ -169,7 +169,7 @@ def _build_graph(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> G
|
||||
|
||||
|
||||
def _build_runtime_state(run_id: str) -> GraphRuntimeState:
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(user_id="user", app_id="app", workflow_id="workflow"),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
|
||||
@ -54,7 +54,7 @@ class TestWorkflowBasedAppRunner:
|
||||
runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app")
|
||||
|
||||
runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=default_system_variables()),
|
||||
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()),
|
||||
start_at=0.0,
|
||||
)
|
||||
|
||||
@ -93,7 +93,7 @@ class TestWorkflowBasedAppRunner:
|
||||
def test_get_graph_and_variable_pool_for_single_node_run(self, monkeypatch: pytest.MonkeyPatch):
|
||||
runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app")
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=default_system_variables()),
|
||||
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()),
|
||||
start_at=0.0,
|
||||
)
|
||||
|
||||
@ -164,7 +164,7 @@ class TestWorkflowBasedAppRunner:
|
||||
app_id="app",
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=default_system_variables()),
|
||||
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()),
|
||||
start_at=0.0,
|
||||
)
|
||||
|
||||
@ -243,7 +243,7 @@ class TestWorkflowBasedAppRunner:
|
||||
|
||||
runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app")
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=default_system_variables()),
|
||||
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()),
|
||||
start_at=0.0,
|
||||
)
|
||||
graph_runtime_state.register_paused_node("node-1")
|
||||
@ -286,7 +286,7 @@ class TestWorkflowBasedAppRunner:
|
||||
|
||||
runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app")
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=default_system_variables()),
|
||||
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()),
|
||||
start_at=0.0,
|
||||
)
|
||||
workflow_entry = SimpleNamespace(graph_engine=SimpleNamespace(graph_runtime_state=graph_runtime_state))
|
||||
@ -425,7 +425,7 @@ class TestWorkflowBasedAppRunner:
|
||||
|
||||
runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app")
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=default_system_variables()),
|
||||
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()),
|
||||
start_at=0.0,
|
||||
)
|
||||
workflow_entry = SimpleNamespace(graph_engine=SimpleNamespace(graph_runtime_state=graph_runtime_state))
|
||||
|
||||
@ -16,7 +16,7 @@ from models.workflow import Workflow
|
||||
|
||||
|
||||
def _make_graph_state():
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=default_system_variables(),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
|
||||
@ -95,7 +95,9 @@ class TestWorkflowGenerateTaskPipeline:
|
||||
def test_to_blocking_response_falls_back_to_human_input_required_when_pause_event_missing(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
|
||||
variable_pool=build_test_variable_pool(
|
||||
variables=build_system_variables(workflow_execution_id="run-id"),
|
||||
),
|
||||
start_at=0.0,
|
||||
total_tokens=5,
|
||||
node_run_steps=2,
|
||||
@ -283,7 +285,9 @@ class TestWorkflowGenerateTaskPipeline:
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_execution_id = "run-id"
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
|
||||
variable_pool=VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(workflow_execution_id="run-id")
|
||||
),
|
||||
start_at=0.0,
|
||||
)
|
||||
pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish"
|
||||
@ -725,7 +729,9 @@ class TestWorkflowGenerateTaskPipeline:
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_execution_id = "run-id"
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
|
||||
variable_pool=VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(workflow_execution_id="run-id")
|
||||
),
|
||||
start_at=0.0,
|
||||
)
|
||||
|
||||
@ -753,7 +759,9 @@ class TestWorkflowGenerateTaskPipeline:
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_execution_id = "run-id"
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
|
||||
variable_pool=VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(workflow_execution_id="run-id")
|
||||
),
|
||||
start_at=0.0,
|
||||
)
|
||||
pipeline._handle_ping_event = lambda event, **kwargs: iter(["ping"])
|
||||
@ -769,7 +777,9 @@ class TestWorkflowGenerateTaskPipeline:
|
||||
def test_process_stream_response_main_match_paths_and_cleanup(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
|
||||
variable_pool=VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(workflow_execution_id="run-id")
|
||||
),
|
||||
start_at=0.0,
|
||||
)
|
||||
pipeline._base_task_pipeline.queue_manager.listen = lambda: iter(
|
||||
|
||||
@ -21,7 +21,9 @@ class TestTriggerPostLayer:
|
||||
)
|
||||
runtime_state = SimpleNamespace(
|
||||
outputs={"answer": "ok"},
|
||||
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-1")),
|
||||
variable_pool=VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(workflow_execution_id="run-1")
|
||||
),
|
||||
total_tokens=12,
|
||||
)
|
||||
|
||||
@ -60,7 +62,9 @@ class TestTriggerPostLayer:
|
||||
def test_on_event_handles_missing_trigger_log(self):
|
||||
runtime_state = SimpleNamespace(
|
||||
outputs={},
|
||||
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-1")),
|
||||
variable_pool=VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(workflow_execution_id="run-1")
|
||||
),
|
||||
total_tokens=0,
|
||||
)
|
||||
|
||||
@ -91,7 +95,9 @@ class TestTriggerPostLayer:
|
||||
def test_on_event_ignores_non_status_events(self):
|
||||
runtime_state = SimpleNamespace(
|
||||
outputs={},
|
||||
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-1")),
|
||||
variable_pool=VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(workflow_execution_id="run-1")
|
||||
),
|
||||
total_tokens=0,
|
||||
)
|
||||
|
||||
|
||||
617
api/tests/unit_tests/core/app/test_llm_quota.py
Normal file
617
api/tests/unit_tests/core/app/test_llm_quota.py
Normal file
@ -0,0 +1,617 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine, select
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.llm.quota import (
|
||||
deduct_llm_quota,
|
||||
deduct_llm_quota_for_model,
|
||||
ensure_llm_quota_available,
|
||||
ensure_llm_quota_available_for_model,
|
||||
)
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
|
||||
from core.errors.error import QuotaExceededError
|
||||
from graphon.model_runtime.entities.llm_entities import LLMUsage
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from models import TenantCreditPool
|
||||
from models.enums import ProviderQuotaType as ModelProviderQuotaType
|
||||
from models.provider import Provider, ProviderType
|
||||
|
||||
|
||||
def test_ensure_llm_quota_available_for_model_raises_when_system_model_is_exhausted() -> None:
|
||||
provider_configuration = SimpleNamespace(
|
||||
using_provider_type=ProviderType.SYSTEM,
|
||||
get_provider_model=MagicMock(return_value=SimpleNamespace(status=ModelStatus.QUOTA_EXCEEDED)),
|
||||
)
|
||||
provider_manager = MagicMock()
|
||||
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
|
||||
|
||||
with (
|
||||
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
|
||||
pytest.raises(QuotaExceededError, match="Model provider openai quota exceeded."),
|
||||
):
|
||||
ensure_llm_quota_available_for_model(
|
||||
tenant_id="tenant-id",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
provider_configuration.get_provider_model.assert_called_once_with(
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
|
||||
def test_ensure_llm_quota_available_for_model_raises_when_provider_is_missing() -> None:
|
||||
provider_manager = MagicMock()
|
||||
provider_manager.get_configurations.return_value.get.return_value = None
|
||||
|
||||
with (
|
||||
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
|
||||
pytest.raises(ValueError, match="Provider openai does not exist."),
|
||||
):
|
||||
ensure_llm_quota_available_for_model(
|
||||
tenant_id="tenant-id",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
|
||||
def test_ensure_llm_quota_available_for_model_ignores_custom_provider_configuration() -> None:
|
||||
provider_configuration = SimpleNamespace(
|
||||
using_provider_type=ProviderType.CUSTOM,
|
||||
get_provider_model=MagicMock(),
|
||||
)
|
||||
provider_manager = MagicMock()
|
||||
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
|
||||
|
||||
with patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager):
|
||||
ensure_llm_quota_available_for_model(
|
||||
tenant_id="tenant-id",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
provider_configuration.get_provider_model.assert_not_called()
|
||||
|
||||
|
||||
def test_deduct_llm_quota_for_model_uses_identity_based_trial_billing() -> None:
|
||||
usage = LLMUsage.empty_usage()
|
||||
usage.total_tokens = 42
|
||||
provider_configuration = SimpleNamespace(
|
||||
using_provider_type=ProviderType.SYSTEM,
|
||||
system_configuration=SimpleNamespace(
|
||||
current_quota_type=ProviderQuotaType.TRIAL,
|
||||
quota_configurations=[
|
||||
SimpleNamespace(
|
||||
quota_type=ProviderQuotaType.TRIAL,
|
||||
quota_unit=QuotaUnit.TOKENS,
|
||||
quota_limit=100,
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
provider_manager = MagicMock()
|
||||
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
|
||||
|
||||
with (
|
||||
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
|
||||
patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits,
|
||||
):
|
||||
deduct_llm_quota_for_model(
|
||||
tenant_id="tenant-id",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
mock_deduct_credits.assert_called_once_with(
|
||||
tenant_id="tenant-id",
|
||||
credits_required=42,
|
||||
)
|
||||
|
||||
|
||||
def test_deduct_llm_quota_for_model_caps_trial_pool_when_usage_exceeds_remaining() -> None:
|
||||
usage = LLMUsage.empty_usage()
|
||||
usage.total_tokens = 3
|
||||
provider_configuration = SimpleNamespace(
|
||||
using_provider_type=ProviderType.SYSTEM,
|
||||
system_configuration=SimpleNamespace(
|
||||
current_quota_type=ProviderQuotaType.TRIAL,
|
||||
quota_configurations=[
|
||||
SimpleNamespace(
|
||||
quota_type=ProviderQuotaType.TRIAL,
|
||||
quota_unit=QuotaUnit.TOKENS,
|
||||
quota_limit=100,
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
provider_manager = MagicMock()
|
||||
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
TenantCreditPool.__table__.create(engine)
|
||||
with engine.begin() as connection:
|
||||
connection.execute(
|
||||
TenantCreditPool.__table__.insert(),
|
||||
{
|
||||
"id": "trial-pool",
|
||||
"tenant_id": "tenant-id",
|
||||
"pool_type": ModelProviderQuotaType.TRIAL,
|
||||
"quota_limit": 10,
|
||||
"quota_used": 9,
|
||||
},
|
||||
)
|
||||
|
||||
with (
|
||||
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
|
||||
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
|
||||
):
|
||||
deduct_llm_quota_for_model(
|
||||
tenant_id="tenant-id",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
with engine.connect() as connection:
|
||||
quota_used = connection.scalar(select(TenantCreditPool.quota_used).where(TenantCreditPool.id == "trial-pool"))
|
||||
|
||||
assert quota_used == 10
|
||||
|
||||
|
||||
def test_deduct_llm_quota_for_model_returns_for_unbounded_quota() -> None:
|
||||
usage = LLMUsage.empty_usage()
|
||||
usage.total_tokens = 42
|
||||
provider_configuration = SimpleNamespace(
|
||||
using_provider_type=ProviderType.SYSTEM,
|
||||
system_configuration=SimpleNamespace(
|
||||
current_quota_type=ProviderQuotaType.TRIAL,
|
||||
quota_configurations=[
|
||||
SimpleNamespace(
|
||||
quota_type=ProviderQuotaType.TRIAL,
|
||||
quota_unit=QuotaUnit.TOKENS,
|
||||
quota_limit=-1,
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
provider_manager = MagicMock()
|
||||
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
|
||||
|
||||
with (
|
||||
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
|
||||
patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits,
|
||||
):
|
||||
deduct_llm_quota_for_model(
|
||||
tenant_id="tenant-id",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
mock_deduct_credits.assert_not_called()
|
||||
|
||||
|
||||
def test_deduct_llm_quota_for_model_uses_credit_configuration() -> None:
|
||||
usage = LLMUsage.empty_usage()
|
||||
provider_configuration = SimpleNamespace(
|
||||
using_provider_type=ProviderType.SYSTEM,
|
||||
system_configuration=SimpleNamespace(
|
||||
current_quota_type=ProviderQuotaType.TRIAL,
|
||||
quota_configurations=[
|
||||
SimpleNamespace(
|
||||
quota_type=ProviderQuotaType.TRIAL,
|
||||
quota_unit=QuotaUnit.CREDITS,
|
||||
quota_limit=100,
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
provider_manager = MagicMock()
|
||||
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
|
||||
|
||||
with (
|
||||
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
|
||||
patch.object(type(dify_config), "get_model_credits", return_value=9) as mock_get_model_credits,
|
||||
patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits,
|
||||
):
|
||||
deduct_llm_quota_for_model(
|
||||
tenant_id="tenant-id",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
mock_get_model_credits.assert_called_once_with("gpt-4o")
|
||||
mock_deduct_credits.assert_called_once_with(
|
||||
tenant_id="tenant-id",
|
||||
credits_required=9,
|
||||
)
|
||||
|
||||
|
||||
def test_deduct_llm_quota_for_model_uses_single_charge_for_times_quota() -> None:
|
||||
usage = LLMUsage.empty_usage()
|
||||
provider_configuration = SimpleNamespace(
|
||||
using_provider_type=ProviderType.SYSTEM,
|
||||
system_configuration=SimpleNamespace(
|
||||
current_quota_type=ProviderQuotaType.TRIAL,
|
||||
quota_configurations=[
|
||||
SimpleNamespace(
|
||||
quota_type=ProviderQuotaType.TRIAL,
|
||||
quota_unit=QuotaUnit.TIMES,
|
||||
quota_limit=100,
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
provider_manager = MagicMock()
|
||||
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
|
||||
|
||||
with (
|
||||
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
|
||||
patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits,
|
||||
):
|
||||
deduct_llm_quota_for_model(
|
||||
tenant_id="tenant-id",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
mock_deduct_credits.assert_called_once_with(
|
||||
tenant_id="tenant-id",
|
||||
credits_required=1,
|
||||
)
|
||||
|
||||
|
||||
def test_deduct_llm_quota_for_model_uses_paid_billing_pool() -> None:
|
||||
usage = LLMUsage.empty_usage()
|
||||
usage.total_tokens = 5
|
||||
provider_configuration = SimpleNamespace(
|
||||
using_provider_type=ProviderType.SYSTEM,
|
||||
system_configuration=SimpleNamespace(
|
||||
current_quota_type=ProviderQuotaType.PAID,
|
||||
quota_configurations=[
|
||||
SimpleNamespace(
|
||||
quota_type=ProviderQuotaType.PAID,
|
||||
quota_unit=QuotaUnit.TOKENS,
|
||||
quota_limit=100,
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
provider_manager = MagicMock()
|
||||
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
|
||||
|
||||
with (
|
||||
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
|
||||
patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits,
|
||||
):
|
||||
deduct_llm_quota_for_model(
|
||||
tenant_id="tenant-id",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
mock_deduct_credits.assert_called_once_with(
|
||||
tenant_id="tenant-id",
|
||||
credits_required=5,
|
||||
pool_type="paid",
|
||||
)
|
||||
|
||||
|
||||
def test_deduct_llm_quota_for_model_updates_free_quota_usage() -> None:
|
||||
usage = LLMUsage.empty_usage()
|
||||
usage.total_tokens = 3
|
||||
provider_configuration = SimpleNamespace(
|
||||
using_provider_type=ProviderType.SYSTEM,
|
||||
system_configuration=SimpleNamespace(
|
||||
current_quota_type=ProviderQuotaType.FREE,
|
||||
quota_configurations=[
|
||||
SimpleNamespace(
|
||||
quota_type=ProviderQuotaType.FREE,
|
||||
quota_unit=QuotaUnit.TOKENS,
|
||||
quota_limit=100,
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
provider_manager = MagicMock()
|
||||
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
Provider.__table__.create(engine)
|
||||
with engine.begin() as connection:
|
||||
connection.execute(
|
||||
Provider.__table__.insert(),
|
||||
[
|
||||
{
|
||||
"id": "matching-provider",
|
||||
"tenant_id": "tenant-id",
|
||||
"provider_name": "openai",
|
||||
"provider_type": ProviderType.SYSTEM,
|
||||
"quota_type": ProviderQuotaType.FREE,
|
||||
"quota_limit": 100,
|
||||
"quota_used": 10,
|
||||
"is_valid": True,
|
||||
},
|
||||
{
|
||||
"id": "other-tenant",
|
||||
"tenant_id": "other-tenant-id",
|
||||
"provider_name": "openai",
|
||||
"provider_type": ProviderType.SYSTEM,
|
||||
"quota_type": ProviderQuotaType.FREE,
|
||||
"quota_limit": 100,
|
||||
"quota_used": 20,
|
||||
"is_valid": True,
|
||||
},
|
||||
{
|
||||
"id": "other-provider",
|
||||
"tenant_id": "tenant-id",
|
||||
"provider_name": "anthropic",
|
||||
"provider_type": ProviderType.SYSTEM,
|
||||
"quota_type": ProviderQuotaType.FREE,
|
||||
"quota_limit": 100,
|
||||
"quota_used": 30,
|
||||
"is_valid": True,
|
||||
},
|
||||
{
|
||||
"id": "custom-provider",
|
||||
"tenant_id": "tenant-id",
|
||||
"provider_name": "openai",
|
||||
"provider_type": ProviderType.CUSTOM,
|
||||
"quota_type": ProviderQuotaType.FREE,
|
||||
"quota_limit": 100,
|
||||
"quota_used": 40,
|
||||
"is_valid": True,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
with (
|
||||
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
|
||||
patch("core.app.llm.quota.db", SimpleNamespace(engine=engine)),
|
||||
):
|
||||
deduct_llm_quota_for_model(
|
||||
tenant_id="tenant-id",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
with engine.connect() as connection:
|
||||
quota_used_by_id = dict(connection.execute(select(Provider.id, Provider.quota_used)).all())
|
||||
|
||||
assert quota_used_by_id == {
|
||||
"matching-provider": 13,
|
||||
"other-tenant": 20,
|
||||
"other-provider": 30,
|
||||
"custom-provider": 40,
|
||||
}
|
||||
|
||||
with engine.begin() as connection:
|
||||
connection.execute(
|
||||
Provider.__table__.update().where(Provider.id == "matching-provider").values(quota_limit=13, quota_used=13)
|
||||
)
|
||||
|
||||
with (
|
||||
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
|
||||
patch("core.app.llm.quota.db", SimpleNamespace(engine=engine)),
|
||||
pytest.raises(QuotaExceededError, match="Model provider openai quota exceeded."),
|
||||
):
|
||||
deduct_llm_quota_for_model(
|
||||
tenant_id="tenant-id",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
with engine.connect() as connection:
|
||||
exhausted_quota_used = connection.scalar(select(Provider.quota_used).where(Provider.id == "matching-provider"))
|
||||
|
||||
assert exhausted_quota_used == 13
|
||||
|
||||
|
||||
def test_deduct_llm_quota_for_model_caps_free_quota_and_raises_when_usage_exceeds_remaining() -> None:
|
||||
usage = LLMUsage.empty_usage()
|
||||
usage.total_tokens = 3
|
||||
provider_configuration = SimpleNamespace(
|
||||
using_provider_type=ProviderType.SYSTEM,
|
||||
system_configuration=SimpleNamespace(
|
||||
current_quota_type=ProviderQuotaType.FREE,
|
||||
quota_configurations=[
|
||||
SimpleNamespace(
|
||||
quota_type=ProviderQuotaType.FREE,
|
||||
quota_unit=QuotaUnit.TOKENS,
|
||||
quota_limit=100,
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
provider_manager = MagicMock()
|
||||
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
Provider.__table__.create(engine)
|
||||
with engine.begin() as connection:
|
||||
connection.execute(
|
||||
Provider.__table__.insert(),
|
||||
{
|
||||
"id": "matching-provider",
|
||||
"tenant_id": "tenant-id",
|
||||
"provider_name": "openai",
|
||||
"provider_type": ProviderType.SYSTEM,
|
||||
"quota_type": ProviderQuotaType.FREE,
|
||||
"quota_limit": 15,
|
||||
"quota_used": 13,
|
||||
"is_valid": True,
|
||||
},
|
||||
)
|
||||
|
||||
with (
|
||||
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
|
||||
patch("core.app.llm.quota.db", SimpleNamespace(engine=engine)),
|
||||
pytest.raises(QuotaExceededError, match="Model provider openai quota exceeded."),
|
||||
):
|
||||
deduct_llm_quota_for_model(
|
||||
tenant_id="tenant-id",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
with engine.connect() as connection:
|
||||
quota_used = connection.scalar(select(Provider.quota_used).where(Provider.id == "matching-provider"))
|
||||
|
||||
assert quota_used == 15
|
||||
|
||||
|
||||
def test_deduct_llm_quota_for_model_ignores_unknown_quota_type() -> None:
|
||||
usage = LLMUsage.empty_usage()
|
||||
usage.total_tokens = 2
|
||||
provider_configuration = SimpleNamespace(
|
||||
using_provider_type=ProviderType.SYSTEM,
|
||||
system_configuration=SimpleNamespace(
|
||||
current_quota_type="unexpected",
|
||||
quota_configurations=[
|
||||
SimpleNamespace(
|
||||
quota_type="unexpected",
|
||||
quota_unit=QuotaUnit.TOKENS,
|
||||
quota_limit=100,
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
provider_manager = MagicMock()
|
||||
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
|
||||
|
||||
with (
|
||||
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
|
||||
patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits,
|
||||
patch("core.app.llm.quota.sessionmaker") as mock_sessionmaker,
|
||||
):
|
||||
deduct_llm_quota_for_model(
|
||||
tenant_id="tenant-id",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
mock_deduct_credits.assert_not_called()
|
||||
mock_sessionmaker.assert_not_called()
|
||||
|
||||
|
||||
def test_deduct_llm_quota_for_model_ignores_custom_provider_configuration() -> None:
|
||||
usage = LLMUsage.empty_usage()
|
||||
usage.total_tokens = 2
|
||||
provider_configuration = SimpleNamespace(
|
||||
using_provider_type=ProviderType.CUSTOM,
|
||||
system_configuration=SimpleNamespace(
|
||||
current_quota_type=ProviderQuotaType.TRIAL,
|
||||
quota_configurations=[],
|
||||
),
|
||||
)
|
||||
provider_manager = MagicMock()
|
||||
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
|
||||
|
||||
with (
|
||||
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
|
||||
patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits,
|
||||
patch("core.app.llm.quota.sessionmaker") as mock_sessionmaker,
|
||||
):
|
||||
deduct_llm_quota_for_model(
|
||||
tenant_id="tenant-id",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
mock_deduct_credits.assert_not_called()
|
||||
mock_sessionmaker.assert_not_called()
|
||||
|
||||
|
||||
def test_ensure_llm_quota_available_wrapper_warns_and_delegates() -> None:
|
||||
model_instance = SimpleNamespace(
|
||||
provider="openai",
|
||||
model_name="gpt-4o",
|
||||
provider_model_bundle=SimpleNamespace(configuration=SimpleNamespace(tenant_id="tenant-id")),
|
||||
model_type_instance=SimpleNamespace(model_type=ModelType.LLM),
|
||||
)
|
||||
|
||||
with (
|
||||
pytest.deprecated_call(match="ensure_llm_quota_available\\(model_instance=.*deprecated"),
|
||||
patch("core.app.llm.quota.ensure_llm_quota_available_for_model") as mock_ensure,
|
||||
):
|
||||
ensure_llm_quota_available(model_instance=model_instance)
|
||||
|
||||
mock_ensure.assert_called_once_with(
|
||||
tenant_id="tenant-id",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
|
||||
def test_ensure_llm_quota_available_wrapper_rejects_non_llm_model_instances() -> None:
|
||||
model_instance = SimpleNamespace(
|
||||
provider="openai",
|
||||
model_name="gpt-4o",
|
||||
provider_model_bundle=SimpleNamespace(configuration=SimpleNamespace(tenant_id="tenant-id")),
|
||||
model_type_instance=SimpleNamespace(model_type=ModelType.TEXT_EMBEDDING),
|
||||
)
|
||||
|
||||
with (
|
||||
pytest.deprecated_call(match="ensure_llm_quota_available\\(model_instance=.*deprecated"),
|
||||
pytest.raises(ValueError, match="only support LLM model instances"),
|
||||
):
|
||||
ensure_llm_quota_available(model_instance=model_instance)
|
||||
|
||||
|
||||
def test_deduct_llm_quota_wrapper_warns_and_delegates() -> None:
|
||||
usage = LLMUsage.empty_usage()
|
||||
usage.total_tokens = 7
|
||||
model_instance = SimpleNamespace(
|
||||
provider="openai",
|
||||
model_name="gpt-4o",
|
||||
model_type_instance=SimpleNamespace(model_type=ModelType.LLM),
|
||||
provider_model_bundle=SimpleNamespace(configuration=SimpleNamespace()),
|
||||
)
|
||||
|
||||
with (
|
||||
pytest.deprecated_call(match="deduct_llm_quota\\(tenant_id=.*deprecated"),
|
||||
patch("core.app.llm.quota.deduct_llm_quota_for_model") as mock_deduct,
|
||||
):
|
||||
deduct_llm_quota(
|
||||
tenant_id="tenant-id",
|
||||
model_instance=model_instance,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
mock_deduct.assert_called_once_with(
|
||||
tenant_id="tenant-id",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
def test_deduct_llm_quota_wrapper_rejects_non_llm_model_instances() -> None:
|
||||
usage = LLMUsage.empty_usage()
|
||||
model_instance = SimpleNamespace(
|
||||
provider="openai",
|
||||
model_name="gpt-4o",
|
||||
model_type_instance=SimpleNamespace(model_type=ModelType.TEXT_EMBEDDING),
|
||||
provider_model_bundle=SimpleNamespace(configuration=SimpleNamespace()),
|
||||
)
|
||||
|
||||
with (
|
||||
pytest.deprecated_call(match="deduct_llm_quota\\(tenant_id=.*deprecated"),
|
||||
pytest.raises(ValueError, match="only support LLM model instances"),
|
||||
):
|
||||
deduct_llm_quota(
|
||||
tenant_id="tenant-id",
|
||||
model_instance=model_instance,
|
||||
usage=usage,
|
||||
)
|
||||
@ -8,9 +8,9 @@ from graphon.enums import BuiltinNodeTypes
|
||||
|
||||
|
||||
class DummyNode:
|
||||
def __init__(self, *, node_id, config, graph_init_params, graph_runtime_state, **kwargs):
|
||||
def __init__(self, *, node_id, data, graph_init_params, graph_runtime_state, **kwargs):
|
||||
self.id = node_id
|
||||
self.config = config
|
||||
self.data = data
|
||||
self.graph_init_params = graph_init_params
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self.kwargs = kwargs
|
||||
|
||||
@ -60,7 +60,10 @@ def _make_layer(
|
||||
workflow_execution_id="run-id",
|
||||
conversation_id="conv-id",
|
||||
)
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(system_variables=system_variables), start_at=0.0)
|
||||
runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.from_bootstrap(system_variables=system_variables),
|
||||
start_at=0.0,
|
||||
)
|
||||
read_only_state = ReadOnlyGraphRuntimeStateWrapper(runtime_state)
|
||||
|
||||
application_generate_entity = WorkflowAppGenerateEntity.model_construct(
|
||||
|
||||
@ -354,7 +354,8 @@ def test_validate_provider_credentials_handles_hidden_secret_value() -> None:
|
||||
|
||||
with _patched_session(mock_session):
|
||||
with patch(
|
||||
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
|
||||
"core.entities.provider_configuration.create_plugin_model_assembly",
|
||||
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory),
|
||||
):
|
||||
with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="restored-key"):
|
||||
with patch(
|
||||
@ -379,7 +380,10 @@ def test_validate_provider_credentials_without_credential_id() -> None:
|
||||
mock_factory = Mock()
|
||||
mock_factory.provider_credentials_validate.return_value = {"region": "us"}
|
||||
|
||||
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
|
||||
with patch(
|
||||
"core.entities.provider_configuration.create_plugin_model_assembly",
|
||||
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory),
|
||||
):
|
||||
validated = configuration.validate_provider_credentials(credentials={"region": "us"})
|
||||
|
||||
assert validated == {"region": "us"}
|
||||
@ -426,23 +430,37 @@ def test_switch_preferred_provider_type_creates_record_when_missing() -> None:
|
||||
|
||||
def test_get_model_type_instance_and_schema_delegate_to_factory() -> None:
|
||||
configuration = _build_provider_configuration()
|
||||
mock_factory = Mock()
|
||||
mock_model_type_instance = Mock()
|
||||
mock_schema = _build_ai_model("gpt-4o")
|
||||
mock_factory.get_model_type_instance.return_value = mock_model_type_instance
|
||||
mock_factory = Mock()
|
||||
mock_factory.get_provider_schema.return_value = configuration.provider
|
||||
mock_factory.get_model_schema.return_value = mock_schema
|
||||
mock_assembly = Mock()
|
||||
mock_assembly.model_runtime = Mock()
|
||||
mock_assembly.model_provider_factory = mock_factory
|
||||
|
||||
with patch(
|
||||
"core.entities.provider_configuration.create_plugin_model_provider_factory",
|
||||
return_value=mock_factory,
|
||||
) as mock_factory_builder:
|
||||
with (
|
||||
patch(
|
||||
"core.entities.provider_configuration.create_plugin_model_assembly",
|
||||
return_value=mock_assembly,
|
||||
) as mock_assembly_builder,
|
||||
patch(
|
||||
"core.entities.provider_configuration.create_model_type_instance",
|
||||
return_value=mock_model_type_instance,
|
||||
) as mock_model_builder,
|
||||
):
|
||||
model_type_instance = configuration.get_model_type_instance(ModelType.LLM)
|
||||
model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"})
|
||||
|
||||
assert model_type_instance is mock_model_type_instance
|
||||
assert model_schema is mock_schema
|
||||
assert mock_factory_builder.call_count == 2
|
||||
mock_factory.get_model_type_instance.assert_called_once_with(provider="openai", model_type=ModelType.LLM)
|
||||
assert mock_assembly_builder.call_count == 2
|
||||
mock_factory.get_provider_schema.assert_called_once_with(provider="openai")
|
||||
mock_model_builder.assert_called_once_with(
|
||||
runtime=mock_assembly.model_runtime,
|
||||
provider_schema=configuration.provider,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
mock_factory.get_model_schema.assert_called_once_with(
|
||||
provider="openai",
|
||||
model_type=ModelType.LLM,
|
||||
@ -456,17 +474,21 @@ def test_get_model_type_instance_and_schema_reuse_bound_runtime_factory() -> Non
|
||||
bound_runtime = Mock()
|
||||
configuration.bind_model_runtime(bound_runtime)
|
||||
|
||||
mock_factory = Mock()
|
||||
mock_model_type_instance = Mock()
|
||||
mock_schema = _build_ai_model("gpt-4o")
|
||||
mock_factory.get_model_type_instance.return_value = mock_model_type_instance
|
||||
mock_factory = Mock()
|
||||
mock_factory.get_provider_schema.return_value = configuration.provider
|
||||
mock_factory.get_model_schema.return_value = mock_schema
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory
|
||||
) as mock_factory_cls,
|
||||
patch("core.entities.provider_configuration.create_plugin_model_provider_factory") as mock_factory_builder,
|
||||
patch("core.entities.provider_configuration.create_plugin_model_assembly") as mock_assembly_builder,
|
||||
patch(
|
||||
"core.entities.provider_configuration.create_model_type_instance",
|
||||
return_value=mock_model_type_instance,
|
||||
) as mock_model_builder,
|
||||
):
|
||||
model_type_instance = configuration.get_model_type_instance(ModelType.LLM)
|
||||
model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"})
|
||||
@ -474,8 +496,14 @@ def test_get_model_type_instance_and_schema_reuse_bound_runtime_factory() -> Non
|
||||
assert model_type_instance is mock_model_type_instance
|
||||
assert model_schema is mock_schema
|
||||
assert mock_factory_cls.call_count == 2
|
||||
mock_factory_cls.assert_called_with(model_runtime=bound_runtime)
|
||||
mock_factory_builder.assert_not_called()
|
||||
mock_factory_cls.assert_called_with(runtime=bound_runtime)
|
||||
mock_assembly_builder.assert_not_called()
|
||||
mock_factory.get_provider_schema.assert_called_once_with(provider="openai")
|
||||
mock_model_builder.assert_called_once_with(
|
||||
runtime=bound_runtime,
|
||||
provider_schema=configuration.provider,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
|
||||
|
||||
def test_get_provider_model_returns_none_when_model_not_found() -> None:
|
||||
@ -504,7 +532,10 @@ def test_get_provider_models_system_deduplicates_sorts_and_filters_active() -> N
|
||||
mock_factory = Mock()
|
||||
mock_factory.get_provider_schema.return_value = provider_schema
|
||||
|
||||
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
|
||||
with patch(
|
||||
"core.entities.provider_configuration.create_plugin_model_assembly",
|
||||
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory),
|
||||
):
|
||||
all_models = configuration.get_provider_models(model_type=ModelType.LLM, only_active=False)
|
||||
active_models = configuration.get_provider_models(model_type=ModelType.LLM, only_active=True)
|
||||
|
||||
@ -722,7 +753,8 @@ def test_validate_provider_credentials_handles_invalid_original_json() -> None:
|
||||
|
||||
with _patched_session(mock_session):
|
||||
with patch(
|
||||
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
|
||||
"core.entities.provider_configuration.create_plugin_model_assembly",
|
||||
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory),
|
||||
):
|
||||
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-key"):
|
||||
validated = configuration.validate_provider_credentials(
|
||||
@ -1069,7 +1101,8 @@ def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless
|
||||
|
||||
with _patched_session(mock_session):
|
||||
with patch(
|
||||
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
|
||||
"core.entities.provider_configuration.create_plugin_model_assembly",
|
||||
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory),
|
||||
):
|
||||
with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"):
|
||||
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
|
||||
@ -1083,7 +1116,10 @@ def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless
|
||||
|
||||
mock_factory2 = Mock()
|
||||
mock_factory2.model_credentials_validate.return_value = {"region": "us"}
|
||||
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory2):
|
||||
with patch(
|
||||
"core.entities.provider_configuration.create_plugin_model_assembly",
|
||||
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory2),
|
||||
):
|
||||
validated = configuration.validate_custom_model_credentials(
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o",
|
||||
@ -1575,7 +1611,8 @@ def test_validate_provider_credentials_uses_empty_original_when_record_missing()
|
||||
|
||||
with _patched_session(mock_session):
|
||||
with patch(
|
||||
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
|
||||
"core.entities.provider_configuration.create_plugin_model_assembly",
|
||||
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory),
|
||||
):
|
||||
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
|
||||
validated = configuration.validate_provider_credentials(
|
||||
@ -1701,7 +1738,8 @@ def test_validate_custom_model_credentials_handles_invalid_original_json() -> No
|
||||
|
||||
with _patched_session(mock_session):
|
||||
with patch(
|
||||
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
|
||||
"core.entities.provider_configuration.create_plugin_model_assembly",
|
||||
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory),
|
||||
):
|
||||
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
|
||||
validated = configuration.validate_custom_model_credentials(
|
||||
|
||||
@ -68,8 +68,8 @@ def test_check_moderation_returns_true_when_model_accepts_text(mocker: MockerFix
|
||||
mocker.patch("core.helper.moderation.secrets.choice", return_value="chunk")
|
||||
|
||||
moderation_model = SimpleNamespace(invoke=lambda **invoke_kwargs: invoke_kwargs["text"] == "chunk")
|
||||
factory = SimpleNamespace(get_model_type_instance=lambda **_factory_kwargs: moderation_model)
|
||||
mocker.patch("core.helper.moderation.create_plugin_model_provider_factory", return_value=factory)
|
||||
assembly = SimpleNamespace(create_model_type_instance=lambda **_factory_kwargs: moderation_model)
|
||||
mocker.patch("core.helper.moderation.create_plugin_model_assembly", return_value=assembly)
|
||||
|
||||
assert (
|
||||
check_moderation(
|
||||
@ -91,7 +91,7 @@ def test_check_moderation_returns_true_when_text_is_empty(mocker: MockerFixture)
|
||||
provider_map={openai_provider: hosting_openai},
|
||||
),
|
||||
)
|
||||
factory_mock = mocker.patch("core.helper.moderation.create_plugin_model_provider_factory")
|
||||
factory_mock = mocker.patch("core.helper.moderation.create_plugin_model_assembly")
|
||||
choice_mock = mocker.patch("core.helper.moderation.secrets.choice")
|
||||
|
||||
assert (
|
||||
@ -119,8 +119,8 @@ def test_check_moderation_returns_false_when_model_rejects_text(mocker: MockerFi
|
||||
mocker.patch("core.helper.moderation.secrets.choice", return_value="chunk")
|
||||
|
||||
moderation_model = SimpleNamespace(invoke=lambda **_invoke_kwargs: False)
|
||||
factory = SimpleNamespace(get_model_type_instance=lambda **_factory_kwargs: moderation_model)
|
||||
mocker.patch("core.helper.moderation.create_plugin_model_provider_factory", return_value=factory)
|
||||
assembly = SimpleNamespace(create_model_type_instance=lambda **_factory_kwargs: moderation_model)
|
||||
mocker.patch("core.helper.moderation.create_plugin_model_assembly", return_value=assembly)
|
||||
|
||||
assert (
|
||||
check_moderation(
|
||||
@ -147,8 +147,8 @@ def test_check_moderation_raises_bad_request_when_provider_call_fails(mocker: Mo
|
||||
failing_model = SimpleNamespace(
|
||||
invoke=lambda **_invoke_kwargs: (_ for _ in ()).throw(RuntimeError("boom")),
|
||||
)
|
||||
factory = SimpleNamespace(get_model_type_instance=lambda **_factory_kwargs: failing_model)
|
||||
mocker.patch("core.helper.moderation.create_plugin_model_provider_factory", return_value=factory)
|
||||
assembly = SimpleNamespace(create_model_type_instance=lambda **_factory_kwargs: failing_model)
|
||||
mocker.patch("core.helper.moderation.create_plugin_model_assembly", return_value=assembly)
|
||||
|
||||
with pytest.raises(InvokeBadRequestError, match="Rate limit exceeded, please try again later."):
|
||||
check_moderation(
|
||||
|
||||
@ -2,6 +2,7 @@ from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.plugin.impl.model_runtime_factory import create_model_type_instance
|
||||
from graphon.model_runtime.entities.common_entities import I18nObject
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
from graphon.model_runtime.entities.provider_entities import (
|
||||
@ -73,7 +74,7 @@ def test_model_provider_factory_resolves_runtime_provider_name() -> None:
|
||||
supported_model_types=[ModelType.LLM],
|
||||
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
|
||||
)
|
||||
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime([provider]))
|
||||
factory = ModelProviderFactory(runtime=_FakeModelRuntime([provider]))
|
||||
|
||||
provider_schema = factory.get_model_provider("openai")
|
||||
|
||||
@ -98,7 +99,7 @@ def test_model_provider_factory_resolves_canonical_short_name_independent_of_pro
|
||||
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
|
||||
),
|
||||
]
|
||||
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers))
|
||||
factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers))
|
||||
|
||||
provider_schema = factory.get_model_provider("openai")
|
||||
|
||||
@ -107,8 +108,8 @@ def test_model_provider_factory_resolves_canonical_short_name_independent_of_pro
|
||||
|
||||
|
||||
def test_model_provider_factory_requires_runtime() -> None:
|
||||
with pytest.raises(ValueError, match="model_runtime is required"):
|
||||
ModelProviderFactory(model_runtime=None) # type: ignore[arg-type]
|
||||
with pytest.raises(ValueError, match="runtime is required"):
|
||||
ModelProviderFactory(runtime=None) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_model_provider_factory_get_providers_returns_runtime_providers() -> None:
|
||||
@ -119,7 +120,7 @@ def test_model_provider_factory_get_providers_returns_runtime_providers() -> Non
|
||||
supported_model_types=[ModelType.LLM],
|
||||
)
|
||||
]
|
||||
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers))
|
||||
factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers))
|
||||
|
||||
result = factory.get_providers()
|
||||
|
||||
@ -133,7 +134,7 @@ def test_model_provider_factory_get_provider_schema_delegates_to_provider_lookup
|
||||
provider_name="openai",
|
||||
supported_model_types=[ModelType.LLM],
|
||||
)
|
||||
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime([provider]))
|
||||
factory = ModelProviderFactory(runtime=_FakeModelRuntime([provider]))
|
||||
|
||||
result = factory.get_provider_schema("openai")
|
||||
|
||||
@ -142,7 +143,7 @@ def test_model_provider_factory_get_provider_schema_delegates_to_provider_lookup
|
||||
|
||||
def test_model_provider_factory_raises_for_unknown_provider() -> None:
|
||||
factory = ModelProviderFactory(
|
||||
model_runtime=_FakeModelRuntime(
|
||||
runtime=_FakeModelRuntime(
|
||||
[
|
||||
_build_provider(
|
||||
provider="langgenius/openai/openai",
|
||||
@ -172,7 +173,7 @@ def test_model_provider_factory_get_models_filters_provider_and_model_type() ->
|
||||
models=[_build_model("rerank-v3", ModelType.RERANK)],
|
||||
),
|
||||
]
|
||||
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers))
|
||||
factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers))
|
||||
|
||||
results = factory.get_models(provider="openai", model_type=ModelType.LLM)
|
||||
|
||||
@ -196,7 +197,7 @@ def test_model_provider_factory_get_models_skips_providers_without_requested_mod
|
||||
models=[_build_model("eleven_multilingual_v2", ModelType.TTS)],
|
||||
),
|
||||
]
|
||||
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers))
|
||||
factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers))
|
||||
|
||||
results = factory.get_models(model_type=ModelType.TTS)
|
||||
|
||||
@ -214,7 +215,7 @@ def test_model_provider_factory_get_models_without_model_type_keeps_all_provider
|
||||
models=[_build_model("gpt-4o-mini", ModelType.LLM), _build_model("tts-1", ModelType.TTS)],
|
||||
)
|
||||
]
|
||||
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers))
|
||||
factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers))
|
||||
|
||||
results = factory.get_models(provider="openai")
|
||||
|
||||
@ -242,7 +243,7 @@ def test_model_provider_factory_validates_provider_credentials() -> None:
|
||||
)
|
||||
]
|
||||
)
|
||||
factory = ModelProviderFactory(model_runtime=runtime)
|
||||
factory = ModelProviderFactory(runtime=runtime)
|
||||
|
||||
filtered = factory.provider_credentials_validate(
|
||||
provider="openai",
|
||||
@ -258,7 +259,7 @@ def test_model_provider_factory_validates_provider_credentials() -> None:
|
||||
|
||||
def test_model_provider_factory_provider_credentials_validate_requires_schema() -> None:
|
||||
factory = ModelProviderFactory(
|
||||
model_runtime=_FakeModelRuntime(
|
||||
runtime=_FakeModelRuntime(
|
||||
[
|
||||
_build_provider(
|
||||
provider="langgenius/openai/openai",
|
||||
@ -294,7 +295,7 @@ def test_model_provider_factory_validates_model_credentials() -> None:
|
||||
)
|
||||
]
|
||||
)
|
||||
factory = ModelProviderFactory(model_runtime=runtime)
|
||||
factory = ModelProviderFactory(runtime=runtime)
|
||||
|
||||
filtered = factory.model_credentials_validate(
|
||||
provider="openai",
|
||||
@ -314,7 +315,7 @@ def test_model_provider_factory_validates_model_credentials() -> None:
|
||||
|
||||
def test_model_provider_factory_model_credentials_validate_requires_schema() -> None:
|
||||
factory = ModelProviderFactory(
|
||||
model_runtime=_FakeModelRuntime(
|
||||
runtime=_FakeModelRuntime(
|
||||
[
|
||||
_build_provider(
|
||||
provider="langgenius/openai/openai",
|
||||
@ -346,7 +347,7 @@ def test_model_provider_factory_get_model_schema_and_icon_use_canonical_provider
|
||||
)
|
||||
runtime.get_model_schema.return_value = "schema"
|
||||
runtime.get_provider_icon.return_value = (b"icon", "image/png")
|
||||
factory = ModelProviderFactory(model_runtime=runtime)
|
||||
factory = ModelProviderFactory(runtime=runtime)
|
||||
|
||||
assert (
|
||||
factory.get_model_schema(
|
||||
@ -382,39 +383,43 @@ def test_model_provider_factory_get_model_schema_and_icon_use_canonical_provider
|
||||
(ModelType.TTS, TTSModel),
|
||||
],
|
||||
)
|
||||
def test_model_provider_factory_builds_model_type_instances(
|
||||
def test_create_model_type_instance_builds_model_wrappers(
|
||||
model_type: ModelType,
|
||||
expected_type: type[object],
|
||||
) -> None:
|
||||
factory = ModelProviderFactory(
|
||||
model_runtime=_FakeModelRuntime(
|
||||
[
|
||||
_build_provider(
|
||||
provider="langgenius/openai/openai",
|
||||
provider_name="openai",
|
||||
supported_model_types=[model_type],
|
||||
)
|
||||
]
|
||||
)
|
||||
runtime = _FakeModelRuntime(
|
||||
[
|
||||
_build_provider(
|
||||
provider="langgenius/openai/openai",
|
||||
provider_name="openai",
|
||||
supported_model_types=[model_type],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
instance = factory.get_model_type_instance("openai", model_type)
|
||||
instance = create_model_type_instance(
|
||||
runtime=runtime,
|
||||
provider_schema=runtime.fetch_model_providers()[0],
|
||||
model_type=model_type,
|
||||
)
|
||||
|
||||
assert isinstance(instance, expected_type)
|
||||
|
||||
|
||||
def test_model_provider_factory_rejects_unsupported_model_type() -> None:
|
||||
factory = ModelProviderFactory(
|
||||
model_runtime=_FakeModelRuntime(
|
||||
[
|
||||
_build_provider(
|
||||
provider="langgenius/openai/openai",
|
||||
provider_name="openai",
|
||||
supported_model_types=[ModelType.LLM],
|
||||
)
|
||||
]
|
||||
)
|
||||
def test_create_model_type_instance_rejects_unsupported_model_type() -> None:
|
||||
runtime = _FakeModelRuntime(
|
||||
[
|
||||
_build_provider(
|
||||
provider="langgenius/openai/openai",
|
||||
provider_name="openai",
|
||||
supported_model_types=[ModelType.LLM],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported model type: unsupported"):
|
||||
factory.get_model_type_instance("openai", "unsupported") # type: ignore[arg-type]
|
||||
create_model_type_instance(
|
||||
runtime=runtime,
|
||||
provider_schema=runtime.fetch_model_providers()[0],
|
||||
model_type="unsupported", # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
@ -31,6 +31,6 @@ def test_plugin_model_assembly_reuses_single_runtime_across_views():
|
||||
assert assembly.model_manager is model_manager
|
||||
|
||||
mock_runtime_factory.assert_called_once_with(tenant_id="tenant-1", user_id="user-1")
|
||||
mock_provider_factory_cls.assert_called_once_with(model_runtime=runtime)
|
||||
mock_provider_factory_cls.assert_called_once_with(runtime=runtime)
|
||||
mock_provider_manager_cls.assert_called_once_with(model_runtime=runtime)
|
||||
mock_model_manager_cls.assert_called_once_with(provider_manager=provider_manager)
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
import datetime
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, sentinel
|
||||
from unittest.mock import Mock, patch, sentinel
|
||||
|
||||
import pytest
|
||||
|
||||
@ -13,6 +13,8 @@ from core.plugin.impl.model import PluginModelClient
|
||||
from core.plugin.impl.model_runtime import TENANT_SCOPE_SCHEMA_CACHE_USER_ID, PluginModelRuntime
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime
|
||||
from graphon.model_runtime.entities.common_entities import I18nObject
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from graphon.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
|
||||
|
||||
@ -146,7 +148,31 @@ class TestPluginModelRuntime:
|
||||
|
||||
def test_invoke_llm_resolves_plugin_fields(self) -> None:
|
||||
client = Mock(spec=PluginModelClient)
|
||||
client.invoke_llm.return_value = sentinel.result
|
||||
usage = LLMUsage.empty_usage()
|
||||
client.invoke_llm.return_value = iter(
|
||||
[
|
||||
LLMResultChunk(
|
||||
model="gpt-4o-mini",
|
||||
prompt_messages=[],
|
||||
system_fingerprint="fp-plugin",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content="plugin "),
|
||||
),
|
||||
),
|
||||
LLMResultChunk(
|
||||
model="gpt-4o-mini",
|
||||
prompt_messages=[],
|
||||
system_fingerprint="fp-plugin",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=1,
|
||||
message=AssistantPromptMessage(content="response"),
|
||||
usage=usage,
|
||||
finish_reason="stop",
|
||||
),
|
||||
),
|
||||
]
|
||||
)
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client)
|
||||
|
||||
result = runtime.invoke_llm(
|
||||
@ -160,7 +186,11 @@ class TestPluginModelRuntime:
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert result is sentinel.result
|
||||
assert result.model == "gpt-4o-mini"
|
||||
assert result.prompt_messages == []
|
||||
assert result.message.content == "plugin response"
|
||||
assert result.usage == usage
|
||||
assert result.system_fingerprint == "fp-plugin"
|
||||
client.invoke_llm.assert_called_once_with(
|
||||
tenant_id="tenant",
|
||||
user_id="user",
|
||||
@ -175,6 +205,38 @@ class TestPluginModelRuntime:
|
||||
stream=False,
|
||||
)
|
||||
|
||||
def test_invoke_llm_returns_plugin_stream_directly(self) -> None:
|
||||
client = Mock(spec=PluginModelClient)
|
||||
stream_result = iter([])
|
||||
client.invoke_llm.return_value = stream_result
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client)
|
||||
|
||||
result = runtime.invoke_llm(
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-4o-mini",
|
||||
credentials={"api_key": "secret"},
|
||||
model_parameters={"temperature": 0.3},
|
||||
prompt_messages=[],
|
||||
tools=None,
|
||||
stop=("END",),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
assert result is stream_result
|
||||
client.invoke_llm.assert_called_once_with(
|
||||
tenant_id="tenant",
|
||||
user_id="user",
|
||||
plugin_id="langgenius/openai",
|
||||
provider="openai",
|
||||
model="gpt-4o-mini",
|
||||
credentials={"api_key": "secret"},
|
||||
model_parameters={"temperature": 0.3},
|
||||
prompt_messages=[],
|
||||
tools=None,
|
||||
stop=["END"],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
def test_invoke_llm_rejects_per_call_user_override(self) -> None:
|
||||
client = Mock(spec=PluginModelClient)
|
||||
client.invoke_llm.return_value = sentinel.result
|
||||
@ -267,6 +329,129 @@ def test_get_model_schema_uses_cached_schema_without_hitting_client(monkeypatch:
|
||||
client.get_model_schema.assert_not_called()
|
||||
|
||||
|
||||
def test_structured_output_adapter_invokes_bound_runtime_streaming() -> None:
|
||||
runtime = Mock()
|
||||
runtime.invoke_llm.return_value = sentinel.stream_result
|
||||
adapter = model_runtime_module._PluginStructuredOutputModelInstance(
|
||||
runtime=runtime,
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-4o-mini",
|
||||
credentials={"api_key": "secret"},
|
||||
)
|
||||
tool = Mock()
|
||||
|
||||
result = adapter.invoke_llm(
|
||||
prompt_messages=[],
|
||||
model_parameters=None,
|
||||
tools=[tool],
|
||||
stop=["END"],
|
||||
stream=True,
|
||||
callbacks=sentinel.callbacks,
|
||||
)
|
||||
|
||||
assert result is sentinel.stream_result
|
||||
runtime.invoke_llm.assert_called_once_with(
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-4o-mini",
|
||||
credentials={"api_key": "secret"},
|
||||
model_parameters={},
|
||||
prompt_messages=[],
|
||||
tools=[tool],
|
||||
stop=["END"],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
|
||||
def test_structured_output_adapter_invokes_bound_runtime_non_streaming() -> None:
|
||||
runtime = Mock()
|
||||
runtime.invoke_llm.return_value = sentinel.result
|
||||
adapter = model_runtime_module._PluginStructuredOutputModelInstance(
|
||||
runtime=runtime,
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-4o-mini",
|
||||
credentials={"api_key": "secret"},
|
||||
)
|
||||
|
||||
result = adapter.invoke_llm(
|
||||
prompt_messages=[],
|
||||
model_parameters={"temperature": 0},
|
||||
tools=None,
|
||||
stop=None,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert result is sentinel.result
|
||||
runtime.invoke_llm.assert_called_once_with(
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-4o-mini",
|
||||
credentials={"api_key": "secret"},
|
||||
model_parameters={"temperature": 0},
|
||||
prompt_messages=[],
|
||||
tools=None,
|
||||
stop=None,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_llm_with_structured_output_delegates_with_bound_adapter() -> None:
|
||||
client = Mock(spec=PluginModelClient)
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client)
|
||||
schema = _build_model_schema()
|
||||
runtime.get_model_schema = Mock(return_value=schema) # type: ignore[method-assign]
|
||||
|
||||
with patch.object(
|
||||
model_runtime_module,
|
||||
"invoke_llm_with_structured_output_helper",
|
||||
return_value=sentinel.structured_result,
|
||||
) as mock_helper:
|
||||
result = runtime.invoke_llm_with_structured_output(
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-4o-mini",
|
||||
credentials={"api_key": "secret"},
|
||||
json_schema={"type": "object"},
|
||||
model_parameters={"temperature": 0},
|
||||
prompt_messages=[],
|
||||
stop=("END",),
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert result is sentinel.structured_result
|
||||
runtime.get_model_schema.assert_called_once_with(
|
||||
provider="langgenius/openai/openai",
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o-mini",
|
||||
credentials={"api_key": "secret"},
|
||||
)
|
||||
helper_kwargs = mock_helper.call_args.kwargs
|
||||
assert helper_kwargs["provider"] == "langgenius/openai/openai"
|
||||
assert helper_kwargs["model_schema"] == schema
|
||||
assert helper_kwargs["json_schema"] == {"type": "object"}
|
||||
assert helper_kwargs["model_parameters"] == {"temperature": 0}
|
||||
assert helper_kwargs["prompt_messages"] == []
|
||||
assert helper_kwargs["tools"] is None
|
||||
assert helper_kwargs["stop"] == ["END"]
|
||||
assert helper_kwargs["stream"] is False
|
||||
assert isinstance(helper_kwargs["model_instance"], model_runtime_module._PluginStructuredOutputModelInstance)
|
||||
|
||||
|
||||
def test_invoke_llm_with_structured_output_raises_when_model_schema_is_missing() -> None:
|
||||
client = Mock(spec=PluginModelClient)
|
||||
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client)
|
||||
runtime.get_model_schema = Mock(return_value=None) # type: ignore[method-assign]
|
||||
|
||||
with pytest.raises(ValueError, match="Model schema not found for gpt-4o-mini"):
|
||||
runtime.invoke_llm_with_structured_output(
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-4o-mini",
|
||||
credentials={"api_key": "secret"},
|
||||
json_schema={"type": "object"},
|
||||
model_parameters={},
|
||||
prompt_messages=[],
|
||||
stop=None,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
|
||||
def test_get_model_schema_deletes_invalid_cache_and_refetches(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
client = Mock(spec=PluginModelClient)
|
||||
schema = _build_model_schema()
|
||||
|
||||
@ -289,7 +289,7 @@ def test_get_default_model_uses_injected_runtime_for_existing_default_record(moc
|
||||
|
||||
result = manager.get_default_model("tenant-id", ModelType.LLM)
|
||||
|
||||
mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime)
|
||||
mock_factory_cls.assert_called_once_with(runtime=manager._model_runtime)
|
||||
assert result is not None
|
||||
assert result.model == "gpt-4"
|
||||
assert result.provider.provider == "openai"
|
||||
@ -316,7 +316,7 @@ def test_get_configurations_uses_injected_runtime_and_adds_provider_aliases(mock
|
||||
result = manager.get_configurations("tenant-id")
|
||||
|
||||
expected_alias = str(ModelProviderID("openai"))
|
||||
mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime)
|
||||
mock_factory_cls.assert_called_once_with(runtime=manager._model_runtime)
|
||||
assert result.tenant_id == "tenant-id"
|
||||
assert expected_alias in provider_records
|
||||
assert expected_alias in provider_model_records
|
||||
@ -402,7 +402,7 @@ def test_get_configurations_reuses_cached_result_for_same_tenant(mocker: MockerF
|
||||
|
||||
assert first is second
|
||||
mock_get_all_providers.assert_called_once_with("tenant-id")
|
||||
mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime)
|
||||
mock_factory_cls.assert_called_once_with(runtime=manager._model_runtime)
|
||||
mock_provider_configuration.assert_called_once()
|
||||
provider_configuration.bind_model_runtime.assert_called_once_with(manager._model_runtime)
|
||||
|
||||
|
||||
@ -1,12 +1,11 @@
|
||||
import logging
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, UserFrom
|
||||
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
|
||||
from core.errors.error import QuotaExceededError
|
||||
from core.model_manager import ModelInstance
|
||||
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
|
||||
from graphon.graph_engine.entities.commands import CommandType
|
||||
from graphon.graph_events import NodeRunSucceededEvent
|
||||
@ -14,17 +13,7 @@ from graphon.model_runtime.entities.llm_entities import LLMUsage
|
||||
from graphon.node_events import NodeRunResult
|
||||
|
||||
|
||||
def _build_dify_context() -> DifyRunContext:
|
||||
return DifyRunContext(
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
user_id="user-id",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
)
|
||||
|
||||
|
||||
def _build_succeeded_event() -> NodeRunSucceededEvent:
|
||||
def _build_succeeded_event(*, provider: str = "openai", model_name: str = "gpt-4o") -> NodeRunSucceededEvent:
|
||||
return NodeRunSucceededEvent(
|
||||
id="execution-id",
|
||||
node_id="llm-node-id",
|
||||
@ -32,113 +21,162 @@ def _build_succeeded_event() -> NodeRunSucceededEvent:
|
||||
start_at=datetime.now(),
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={"question": "hello"},
|
||||
inputs={
|
||||
"question": "hello",
|
||||
"model_provider": provider,
|
||||
"model_name": model_name,
|
||||
},
|
||||
llm_usage=LLMUsage.empty_usage(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _build_wrapped_model_instance() -> tuple[SimpleNamespace, ModelInstance]:
|
||||
raw_model_instance = ModelInstance.__new__(ModelInstance)
|
||||
return SimpleNamespace(_model_instance=raw_model_instance), raw_model_instance
|
||||
def _build_public_model_identity(*, provider: str = "openai", model_name: str = "gpt-4o") -> SimpleNamespace:
|
||||
return SimpleNamespace(provider=provider, name=model_name)
|
||||
|
||||
|
||||
def _build_node_data(*, model: SimpleNamespace | None = None) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
error_strategy=None,
|
||||
retry_config=SimpleNamespace(retry_enabled=False),
|
||||
model=model,
|
||||
)
|
||||
|
||||
|
||||
def _build_node(*, node_type: BuiltinNodeTypes = BuiltinNodeTypes.LLM) -> MagicMock:
|
||||
node = MagicMock()
|
||||
node.id = "node-id"
|
||||
node.execution_id = "execution-id"
|
||||
node.node_type = node_type
|
||||
node.node_data = _build_node_data(model=_build_public_model_identity())
|
||||
node.model_instance = SimpleNamespace(provider="stale-provider", model_name="stale-model")
|
||||
return node
|
||||
|
||||
|
||||
class _RunnableQuotaNode:
|
||||
id = "node-id"
|
||||
execution_id = "execution-id"
|
||||
node_type = BuiltinNodeTypes.LLM
|
||||
title = "LLM node"
|
||||
|
||||
def __init__(self, *, stop_event: threading.Event, node_data: SimpleNamespace | None = None) -> None:
|
||||
self.node_data = node_data or _build_node_data(model=_build_public_model_identity())
|
||||
self.graph_runtime_state = SimpleNamespace(stop_event=stop_event)
|
||||
self.original_run_called = False
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
self.original_run_called = True
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED)
|
||||
|
||||
|
||||
def test_deduct_quota_called_for_successful_llm_node() -> None:
|
||||
layer = LLMQuotaLayer()
|
||||
node = MagicMock()
|
||||
node.id = "llm-node-id"
|
||||
node.execution_id = "execution-id"
|
||||
node.node_type = BuiltinNodeTypes.LLM
|
||||
node.tenant_id = "tenant-id"
|
||||
node.require_run_context_value.return_value = _build_dify_context()
|
||||
node.model_instance, raw_model_instance = _build_wrapped_model_instance()
|
||||
|
||||
layer = LLMQuotaLayer(tenant_id="tenant-id")
|
||||
node = _build_node(node_type=BuiltinNodeTypes.LLM)
|
||||
result_event = _build_succeeded_event()
|
||||
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
|
||||
|
||||
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model", autospec=True) as mock_deduct:
|
||||
layer.on_node_run_end(node=node, error=None, result_event=result_event)
|
||||
|
||||
mock_deduct.assert_called_once_with(
|
||||
tenant_id="tenant-id",
|
||||
model_instance=raw_model_instance,
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
usage=result_event.node_run_result.llm_usage,
|
||||
)
|
||||
|
||||
|
||||
def test_deduct_quota_called_for_question_classifier_node() -> None:
|
||||
layer = LLMQuotaLayer()
|
||||
node = MagicMock()
|
||||
node.id = "question-classifier-node-id"
|
||||
node.execution_id = "execution-id"
|
||||
node.node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER
|
||||
node.tenant_id = "tenant-id"
|
||||
node.require_run_context_value.return_value = _build_dify_context()
|
||||
node.model_instance, raw_model_instance = _build_wrapped_model_instance()
|
||||
layer = LLMQuotaLayer(tenant_id="tenant-id")
|
||||
node = _build_node(node_type=BuiltinNodeTypes.QUESTION_CLASSIFIER)
|
||||
result_event = _build_succeeded_event(provider="anthropic", model_name="claude-3-7-sonnet")
|
||||
|
||||
result_event = _build_succeeded_event()
|
||||
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
|
||||
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model", autospec=True) as mock_deduct:
|
||||
layer.on_node_run_end(node=node, error=None, result_event=result_event)
|
||||
|
||||
mock_deduct.assert_called_once_with(
|
||||
tenant_id="tenant-id",
|
||||
model_instance=raw_model_instance,
|
||||
provider="anthropic",
|
||||
model="claude-3-7-sonnet",
|
||||
usage=result_event.node_run_result.llm_usage,
|
||||
)
|
||||
|
||||
|
||||
def test_non_llm_node_is_ignored() -> None:
|
||||
layer = LLMQuotaLayer()
|
||||
node = MagicMock()
|
||||
node.id = "start-node-id"
|
||||
node.execution_id = "execution-id"
|
||||
node.node_type = BuiltinNodeTypes.START
|
||||
node.tenant_id = "tenant-id"
|
||||
node.require_run_context_value.return_value = _build_dify_context()
|
||||
node._model_instance = object()
|
||||
|
||||
layer = LLMQuotaLayer(tenant_id="tenant-id")
|
||||
node = _build_node(node_type=BuiltinNodeTypes.START)
|
||||
result_event = _build_succeeded_event()
|
||||
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
|
||||
|
||||
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model", autospec=True) as mock_deduct:
|
||||
layer.on_node_run_end(node=node, error=None, result_event=result_event)
|
||||
|
||||
mock_deduct.assert_not_called()
|
||||
|
||||
|
||||
def test_quota_error_is_handled_in_layer() -> None:
|
||||
layer = LLMQuotaLayer()
|
||||
node = MagicMock()
|
||||
node.id = "llm-node-id"
|
||||
node.execution_id = "execution-id"
|
||||
node.node_type = BuiltinNodeTypes.LLM
|
||||
node.tenant_id = "tenant-id"
|
||||
node.require_run_context_value.return_value = _build_dify_context()
|
||||
node.model_instance = object()
|
||||
def test_precheck_ignores_non_quota_node() -> None:
|
||||
layer = LLMQuotaLayer(tenant_id="tenant-id")
|
||||
node = _build_node(node_type=BuiltinNodeTypes.START)
|
||||
|
||||
result_event = _build_succeeded_event()
|
||||
with patch(
|
||||
"core.app.workflow.layers.llm_quota.deduct_llm_quota",
|
||||
autospec=True,
|
||||
side_effect=ValueError("quota exceeded"),
|
||||
):
|
||||
layer.on_node_run_end(node=node, error=None, result_event=result_event)
|
||||
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check:
|
||||
layer.on_node_run_start(node)
|
||||
|
||||
mock_check.assert_not_called()
|
||||
|
||||
|
||||
def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None:
|
||||
layer = LLMQuotaLayer()
|
||||
def test_quota_error_is_handled_in_layer(caplog) -> None:
|
||||
layer = LLMQuotaLayer(tenant_id="tenant-id")
|
||||
stop_event = threading.Event()
|
||||
layer.command_channel = MagicMock()
|
||||
|
||||
node = MagicMock()
|
||||
node.id = "llm-node-id"
|
||||
node.execution_id = "execution-id"
|
||||
node.node_type = BuiltinNodeTypes.LLM
|
||||
node.tenant_id = "tenant-id"
|
||||
node.require_run_context_value.return_value = _build_dify_context()
|
||||
node.model_instance, _ = _build_wrapped_model_instance()
|
||||
node = _build_node(node_type=BuiltinNodeTypes.LLM)
|
||||
node.graph_runtime_state = MagicMock()
|
||||
node.graph_runtime_state.stop_event = stop_event
|
||||
result_event = _build_succeeded_event()
|
||||
|
||||
with (
|
||||
caplog.at_level(logging.ERROR, logger="core.app.workflow.layers.llm_quota"),
|
||||
patch(
|
||||
"core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model",
|
||||
autospec=True,
|
||||
side_effect=ValueError("quota exceeded"),
|
||||
) as mock_deduct,
|
||||
):
|
||||
layer.on_node_run_end(node=node, error=None, result_event=result_event)
|
||||
|
||||
mock_deduct.assert_called_once_with(
|
||||
tenant_id="tenant-id",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
usage=result_event.node_run_result.llm_usage,
|
||||
)
|
||||
assert "LLM quota deduction failed, node_id=node-id" in caplog.text
|
||||
assert not stop_event.is_set()
|
||||
layer.command_channel.send_command.assert_not_called()
|
||||
|
||||
|
||||
def test_send_abort_command_is_noop_without_channel_or_after_abort() -> None:
|
||||
layer = LLMQuotaLayer(tenant_id="tenant-id")
|
||||
|
||||
layer._send_abort_command(reason="no channel")
|
||||
|
||||
layer.command_channel = MagicMock()
|
||||
layer._abort_sent = True
|
||||
layer._send_abort_command(reason="already aborted")
|
||||
|
||||
layer.command_channel.send_command.assert_not_called()
|
||||
|
||||
|
||||
def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None:
|
||||
layer = LLMQuotaLayer(tenant_id="tenant-id")
|
||||
stop_event = threading.Event()
|
||||
layer.command_channel = MagicMock()
|
||||
|
||||
node = _build_node(node_type=BuiltinNodeTypes.LLM)
|
||||
node.graph_runtime_state = MagicMock()
|
||||
node.graph_runtime_state.stop_event = stop_event
|
||||
|
||||
result_event = _build_succeeded_event()
|
||||
with patch(
|
||||
"core.app.workflow.layers.llm_quota.deduct_llm_quota",
|
||||
"core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model",
|
||||
autospec=True,
|
||||
side_effect=QuotaExceededError("No credits remaining"),
|
||||
):
|
||||
@ -152,19 +190,16 @@ def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None:
|
||||
|
||||
|
||||
def test_quota_precheck_failure_aborts_workflow_immediately() -> None:
|
||||
layer = LLMQuotaLayer()
|
||||
layer = LLMQuotaLayer(tenant_id="tenant-id")
|
||||
stop_event = threading.Event()
|
||||
layer.command_channel = MagicMock()
|
||||
|
||||
node = MagicMock()
|
||||
node.id = "llm-node-id"
|
||||
node.node_type = BuiltinNodeTypes.LLM
|
||||
node.model_instance, _ = _build_wrapped_model_instance()
|
||||
node = _build_node(node_type=BuiltinNodeTypes.LLM)
|
||||
node.graph_runtime_state = MagicMock()
|
||||
node.graph_runtime_state.stop_event = stop_event
|
||||
|
||||
with patch(
|
||||
"core.app.workflow.layers.llm_quota.ensure_llm_quota_available",
|
||||
"core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model",
|
||||
autospec=True,
|
||||
side_effect=QuotaExceededError("Model provider openai quota exceeded."),
|
||||
):
|
||||
@ -177,21 +212,140 @@ def test_quota_precheck_failure_aborts_workflow_immediately() -> None:
|
||||
assert abort_command.reason == "Model provider openai quota exceeded."
|
||||
|
||||
|
||||
def test_quota_precheck_passes_without_abort() -> None:
|
||||
layer = LLMQuotaLayer()
|
||||
def test_quota_precheck_failure_blocks_current_node_run() -> None:
|
||||
layer = LLMQuotaLayer(tenant_id="tenant-id")
|
||||
stop_event = threading.Event()
|
||||
layer.command_channel = MagicMock()
|
||||
|
||||
node = MagicMock()
|
||||
node.id = "llm-node-id"
|
||||
node.node_type = BuiltinNodeTypes.LLM
|
||||
node.model_instance, raw_model_instance = _build_wrapped_model_instance()
|
||||
node = _RunnableQuotaNode(stop_event=stop_event)
|
||||
|
||||
with patch(
|
||||
"core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model",
|
||||
autospec=True,
|
||||
side_effect=QuotaExceededError("Model provider openai quota exceeded."),
|
||||
):
|
||||
layer.on_node_run_start(node)
|
||||
|
||||
result = node._run()
|
||||
assert not node.original_run_called
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert result.error == "Model provider openai quota exceeded."
|
||||
assert result.error_type == QuotaExceededError.__name__
|
||||
|
||||
|
||||
def test_missing_model_identity_blocks_current_node_run() -> None:
|
||||
layer = LLMQuotaLayer(tenant_id="tenant-id")
|
||||
stop_event = threading.Event()
|
||||
layer.command_channel = MagicMock()
|
||||
|
||||
node = _RunnableQuotaNode(stop_event=stop_event, node_data=_build_node_data())
|
||||
|
||||
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check:
|
||||
layer.on_node_run_start(node)
|
||||
|
||||
result = node._run()
|
||||
assert not node.original_run_called
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert result.error == "LLM quota check requires public node model identity before execution."
|
||||
assert result.error_type == "LLMQuotaIdentityError"
|
||||
mock_check.assert_not_called()
|
||||
|
||||
|
||||
def test_quota_precheck_passes_without_abort() -> None:
|
||||
layer = LLMQuotaLayer(tenant_id="tenant-id")
|
||||
stop_event = threading.Event()
|
||||
layer.command_channel = MagicMock()
|
||||
|
||||
node = _build_node(node_type=BuiltinNodeTypes.LLM)
|
||||
node.graph_runtime_state = MagicMock()
|
||||
node.graph_runtime_state.stop_event = stop_event
|
||||
|
||||
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available", autospec=True) as mock_check:
|
||||
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check:
|
||||
layer.on_node_run_start(node)
|
||||
|
||||
assert not stop_event.is_set()
|
||||
mock_check.assert_called_once_with(model_instance=raw_model_instance)
|
||||
mock_check.assert_called_once_with(
|
||||
tenant_id="tenant-id",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
)
|
||||
layer.command_channel.send_command.assert_not_called()
|
||||
|
||||
|
||||
def test_precheck_reads_model_identity_from_data_when_node_data_is_absent() -> None:
|
||||
layer = LLMQuotaLayer(tenant_id="tenant-id")
|
||||
node = SimpleNamespace(
|
||||
id="node-id",
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
data=_build_node_data(model=_build_public_model_identity(provider="anthropic", model_name="claude")),
|
||||
)
|
||||
|
||||
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check:
|
||||
layer.on_node_run_start(node)
|
||||
|
||||
mock_check.assert_called_once_with(
|
||||
tenant_id="tenant-id",
|
||||
provider="anthropic",
|
||||
model="claude",
|
||||
)
|
||||
|
||||
|
||||
def test_precheck_rejects_invalid_public_model_identity() -> None:
|
||||
layer = LLMQuotaLayer(tenant_id="tenant-id")
|
||||
stop_event = threading.Event()
|
||||
layer.command_channel = MagicMock()
|
||||
|
||||
node = _build_node(node_type=BuiltinNodeTypes.LLM)
|
||||
node.node_data = _build_node_data(model=_build_public_model_identity(provider="", model_name="gpt-4o"))
|
||||
node.graph_runtime_state = MagicMock()
|
||||
node.graph_runtime_state.stop_event = stop_event
|
||||
|
||||
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check:
|
||||
layer.on_node_run_start(node)
|
||||
|
||||
assert stop_event.is_set()
|
||||
mock_check.assert_not_called()
|
||||
layer.command_channel.send_command.assert_called_once()
|
||||
|
||||
|
||||
def test_precheck_requires_public_node_model_config() -> None:
|
||||
layer = LLMQuotaLayer(tenant_id="tenant-id")
|
||||
stop_event = threading.Event()
|
||||
layer.command_channel = MagicMock()
|
||||
|
||||
node = _build_node(node_type=BuiltinNodeTypes.LLM)
|
||||
node.node_data = _build_node_data()
|
||||
node.graph_runtime_state = MagicMock()
|
||||
node.graph_runtime_state.stop_event = stop_event
|
||||
|
||||
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check:
|
||||
layer.on_node_run_start(node)
|
||||
|
||||
assert stop_event.is_set()
|
||||
mock_check.assert_not_called()
|
||||
layer.command_channel.send_command.assert_called_once()
|
||||
abort_command = layer.command_channel.send_command.call_args.args[0]
|
||||
assert abort_command.command_type == CommandType.ABORT
|
||||
assert abort_command.reason == "LLM quota check requires public node model identity before execution."
|
||||
|
||||
|
||||
def test_deduction_requires_public_event_model_identity() -> None:
|
||||
layer = LLMQuotaLayer(tenant_id="tenant-id")
|
||||
stop_event = threading.Event()
|
||||
layer.command_channel = MagicMock()
|
||||
|
||||
node = _build_node(node_type=BuiltinNodeTypes.LLM)
|
||||
node.graph_runtime_state = MagicMock()
|
||||
node.graph_runtime_state.stop_event = stop_event
|
||||
result_event = _build_succeeded_event()
|
||||
result_event.node_run_result.inputs = {"question": "hello"}
|
||||
|
||||
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model", autospec=True) as mock_deduct:
|
||||
layer.on_node_run_end(node=node, error=None, result_event=result_event)
|
||||
|
||||
assert stop_event.is_set()
|
||||
mock_deduct.assert_not_called()
|
||||
layer.command_channel.send_command.assert_called_once()
|
||||
abort_command = layer.command_channel.send_command.call_args.args[0]
|
||||
assert abort_command.command_type == CommandType.ABORT
|
||||
assert abort_command.reason == "LLM quota deduction requires model identity in the node result event."
|
||||
|
||||
@ -96,7 +96,7 @@ class MockNodeFactory(DifyNodeFactory):
|
||||
if node_type == BuiltinNodeTypes.CODE:
|
||||
mock_instance = mock_class(
|
||||
node_id=node_id,
|
||||
config=resolved_node_data,
|
||||
data=resolved_node_data,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
mock_config=self.mock_config,
|
||||
@ -106,7 +106,7 @@ class MockNodeFactory(DifyNodeFactory):
|
||||
elif node_type == BuiltinNodeTypes.HTTP_REQUEST:
|
||||
mock_instance = mock_class(
|
||||
node_id=node_id,
|
||||
config=resolved_node_data,
|
||||
data=resolved_node_data,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
mock_config=self.mock_config,
|
||||
@ -122,7 +122,7 @@ class MockNodeFactory(DifyNodeFactory):
|
||||
}:
|
||||
mock_instance = mock_class(
|
||||
node_id=node_id,
|
||||
config=resolved_node_data,
|
||||
data=resolved_node_data,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
mock_config=self.mock_config,
|
||||
@ -132,7 +132,7 @@ class MockNodeFactory(DifyNodeFactory):
|
||||
else:
|
||||
mock_instance = mock_class(
|
||||
node_id=node_id,
|
||||
config=resolved_node_data,
|
||||
data=resolved_node_data,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
mock_config=self.mock_config,
|
||||
|
||||
@ -56,7 +56,7 @@ class MockNodeMixin:
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
config: Any,
|
||||
data: Any,
|
||||
*,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
@ -98,7 +98,7 @@ class MockNodeMixin:
|
||||
|
||||
super().__init__(
|
||||
node_id=node_id,
|
||||
config=config,
|
||||
data=data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
**kwargs,
|
||||
|
||||
@ -111,7 +111,7 @@ class StaticRepo(HumanInputFormRepository):
|
||||
|
||||
|
||||
def _build_runtime_state() -> GraphRuntimeState:
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(
|
||||
user_id="user",
|
||||
app_id="app",
|
||||
@ -140,7 +140,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
|
||||
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
|
||||
start_node = StartNode(
|
||||
node_id=start_config["id"],
|
||||
config=StartNodeData(title="Start", variables=[]),
|
||||
data=StartNodeData(title="Start", variables=[]),
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
@ -155,7 +155,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
|
||||
human_a_config = {"id": "human_a", "data": human_data.model_dump()}
|
||||
human_a = HumanInputNode(
|
||||
node_id=human_a_config["id"],
|
||||
config=human_data,
|
||||
data=human_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
form_repository=repo,
|
||||
@ -165,7 +165,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
|
||||
human_b_config = {"id": "human_b", "data": human_data.model_dump()}
|
||||
human_b = HumanInputNode(
|
||||
node_id=human_b_config["id"],
|
||||
config=human_data,
|
||||
data=human_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
form_repository=repo,
|
||||
@ -183,7 +183,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
|
||||
end_config = {"id": "end", "data": end_data.model_dump()}
|
||||
end_node = EndNode(
|
||||
node_id=end_config["id"],
|
||||
config=end_data,
|
||||
data=end_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
@ -1,41 +1,36 @@
|
||||
import time
|
||||
import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.system_variables import build_system_variables
|
||||
from extensions.ext_database import db
|
||||
from graphon.enums import WorkflowNodeExecutionStatus
|
||||
from graphon.graph import Graph
|
||||
from graphon.nodes.answer.answer_node import AnswerNode
|
||||
from graphon.nodes.answer.entities import AnswerNodeData
|
||||
from graphon.runtime import GraphRuntimeState, VariablePool
|
||||
from tests.workflow_test_utils import build_test_graph_init_params
|
||||
|
||||
|
||||
def test_execute_answer():
|
||||
def _build_variable_pool() -> VariablePool:
|
||||
return VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
|
||||
def _build_answer_node(*, answer: str, variable_pool: VariablePool) -> AnswerNode:
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-answer-target",
|
||||
"source": "start",
|
||||
"target": "answer",
|
||||
},
|
||||
],
|
||||
"edges": [],
|
||||
"nodes": [
|
||||
{"data": {"type": "start", "title": "Start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"title": "123",
|
||||
"title": "Answer",
|
||||
"type": "answer",
|
||||
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
|
||||
"answer": answer,
|
||||
},
|
||||
"id": "answer",
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
init_params = build_test_graph_init_params(
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
@ -46,42 +41,31 @@ def test_execute_answer():
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=build_system_variables(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=time.perf_counter(),
|
||||
)
|
||||
variable_pool.add(["start", "weather"], "sunny")
|
||||
variable_pool.add(["llm", "text"], "You are a helpful AI.")
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# create node factory
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
|
||||
|
||||
node = AnswerNode(
|
||||
return AnswerNode(
|
||||
node_id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
config=AnswerNodeData(
|
||||
title="123",
|
||||
data=AnswerNodeData(
|
||||
title="Answer",
|
||||
type="answer",
|
||||
answer="Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
|
||||
answer=answer,
|
||||
),
|
||||
)
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
# execute node
|
||||
def test_execute_answer_renders_variable_selectors() -> None:
|
||||
variable_pool = _build_variable_pool()
|
||||
variable_pool.add(["start", "weather"], "sunny")
|
||||
variable_pool.add(["llm", "text"], "You are a helpful AI.")
|
||||
node = _build_answer_node(
|
||||
answer="Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
@ -89,36 +73,11 @@ def test_execute_answer():
|
||||
|
||||
|
||||
def test_execute_answer_renders_structured_output_object_as_json() -> None:
|
||||
init_params = build_test_graph_init_params(
|
||||
workflow_id="1",
|
||||
graph_config={"nodes": [], "edges": []},
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=build_system_variables(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
variable_pool = _build_variable_pool()
|
||||
variable_pool.add(["1777539038857", "structured_output"], {"type": "greeting"})
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
node = AnswerNode(
|
||||
node_id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
config=AnswerNodeData(
|
||||
title="123",
|
||||
type="answer",
|
||||
answer="{{#1777539038857.structured_output#}}",
|
||||
),
|
||||
node = _build_answer_node(
|
||||
answer="{{#1777539038857.structured_output#}}",
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
@ -128,35 +87,9 @@ def test_execute_answer_renders_structured_output_object_as_json() -> None:
|
||||
|
||||
|
||||
def test_execute_answer_falls_back_to_plain_selector_text_when_structured_output_missing() -> None:
|
||||
init_params = build_test_graph_init_params(
|
||||
workflow_id="1",
|
||||
graph_config={"nodes": [], "edges": []},
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=build_system_variables(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
node = AnswerNode(
|
||||
node_id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
config=AnswerNodeData(
|
||||
title="123",
|
||||
type="answer",
|
||||
answer="{{#1777539038857.structured_output#}}",
|
||||
),
|
||||
node = _build_answer_node(
|
||||
answer="{{#1777539038857.structured_output#}}",
|
||||
variable_pool=_build_variable_pool(),
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
|
||||
@ -81,7 +81,7 @@ def test_datasource_node_delegates_to_manager_stream(mocker: MockerFixture):
|
||||
|
||||
node = DatasourceNode(
|
||||
node_id="n",
|
||||
config=DatasourceNodeData(
|
||||
data=DatasourceNodeData(
|
||||
type="datasource",
|
||||
version="1",
|
||||
title="Datasource",
|
||||
|
||||
@ -29,7 +29,7 @@ HTTP_REQUEST_CONFIG = HttpRequestNodeConfig(
|
||||
|
||||
def test_executor_with_json_body_and_number_variable():
|
||||
# Prepare the variable pool
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=default_system_variables(),
|
||||
user_inputs={},
|
||||
)
|
||||
@ -85,7 +85,7 @@ def test_executor_with_json_body_and_number_variable():
|
||||
|
||||
def test_executor_with_json_body_and_object_variable():
|
||||
# Prepare the variable pool
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=default_system_variables(),
|
||||
user_inputs={},
|
||||
)
|
||||
@ -143,7 +143,7 @@ def test_executor_with_json_body_and_object_variable():
|
||||
|
||||
def test_executor_with_json_body_and_nested_object_variable():
|
||||
# Prepare the variable pool
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=default_system_variables(),
|
||||
user_inputs={},
|
||||
)
|
||||
@ -201,7 +201,7 @@ def test_executor_with_json_body_and_nested_object_variable():
|
||||
|
||||
|
||||
def test_extract_selectors_from_template_with_newline():
|
||||
variable_pool = VariablePool(system_variables=default_system_variables())
|
||||
variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables())
|
||||
variable_pool.add(("node_id", "custom_query"), "line1\nline2")
|
||||
node_data = HttpRequestNodeData(
|
||||
title="Test JSON Body with Nested Object Variable",
|
||||
@ -230,7 +230,7 @@ def test_extract_selectors_from_template_with_newline():
|
||||
|
||||
def test_executor_with_form_data():
|
||||
# Prepare the variable pool
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=default_system_variables(),
|
||||
user_inputs={},
|
||||
)
|
||||
@ -320,7 +320,7 @@ def test_init_headers():
|
||||
node_data=node_data,
|
||||
timeout=timeout,
|
||||
http_request_config=HTTP_REQUEST_CONFIG,
|
||||
variable_pool=VariablePool(system_variables=default_system_variables()),
|
||||
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()),
|
||||
http_client=ssrf_proxy,
|
||||
file_manager=file_manager,
|
||||
)
|
||||
@ -357,7 +357,7 @@ def test_init_params():
|
||||
node_data=node_data,
|
||||
timeout=timeout,
|
||||
http_request_config=HTTP_REQUEST_CONFIG,
|
||||
variable_pool=VariablePool(system_variables=default_system_variables()),
|
||||
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()),
|
||||
http_client=ssrf_proxy,
|
||||
file_manager=file_manager,
|
||||
)
|
||||
@ -390,7 +390,7 @@ def test_init_params():
|
||||
|
||||
def test_empty_api_key_raises_error_bearer():
|
||||
"""Test that empty API key raises AuthorizationConfigError for bearer auth."""
|
||||
variable_pool = VariablePool(system_variables=default_system_variables())
|
||||
variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables())
|
||||
node_data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="get",
|
||||
@ -417,7 +417,7 @@ def test_empty_api_key_raises_error_bearer():
|
||||
|
||||
def test_empty_api_key_raises_error_basic():
|
||||
"""Test that empty API key raises AuthorizationConfigError for basic auth."""
|
||||
variable_pool = VariablePool(system_variables=default_system_variables())
|
||||
variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables())
|
||||
node_data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="get",
|
||||
@ -444,7 +444,7 @@ def test_empty_api_key_raises_error_basic():
|
||||
|
||||
def test_empty_api_key_raises_error_custom():
|
||||
"""Test that empty API key raises AuthorizationConfigError for custom auth."""
|
||||
variable_pool = VariablePool(system_variables=default_system_variables())
|
||||
variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables())
|
||||
node_data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="get",
|
||||
@ -471,7 +471,7 @@ def test_empty_api_key_raises_error_custom():
|
||||
|
||||
def test_whitespace_only_api_key_raises_error():
|
||||
"""Test that whitespace-only API key raises AuthorizationConfigError."""
|
||||
variable_pool = VariablePool(system_variables=default_system_variables())
|
||||
variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables())
|
||||
node_data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="get",
|
||||
@ -498,7 +498,7 @@ def test_whitespace_only_api_key_raises_error():
|
||||
|
||||
def test_valid_api_key_works():
|
||||
"""Test that valid API key works correctly for bearer auth."""
|
||||
variable_pool = VariablePool(system_variables=default_system_variables())
|
||||
variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables())
|
||||
node_data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="get",
|
||||
@ -536,7 +536,7 @@ def test_executor_with_json_body_and_unquoted_uuid_variable():
|
||||
# UUID that triggers the json_repair truncation bug
|
||||
test_uuid = "57eeeeb1-450b-482c-81b9-4be77e95dee2"
|
||||
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=default_system_variables(),
|
||||
user_inputs={},
|
||||
)
|
||||
@ -583,7 +583,7 @@ def test_executor_with_json_body_and_unquoted_uuid_with_newlines():
|
||||
"""
|
||||
test_uuid = "57eeeeb1-450b-482c-81b9-4be77e95dee2"
|
||||
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=default_system_variables(),
|
||||
user_inputs={},
|
||||
)
|
||||
@ -624,7 +624,7 @@ def test_executor_with_json_body_and_unquoted_uuid_with_newlines():
|
||||
|
||||
def test_executor_with_json_body_preserves_numbers_and_strings():
|
||||
"""Test that numbers are preserved and string values are properly quoted."""
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=default_system_variables(),
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
@ -110,12 +110,15 @@ def _build_http_node(
|
||||
call_depth=0,
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}),
|
||||
variable_pool=VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(user_id="user", files=[]),
|
||||
user_inputs={},
|
||||
),
|
||||
start_at=time.perf_counter(),
|
||||
)
|
||||
return HttpRequestNode(
|
||||
node_id="http-node",
|
||||
config=HttpRequestNodeData.model_validate(node_data),
|
||||
data=HttpRequestNodeData.model_validate(node_data),
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
http_request_config=HTTP_REQUEST_CONFIG,
|
||||
|
||||
@ -149,7 +149,7 @@ def _build_human_input_node(
|
||||
)
|
||||
return HumanInputNode(
|
||||
node_id=node_id,
|
||||
config=typed_node_data,
|
||||
data=typed_node_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
runtime=runtime,
|
||||
@ -241,16 +241,16 @@ class TestUserAction:
|
||||
|
||||
def test_user_action_length_boundaries(self):
|
||||
"""Test user action id and title length boundaries."""
|
||||
action = UserAction(id="a" * 20, title="b" * 20)
|
||||
action = UserAction(id="a" * 20, title="b" * 100)
|
||||
|
||||
assert action.id == "a" * 20
|
||||
assert action.title == "b" * 20
|
||||
assert action.title == "b" * 100
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("field_name", "value"),
|
||||
[
|
||||
("id", "a" * 21),
|
||||
("title", "b" * 21),
|
||||
("title", "b" * 101),
|
||||
],
|
||||
)
|
||||
def test_user_action_length_limits(self, field_name: str, value: str):
|
||||
@ -427,7 +427,7 @@ class TestHumanInputNodeVariableResolution:
|
||||
"""Tests for resolving variable-based defaults in HumanInputNode."""
|
||||
|
||||
def test_resolves_variable_defaults(self):
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(
|
||||
user_id="user",
|
||||
app_id="app",
|
||||
@ -504,7 +504,7 @@ class TestHumanInputNodeVariableResolution:
|
||||
assert params.resolved_default_values == expected_values
|
||||
|
||||
def test_debugger_falls_back_to_recipient_token_when_webapp_disabled(self):
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(
|
||||
user_id="user",
|
||||
app_id="app",
|
||||
@ -565,7 +565,7 @@ class TestHumanInputNodeVariableResolution:
|
||||
assert not hasattr(pause_event.reason, "form_token")
|
||||
|
||||
def test_webapp_runtime_keeps_form_visible_in_ui_when_webapp_delivery_is_enabled(self):
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(
|
||||
user_id="user",
|
||||
app_id="app",
|
||||
@ -631,7 +631,7 @@ class TestHumanInputNodeVariableResolution:
|
||||
assert params.display_in_ui is True
|
||||
|
||||
def test_debugger_debug_mode_overrides_email_recipients(self):
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(
|
||||
user_id="user-123",
|
||||
app_id="app",
|
||||
@ -748,7 +748,7 @@ class TestHumanInputNodeRenderedContent:
|
||||
"""Tests for rendering submitted content."""
|
||||
|
||||
def test_replaces_outputs_placeholders_after_submission(self):
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(
|
||||
user_id="user",
|
||||
app_id="app",
|
||||
|
||||
@ -40,7 +40,7 @@ def _create_human_input_node(
|
||||
)
|
||||
return HumanInputNode(
|
||||
node_id=config["id"],
|
||||
config=node_data,
|
||||
data=node_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
form_repository=repo,
|
||||
@ -51,7 +51,11 @@ def _create_human_input_node(
|
||||
def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#}}") -> HumanInputNode:
|
||||
system_variables = default_system_variables()
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=system_variables, user_inputs={}, environment_variables=[]),
|
||||
variable_pool=VariablePool.from_bootstrap(
|
||||
system_variables=system_variables,
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
),
|
||||
start_at=0.0,
|
||||
)
|
||||
graph_init_params = GraphInitParams(
|
||||
@ -114,7 +118,11 @@ def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#
|
||||
def _build_timeout_node() -> HumanInputNode:
|
||||
system_variables = default_system_variables()
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=system_variables, user_inputs={}, environment_variables=[]),
|
||||
variable_pool=VariablePool.from_bootstrap(
|
||||
system_variables=system_variables,
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
),
|
||||
start_at=0.0,
|
||||
)
|
||||
graph_init_params = GraphInitParams(
|
||||
|
||||
@ -32,7 +32,7 @@ class _MissingGraphBuilder:
|
||||
|
||||
def _build_runtime_state() -> GraphRuntimeState:
|
||||
return GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=default_system_variables(), user_inputs={}),
|
||||
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables(), user_inputs={}),
|
||||
start_at=0.0,
|
||||
)
|
||||
|
||||
@ -46,7 +46,7 @@ def _build_iteration_node(
|
||||
init_params = build_test_graph_init_params(graph_config=graph_config)
|
||||
return IterationNode(
|
||||
node_id="iteration-node",
|
||||
config=IterationNodeData(
|
||||
data=IterationNodeData(
|
||||
type="iteration",
|
||||
title="Iteration",
|
||||
iterator_selector=["start", "items"],
|
||||
|
||||
@ -41,7 +41,7 @@ def mock_graph_init_params():
|
||||
@pytest.fixture
|
||||
def mock_graph_runtime_state():
|
||||
"""Create mock GraphRuntimeState."""
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(user_id=str(uuid.uuid4()), files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
@ -103,7 +103,7 @@ def _build_node(
|
||||
) -> KnowledgeIndexNode:
|
||||
return KnowledgeIndexNode(
|
||||
node_id=node_id,
|
||||
config=(
|
||||
data=(
|
||||
node_data
|
||||
if isinstance(node_data, KnowledgeIndexNodeData)
|
||||
else KnowledgeIndexNodeData.model_validate(node_data)
|
||||
|
||||
@ -47,7 +47,7 @@ def mock_graph_init_params():
|
||||
@pytest.fixture
|
||||
def mock_graph_runtime_state():
|
||||
"""Create mock GraphRuntimeState."""
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(user_id=str(uuid.uuid4()), files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
@ -118,7 +118,7 @@ class TestKnowledgeRetrievalNode:
|
||||
# Act
|
||||
node = KnowledgeRetrievalNode(
|
||||
node_id=node_id,
|
||||
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@ -147,7 +147,7 @@ class TestKnowledgeRetrievalNode:
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
node_id=node_id,
|
||||
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@ -206,7 +206,7 @@ class TestKnowledgeRetrievalNode:
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
node_id=node_id,
|
||||
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@ -250,7 +250,7 @@ class TestKnowledgeRetrievalNode:
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
node_id=node_id,
|
||||
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@ -286,7 +286,7 @@ class TestKnowledgeRetrievalNode:
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
node_id=node_id,
|
||||
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@ -321,7 +321,7 @@ class TestKnowledgeRetrievalNode:
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
node_id=node_id,
|
||||
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@ -362,7 +362,7 @@ class TestKnowledgeRetrievalNode:
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
node_id=node_id,
|
||||
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@ -401,7 +401,7 @@ class TestKnowledgeRetrievalNode:
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
node_id=node_id,
|
||||
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@ -482,7 +482,7 @@ class TestFetchDatasetRetriever:
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
node_id=node_id,
|
||||
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@ -519,7 +519,7 @@ class TestFetchDatasetRetriever:
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
node_id=node_id,
|
||||
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@ -574,7 +574,7 @@ class TestFetchDatasetRetriever:
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
node_id=node_id,
|
||||
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@ -622,7 +622,7 @@ class TestFetchDatasetRetriever:
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
node_id=node_id,
|
||||
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@ -683,7 +683,7 @@ class TestFetchDatasetRetriever:
|
||||
config = {"id": node_id, "data": node_data.model_dump()}
|
||||
node = KnowledgeRetrievalNode(
|
||||
node_id=node_id,
|
||||
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
|
||||
@ -16,10 +16,10 @@ class TestListOperatorNode:
|
||||
"""Comprehensive tests for ListOperatorNode."""
|
||||
|
||||
@staticmethod
|
||||
def _build_node(*, config, graph_init_params, graph_runtime_state):
|
||||
def _build_node(*, data, graph_init_params, graph_runtime_state):
|
||||
return ListOperatorNode(
|
||||
node_id="test",
|
||||
config=config if isinstance(config, ListOperatorNodeData) else ListOperatorNodeData.model_validate(config),
|
||||
data=data if isinstance(data, ListOperatorNodeData) else ListOperatorNodeData.model_validate(data),
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
@ -65,7 +65,7 @@ class TestListOperatorNode:
|
||||
def _create_node(config, mock_variable):
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_variable
|
||||
return self._build_node(
|
||||
config=config,
|
||||
data=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@ -83,7 +83,7 @@ class TestListOperatorNode:
|
||||
}
|
||||
|
||||
node = self._build_node(
|
||||
config=config,
|
||||
data=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@ -127,7 +127,7 @@ class TestListOperatorNode:
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
|
||||
|
||||
node = self._build_node(
|
||||
config=config,
|
||||
data=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@ -153,7 +153,7 @@ class TestListOperatorNode:
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
|
||||
|
||||
node = self._build_node(
|
||||
config=config,
|
||||
data=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@ -177,7 +177,7 @@ class TestListOperatorNode:
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
|
||||
|
||||
node = self._build_node(
|
||||
config=config,
|
||||
data=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@ -201,7 +201,7 @@ class TestListOperatorNode:
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
|
||||
|
||||
node = self._build_node(
|
||||
config=config,
|
||||
data=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@ -228,7 +228,7 @@ class TestListOperatorNode:
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
|
||||
|
||||
node = self._build_node(
|
||||
config=config,
|
||||
data=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@ -255,7 +255,7 @@ class TestListOperatorNode:
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
|
||||
|
||||
node = self._build_node(
|
||||
config=config,
|
||||
data=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@ -282,7 +282,7 @@ class TestListOperatorNode:
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
|
||||
|
||||
node = self._build_node(
|
||||
config=config,
|
||||
data=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@ -312,7 +312,7 @@ class TestListOperatorNode:
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
|
||||
|
||||
node = self._build_node(
|
||||
config=config,
|
||||
data=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@ -335,7 +335,7 @@ class TestListOperatorNode:
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = None
|
||||
|
||||
node = self._build_node(
|
||||
config=config,
|
||||
data=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@ -359,7 +359,7 @@ class TestListOperatorNode:
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
|
||||
|
||||
node = self._build_node(
|
||||
config=config,
|
||||
data=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@ -384,7 +384,7 @@ class TestListOperatorNode:
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
|
||||
|
||||
node = self._build_node(
|
||||
config=config,
|
||||
data=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@ -408,7 +408,7 @@ class TestListOperatorNode:
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
|
||||
|
||||
node = self._build_node(
|
||||
config=config,
|
||||
data=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@ -432,7 +432,7 @@ class TestListOperatorNode:
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
|
||||
|
||||
node = self._build_node(
|
||||
config=config,
|
||||
data=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@ -456,7 +456,7 @@ class TestListOperatorNode:
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
|
||||
|
||||
node = self._build_node(
|
||||
config=config,
|
||||
data=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
@ -483,7 +483,7 @@ class TestListOperatorNode:
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
|
||||
|
||||
node = self._build_node(
|
||||
config=config,
|
||||
data=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
|
||||
@ -15,7 +15,7 @@ from core.app.llm.model_access import (
|
||||
)
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||
from core.entities.provider_entities import CustomConfiguration, SystemConfiguration
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.workflow.system_variables import default_system_variables
|
||||
from graphon.entities import GraphInitParams
|
||||
@ -187,7 +187,7 @@ def graph_init_params() -> GraphInitParams:
|
||||
|
||||
@pytest.fixture
|
||||
def graph_runtime_state() -> GraphRuntimeState:
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=default_system_variables(),
|
||||
user_inputs={},
|
||||
)
|
||||
@ -208,7 +208,7 @@ def llm_node(
|
||||
http_client = mock.MagicMock()
|
||||
node = LLMNode(
|
||||
node_id="1",
|
||||
config=llm_node_data,
|
||||
data=llm_node_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
credentials_provider=mock_credentials_provider,
|
||||
@ -241,9 +241,10 @@ def model_config(monkeypatch: pytest.MonkeyPatch):
|
||||
)
|
||||
|
||||
# Create actual provider and model type instances
|
||||
model_provider_factory = ModelProviderFactory(model_runtime=create_plugin_model_runtime(tenant_id="test"))
|
||||
model_assembly = create_plugin_model_assembly(tenant_id="test")
|
||||
model_provider_factory = model_assembly.model_provider_factory
|
||||
provider_instance = model_provider_factory.get_model_provider("openai")
|
||||
model_type_instance = model_provider_factory.get_model_type_instance("openai", ModelType.LLM)
|
||||
model_type_instance = model_assembly.create_model_type_instance(provider="openai", model_type=ModelType.LLM)
|
||||
|
||||
# Create a ProviderModelBundle
|
||||
provider_model_bundle = ProviderModelBundle(
|
||||
@ -1173,7 +1174,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
|
||||
http_client = mock.MagicMock()
|
||||
node = LLMNode(
|
||||
node_id="1",
|
||||
config=llm_node_data,
|
||||
data=llm_node_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
credentials_provider=mock_credentials_provider,
|
||||
|
||||
@ -28,7 +28,7 @@ def _build_template_transform_node(
|
||||
)
|
||||
return TemplateTransformNode(
|
||||
node_id=node_id,
|
||||
config=typed_node_data,
|
||||
data=typed_node_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
**kwargs,
|
||||
|
||||
@ -39,7 +39,7 @@ def mock_graph_runtime_state():
|
||||
def test_node_uses_default_max_output_length_when_not_overridden(graph_init_params, mock_graph_runtime_state):
|
||||
node = TemplateTransformNode(
|
||||
node_id="test_node",
|
||||
config=TemplateTransformNodeData(
|
||||
data=TemplateTransformNodeData(
|
||||
title="Template Transform",
|
||||
type="template-transform",
|
||||
variables=[],
|
||||
|
||||
@ -35,7 +35,10 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams,
|
||||
invoke_from="debugger",
|
||||
)
|
||||
runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}),
|
||||
variable_pool=VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(user_id="user", files=[]),
|
||||
user_inputs={},
|
||||
),
|
||||
start_at=0.0,
|
||||
)
|
||||
return init_params, runtime_state
|
||||
@ -62,7 +65,7 @@ def test_node_hydrates_data_during_initialization():
|
||||
|
||||
node = _SampleNode(
|
||||
node_id="node-1",
|
||||
config=_build_node_data(),
|
||||
data=_build_node_data(),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
@ -82,13 +85,16 @@ def test_node_accepts_invoke_from_enum():
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
)
|
||||
runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}),
|
||||
variable_pool=VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(user_id="user", files=[]),
|
||||
user_inputs={},
|
||||
),
|
||||
start_at=0.0,
|
||||
)
|
||||
|
||||
node = _SampleNode(
|
||||
node_id="node-1",
|
||||
config=_build_node_data(),
|
||||
data=_build_node_data(),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
@ -140,7 +146,7 @@ def test_node_hydration_preserves_compatibility_extra_fields():
|
||||
|
||||
node = _SampleNode(
|
||||
node_id="node-1",
|
||||
config=node_config["data"],
|
||||
data=node_config["data"],
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
@ -49,7 +49,7 @@ def document_extractor_node(graph_init_params):
|
||||
http_client = Mock()
|
||||
node = DocumentExtractorNode(
|
||||
node_id="test_node_id",
|
||||
config=node_data,
|
||||
data=node_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=Mock(),
|
||||
http_client=http_client,
|
||||
@ -186,12 +186,13 @@ def test_run_extract_text(
|
||||
|
||||
monkeypatch.setattr("graphon.file.file_manager.download", mock_download)
|
||||
|
||||
dispatch_mock = None
|
||||
if mime_type == "application/pdf":
|
||||
mock_pdf_extract = Mock(return_value=expected_text[0])
|
||||
monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_from_pdf", mock_pdf_extract)
|
||||
dispatch_mock = Mock(return_value=expected_text[0])
|
||||
monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_by_file_extension", dispatch_mock)
|
||||
elif mime_type.startswith("application/vnd.openxmlformats"):
|
||||
mock_docx_extract = Mock(return_value=expected_text[0])
|
||||
monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_from_docx", mock_docx_extract)
|
||||
dispatch_mock = Mock(return_value=expected_text[0])
|
||||
monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_by_mime_type", dispatch_mock)
|
||||
|
||||
result = document_extractor_node._run()
|
||||
|
||||
@ -200,6 +201,19 @@ def test_run_extract_text(
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["text"] == ArrayStringSegment(value=expected_text)
|
||||
|
||||
if mime_type == "application/pdf":
|
||||
dispatch_mock.assert_called_once_with(
|
||||
file_content=file_content,
|
||||
file_extension=extension,
|
||||
unstructured_api_config=document_extractor_node._unstructured_api_config,
|
||||
)
|
||||
elif mime_type.startswith("application/vnd.openxmlformats"):
|
||||
dispatch_mock.assert_called_once_with(
|
||||
file_content=file_content,
|
||||
mime_type=mime_type,
|
||||
unstructured_api_config=document_extractor_node._unstructured_api_config,
|
||||
)
|
||||
|
||||
if transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
document_extractor_node._http_client.get.assert_called_once_with("https://example.com/file.txt")
|
||||
elif transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
@ -439,24 +453,42 @@ def test_extract_text_from_file_routes_excel_inputs(document_extractor_node, ext
|
||||
file.extension = extension
|
||||
file.mime_type = mime_type
|
||||
|
||||
with (
|
||||
patch(
|
||||
"graphon.nodes.document_extractor.node._download_file_content",
|
||||
return_value=b"excel",
|
||||
),
|
||||
patch(
|
||||
"graphon.nodes.document_extractor.node._extract_text_from_excel",
|
||||
return_value="excel text",
|
||||
) as mock_extract,
|
||||
with patch(
|
||||
"graphon.nodes.document_extractor.node._download_file_content",
|
||||
return_value=b"excel",
|
||||
):
|
||||
result = _extract_text_from_file(
|
||||
document_extractor_node.http_client,
|
||||
file,
|
||||
unstructured_api_config=document_extractor_node._unstructured_api_config,
|
||||
)
|
||||
if extension:
|
||||
with patch(
|
||||
"graphon.nodes.document_extractor.node._extract_text_by_file_extension",
|
||||
return_value="excel text",
|
||||
) as mock_extract:
|
||||
result = _extract_text_from_file(
|
||||
document_extractor_node.http_client,
|
||||
file,
|
||||
unstructured_api_config=document_extractor_node._unstructured_api_config,
|
||||
)
|
||||
mock_extract.assert_called_once_with(
|
||||
file_content=b"excel",
|
||||
file_extension=extension,
|
||||
unstructured_api_config=document_extractor_node._unstructured_api_config,
|
||||
)
|
||||
else:
|
||||
with patch(
|
||||
"graphon.nodes.document_extractor.node._extract_text_by_mime_type",
|
||||
return_value="excel text",
|
||||
) as mock_extract:
|
||||
result = _extract_text_from_file(
|
||||
document_extractor_node.http_client,
|
||||
file,
|
||||
unstructured_api_config=document_extractor_node._unstructured_api_config,
|
||||
)
|
||||
mock_extract.assert_called_once_with(
|
||||
file_content=b"excel",
|
||||
mime_type=mime_type,
|
||||
unstructured_api_config=document_extractor_node._unstructured_api_config,
|
||||
)
|
||||
|
||||
assert result == "excel text"
|
||||
mock_extract.assert_called_once_with(b"excel")
|
||||
|
||||
|
||||
def test_extract_text_from_file_rejects_missing_extension_and_mime_type(document_extractor_node):
|
||||
|
||||
@ -29,7 +29,7 @@ def _build_if_else_node(
|
||||
node_id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
config=node_data if isinstance(node_data, IfElseNodeData) else IfElseNodeData.model_validate(node_data),
|
||||
data=node_data if isinstance(node_data, IfElseNodeData) else IfElseNodeData.model_validate(node_data),
|
||||
)
|
||||
|
||||
|
||||
@ -48,7 +48,10 @@ def test_execute_if_else_result_true():
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={})
|
||||
pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
)
|
||||
pool.add(["start", "array_contains"], ["ab", "def"])
|
||||
pool.add(["start", "array_not_contains"], ["ac", "def"])
|
||||
pool.add(["start", "contains"], "cabcde")
|
||||
@ -148,7 +151,7 @@ def test_execute_if_else_result_false():
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(
|
||||
pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
@ -305,7 +308,7 @@ def test_execute_if_else_boolean_conditions(condition: Condition):
|
||||
)
|
||||
|
||||
# construct variable pool with boolean values
|
||||
pool = VariablePool(
|
||||
pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(files=[], user_id="aaa"),
|
||||
)
|
||||
pool.add(["start", "bool_true"], True)
|
||||
@ -359,7 +362,7 @@ def test_execute_if_else_boolean_false_conditions():
|
||||
)
|
||||
|
||||
# construct variable pool with boolean values
|
||||
pool = VariablePool(
|
||||
pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(files=[], user_id="aaa"),
|
||||
)
|
||||
pool.add(["start", "bool_true"], True)
|
||||
@ -424,7 +427,7 @@ def test_execute_if_else_boolean_cases_structure():
|
||||
)
|
||||
|
||||
# construct variable pool with boolean values
|
||||
pool = VariablePool(
|
||||
pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(files=[], user_id="aaa"),
|
||||
)
|
||||
pool.add(["start", "bool_true"], True)
|
||||
|
||||
@ -22,7 +22,7 @@ from graphon.variables import ArrayFileSegment
|
||||
def _build_list_operator_node(node_data: ListOperatorNodeData, graph_init_params) -> ListOperatorNode:
|
||||
return ListOperatorNode(
|
||||
node_id="test_node_id",
|
||||
config=node_data,
|
||||
data=node_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=MagicMock(),
|
||||
)
|
||||
|
||||
@ -31,7 +31,7 @@ def make_start_node(user_inputs, variables):
|
||||
|
||||
return StartNode(
|
||||
node_id="start",
|
||||
config=node_data,
|
||||
data=node_data,
|
||||
graph_init_params=build_test_graph_init_params(
|
||||
workflow_id="wf",
|
||||
graph_config={},
|
||||
@ -260,7 +260,7 @@ def test_start_node_outputs_full_variable_pool_snapshot():
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
node = StartNode(
|
||||
node_id="start",
|
||||
config=node_data,
|
||||
data=node_data,
|
||||
graph_init_params=build_test_graph_init_params(
|
||||
workflow_id="wf",
|
||||
graph_config={},
|
||||
|
||||
@ -99,7 +99,7 @@ def tool_node(monkeypatch) -> ToolNode:
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(system_variables=build_system_variables(user_id="user-id"))
|
||||
variable_pool = VariablePool.from_bootstrap(system_variables=build_system_variables(user_id="user-id"))
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
|
||||
|
||||
config = graph_config["nodes"][0]
|
||||
@ -110,7 +110,7 @@ def tool_node(monkeypatch) -> ToolNode:
|
||||
|
||||
node = ToolNode(
|
||||
node_id="node-instance",
|
||||
config=ToolNodeData.model_validate(config["data"]),
|
||||
data=ToolNodeData.model_validate(config["data"]),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
tool_file_manager_factory=tool_file_manager_factory,
|
||||
|
||||
@ -44,7 +44,7 @@ def test_trigger_event_node_run_populates_trigger_info_metadata() -> None:
|
||||
init_params, runtime_state = _build_context(graph_config={})
|
||||
node = TriggerEventNode(
|
||||
node_id="node-1",
|
||||
config=_build_node_data(),
|
||||
data=_build_node_data(),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
@ -52,7 +52,7 @@ def create_webhook_node(
|
||||
|
||||
node = TriggerWebhookNode(
|
||||
node_id="webhook-node-1",
|
||||
config=webhook_data,
|
||||
data=webhook_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
@ -44,7 +44,7 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool)
|
||||
)
|
||||
node = TriggerWebhookNode(
|
||||
node_id="1",
|
||||
config=webhook_data,
|
||||
data=webhook_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from collections.abc import Mapping
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch, sentinel
|
||||
|
||||
@ -11,19 +12,20 @@ from graphon.entities.base_node_data import BaseNodeData
|
||||
from graphon.enums import BuiltinNodeTypes, NodeType
|
||||
from graphon.nodes.code.entities import CodeLanguage
|
||||
from graphon.nodes.llm.entities import LLMNodeData
|
||||
from graphon.nodes.llm.node import LLMNode
|
||||
from graphon.variables.segments import StringSegment
|
||||
|
||||
|
||||
def _assert_typed_node_config(config, *, node_id: str, node_type: NodeType, version: str = "1") -> None:
|
||||
def _assert_constructor_node_data(data, *, node_id: str, node_type: NodeType, version: str = "1") -> None:
|
||||
_ = node_id
|
||||
if isinstance(config, BaseNodeData):
|
||||
assert config.type == node_type
|
||||
assert config.version == version
|
||||
if isinstance(data, BaseNodeData):
|
||||
assert data.type == node_type
|
||||
assert data.version == version
|
||||
return
|
||||
|
||||
assert isinstance(config, dict)
|
||||
assert config["type"] == node_type
|
||||
assert config["version"] == version
|
||||
assert isinstance(data, Mapping)
|
||||
assert data["type"] == node_type
|
||||
assert data.get("version", "1") == version
|
||||
|
||||
|
||||
def _node_constructor(*, return_value):
|
||||
@ -470,7 +472,7 @@ class TestDifyNodeFactoryCreateNode:
|
||||
matched_node_class.assert_called_once()
|
||||
kwargs = matched_node_class.call_args.kwargs
|
||||
assert kwargs["node_id"] == "node-id"
|
||||
_assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=BuiltinNodeTypes.START, version="9")
|
||||
_assert_constructor_node_data(kwargs["data"], node_id="node-id", node_type=BuiltinNodeTypes.START, version="9")
|
||||
assert kwargs["graph_init_params"] is sentinel.graph_init_params
|
||||
assert kwargs["graph_runtime_state"] is factory.graph_runtime_state
|
||||
latest_node_class.assert_not_called()
|
||||
@ -492,7 +494,7 @@ class TestDifyNodeFactoryCreateNode:
|
||||
latest_node_class.assert_called_once()
|
||||
kwargs = latest_node_class.call_args.kwargs
|
||||
assert kwargs["node_id"] == "node-id"
|
||||
_assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=BuiltinNodeTypes.START, version="9")
|
||||
_assert_constructor_node_data(kwargs["data"], node_id="node-id", node_type=BuiltinNodeTypes.START, version="9")
|
||||
assert kwargs["graph_init_params"] is sentinel.graph_init_params
|
||||
assert kwargs["graph_runtime_state"] is factory.graph_runtime_state
|
||||
|
||||
@ -530,7 +532,7 @@ class TestDifyNodeFactoryCreateNode:
|
||||
assert result is created_node
|
||||
kwargs = constructor.call_args.kwargs
|
||||
assert kwargs["node_id"] == "node-id"
|
||||
_assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=node_type)
|
||||
_assert_constructor_node_data(kwargs["data"], node_id="node-id", node_type=node_type)
|
||||
assert kwargs["graph_init_params"] is sentinel.graph_init_params
|
||||
assert kwargs["graph_runtime_state"] is factory.graph_runtime_state
|
||||
|
||||
@ -599,11 +601,12 @@ class TestDifyNodeFactoryCreateNode:
|
||||
prepared_llm.assert_called_once_with(sentinel.model_instance)
|
||||
assert kwargs["model_instance"] is wrapped_model_instance
|
||||
|
||||
def test_create_node_passes_alias_preserving_llm_config_to_constructor(
|
||||
self, monkeypatch: pytest.MonkeyPatch, factory
|
||||
):
|
||||
def test_create_node_passes_alias_preserving_llm_data_to_constructor(self, monkeypatch, factory):
|
||||
created_node = object()
|
||||
constructor = _node_constructor(return_value=created_node)
|
||||
constructor.validate_node_data.side_effect = lambda node_data: LLMNodeData.model_validate(
|
||||
node_data.model_dump(mode="python") if isinstance(node_data, BaseNodeData) else node_data
|
||||
)
|
||||
monkeypatch.setattr(factory, "_resolve_node_class", MagicMock(return_value=constructor))
|
||||
monkeypatch.setattr(factory, "_build_llm_compatible_node_init_kwargs", MagicMock(return_value={}))
|
||||
|
||||
@ -629,10 +632,56 @@ class TestDifyNodeFactoryCreateNode:
|
||||
|
||||
factory.create_node(node_config)
|
||||
|
||||
config = constructor.call_args.kwargs["config"]
|
||||
assert isinstance(config, dict)
|
||||
assert config["structured_output_enabled"] is True
|
||||
assert "structured_output_switch_on" not in config
|
||||
data = constructor.call_args.kwargs["data"]
|
||||
assert isinstance(data, Mapping)
|
||||
assert data["structured_output_enabled"] is True
|
||||
assert "structured_output_switch_on" not in data
|
||||
assert LLMNodeData.model_validate(data).structured_output_enabled is True
|
||||
|
||||
def test_create_node_preserves_structured_output_switch_after_graphon_constructor(self, monkeypatch, factory):
|
||||
factory.graph_init_params = SimpleNamespace(
|
||||
workflow_id="workflow-id",
|
||||
graph_config={},
|
||||
run_context={},
|
||||
call_depth=0,
|
||||
)
|
||||
monkeypatch.setattr(factory, "_resolve_node_class", MagicMock(return_value=LLMNode))
|
||||
monkeypatch.setattr(
|
||||
factory,
|
||||
"_build_llm_compatible_node_init_kwargs",
|
||||
MagicMock(
|
||||
return_value={
|
||||
"model_instance": sentinel.model_instance,
|
||||
"llm_file_saver": sentinel.llm_file_saver,
|
||||
"prompt_message_serializer": sentinel.prompt_message_serializer,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
node_config = {
|
||||
"id": "llm-node-id",
|
||||
"data": {
|
||||
"type": BuiltinNodeTypes.LLM,
|
||||
"title": "LLM",
|
||||
"model": {"provider": "provider", "name": "model", "mode": "chat", "completion_params": {}},
|
||||
"prompt_template": [{"role": "system", "text": "x"}],
|
||||
"context": {"enabled": False, "variable_selector": []},
|
||||
"vision": {"enabled": False},
|
||||
"structured_output_enabled": True,
|
||||
"structured_output": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {"type": {"type": "string"}},
|
||||
"required": ["type"],
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
node = factory.create_node(node_config)
|
||||
|
||||
assert node.node_data.structured_output_switch_on is True
|
||||
assert node.node_data.structured_output_enabled is True
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("node_type", "constructor_name", "expected_extra_kwargs"),
|
||||
@ -711,7 +760,7 @@ class TestDifyNodeFactoryCreateNode:
|
||||
|
||||
constructor_kwargs = constructor.call_args.kwargs
|
||||
assert constructor_kwargs["node_id"] == "node-id"
|
||||
_assert_typed_node_config(constructor_kwargs["config"], node_id="node-id", node_type=node_type)
|
||||
_assert_constructor_node_data(constructor_kwargs["data"], node_id="node-id", node_type=node_type)
|
||||
assert constructor_kwargs["graph_init_params"] is sentinel.graph_init_params
|
||||
assert constructor_kwargs["graph_runtime_state"] is factory.graph_runtime_state
|
||||
assert constructor_kwargs["credentials_provider"] is sentinel.credentials_provider
|
||||
|
||||
@ -109,8 +109,8 @@ class TestVariablePool:
|
||||
assert pool.get([ENVIRONMENT_VARIABLE_NODE_ID, "env_var_1"]) is not None
|
||||
assert pool.get([CONVERSATION_VARIABLE_NODE_ID, "conv_var_1"]) is not None
|
||||
|
||||
def test_constructor_loads_legacy_bootstrap_kwargs(self):
|
||||
pool = VariablePool(
|
||||
def test_from_bootstrap_loads_legacy_bootstrap_kwargs(self):
|
||||
pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(user_id="test_user_id"),
|
||||
environment_variables=[StringVariable(name="env_var", value="env-value")],
|
||||
conversation_variables=[StringVariable(name="conv_var", value="conv-value")],
|
||||
|
||||
@ -55,7 +55,7 @@ class TestWorkflowEntry:
|
||||
def test_mapping_user_inputs_to_variable_pool_with_system_variables(self):
|
||||
"""Test mapping system variables from user inputs to variable pool."""
|
||||
# Initialize variable pool with system variables
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(
|
||||
user_id="test_user_id",
|
||||
app_id="test_app_id",
|
||||
@ -128,7 +128,7 @@ class TestWorkflowEntry:
|
||||
return NodeConfigDictAdapter.validate_python(node_config)
|
||||
|
||||
workflow = StubWorkflow()
|
||||
variable_pool = VariablePool(system_variables=default_system_variables(), user_inputs={})
|
||||
variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables(), user_inputs={})
|
||||
expected_limits = CodeNodeLimits(
|
||||
max_string_length=dify_config.CODE_MAX_STRING_LENGTH,
|
||||
max_number=dify_config.CODE_MAX_NUMBER,
|
||||
@ -157,7 +157,7 @@ class TestWorkflowEntry:
|
||||
"""Test mapping environment variables from user inputs to variable pool."""
|
||||
# Initialize variable pool with environment variables
|
||||
env_var = StringVariable(name="API_KEY", value="existing_key")
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=default_system_variables(),
|
||||
environment_variables=[env_var],
|
||||
user_inputs={},
|
||||
@ -198,7 +198,7 @@ class TestWorkflowEntry:
|
||||
"""Test mapping conversation variables from user inputs to variable pool."""
|
||||
# Initialize variable pool with conversation variables
|
||||
conv_var = StringVariable(name="last_message", value="Hello")
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=default_system_variables(),
|
||||
conversation_variables=[conv_var],
|
||||
user_inputs={},
|
||||
@ -239,7 +239,7 @@ class TestWorkflowEntry:
|
||||
def test_mapping_user_inputs_to_variable_pool_with_regular_variables(self):
|
||||
"""Test mapping regular node variables from user inputs to variable pool."""
|
||||
# Initialize empty variable pool
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=default_system_variables(),
|
||||
user_inputs={},
|
||||
)
|
||||
@ -281,7 +281,7 @@ class TestWorkflowEntry:
|
||||
|
||||
def test_mapping_user_inputs_with_file_handling(self):
|
||||
"""Test mapping file inputs from user inputs to variable pool."""
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=default_system_variables(),
|
||||
user_inputs={},
|
||||
)
|
||||
@ -340,7 +340,7 @@ class TestWorkflowEntry:
|
||||
|
||||
def test_mapping_user_inputs_missing_variable_error(self):
|
||||
"""Test that mapping raises error when required variable is missing."""
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=default_system_variables(),
|
||||
user_inputs={},
|
||||
)
|
||||
@ -366,7 +366,7 @@ class TestWorkflowEntry:
|
||||
|
||||
def test_mapping_user_inputs_with_alternative_key_format(self):
|
||||
"""Test mapping with alternative key format (without node prefix)."""
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=default_system_variables(),
|
||||
user_inputs={},
|
||||
)
|
||||
@ -396,7 +396,7 @@ class TestWorkflowEntry:
|
||||
|
||||
def test_mapping_user_inputs_with_complex_selectors(self):
|
||||
"""Test mapping with complex node variable keys."""
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=default_system_variables(),
|
||||
user_inputs={},
|
||||
)
|
||||
@ -432,7 +432,7 @@ class TestWorkflowEntry:
|
||||
|
||||
def test_mapping_user_inputs_invalid_node_variable(self):
|
||||
"""Test that mapping handles invalid node variable format."""
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=default_system_variables(),
|
||||
user_inputs={},
|
||||
)
|
||||
@ -463,7 +463,7 @@ class TestWorkflowEntry:
|
||||
env_var = StringVariable(name="API_KEY", value="existing_key")
|
||||
conv_var = StringVariable(name="session_id", value="session123")
|
||||
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(
|
||||
user_id="test_user",
|
||||
app_id="test_app",
|
||||
|
||||
@ -7,7 +7,6 @@ import pytest
|
||||
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
from core.model_manager import ModelInstance
|
||||
from core.workflow import workflow_entry
|
||||
from core.workflow.system_variables import default_system_variables
|
||||
from graphon.entities.base_node_data import BaseNodeData
|
||||
@ -16,10 +15,12 @@ from graphon.errors import WorkflowNodeRunFailedError
|
||||
from graphon.file import File, FileTransferMethod, FileType
|
||||
from graphon.graph import Graph
|
||||
from graphon.graph_events import GraphRunFailedEvent
|
||||
from graphon.model_runtime.entities.llm_entities import LLMUsage
|
||||
from graphon.model_runtime.entities.llm_entities import LLMMode, LLMUsage
|
||||
from graphon.node_events import NodeRunResult
|
||||
from graphon.nodes import BuiltinNodeTypes
|
||||
from graphon.nodes.base.node import Node
|
||||
from graphon.nodes.llm.entities import ContextConfig, LLMNodeData, ModelConfig
|
||||
from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData
|
||||
from graphon.runtime import ChildGraphNotFoundError, VariablePool
|
||||
from graphon.variables.variables import StringVariable
|
||||
from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool
|
||||
@ -29,9 +30,30 @@ def _build_typed_node_config(node_type: NodeType):
|
||||
return {"id": "node-id", "data": BaseNodeData(type=node_type)}
|
||||
|
||||
|
||||
def _build_wrapped_model_instance() -> tuple[SimpleNamespace, ModelInstance]:
|
||||
raw_model_instance = ModelInstance.__new__(ModelInstance)
|
||||
return SimpleNamespace(_model_instance=raw_model_instance), raw_model_instance
|
||||
def _build_model_config(*, provider: str = "openai", model_name: str = "gpt-4o") -> ModelConfig:
|
||||
return ModelConfig(provider=provider, name=model_name, mode=LLMMode.CHAT)
|
||||
|
||||
|
||||
def _build_llm_node_data(*, provider: str = "openai", model_name: str = "gpt-4o") -> LLMNodeData:
|
||||
return LLMNodeData(
|
||||
type=BuiltinNodeTypes.LLM,
|
||||
title="Child Model",
|
||||
model=_build_model_config(provider=provider, model_name=model_name),
|
||||
prompt_template=[],
|
||||
context=ContextConfig(enabled=False),
|
||||
)
|
||||
|
||||
|
||||
def _build_question_classifier_node_data(
|
||||
*, provider: str = "openai", model_name: str = "gpt-4o"
|
||||
) -> QuestionClassifierNodeData:
|
||||
return QuestionClassifierNodeData(
|
||||
type=BuiltinNodeTypes.QUESTION_CLASSIFIER,
|
||||
title="Child Model",
|
||||
query_variable_selector=["sys", "query"],
|
||||
model=_build_model_config(provider=provider, model_name=model_name),
|
||||
classes=[],
|
||||
)
|
||||
|
||||
|
||||
class _FakeModelNodeMixin:
|
||||
@ -40,22 +62,26 @@ class _FakeModelNodeMixin:
|
||||
return "1"
|
||||
|
||||
def post_init(self) -> None:
|
||||
self.model_instance, self.raw_model_instance = _build_wrapped_model_instance()
|
||||
self.model_instance = SimpleNamespace(provider="stale-provider", model_name="stale-model")
|
||||
self.usage_snapshot = LLMUsage.empty_usage()
|
||||
self.usage_snapshot.total_tokens = 1
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={
|
||||
"model_provider": self.node_data.model.provider,
|
||||
"model_name": self.node_data.model.name,
|
||||
},
|
||||
llm_usage=self.usage_snapshot,
|
||||
)
|
||||
|
||||
|
||||
class _FakeLLMNode(_FakeModelNodeMixin, Node[BaseNodeData]):
|
||||
class _FakeLLMNode(_FakeModelNodeMixin, Node[LLMNodeData]):
|
||||
node_type = BuiltinNodeTypes.LLM
|
||||
|
||||
|
||||
class _FakeQuestionClassifierNode(_FakeModelNodeMixin, Node[BaseNodeData]):
|
||||
class _FakeQuestionClassifierNode(_FakeModelNodeMixin, Node[QuestionClassifierNodeData]):
|
||||
node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER
|
||||
|
||||
|
||||
@ -75,7 +101,7 @@ class TestWorkflowChildEngineBuilder:
|
||||
assert result is expected
|
||||
|
||||
def test_build_child_engine_raises_when_root_node_is_missing(self):
|
||||
builder = workflow_entry._WorkflowChildEngineBuilder()
|
||||
builder = workflow_entry._WorkflowChildEngineBuilder(tenant_id="tenant-id")
|
||||
graph_init_params = SimpleNamespace(graph_config={"nodes": []})
|
||||
parent_graph_runtime_state = SimpleNamespace(
|
||||
execution_context=sentinel.execution_context,
|
||||
@ -92,7 +118,7 @@ class TestWorkflowChildEngineBuilder:
|
||||
)
|
||||
|
||||
def test_build_child_engine_constructs_graph_engine_with_quota_layer_only(self):
|
||||
builder = workflow_entry._WorkflowChildEngineBuilder()
|
||||
builder = workflow_entry._WorkflowChildEngineBuilder(tenant_id="tenant-id")
|
||||
graph_init_params = SimpleNamespace(graph_config={"nodes": [{"id": "root"}]})
|
||||
parent_graph_runtime_state = SimpleNamespace(
|
||||
execution_context=sentinel.execution_context,
|
||||
@ -114,7 +140,7 @@ class TestWorkflowChildEngineBuilder:
|
||||
patch.object(workflow_entry, "GraphEngine", return_value=child_engine) as graph_engine_cls,
|
||||
patch.object(workflow_entry, "GraphEngineConfig", return_value=sentinel.graph_engine_config),
|
||||
patch.object(workflow_entry, "InMemoryChannel", return_value=sentinel.command_channel),
|
||||
patch.object(workflow_entry, "LLMQuotaLayer", return_value=sentinel.llm_quota_layer),
|
||||
patch.object(workflow_entry, "LLMQuotaLayer", return_value=sentinel.llm_quota_layer) as llm_quota_layer_cls,
|
||||
):
|
||||
result = builder.build_child_engine(
|
||||
workflow_id="workflow-id",
|
||||
@ -147,11 +173,12 @@ class TestWorkflowChildEngineBuilder:
|
||||
config=sentinel.graph_engine_config,
|
||||
child_engine_builder=builder,
|
||||
)
|
||||
llm_quota_layer_cls.assert_called_once_with(tenant_id="tenant-id")
|
||||
assert child_engine.layer.call_args_list == [((sentinel.llm_quota_layer,), {})]
|
||||
|
||||
@pytest.mark.parametrize("node_cls", [_FakeLLMNode, _FakeQuestionClassifierNode])
|
||||
def test_build_child_engine_runs_llm_quota_layer_for_child_model_nodes(self, node_cls):
|
||||
builder = workflow_entry._WorkflowChildEngineBuilder()
|
||||
builder = workflow_entry._WorkflowChildEngineBuilder(tenant_id="tenant-id")
|
||||
graph_init_params = build_test_graph_init_params(
|
||||
graph_config={"nodes": [{"id": "root"}], "edges": []},
|
||||
)
|
||||
@ -163,12 +190,10 @@ class TestWorkflowChildEngineBuilder:
|
||||
|
||||
def build_graph(*, graph_config, node_factory, root_node_id):
|
||||
_ = graph_config
|
||||
node_data = _build_llm_node_data() if node_cls is _FakeLLMNode else _build_question_classifier_node_data()
|
||||
node = node_cls(
|
||||
node_id=root_node_id,
|
||||
config=BaseNodeData(
|
||||
type=node_cls.node_type,
|
||||
title="Child Model",
|
||||
),
|
||||
data=node_data,
|
||||
graph_init_params=node_factory.graph_init_params,
|
||||
graph_runtime_state=node_factory.graph_runtime_state,
|
||||
)
|
||||
@ -191,8 +216,8 @@ class TestWorkflowChildEngineBuilder:
|
||||
),
|
||||
),
|
||||
patch.object(workflow_entry.Graph, "init", side_effect=build_graph),
|
||||
patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available") as ensure_quota,
|
||||
patch("core.app.workflow.layers.llm_quota.deduct_llm_quota") as deduct_quota,
|
||||
patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model") as ensure_quota,
|
||||
patch("core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model") as deduct_quota,
|
||||
):
|
||||
child_engine = builder.build_child_engine(
|
||||
workflow_id="workflow-id",
|
||||
@ -203,10 +228,15 @@ class TestWorkflowChildEngineBuilder:
|
||||
list(child_engine.run())
|
||||
|
||||
node = created_node["node"]
|
||||
ensure_quota.assert_called_once_with(model_instance=node.raw_model_instance)
|
||||
ensure_quota.assert_called_once_with(
|
||||
tenant_id="tenant-id",
|
||||
provider=node.node_data.model.provider,
|
||||
model=node.node_data.model.name,
|
||||
)
|
||||
deduct_quota.assert_called_once_with(
|
||||
tenant_id="tenant",
|
||||
model_instance=node.raw_model_instance,
|
||||
tenant_id="tenant-id",
|
||||
provider=node.node_data.model.provider,
|
||||
model=node.node_data.model.name,
|
||||
usage=node.usage_snapshot,
|
||||
)
|
||||
|
||||
@ -252,7 +282,7 @@ class TestWorkflowEntryInit:
|
||||
"ExecutionLimitsLayer",
|
||||
return_value=execution_limits_layer,
|
||||
) as execution_limits_layer_cls,
|
||||
patch.object(workflow_entry, "LLMQuotaLayer", return_value=llm_quota_layer),
|
||||
patch.object(workflow_entry, "LLMQuotaLayer", return_value=llm_quota_layer) as llm_quota_layer_cls,
|
||||
patch.object(workflow_entry, "ObservabilityLayer", return_value=observability_layer),
|
||||
):
|
||||
entry = workflow_entry.WorkflowEntry(
|
||||
@ -291,6 +321,7 @@ class TestWorkflowEntryInit:
|
||||
max_steps=workflow_entry.dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
|
||||
max_time=workflow_entry.dify_config.WORKFLOW_MAX_EXECUTION_TIME,
|
||||
)
|
||||
llm_quota_layer_cls.assert_called_once_with(tenant_id="tenant-id")
|
||||
assert graph_engine.layer.call_args_list == [
|
||||
((debug_layer,), {}),
|
||||
((execution_limits_layer,), {}),
|
||||
@ -334,7 +365,7 @@ class TestWorkflowEntrySingleStepRun:
|
||||
def extract_variable_selector_to_variable_mapping(**_kwargs):
|
||||
return {}
|
||||
|
||||
variable_pool = VariablePool(system_variables=default_system_variables(), user_inputs={})
|
||||
variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables(), user_inputs={})
|
||||
variable_loader = MagicMock()
|
||||
variable_loader.load_variables.return_value = [
|
||||
StringVariable(
|
||||
|
||||
@ -0,0 +1,130 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import create_engine, select
|
||||
|
||||
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity
|
||||
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
|
||||
from events.event_handlers import update_provider_when_message_created
|
||||
from models import TenantCreditPool
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
def test_message_created_trial_credit_accounting_does_not_raise_when_balance_is_insufficient() -> None:
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
TenantCreditPool.__table__.create(engine)
|
||||
tenant_id = str(uuid4())
|
||||
pool_id = str(uuid4())
|
||||
with engine.begin() as connection:
|
||||
connection.execute(
|
||||
TenantCreditPool.__table__.insert(),
|
||||
{
|
||||
"id": pool_id,
|
||||
"tenant_id": tenant_id,
|
||||
"pool_type": ProviderQuotaType.TRIAL,
|
||||
"quota_limit": 10,
|
||||
"quota_used": 9,
|
||||
},
|
||||
)
|
||||
|
||||
system_configuration = SimpleNamespace(
|
||||
current_quota_type=ProviderQuotaType.TRIAL,
|
||||
quota_configurations=[
|
||||
SimpleNamespace(
|
||||
quota_type=ProviderQuotaType.TRIAL,
|
||||
quota_unit=QuotaUnit.TOKENS,
|
||||
quota_limit=10,
|
||||
)
|
||||
],
|
||||
)
|
||||
application_generate_entity = ChatAppGenerateEntity.model_construct(
|
||||
app_config=SimpleNamespace(tenant_id=tenant_id),
|
||||
model_conf=SimpleNamespace(
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
provider_model_bundle=SimpleNamespace(
|
||||
configuration=SimpleNamespace(
|
||||
using_provider_type=ProviderType.SYSTEM,
|
||||
system_configuration=system_configuration,
|
||||
)
|
||||
),
|
||||
),
|
||||
)
|
||||
message = SimpleNamespace(message_tokens=2, answer_tokens=1)
|
||||
|
||||
with (
|
||||
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
|
||||
patch.object(update_provider_when_message_created, "_execute_provider_updates"),
|
||||
):
|
||||
update_provider_when_message_created.handle(
|
||||
sender=message,
|
||||
application_generate_entity=application_generate_entity,
|
||||
)
|
||||
|
||||
with engine.connect() as connection:
|
||||
quota_used = connection.scalar(select(TenantCreditPool.quota_used).where(TenantCreditPool.id == pool_id))
|
||||
|
||||
assert quota_used == 10
|
||||
|
||||
|
||||
def test_message_created_paid_credit_accounting_uses_paid_pool() -> None:
|
||||
tenant_id = str(uuid4())
|
||||
system_configuration = SimpleNamespace(
|
||||
current_quota_type=ProviderQuotaType.PAID,
|
||||
quota_configurations=[
|
||||
SimpleNamespace(
|
||||
quota_type=ProviderQuotaType.PAID,
|
||||
quota_unit=QuotaUnit.TOKENS,
|
||||
quota_limit=10,
|
||||
)
|
||||
],
|
||||
)
|
||||
application_generate_entity = ChatAppGenerateEntity.model_construct(
|
||||
app_config=SimpleNamespace(tenant_id=tenant_id),
|
||||
model_conf=SimpleNamespace(
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
provider_model_bundle=SimpleNamespace(
|
||||
configuration=SimpleNamespace(
|
||||
using_provider_type=ProviderType.SYSTEM,
|
||||
system_configuration=system_configuration,
|
||||
)
|
||||
),
|
||||
),
|
||||
)
|
||||
message = SimpleNamespace(message_tokens=2, answer_tokens=1)
|
||||
|
||||
with (
|
||||
patch.object(update_provider_when_message_created, "_deduct_credit_pool_quota_capped") as mock_deduct,
|
||||
patch.object(update_provider_when_message_created, "_execute_provider_updates"),
|
||||
):
|
||||
update_provider_when_message_created.handle(
|
||||
sender=message,
|
||||
application_generate_entity=application_generate_entity,
|
||||
)
|
||||
|
||||
mock_deduct.assert_called_once_with(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=3,
|
||||
pool_type="paid",
|
||||
)
|
||||
|
||||
|
||||
def test_capped_credit_pool_accounting_skips_exhaustion_warning_when_full_amount_is_deducted(caplog) -> None:
|
||||
with patch(
|
||||
"services.credit_pool_service.CreditPoolService.deduct_credits_capped",
|
||||
return_value=3,
|
||||
) as mock_deduct:
|
||||
update_provider_when_message_created._deduct_credit_pool_quota_capped(
|
||||
tenant_id="tenant-id",
|
||||
credits_required=3,
|
||||
pool_type="trial",
|
||||
)
|
||||
|
||||
mock_deduct.assert_called_once_with(
|
||||
tenant_id="tenant-id",
|
||||
credits_required=3,
|
||||
pool_type="trial",
|
||||
)
|
||||
assert "Credit pool exhausted during message-created accounting" not in caplog.text
|
||||
158
api/tests/unit_tests/services/test_credit_pool_service.py
Normal file
158
api/tests/unit_tests/services/test_credit_pool_service.py
Normal file
@ -0,0 +1,158 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine, select
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
from core.errors.error import QuotaExceededError
|
||||
from models import TenantCreditPool
|
||||
from models.enums import ProviderQuotaType
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
|
||||
def _create_engine_with_pool(*, quota_limit: int, quota_used: int) -> tuple[Engine, str, str]:
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
TenantCreditPool.__table__.create(engine)
|
||||
tenant_id = str(uuid4())
|
||||
pool_id = str(uuid4())
|
||||
with engine.begin() as connection:
|
||||
connection.execute(
|
||||
TenantCreditPool.__table__.insert(),
|
||||
{
|
||||
"id": pool_id,
|
||||
"tenant_id": tenant_id,
|
||||
"pool_type": ProviderQuotaType.TRIAL,
|
||||
"quota_limit": quota_limit,
|
||||
"quota_used": quota_used,
|
||||
},
|
||||
)
|
||||
return engine, tenant_id, pool_id
|
||||
|
||||
|
||||
def _get_quota_used(*, engine: Engine, pool_id: str) -> int | None:
|
||||
with engine.connect() as connection:
|
||||
return connection.scalar(select(TenantCreditPool.quota_used).where(TenantCreditPool.id == pool_id))
|
||||
|
||||
|
||||
def test_check_and_deduct_credits_deducts_exact_amount_when_sufficient() -> None:
|
||||
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2)
|
||||
|
||||
with patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)):
|
||||
deducted_credits = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=3)
|
||||
|
||||
assert deducted_credits == 3
|
||||
assert _get_quota_used(engine=engine, pool_id=pool_id) == 5
|
||||
|
||||
|
||||
def test_check_and_deduct_credits_returns_zero_for_non_positive_request() -> None:
|
||||
assert CreditPoolService.check_and_deduct_credits(tenant_id=str(uuid4()), credits_required=0) == 0
|
||||
|
||||
|
||||
def test_check_and_deduct_credits_raises_when_pool_is_missing() -> None:
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
TenantCreditPool.__table__.create(engine)
|
||||
|
||||
with (
|
||||
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
|
||||
pytest.raises(QuotaExceededError, match="Credit pool not found"),
|
||||
):
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id=str(uuid4()), credits_required=1)
|
||||
|
||||
|
||||
def test_check_and_deduct_credits_raises_when_pool_is_empty() -> None:
|
||||
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=10)
|
||||
|
||||
with (
|
||||
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
|
||||
pytest.raises(QuotaExceededError, match="No credits remaining"),
|
||||
):
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=1)
|
||||
|
||||
assert _get_quota_used(engine=engine, pool_id=pool_id) == 10
|
||||
|
||||
|
||||
def test_check_and_deduct_credits_raises_without_partial_deduction_when_insufficient() -> None:
|
||||
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=9)
|
||||
|
||||
with (
|
||||
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
|
||||
pytest.raises(QuotaExceededError, match="Insufficient credits remaining"),
|
||||
):
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=3)
|
||||
|
||||
assert _get_quota_used(engine=engine, pool_id=pool_id) == 9
|
||||
|
||||
|
||||
def test_check_and_deduct_credits_wraps_unexpected_deduction_errors() -> None:
|
||||
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2)
|
||||
|
||||
with (
|
||||
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
|
||||
patch.object(CreditPoolService, "_get_locked_pool", side_effect=RuntimeError("database unavailable")),
|
||||
pytest.raises(QuotaExceededError, match="Failed to deduct credits"),
|
||||
):
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=1)
|
||||
|
||||
assert _get_quota_used(engine=engine, pool_id=pool_id) == 2
|
||||
|
||||
|
||||
def test_deduct_credits_capped_returns_zero_for_non_positive_request() -> None:
|
||||
assert CreditPoolService.deduct_credits_capped(tenant_id=str(uuid4()), credits_required=0) == 0
|
||||
|
||||
|
||||
def test_deduct_credits_capped_returns_zero_when_pool_is_missing() -> None:
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
TenantCreditPool.__table__.create(engine)
|
||||
|
||||
with patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)):
|
||||
deducted_credits = CreditPoolService.deduct_credits_capped(tenant_id=str(uuid4()), credits_required=1)
|
||||
|
||||
assert deducted_credits == 0
|
||||
|
||||
|
||||
def test_deduct_credits_capped_returns_zero_when_pool_is_empty() -> None:
|
||||
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=10)
|
||||
|
||||
with patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)):
|
||||
deducted_credits = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=1)
|
||||
|
||||
assert deducted_credits == 0
|
||||
assert _get_quota_used(engine=engine, pool_id=pool_id) == 10
|
||||
|
||||
|
||||
def test_deduct_credits_capped_deducts_only_remaining_balance_when_insufficient() -> None:
|
||||
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=9)
|
||||
|
||||
with patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)):
|
||||
deducted_credits = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=3)
|
||||
|
||||
assert deducted_credits == 1
|
||||
assert _get_quota_used(engine=engine, pool_id=pool_id) == 10
|
||||
|
||||
|
||||
def test_deduct_credits_capped_wraps_unexpected_deduction_errors() -> None:
|
||||
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2)
|
||||
|
||||
with (
|
||||
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
|
||||
patch.object(CreditPoolService, "_get_locked_pool", side_effect=RuntimeError("database unavailable")),
|
||||
pytest.raises(QuotaExceededError, match="Failed to deduct credits"),
|
||||
):
|
||||
CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=1)
|
||||
|
||||
assert _get_quota_used(engine=engine, pool_id=pool_id) == 2
|
||||
|
||||
|
||||
def test_deduct_credits_capped_reraises_quota_exceeded_errors() -> None:
|
||||
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2)
|
||||
|
||||
with (
|
||||
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
|
||||
patch.object(CreditPoolService, "_get_locked_pool", side_effect=QuotaExceededError("quota unavailable")),
|
||||
pytest.raises(QuotaExceededError, match="quota unavailable"),
|
||||
):
|
||||
CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=1)
|
||||
|
||||
assert _get_quota_used(engine=engine, pool_id=pool_id) == 2
|
||||
@ -2845,7 +2845,7 @@ class TestWorkflowServiceFreeNodeExecution:
|
||||
mock_node_cls.validate_node_data.assert_called_once_with(sentinel.adapted_node_data)
|
||||
mock_node_cls.assert_called_once_with(
|
||||
node_id="n-1",
|
||||
config=sentinel.node_data,
|
||||
data=sentinel.node_data,
|
||||
graph_init_params=mock_graph_init_context_cls.return_value.to_graph_init_params.return_value,
|
||||
graph_runtime_state=ANY,
|
||||
runtime=mock_runtime_cls.return_value,
|
||||
|
||||
8
api/uv.lock
generated
8
api/uv.lock
generated
@ -1597,7 +1597,7 @@ requires-dist = [
|
||||
{ name = "gmpy2", specifier = ">=2.3.0" },
|
||||
{ name = "google-api-python-client", specifier = ">=2.195.0" },
|
||||
{ name = "google-cloud-aiplatform", specifier = ">=1.149.0,<2.0.0" },
|
||||
{ name = "graphon", specifier = "~=0.2.2" },
|
||||
{ name = "graphon", specifier = "~=0.3.0" },
|
||||
{ name = "gunicorn", specifier = ">=25.3.0" },
|
||||
{ name = "httpx", extras = ["socks"], specifier = ">=0.28.1,<1.0.0" },
|
||||
{ name = "httpx-sse", specifier = "~=0.4.0" },
|
||||
@ -2940,7 +2940,7 @@ httpx = [
|
||||
|
||||
[[package]]
|
||||
name = "graphon"
|
||||
version = "0.2.2"
|
||||
version = "0.3.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "charset-normalizer" },
|
||||
@ -2961,9 +2961,9 @@ dependencies = [
|
||||
{ name = "unstructured", extra = ["docx", "epub", "md", "ppt", "pptx"] },
|
||||
{ name = "webvtt-py" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/08/50/e745a79c5f742f88f6011a1f7c9ba2c2f9cc1beedd982f0b192f1ab8c748/graphon-0.2.2.tar.gz", hash = "sha256:141f0de536171850f1af6f738dc66f0285aadd3c097f1dad2a038636789e0aa5", size = 236360, upload-time = "2026-04-17T08:52:28.047Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/bf/62/83593d6e7a139ff124711ea05882cadca7065c11a38763aa9360d7e76804/graphon-0.3.0.tar.gz", hash = "sha256:cd38f842ae3dcfa956428b952efbe2a3ea9c1581446647142accbbdeb638b876", size = 241176, upload-time = "2026-04-21T15:18:48.291Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/de/89/a6340afdaf5169d17a318e00fc685fb67ed99baa602c2cbbbf6af6a76096/graphon-0.2.2-py3-none-any.whl", hash = "sha256:754e544d08779138f99eac6547ab08559463680e2c76488b05e1c978210392b4", size = 340808, upload-time = "2026-04-17T08:52:26.5Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b3/f7/81ee8f0368aa6a2d47f97fecc5d4a12865c987906798cbddd0e3b8387f33/graphon-0.3.0-py3-none-any.whl", hash = "sha256:9cca45ebab2a79fd4d04432f55b5b962e9e4f34fa037cc20fee7f18ec80eaa5d", size = 348486, upload-time = "2026-04-21T15:18:46.737Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user