mirror of
https://github.com/langgenius/dify.git
synced 2026-04-29 04:26:30 +08:00
fix(api):LLM node losing Flask context during parallel iterations (#26098)
This commit is contained in:
parent
25c69ac540
commit
a4acc64afd
@ -1,9 +1,11 @@
|
|||||||
|
import contextvars
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Generator, Mapping, Sequence
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
|
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import TYPE_CHECKING, Any, NewType, cast
|
from typing import TYPE_CHECKING, Any, NewType, cast
|
||||||
|
|
||||||
|
from flask import Flask, current_app
|
||||||
from typing_extensions import TypeIs
|
from typing_extensions import TypeIs
|
||||||
|
|
||||||
from core.variables import IntegerVariable, NoneSegment
|
from core.variables import IntegerVariable, NoneSegment
|
||||||
@ -35,6 +37,7 @@ from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
|||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
|
from libs.flask_utils import preserve_flask_contexts
|
||||||
|
|
||||||
from .exc import (
|
from .exc import (
|
||||||
InvalidIteratorValueError,
|
InvalidIteratorValueError,
|
||||||
@ -239,6 +242,8 @@ class IterationNode(Node):
|
|||||||
self._execute_single_iteration_parallel,
|
self._execute_single_iteration_parallel,
|
||||||
index=index,
|
index=index,
|
||||||
item=item,
|
item=item,
|
||||||
|
flask_app=current_app._get_current_object(), # type: ignore
|
||||||
|
context_vars=contextvars.copy_context(),
|
||||||
)
|
)
|
||||||
future_to_index[future] = index
|
future_to_index[future] = index
|
||||||
|
|
||||||
@ -281,26 +286,29 @@ class IterationNode(Node):
|
|||||||
self,
|
self,
|
||||||
index: int,
|
index: int,
|
||||||
item: object,
|
item: object,
|
||||||
|
flask_app: Flask,
|
||||||
|
context_vars: contextvars.Context,
|
||||||
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int]:
|
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int]:
|
||||||
"""Execute a single iteration in parallel mode and return results."""
|
"""Execute a single iteration in parallel mode and return results."""
|
||||||
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
|
||||||
events: list[GraphNodeEventBase] = []
|
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||||
outputs_temp: list[object] = []
|
events: list[GraphNodeEventBase] = []
|
||||||
|
outputs_temp: list[object] = []
|
||||||
|
|
||||||
graph_engine = self._create_graph_engine(index, item)
|
graph_engine = self._create_graph_engine(index, item)
|
||||||
|
|
||||||
# Collect events instead of yielding them directly
|
# Collect events instead of yielding them directly
|
||||||
for event in self._run_single_iter(
|
for event in self._run_single_iter(
|
||||||
variable_pool=graph_engine.graph_runtime_state.variable_pool,
|
variable_pool=graph_engine.graph_runtime_state.variable_pool,
|
||||||
outputs=outputs_temp,
|
outputs=outputs_temp,
|
||||||
graph_engine=graph_engine,
|
graph_engine=graph_engine,
|
||||||
):
|
):
|
||||||
events.append(event)
|
events.append(event)
|
||||||
|
|
||||||
# Get the output value from the temporary outputs list
|
# Get the output value from the temporary outputs list
|
||||||
output_value = outputs_temp[0] if outputs_temp else None
|
output_value = outputs_temp[0] if outputs_temp else None
|
||||||
|
|
||||||
return iter_start_at, events, output_value, graph_engine.graph_runtime_state.total_tokens
|
return iter_start_at, events, output_value, graph_engine.graph_runtime_state.total_tokens
|
||||||
|
|
||||||
def _handle_iteration_success(
|
def _handle_iteration_success(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user