refactor(graph_engine): Take GraphRuntimeState out of GraphEngine (#21882)

This commit is contained in:
-LAN- 2025-07-07 13:15:18 +08:00 committed by GitHub
parent de22648b9f
commit 8f723697ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 33 additions and 14 deletions

View File

@ -103,7 +103,7 @@ class GraphEngine:
call_depth: int, call_depth: int,
graph: Graph, graph: Graph,
graph_config: Mapping[str, Any], graph_config: Mapping[str, Any],
variable_pool: VariablePool, graph_runtime_state: GraphRuntimeState,
max_execution_steps: int, max_execution_steps: int,
max_execution_time: int, max_execution_time: int,
thread_pool_id: Optional[str] = None, thread_pool_id: Optional[str] = None,
@ -140,7 +140,7 @@ class GraphEngine:
call_depth=call_depth, call_depth=call_depth,
) )
self.graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) self.graph_runtime_state = graph_runtime_state
self.max_execution_steps = max_execution_steps self.max_execution_steps = max_execution_steps
self.max_execution_time = max_execution_time self.max_execution_time = max_execution_time

View File

@ -1,5 +1,6 @@
import contextvars import contextvars
import logging import logging
import time
import uuid import uuid
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from concurrent.futures import Future, wait from concurrent.futures import Future, wait
@ -133,8 +134,11 @@ class IterationNode(BaseNode[IterationNodeData]):
variable_pool.add([self.node_id, "item"], iterator_list_value[0]) variable_pool.add([self.node_id, "item"], iterator_list_value[0])
# init graph engine # init graph engine
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.graph_engine import GraphEngine, GraphEngineThreadPool from core.workflow.graph_engine.graph_engine import GraphEngine, GraphEngineThreadPool
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine( graph_engine = GraphEngine(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
app_id=self.app_id, app_id=self.app_id,
@ -146,7 +150,7 @@ class IterationNode(BaseNode[IterationNodeData]):
call_depth=self.workflow_call_depth, call_depth=self.workflow_call_depth,
graph=iteration_graph, graph=iteration_graph,
graph_config=graph_config, graph_config=graph_config,
variable_pool=variable_pool, graph_runtime_state=graph_runtime_state,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME, max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
thread_pool_id=self.thread_pool_id, thread_pool_id=self.thread_pool_id,

View File

@ -1,5 +1,6 @@
import json import json
import logging import logging
import time
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, Literal, cast from typing import TYPE_CHECKING, Any, Literal, cast
@ -101,8 +102,11 @@ class LoopNode(BaseNode[LoopNodeData]):
loop_variable_selectors[loop_variable.label] = variable_selector loop_variable_selectors[loop_variable.label] = variable_selector
inputs[loop_variable.label] = processed_segment.value inputs[loop_variable.label] = processed_segment.value
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.graph_engine import GraphEngine from core.workflow.graph_engine.graph_engine import GraphEngine
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine( graph_engine = GraphEngine(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
app_id=self.app_id, app_id=self.app_id,
@ -114,7 +118,7 @@ class LoopNode(BaseNode[LoopNodeData]):
call_depth=self.workflow_call_depth, call_depth=self.workflow_call_depth,
graph=loop_graph, graph=loop_graph,
graph_config=self.graph_config, graph_config=self.graph_config,
variable_pool=variable_pool, graph_runtime_state=graph_runtime_state,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME, max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
thread_pool_id=self.thread_pool_id, thread_pool_id=self.thread_pool_id,

View File

@ -69,6 +69,7 @@ class WorkflowEntry:
raise ValueError("Max workflow call depth {} reached.".format(workflow_call_max_depth)) raise ValueError("Max workflow call depth {} reached.".format(workflow_call_max_depth))
# init workflow run state # init workflow run state
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
self.graph_engine = GraphEngine( self.graph_engine = GraphEngine(
tenant_id=tenant_id, tenant_id=tenant_id,
app_id=app_id, app_id=app_id,
@ -80,7 +81,7 @@ class WorkflowEntry:
call_depth=call_depth, call_depth=call_depth,
graph=graph, graph=graph,
graph_config=graph_config, graph_config=graph_config,
variable_pool=variable_pool, graph_runtime_state=graph_runtime_state,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME, max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
thread_pool_id=thread_pool_id, thread_pool_id=thread_pool_id,

View File

@ -1,3 +1,4 @@
import time
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
@ -19,6 +20,7 @@ from core.workflow.graph_engine.entities.event import (
NodeRunSucceededEvent, NodeRunSucceededEvent,
) )
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.graph_engine.graph_engine import GraphEngine from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.code.code_node import CodeNode
@ -172,6 +174,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"} system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"}
) )
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine( graph_engine = GraphEngine(
tenant_id="111", tenant_id="111",
app_id="222", app_id="222",
@ -183,7 +186,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
invoke_from=InvokeFrom.WEB_APP, invoke_from=InvokeFrom.WEB_APP,
call_depth=0, call_depth=0,
graph=graph, graph=graph,
variable_pool=variable_pool, graph_runtime_state=graph_runtime_state,
max_execution_steps=500, max_execution_steps=500,
max_execution_time=1200, max_execution_time=1200,
) )
@ -299,6 +302,7 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove):
user_inputs={}, user_inputs={},
) )
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine( graph_engine = GraphEngine(
tenant_id="111", tenant_id="111",
app_id="222", app_id="222",
@ -310,7 +314,7 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove):
invoke_from=InvokeFrom.WEB_APP, invoke_from=InvokeFrom.WEB_APP,
call_depth=0, call_depth=0,
graph=graph, graph=graph,
variable_pool=variable_pool, graph_runtime_state=graph_runtime_state,
max_execution_steps=500, max_execution_steps=500,
max_execution_time=1200, max_execution_time=1200,
) )
@ -479,6 +483,7 @@ def test_run_branch(mock_close, mock_remove):
user_inputs={"uid": "takato"}, user_inputs={"uid": "takato"},
) )
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine( graph_engine = GraphEngine(
tenant_id="111", tenant_id="111",
app_id="222", app_id="222",
@ -490,7 +495,7 @@ def test_run_branch(mock_close, mock_remove):
invoke_from=InvokeFrom.WEB_APP, invoke_from=InvokeFrom.WEB_APP,
call_depth=0, call_depth=0,
graph=graph, graph=graph,
variable_pool=variable_pool, graph_runtime_state=graph_runtime_state,
max_execution_steps=500, max_execution_steps=500,
max_execution_time=1200, max_execution_time=1200,
) )
@ -813,6 +818,7 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app):
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"} system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"}
) )
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine( graph_engine = GraphEngine(
tenant_id="111", tenant_id="111",
app_id="222", app_id="222",
@ -824,7 +830,7 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app):
invoke_from=InvokeFrom.WEB_APP, invoke_from=InvokeFrom.WEB_APP,
call_depth=0, call_depth=0,
graph=graph, graph=graph,
variable_pool=variable_pool, graph_runtime_state=graph_runtime_state,
max_execution_steps=500, max_execution_steps=500,
max_execution_time=1200, max_execution_time=1200,
) )

