mirror of
https://github.com/langgenius/dify.git
synced 2026-05-09 12:59:18 +08:00
fix(api): resolve graphon upgrade CI failures
Repair the plugin model runtime for Graphon 0.3.0 by implementing structured-output support, normalizing non-stream LLM calls, and tightening workflow-layer type safety. Update the remaining workflow tests to use VariablePool.from_bootstrap(...) and the new node data constructor API so the CI unit and integration suites match the upgraded runtime behavior.
This commit is contained in:
parent
05aba26091
commit
f13faa69b3
@ -5,7 +5,7 @@ This layer centralizes model-quota deduction outside node implementations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, cast, final, override
|
||||
from typing import final, override
|
||||
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
|
||||
from core.app.llm import deduct_llm_quota, ensure_llm_quota_available
|
||||
@ -17,11 +17,6 @@ from graphon.graph_engine.layers import GraphEngineLayer
|
||||
from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, NodeRunSucceededEvent
|
||||
from graphon.nodes.base.node import Node
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphon.nodes.llm.node import LLMNode
|
||||
from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from graphon.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -109,16 +104,14 @@ class LLMQuotaLayer(GraphEngineLayer):
|
||||
|
||||
@staticmethod
|
||||
def _extract_model_instance(node: Node) -> ModelInstance | None:
|
||||
match node.node_type:
|
||||
case BuiltinNodeTypes.LLM | BuiltinNodeTypes.PARAMETER_EXTRACTOR | BuiltinNodeTypes.QUESTION_CLASSIFIER:
|
||||
pass
|
||||
case _:
|
||||
return None
|
||||
|
||||
try:
|
||||
match node.node_type:
|
||||
case BuiltinNodeTypes.LLM:
|
||||
model_instance = cast("LLMNode", node).model_instance
|
||||
case BuiltinNodeTypes.PARAMETER_EXTRACTOR:
|
||||
model_instance = cast("ParameterExtractorNode", node).model_instance
|
||||
case BuiltinNodeTypes.QUESTION_CLASSIFIER:
|
||||
model_instance = cast("QuestionClassifierNode", node).model_instance
|
||||
case _:
|
||||
return None
|
||||
model_instance = getattr(node, "model_instance", None)
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
"LLMQuotaLayer skipped quota deduction because node does not expose a model instance, node_id=%s",
|
||||
@ -133,4 +126,12 @@ class LLMQuotaLayer(GraphEngineLayer):
|
||||
if isinstance(raw_model_instance, ModelInstance):
|
||||
return raw_model_instance
|
||||
|
||||
private_model_instance = getattr(node, "_model_instance", None)
|
||||
if isinstance(private_model_instance, ModelInstance):
|
||||
return private_model_instance
|
||||
|
||||
wrapped_private_model_instance = getattr(private_model_instance, "_model_instance", None)
|
||||
if isinstance(wrapped_private_model_instance, ModelInstance):
|
||||
return wrapped_private_model_instance
|
||||
|
||||
return None
|
||||
|
||||
@ -4,22 +4,31 @@ import hashlib
|
||||
import logging
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
from threading import Lock
|
||||
from typing import IO, Any, Union
|
||||
from typing import IO, Any, Literal, cast, overload
|
||||
|
||||
from pydantic import ValidationError
|
||||
from redis import RedisError
|
||||
|
||||
from configs import dify_config
|
||||
from core.llm_generator.output_parser.structured_output import (
|
||||
invoke_llm_with_structured_output as invoke_llm_with_structured_output_helper,
|
||||
)
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from core.plugin.impl.asset import PluginAssetManager
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
from extensions.ext_redis import redis_client
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from graphon.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkWithStructuredOutput,
|
||||
LLMResultWithStructuredOutput,
|
||||
)
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from graphon.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
|
||||
from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import normalize_non_stream_runtime_result
|
||||
from graphon.model_runtime.protocols.runtime import ModelRuntime
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
@ -29,6 +38,57 @@ logger = logging.getLogger(__name__)
|
||||
TENANT_SCOPE_SCHEMA_CACHE_USER_ID = "__DIFY_TS__"
|
||||
|
||||
|
||||
class _PluginStructuredOutputModelInstance:
|
||||
"""Reuse the shared structured-output helper without depending on `ModelInstance`."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
runtime: PluginModelRuntime,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
) -> None:
|
||||
self._runtime = runtime
|
||||
self._provider = provider
|
||||
self._model = model
|
||||
self._credentials = credentials
|
||||
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict[str, Any] | None = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
callbacks: object | None = None,
|
||||
) -> LLMResult | Generator[LLMResultChunk, None, None]:
|
||||
del callbacks
|
||||
if stream:
|
||||
return self._runtime.invoke_llm(
|
||||
provider=self._provider,
|
||||
model=self._model,
|
||||
credentials=self._credentials,
|
||||
model_parameters=model_parameters or {},
|
||||
prompt_messages=prompt_messages,
|
||||
tools=list(tools) if tools else None,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
return self._runtime.invoke_llm(
|
||||
provider=self._provider,
|
||||
model=self._model,
|
||||
credentials=self._credentials,
|
||||
model_parameters=model_parameters or {},
|
||||
prompt_messages=prompt_messages,
|
||||
tools=list(tools) if tools else None,
|
||||
stop=stop,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
|
||||
class PluginModelRuntime(ModelRuntime):
|
||||
"""Plugin-backed runtime adapter bound to tenant context and optional caller scope."""
|
||||
|
||||
@ -195,6 +255,34 @@ class PluginModelRuntime(ModelRuntime):
|
||||
|
||||
return schema
|
||||
|
||||
@overload
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
model_parameters: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None,
|
||||
stop: Sequence[str] | None,
|
||||
stream: Literal[False],
|
||||
) -> LLMResult: ...
|
||||
|
||||
@overload
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
model_parameters: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None,
|
||||
stop: Sequence[str] | None,
|
||||
stream: Literal[True],
|
||||
) -> Generator[LLMResultChunk, None, None]: ...
|
||||
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
@ -206,9 +294,9 @@ class PluginModelRuntime(ModelRuntime):
|
||||
tools: list[PromptMessageTool] | None,
|
||||
stop: Sequence[str] | None,
|
||||
stream: bool,
|
||||
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
|
||||
) -> LLMResult | Generator[LLMResultChunk, None, None]:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.invoke_llm(
|
||||
result = self.client.invoke_llm(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
@ -221,6 +309,81 @@ class PluginModelRuntime(ModelRuntime):
|
||||
stop=list(stop) if stop else None,
|
||||
stream=stream,
|
||||
)
|
||||
if stream:
|
||||
return result
|
||||
|
||||
return normalize_non_stream_runtime_result(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
result=result,
|
||||
)
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
json_schema: dict[str, Any],
|
||||
model_parameters: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
stop: Sequence[str] | None,
|
||||
stream: Literal[False],
|
||||
) -> LLMResultWithStructuredOutput: ...
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
json_schema: dict[str, Any],
|
||||
model_parameters: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
stop: Sequence[str] | None,
|
||||
stream: Literal[True],
|
||||
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
|
||||
def invoke_llm_with_structured_output(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
json_schema: dict[str, Any],
|
||||
model_parameters: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
stop: Sequence[str] | None,
|
||||
stream: bool,
|
||||
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
model_schema = self.get_model_schema(
|
||||
provider=provider,
|
||||
model_type=ModelType.LLM,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
)
|
||||
if model_schema is None:
|
||||
raise ValueError(f"Model schema not found for {model}")
|
||||
|
||||
adapter = _PluginStructuredOutputModelInstance(
|
||||
runtime=self,
|
||||
provider=provider,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
)
|
||||
return invoke_llm_with_structured_output_helper(
|
||||
provider=provider,
|
||||
model_schema=model_schema,
|
||||
model_instance=cast(Any, adapter),
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=json_schema,
|
||||
model_parameters=model_parameters,
|
||||
tools=None,
|
||||
stop=list(stop) if stop else None,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
def get_llm_num_tokens(
|
||||
self,
|
||||
|
||||
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Any, Protocol, cast
|
||||
from typing import Any, Protocol
|
||||
from uuid import uuid4
|
||||
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
@ -82,13 +82,10 @@ def build_system_variables(values: Mapping[str, Any] | None = None, /, **kwargs:
|
||||
normalized = _normalize_system_variable_values(values, **kwargs)
|
||||
|
||||
return [
|
||||
cast(
|
||||
Variable,
|
||||
segment_to_variable(
|
||||
segment=build_segment(value),
|
||||
selector=system_variable_selector(key),
|
||||
name=key,
|
||||
),
|
||||
segment_to_variable(
|
||||
segment=build_segment(value),
|
||||
selector=system_variable_selector(key),
|
||||
name=key,
|
||||
)
|
||||
for key, value in normalized.items()
|
||||
]
|
||||
@ -130,13 +127,10 @@ def build_bootstrap_variables(
|
||||
|
||||
for node_id, value in rag_pipeline_variables_map.items():
|
||||
variables.append(
|
||||
cast(
|
||||
Variable,
|
||||
segment_to_variable(
|
||||
segment=build_segment(value),
|
||||
selector=(RAG_PIPELINE_VARIABLE_NODE_ID, node_id),
|
||||
name=node_id,
|
||||
),
|
||||
segment_to_variable(
|
||||
segment=build_segment(value),
|
||||
selector=(RAG_PIPELINE_VARIABLE_NODE_ID, node_id),
|
||||
name=node_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -157,8 +157,8 @@ class DraftVarLoader(VariableLoader):
|
||||
# This approach reduces loading time by querying external systems concurrently.
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
offloaded_variables = executor.map(self._load_offloaded_variable, offloaded_draft_vars)
|
||||
for selector, variable in offloaded_variables:
|
||||
variable_by_selector[selector] = variable
|
||||
for selector, offloaded_variable in offloaded_variables:
|
||||
variable_by_selector[selector] = offloaded_variable
|
||||
|
||||
return list(variable_by_selector.values())
|
||||
|
||||
|
||||
@ -45,7 +45,7 @@ def init_code_node(code_config: dict):
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
|
||||
@ -55,7 +55,7 @@ def init_http_node(config: dict):
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
@ -204,7 +204,7 @@ def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock):
|
||||
from graphon.runtime import VariablePool
|
||||
|
||||
# Create variable pool
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(user_id="test", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
@ -702,7 +702,7 @@ def test_nested_object_variable_selector(setup_http_mock):
|
||||
)
|
||||
|
||||
# Create independent variable pool for this test only
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
|
||||
@ -53,7 +53,7 @@ def init_llm_node(config: dict) -> LLMNode:
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(
|
||||
user_id="aaa",
|
||||
app_id=app_id,
|
||||
|
||||
@ -56,7 +56,7 @@ def init_parameter_extractor_node(config: dict, memory=None):
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(
|
||||
user_id="aaa", files=[], query="what's the weather in SF", conversation_id="abababa"
|
||||
),
|
||||
|
||||
@ -66,7 +66,7 @@ def test_execute_template_transform():
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
|
||||
@ -41,7 +41,7 @@ def init_tool_node(config: dict):
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
|
||||
@ -66,7 +66,7 @@ def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepos
|
||||
|
||||
|
||||
def _build_runtime_state(workflow_execution_id: str, app_id: str, workflow_id: str, user_id: str) -> GraphRuntimeState:
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(
|
||||
workflow_execution_id=workflow_execution_id,
|
||||
app_id=app_id,
|
||||
|
||||
@ -16,7 +16,7 @@ from models.workflow import Workflow
|
||||
|
||||
|
||||
def _make_graph_state():
|
||||
variable_pool = VariablePool(
|
||||
variable_pool = VariablePool.from_bootstrap(
|
||||
system_variables=default_system_variables(),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
|
||||
@ -8,9 +8,9 @@ from graphon.enums import BuiltinNodeTypes
|
||||
|
||||
|
||||
class DummyNode:
|
||||
def __init__(self, *, node_id, config, graph_init_params, graph_runtime_state, **kwargs):
|
||||
def __init__(self, *, node_id, data, graph_init_params, graph_runtime_state, **kwargs):
|
||||
self.id = node_id
|
||||
self.config = config
|
||||
self.data = data
|
||||
self.graph_init_params = graph_init_params
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self.kwargs = kwargs
|
||||
|
||||
@ -14,16 +14,11 @@ from graphon.nodes.llm.entities import LLMNodeData
|
||||
from graphon.variables.segments import StringSegment
|
||||
|
||||
|
||||
def _assert_typed_node_config(config, *, node_id: str, node_type: NodeType, version: str = "1") -> None:
|
||||
def _assert_typed_node_data(data, *, node_id: str, node_type: NodeType, version: str = "1") -> None:
|
||||
_ = node_id
|
||||
if isinstance(config, BaseNodeData):
|
||||
assert config.type == node_type
|
||||
assert config.version == version
|
||||
return
|
||||
|
||||
assert isinstance(config, dict)
|
||||
assert config["type"] == node_type
|
||||
assert config["version"] == version
|
||||
assert isinstance(data, BaseNodeData)
|
||||
assert data.type == node_type
|
||||
assert data.version == version
|
||||
|
||||
|
||||
def _node_constructor(*, return_value):
|
||||
@ -470,7 +465,7 @@ class TestDifyNodeFactoryCreateNode:
|
||||
matched_node_class.assert_called_once()
|
||||
kwargs = matched_node_class.call_args.kwargs
|
||||
assert kwargs["node_id"] == "node-id"
|
||||
_assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=BuiltinNodeTypes.START, version="9")
|
||||
_assert_typed_node_data(kwargs["data"], node_id="node-id", node_type=BuiltinNodeTypes.START, version="9")
|
||||
assert kwargs["graph_init_params"] is sentinel.graph_init_params
|
||||
assert kwargs["graph_runtime_state"] is factory.graph_runtime_state
|
||||
latest_node_class.assert_not_called()
|
||||
@ -490,7 +485,7 @@ class TestDifyNodeFactoryCreateNode:
|
||||
latest_node_class.assert_called_once()
|
||||
kwargs = latest_node_class.call_args.kwargs
|
||||
assert kwargs["node_id"] == "node-id"
|
||||
_assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=BuiltinNodeTypes.START, version="9")
|
||||
_assert_typed_node_data(kwargs["data"], node_id="node-id", node_type=BuiltinNodeTypes.START, version="9")
|
||||
assert kwargs["graph_init_params"] is sentinel.graph_init_params
|
||||
assert kwargs["graph_runtime_state"] is factory.graph_runtime_state
|
||||
|
||||
@ -528,7 +523,7 @@ class TestDifyNodeFactoryCreateNode:
|
||||
assert result is created_node
|
||||
kwargs = constructor.call_args.kwargs
|
||||
assert kwargs["node_id"] == "node-id"
|
||||
_assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=node_type)
|
||||
_assert_typed_node_data(kwargs["data"], node_id="node-id", node_type=node_type)
|
||||
assert kwargs["graph_init_params"] is sentinel.graph_init_params
|
||||
assert kwargs["graph_runtime_state"] is factory.graph_runtime_state
|
||||
|
||||
@ -597,7 +592,7 @@ class TestDifyNodeFactoryCreateNode:
|
||||
prepared_llm.assert_called_once_with(sentinel.model_instance)
|
||||
assert kwargs["model_instance"] is wrapped_model_instance
|
||||
|
||||
def test_create_node_passes_alias_preserving_llm_config_to_constructor(self, monkeypatch, factory):
|
||||
def test_create_node_passes_typed_llm_data_to_constructor(self, monkeypatch, factory):
|
||||
created_node = object()
|
||||
constructor = _node_constructor(return_value=created_node)
|
||||
monkeypatch.setattr(factory, "_resolve_node_class", MagicMock(return_value=constructor))
|
||||
@ -625,10 +620,10 @@ class TestDifyNodeFactoryCreateNode:
|
||||
|
||||
factory.create_node(node_config)
|
||||
|
||||
config = constructor.call_args.kwargs["config"]
|
||||
assert isinstance(config, dict)
|
||||
assert config["structured_output_enabled"] is True
|
||||
assert "structured_output_switch_on" not in config
|
||||
data = constructor.call_args.kwargs["data"]
|
||||
assert isinstance(data, LLMNodeData)
|
||||
assert data.structured_output_enabled is True
|
||||
assert data.structured_output_switch_on is True
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("node_type", "constructor_name", "expected_extra_kwargs"),
|
||||
@ -707,7 +702,7 @@ class TestDifyNodeFactoryCreateNode:
|
||||
|
||||
constructor_kwargs = constructor.call_args.kwargs
|
||||
assert constructor_kwargs["node_id"] == "node-id"
|
||||
_assert_typed_node_config(constructor_kwargs["config"], node_id="node-id", node_type=node_type)
|
||||
_assert_typed_node_data(constructor_kwargs["data"], node_id="node-id", node_type=node_type)
|
||||
assert constructor_kwargs["graph_init_params"] is sentinel.graph_init_params
|
||||
assert constructor_kwargs["graph_runtime_state"] is factory.graph_runtime_state
|
||||
assert constructor_kwargs["credentials_provider"] is sentinel.credentials_provider
|
||||
|
||||
@ -109,8 +109,8 @@ class TestVariablePool:
|
||||
assert pool.get([ENVIRONMENT_VARIABLE_NODE_ID, "env_var_1"]) is not None
|
||||
assert pool.get([CONVERSATION_VARIABLE_NODE_ID, "conv_var_1"]) is not None
|
||||
|
||||
def test_constructor_loads_legacy_bootstrap_kwargs(self):
|
||||
pool = VariablePool(
|
||||
def test_from_bootstrap_loads_legacy_bootstrap_kwargs(self):
|
||||
pool = VariablePool.from_bootstrap(
|
||||
system_variables=build_system_variables(user_id="test_user_id"),
|
||||
environment_variables=[StringVariable(name="env_var", value="env-value")],
|
||||
conversation_variables=[StringVariable(name="conv_var", value="conv-value")],
|
||||
|
||||
@ -2845,7 +2845,7 @@ class TestWorkflowServiceFreeNodeExecution:
|
||||
mock_node_cls.validate_node_data.assert_called_once_with(sentinel.adapted_node_data)
|
||||
mock_node_cls.assert_called_once_with(
|
||||
node_id="n-1",
|
||||
config=sentinel.node_data,
|
||||
data=sentinel.node_data,
|
||||
graph_init_params=mock_graph_init_context_cls.return_value.to_graph_init_params.return_value,
|
||||
graph_runtime_state=ANY,
|
||||
runtime=mock_runtime_cls.return_value,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user