mirror of https://github.com/langgenius/dify.git
fix: iteration and loop node single step run (#26036)
This commit is contained in:
parent
1e3df09fc6
commit
d823da18db
|
|
@ -79,29 +79,12 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
# if only single iteration run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
# Handle single iteration or single loop run
|
||||
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
|
||||
workflow=self._workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
elif self.application_generate_entity.single_loop_run:
|
||||
# if only single loop run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=self._workflow,
|
||||
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||
user_inputs=dict(self.application_generate_entity.single_loop_run.inputs),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
single_iteration_run=self.application_generate_entity.single_iteration_run,
|
||||
single_loop_run=self.application_generate_entity.single_loop_run,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
|
|
|
|||
|
|
@ -427,6 +427,9 @@ class PipelineGenerator(BaseAppGenerator):
|
|||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
workflow_execution_id=str(uuid.uuid4()),
|
||||
single_iteration_run=RagPipelineGenerateEntity.SingleIterationRunEntity(
|
||||
node_id=node_id, inputs=args["inputs"]
|
||||
),
|
||||
)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
|
@ -465,6 +468,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
variable_loader=var_loader,
|
||||
context=contextvars.copy_context(),
|
||||
)
|
||||
|
||||
def single_loop_generate(
|
||||
|
|
@ -559,6 +563,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
variable_loader=var_loader,
|
||||
context=contextvars.copy_context(),
|
||||
)
|
||||
|
||||
def _generate_worker(
|
||||
|
|
|
|||
|
|
@ -86,29 +86,12 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
|||
db.session.close()
|
||||
|
||||
# if only single iteration run is requested
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
# if only single iteration run is requested
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
# Handle single iteration or single loop run
|
||||
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
elif self.application_generate_entity.single_loop_run:
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
# if only single loop run is requested
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_loop_run.inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
single_iteration_run=self.application_generate_entity.single_iteration_run,
|
||||
single_loop_run=self.application_generate_entity.single_loop_run,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
|
|
|
|||
|
|
@ -51,30 +51,12 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(WorkflowAppConfig, app_config)
|
||||
|
||||
# if only single iteration run is requested
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
# if only single iteration run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
# if only single iteration or single loop run is requested
|
||||
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
|
||||
workflow=self._workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
elif self.application_generate_entity.single_loop_run:
|
||||
# if only single loop run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=self._workflow,
|
||||
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_loop_run.inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
single_iteration_run=self.application_generate_entity.single_iteration_run,
|
||||
single_loop_run=self.application_generate_entity.single_loop_run,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
|
|
@ -119,15 +120,81 @@ class WorkflowBasedAppRunner:
|
|||
|
||||
return graph
|
||||
|
||||
def _get_graph_and_variable_pool_of_single_iteration(
|
||||
def _prepare_single_node_execution(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
single_iteration_run: Any | None = None,
|
||||
single_loop_run: Any | None = None,
|
||||
) -> tuple[Graph, VariablePool, GraphRuntimeState]:
|
||||
"""
|
||||
Prepare graph, variable pool, and runtime state for single node execution
|
||||
(either single iteration or single loop).
|
||||
|
||||
Args:
|
||||
workflow: The workflow instance
|
||||
single_iteration_run: SingleIterationRunEntity if running single iteration, None otherwise
|
||||
single_loop_run: SingleLoopRunEntity if running single loop, None otherwise
|
||||
|
||||
Returns:
|
||||
A tuple containing (graph, variable_pool, graph_runtime_state)
|
||||
|
||||
Raises:
|
||||
ValueError: If neither single_iteration_run nor single_loop_run is specified
|
||||
"""
|
||||
# Create initial runtime state with variable pool containing environment variables
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
),
|
||||
start_at=time.time(),
|
||||
)
|
||||
|
||||
# Determine which type of single node execution and get graph/variable_pool
|
||||
if single_iteration_run:
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=workflow,
|
||||
node_id=single_iteration_run.node_id,
|
||||
user_inputs=dict(single_iteration_run.inputs),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
elif single_loop_run:
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=workflow,
|
||||
node_id=single_loop_run.node_id,
|
||||
user_inputs=dict(single_loop_run.inputs),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Neither single_iteration_run nor single_loop_run is specified")
|
||||
|
||||
# Return the graph, variable_pool, and the same graph_runtime_state used during graph creation
|
||||
# This ensures all nodes in the graph reference the same GraphRuntimeState instance
|
||||
return graph, variable_pool, graph_runtime_state
|
||||
|
||||
def _get_graph_and_variable_pool_for_single_node_run(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict,
|
||||
user_inputs: dict[str, Any],
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
node_type_filter_key: str, # 'iteration_id' or 'loop_id'
|
||||
node_type_label: str = "node", # 'iteration' or 'loop' for error messages
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get variable pool of single iteration
|
||||
Get graph and variable pool for single node execution (iteration or loop).
|
||||
|
||||
Args:
|
||||
workflow: The workflow instance
|
||||
node_id: The node ID to execute
|
||||
user_inputs: User inputs for the node
|
||||
graph_runtime_state: The graph runtime state
|
||||
node_type_filter_key: The key to filter nodes ('iteration_id' or 'loop_id')
|
||||
node_type_label: Label for error messages ('iteration' or 'loop')
|
||||
|
||||
Returns:
|
||||
A tuple containing (graph, variable_pool)
|
||||
"""
|
||||
# fetch workflow graph
|
||||
graph_config = workflow.graph_dict
|
||||
|
|
@ -145,18 +212,22 @@ class WorkflowBasedAppRunner:
|
|||
if not isinstance(graph_config.get("edges"), list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
|
||||
# filter nodes only in iteration
|
||||
# filter nodes only in the specified node type (iteration or loop)
|
||||
main_node_config = next((n for n in graph_config.get("nodes", []) if n.get("id") == node_id), None)
|
||||
start_node_id = main_node_config.get("data", {}).get("start_node_id") if main_node_config else None
|
||||
node_configs = [
|
||||
node
|
||||
for node in graph_config.get("nodes", [])
|
||||
if node.get("id") == node_id or node.get("data", {}).get("iteration_id", "") == node_id
|
||||
if node.get("id") == node_id
|
||||
or node.get("data", {}).get(node_type_filter_key, "") == node_id
|
||||
or (start_node_id and node.get("id") == start_node_id)
|
||||
]
|
||||
|
||||
graph_config["nodes"] = node_configs
|
||||
|
||||
node_ids = [node.get("id") for node in node_configs]
|
||||
|
||||
# filter edges only in iteration
|
||||
# filter edges only in the specified node type
|
||||
edge_configs = [
|
||||
edge
|
||||
for edge in graph_config.get("edges", [])
|
||||
|
|
@ -190,30 +261,26 @@ class WorkflowBasedAppRunner:
|
|||
raise ValueError("graph not found in workflow")
|
||||
|
||||
# fetch node config from node id
|
||||
iteration_node_config = None
|
||||
target_node_config = None
|
||||
for node in node_configs:
|
||||
if node.get("id") == node_id:
|
||||
iteration_node_config = node
|
||||
target_node_config = node
|
||||
break
|
||||
|
||||
if not iteration_node_config:
|
||||
raise ValueError("iteration node id not found in workflow graph")
|
||||
if not target_node_config:
|
||||
raise ValueError(f"{node_type_label} node id not found in workflow graph")
|
||||
|
||||
# Get node class
|
||||
node_type = NodeType(iteration_node_config.get("data", {}).get("type"))
|
||||
node_version = iteration_node_config.get("data", {}).get("version", "1")
|
||||
node_type = NodeType(target_node_config.get("data", {}).get("type"))
|
||||
node_version = target_node_config.get("data", {}).get("version", "1")
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
)
|
||||
# Use the variable pool from graph_runtime_state instead of creating a new one
|
||||
variable_pool = graph_runtime_state.variable_pool
|
||||
|
||||
try:
|
||||
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=workflow.graph_dict, config=iteration_node_config
|
||||
graph_config=workflow.graph_dict, config=target_node_config
|
||||
)
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
|
|
@ -234,120 +301,44 @@ class WorkflowBasedAppRunner:
|
|||
|
||||
return graph, variable_pool
|
||||
|
||||
def _get_graph_and_variable_pool_of_single_iteration(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict[str, Any],
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get variable pool of single iteration
|
||||
"""
|
||||
return self._get_graph_and_variable_pool_for_single_node_run(
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
node_type_filter_key="iteration_id",
|
||||
node_type_label="iteration",
|
||||
)
|
||||
|
||||
def _get_graph_and_variable_pool_of_single_loop(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict,
|
||||
user_inputs: dict[str, Any],
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get variable pool of single loop
|
||||
"""
|
||||
# fetch workflow graph
|
||||
graph_config = workflow.graph_dict
|
||||
if not graph_config:
|
||||
raise ValueError("workflow graph not found")
|
||||
|
||||
graph_config = cast(dict[str, Any], graph_config)
|
||||
|
||||
if "nodes" not in graph_config or "edges" not in graph_config:
|
||||
raise ValueError("nodes or edges not found in workflow graph")
|
||||
|
||||
if not isinstance(graph_config.get("nodes"), list):
|
||||
raise ValueError("nodes in workflow graph must be a list")
|
||||
|
||||
if not isinstance(graph_config.get("edges"), list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
|
||||
# filter nodes only in loop
|
||||
node_configs = [
|
||||
node
|
||||
for node in graph_config.get("nodes", [])
|
||||
if node.get("id") == node_id or node.get("data", {}).get("loop_id", "") == node_id
|
||||
]
|
||||
|
||||
graph_config["nodes"] = node_configs
|
||||
|
||||
node_ids = [node.get("id") for node in node_configs]
|
||||
|
||||
# filter edges only in loop
|
||||
edge_configs = [
|
||||
edge
|
||||
for edge in graph_config.get("edges", [])
|
||||
if (edge.get("source") is None or edge.get("source") in node_ids)
|
||||
and (edge.get("target") is None or edge.get("target") in node_ids)
|
||||
]
|
||||
|
||||
graph_config["edges"] = edge_configs
|
||||
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=self._app_id,
|
||||
workflow_id=workflow.id,
|
||||
graph_config=graph_config,
|
||||
user_id="",
|
||||
user_from=UserFrom.ACCOUNT.value,
|
||||
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
return self._get_graph_and_variable_pool_for_single_node_run(
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
node_type_filter_key="loop_id",
|
||||
node_type_label="loop",
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
||||
# fetch node config from node id
|
||||
loop_node_config = None
|
||||
for node in node_configs:
|
||||
if node.get("id") == node_id:
|
||||
loop_node_config = node
|
||||
break
|
||||
|
||||
if not loop_node_config:
|
||||
raise ValueError("loop node id not found in workflow graph")
|
||||
|
||||
# Get node class
|
||||
node_type = NodeType(loop_node_config.get("data", {}).get("type"))
|
||||
node_version = loop_node_config.get("data", {}).get("version", "1")
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
)
|
||||
|
||||
try:
|
||||
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=workflow.graph_dict, config=loop_node_config
|
||||
)
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
load_into_variable_pool(
|
||||
self._variable_loader,
|
||||
variable_pool=variable_pool,
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
)
|
||||
|
||||
WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id=workflow.tenant_id,
|
||||
)
|
||||
|
||||
return graph, variable_pool
|
||||
|
||||
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent):
|
||||
"""
|
||||
Handle event
|
||||
|
|
|
|||
|
|
@ -372,43 +372,16 @@ class IterationNode(Node):
|
|||
variable_mapping: dict[str, Sequence[str]] = {
|
||||
f"{node_id}.input_selector": typed_node_data.iterator_selector,
|
||||
}
|
||||
iteration_node_ids = set()
|
||||
|
||||
# init graph
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
|
||||
# Create minimal GraphInitParams for static analysis
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="",
|
||||
app_id="",
|
||||
workflow_id="",
|
||||
graph_config=graph_config,
|
||||
user_id="",
|
||||
user_from="",
|
||||
invoke_from="",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# Create minimal GraphRuntimeState for static analysis
|
||||
from core.workflow.entities import VariablePool
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(),
|
||||
start_at=0,
|
||||
)
|
||||
|
||||
# Create node factory for static analysis
|
||||
node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state)
|
||||
|
||||
iteration_graph = Graph.init(
|
||||
graph_config=graph_config,
|
||||
node_factory=node_factory,
|
||||
root_node_id=typed_node_data.start_node_id,
|
||||
)
|
||||
|
||||
if not iteration_graph:
|
||||
raise IterationGraphNotFoundError("iteration graph not found")
|
||||
# Find all nodes that belong to this loop
|
||||
nodes = graph_config.get("nodes", [])
|
||||
for node in nodes:
|
||||
node_data = node.get("data", {})
|
||||
if node_data.get("iteration_id") == node_id:
|
||||
in_iteration_node_id = node.get("id")
|
||||
if in_iteration_node_id:
|
||||
iteration_node_ids.add(in_iteration_node_id)
|
||||
|
||||
# Get node configs from graph_config instead of non-existent node_id_config_mapping
|
||||
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
|
||||
|
|
@ -444,9 +417,7 @@ class IterationNode(Node):
|
|||
variable_mapping.update(sub_node_variable_mapping)
|
||||
|
||||
# remove variable out from iteration
|
||||
variable_mapping = {
|
||||
key: value for key, value in variable_mapping.items() if value[0] not in iteration_graph.node_ids
|
||||
}
|
||||
variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in iteration_node_ids}
|
||||
|
||||
return variable_mapping
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
|
|
@ -127,11 +128,13 @@ class LoopNode(Node):
|
|||
try:
|
||||
reach_break_condition = False
|
||||
if break_conditions:
|
||||
_, _, reach_break_condition = condition_processor.process_conditions(
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
conditions=break_conditions,
|
||||
operator=logical_operator,
|
||||
)
|
||||
with contextlib.suppress(ValueError):
|
||||
_, _, reach_break_condition = condition_processor.process_conditions(
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
conditions=break_conditions,
|
||||
operator=logical_operator,
|
||||
)
|
||||
|
||||
if reach_break_condition:
|
||||
loop_count = 0
|
||||
cost_tokens = 0
|
||||
|
|
@ -295,42 +298,11 @@ class LoopNode(Node):
|
|||
|
||||
variable_mapping = {}
|
||||
|
||||
# init graph
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
# Extract loop node IDs statically from graph_config
|
||||
|
||||
# Create minimal GraphInitParams for static analysis
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="",
|
||||
app_id="",
|
||||
workflow_id="",
|
||||
graph_config=graph_config,
|
||||
user_id="",
|
||||
user_from="",
|
||||
invoke_from="",
|
||||
call_depth=0,
|
||||
)
|
||||
loop_node_ids = cls._extract_loop_node_ids_from_config(graph_config, node_id)
|
||||
|
||||
# Create minimal GraphRuntimeState for static analysis
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(),
|
||||
start_at=0,
|
||||
)
|
||||
|
||||
# Create node factory for static analysis
|
||||
node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state)
|
||||
|
||||
loop_graph = Graph.init(
|
||||
graph_config=graph_config,
|
||||
node_factory=node_factory,
|
||||
root_node_id=typed_node_data.start_node_id,
|
||||
)
|
||||
|
||||
if not loop_graph:
|
||||
raise ValueError("loop graph not found")
|
||||
|
||||
# Get node configs from graph_config instead of non-existent node_id_config_mapping
|
||||
# Get node configs from graph_config
|
||||
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
|
||||
for sub_node_id, sub_node_config in node_configs.items():
|
||||
if sub_node_config.get("data", {}).get("loop_id") != node_id:
|
||||
|
|
@ -371,12 +343,35 @@ class LoopNode(Node):
|
|||
variable_mapping[f"{node_id}.{loop_variable.label}"] = selector
|
||||
|
||||
# remove variable out from loop
|
||||
variable_mapping = {
|
||||
key: value for key, value in variable_mapping.items() if value[0] not in loop_graph.node_ids
|
||||
}
|
||||
variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in loop_node_ids}
|
||||
|
||||
return variable_mapping
|
||||
|
||||
@classmethod
|
||||
def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, Any], loop_node_id: str) -> set[str]:
|
||||
"""
|
||||
Extract node IDs that belong to a specific loop from graph configuration.
|
||||
|
||||
This method statically analyzes the graph configuration to find all nodes
|
||||
that are part of the specified loop, without creating actual node instances.
|
||||
|
||||
:param graph_config: the complete graph configuration
|
||||
:param loop_node_id: the ID of the loop node
|
||||
:return: set of node IDs that belong to the loop
|
||||
"""
|
||||
loop_node_ids = set()
|
||||
|
||||
# Find all nodes that belong to this loop
|
||||
nodes = graph_config.get("nodes", [])
|
||||
for node in nodes:
|
||||
node_data = node.get("data", {})
|
||||
if node_data.get("loop_id") == loop_node_id:
|
||||
node_id = node.get("id")
|
||||
if node_id:
|
||||
loop_node_ids.add(node_id)
|
||||
|
||||
return loop_node_ids
|
||||
|
||||
@staticmethod
|
||||
def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment:
|
||||
"""Get the appropriate segment type for a constant value."""
|
||||
|
|
|
|||
Loading…
Reference in New Issue