add iteration support

This commit is contained in:
takatost 2024-07-25 23:07:27 +08:00
parent df133168dd
commit ae351bd40e
11 changed files with 193 additions and 162 deletions

View File

@ -141,3 +141,15 @@ class VariablePool(BaseModel):
return
hash_key = hash(tuple(selector[1:]))
self.variable_dictionary[selector[0]].pop(hash_key, None)
def remove_node(self, node_id: str, /):
"""
Remove all variables associated with a given node id.
Args:
node_id (str): The node id to remove.
Returns:
None
"""
self.variable_dictionary.pop(node_id, None)

View File

@ -1,5 +1,6 @@
import uuid
from typing import Optional, cast
from collections.abc import Mapping
from typing import Any, Optional, cast
from pydantic import BaseModel, Field
@ -61,7 +62,7 @@ class Graph(BaseModel):
@classmethod
def init(cls,
graph_config: dict,
graph_config: Mapping[str, Any],
root_node_id: Optional[str] = None) -> "Graph":
"""
Init graph

View File

@ -1,3 +1,6 @@
from collections.abc import Mapping
from typing import Any
from pydantic import BaseModel, Field
from core.app.entities.app_invoke_entities import InvokeFrom
@ -11,6 +14,7 @@ class GraphInitParams(BaseModel):
app_id: str = Field(..., description="app id")
workflow_type: WorkflowType = Field(..., description="workflow type")
workflow_id: str = Field(..., description="workflow id")
graph_config: Mapping[str, Any] = Field(..., description="graph config")
user_id: str = Field(..., description="user id")
user_from: UserFrom = Field(..., description="user from, account or end-user")
invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger")

View File

@ -2,8 +2,8 @@ import logging
import queue
import threading
import time
from collections.abc import Generator
from typing import Optional
from collections.abc import Generator, Mapping
from typing import Any, Optional
from flask import Flask, current_app
from uritemplate.variable import VariableValue
@ -41,24 +41,29 @@ logger = logging.getLogger(__name__)
class GraphEngine:
def __init__(self, tenant_id: str,
app_id: str,
workflow_type: WorkflowType,
workflow_id: str,
user_id: str,
user_from: UserFrom,
invoke_from: InvokeFrom,
call_depth: int,
graph: Graph,
variable_pool: VariablePool,
max_execution_steps: int,
max_execution_time: int) -> None:
def __init__(
self,
tenant_id: str,
app_id: str,
workflow_type: WorkflowType,
workflow_id: str,
user_id: str,
user_from: UserFrom,
invoke_from: InvokeFrom,
call_depth: int,
graph: Graph,
graph_config: Mapping[str, Any],
variable_pool: VariablePool,
max_execution_steps: int,
max_execution_time: int
) -> None:
self.graph = graph
self.init_params = GraphInitParams(
tenant_id=tenant_id,
app_id=app_id,
workflow_type=workflow_type,
workflow_id=workflow_id,
graph_config=graph_config,
user_id=user_id,
user_from=user_from,
invoke_from=invoke_from,

View File

@ -2,14 +2,12 @@ from abc import ABC, abstractmethod
from collections.abc import Generator, Mapping
from typing import Any, Optional
from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.event import RunCompletedEvent, RunEvent
from core.workflow.nodes.iterable_node_mixin import IterableNodeMixin
class BaseNode(ABC):
@ -26,6 +24,7 @@ class BaseNode(ABC):
self.app_id = graph_init_params.app_id
self.workflow_type = graph_init_params.workflow_type
self.workflow_id = graph_init_params.workflow_id
self.graph_config = graph_init_params.graph_config
self.user_id = graph_init_params.user_id
self.user_from = graph_init_params.user_from
self.invoke_from = graph_init_params.invoke_from
@ -100,37 +99,3 @@ class BaseNode(ABC):
:return:
"""
return self._node_type
class BaseIterationNode(BaseNode, IterableNodeMixin):
@abstractmethod
def _run(self) -> BaseIterationState:
"""
Run node
:return:
"""
raise NotImplementedError
def run(self) -> BaseIterationState:
"""
Run node entry
:return:
"""
return self._run(variable_pool=self.graph_runtime_state.variable_pool)
def get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str:
"""
Get next iteration start node id based on the graph.
:param graph: graph
:return: next node id
"""
return self._get_next_iteration(variable_pool, state)
@abstractmethod
def _get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str:
"""
Get next iteration start node id based on the graph.
:param graph: graph
:return: next node id
"""
raise NotImplementedError

View File