View File

@ -1,7 +1,9 @@
import time
from unittest.mock import patch from unittest.mock import patch
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunResult, WorkflowNodeExecutionMetadataKey from core.workflow.entities.node_entities import NodeRunResult, WorkflowNodeExecutionMetadataKey
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import ( from core.workflow.graph_engine.entities.event import (
@ -11,6 +13,7 @@ from core.workflow.graph_engine.entities.event import (
NodeRunStreamChunkEvent, NodeRunStreamChunkEvent,
) )
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.graph_engine import GraphEngine from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.nodes.llm.node import LLMNode from core.workflow.nodes.llm.node import LLMNode
@ -163,15 +166,16 @@ class ContinueOnErrorTestHelper:
def create_test_graph_engine(graph_config: dict, user_inputs: dict | None = None): def create_test_graph_engine(graph_config: dict, user_inputs: dict | None = None):
"""Helper method to create a graph engine instance for testing""" """Helper method to create a graph engine instance for testing"""
graph = Graph.init(graph_config=graph_config) graph = Graph.init(graph_config=graph_config)
variable_pool = { variable_pool = VariablePool(
"system_variables": { system_variables={
SystemVariableKey.QUERY: "clear", SystemVariableKey.QUERY: "clear",
SystemVariableKey.FILES: [], SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa", SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "aaa", SystemVariableKey.USER_ID: "aaa",
}, },
"user_inputs": user_inputs or {"uid": "takato"}, user_inputs=user_inputs or {"uid": "takato"},
} )
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
return GraphEngine( return GraphEngine(
tenant_id="111", tenant_id="111",
@ -184,7 +188,7 @@ class ContinueOnErrorTestHelper:
invoke_from=InvokeFrom.WEB_APP, invoke_from=InvokeFrom.WEB_APP,
call_depth=0, call_depth=0,
graph=graph, graph=graph,
variable_pool=variable_pool, graph_runtime_state=graph_runtime_state,
max_execution_steps=500, max_execution_steps=500,
max_execution_time=1200, max_execution_time=1200,
) )