mirror of https://github.com/langgenius/dify.git
feat: workflow support register context and read context (#31265)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Maries <xh001x@hotmail.com>
This commit is contained in:
parent
e80d76af15
commit
34436fc89c
|
|
@ -9,7 +9,7 @@ from typing import Any, final
|
|||
|
||||
from flask import Flask, current_app, g
|
||||
|
||||
from context import register_context_capturer
|
||||
from core.workflow.context import register_context_capturer
|
||||
from core.workflow.context.execution_context import (
|
||||
AppContext,
|
||||
IExecutionContext,
|
||||
|
|
|
|||
|
|
@ -7,16 +7,28 @@ execution in multi-threaded environments.
|
|||
|
||||
from core.workflow.context.execution_context import (
|
||||
AppContext,
|
||||
ContextProviderNotFoundError,
|
||||
ExecutionContext,
|
||||
IExecutionContext,
|
||||
NullAppContext,
|
||||
capture_current_context,
|
||||
read_context,
|
||||
register_context,
|
||||
register_context_capturer,
|
||||
reset_context_provider,
|
||||
)
|
||||
from core.workflow.context.models import SandboxContext
|
||||
|
||||
__all__ = [
|
||||
"AppContext",
|
||||
"ContextProviderNotFoundError",
|
||||
"ExecutionContext",
|
||||
"IExecutionContext",
|
||||
"NullAppContext",
|
||||
"SandboxContext",
|
||||
"capture_current_context",
|
||||
"read_context",
|
||||
"register_context",
|
||||
"register_context_capturer",
|
||||
"reset_context_provider",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -4,9 +4,11 @@ Execution Context - Abstracted context management for workflow execution.
|
|||
|
||||
import contextvars
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from typing import Any, Protocol, final, runtime_checkable
|
||||
from typing import Any, Protocol, TypeVar, final, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AppContext(ABC):
|
||||
|
|
@ -204,13 +206,75 @@ class ExecutionContextBuilder:
|
|||
)
|
||||
|
||||
|
||||
_capturer: Callable[[], IExecutionContext] | None = None
|
||||
|
||||
# Tenant-scoped providers using tuple keys for clarity and constant-time lookup.
|
||||
# Key mapping:
|
||||
# (name, tenant_id) -> provider
|
||||
# - name: namespaced identifier (recommend prefixing, e.g. "workflow.sandbox")
|
||||
# - tenant_id: tenant identifier string
|
||||
# Value:
|
||||
# provider: Callable[[], BaseModel] returning the typed context value
|
||||
# Type-safety note:
|
||||
# - This registry cannot enforce that all providers for a given name return the same BaseModel type.
|
||||
# - Implementors SHOULD provide typed wrappers around register/read (like Go's context best practice),
|
||||
# e.g. def register_sandbox_ctx(tenant_id: str, p: Callable[[], SandboxContext]) and
|
||||
# def read_sandbox_ctx(tenant_id: str) -> SandboxContext.
|
||||
_tenant_context_providers: dict[tuple[str, str], Callable[[], BaseModel]] = {}
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class ContextProviderNotFoundError(KeyError):
|
||||
"""Raised when a tenant-scoped context provider is missing for a given (name, tenant_id)."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def register_context_capturer(capturer: Callable[[], IExecutionContext]) -> None:
|
||||
"""Register a single enterable execution context capturer (e.g., Flask)."""
|
||||
global _capturer
|
||||
_capturer = capturer
|
||||
|
||||
|
||||
def register_context(name: str, tenant_id: str, provider: Callable[[], BaseModel]) -> None:
|
||||
"""Register a tenant-specific provider for a named context.
|
||||
|
||||
Tip: use a namespaced "name" (e.g., "workflow.sandbox") to avoid key collisions.
|
||||
Consider adding a typed wrapper for this registration in your feature module.
|
||||
"""
|
||||
_tenant_context_providers[(name, tenant_id)] = provider
|
||||
|
||||
|
||||
def read_context(name: str, *, tenant_id: str) -> BaseModel:
|
||||
"""
|
||||
Read a context value for a specific tenant.
|
||||
|
||||
Raises KeyError if the provider for (name, tenant_id) is not registered.
|
||||
"""
|
||||
prov = _tenant_context_providers.get((name, tenant_id))
|
||||
if prov is None:
|
||||
raise ContextProviderNotFoundError(f"Context provider '{name}' not registered for tenant '{tenant_id}'")
|
||||
return prov()
|
||||
|
||||
|
||||
def capture_current_context() -> IExecutionContext:
|
||||
"""
|
||||
Capture current execution context from the calling environment.
|
||||
|
||||
Returns:
|
||||
IExecutionContext with captured context
|
||||
If a capturer is registered (e.g., Flask), use it. Otherwise, return a minimal
|
||||
context with NullAppContext + copy of current contextvars.
|
||||
"""
|
||||
from context import capture_current_context
|
||||
if _capturer is None:
|
||||
return ExecutionContext(
|
||||
app_context=NullAppContext(),
|
||||
context_vars=contextvars.copy_context(),
|
||||
)
|
||||
return _capturer()
|
||||
|
||||
return capture_current_context()
|
||||
|
||||
def reset_context_provider() -> None:
|
||||
"""Reset the capturer and all tenant-scoped context providers (primarily for tests)."""
|
||||
global _capturer
|
||||
_capturer = None
|
||||
_tenant_context_providers.clear()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,13 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from pydantic import AnyHttpUrl, BaseModel
|
||||
|
||||
|
||||
class SandboxContext(BaseModel):
|
||||
"""Typed context for sandbox integration. All fields optional by design."""
|
||||
|
||||
sandbox_url: AnyHttpUrl | None = None
|
||||
sandbox_token: str | None = None # optional, if later needed for auth
|
||||
|
||||
|
||||
__all__ = ["SandboxContext"]
|
||||
|
|
@ -5,6 +5,7 @@ from typing import Any
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.context.execution_context import (
|
||||
AppContext,
|
||||
|
|
@ -12,6 +13,8 @@ from core.workflow.context.execution_context import (
|
|||
ExecutionContextBuilder,
|
||||
IExecutionContext,
|
||||
NullAppContext,
|
||||
read_context,
|
||||
register_context,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -256,3 +259,31 @@ class TestCaptureCurrentContext:
|
|||
|
||||
# Context variables should be captured
|
||||
assert result.context_vars is not None
|
||||
|
||||
|
||||
class TestTenantScopedContextRegistry:
|
||||
def setup_method(self):
|
||||
from core.workflow.context import reset_context_provider
|
||||
|
||||
reset_context_provider()
|
||||
|
||||
def teardown_method(self):
|
||||
from core.workflow.context import reset_context_provider
|
||||
|
||||
reset_context_provider()
|
||||
|
||||
def test_tenant_provider_read_ok(self):
|
||||
class SandboxContext(BaseModel):
|
||||
base_url: str | None = None
|
||||
|
||||
register_context("workflow.sandbox", "t1", lambda: SandboxContext(base_url="http://t1"))
|
||||
register_context("workflow.sandbox", "t2", lambda: SandboxContext(base_url="http://t2"))
|
||||
|
||||
assert read_context("workflow.sandbox", tenant_id="t1").base_url == "http://t1"
|
||||
assert read_context("workflow.sandbox", tenant_id="t2").base_url == "http://t2"
|
||||
|
||||
def test_missing_provider_raises_keyerror(self):
|
||||
from core.workflow.context import ContextProviderNotFoundError
|
||||
|
||||
with pytest.raises(ContextProviderNotFoundError):
|
||||
read_context("missing", tenant_id="unknown")
|
||||
|
|
|
|||
Loading…
Reference in New Issue