@ -1,16 +1,25 @@
import logging
from typing import Any, cast
from configs import dify_config
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.base_node_data_entities import BaseIterationState
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseIterationNode
from core.workflow.nodes.iteration.entities import IterationNodeData, IterationState
from core.workflow.graph_engine.entities.event import GraphRunFailedEvent, NodeRunSucceededEvent
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.event import RunCompletedEvent
from core.workflow.nodes.iteration.entities import IterationNodeData
from core.workflow.utils.condition.entities import Condition
from core.workflow.workflow_entry import WorkflowRunFailedError
from models.workflow import WorkflowNodeExecutionStatus
logger = logging.getLogger(__name__)
class IterationNode(BaseIterationNode):
class IterationNode(BaseNode):
"""
Iteration Node.
"""
@ -22,92 +31,144 @@ class IterationNode(BaseIterationNode):
Run the node.
"""
self.node_data = cast(IterationNodeData, self.node_data)
iterator = variable_pool.get_any(self.node_data.iterator_selector)
iterator_list_value = self.graph_runtime_state.variable_pool.get_any(self.node_data.iterator_selector)
if not isinstance(iterator, list):
raise ValueError(f"Invalid iterator value: {iterator}, please provide a list.")
if not isinstance(iterator_list_value, list):
raise ValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
state = IterationState(iteration_node_id=self.node_id, index=-1, inputs={
'iterator_selector': iterator
}, outputs=[], metadata=IterationState.MetaData(
iterator_length=len(iterator) if iterator is not None else 0
))
self._set_current_iteration_variable(self.graph_runtime_state.variable_pool, state)
return state
root_node_id = self.node_data.start_node_id
graph_config = self.graph_config
def _get_next_iteration(self, variable_pool: VariablePool, state: IterationState) -> NodeRunResult | str:
"""
Get next iteration start node id based on the graph.
:param graph: graph
:return: next node id
"""
# resolve current output
self._resolve_current_output(variable_pool, state)
# move to next iteration
self._next_iteration(variable_pool, state)
# init graph
iteration_graph = Graph.init(
graph_config=graph_config,
root_node_id=root_node_id
)
node_data = cast(IterationNodeData, self.node_data)
if self._reached_iteration_limit(variable_pool, state):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
'output': jsonable_encoder(state.outputs)
}
if not iteration_graph:
raise ValueError('iteration graph not found')
leaf_node_ids = iteration_graph.get_leaf_node_ids()
iteration_leaf_node_ids = []
for leaf_node_id in leaf_node_ids:
node_config = iteration_graph.node_id_config_mapping.get(leaf_node_id)
if not node_config:
continue
leaf_node_iteration_id = node_config.get("data", {}).get("iteration_id")
if not leaf_node_iteration_id:
continue
if leaf_node_iteration_id != self.node_id:
continue
iteration_leaf_node_ids.append(leaf_node_id)
# add condition of end nodes to root node
iteration_graph.add_extra_edge(
source_node_id=leaf_node_id,
target_node_id=root_node_id,
run_condition=RunCondition(
type="condition",
conditions=[
Condition(
variable_selector=[self.node_id, "index"],
comparison_operator="<",
value=len(iterator_list_value)
)
]
)
)
return node_data.start_node_id
def _set_current_iteration_variable(self, variable_pool: VariablePool, state: IterationState):
"""
Set current iteration variable.
:variable_pool: variable pool
"""
node_data = cast(IterationNodeData, self.node_data)
variable_pool.add((self.node_id, 'index'), state.index)
# get the iterator value
iterator = variable_pool.get_any(node_data.iterator_selector)
variable_pool = self.graph_runtime_state.variable_pool
if iterator is None or not isinstance(iterator, list):
return
if state.index < len(iterator):
variable_pool.add((self.node_id, 'item'), iterator[state.index])
# append iteration variable (item, index) to variable pool
variable_pool.add(
[self.node_id, 'index'],
0
)
variable_pool.add(
[self.node_id, 'item'],
iterator_list_value[0]
)
def _next_iteration(self, variable_pool: VariablePool, state: IterationState):
"""
Move to next iteration.
:param variable_pool: variable pool
"""
state.index += 1
self._set_current_iteration_variable(variable_pool, state)
# init graph engine
graph_engine = GraphEngine(
tenant_id=self.tenant_id,
app_id=self.app_id,
workflow_type=self.workflow_type,
workflow_id=self.workflow_id,
user_id=self.user_id,
user_from=self.user_from,
invoke_from=self.invoke_from,
call_depth=self.workflow_call_depth,
graph=iteration_graph,
graph_config=graph_config,
variable_pool=variable_pool,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
)
def _reached_iteration_limit(self, variable_pool: VariablePool, state: IterationState):
"""
Check if iteration limit is reached.
:return: True if iteration limit is reached, False otherwise
"""
node_data = cast(IterationNodeData, self.node_data)
iterator = variable_pool.get_any(node_data.iterator_selector)
try:
# run workflow
rst = graph_engine.run()
outputs: list[Any] = []
for event in rst:
yield event
if isinstance(event, NodeRunSucceededEvent):
# handle iteration run result
if event.node_id in iteration_leaf_node_ids:
# append to iteration output variable list
outputs.append(variable_pool.get_any(self.node_data.output_selector))
if iterator is None or not isinstance(iterator, list):
return True
# remove all nodes outputs from variable pool
for node_id in iteration_graph.node_ids:
variable_pool.remove_node(node_id)
return state.index >= len(iterator)
def _resolve_current_output(self, variable_pool: VariablePool, state: IterationState):
"""
Resolve current output.
:param variable_pool: variable pool
"""
output_selector = cast(IterationNodeData, self.node_data).output_selector
output = variable_pool.get_any(output_selector)
# clear the output for this iteration
variable_pool.remove([self.node_id] + output_selector[1:])
state.current_output = output
if output is not None:
state.outputs.append(output)
# move to next iteration
next_index = variable_pool.get_any([self.node_id, 'index']) + 1
variable_pool.add(
[self.node_id, 'index'],
next_index
)
variable_pool.add(
[self.node_id, 'item'],
iterator_list_value[next_index]
)
elif isinstance(event, GraphRunFailedEvent):
# iteration run failed
raise WorkflowRunFailedError(event.reason)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
'output': jsonable_encoder(outputs)
}
)
)
except WorkflowRunFailedError as e:
# iteration run failed
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
)
)
except Exception as e:
# iteration run failed
logger.exception("Iteration run failed")
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
)
)
finally:
# remove iteration variable (item, index) from variable pool after iteration run completed
variable_pool.remove([self.node_id, 'index'])
variable_pool.remove([self.node_id, 'item'])
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: IterationNodeData) -> dict[str, list[str]]:
@ -119,19 +180,3 @@ class IterationNode(BaseIterationNode):
return {
'input_selector': node_data.iterator_selector,
}
@classmethod
def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]:
"""
Get conditions.
"""
node_id = node_config.get('id')
if not node_id:
return []
return [Condition(
variable_selector=[node_id, 'index'],
comparison_operator="",
value_type="value_selector",
value_selector=node_config.get('data', {}).get('iterator_selector')
)]

View File

@ -1,13 +1,12 @@
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseIterationNode
from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.loop.entities import LoopNodeData, LoopState
from core.workflow.utils.condition.entities import Condition
class LoopNode(BaseIterationNode):
class LoopNode(BaseNode):
"""
Loop Node.
"""
@ -17,12 +16,6 @@ class LoopNode(BaseIterationNode):
def _run(self) -> LoopState:
return super()._run()
def _get_next_iteration(self, variable_loop: VariablePool) -> NodeRunResult | str:
"""
Get next iteration start node id based on the graph.
"""
pass
@classmethod
def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]:
"""

