From 0bd17c6d0f0e8bca33f0dc11e6fcd0c6994db449 Mon Sep 17 00:00:00 2001 From: Harry Date: Fri, 16 Jan 2026 18:21:53 +0800 Subject: [PATCH] refactor(sandbox): sandbox provider system default configuration --- .../console/workspace/sandbox_providers.py | 46 +- api/core/sandbox/__init__.py | 7 +- api/core/sandbox/entities/__init__.py | 3 + api/core/sandbox/entities/providers.py | 21 + api/core/sandbox/factory.py | 81 ---- api/core/sandbox/vm.py | 109 +++++ .../__base/virtual_environment.py | 6 + .../providers/docker_daemon_sandbox.py | 8 + .../providers/e2b_sandbox.py | 9 + .../providers/local_without_isolation.py | 7 + ...e916693_sandbox_provider_configure_type.py | 68 +++ api/models/sandbox.py | 3 +- api/services/sandbox/__init__.py | 3 - .../sandbox/sandbox_provider_service.py | 397 ++++++------------ .../core/virtual_environment/test_factory.py | 30 +- .../sandbox-provider-page/provider-card.tsx | 2 +- web/contract/console/sandbox-provider.ts | 20 - web/contract/router.ts | 4 - web/service/use-sandbox-provider.ts | 15 - 19 files changed, 382 insertions(+), 457 deletions(-) create mode 100644 api/core/sandbox/entities/__init__.py create mode 100644 api/core/sandbox/entities/providers.py delete mode 100644 api/core/sandbox/factory.py create mode 100644 api/core/sandbox/vm.py create mode 100644 api/migrations/versions/2026_01_16_1728-45471e916693_sandbox_provider_configure_type.py delete mode 100644 api/services/sandbox/__init__.py diff --git a/api/controllers/console/workspace/sandbox_providers.py b/api/controllers/console/workspace/sandbox_providers.py index 561f2f1cfb..4ffbd3baef 100644 --- a/api/controllers/console/workspace/sandbox_providers.py +++ b/api/controllers/console/workspace/sandbox_providers.py @@ -13,45 +13,18 @@ logger = logging.getLogger(__name__) @console_ns.route("/workspaces/current/sandbox-providers") class SandboxProviderListApi(Resource): - """List all sandbox providers for the current tenant.""" - @console_ns.doc("list_sandbox_providers") @console_ns.doc(description="Get list of available sandbox providers with configuration status") - @console_ns.response( - 200, - "Success", - fields.List(fields.Raw(description="Sandbox provider information")), - ) + @console_ns.response(200, "Success", fields.List(fields.Raw(description="Sandbox provider information"))) @setup_required @login_required @account_initialization_required def get(self): - """List all sandbox providers.""" _, current_tenant_id = current_account_with_tenant() providers = SandboxProviderService.list_providers(current_tenant_id) return jsonable_encoder([p.model_dump() for p in providers]) -@console_ns.route("/workspaces/current/sandbox-provider/") -class SandboxProviderApi(Resource): - """Get specific sandbox provider details.""" - - @console_ns.doc("get_sandbox_provider") - @console_ns.doc(description="Get specific sandbox provider details") - @console_ns.doc(params={"provider_type": "Sandbox provider type (e2b, docker, local)"}) - @console_ns.response(200, "Success", fields.Raw(description="Sandbox provider details")) - @setup_required - @login_required - @account_initialization_required - def get(self, provider_type: str): - """Get a specific sandbox provider.""" - _, current_tenant_id = current_account_with_tenant() - provider = SandboxProviderService.get_provider(current_tenant_id, provider_type) - if not provider: - return {"message": f"Provider {provider_type} not found"}, 404 - return jsonable_encoder(provider.model_dump()) - - config_parser = reqparse.RequestParser() config_parser.add_argument("config", type=dict, required=True, location="json") @@ -120,20 +93,3 @@ class SandboxProviderActivateApi(Resource): return result except ValueError as e: return {"message": str(e)}, 400 - - -@console_ns.route("/workspaces/current/sandbox-provider/active") -class SandboxProviderActiveApi(Resource): - """Get the currently active sandbox provider.""" - - @console_ns.doc("get_active_sandbox_provider") - @console_ns.doc(description="Get the currently active sandbox provider for the workspace") - @console_ns.response(200, "Success") - @setup_required - @login_required - @account_initialization_required - def get(self): - """Get the active sandbox provider.""" - _, current_tenant_id = current_account_with_tenant() - active_provider = SandboxProviderService.get_active_provider(current_tenant_id) - return {"provider_type": active_provider} diff --git a/api/core/sandbox/__init__.py b/api/core/sandbox/__init__.py index 7fd6e4b309..6d309e1af1 100644 --- a/api/core/sandbox/__init__.py +++ b/api/core/sandbox/__init__.py @@ -12,13 +12,13 @@ from .constants import ( DIFY_CLI_PATH, DIFY_CLI_PATH_PATTERN, ) -from .factory import VMBuilder, VMType from .initializer import AppAssetsInitializer, DifyCliInitializer, SandboxInitializer from .manager import SandboxManager from .session import SandboxSession from .storage import ArchiveSandboxStorage, SandboxStorage from .utils.debug import sandbox_debug from .utils.encryption import create_sandbox_config_encrypter, masked_config +from .vm import SandboxBuilder, SandboxType, VMConfig __all__ = [ "APP_ASSETS_PATH", @@ -34,12 +34,13 @@ __all__ = [ "DifyCliInitializer", "DifyCliLocator", "DifyCliToolConfig", + "SandboxBuilder", "SandboxInitializer", "SandboxManager", "SandboxSession", "SandboxStorage", - "VMBuilder", - "VMType", + "SandboxType", + "VMConfig", "create_sandbox_config_encrypter", "masked_config", "sandbox_debug", diff --git a/api/core/sandbox/entities/__init__.py b/api/core/sandbox/entities/__init__.py new file mode 100644 index 0000000000..6829ca31ba --- /dev/null +++ b/api/core/sandbox/entities/__init__.py @@ -0,0 +1,3 @@ +from .providers import SandboxProviderApiEntity + +__all__ = ["SandboxProviderApiEntity"] diff --git a/api/core/sandbox/entities/providers.py b/api/core/sandbox/entities/providers.py new file mode 100644 index 0000000000..82c00cb144 --- /dev/null +++ b/api/core/sandbox/entities/providers.py @@ -0,0 +1,21 @@ +from collections.abc import Mapping +from typing import Any + +from pydantic import BaseModel, Field + + +class SandboxProviderApiEntity(BaseModel): + provider_type: str = Field(..., description="Provider type identifier") + is_system_configured: bool = Field(default=False) + is_tenant_configured: bool = Field(default=False) + is_active: bool = Field(default=False) + config: Mapping[str, Any] = Field(default_factory=dict) + config_schema: list[dict[str, Any]] = Field(default_factory=list) + + +class SandboxProviderEntity(BaseModel): + id: str = Field(..., description="Provider identifier") + provider_type: str = Field(..., description="Provider type identifier") + is_active: bool = Field(default=False) + config: Mapping[str, Any] = Field(default_factory=dict) + config_schema: list[dict[str, Any]] = Field(default_factory=list) diff --git a/api/core/sandbox/factory.py b/api/core/sandbox/factory.py deleted file mode 100644 index bd248bb6c8..0000000000 --- a/api/core/sandbox/factory.py +++ /dev/null @@ -1,81 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping, Sequence -from enum import StrEnum -from typing import TYPE_CHECKING, Any - -from core.virtual_environment.__base.virtual_environment import VirtualEnvironment - -if TYPE_CHECKING: - from .initializer import SandboxInitializer - - -class VMType(StrEnum): - DOCKER = "docker" - E2B = "e2b" - LOCAL = "local" - - -def _get_vm_class(vm_type: VMType) -> type[VirtualEnvironment]: - match vm_type: - case VMType.DOCKER: - from core.virtual_environment.providers.docker_daemon_sandbox import DockerDaemonEnvironment - - return DockerDaemonEnvironment - case VMType.E2B: - from core.virtual_environment.providers.e2b_sandbox import E2BEnvironment - - return E2BEnvironment - case VMType.LOCAL: - from core.virtual_environment.providers.local_without_isolation import LocalVirtualEnvironment - - return LocalVirtualEnvironment - case _: - raise ValueError(f"Unsupported VM type: {vm_type}") - - -class VMBuilder: - def __init__(self, tenant_id: str, vm_type: VMType) -> None: - self._tenant_id = tenant_id - self._vm_type = vm_type - self._user_id: str | None = None - self._options: dict[str, Any] = {} - self._environments: dict[str, str] = {} - self._initializers: list[SandboxInitializer] = [] - - def user(self, user_id: str) -> VMBuilder: - self._user_id = user_id - return self - - def options(self, options: Mapping[str, Any]) -> VMBuilder: - self._options = dict(options) - return self - - def environments(self, environments: Mapping[str, str]) -> VMBuilder: - self._environments = dict(environments) - return self - - def initializer(self, initializer: SandboxInitializer) -> VMBuilder: - self._initializers.append(initializer) - return self - - def initializers(self, initializers: Sequence[SandboxInitializer]) -> VMBuilder: - self._initializers.extend(initializers) - return self - - def build(self) -> VirtualEnvironment: - vm_class = _get_vm_class(self._vm_type) - vm = vm_class( - tenant_id=self._tenant_id, - options=self._options, - environments=self._environments, - user_id=self._user_id, - ) - for init in self._initializers: - init.initialize(vm) - return vm - - @staticmethod - def validate(vm_type: VMType, options: Mapping[str, Any]) -> None: - vm_class = _get_vm_class(vm_type) - vm_class.validate(options) diff --git a/api/core/sandbox/vm.py b/api/core/sandbox/vm.py new file mode 100644 index 0000000000..911d442992 --- /dev/null +++ b/api/core/sandbox/vm.py @@ -0,0 +1,109 @@ +""" +Facade module for virtual machine providers. + +Provides unified interfaces to access different VM provider implementations +(E2B, Docker, Local) through VMType, VMBuilder, and VMConfig. +""" + +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from enum import StrEnum +from typing import Any + +from configs import dify_config +from core.entities.provider_entities import BasicProviderConfig +from core.virtual_environment.__base.virtual_environment import VirtualEnvironment + +from .initializer import SandboxInitializer + + +class SandboxType(StrEnum): + """ + Sandbox types. + """ + + DOCKER = "docker" + E2B = "e2b" + LOCAL = "local" + + @classmethod + def get_all(cls) -> list[str]: + """ + Get all available sandbox types. + """ + if dify_config.EDITION == "SELF_HOSTED": + return [p.value for p in cls] + else: + return [p.value for p in cls if p != SandboxType.LOCAL] + + +def _get_sandbox_class(sandbox_type: SandboxType) -> type[VirtualEnvironment]: + match sandbox_type: + case SandboxType.DOCKER: + from core.virtual_environment.providers.docker_daemon_sandbox import DockerDaemonEnvironment + + return DockerDaemonEnvironment + case SandboxType.E2B: + from core.virtual_environment.providers.e2b_sandbox import E2BEnvironment + + return E2BEnvironment + case SandboxType.LOCAL: + from core.virtual_environment.providers.local_without_isolation import LocalVirtualEnvironment + + return LocalVirtualEnvironment + case _: + raise ValueError(f"Unsupported sandbox type: {sandbox_type}") + + +class SandboxBuilder: + def __init__(self, tenant_id: str, sandbox_type: SandboxType) -> None: + self._tenant_id = tenant_id + self._sandbox_type = sandbox_type + self._user_id: str | None = None + self._options: dict[str, Any] = {} + self._environments: dict[str, str] = {} + self._initializers: list[SandboxInitializer] = [] + + def user(self, user_id: str) -> SandboxBuilder: + self._user_id = user_id + return self + + def options(self, options: Mapping[str, Any]) -> SandboxBuilder: + self._options = dict(options) + return self + + def environments(self, environments: Mapping[str, str]) -> SandboxBuilder: + self._environments = dict(environments) + return self + + def initializer(self, initializer: SandboxInitializer) -> SandboxBuilder: + self._initializers.append(initializer) + return self + + def initializers(self, initializers: Sequence[SandboxInitializer]) -> SandboxBuilder: + self._initializers.extend(initializers) + return self + + def build(self) -> VirtualEnvironment: + vm_class = _get_sandbox_class(self._sandbox_type) + vm = vm_class( + tenant_id=self._tenant_id, + options=self._options, + environments=self._environments, + user_id=self._user_id, + ) + for init in self._initializers: + init.initialize(vm) + return vm + + @staticmethod + def validate(vm_type: SandboxType, options: Mapping[str, Any]) -> None: + vm_class = _get_sandbox_class(vm_type) + vm_class.validate(options) + + +class VMConfig: + @staticmethod + def get_schema(vm_type: SandboxType) -> list[BasicProviderConfig]: + return _get_sandbox_class(vm_type).get_config_schema() diff --git a/api/core/virtual_environment/__base/virtual_environment.py b/api/core/virtual_environment/__base/virtual_environment.py index 6fe4f0f081..826d3a6ec7 100644 --- a/api/core/virtual_environment/__base/virtual_environment.py +++ b/api/core/virtual_environment/__base/virtual_environment.py @@ -3,6 +3,7 @@ from collections.abc import Mapping, Sequence from io import BytesIO from typing import Any +from core.entities.provider_entities import BasicProviderConfig from core.virtual_environment.__base.entities import CommandStatus, ConnectionHandle, FileState, Metadata from core.virtual_environment.channel.transport import TransportReadCloser, TransportWriteCloser @@ -174,3 +175,8 @@ class VirtualEnvironment(ABC): Returns: CommandStatus: The status of the command execution. """ + + @classmethod + @abstractmethod + def get_config_schema(cls) -> list[BasicProviderConfig]: + pass diff --git a/api/core/virtual_environment/providers/docker_daemon_sandbox.py b/api/core/virtual_environment/providers/docker_daemon_sandbox.py index a326270298..ccbc9699d2 100644 --- a/api/core/virtual_environment/providers/docker_daemon_sandbox.py +++ b/api/core/virtual_environment/providers/docker_daemon_sandbox.py @@ -15,6 +15,7 @@ import docker.errors from docker.models.containers import Container import docker +from core.entities.provider_entities import BasicProviderConfig from core.virtual_environment.__base.entities import ( Arch, CommandStatus, @@ -256,6 +257,13 @@ class DockerDaemonEnvironment(VirtualEnvironment): DOCKER_IMAGE = "docker_image" DOCKER_COMMAND = "docker_command" + @classmethod + def get_config_schema(cls) -> list[BasicProviderConfig]: + return [ + BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name=cls.OptionsKey.DOCKER_SOCK), + BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name=cls.OptionsKey.DOCKER_IMAGE), + ] + @classmethod def validate(cls, options: Mapping[str, Any]) -> None: docker_sock = options.get(cls.OptionsKey.DOCKER_SOCK, cls._DEFAULT_DOCKER_SOCK) diff --git a/api/core/virtual_environment/providers/e2b_sandbox.py b/api/core/virtual_environment/providers/e2b_sandbox.py index 89d1c6ec83..58386585e7 100644 --- a/api/core/virtual_environment/providers/e2b_sandbox.py +++ b/api/core/virtual_environment/providers/e2b_sandbox.py @@ -10,6 +10,7 @@ from uuid import uuid4 from e2b_code_interpreter import Sandbox # type: ignore[import-untyped] +from core.entities.provider_entities import BasicProviderConfig from core.virtual_environment.__base.entities import ( Arch, CommandStatus, @@ -96,6 +97,14 @@ class E2BEnvironment(VirtualEnvironment): class StoreKey(StrEnum): SANDBOX = "sandbox" + @classmethod + def get_config_schema(cls) -> list[BasicProviderConfig]: + return [ + BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=cls.OptionsKey.API_KEY), + BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name=cls.OptionsKey.E2B_API_URL), + BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name=cls.OptionsKey.E2B_DEFAULT_TEMPLATE), + ] + @classmethod def validate(cls, options: Mapping[str, Any]) -> None: from e2b.exceptions import AuthenticationException # type: ignore[import-untyped] diff --git a/api/core/virtual_environment/providers/local_without_isolation.py b/api/core/virtual_environment/providers/local_without_isolation.py index 6888ffe226..f9ff34c73a 100644 --- a/api/core/virtual_environment/providers/local_without_isolation.py +++ b/api/core/virtual_environment/providers/local_without_isolation.py @@ -8,6 +8,7 @@ from platform import machine, system from typing import Any from uuid import uuid4 +from core.entities.provider_entities import BasicProviderConfig from core.virtual_environment.__base.entities import ( Arch, CommandStatus, @@ -72,6 +73,12 @@ class LocalVirtualEnvironment(VirtualEnvironment): NEVER USE IT IN PRODUCTION ENVIRONMENTS. """ + @classmethod + def get_config_schema(cls) -> list[BasicProviderConfig]: + return [ + BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name="base_working_path"), + ] + @classmethod def validate(cls, options: Mapping[str, Any]) -> None: pass diff --git a/api/migrations/versions/2026_01_16_1728-45471e916693_sandbox_provider_configure_type.py b/api/migrations/versions/2026_01_16_1728-45471e916693_sandbox_provider_configure_type.py new file mode 100644 index 0000000000..8b89c0d496 --- /dev/null +++ b/api/migrations/versions/2026_01_16_1728-45471e916693_sandbox_provider_configure_type.py @@ -0,0 +1,68 @@ +"""sandbox_provider_configure_type + +Revision ID: 45471e916693 +Revises: d88f3edbd99d +Create Date: 2026-01-16 17:28:46.691473 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '45471e916693' +down_revision = 'd88f3edbd99d' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('tenant_credit_pools', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('pool_type', sa.String(length=40), server_default='trial', nullable=False), + sa.Column('quota_limit', sa.BigInteger(), nullable=False), + sa.Column('quota_used', sa.BigInteger(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='tenant_credit_pool_pkey') + ) + with op.batch_alter_table('tenant_credit_pools', schema=None) as batch_op: + batch_op.create_index('tenant_credit_pool_pool_type_idx', ['pool_type'], unique=False) + batch_op.create_index('tenant_credit_pool_tenant_id_idx', ['tenant_id'], unique=False) + + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.create_index('message_created_at_id_idx', ['created_at', 'id'], unique=False) + + with op.batch_alter_table('sandbox_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('configure_type', sa.String(length=20), server_default='user', nullable=False)) + batch_op.drop_constraint(batch_op.f('unique_sandbox_provider_tenant_type'), type_='unique') + batch_op.create_unique_constraint('unique_sandbox_provider_tenant_type', ['tenant_id', 'provider_type', 'configure_type']) + + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.create_index('workflow_run_created_at_id_idx', ['created_at', 'id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.drop_index('workflow_run_created_at_id_idx') + + with op.batch_alter_table('sandbox_providers', schema=None) as batch_op: + batch_op.drop_constraint('unique_sandbox_provider_tenant_type', type_='unique') + batch_op.create_unique_constraint(batch_op.f('unique_sandbox_provider_tenant_type'), ['tenant_id', 'provider_type'], postgresql_nulls_not_distinct=False) + batch_op.drop_column('configure_type') + + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.drop_index('message_created_at_id_idx') + + with op.batch_alter_table('tenant_credit_pools', schema=None) as batch_op: + batch_op.drop_index('tenant_credit_pool_tenant_id_idx') + batch_op.drop_index('tenant_credit_pool_pool_type_idx') + + op.drop_table('tenant_credit_pools') + # ### end Alembic commands ### diff --git a/api/models/sandbox.py b/api/models/sandbox.py index 1ca95eba29..2e64cfbfd9 100644 --- a/api/models/sandbox.py +++ b/api/models/sandbox.py @@ -51,7 +51,7 @@ class SandboxProvider(TypeBase): __tablename__ = "sandbox_providers" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="sandbox_provider_pkey"), - sa.UniqueConstraint("tenant_id", "provider_type", name="unique_sandbox_provider_tenant_type"), + sa.UniqueConstraint("tenant_id", "provider_type", "configure_type", name="unique_sandbox_provider_tenant_type"), sa.Index("idx_sandbox_providers_tenant_id", "tenant_id"), sa.Index("idx_sandbox_providers_tenant_active", "tenant_id", "is_active"), ) @@ -62,6 +62,7 @@ class SandboxProvider(TypeBase): tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) provider_type: Mapped[str] = mapped_column(String(50), nullable=False, comment="e2b, docker, local") encrypted_config: Mapped[str] = mapped_column(LongText, nullable=False, comment="Encrypted config JSON") + configure_type: Mapped[str] = mapped_column(String(20), nullable=False, server_default="user", default="user") is_active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"), default=False) created_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=func.current_timestamp(), init=False diff --git a/api/services/sandbox/__init__.py b/api/services/sandbox/__init__.py deleted file mode 100644 index d450b3aab8..0000000000 --- a/api/services/sandbox/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .sandbox_provider_service import SandboxProviderService - -__all__ = ["SandboxProviderService"] diff --git a/api/services/sandbox/sandbox_provider_service.py b/api/services/sandbox/sandbox_provider_service.py index 09eaeb2be0..77a6c88f2a 100644 --- a/api/services/sandbox/sandbox_provider_service.py +++ b/api/services/sandbox/sandbox_provider_service.py @@ -1,332 +1,191 @@ -""" -Sandbox Provider Service for managing sandbox configurations. - -Supports three provider types: -- e2b: Cloud-based sandbox (requires API key) -- docker: Local Docker-based sandbox (self-hosted) -- local: Local execution without isolation (self-hosted only) -""" - import json import logging from collections.abc import Mapping -from enum import StrEnum from typing import Any -from pydantic import BaseModel, Field, model_validator from sqlalchemy.orm import Session -from configs import dify_config from constants import HIDDEN_VALUE -from core.entities.provider_entities import BasicProviderConfig -from core.sandbox import VMBuilder, VMType, create_sandbox_config_encrypter, masked_config -from core.tools.utils.system_encryption import ( - decrypt_system_params, -) +from core.sandbox import SandboxBuilder, SandboxType, VMConfig, create_sandbox_config_encrypter, masked_config +from core.sandbox.entities import SandboxProviderApiEntity +from core.sandbox.entities.providers import SandboxProviderEntity +from core.tools.utils.system_encryption import decrypt_system_params from extensions.ext_database import db from models.sandbox import SandboxProvider, SandboxProviderSystemConfig logger = logging.getLogger(__name__) -class SandboxProviderType(StrEnum): - E2B = "e2b" - DOCKER = "docker" - LOCAL = "local" +def _get_encrypter(tenant_id: str, provider_type: str): + return create_sandbox_config_encrypter(tenant_id, VMConfig.get_schema(SandboxType(provider_type)), provider_type)[0] -class E2BConfig(BaseModel): - api_key: str = "" - e2b_api_url: str = "https://api.e2b.app" - e2b_default_template: str = "code-interpreter-v1" - - @model_validator(mode="before") - @classmethod - def check_required(cls, values: dict[str, Any]) -> dict[str, Any]: - if not values.get("api_key"): - raise ValueError("api_key is required") - return values - - -class DockerConfig(BaseModel): - docker_sock: str = "unix:///var/run/docker.sock" - docker_image: str = "ubuntu:latest" - - -class LocalConfig(BaseModel): - pass - - -PROVIDER_CONFIG_MODELS: dict[str, type[BaseModel]] = { - SandboxProviderType.E2B: E2BConfig, - SandboxProviderType.DOCKER: DockerConfig, - SandboxProviderType.LOCAL: LocalConfig, -} - -PROVIDER_CONFIG_SCHEMAS: dict[str, list[BasicProviderConfig]] = { - SandboxProviderType.E2B: [ - BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name="api_key"), - BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name="e2b_api_url"), - BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name="e2b_default_template"), - ], - SandboxProviderType.DOCKER: [ - BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name="docker_sock"), - BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name="docker_image"), - ], - SandboxProviderType.LOCAL: [ - BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name="base_working_path"), - ], -} - - -class SandboxProviderInfo(BaseModel): - provider_type: str = Field(..., description="Provider type identifier") - label: str = Field(..., description="Display name") - description: str = Field(..., description="Provider description") - icon: str = Field(..., description="Icon identifier") - is_system_configured: bool = Field(default=False, description="Whether system default is configured") - is_tenant_configured: bool = Field(default=False, description="Whether tenant has custom config") - is_active: bool = Field(default=False, description="Whether this provider is active for the tenant") - config: Mapping[str, Any] = Field(default_factory=dict, description="Masked config") - config_schema: list[dict[str, Any]] = Field(default_factory=list, description="Config form schema") - - -PROVIDER_METADATA: dict[str, dict[str, str]] = { - SandboxProviderType.E2B: { - "label": "E2B", - "description": "Cloud-based sandbox powered by E2B. Secure, scalable, and managed.", - "icon": "e2b", - }, - SandboxProviderType.DOCKER: { - "label": "Docker", - "description": "Local Docker-based sandbox. Requires Docker daemon running on the host.", - "icon": "docker", - }, - SandboxProviderType.LOCAL: { - "label": "Local", - "description": "Local execution without isolation. Only for development/testing.", - "icon": "local", - }, -} +def _query_tenant_config(session: Session, tenant_id: str, provider_type: str) -> SandboxProvider | None: + return ( + session.query(SandboxProvider) + .filter(SandboxProvider.tenant_id == tenant_id, SandboxProvider.provider_type == provider_type) + .first() + ) class SandboxProviderService: @classmethod - def get_available_provider_types(cls) -> list[str]: - providers = [SandboxProviderType.E2B, SandboxProviderType.DOCKER] - if dify_config.EDITION == "SELF_HOSTED": - providers.append(SandboxProviderType.LOCAL) - return [provider.value for provider in providers] - - @classmethod - def list_providers(cls, tenant_id: str) -> list[SandboxProviderInfo]: - result: list[SandboxProviderInfo] = [] - + def list_providers(cls, tenant_id: str) -> list[SandboxProviderApiEntity]: with Session(db.engine, expire_on_commit=False) as session: + provider_types = SandboxType.get_all() tenant_configs = { - cfg.provider_type: cfg - for cfg in session.query(SandboxProvider).filter(SandboxProvider.tenant_id == tenant_id).all() + config.provider_type: config + for config in session.query(SandboxProvider).filter(SandboxProvider.tenant_id == tenant_id).all() + } + system_configs = { + config.provider_type: config + for config in session.query(SandboxProviderSystemConfig) + .filter(SandboxProviderSystemConfig.provider_type.in_(provider_types)) + .all() } - system_defaults = {cfg.provider_type for cfg in session.query(SandboxProviderSystemConfig).all()} - for provider_type in cls.get_available_provider_types(): + providers: list[SandboxProviderApiEntity] = [] + current_provider = cls.get_active_sandbox_config(session, tenant_id) + for provider_type in SandboxType.get_all(): tenant_config = tenant_configs.get(provider_type) - schema = PROVIDER_CONFIG_SCHEMAS.get(provider_type, []) - metadata = PROVIDER_METADATA.get(provider_type, {}) - - config: Mapping[str, Any] = {} - if tenant_config and tenant_config.config: - encrypter, _ = create_sandbox_config_encrypter(tenant_id, schema, provider_type) - config = masked_config(schema, encrypter.decrypt(tenant_config.config)) - - result.append( - SandboxProviderInfo( - provider_type=provider_type, - label=metadata.get("label", provider_type), - description=metadata.get("description", ""), - icon=metadata.get("icon", provider_type), - is_system_configured=provider_type in system_defaults and tenant_config is None, - is_tenant_configured=tenant_config is not None, - is_active=tenant_config.is_active if tenant_config else False, - config=config, - config_schema=[{"name": c.name, "type": c.type.value} for c in schema], + schema = VMConfig.get_schema(SandboxType(provider_type)) + if tenant_config: + is_tenant_configured = tenant_config.configure_type == "user" + if is_tenant_configured: + decrypted_config = _get_encrypter(tenant_id, provider_type).decrypt(data=tenant_config.config) + config = masked_config(schemas=schema, config=decrypted_config) + else: + config = {} + providers.append( + SandboxProviderApiEntity( + provider_type=provider_type, + is_system_configured=system_configs.get(provider_type) is not None, + is_tenant_configured=is_tenant_configured, + is_active=current_provider.id == tenant_config.id, + config=config, + config_schema=[c.model_dump() for c in schema], + ) ) - ) - - return result - - @classmethod - def get_provider(cls, tenant_id: str, provider_type: str) -> SandboxProviderInfo | None: - if provider_type not in cls.get_available_provider_types(): - return None - - providers = cls.list_providers(tenant_id) - for provider in providers: - if provider.provider_type == provider_type: - return provider - return None + else: + system_config = system_configs.get(provider_type) + providers.append( + SandboxProviderApiEntity( + provider_type=provider_type, + is_active=system_config is not None and system_config.id == current_provider.id, + is_system_configured=system_config is not None, + config_schema=[c.model_dump() for c in schema], + ) + ) + return providers @classmethod def validate_config(cls, provider_type: str, config: Mapping[str, Any]) -> None: - model_class = PROVIDER_CONFIG_MODELS.get(provider_type) - if model_class: - model_class.model_validate(config) - - VMBuilder.validate(VMType(provider_type), config) + SandboxBuilder.validate(SandboxType(provider_type), config) @classmethod - def save_config( - cls, - tenant_id: str, - provider_type: str, - config: Mapping[str, Any], - ) -> dict[str, Any]: - if provider_type not in cls.get_available_provider_types(): + def save_config(cls, tenant_id: str, provider_type: str, config: Mapping[str, Any]) -> dict[str, Any]: + if provider_type not in SandboxType.get_all(): raise ValueError(f"Invalid provider type: {provider_type}") with Session(db.engine) as session: - existing = ( - session.query(SandboxProvider) - .filter( - SandboxProvider.tenant_id == tenant_id, - SandboxProvider.provider_type == provider_type, - ) - .first() - ) - - schema = PROVIDER_CONFIG_SCHEMAS.get(provider_type, []) - encrypter, _ = create_sandbox_config_encrypter(tenant_id, schema, provider_type) - - final_config = dict(config) - if existing and existing.config: - existing_config = encrypter.decrypt(existing.config) - for key, value in final_config.items(): - if value == HIDDEN_VALUE: - final_config[key] = existing_config.get(key, "") - - cls.validate_config(provider_type, final_config) - - encrypted = encrypter.encrypt(final_config) - - if existing: - existing.encrypted_config = json.dumps(encrypted) - else: - new_config = SandboxProvider( + provider = _query_tenant_config(session, tenant_id, provider_type) + encrypter = _get_encrypter(tenant_id, provider_type) + if not provider: + provider = SandboxProvider( tenant_id=tenant_id, provider_type=provider_type, - encrypted_config=json.dumps(encrypted), - is_active=False, + encrypted_config=json.dumps({}), ) - session.add(new_config) + session.add(provider) + new_config = dict(config) + old_config = encrypter.decrypt(provider.config) + for key, value in new_config.items(): + if value == HIDDEN_VALUE: + new_config[key] = old_config.get(key, "") + + cls.validate_config(provider_type, new_config) + + provider.encrypted_config = json.dumps(encrypter.encrypt(new_config)) + provider.is_active = provider.is_active or cls.is_system_default_config(session, tenant_id) + provider.configure_type = "user" session.commit() - return {"result": "success"} @classmethod def delete_config(cls, tenant_id: str, provider_type: str) -> dict[str, Any]: with Session(db.engine) as session: - config = ( - session.query(SandboxProvider) - .filter( - SandboxProvider.tenant_id == tenant_id, - SandboxProvider.provider_type == provider_type, - ) - .first() - ) - - if not config: - return {"result": "success"} - - session.delete(config) - session.commit() - + if config := _query_tenant_config(session, tenant_id, provider_type): + session.delete(config) + session.commit() return {"result": "success"} + @classmethod + def is_system_default_config(cls, session: Session, tenant_id: str) -> bool: + system_configed: SandboxProviderSystemConfig | None = session.query(SandboxProviderSystemConfig).first() + if not system_configed: + return False + active_config = cls.get_active_sandbox_config(session, tenant_id) + return active_config.id == system_configed.id + @classmethod def activate_provider(cls, tenant_id: str, provider_type: str) -> dict[str, Any]: - if provider_type not in cls.get_available_provider_types(): + if provider_type not in SandboxType.get_all(): raise ValueError(f"Invalid provider type: {provider_type}") with Session(db.engine) as session: - tenant_config = ( - session.query(SandboxProvider) - .filter( - SandboxProvider.tenant_id == tenant_id, - SandboxProvider.provider_type == provider_type, - ) - .first() - ) + tenant_config = _query_tenant_config(session, tenant_id, provider_type) + system_config = session.query(SandboxProviderSystemConfig).filter_by(provider_type=provider_type).first() - system_default = ( - session.query(SandboxProviderSystemConfig) - .filter(SandboxProviderSystemConfig.provider_type == provider_type) - .first() - ) - - config_schema = PROVIDER_CONFIG_SCHEMAS.get(provider_type, []) - needs_config = len(config_schema) > 0 - - if needs_config and not tenant_config and not system_default: - raise ValueError(f"Provider {provider_type} is not configured. Please add configuration first.") - - session.query(SandboxProvider).filter( - SandboxProvider.tenant_id == tenant_id, - ).update({"is_active": False}) + session.query(SandboxProvider).filter(SandboxProvider.tenant_id == tenant_id).update({"is_active": False}) + # using tenant config if tenant_config: tenant_config.is_active = True - else: - new_config = SandboxProvider( - tenant_id=tenant_id, - provider_type=provider_type, - encrypted_config=json.dumps({}), - is_active=True, + session.commit() + return {"result": "success"} + + # using system config + if system_config: + session.add( + SandboxProvider( + is_active=True, + tenant_id=tenant_id, + configure_type="system", + provider_type=provider_type, + encrypted_config=json.dumps({}), + ) ) - session.add(new_config) + session.commit() + return {"result": "success"} - session.commit() - - return {"result": "success"} + raise ValueError(f"No sandbox provider configured for tenant {tenant_id} and provider type {provider_type}") @classmethod - def get_active_provider(cls, tenant_id: str) -> str | None: - with Session(db.engine, expire_on_commit=False) as session: - config = ( - session.query(SandboxProvider) - .filter( - SandboxProvider.tenant_id == tenant_id, - SandboxProvider.is_active.is_(True), - ) - .first() + def get_active_sandbox_config(cls, session: Session, tenant_id: str) -> SandboxProviderEntity: + tenant_configed = ( + session.query(SandboxProvider) + .filter(SandboxProvider.tenant_id == tenant_id, SandboxProvider.is_active.is_(True)) + .first() + ) + if tenant_configed: + config = _get_encrypter(tenant_id, tenant_configed.provider_type).decrypt(tenant_configed.config) + return SandboxProviderEntity( + id=tenant_configed.id, provider_type=tenant_configed.provider_type, config=config ) - return config.provider_type if config else None + + system_configed: SandboxProviderSystemConfig | None = session.query(SandboxProviderSystemConfig).first() + if system_configed: + return SandboxProviderEntity( + id=system_configed.id, + provider_type=system_configed.provider_type, + config=decrypt_system_params(system_configed.encrypted_config), + ) + + raise ValueError(f"No sandbox provider configured for tenant {tenant_id}") @classmethod - def create_sandbox_builder(cls, tenant_id: str) -> VMBuilder: + def create_sandbox_builder(cls, tenant_id: str) -> SandboxBuilder: with Session(db.engine, expire_on_commit=False) as session: - tenant_config = ( - session.query(SandboxProvider) - .filter( - SandboxProvider.tenant_id == tenant_id, - SandboxProvider.is_active.is_(True), - ) - .first() - ) - config: Mapping[str, Any] = {} - provider_type = None - if tenant_config: - schema = PROVIDER_CONFIG_SCHEMAS.get(tenant_config.provider_type, []) - encrypter, _ = create_sandbox_config_encrypter(tenant_id, schema, tenant_config.provider_type) - config = encrypter.decrypt(tenant_config.config) - provider_type = tenant_config.provider_type - else: - system_default = session.query(SandboxProviderSystemConfig).first() - if system_default: - config = decrypt_system_params(system_default.encrypted_config) - provider_type = system_default.provider_type - - if not config or not provider_type: - raise ValueError(f"No active sandbox provider for tenant {tenant_id} or system default") - - return VMBuilder(tenant_id, VMType(provider_type)).options(config) + provider_type, config = cls.get_active_sandbox_config(session, tenant_id) + return SandboxBuilder(tenant_id, SandboxType(provider_type)).options(config) diff --git a/api/tests/unit_tests/core/virtual_environment/test_factory.py b/api/tests/unit_tests/core/virtual_environment/test_factory.py index 6a68f3b454..edf8e3e499 100644 --- a/api/tests/unit_tests/core/virtual_environment/test_factory.py +++ b/api/tests/unit_tests/core/virtual_environment/test_factory.py @@ -3,20 +3,20 @@ from unittest.mock import MagicMock, patch import pytest -from core.sandbox import VMBuilder, VMType +from core.sandbox import SandboxBuilder, SandboxType from core.virtual_environment.__base.virtual_environment import VirtualEnvironment class TestVMType: def test_values(self): - assert VMType.DOCKER == "docker" - assert VMType.E2B == "e2b" - assert VMType.LOCAL == "local" + assert SandboxType.DOCKER == "docker" + assert SandboxType.E2B == "e2b" + assert SandboxType.LOCAL == "local" def test_is_string_enum(self): - assert isinstance(VMType.DOCKER.value, str) - assert isinstance(VMType.E2B.value, str) - assert isinstance(VMType.LOCAL.value, str) + assert isinstance(SandboxType.DOCKER.value, str) + assert isinstance(SandboxType.E2B.value, str) + assert isinstance(SandboxType.LOCAL.value, str) class TestVMBuilder: @@ -29,7 +29,7 @@ class TestVMBuilder: mock_class, ): result = ( - VMBuilder("test-tenant", VMType.DOCKER) + SandboxBuilder("test-tenant", SandboxType.DOCKER) .options({"docker_image": "python:3.11-slim"}) .environments({"PYTHONUNBUFFERED": "1"}) .build() @@ -51,7 +51,7 @@ class TestVMBuilder: "core.virtual_environment.providers.docker_daemon_sandbox.DockerDaemonEnvironment", mock_class, ): - VMBuilder("test-tenant", VMType.DOCKER).user("user-123").build() + SandboxBuilder("test-tenant", SandboxType.DOCKER).user("user-123").build() mock_class.assert_called_once_with( tenant_id="test-tenant", @@ -69,7 +69,7 @@ class TestVMBuilder: "core.virtual_environment.providers.docker_daemon_sandbox.DockerDaemonEnvironment", mock_class, ): - VMBuilder("test-tenant", VMType.DOCKER).initializer(mock_initializer).build() + SandboxBuilder("test-tenant", SandboxType.DOCKER).initializer(mock_initializer).build() mock_initializer.initialize.assert_called_once_with(mock_instance) @@ -80,7 +80,7 @@ class TestVMBuilder: "core.virtual_environment.providers.local_without_isolation.LocalVirtualEnvironment", return_value=mock_instance, ) as mock_class: - VMBuilder("test-tenant", VMType.LOCAL).build() + SandboxBuilder("test-tenant", SandboxType.LOCAL).build() mock_class.assert_called_once() def test_build_e2b(self): @@ -90,12 +90,12 @@ class TestVMBuilder: "core.virtual_environment.providers.e2b_sandbox.E2BEnvironment", return_value=mock_instance, ) as mock_class: - VMBuilder("test-tenant", VMType.E2B).build() + SandboxBuilder("test-tenant", SandboxType.E2B).build() mock_class.assert_called_once() def test_build_unsupported_type_raises(self): with pytest.raises(ValueError, match="Unsupported VM type"): - VMBuilder("test-tenant", "unsupported").build() # type: ignore[arg-type] + SandboxBuilder("test-tenant", "unsupported").build() # type: ignore[arg-type] def test_validate(self): mock_class = MagicMock() @@ -104,13 +104,13 @@ class TestVMBuilder: "core.virtual_environment.providers.docker_daemon_sandbox.DockerDaemonEnvironment", mock_class, ): - VMBuilder.validate(VMType.DOCKER, {"key": "value"}) + SandboxBuilder.validate(SandboxType.DOCKER, {"key": "value"}) mock_class.validate.assert_called_once_with({"key": "value"}) class TestVMBuilderIntegration: def test_local_sandbox(self, tmp_path: Path): - sandbox = VMBuilder("test-tenant", VMType.LOCAL).options({"base_working_path": str(tmp_path)}).build() + sandbox = SandboxBuilder("test-tenant", SandboxType.LOCAL).options({"base_working_path": str(tmp_path)}).build() try: assert sandbox is not None diff --git a/web/app/components/header/account-setting/sandbox-provider-page/provider-card.tsx b/web/app/components/header/account-setting/sandbox-provider-page/provider-card.tsx index 21cfd18da2..c1a529402f 100644 --- a/web/app/components/header/account-setting/sandbox-provider-page/provider-card.tsx +++ b/web/app/components/header/account-setting/sandbox-provider-page/provider-card.tsx @@ -47,7 +47,7 @@ const ProviderCard = ({ {provider.label} - {provider.is_system_configured && ( + {provider.is_system_configured && !provider.is_tenant_configured && ( {t('sandboxProvider.managedBySaas', { ns: 'common' })} diff --git a/web/contract/console/sandbox-provider.ts b/web/contract/console/sandbox-provider.ts index 5084e24355..5f78e26150 100644 --- a/web/contract/console/sandbox-provider.ts +++ b/web/contract/console/sandbox-provider.ts @@ -10,18 +10,6 @@ export const getSandboxProviderListContract = base .input(type()) .output(type()) -export const getSandboxProviderContract = base - .route({ - path: '/workspaces/current/sandbox-provider/{providerType}', - method: 'GET', - }) - .input(type<{ - params: { - providerType: string - } - }>()) - .output(type()) - export const saveSandboxProviderConfigContract = base .route({ path: '/workspaces/current/sandbox-provider/{providerType}/config', @@ -60,11 +48,3 @@ export const activateSandboxProviderContract = base } }>()) .output(type<{ result: string }>()) - -export const getActiveSandboxProviderContract = base - .route({ - path: '/workspaces/current/sandbox-provider/active', - method: 'GET', - }) - .input(type()) - .output(type<{ provider_type: string | null }>()) diff --git a/web/contract/router.ts b/web/contract/router.ts index b2422b6cc1..5d6514b47d 100644 --- a/web/contract/router.ts +++ b/web/contract/router.ts @@ -16,8 +16,6 @@ import { bindPartnerStackContract, invoicesContract } from './console/billing' import { activateSandboxProviderContract, deleteSandboxProviderConfigContract, - getActiveSandboxProviderContract, - getSandboxProviderContract, getSandboxProviderListContract, saveSandboxProviderConfigContract, } from './console/sandbox-provider' @@ -40,11 +38,9 @@ export const consoleRouterContract = { }, sandboxProvider: { getSandboxProviderList: getSandboxProviderListContract, - getSandboxProvider: getSandboxProviderContract, saveSandboxProviderConfig: saveSandboxProviderConfigContract, deleteSandboxProviderConfig: deleteSandboxProviderConfigContract, activateSandboxProvider: activateSandboxProviderContract, - getActiveSandboxProvider: getActiveSandboxProviderContract, }, appAsset: { tree: treeContract, diff --git a/web/service/use-sandbox-provider.ts b/web/service/use-sandbox-provider.ts index 486c3ba94a..975d3008c6 100644 --- a/web/service/use-sandbox-provider.ts +++ b/web/service/use-sandbox-provider.ts @@ -12,14 +12,6 @@ export const useGetSandboxProviderList = () => { }) } -export const useGetSandboxProvider = (providerType: string) => { - return useQuery({ - queryKey: consoleQuery.sandboxProvider.getSandboxProvider.queryKey({ input: { params: { providerType } } }), - queryFn: () => consoleClient.sandboxProvider.getSandboxProvider({ params: { providerType } }), - enabled: !!providerType, - }) -} - export const useSaveSandboxProviderConfig = () => { const queryClient = useQueryClient() return useMutation({ @@ -65,10 +57,3 @@ export const useActivateSandboxProvider = () => { }, }) } - -export const useGetActiveSandboxProvider = () => { - return useQuery({ - queryKey: consoleQuery.sandboxProvider.getActiveSandboxProvider.queryKey(), - queryFn: () => consoleClient.sandboxProvider.getActiveSandboxProvider(), - }) -}