From 3b225c01dac41b4ab537f856a95d728e1327ad02 Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Mon, 19 Jan 2026 12:18:51 +0800 Subject: [PATCH] refactor: refactor workflow context (#30607) --- api/app_factory.py | 4 + api/context/__init__.py | 74 ++++ api/context/flask_app_context.py | 198 +++++++++++ api/core/app/apps/workflow/app_generator.py | 5 +- api/core/tools/workflow_as_tool/tool.py | 36 +- api/core/workflow/context/__init__.py | 22 ++ .../workflow/context/execution_context.py | 216 ++++++++++++ .../workflow/graph_engine/graph_engine.py | 20 +- api/core/workflow/graph_engine/worker.py | 28 +- .../worker_management/worker_pool.py | 20 +- .../nodes/iteration/iteration_node.py | 18 +- .../core/workflow/context/__init__.py | 1 + .../context/test_execution_context.py | 258 ++++++++++++++ .../context/test_flask_app_context.py | 316 ++++++++++++++++++ 14 files changed, 1145 insertions(+), 71 deletions(-) create mode 100644 api/context/__init__.py create mode 100644 api/context/flask_app_context.py create mode 100644 api/core/workflow/context/__init__.py create mode 100644 api/core/workflow/context/execution_context.py create mode 100644 api/tests/unit_tests/core/workflow/context/__init__.py create mode 100644 api/tests/unit_tests/core/workflow/context/test_execution_context.py create mode 100644 api/tests/unit_tests/core/workflow/context/test_flask_app_context.py diff --git a/api/app_factory.py b/api/app_factory.py index f827842d68..1fb01d2e91 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -71,6 +71,8 @@ def create_app() -> DifyApp: def initialize_extensions(app: DifyApp): + # Initialize Flask context capture for workflow execution + from context.flask_app_context import init_flask_context from extensions import ( ext_app_metrics, ext_blueprints, @@ -100,6 +102,8 @@ def initialize_extensions(app: DifyApp): ext_warnings, ) + init_flask_context() + extensions = [ ext_timezone, ext_logging, diff --git a/api/context/__init__.py b/api/context/__init__.py new file mode 100644 index 0000000000..aebf9750ce --- /dev/null +++ b/api/context/__init__.py @@ -0,0 +1,74 @@ +""" +Core Context - Framework-agnostic context management. + +This module provides context management that is independent of any specific +web framework. Framework-specific implementations register their context +capture functions at application initialization time. + +This ensures the workflow layer remains completely decoupled from Flask +or any other web framework. +""" + +import contextvars +from collections.abc import Callable + +from core.workflow.context.execution_context import ( + ExecutionContext, + IExecutionContext, + NullAppContext, +) + +# Global capturer function - set by framework-specific modules +_capturer: Callable[[], IExecutionContext] | None = None + + +def register_context_capturer(capturer: Callable[[], IExecutionContext]) -> None: + """ + Register a context capture function. + + This should be called by framework-specific modules (e.g., Flask) + during application initialization. + + Args: + capturer: Function that captures current context and returns IExecutionContext + """ + global _capturer + _capturer = capturer + + +def capture_current_context() -> IExecutionContext: + """ + Capture current execution context. + + This function uses the registered context capturer. If no capturer + is registered, it returns a minimal context with only contextvars + (suitable for non-framework environments like tests or standalone scripts). + + Returns: + IExecutionContext with captured context + """ + if _capturer is None: + # No framework registered - return minimal context + return ExecutionContext( + app_context=NullAppContext(), + context_vars=contextvars.copy_context(), + ) + + return _capturer() + + +def reset_context_provider() -> None: + """ + Reset the context capturer. + + This is primarily useful for testing to ensure a clean state. + """ + global _capturer + _capturer = None + + +__all__ = [ + "capture_current_context", + "register_context_capturer", + "reset_context_provider", +] diff --git a/api/context/flask_app_context.py b/api/context/flask_app_context.py new file mode 100644 index 0000000000..4b693cd91f --- /dev/null +++ b/api/context/flask_app_context.py @@ -0,0 +1,198 @@ +""" +Flask App Context - Flask implementation of AppContext interface. +""" + +import contextvars +from collections.abc import Generator +from contextlib import contextmanager +from typing import Any, final + +from flask import Flask, current_app, g + +from context import register_context_capturer +from core.workflow.context.execution_context import ( + AppContext, + IExecutionContext, +) + + +@final +class FlaskAppContext(AppContext): + """ + Flask implementation of AppContext. + + This adapts Flask's app context to the AppContext interface. + """ + + def __init__(self, flask_app: Flask) -> None: + """ + Initialize Flask app context. + + Args: + flask_app: The Flask application instance + """ + self._flask_app = flask_app + + def get_config(self, key: str, default: Any = None) -> Any: + """Get configuration value from Flask app config.""" + return self._flask_app.config.get(key, default) + + def get_extension(self, name: str) -> Any: + """Get Flask extension by name.""" + return self._flask_app.extensions.get(name) + + @contextmanager + def enter(self) -> Generator[None, None, None]: + """Enter Flask app context.""" + with self._flask_app.app_context(): + yield + + @property + def flask_app(self) -> Flask: + """Get the underlying Flask app instance.""" + return self._flask_app + + +def capture_flask_context(user: Any = None) -> IExecutionContext: + """ + Capture current Flask execution context. + + This function captures the Flask app context and contextvars from the + current environment. It should be called from within a Flask request or + app context. + + Args: + user: Optional user object to include in context + + Returns: + IExecutionContext with captured Flask context + + Raises: + RuntimeError: If called outside Flask context + """ + # Get Flask app instance + flask_app = current_app._get_current_object() # type: ignore + + # Save current user if available + saved_user = user + if saved_user is None: + # Check for user in g (flask-login) + if hasattr(g, "_login_user"): + saved_user = g._login_user + + # Capture contextvars + context_vars = contextvars.copy_context() + + return FlaskExecutionContext( + flask_app=flask_app, + context_vars=context_vars, + user=saved_user, + ) + + +@final +class FlaskExecutionContext: + """ + Flask-specific execution context. + + This is a specialized version of ExecutionContext that includes Flask app + context. It provides the same interface as ExecutionContext but with + Flask-specific implementation. + """ + + def __init__( + self, + flask_app: Flask, + context_vars: contextvars.Context, + user: Any = None, + ) -> None: + """ + Initialize Flask execution context. + + Args: + flask_app: Flask application instance + context_vars: Python contextvars + user: Optional user object + """ + self._app_context = FlaskAppContext(flask_app) + self._context_vars = context_vars + self._user = user + self._flask_app = flask_app + + @property + def app_context(self) -> FlaskAppContext: + """Get Flask app context.""" + return self._app_context + + @property + def context_vars(self) -> contextvars.Context: + """Get context variables.""" + return self._context_vars + + @property + def user(self) -> Any: + """Get user object.""" + return self._user + + def __enter__(self) -> "FlaskExecutionContext": + """Enter the Flask execution context.""" + # Restore context variables + for var, val in self._context_vars.items(): + var.set(val) + + # Save current user from g if available + saved_user = None + if hasattr(g, "_login_user"): + saved_user = g._login_user + + # Enter Flask app context + self._cm = self._app_context.enter() + self._cm.__enter__() + + # Restore user in new app context + if saved_user is not None: + g._login_user = saved_user + + return self + + def __exit__(self, *args: Any) -> None: + """Exit the Flask execution context.""" + if hasattr(self, "_cm"): + self._cm.__exit__(*args) + + @contextmanager + def enter(self) -> Generator[None, None, None]: + """Enter Flask execution context as context manager.""" + # Restore context variables + for var, val in self._context_vars.items(): + var.set(val) + + # Save current user from g if available + saved_user = None + if hasattr(g, "_login_user"): + saved_user = g._login_user + + # Enter Flask app context + with self._flask_app.app_context(): + # Restore user in new app context + if saved_user is not None: + g._login_user = saved_user + yield + + +def init_flask_context() -> None: + """ + Initialize Flask context capture by registering the capturer. + + This function should be called during Flask application initialization + to register the Flask-specific context capturer with the core context module. + + Example: + app = Flask(__name__) + init_flask_context() # Register Flask context capturer + + Note: + This function does not need the app instance as it uses Flask's + `current_app` to get the app when capturing context. + """ + register_context_capturer(capture_flask_context) diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 0165c74295..2be773f103 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -8,7 +8,7 @@ from typing import Any, Literal, Union, overload from flask import Flask, current_app from pydantic import ValidationError from sqlalchemy import select -from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.orm import sessionmaker import contexts from configs import dify_config @@ -23,6 +23,7 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse +from core.db.session_factory import session_factory from core.helper.trace_id_helper import extract_external_trace_id_from_args from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager @@ -476,7 +477,7 @@ class WorkflowAppGenerator(BaseAppGenerator): :return: """ with preserve_flask_contexts(flask_app, context_vars=context): - with Session(db.engine, expire_on_commit=False) as session: + with session_factory.create_session() as session: workflow = session.scalar( select(Workflow).where( Workflow.tenant_id == application_generate_entity.app_config.tenant_id, diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 389db8a972..283744b43b 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -5,7 +5,6 @@ import logging from collections.abc import Generator, Mapping, Sequence from typing import Any, cast -from flask import has_request_context from sqlalchemy import select from core.db.session_factory import session_factory @@ -29,6 +28,21 @@ from models.workflow import Workflow logger = logging.getLogger(__name__) +def _try_resolve_user_from_request() -> Account | EndUser | None: + """ + Try to resolve user from Flask request context. + + Returns None if not in a request context or if user is not available. + """ + # Note: `current_user` is a LocalProxy. Never compare it with None directly. + # Use _get_current_object() to dereference the proxy + user = getattr(current_user, "_get_current_object", lambda: current_user)() + # Check if we got a valid user object + if user is not None and hasattr(user, "id"): + return user + return None + + class WorkflowTool(Tool): """ Workflow tool. @@ -209,21 +223,13 @@ class WorkflowTool(Tool): Returns: Account | EndUser | None: The resolved user object, or None if resolution fails. """ - if has_request_context(): - return self._resolve_user_from_request() - else: - return self._resolve_user_from_database(user_id=user_id) + # Try to resolve user from request context first + user = _try_resolve_user_from_request() + if user is not None: + return user - def _resolve_user_from_request(self) -> Account | EndUser | None: - """ - Resolve user from Flask request context. - """ - try: - # Note: `current_user` is a LocalProxy. Never compare it with None directly. - return getattr(current_user, "_get_current_object", lambda: current_user)() - except Exception as e: - logger.warning("Failed to resolve user from request context: %s", e) - return None + # Fall back to database resolution + return self._resolve_user_from_database(user_id=user_id) def _resolve_user_from_database(self, user_id: str) -> Account | EndUser | None: """ diff --git a/api/core/workflow/context/__init__.py b/api/core/workflow/context/__init__.py new file mode 100644 index 0000000000..31e1f2c8d9 --- /dev/null +++ b/api/core/workflow/context/__init__.py @@ -0,0 +1,22 @@ +""" +Execution Context - Context management for workflow execution. + +This package provides Flask-independent context management for workflow +execution in multi-threaded environments. +""" + +from core.workflow.context.execution_context import ( + AppContext, + ExecutionContext, + IExecutionContext, + NullAppContext, + capture_current_context, +) + +__all__ = [ + "AppContext", + "ExecutionContext", + "IExecutionContext", + "NullAppContext", + "capture_current_context", +] diff --git a/api/core/workflow/context/execution_context.py b/api/core/workflow/context/execution_context.py new file mode 100644 index 0000000000..5a4203be93 --- /dev/null +++ b/api/core/workflow/context/execution_context.py @@ -0,0 +1,216 @@ +""" +Execution Context - Abstracted context management for workflow execution. +""" + +import contextvars +from abc import ABC, abstractmethod +from collections.abc import Generator +from contextlib import AbstractContextManager, contextmanager +from typing import Any, Protocol, final, runtime_checkable + + +class AppContext(ABC): + """ + Abstract application context interface. + + This abstraction allows workflow execution to work with or without Flask + by providing a common interface for application context management. + """ + + @abstractmethod + def get_config(self, key: str, default: Any = None) -> Any: + """Get configuration value by key.""" + pass + + @abstractmethod + def get_extension(self, name: str) -> Any: + """Get Flask extension by name (e.g., 'db', 'cache').""" + pass + + @abstractmethod + def enter(self) -> AbstractContextManager[None]: + """Enter the application context.""" + pass + + +@runtime_checkable +class IExecutionContext(Protocol): + """ + Protocol for execution context. + + This protocol defines the interface that all execution contexts must implement, + allowing both ExecutionContext and FlaskExecutionContext to be used interchangeably. + """ + + def __enter__(self) -> "IExecutionContext": + """Enter the execution context.""" + ... + + def __exit__(self, *args: Any) -> None: + """Exit the execution context.""" + ... + + @property + def user(self) -> Any: + """Get user object.""" + ... + + +@final +class ExecutionContext: + """ + Execution context for workflow execution in worker threads. + + This class encapsulates all context needed for workflow execution: + - Application context (Flask app or standalone) + - Context variables for Python contextvars + - User information (optional) + + It is designed to be serializable and passable to worker threads. + """ + + def __init__( + self, + app_context: AppContext | None = None, + context_vars: contextvars.Context | None = None, + user: Any = None, + ) -> None: + """ + Initialize execution context. + + Args: + app_context: Application context (Flask or standalone) + context_vars: Python contextvars to preserve + user: User object (optional) + """ + self._app_context = app_context + self._context_vars = context_vars + self._user = user + + @property + def app_context(self) -> AppContext | None: + """Get application context.""" + return self._app_context + + @property + def context_vars(self) -> contextvars.Context | None: + """Get context variables.""" + return self._context_vars + + @property + def user(self) -> Any: + """Get user object.""" + return self._user + + @contextmanager + def enter(self) -> Generator[None, None, None]: + """ + Enter this execution context. + + This is a convenience method that creates a context manager. + """ + # Restore context variables if provided + if self._context_vars: + for var, val in self._context_vars.items(): + var.set(val) + + # Enter app context if available + if self._app_context is not None: + with self._app_context.enter(): + yield + else: + yield + + def __enter__(self) -> "ExecutionContext": + """Enter the execution context.""" + self._cm = self.enter() + self._cm.__enter__() + return self + + def __exit__(self, *args: Any) -> None: + """Exit the execution context.""" + if hasattr(self, "_cm"): + self._cm.__exit__(*args) + + +class NullAppContext(AppContext): + """ + Null implementation of AppContext for non-Flask environments. + + This is used when running without Flask (e.g., in tests or standalone mode). + """ + + def __init__(self, config: dict[str, Any] | None = None) -> None: + """ + Initialize null app context. + + Args: + config: Optional configuration dictionary + """ + self._config = config or {} + self._extensions: dict[str, Any] = {} + + def get_config(self, key: str, default: Any = None) -> Any: + """Get configuration value by key.""" + return self._config.get(key, default) + + def get_extension(self, name: str) -> Any: + """Get extension by name.""" + return self._extensions.get(name) + + def set_extension(self, name: str, extension: Any) -> None: + """Set extension by name.""" + self._extensions[name] = extension + + @contextmanager + def enter(self) -> Generator[None, None, None]: + """Enter null context (no-op).""" + yield + + +class ExecutionContextBuilder: + """ + Builder for creating ExecutionContext instances. + + This provides a fluent API for building execution contexts. + """ + + def __init__(self) -> None: + self._app_context: AppContext | None = None + self._context_vars: contextvars.Context | None = None + self._user: Any = None + + def with_app_context(self, app_context: AppContext) -> "ExecutionContextBuilder": + """Set application context.""" + self._app_context = app_context + return self + + def with_context_vars(self, context_vars: contextvars.Context) -> "ExecutionContextBuilder": + """Set context variables.""" + self._context_vars = context_vars + return self + + def with_user(self, user: Any) -> "ExecutionContextBuilder": + """Set user.""" + self._user = user + return self + + def build(self) -> ExecutionContext: + """Build the execution context.""" + return ExecutionContext( + app_context=self._app_context, + context_vars=self._context_vars, + user=self._user, + ) + + +def capture_current_context() -> IExecutionContext: + """ + Capture current execution context from the calling environment. + + Returns: + IExecutionContext with captured context + """ + from context import capture_current_context + + return capture_current_context() diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 9a870d7bf5..dbb2727c98 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -7,15 +7,13 @@ Domain-Driven Design principles for improved maintainability and testability. from __future__ import annotations -import contextvars import logging import queue import threading from collections.abc import Generator from typing import TYPE_CHECKING, cast, final -from flask import Flask, current_app - +from core.workflow.context import capture_current_context from core.workflow.enums import NodeExecutionType from core.workflow.graph import Graph from core.workflow.graph_events import ( @@ -159,17 +157,8 @@ class GraphEngine: self._layers: list[GraphEngineLayer] = [] # === Worker Pool Setup === - # Capture Flask app context for worker threads - flask_app: Flask | None = None - try: - app = current_app._get_current_object() # type: ignore - if isinstance(app, Flask): - flask_app = app - except RuntimeError: - pass - - # Capture context variables for worker threads - context_vars = contextvars.copy_context() + # Capture execution context for worker threads + execution_context = capture_current_context() # Create worker pool for parallel node execution self._worker_pool = WorkerPool( @@ -177,8 +166,7 @@ class GraphEngine: event_queue=self._event_queue, graph=self._graph, layers=self._layers, - flask_app=flask_app, - context_vars=context_vars, + execution_context=execution_context, min_workers=self._min_workers, max_workers=self._max_workers, scale_up_threshold=self._scale_up_threshold, diff --git a/api/core/workflow/graph_engine/worker.py b/api/core/workflow/graph_engine/worker.py index 83419830b6..95db5c5c92 100644 --- a/api/core/workflow/graph_engine/worker.py +++ b/api/core/workflow/graph_engine/worker.py @@ -5,26 +5,27 @@ Workers pull node IDs from the ready_queue, execute nodes, and push events to the event_queue for the dispatcher to process. """ -import contextvars import queue import threading import time from collections.abc import Sequence from datetime import datetime -from typing import final +from typing import TYPE_CHECKING, final from uuid import uuid4 -from flask import Flask from typing_extensions import override +from core.workflow.context import IExecutionContext from core.workflow.graph import Graph from core.workflow.graph_engine.layers.base import GraphEngineLayer from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent from core.workflow.nodes.base.node import Node -from libs.flask_utils import preserve_flask_contexts from .ready_queue import ReadyQueue +if TYPE_CHECKING: + pass + @final class Worker(threading.Thread): @@ -44,8 +45,7 @@ class Worker(threading.Thread): layers: Sequence[GraphEngineLayer], stop_event: threading.Event, worker_id: int = 0, - flask_app: Flask | None = None, - context_vars: contextvars.Context | None = None, + execution_context: IExecutionContext | None = None, ) -> None: """ Initialize worker thread. @@ -56,19 +56,17 @@ class Worker(threading.Thread): graph: Graph containing nodes to execute layers: Graph engine layers for node execution hooks worker_id: Unique identifier for this worker - flask_app: Optional Flask application for context preservation - context_vars: Optional context variables to preserve in worker thread + execution_context: Optional execution context for context preservation """ super().__init__(name=f"GraphWorker-{worker_id}", daemon=True) self._ready_queue = ready_queue self._event_queue = event_queue self._graph = graph self._worker_id = worker_id - self._flask_app = flask_app - self._context_vars = context_vars - self._last_task_time = time.time() + self._execution_context = execution_context self._stop_event = stop_event self._layers = layers if layers is not None else [] + self._last_task_time = time.time() def stop(self) -> None: """Worker is controlled via shared stop_event from GraphEngine. @@ -135,11 +133,9 @@ class Worker(threading.Thread): error: Exception | None = None - if self._flask_app and self._context_vars: - with preserve_flask_contexts( - flask_app=self._flask_app, - context_vars=self._context_vars, - ): + # Execute the node with preserved context if execution context is provided + if self._execution_context is not None: + with self._execution_context: self._invoke_node_run_start_hooks(node) try: node_events = node.run() diff --git a/api/core/workflow/graph_engine/worker_management/worker_pool.py b/api/core/workflow/graph_engine/worker_management/worker_pool.py index df76ebe882..9ce7d16e93 100644 --- a/api/core/workflow/graph_engine/worker_management/worker_pool.py +++ b/api/core/workflow/graph_engine/worker_management/worker_pool.py @@ -8,9 +8,10 @@ DynamicScaler, and WorkerFactory into a single class. import logging import queue import threading -from typing import TYPE_CHECKING, final +from typing import final from configs import dify_config +from core.workflow.context import IExecutionContext from core.workflow.graph import Graph from core.workflow.graph_events import GraphNodeEventBase @@ -20,11 +21,6 @@ from ..worker import Worker logger = logging.getLogger(__name__) -if TYPE_CHECKING: - from contextvars import Context - - from flask import Flask - @final class WorkerPool: @@ -42,8 +38,7 @@ class WorkerPool: graph: Graph, layers: list[GraphEngineLayer], stop_event: threading.Event, - flask_app: "Flask | None" = None, - context_vars: "Context | None" = None, + execution_context: IExecutionContext | None = None, min_workers: int | None = None, max_workers: int | None = None, scale_up_threshold: int | None = None, @@ -57,8 +52,7 @@ class WorkerPool: event_queue: Queue for worker events graph: The workflow graph layers: Graph engine layers for node execution hooks - flask_app: Optional Flask app for context preservation - context_vars: Optional context variables + execution_context: Optional execution context for context preservation min_workers: Minimum number of workers max_workers: Maximum number of workers scale_up_threshold: Queue depth to trigger scale up @@ -67,8 +61,7 @@ class WorkerPool: self._ready_queue = ready_queue self._event_queue = event_queue self._graph = graph - self._flask_app = flask_app - self._context_vars = context_vars + self._execution_context = execution_context self._layers = layers # Scaling parameters with defaults @@ -152,8 +145,7 @@ class WorkerPool: graph=self._graph, layers=self._layers, worker_id=worker_id, - flask_app=self._flask_app, - context_vars=self._context_vars, + execution_context=self._execution_context, stop_event=self._stop_event, ) diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 91df2e4e0b..569a4196fb 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -1,11 +1,9 @@ -import contextvars import logging from collections.abc import Generator, Mapping, Sequence from concurrent.futures import Future, ThreadPoolExecutor, as_completed from datetime import UTC, datetime from typing import TYPE_CHECKING, Any, NewType, cast -from flask import Flask, current_app from typing_extensions import TypeIs from core.model_runtime.entities.llm_entities import LLMUsage @@ -39,7 +37,6 @@ from core.workflow.nodes.base.node import Node from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData from core.workflow.runtime import VariablePool from libs.datetime_utils import naive_utc_now -from libs.flask_utils import preserve_flask_contexts from .exc import ( InvalidIteratorValueError, @@ -51,6 +48,7 @@ from .exc import ( ) if TYPE_CHECKING: + from core.workflow.context import IExecutionContext from core.workflow.graph_engine import GraphEngine logger = logging.getLogger(__name__) @@ -252,8 +250,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): self._execute_single_iteration_parallel, index=index, item=item, - flask_app=current_app._get_current_object(), # type: ignore - context_vars=contextvars.copy_context(), + execution_context=self._capture_execution_context(), ) future_to_index[future] = index @@ -306,11 +303,10 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): self, index: int, item: object, - flask_app: Flask, - context_vars: contextvars.Context, + execution_context: "IExecutionContext", ) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]: """Execute a single iteration in parallel mode and return results.""" - with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars): + with execution_context: iter_start_at = datetime.now(UTC).replace(tzinfo=None) events: list[GraphNodeEventBase] = [] outputs_temp: list[object] = [] @@ -339,6 +335,12 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): graph_engine.graph_runtime_state.llm_usage, ) + def _capture_execution_context(self) -> "IExecutionContext": + """Capture current execution context for parallel iterations.""" + from core.workflow.context import capture_current_context + + return capture_current_context() + def _handle_iteration_success( self, started_at: datetime, diff --git a/api/tests/unit_tests/core/workflow/context/__init__.py b/api/tests/unit_tests/core/workflow/context/__init__.py new file mode 100644 index 0000000000..ac81c5c9e8 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/context/__init__.py @@ -0,0 +1 @@ +"""Tests for workflow context management.""" diff --git a/api/tests/unit_tests/core/workflow/context/test_execution_context.py b/api/tests/unit_tests/core/workflow/context/test_execution_context.py new file mode 100644 index 0000000000..217c39385c --- /dev/null +++ b/api/tests/unit_tests/core/workflow/context/test_execution_context.py @@ -0,0 +1,258 @@ +"""Tests for execution context module.""" + +import contextvars +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from core.workflow.context.execution_context import ( + AppContext, + ExecutionContext, + ExecutionContextBuilder, + IExecutionContext, + NullAppContext, +) + + +class TestAppContext: + """Test AppContext abstract base class.""" + + def test_app_context_is_abstract(self): + """Test that AppContext cannot be instantiated directly.""" + with pytest.raises(TypeError): + AppContext() # type: ignore + + +class TestNullAppContext: + """Test NullAppContext implementation.""" + + def test_null_app_context_get_config(self): + """Test get_config returns value from config dict.""" + config = {"key1": "value1", "key2": "value2"} + ctx = NullAppContext(config=config) + + assert ctx.get_config("key1") == "value1" + assert ctx.get_config("key2") == "value2" + + def test_null_app_context_get_config_default(self): + """Test get_config returns default when key not found.""" + ctx = NullAppContext() + + assert ctx.get_config("nonexistent", "default") == "default" + assert ctx.get_config("nonexistent") is None + + def test_null_app_context_get_extension(self): + """Test get_extension returns stored extension.""" + ctx = NullAppContext() + extension = MagicMock() + ctx.set_extension("db", extension) + + assert ctx.get_extension("db") == extension + + def test_null_app_context_get_extension_not_found(self): + """Test get_extension returns None when extension not found.""" + ctx = NullAppContext() + + assert ctx.get_extension("nonexistent") is None + + def test_null_app_context_enter_yield(self): + """Test enter method yields without any side effects.""" + ctx = NullAppContext() + + with ctx.enter(): + # Should not raise any exception + pass + + +class TestExecutionContext: + """Test ExecutionContext class.""" + + def test_initialization_with_all_params(self): + """Test ExecutionContext initialization with all parameters.""" + app_ctx = NullAppContext() + context_vars = contextvars.copy_context() + user = MagicMock() + + ctx = ExecutionContext( + app_context=app_ctx, + context_vars=context_vars, + user=user, + ) + + assert ctx.app_context == app_ctx + assert ctx.context_vars == context_vars + assert ctx.user == user + + def test_initialization_with_minimal_params(self): + """Test ExecutionContext initialization with minimal parameters.""" + ctx = ExecutionContext() + + assert ctx.app_context is None + assert ctx.context_vars is None + assert ctx.user is None + + def test_enter_with_context_vars(self): + """Test enter restores context variables.""" + test_var = contextvars.ContextVar("test_var") + test_var.set("original_value") + + # Copy context with the variable + context_vars = contextvars.copy_context() + + # Change the variable + test_var.set("new_value") + + # Create execution context and enter it + ctx = ExecutionContext(context_vars=context_vars) + + with ctx.enter(): + # Variable should be restored to original value + assert test_var.get() == "original_value" + + # After exiting, variable stays at the value from within the context + # (this is expected Python contextvars behavior) + assert test_var.get() == "original_value" + + def test_enter_with_app_context(self): + """Test enter enters app context if available.""" + app_ctx = NullAppContext() + ctx = ExecutionContext(app_context=app_ctx) + + # Should not raise any exception + with ctx.enter(): + pass + + def test_enter_without_app_context(self): + """Test enter works without app context.""" + ctx = ExecutionContext(app_context=None) + + # Should not raise any exception + with ctx.enter(): + pass + + def test_context_manager_protocol(self): + """Test ExecutionContext supports context manager protocol.""" + ctx = ExecutionContext() + + with ctx: + # Should not raise any exception + pass + + def test_user_property(self): + """Test user property returns set user.""" + user = MagicMock() + ctx = ExecutionContext(user=user) + + assert ctx.user == user + + +class TestIExecutionContextProtocol: + """Test IExecutionContext protocol.""" + + def test_execution_context_implements_protocol(self): + """Test that ExecutionContext implements IExecutionContext protocol.""" + ctx = ExecutionContext() + + # Should have __enter__ and __exit__ methods + assert hasattr(ctx, "__enter__") + assert hasattr(ctx, "__exit__") + assert hasattr(ctx, "user") + + def test_protocol_compatibility(self): + """Test that ExecutionContext can be used where IExecutionContext is expected.""" + + def accept_context(context: IExecutionContext) -> Any: + """Function that accepts IExecutionContext protocol.""" + # Just verify it has the required protocol attributes + assert hasattr(context, "__enter__") + assert hasattr(context, "__exit__") + assert hasattr(context, "user") + return context.user + + ctx = ExecutionContext(user="test_user") + result = accept_context(ctx) + + assert result == "test_user" + + def test_protocol_with_flask_execution_context(self): + """Test that IExecutionContext protocol is compatible with different implementations.""" + # Verify the protocol works with ExecutionContext + ctx = ExecutionContext(user="test_user") + + # Should have the required protocol attributes + assert hasattr(ctx, "__enter__") + assert hasattr(ctx, "__exit__") + assert hasattr(ctx, "user") + assert ctx.user == "test_user" + + # Should work as context manager + with ctx: + assert ctx.user == "test_user" + + +class TestExecutionContextBuilder: + """Test ExecutionContextBuilder class.""" + + def test_builder_with_all_params(self): + """Test builder with all parameters set.""" + app_ctx = NullAppContext() + context_vars = contextvars.copy_context() + user = MagicMock() + + ctx = ( + ExecutionContextBuilder().with_app_context(app_ctx).with_context_vars(context_vars).with_user(user).build() + ) + + assert ctx.app_context == app_ctx + assert ctx.context_vars == context_vars + assert ctx.user == user + + def test_builder_with_partial_params(self): + """Test builder with only some parameters set.""" + app_ctx = NullAppContext() + + ctx = ExecutionContextBuilder().with_app_context(app_ctx).build() + + assert ctx.app_context == app_ctx + assert ctx.context_vars is None + assert ctx.user is None + + def test_builder_fluent_interface(self): + """Test builder provides fluent interface.""" + builder = ExecutionContextBuilder() + + # Each method should return the builder + assert isinstance(builder.with_app_context(NullAppContext()), ExecutionContextBuilder) + assert isinstance(builder.with_context_vars(contextvars.copy_context()), ExecutionContextBuilder) + assert isinstance(builder.with_user(None), ExecutionContextBuilder) + + +class TestCaptureCurrentContext: + """Test capture_current_context function.""" + + def test_capture_current_context_returns_context(self): + """Test that capture_current_context returns a valid context.""" + from core.workflow.context.execution_context import capture_current_context + + result = capture_current_context() + + # Should return an object that implements IExecutionContext + assert hasattr(result, "__enter__") + assert hasattr(result, "__exit__") + assert hasattr(result, "user") + + def test_capture_current_context_captures_contextvars(self): + """Test that capture_current_context captures context variables.""" + # Set a context variable before capturing + import contextvars + + test_var = contextvars.ContextVar("capture_test_var") + test_var.set("test_value_123") + + from core.workflow.context.execution_context import capture_current_context + + result = capture_current_context() + + # Context variables should be captured + assert result.context_vars is not None diff --git a/api/tests/unit_tests/core/workflow/context/test_flask_app_context.py b/api/tests/unit_tests/core/workflow/context/test_flask_app_context.py new file mode 100644 index 0000000000..a809b29552 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/context/test_flask_app_context.py @@ -0,0 +1,316 @@ +"""Tests for Flask app context module.""" + +import contextvars +from unittest.mock import MagicMock, patch + +import pytest + + +class TestFlaskAppContext: + """Test FlaskAppContext implementation.""" + + @pytest.fixture + def mock_flask_app(self): + """Create a mock Flask app.""" + app = MagicMock() + app.config = {"TEST_KEY": "test_value"} + app.extensions = {"db": MagicMock(), "cache": MagicMock()} + app.app_context = MagicMock() + app.app_context.return_value.__enter__ = MagicMock(return_value=None) + app.app_context.return_value.__exit__ = MagicMock(return_value=None) + return app + + def test_flask_app_context_initialization(self, mock_flask_app): + """Test FlaskAppContext initialization.""" + # Import here to avoid Flask dependency in test environment + from context.flask_app_context import FlaskAppContext + + ctx = FlaskAppContext(mock_flask_app) + + assert ctx.flask_app == mock_flask_app + + def test_flask_app_context_get_config(self, mock_flask_app): + """Test get_config returns Flask app config value.""" + from context.flask_app_context import FlaskAppContext + + ctx = FlaskAppContext(mock_flask_app) + + assert ctx.get_config("TEST_KEY") == "test_value" + + def test_flask_app_context_get_config_default(self, mock_flask_app): + """Test get_config returns default when key not found.""" + from context.flask_app_context import FlaskAppContext + + ctx = FlaskAppContext(mock_flask_app) + + assert ctx.get_config("NONEXISTENT", "default") == "default" + + def test_flask_app_context_get_extension(self, mock_flask_app): + """Test get_extension returns Flask extension.""" + from context.flask_app_context import FlaskAppContext + + ctx = FlaskAppContext(mock_flask_app) + db_ext = mock_flask_app.extensions["db"] + + assert ctx.get_extension("db") == db_ext + + def test_flask_app_context_get_extension_not_found(self, mock_flask_app): + """Test get_extension returns None when extension not found.""" + from context.flask_app_context import FlaskAppContext + + ctx = FlaskAppContext(mock_flask_app) + + assert ctx.get_extension("nonexistent") is None + + def test_flask_app_context_enter(self, mock_flask_app): + """Test enter method enters Flask app context.""" + from context.flask_app_context import FlaskAppContext + + ctx = FlaskAppContext(mock_flask_app) + + with ctx.enter(): + # Should not raise any exception + pass + + # Verify app_context was called + mock_flask_app.app_context.assert_called_once() + + +class TestFlaskExecutionContext: + """Test FlaskExecutionContext class.""" + + @pytest.fixture + def mock_flask_app(self): + """Create a mock Flask app.""" + app = MagicMock() + app.config = {} + app.app_context = MagicMock() + app.app_context.return_value.__enter__ = MagicMock(return_value=None) + app.app_context.return_value.__exit__ = MagicMock(return_value=None) + return app + + def test_initialization(self, mock_flask_app): + """Test FlaskExecutionContext initialization.""" + from context.flask_app_context import FlaskExecutionContext + + context_vars = contextvars.copy_context() + user = MagicMock() + + ctx = FlaskExecutionContext( + flask_app=mock_flask_app, + context_vars=context_vars, + user=user, + ) + + assert ctx.context_vars == context_vars + assert ctx.user == user + + def test_app_context_property(self, mock_flask_app): + """Test app_context property returns FlaskAppContext.""" + from context.flask_app_context import FlaskAppContext, FlaskExecutionContext + + ctx = FlaskExecutionContext( + flask_app=mock_flask_app, + context_vars=contextvars.copy_context(), + ) + + assert isinstance(ctx.app_context, FlaskAppContext) + assert ctx.app_context.flask_app == mock_flask_app + + def test_context_manager_protocol(self, mock_flask_app): + """Test FlaskExecutionContext supports context manager protocol.""" + from context.flask_app_context import FlaskExecutionContext + + ctx = FlaskExecutionContext( + flask_app=mock_flask_app, + context_vars=contextvars.copy_context(), + ) + + # Should have __enter__ and __exit__ methods + assert hasattr(ctx, "__enter__") + assert hasattr(ctx, "__exit__") + + # Should work as context manager + with ctx: + pass + + +class TestCaptureFlaskContext: + """Test capture_flask_context function.""" + + @patch("context.flask_app_context.current_app") + @patch("context.flask_app_context.g") + def test_capture_flask_context_captures_app(self, mock_g, mock_current_app): + """Test capture_flask_context captures Flask app.""" + mock_app = MagicMock() + mock_app._get_current_object = MagicMock(return_value=mock_app) + mock_current_app._get_current_object = MagicMock(return_value=mock_app) + + from context.flask_app_context import capture_flask_context + + ctx = capture_flask_context() + + assert ctx._flask_app == mock_app + + @patch("context.flask_app_context.current_app") + @patch("context.flask_app_context.g") + def test_capture_flask_context_captures_user_from_g(self, mock_g, mock_current_app): + """Test capture_flask_context captures user from Flask g object.""" + mock_app = MagicMock() + mock_app._get_current_object = MagicMock(return_value=mock_app) + mock_current_app._get_current_object = MagicMock(return_value=mock_app) + + mock_user = MagicMock() + mock_user.id = "user_123" + mock_g._login_user = mock_user + + from context.flask_app_context import capture_flask_context + + ctx = capture_flask_context() + + assert ctx.user == mock_user + + @patch("context.flask_app_context.current_app") + def test_capture_flask_context_with_explicit_user(self, mock_current_app): + """Test capture_flask_context uses explicit user parameter.""" + mock_app = MagicMock() + mock_app._get_current_object = MagicMock(return_value=mock_app) + mock_current_app._get_current_object = MagicMock(return_value=mock_app) + + explicit_user = MagicMock() + explicit_user.id = "user_456" + + from context.flask_app_context import capture_flask_context + + ctx = capture_flask_context(user=explicit_user) + + assert ctx.user == explicit_user + + @patch("context.flask_app_context.current_app") + def test_capture_flask_context_captures_contextvars(self, mock_current_app): + """Test capture_flask_context captures context variables.""" + mock_app = MagicMock() + mock_app._get_current_object = MagicMock(return_value=mock_app) + mock_current_app._get_current_object = MagicMock(return_value=mock_app) + + # Set a context variable + test_var = contextvars.ContextVar("test_var") + test_var.set("test_value") + + from context.flask_app_context import capture_flask_context + + ctx = capture_flask_context() + + # Context variables should be captured + assert ctx.context_vars is not None + # Verify the variable is in the captured context + captured_value = ctx.context_vars[test_var] + assert captured_value == "test_value" + + +class TestFlaskExecutionContextIntegration: + """Integration tests for FlaskExecutionContext.""" + + @pytest.fixture + def mock_flask_app(self): + """Create a mock Flask app with proper app context.""" + app = MagicMock() + app.config = {"TEST": "value"} + app.extensions = {"db": MagicMock()} + + # Mock app context + mock_app_context = MagicMock() + mock_app_context.__enter__ = MagicMock(return_value=None) + mock_app_context.__exit__ = MagicMock(return_value=None) + app.app_context.return_value = mock_app_context + + return app + + def test_enter_restores_context_vars(self, mock_flask_app): + """Test that enter restores captured context variables.""" + # Create a context variable and set a value + test_var = contextvars.ContextVar("integration_test_var") + test_var.set("original_value") + + # Capture the context + context_vars = contextvars.copy_context() + + # Change the value + test_var.set("new_value") + + # Create FlaskExecutionContext and enter it + from context.flask_app_context import FlaskExecutionContext + + ctx = FlaskExecutionContext( + flask_app=mock_flask_app, + context_vars=context_vars, + ) + + with ctx: + # Value should be restored to original + assert test_var.get() == "original_value" + + # After exiting, variable stays at the value from within the context + # (this is expected Python contextvars behavior) + assert test_var.get() == "original_value" + + def test_enter_enters_flask_app_context(self, mock_flask_app): + """Test that enter enters Flask app context.""" + from context.flask_app_context import FlaskExecutionContext + + ctx = FlaskExecutionContext( + flask_app=mock_flask_app, + context_vars=contextvars.copy_context(), + ) + + with ctx: + # Verify app context was entered + assert mock_flask_app.app_context.called + + @patch("context.flask_app_context.g") + def test_enter_restores_user_in_g(self, mock_g, mock_flask_app): + """Test that enter restores user in Flask g object.""" + mock_user = MagicMock() + mock_user.id = "test_user" + + # Note: FlaskExecutionContext saves user from g before entering context, + # then restores it after entering the app context. + # The user passed to constructor is NOT restored to g. + # So we need to test the actual behavior. + + # Create FlaskExecutionContext with user in constructor + from context.flask_app_context import FlaskExecutionContext + + ctx = FlaskExecutionContext( + flask_app=mock_flask_app, + context_vars=contextvars.copy_context(), + user=mock_user, + ) + + # Set user in g before entering (simulating existing user in g) + mock_g._login_user = mock_user + + with ctx: + # After entering, the user from g before entry should be restored + assert mock_g._login_user == mock_user + + # The user in constructor is stored but not automatically restored to g + # (it's available via ctx.user property) + assert ctx.user == mock_user + + def test_enter_method_as_context_manager(self, mock_flask_app): + """Test enter method returns a proper context manager.""" + from context.flask_app_context import FlaskExecutionContext + + ctx = FlaskExecutionContext( + flask_app=mock_flask_app, + context_vars=contextvars.copy_context(), + ) + + # enter() should return a generator/context manager + with ctx.enter(): + # Should work without issues + pass + + # Verify app context was called + assert mock_flask_app.app_context.called