View File

@ -98,9 +98,10 @@ class WorkflowEntry:
invoke_from=invoke_from,
call_depth=call_depth,
graph=graph,
graph_config=graph_config,
variable_pool=variable_pool,
max_execution_steps=current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS"),
max_execution_time=current_app.config.get("WORKFLOW_MAX_EXECUTION_TIME")
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
)
# init workflow run
@ -155,7 +156,6 @@ class WorkflowEntry:
)
predecessor_node: BaseNode | None = None
current_iteration_node: BaseIterationNode | None = None
has_entry_node = False
max_execution_steps = dify_config.WORKFLOW_MAX_EXECUTION_STEPS
max_execution_time = dify_config.WORKFLOW_MAX_EXECUTION_TIME
@ -610,7 +610,7 @@ class WorkflowEntry:
for callback in callbacks:
callback.on_workflow_run_started()
def _workflow_run_success(self, callbacks: Sequence[WorkflowCallback]) -> None:
def _workflow_run_success(self, callbacks: Sequence[BaseWorkflowCallback]) -> None:
"""
Workflow run success
:param callbacks: workflow callbacks

View File

@ -211,6 +211,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove, mocker):
app_id="222",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="333",
graph_config=graph_config,
user_id="444",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP,
@ -372,6 +373,7 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove):
app_id="222",
workflow_type=WorkflowType.CHAT,
workflow_id="333",
graph_config=graph_config,
user_id="444",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP,
@ -531,6 +533,7 @@ def test_run_branch(mock_close, mock_remove):
app_id="222",
workflow_type=WorkflowType.CHAT,
workflow_id="333",
graph_config=graph_config,
user_id="444",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP,

View File

@ -46,6 +46,7 @@ def test_execute_answer():
app_id='1',
workflow_type=WorkflowType.WORKFLOW,
workflow_id='1',
graph_config=graph_config,
user_id='1',
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,

View File

@ -34,6 +34,7 @@ def test_execute_if_else_result_true():
app_id='1',
workflow_type=WorkflowType.WORKFLOW,
workflow_id='1',
graph_config=graph_config,
user_id='1',
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
@ -237,6 +238,7 @@ def test_execute_if_else_result_false():
app_id='1',
workflow_type=WorkflowType.WORKFLOW,
workflow_id='1',
graph_config=graph_config,
user_id='1',
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,