From 932be0ad6403a478d18f9a73898ae61c605716c8 Mon Sep 17 00:00:00 2001 From: Harry Date: Mon, 5 Jan 2026 15:48:31 +0800 Subject: [PATCH] feat: session management for InnerAPI&VM --- api/controllers/inner_api/plugin/wraps.py | 15 ++- api/controllers/inner_api/wraps.py | 15 ++- api/core/session/__init__.py | 11 ++ api/core/session/inner_api.py | 19 ++++ api/core/session/session.py | 106 ++++++++++++++++++ .../virtual_environment/session/__init__.py | 3 + .../session/sandbox_session.py | 47 ++++++++ 7 files changed, 209 insertions(+), 7 deletions(-) create mode 100644 api/core/session/__init__.py create mode 100644 api/core/session/inner_api.py create mode 100644 api/core/session/session.py create mode 100644 api/core/virtual_environment/session/__init__.py create mode 100644 api/core/virtual_environment/session/sandbox_session.py diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py index edf3ac393c..5f8dae5f87 100644 --- a/api/controllers/inner_api/plugin/wraps.py +++ b/api/controllers/inner_api/plugin/wraps.py @@ -7,6 +7,7 @@ from flask_login import user_logged_in from pydantic import BaseModel from sqlalchemy.orm import Session +from core.session.inner_api import InnerApiSession, InnerApiSessionManager from extensions.ext_database import db from libs.login import current_user from models.account import Tenant @@ -74,10 +75,18 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser: def get_user_tenant(view_func: Callable[P, R]): @wraps(view_func) def decorated_view(*args: P.args, **kwargs: P.kwargs): - payload = TenantUserPayload.model_validate(request.get_json(silent=True) or {}) + session_id = request.headers.get("X-Inner-Api-Session-Id") - user_id = payload.user_id - tenant_id = payload.tenant_id + if session_id: + session: InnerApiSession | None = InnerApiSessionManager().get(session_id) + if not session: + raise ValueError("session not found") + user_id = session.user_id + tenant_id = session.tenant_id + else: + payload = TenantUserPayload.model_validate(request.get_json(silent=True) or {}) + user_id = payload.user_id + tenant_id = payload.tenant_id if not tenant_id: raise ValueError("tenant_id is required") diff --git a/api/controllers/inner_api/wraps.py b/api/controllers/inner_api/wraps.py index 4bdcc6832a..9c859462db 100644 --- a/api/controllers/inner_api/wraps.py +++ b/api/controllers/inner_api/wraps.py @@ -5,6 +5,8 @@ from hashlib import sha1 from hmac import new as hmac_new from typing import ParamSpec, TypeVar +from core.session.inner_api import InnerApiSessionManager + P = ParamSpec("P") R = TypeVar("R") from flask import abort, request @@ -85,14 +87,19 @@ def enterprise_inner_api_user_auth(view: Callable[P, R]): def plugin_inner_api_only(view: Callable[P, R]): @wraps(view) def decorated(*args: P.args, **kwargs: P.kwargs): + # if session id is provided, using session id to validate + session_id = request.headers.get("X-Inner-Api-Session-Id") + if session_id and InnerApiSessionManager().exists(session_id): + return view(*args, **kwargs) + if not dify_config.PLUGIN_DAEMON_KEY: abort(404) - # get header 'X-Inner-Api-Key' + # if inner api key is provided, using inner api key to validate inner_api_key = request.headers.get("X-Inner-Api-Key") - if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY_FOR_PLUGIN: - abort(404) + if inner_api_key and inner_api_key == dify_config.INNER_API_KEY_FOR_PLUGIN: + return view(*args, **kwargs) - return view(*args, **kwargs) + abort(404) return decorated diff --git a/api/core/session/__init__.py b/api/core/session/__init__.py new file mode 100644 index 0000000000..bb3f2c97f6 --- /dev/null +++ b/api/core/session/__init__.py @@ -0,0 +1,11 @@ +from .inner_api import InnerApiSession, InnerApiSessionManager +from .session import BaseSession, RedisSessionStorage, SessionManager, SessionStorage + +__all__ = [ + "BaseSession", + "InnerApiSession", + "InnerApiSessionManager", + "RedisSessionStorage", + "SessionManager", + "SessionStorage", +] diff --git a/api/core/session/inner_api.py b/api/core/session/inner_api.py new file mode 100644 index 0000000000..7b033397ec --- /dev/null +++ b/api/core/session/inner_api.py @@ -0,0 +1,19 @@ +from typing import Any + +from .session import BaseSession, SessionManager + + +class InnerApiSession(BaseSession): + """Inner API Session""" + + pass + + +class InnerApiSessionManager(SessionManager[InnerApiSession]): + def __init__(self, ttl: int | None = None): + super().__init__(key_prefix="inner_api_session", session_class=InnerApiSession, ttl=ttl) + + def create(self, tenant_id: str, user_id: str, context: dict[str, Any] | None = None) -> InnerApiSession: + session = InnerApiSession(tenant_id=tenant_id, user_id=user_id, context=context or {}) + self.save(session) + return session diff --git a/api/core/session/session.py b/api/core/session/session.py new file mode 100644 index 0000000000..620ea39b3a --- /dev/null +++ b/api/core/session/session.py @@ -0,0 +1,106 @@ +import json +import logging +import uuid +from datetime import UTC, datetime +from typing import Any, Generic, Protocol, TypeVar + +from pydantic import BaseModel, Field, ValidationError + +from extensions.ext_redis import redis_client + +logger = logging.getLogger(__name__) + + +class SessionStorage(Protocol): + """Session storage interface.""" + + def get(self, key: str) -> str | None: ... + def set(self, key: str, value: str, ttl: int) -> None: ... + def delete(self, key: str) -> bool: ... + def exists(self, key: str) -> bool: ... + def refresh_ttl(self, key: str, ttl: int) -> bool: ... + + +class RedisSessionStorage: + """Redis storage implementation (default).""" + + def get(self, key: str) -> str | None: + result = redis_client.get(key) + if result is None: + return None + return result.decode() if isinstance(result, bytes) else result + + def set(self, key: str, value: str, ttl: int) -> None: + redis_client.setex(key, ttl, value) + + def delete(self, key: str) -> bool: + return redis_client.delete(key) > 0 + + def exists(self, key: str) -> bool: + return redis_client.exists(key) > 0 + + def refresh_ttl(self, key: str, ttl: int) -> bool: + return bool(redis_client.expire(key, ttl)) + + +class BaseSession(BaseModel): + """Base session model.""" + + id: str = Field(default_factory=lambda: str(uuid.uuid4())) + tenant_id: str + user_id: str + created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) + updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) + context: dict[str, Any] = Field(default_factory=dict) + + def update_timestamp(self) -> None: + self.updated_at = datetime.now(UTC) + + +T = TypeVar("T", bound=BaseSession) + + +class SessionManager(Generic[T]): + """Generic session manager.""" + + DEFAULT_TTL = 7200 # 2 hours + + def __init__( + self, + key_prefix: str, + session_class: type[T], + storage: SessionStorage | None = None, + ttl: int | None = None, + ): + self._key_prefix = key_prefix + self._session_class = session_class + self._storage = storage or RedisSessionStorage() + self._ttl = ttl or self.DEFAULT_TTL + + def _get_key(self, session_id: str) -> str: + return f"{self._key_prefix}:{session_id}" + + def save(self, session: T) -> None: + session.update_timestamp() + key = self._get_key(session.id) + self._storage.set(key, session.model_dump_json(), self._ttl) + + def get(self, session_id: str) -> T | None: + key = self._get_key(session_id) + data = self._storage.get(key) + if data is None: + return None + try: + return self._session_class.model_validate(json.loads(data)) + except (json.JSONDecodeError, ValidationError) as e: + logger.warning("Failed to deserialize session %s: %s", session_id, e) + return None + + def delete(self, session_id: str) -> bool: + return self._storage.delete(self._get_key(session_id)) + + def exists(self, session_id: str) -> bool: + return self._storage.exists(self._get_key(session_id)) + + def refresh_ttl(self, session_id: str) -> bool: + return self._storage.refresh_ttl(self._get_key(session_id), self._ttl) diff --git a/api/core/virtual_environment/session/__init__.py b/api/core/virtual_environment/session/__init__.py new file mode 100644 index 0000000000..a4c24ce169 --- /dev/null +++ b/api/core/virtual_environment/session/__init__.py @@ -0,0 +1,3 @@ +from .sandbox_session import SandboxProvider, SandboxSession, SandboxSessionManager + +__all__ = ["SandboxProvider", "SandboxSession", "SandboxSessionManager"] diff --git a/api/core/virtual_environment/session/sandbox_session.py b/api/core/virtual_environment/session/sandbox_session.py new file mode 100644 index 0000000000..d2bf64d7eb --- /dev/null +++ b/api/core/virtual_environment/session/sandbox_session.py @@ -0,0 +1,47 @@ +from enum import StrEnum +from typing import Any + +from pydantic import Field + +from core.session import BaseSession, SessionManager +from core.virtual_environment.__base.entities import Arch + + +class SandboxProvider(StrEnum): + E2B = "e2b" + DOCKER = "docker" + LOCAL = "local" + + +class SandboxSession(BaseSession): + provider: SandboxProvider + sandbox_id: str + arch: Arch + connection_config: dict[str, Any] = Field(default_factory=dict) + + +class SandboxSessionManager(SessionManager[SandboxSession]): + def __init__(self, ttl: int | None = None): + super().__init__(key_prefix="sandbox_session", session_class=SandboxSession, ttl=ttl) + + def create( + self, + tenant_id: str, + user_id: str, + provider: SandboxProvider, + sandbox_id: str, + arch: Arch, + connection_config: dict[str, Any] | None = None, + context: dict[str, Any] | None = None, + ) -> SandboxSession: + session = SandboxSession( + tenant_id=tenant_id, + user_id=user_id, + provider=provider, + sandbox_id=sandbox_id, + arch=arch, + connection_config=connection_config or {}, + context=context or {}, + ) + self.save(session) + return session