mirror of https://github.com/langgenius/dify.git
add iteration support
This commit is contained in:
parent
df133168dd
commit
ae351bd40e
|
|
@ -141,3 +141,15 @@ class VariablePool(BaseModel):
|
|||
return
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
self.variable_dictionary[selector[0]].pop(hash_key, None)
|
||||
|
||||
def remove_node(self, node_id: str, /):
|
||||
"""
|
||||
Remove all variables associated with a given node id.
|
||||
|
||||
Args:
|
||||
node_id (str): The node id to remove.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self.variable_dictionary.pop(node_id, None)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import uuid
|
||||
from typing import Optional, cast
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
|
@ -61,7 +62,7 @@ class Graph(BaseModel):
|
|||
|
||||
@classmethod
|
||||
def init(cls,
|
||||
graph_config: dict,
|
||||
graph_config: Mapping[str, Any],
|
||||
root_node_id: Optional[str] = None) -> "Graph":
|
||||
"""
|
||||
Init graph
|
||||
|
|
|
|||
|
|
@ -1,3 +1,6 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
|
|
@ -11,6 +14,7 @@ class GraphInitParams(BaseModel):
|
|||
app_id: str = Field(..., description="app id")
|
||||
workflow_type: WorkflowType = Field(..., description="workflow type")
|
||||
workflow_id: str = Field(..., description="workflow id")
|
||||
graph_config: Mapping[str, Any] = Field(..., description="graph config")
|
||||
user_id: str = Field(..., description="user id")
|
||||
user_from: UserFrom = Field(..., description="user from, account or end-user")
|
||||
invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger")
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@ import logging
|
|||
import queue
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Optional
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import Flask, current_app
|
||||
from uritemplate.variable import VariableValue
|
||||
|
|
@ -41,24 +41,29 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class GraphEngine:
|
||||
def __init__(self, tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_type: WorkflowType,
|
||||
workflow_id: str,
|
||||
user_id: str,
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
call_depth: int,
|
||||
graph: Graph,
|
||||
variable_pool: VariablePool,
|
||||
max_execution_steps: int,
|
||||
max_execution_time: int) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_type: WorkflowType,
|
||||
workflow_id: str,
|
||||
user_id: str,
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
call_depth: int,
|
||||
graph: Graph,
|
||||
graph_config: Mapping[str, Any],
|
||||
variable_pool: VariablePool,
|
||||
max_execution_steps: int,
|
||||
max_execution_time: int
|
||||
) -> None:
|
||||
self.graph = graph
|
||||
self.init_params = GraphInitParams(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_type=workflow_type,
|
||||
workflow_id=workflow_id,
|
||||
graph_config=graph_config,
|
||||
user_id=user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
|
|
|
|||
|
|
@ -2,14 +2,12 @@ from abc import ABC, abstractmethod
|
|||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunEvent
|
||||
from core.workflow.nodes.iterable_node_mixin import IterableNodeMixin
|
||||
|
||||
|
||||
class BaseNode(ABC):
|
||||
|
|
@ -26,6 +24,7 @@ class BaseNode(ABC):
|
|||
self.app_id = graph_init_params.app_id
|
||||
self.workflow_type = graph_init_params.workflow_type
|
||||
self.workflow_id = graph_init_params.workflow_id
|
||||
self.graph_config = graph_init_params.graph_config
|
||||
self.user_id = graph_init_params.user_id
|
||||
self.user_from = graph_init_params.user_from
|
||||
self.invoke_from = graph_init_params.invoke_from
|
||||
|
|
@ -100,37 +99,3 @@ class BaseNode(ABC):
|
|||
:return:
|
||||
"""
|
||||
return self._node_type
|
||||
|
||||
|
||||
class BaseIterationNode(BaseNode, IterableNodeMixin):
|
||||
@abstractmethod
|
||||
def _run(self) -> BaseIterationState:
|
||||
"""
|
||||
Run node
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def run(self) -> BaseIterationState:
|
||||
"""
|
||||
Run node entry
|
||||
:return:
|
||||
"""
|
||||
return self._run(variable_pool=self.graph_runtime_state.variable_pool)
|
||||
|
||||
def get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str:
|
||||
"""
|
||||
Get next iteration start node id based on the graph.
|
||||
:param graph: graph
|
||||
:return: next node id
|
||||
"""
|
||||
return self._get_next_iteration(variable_pool, state)
|
||||
|
||||
@abstractmethod
|
||||
def _get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str:
|
||||
"""
|
||||
Get next iteration start node id based on the graph.
|
||||
:param graph: graph
|
||||
:return: next node id
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -1,16 +1,25 @@
|
|||
import logging
|
||||
from typing import Any, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.entities.base_node_data_entities import BaseIterationState
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseIterationNode
|
||||
from core.workflow.nodes.iteration.entities import IterationNodeData, IterationState
|
||||
from core.workflow.graph_engine.entities.event import GraphRunFailedEvent, NodeRunSucceededEvent
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.event import RunCompletedEvent
|
||||
from core.workflow.nodes.iteration.entities import IterationNodeData
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
from core.workflow.workflow_entry import WorkflowRunFailedError
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class IterationNode(BaseIterationNode):
|
||||
|
||||
class IterationNode(BaseNode):
|
||||
"""
|
||||
Iteration Node.
|
||||
"""
|
||||
|
|
@ -22,92 +31,144 @@ class IterationNode(BaseIterationNode):
|
|||
Run the node.
|
||||
"""
|
||||
self.node_data = cast(IterationNodeData, self.node_data)
|
||||
iterator = variable_pool.get_any(self.node_data.iterator_selector)
|
||||
iterator_list_value = self.graph_runtime_state.variable_pool.get_any(self.node_data.iterator_selector)
|
||||
|
||||
if not isinstance(iterator, list):
|
||||
raise ValueError(f"Invalid iterator value: {iterator}, please provide a list.")
|
||||
if not isinstance(iterator_list_value, list):
|
||||
raise ValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
|
||||
|
||||
state = IterationState(iteration_node_id=self.node_id, index=-1, inputs={
|
||||
'iterator_selector': iterator
|
||||
}, outputs=[], metadata=IterationState.MetaData(
|
||||
iterator_length=len(iterator) if iterator is not None else 0
|
||||
))
|
||||
|
||||
self._set_current_iteration_variable(self.graph_runtime_state.variable_pool, state)
|
||||
return state
|
||||
root_node_id = self.node_data.start_node_id
|
||||
graph_config = self.graph_config
|
||||
|
||||
def _get_next_iteration(self, variable_pool: VariablePool, state: IterationState) -> NodeRunResult | str:
|
||||
"""
|
||||
Get next iteration start node id based on the graph.
|
||||
:param graph: graph
|
||||
:return: next node id
|
||||
"""
|
||||
# resolve current output
|
||||
self._resolve_current_output(variable_pool, state)
|
||||
# move to next iteration
|
||||
self._next_iteration(variable_pool, state)
|
||||
# init graph
|
||||
iteration_graph = Graph.init(
|
||||
graph_config=graph_config,
|
||||
root_node_id=root_node_id
|
||||
)
|
||||
|
||||
node_data = cast(IterationNodeData, self.node_data)
|
||||
if self._reached_iteration_limit(variable_pool, state):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={
|
||||
'output': jsonable_encoder(state.outputs)
|
||||
}
|
||||
if not iteration_graph:
|
||||
raise ValueError('iteration graph not found')
|
||||
|
||||
leaf_node_ids = iteration_graph.get_leaf_node_ids()
|
||||
iteration_leaf_node_ids = []
|
||||
for leaf_node_id in leaf_node_ids:
|
||||
node_config = iteration_graph.node_id_config_mapping.get(leaf_node_id)
|
||||
if not node_config:
|
||||
continue
|
||||
|
||||
leaf_node_iteration_id = node_config.get("data", {}).get("iteration_id")
|
||||
if not leaf_node_iteration_id:
|
||||
continue
|
||||
|
||||
if leaf_node_iteration_id != self.node_id:
|
||||
continue
|
||||
|
||||
iteration_leaf_node_ids.append(leaf_node_id)
|
||||
|
||||
# add condition of end nodes to root node
|
||||
iteration_graph.add_extra_edge(
|
||||
source_node_id=leaf_node_id,
|
||||
target_node_id=root_node_id,
|
||||
run_condition=RunCondition(
|
||||
type="condition",
|
||||
conditions=[
|
||||
Condition(
|
||||
variable_selector=[self.node_id, "index"],
|
||||
comparison_operator="<",
|
||||
value=len(iterator_list_value)
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
return node_data.start_node_id
|
||||
|
||||
def _set_current_iteration_variable(self, variable_pool: VariablePool, state: IterationState):
|
||||
"""
|
||||
Set current iteration variable.
|
||||
:variable_pool: variable pool
|
||||
"""
|
||||
node_data = cast(IterationNodeData, self.node_data)
|
||||
|
||||
variable_pool.add((self.node_id, 'index'), state.index)
|
||||
# get the iterator value
|
||||
iterator = variable_pool.get_any(node_data.iterator_selector)
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
if iterator is None or not isinstance(iterator, list):
|
||||
return
|
||||
|
||||
if state.index < len(iterator):
|
||||
variable_pool.add((self.node_id, 'item'), iterator[state.index])
|
||||
# append iteration variable (item, index) to variable pool
|
||||
variable_pool.add(
|
||||
[self.node_id, 'index'],
|
||||
0
|
||||
)
|
||||
variable_pool.add(
|
||||
[self.node_id, 'item'],
|
||||
iterator_list_value[0]
|
||||
)
|
||||
|
||||
def _next_iteration(self, variable_pool: VariablePool, state: IterationState):
|
||||
"""
|
||||
Move to next iteration.
|
||||
:param variable_pool: variable pool
|
||||
"""
|
||||
state.index += 1
|
||||
self._set_current_iteration_variable(variable_pool, state)
|
||||
# init graph engine
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
workflow_type=self.workflow_type,
|
||||
workflow_id=self.workflow_id,
|
||||
user_id=self.user_id,
|
||||
user_from=self.user_from,
|
||||
invoke_from=self.invoke_from,
|
||||
call_depth=self.workflow_call_depth,
|
||||
graph=iteration_graph,
|
||||
graph_config=graph_config,
|
||||
variable_pool=variable_pool,
|
||||
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
|
||||
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
|
||||
)
|
||||
|
||||
def _reached_iteration_limit(self, variable_pool: VariablePool, state: IterationState):
|
||||
"""
|
||||
Check if iteration limit is reached.
|
||||
:return: True if iteration limit is reached, False otherwise
|
||||
"""
|
||||
node_data = cast(IterationNodeData, self.node_data)
|
||||
iterator = variable_pool.get_any(node_data.iterator_selector)
|
||||
try:
|
||||
# run workflow
|
||||
rst = graph_engine.run()
|
||||
outputs: list[Any] = []
|
||||
for event in rst:
|
||||
yield event
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
# handle iteration run result
|
||||
if event.node_id in iteration_leaf_node_ids:
|
||||
# append to iteration output variable list
|
||||
outputs.append(variable_pool.get_any(self.node_data.output_selector))
|
||||
|
||||
if iterator is None or not isinstance(iterator, list):
|
||||
return True
|
||||
# remove all nodes outputs from variable pool
|
||||
for node_id in iteration_graph.node_ids:
|
||||
variable_pool.remove_node(node_id)
|
||||
|
||||
return state.index >= len(iterator)
|
||||
|
||||
def _resolve_current_output(self, variable_pool: VariablePool, state: IterationState):
|
||||
"""
|
||||
Resolve current output.
|
||||
:param variable_pool: variable pool
|
||||
"""
|
||||
output_selector = cast(IterationNodeData, self.node_data).output_selector
|
||||
output = variable_pool.get_any(output_selector)
|
||||
# clear the output for this iteration
|
||||
variable_pool.remove([self.node_id] + output_selector[1:])
|
||||
state.current_output = output
|
||||
if output is not None:
|
||||
state.outputs.append(output)
|
||||
# move to next iteration
|
||||
next_index = variable_pool.get_any([self.node_id, 'index']) + 1
|
||||
variable_pool.add(
|
||||
[self.node_id, 'index'],
|
||||
next_index
|
||||
)
|
||||
|
||||
variable_pool.add(
|
||||
[self.node_id, 'item'],
|
||||
iterator_list_value[next_index]
|
||||
)
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
# iteration run failed
|
||||
raise WorkflowRunFailedError(event.reason)
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={
|
||||
'output': jsonable_encoder(outputs)
|
||||
}
|
||||
)
|
||||
)
|
||||
except WorkflowRunFailedError as e:
|
||||
# iteration run failed
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
# iteration run failed
|
||||
logger.exception("Iteration run failed")
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
)
|
||||
)
|
||||
finally:
|
||||
# remove iteration variable (item, index) from variable pool after iteration run completed
|
||||
variable_pool.remove([self.node_id, 'index'])
|
||||
variable_pool.remove([self.node_id, 'item'])
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: IterationNodeData) -> dict[str, list[str]]:
|
||||
|
|
@ -119,19 +180,3 @@ class IterationNode(BaseIterationNode):
|
|||
return {
|
||||
'input_selector': node_data.iterator_selector,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]:
|
||||
"""
|
||||
Get conditions.
|
||||
"""
|
||||
node_id = node_config.get('id')
|
||||
if not node_id:
|
||||
return []
|
||||
|
||||
return [Condition(
|
||||
variable_selector=[node_id, 'index'],
|
||||
comparison_operator="≤",
|
||||
value_type="value_selector",
|
||||
value_selector=node_config.get('data', {}).get('iterator_selector')
|
||||
)]
|
||||
|
|
|
|||
|
|
@ -1,13 +1,12 @@
|
|||
from typing import Any
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseIterationNode
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.loop.entities import LoopNodeData, LoopState
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
|
||||
|
||||
class LoopNode(BaseIterationNode):
|
||||
class LoopNode(BaseNode):
|
||||
"""
|
||||
Loop Node.
|
||||
"""
|
||||
|
|
@ -17,12 +16,6 @@ class LoopNode(BaseIterationNode):
|
|||
def _run(self) -> LoopState:
|
||||
return super()._run()
|
||||
|
||||
def _get_next_iteration(self, variable_loop: VariablePool) -> NodeRunResult | str:
|
||||
"""
|
||||
Get next iteration start node id based on the graph.
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -98,9 +98,10 @@ class WorkflowEntry:
|
|||
invoke_from=invoke_from,
|
||||
call_depth=call_depth,
|
||||
graph=graph,
|
||||
graph_config=graph_config,
|
||||
variable_pool=variable_pool,
|
||||
max_execution_steps=current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS"),
|
||||
max_execution_time=current_app.config.get("WORKFLOW_MAX_EXECUTION_TIME")
|
||||
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
|
||||
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
|
||||
)
|
||||
|
||||
# init workflow run
|
||||
|
|
@ -155,7 +156,6 @@ class WorkflowEntry:
|
|||
)
|
||||
|
||||
predecessor_node: BaseNode | None = None
|
||||
current_iteration_node: BaseIterationNode | None = None
|
||||
has_entry_node = False
|
||||
max_execution_steps = dify_config.WORKFLOW_MAX_EXECUTION_STEPS
|
||||
max_execution_time = dify_config.WORKFLOW_MAX_EXECUTION_TIME
|
||||
|
|
@ -610,7 +610,7 @@ class WorkflowEntry:
|
|||
for callback in callbacks:
|
||||
callback.on_workflow_run_started()
|
||||
|
||||
def _workflow_run_success(self, callbacks: Sequence[WorkflowCallback]) -> None:
|
||||
def _workflow_run_success(self, callbacks: Sequence[BaseWorkflowCallback]) -> None:
|
||||
"""
|
||||
Workflow run success
|
||||
:param callbacks: workflow callbacks
|
||||
|
|
|
|||
|
|
@ -211,6 +211,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove, mocker):
|
|||
app_id="222",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="333",
|
||||
graph_config=graph_config,
|
||||
user_id="444",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
|
|
@ -372,6 +373,7 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove):
|
|||
app_id="222",
|
||||
workflow_type=WorkflowType.CHAT,
|
||||
workflow_id="333",
|
||||
graph_config=graph_config,
|
||||
user_id="444",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
|
|
@ -531,6 +533,7 @@ def test_run_branch(mock_close, mock_remove):
|
|||
app_id="222",
|
||||
workflow_type=WorkflowType.CHAT,
|
||||
workflow_id="333",
|
||||
graph_config=graph_config,
|
||||
user_id="444",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ def test_execute_answer():
|
|||
app_id='1',
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id='1',
|
||||
graph_config=graph_config,
|
||||
user_id='1',
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ def test_execute_if_else_result_true():
|
|||
app_id='1',
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id='1',
|
||||
graph_config=graph_config,
|
||||
user_id='1',
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
|
|
@ -237,6 +238,7 @@ def test_execute_if_else_result_false():
|
|||
app_id='1',
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id='1',
|
||||
graph_config=graph_config,
|
||||
user_id='1',
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
|
|
|
|||
Loading…
Reference in New Issue