feat(api): port Sandbox + VirtualEnvironment + Skill system from feat/support-agent-sandbox (Phase 5-6)

Port the complete infrastructure for agent sandbox execution and skill system:

Sandbox & Virtual Environment (core/sandbox/, core/virtual_environment/):
- Sandbox entity with lifecycle management (ready/failed/cancelled states)
- SandboxBuilder with fluent API for configuring providers
- 5 VM providers: Local, SSH, Docker, E2B, AWS CodeInterpreter
- VirtualEnvironment base with command execution, file transfer, transport layers
- Channel transport: pipe, queue, socket implementations
- Bash session management and DifyCli binary integration
- Storage: archive storage, file storage, noop storage, presign storage
- Initializers: DifyCli, AppAssets, DraftAppAssets, Skills
- Inspector: file browser, archive/runtime source, script utils
- Security: encryption utils, debug helpers

Skill & App Assets (core/skill/, core/app_assets/, core/app_bundle/):
- Skill entity and manager
- App asset accessor, builder pipeline (file, skill builders)
- App bundle source zip extractor
- Storage and converter utilities

API Endpoints:
- CLI API blueprint (controllers/cli_api/) for sandbox callback
- Sandbox provider management (workspace/sandbox_providers)
- Sandbox file browser (console/sandbox_files)
- App asset management (console/app/app_asset)
- Skill management (console/app/skills)
- Storage file endpoints (controllers/files/storage_files)

Services:
- Sandbox service, provider service, file service
- App asset service, app bundle service

Config:
- CliApiConfig, CreatorsPlatformConfig, CollaborationConfig
- FILES_API_URL for sandbox file access

Note: Controller route registration temporarily commented out (marked TODO)
pending resolution of deep dependency chains (socketio, workflow_comment,
command node, etc.). Core sandbox modules are fully ported and syntax-validated.
110 files changed, 10,549 insertions.

Made-with: Cursor
This commit is contained in:
Yansong Zhang 2026-04-08 17:39:02 +08:00
parent d9d1e9b63a
commit 0c7e7e0c4e
110 changed files with 10549 additions and 0 deletions

View File

@ -271,6 +271,17 @@ class PluginConfig(BaseSettings):
)
class CliApiConfig(BaseSettings):
"""
Configuration for CLI API (for dify-cli to call back from external sandbox environments)
"""
CLI_API_URL: str = Field(
description="CLI API URL for external sandbox (e.g., e2b) to call back.",
default="http://localhost:5001",
)
class MarketplaceConfig(BaseSettings):
"""
Configuration for marketplace
@ -287,6 +298,27 @@ class MarketplaceConfig(BaseSettings):
)
class CreatorsPlatformConfig(BaseSettings):
"""
Configuration for creators platform
"""
CREATORS_PLATFORM_FEATURES_ENABLED: bool = Field(
description="Enable or disable creators platform features",
default=True,
)
CREATORS_PLATFORM_API_URL: HttpUrl = Field(
description="Creators Platform API URL",
default=HttpUrl("https://creators.dify.ai"),
)
CREATORS_PLATFORM_OAUTH_CLIENT_ID: str = Field(
description="OAuth client_id for the Creators Platform app registered in Dify",
default="",
)
class EndpointConfig(BaseSettings):
"""
Configuration for various application endpoints and URLs
@ -341,6 +373,15 @@ class FileAccessConfig(BaseSettings):
default="",
)
FILES_API_URL: str = Field(
description="Base URL for storage file ticket API endpoints."
" Used by sandbox containers (internal or external like e2b) that need"
" an absolute, routable address to upload/download files via the API."
" For all-in-one Docker deployments, set to http://localhost."
" For public sandbox environments, set to a public domain or IP.",
default="",
)
FILES_ACCESS_TIMEOUT: int = Field(
description="Expiration time in seconds for file access URLs",
default=300,
@ -1274,6 +1315,13 @@ class PositionConfig(BaseSettings):
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""}
class CollaborationConfig(BaseSettings):
ENABLE_COLLABORATION_MODE: bool = Field(
description="Whether to enable collaboration mode features across the workspace",
default=False,
)
class LoginConfig(BaseSettings):
ENABLE_EMAIL_CODE_LOGIN: bool = Field(
description="whether to enable email code login",
@ -1375,7 +1423,9 @@ class FeatureConfig(
TriggerConfig,
AsyncWorkflowConfig,
PluginConfig,
CliApiConfig,
MarketplaceConfig,
CreatorsPlatformConfig,
DataSetConfig,
EndpointConfig,
FileAccessConfig,
@ -1399,6 +1449,7 @@ class FeatureConfig(
WorkflowConfig,
WorkflowNodeExecutionConfig,
WorkspaceConfig,
CollaborationConfig,
LoginConfig,
AccountConfig,
SwaggerUIConfig,

View File

@ -0,0 +1,27 @@
from flask import Blueprint
from flask_restx import Namespace
from libs.external_api import ExternalApi
bp = Blueprint("cli_api", __name__, url_prefix="/cli/api")
api = ExternalApi(
bp,
version="1.0",
title="CLI API",
description="APIs for Dify CLI to call back from external sandbox environments (e.g., e2b)",
)
# Create namespace
cli_api_ns = Namespace("cli_api", description="CLI API operations", path="/")
from .dify_cli import cli_api as _plugin
api.add_namespace(cli_api_ns)
__all__ = [
"_plugin",
"api",
"bp",
"cli_api_ns",
]

View File

@ -0,0 +1,192 @@
from flask import abort
from flask_restx import Resource
from pydantic import BaseModel
from controllers.cli_api import cli_api_ns
from controllers.cli_api.dify_cli.wraps import get_cli_user_tenant, plugin_data
from controllers.cli_api.wraps import cli_api_only
from controllers.console.wraps import setup_required
from core.app.entities.app_invoke_entities import InvokeFrom
from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation
from core.plugin.backwards_invocation.base import BaseBackwardsInvocationResponse
from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation
from core.plugin.backwards_invocation.tool import PluginToolBackwardsInvocation
from core.plugin.entities.request import (
RequestInvokeApp,
RequestInvokeLLM,
RequestInvokeTool,
RequestRequestUploadFile,
)
from core.sandbox.bash.dify_cli import DifyCliToolConfig
from core.session.cli_api import CliContext
from core.skill.entities import ToolInvocationRequest
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.tool_manager import ToolManager
from graphon.file.helpers import get_signed_file_url_for_plugin
from libs.helper import length_prefixed_response
from models.account import Account
from models.model import EndUser, Tenant
class FetchToolItem(BaseModel):
tool_type: str
tool_provider: str
tool_name: str
credential_id: str | None = None
class FetchToolBatchRequest(BaseModel):
tools: list[FetchToolItem]
@cli_api_ns.route("/invoke/llm")
class CliInvokeLLMApi(Resource):
@cli_api_only
@get_cli_user_tenant
@setup_required
@plugin_data(payload_type=RequestInvokeLLM)
def post(
self,
user_model: Account | EndUser,
tenant_model: Tenant,
payload: RequestInvokeLLM,
cli_context: CliContext,
):
def generator():
response = PluginModelBackwardsInvocation.invoke_llm(user_model.id, tenant_model, payload)
return PluginModelBackwardsInvocation.convert_to_event_stream(response)
return length_prefixed_response(0xF, generator())
@cli_api_ns.route("/invoke/tool")
class CliInvokeToolApi(Resource):
@cli_api_only
@get_cli_user_tenant
@setup_required
@plugin_data(payload_type=RequestInvokeTool)
def post(
self,
user_model: Account | EndUser,
tenant_model: Tenant,
payload: RequestInvokeTool,
cli_context: CliContext,
):
tool_type = ToolProviderType.value_of(payload.tool_type)
request = ToolInvocationRequest(
tool_type=tool_type,
provider=payload.provider,
tool_name=payload.tool,
credential_id=payload.credential_id,
)
if cli_context.tool_access and not cli_context.tool_access.is_allowed(request):
abort(403, description=f"Access denied for tool: {payload.provider}/{payload.tool}")
def generator():
return PluginToolBackwardsInvocation.convert_to_event_stream(
PluginToolBackwardsInvocation.invoke_tool(
tenant_id=tenant_model.id,
user_id=user_model.id,
tool_type=tool_type,
provider=payload.provider,
tool_name=payload.tool,
tool_parameters=payload.tool_parameters,
credential_id=payload.credential_id,
),
)
return length_prefixed_response(0xF, generator())
@cli_api_ns.route("/invoke/app")
class CliInvokeAppApi(Resource):
@cli_api_only
@get_cli_user_tenant
@setup_required
@plugin_data(payload_type=RequestInvokeApp)
def post(
self,
user_model: Account | EndUser,
tenant_model: Tenant,
payload: RequestInvokeApp,
cli_context: CliContext,
):
response = PluginAppBackwardsInvocation.invoke_app(
app_id=payload.app_id,
user_id=user_model.id,
tenant_id=tenant_model.id,
conversation_id=payload.conversation_id,
query=payload.query,
stream=payload.response_mode == "streaming",
inputs=payload.inputs,
files=payload.files,
)
return length_prefixed_response(0xF, PluginAppBackwardsInvocation.convert_to_event_stream(response))
@cli_api_ns.route("/upload/file/request")
class CliUploadFileRequestApi(Resource):
@cli_api_only
@get_cli_user_tenant
@setup_required
@plugin_data(payload_type=RequestRequestUploadFile)
def post(
self,
user_model: Account | EndUser,
tenant_model: Tenant,
payload: RequestRequestUploadFile,
cli_context: CliContext,
):
url = get_signed_file_url_for_plugin(
filename=payload.filename,
mimetype=payload.mimetype,
tenant_id=tenant_model.id,
user_id=user_model.id,
)
return BaseBackwardsInvocationResponse(data={"url": url}).model_dump()
@cli_api_ns.route("/fetch/tools/batch")
class CliFetchToolsBatchApi(Resource):
@cli_api_only
@get_cli_user_tenant
@setup_required
@plugin_data(payload_type=FetchToolBatchRequest)
def post(
self,
user_model: Account | EndUser,
tenant_model: Tenant,
payload: FetchToolBatchRequest,
cli_context: CliContext,
):
tools: list[dict] = []
for item in payload.tools:
provider_type = ToolProviderType.value_of(item.tool_type)
request = ToolInvocationRequest(
tool_type=provider_type,
provider=item.tool_provider,
tool_name=item.tool_name,
credential_id=item.credential_id,
)
if cli_context.tool_access and not cli_context.tool_access.is_allowed(request):
abort(403, description=f"Access denied for tool: {item.tool_provider}/{item.tool_name}")
try:
tool_runtime = ToolManager.get_tool_runtime(
tenant_id=tenant_model.id,
provider_type=provider_type,
provider_id=item.tool_provider,
tool_name=item.tool_name,
invoke_from=InvokeFrom.AGENT,
credential_id=item.credential_id,
)
tool_config = DifyCliToolConfig.create_from_tool(tool_runtime)
tools.append(tool_config.model_dump())
except Exception:
continue
return BaseBackwardsInvocationResponse(data={"tools": tools}).model_dump()

View File

@ -0,0 +1,137 @@
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from flask import current_app, g, request
from flask_login import user_logged_in
from pydantic import BaseModel
from sqlalchemy.orm import Session
from core.session.cli_api import CliApiSession, CliContext
from extensions.ext_database import db
from libs.login import current_user
from models.account import Tenant
from models.model import DefaultEndUserSessionID, EndUser
P = ParamSpec("P")
R = TypeVar("R")
class TenantUserPayload(BaseModel):
tenant_id: str
user_id: str
def get_user(tenant_id: str, user_id: str | None) -> EndUser:
"""
Get current user
NOTE: user_id is not trusted, it could be maliciously set to any value.
As a result, it could only be considered as an end user id.
"""
if not user_id:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
is_anonymous = user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
try:
with Session(db.engine) as session:
user_model = None
if is_anonymous:
user_model = (
session.query(EndUser)
.where(
EndUser.session_id == user_id,
EndUser.tenant_id == tenant_id,
)
.first()
)
else:
user_model = (
session.query(EndUser)
.where(
EndUser.id == user_id,
EndUser.tenant_id == tenant_id,
)
.first()
)
if not user_model:
user_model = EndUser(
tenant_id=tenant_id,
type="service_api",
is_anonymous=is_anonymous,
session_id=user_id,
)
session.add(user_model)
session.commit()
session.refresh(user_model)
except Exception:
raise ValueError("user not found")
return user_model
def get_cli_user_tenant(view_func: Callable[P, R]):
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
session: CliApiSession | None = getattr(g, "cli_api_session", None)
if session is None:
raise ValueError("session not found")
user_id = session.user_id
tenant_id = session.tenant_id
cli_context = CliContext.model_validate(session.context)
if not user_id:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
try:
tenant_model = (
db.session.query(Tenant)
.where(
Tenant.id == tenant_id,
)
.first()
)
except Exception:
raise ValueError("tenant not found")
if not tenant_model:
raise ValueError("tenant not found")
kwargs["tenant_model"] = tenant_model
kwargs["user_model"] = get_user(tenant_id, user_id)
kwargs["cli_context"] = cli_context
current_app.login_manager._update_request_context_with_user(kwargs["user_model"]) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
return view_func(*args, **kwargs)
return decorated_view
def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseModel]):
def decorator(view_func: Callable[P, R]):
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
try:
data = request.get_json()
except Exception:
raise ValueError("invalid json")
try:
payload = payload_type.model_validate(data)
except Exception as e:
raise ValueError(f"invalid payload: {str(e)}")
kwargs["payload"] = payload
return view_func(*args, **kwargs)
return decorated_view
if view is None:
return decorator
else:
return decorator(view)

View File

@ -0,0 +1,56 @@
import hashlib
import hmac
import time
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from flask import abort, g, request
from core.session.cli_api import CliApiSessionManager
P = ParamSpec("P")
R = TypeVar("R")
SIGNATURE_TTL_SECONDS = 300
def _verify_signature(session_secret: str, timestamp: str, body: bytes, signature: str) -> bool:
expected = hmac.new(
session_secret.encode(),
f"{timestamp}.".encode() + body,
hashlib.sha256,
).hexdigest()
return hmac.compare_digest(f"sha256={expected}", signature)
def cli_api_only(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
session_id = request.headers.get("X-Cli-Api-Session-Id")
timestamp = request.headers.get("X-Cli-Api-Timestamp")
signature = request.headers.get("X-Cli-Api-Signature")
if not session_id or not timestamp or not signature:
abort(401)
try:
ts = int(timestamp)
if abs(time.time() - ts) > SIGNATURE_TTL_SECONDS:
abort(401)
except ValueError:
abort(401)
session = CliApiSessionManager().get(session_id)
if not session:
abort(401)
body = request.get_data()
if not _verify_signature(session.secret, timestamp, body, signature):
abort(401)
g.cli_api_session = session
return view(*args, **kwargs)
return decorated

View File

@ -41,6 +41,7 @@ from . import (
init_validate,
notification,
ping,
# sandbox_files, # TODO: enable after full sandbox integration
setup,
spec,
version,
@ -52,6 +53,7 @@ from .app import (
agent,
annotation,
app,
# app_asset, # TODO: enable after full sandbox integration
audio,
completion,
conversation,
@ -62,6 +64,7 @@ from .app import (
model_config,
ops_trace,
site,
# skills, # TODO: enable after full sandbox integration
statistic,
workflow,
workflow_app_log,
@ -130,6 +133,7 @@ from .workspace import (
model_providers,
models,
plugin,
# sandbox_providers, # TODO: enable after full sandbox integration
tool_providers,
trigger_providers,
workspace,

View File

@ -0,0 +1,333 @@
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from controllers.console import console_ns
from controllers.console.app.error import (
AppAssetNodeNotFoundError,
AppAssetPathConflictError,
)
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from core.app.entities.app_asset_entities import BatchUploadNode
from libs.login import current_account_with_tenant, login_required
from models import App
from models.model import AppMode
from services.app_asset_service import AppAssetService
from services.errors.app_asset import (
AppAssetNodeNotFoundError as ServiceNodeNotFoundError,
)
from services.errors.app_asset import (
AppAssetParentNotFoundError,
)
from services.errors.app_asset import (
AppAssetPathConflictError as ServicePathConflictError,
)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class CreateFolderPayload(BaseModel):
name: str = Field(..., min_length=1, max_length=255)
parent_id: str | None = None
class CreateFilePayload(BaseModel):
name: str = Field(..., min_length=1, max_length=255)
parent_id: str | None = None
@field_validator("name", mode="before")
@classmethod
def strip_name(cls, v: str) -> str:
return v.strip() if isinstance(v, str) else v
@field_validator("parent_id", mode="before")
@classmethod
def empty_to_none(cls, v: str | None) -> str | None:
return v or None
class GetUploadUrlPayload(BaseModel):
name: str = Field(..., min_length=1, max_length=255)
size: int = Field(..., ge=0)
parent_id: str | None = None
@field_validator("name", mode="before")
@classmethod
def strip_name(cls, v: str) -> str:
return v.strip() if isinstance(v, str) else v
@field_validator("parent_id", mode="before")
@classmethod
def empty_to_none(cls, v: str | None) -> str | None:
return v or None
class BatchUploadPayload(BaseModel):
children: list[BatchUploadNode] = Field(..., min_length=1)
parent_id: str | None = None
@field_validator("parent_id", mode="before")
@classmethod
def empty_to_none(cls, v: str | None) -> str | None:
return v or None
class UpdateFileContentPayload(BaseModel):
content: str
class RenameNodePayload(BaseModel):
name: str = Field(..., min_length=1, max_length=255)
class MoveNodePayload(BaseModel):
parent_id: str | None = None
class ReorderNodePayload(BaseModel):
after_node_id: str | None = Field(default=None, description="Place after this node, None for first position")
def reg(cls: type[BaseModel]) -> None:
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(CreateFolderPayload)
reg(CreateFilePayload)
reg(GetUploadUrlPayload)
reg(BatchUploadNode)
reg(BatchUploadPayload)
reg(UpdateFileContentPayload)
reg(RenameNodePayload)
reg(MoveNodePayload)
reg(ReorderNodePayload)
@console_ns.route("/apps/<string:app_id>/assets/tree")
class AppAssetTreeResource(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App):
current_user, _ = current_account_with_tenant()
tree = AppAssetService.get_asset_tree(app_model, current_user.id)
return {"children": [view.model_dump() for view in tree.transform()]}
@console_ns.route("/apps/<string:app_id>/assets/folders")
class AppAssetFolderResource(Resource):
@console_ns.expect(console_ns.models[CreateFolderPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def post(self, app_model: App):
current_user, _ = current_account_with_tenant()
payload = CreateFolderPayload.model_validate(console_ns.payload or {})
try:
node = AppAssetService.create_folder(app_model, current_user.id, payload.name, payload.parent_id)
return node.model_dump(), 201
except AppAssetParentNotFoundError:
raise AppAssetNodeNotFoundError()
except ServicePathConflictError:
raise AppAssetPathConflictError()
@console_ns.route("/apps/<string:app_id>/assets/files/<string:node_id>")
class AppAssetFileDetailResource(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, node_id: str):
current_user, _ = current_account_with_tenant()
try:
content = AppAssetService.get_file_content(app_model, current_user.id, node_id)
return {"content": content.decode("utf-8", errors="replace")}
except ServiceNodeNotFoundError:
raise AppAssetNodeNotFoundError()
@console_ns.expect(console_ns.models[UpdateFileContentPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def put(self, app_model: App, node_id: str):
current_user, _ = current_account_with_tenant()
file = request.files.get("file")
if file:
content = file.read()
else:
payload = UpdateFileContentPayload.model_validate(console_ns.payload or {})
content = payload.content.encode("utf-8")
try:
node = AppAssetService.update_file_content(app_model, current_user.id, node_id, content)
return node.model_dump()
except ServiceNodeNotFoundError:
raise AppAssetNodeNotFoundError()
@console_ns.route("/apps/<string:app_id>/assets/nodes/<string:node_id>")
class AppAssetNodeResource(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def delete(self, app_model: App, node_id: str):
current_user, _ = current_account_with_tenant()
try:
AppAssetService.delete_node(app_model, current_user.id, node_id)
return {"result": "success"}, 200
except ServiceNodeNotFoundError:
raise AppAssetNodeNotFoundError()
@console_ns.route("/apps/<string:app_id>/assets/nodes/<string:node_id>/rename")
class AppAssetNodeRenameResource(Resource):
@console_ns.expect(console_ns.models[RenameNodePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def post(self, app_model: App, node_id: str):
current_user, _ = current_account_with_tenant()
payload = RenameNodePayload.model_validate(console_ns.payload or {})
try:
node = AppAssetService.rename_node(app_model, current_user.id, node_id, payload.name)
return node.model_dump()
except ServiceNodeNotFoundError:
raise AppAssetNodeNotFoundError()
except ServicePathConflictError:
raise AppAssetPathConflictError()
@console_ns.route("/apps/<string:app_id>/assets/nodes/<string:node_id>/move")
class AppAssetNodeMoveResource(Resource):
@console_ns.expect(console_ns.models[MoveNodePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def post(self, app_model: App, node_id: str):
current_user, _ = current_account_with_tenant()
payload = MoveNodePayload.model_validate(console_ns.payload or {})
try:
node = AppAssetService.move_node(app_model, current_user.id, node_id, payload.parent_id)
return node.model_dump()
except ServiceNodeNotFoundError:
raise AppAssetNodeNotFoundError()
except AppAssetParentNotFoundError:
raise AppAssetNodeNotFoundError()
except ServicePathConflictError:
raise AppAssetPathConflictError()
@console_ns.route("/apps/<string:app_id>/assets/nodes/<string:node_id>/reorder")
class AppAssetNodeReorderResource(Resource):
@console_ns.expect(console_ns.models[ReorderNodePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def post(self, app_model: App, node_id: str):
current_user, _ = current_account_with_tenant()
payload = ReorderNodePayload.model_validate(console_ns.payload or {})
try:
node = AppAssetService.reorder_node(app_model, current_user.id, node_id, payload.after_node_id)
return node.model_dump()
except ServiceNodeNotFoundError:
raise AppAssetNodeNotFoundError()
@console_ns.route("/apps/<string:app_id>/assets/files/<string:node_id>/download-url")
class AppAssetFileDownloadUrlResource(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, node_id: str):
current_user, _ = current_account_with_tenant()
try:
download_url = AppAssetService.get_file_download_url(app_model, current_user.id, node_id)
return {"download_url": download_url}
except ServiceNodeNotFoundError:
raise AppAssetNodeNotFoundError()
@console_ns.route("/apps/<string:app_id>/assets/files/upload")
class AppAssetFileUploadUrlResource(Resource):
@console_ns.expect(console_ns.models[GetUploadUrlPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def post(self, app_model: App):
current_user, _ = current_account_with_tenant()
payload = GetUploadUrlPayload.model_validate(console_ns.payload or {})
try:
node, upload_url = AppAssetService.get_file_upload_url(
app_model, current_user.id, payload.name, payload.size, payload.parent_id
)
return {"node": node.model_dump(), "upload_url": upload_url}, 201
except AppAssetParentNotFoundError:
raise AppAssetNodeNotFoundError()
except ServicePathConflictError:
raise AppAssetPathConflictError()
@console_ns.route("/apps/<string:app_id>/assets/batch-upload")
class AppAssetBatchUploadResource(Resource):
@console_ns.expect(console_ns.models[BatchUploadPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def post(self, app_model: App):
"""
Create nodes from tree structure and return upload URLs.
Input:
{
"parent_id": "optional-target-folder-id",
"children": [
{"name": "folder1", "node_type": "folder", "children": [
{"name": "file1.txt", "node_type": "file", "size": 1024}
]},
{"name": "root.txt", "node_type": "file", "size": 512}
]
}
Output:
{
"children": [
{"id": "xxx", "name": "folder1", "node_type": "folder", "children": [
{"id": "yyy", "name": "file1.txt", "node_type": "file", "size": 1024, "upload_url": "..."}
]},
{"id": "zzz", "name": "root.txt", "node_type": "file", "size": 512, "upload_url": "..."}
]
}
"""
current_user, _ = current_account_with_tenant()
payload = BatchUploadPayload.model_validate(console_ns.payload or {})
try:
result_children = AppAssetService.batch_create_from_tree(
app_model,
current_user.id,
payload.children,
parent_id=payload.parent_id,
)
return {"children": [child.model_dump() for child in result_children]}, 201
except AppAssetParentNotFoundError:
raise AppAssetNodeNotFoundError()
except ServicePathConflictError:
raise AppAssetPathConflictError()

View File

@ -0,0 +1,38 @@
from flask import request
from flask_restx import Resource
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, current_account_with_tenant, setup_required
from libs.login import login_required
from models import App
from models.model import AppMode
from services.skill_service import SkillService
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/llm/skills")
class NodeSkillsApi(Resource):
"""Extract tool dependencies from an LLM node's skill prompts.
The client sends the full node ``data`` object in the request body.
The server real-time builds a ``SkillBundle`` from the current draft
``.md`` assets and resolves transitive tool dependencies no cached
bundle is used.
"""
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def post(self, app_model: App):
current_user, _ = current_account_with_tenant()
node_data = request.get_json(force=True)
if not isinstance(node_data, dict):
return {"tool_dependencies": []}
tool_deps = SkillService.extract_tool_dependencies(
app=app_model,
node_data=node_data,
user_id=current_user.id,
)
return {"tool_dependencies": [d.model_dump() for d in tool_deps]}

View File

@ -0,0 +1,103 @@
from __future__ import annotations
from fastapi.encoders import jsonable_encoder
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import current_account_with_tenant, login_required
from services.sandbox.sandbox_file_service import SandboxFileService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class SandboxFileListQuery(BaseModel):
path: str | None = Field(default=None, description="Workspace relative path")
recursive: bool = Field(default=False, description="List recursively")
class SandboxFileDownloadRequest(BaseModel):
path: str = Field(..., description="Workspace relative file path")
console_ns.schema_model(
SandboxFileListQuery.__name__,
SandboxFileListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
SandboxFileDownloadRequest.__name__,
SandboxFileDownloadRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
SANDBOX_FILE_NODE_FIELDS = {
"path": fields.String,
"is_dir": fields.Boolean,
"size": fields.Raw,
"mtime": fields.Raw,
"extension": fields.String,
}
SANDBOX_FILE_DOWNLOAD_TICKET_FIELDS = {
"download_url": fields.String,
"expires_in": fields.Integer,
"export_id": fields.String,
}
sandbox_file_node_model = console_ns.model("SandboxFileNode", SANDBOX_FILE_NODE_FIELDS)
sandbox_file_download_ticket_model = console_ns.model("SandboxFileDownloadTicket", SANDBOX_FILE_DOWNLOAD_TICKET_FIELDS)
@console_ns.route("/apps/<string:app_id>/sandbox/files")
class SandboxFilesApi(Resource):
"""List sandbox files for the current user.
The sandbox_id is derived from the current user's ID, as each user has
their own sandbox workspace per app.
"""
@setup_required
@login_required
@account_initialization_required
@console_ns.expect(console_ns.models[SandboxFileListQuery.__name__])
@console_ns.marshal_list_with(sandbox_file_node_model)
def get(self, app_id: str):
args = SandboxFileListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore[arg-type]
account, tenant_id = current_account_with_tenant()
sandbox_id = account.id
return jsonable_encoder(
SandboxFileService.list_files(
tenant_id=tenant_id,
app_id=app_id,
sandbox_id=sandbox_id,
path=args.path,
recursive=args.recursive,
)
)
@console_ns.route("/apps/<string:app_id>/sandbox/files/download")
class SandboxFileDownloadApi(Resource):
"""Download a sandbox file for the current user.
The sandbox_id is derived from the current user's ID, as each user has
their own sandbox workspace per app.
"""
@setup_required
@login_required
@account_initialization_required
@console_ns.expect(console_ns.models[SandboxFileDownloadRequest.__name__])
@console_ns.marshal_with(sandbox_file_download_ticket_model)
def post(self, app_id: str):
payload = SandboxFileDownloadRequest.model_validate(console_ns.payload or {})
account, tenant_id = current_account_with_tenant()
sandbox_id = account.id
res = SandboxFileService.download_file(
tenant_id=tenant_id, app_id=app_id, sandbox_id=sandbox_id, path=payload.path
)
return jsonable_encoder(res)

View File

@ -0,0 +1,104 @@
import logging
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from graphon.model_runtime.utils.encoders import jsonable_encoder
from libs.login import current_account_with_tenant, login_required
from services.sandbox.sandbox_provider_service import SandboxProviderService
logger = logging.getLogger(__name__)
class SandboxProviderConfigRequest(BaseModel):
config: dict
activate: bool = False
class SandboxProviderActivateRequest(BaseModel):
type: str
@console_ns.route("/workspaces/current/sandbox-providers")
class SandboxProviderListApi(Resource):
@console_ns.doc("list_sandbox_providers")
@console_ns.doc(description="Get list of available sandbox providers with configuration status")
@console_ns.response(200, "Success", fields.List(fields.Raw(description="Sandbox provider information")))
@setup_required
@login_required
@account_initialization_required
def get(self):
_, current_tenant_id = current_account_with_tenant()
providers = SandboxProviderService.list_providers(current_tenant_id)
return jsonable_encoder([p.model_dump() for p in providers])
@console_ns.route("/workspaces/current/sandbox-provider/<string:provider_type>/config")
class SandboxProviderConfigApi(Resource):
@console_ns.doc("save_sandbox_provider_config")
@console_ns.doc(description="Save or update configuration for a sandbox provider")
@console_ns.response(200, "Success")
@setup_required
@login_required
@account_initialization_required
def post(self, provider_type: str):
_, current_tenant_id = current_account_with_tenant()
args = SandboxProviderConfigRequest.model_validate(request.get_json())
try:
result = SandboxProviderService.save_config(
tenant_id=current_tenant_id,
provider_type=provider_type,
config=args.config,
activate=args.activate,
)
return result
except ValueError as e:
return {"message": str(e)}, 400
@console_ns.doc("delete_sandbox_provider_config")
@console_ns.doc(description="Delete configuration for a sandbox provider")
@console_ns.response(200, "Success")
@setup_required
@login_required
@account_initialization_required
def delete(self, provider_type: str):
_, current_tenant_id = current_account_with_tenant()
try:
result = SandboxProviderService.delete_config(
tenant_id=current_tenant_id,
provider_type=provider_type,
)
return result
except ValueError as e:
return {"message": str(e)}, 400
@console_ns.route("/workspaces/current/sandbox-provider/<string:provider_type>/activate")
class SandboxProviderActivateApi(Resource):
"""Activate a sandbox provider."""
@console_ns.doc("activate_sandbox_provider")
@console_ns.doc(description="Activate a sandbox provider for the current workspace")
@console_ns.response(200, "Success")
@setup_required
@login_required
@account_initialization_required
def post(self, provider_type: str):
"""Activate a sandbox provider."""
_, current_tenant_id = current_account_with_tenant()
try:
args = SandboxProviderActivateRequest.model_validate(request.get_json())
result = SandboxProviderService.activate_provider(
tenant_id=current_tenant_id,
provider_type=provider_type,
type=args.type,
)
return result
except ValueError as e:
return {"message": str(e)}, 400

View File

@ -0,0 +1,80 @@
"""Token-based file proxy controller for storage operations.
This controller handles file download and upload operations using opaque UUID tokens.
The token maps to the real storage key in Redis, so the actual storage path is never
exposed in the URL.
Routes:
GET /files/storage-files/{token} - Download a file
PUT /files/storage-files/{token} - Upload a file
The operation type (download/upload) is determined by the ticket stored in Redis,
not by the HTTP method. This ensures a download ticket cannot be used for upload
and vice versa.
"""
from urllib.parse import quote
from flask import Response, request
from flask_restx import Resource
from werkzeug.exceptions import Forbidden, NotFound, RequestEntityTooLarge
from controllers.files import files_ns
from extensions.ext_storage import storage
from services.storage_ticket_service import StorageTicketService
@files_ns.route("/storage-files/<string:token>")
class StorageFilesApi(Resource):
"""Handle file operations through token-based URLs."""
def get(self, token: str):
"""Download a file using a token.
The ticket must have op="download", otherwise returns 403.
"""
ticket = StorageTicketService.get_ticket(token)
if ticket is None:
raise Forbidden("Invalid or expired token")
if ticket.op != "download":
raise Forbidden("This token is not valid for download")
try:
generator = storage.load_stream(ticket.storage_key)
except FileNotFoundError:
raise NotFound("File not found")
filename = ticket.filename or ticket.storage_key.rsplit("/", 1)[-1]
encoded_filename = quote(filename)
return Response(
generator,
mimetype="application/octet-stream",
direct_passthrough=True,
headers={
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
},
)
def put(self, token: str):
"""Upload a file using a token.
The ticket must have op="upload", otherwise returns 403.
If the request body exceeds max_bytes, returns 413.
"""
ticket = StorageTicketService.get_ticket(token)
if ticket is None:
raise Forbidden("Invalid or expired token")
if ticket.op != "upload":
raise Forbidden("This token is not valid for upload")
content = request.get_data()
if ticket.max_bytes is not None and len(content) > ticket.max_bytes:
raise RequestEntityTooLarge(f"Upload exceeds maximum size of {ticket.max_bytes} bytes")
storage.save(ticket.storage_key, content)
return Response(status=204)

View File

View File

@ -0,0 +1,22 @@
import logging
from core.sandbox import Sandbox
from graphon.graph_engine.layers.base import GraphEngineLayer
from graphon.graph_events.base import GraphEngineEvent
logger = logging.getLogger(__name__)
class SandboxLayer(GraphEngineLayer):
def __init__(self, sandbox: Sandbox) -> None:
super().__init__()
self._sandbox = sandbox
def on_graph_start(self) -> None:
pass
def on_event(self, event: GraphEngineEvent) -> None:
pass
def on_graph_end(self, error: Exception | None) -> None:
self._sandbox.release()

View File

@ -0,0 +1,13 @@
from .constants import AppAssetsAttrs
from .entities import (
AssetItem,
SkillAsset,
)
from .storage import AssetPaths
__all__ = [
"AppAssetsAttrs",
"AssetItem",
"AssetPaths",
"SkillAsset",
]

View File

@ -0,0 +1,180 @@
"""Unified content accessor for app asset nodes.
Accessor is scoped to a single app (tenant_id + app_id), not a single node.
All methods accept an AppAssetNode parameter to identify the target.
CachedContentAccessor is the primary entry point:
- Reads DB first, misses fall through to S3 with sync backfill.
- Writes go to both DB and S3 (dual-write).
- resolve_items() batch-enriches AssetItem lists with DB-cached content
(extension-agnostic), so callers never need to filter by extension.
- Wraps an internal _StorageAccessor for S3 I/O.
Collaborators:
- services.asset_content_service.AssetContentService (DB layer)
- core.app_assets.storage.AssetPaths (S3 key generation)
- extensions.storage.cached_presign_storage.CachedPresignStorage (S3 I/O)
"""
from __future__ import annotations
import logging
from core.app.entities.app_asset_entities import AppAssetNode
from core.app_assets.entities.assets import AssetItem
from core.app_assets.storage import AssetPaths
from extensions.storage.cached_presign_storage import CachedPresignStorage
from services.asset_content_service import AssetContentService
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# S3-only implementation (internal, used as inner delegate)
# ---------------------------------------------------------------------------
class _StorageAccessor:
"""Reads/writes draft content via object storage (S3) only."""
_storage: CachedPresignStorage
_tenant_id: str
_app_id: str
def __init__(self, storage: CachedPresignStorage, tenant_id: str, app_id: str) -> None:
self._storage = storage
self._tenant_id = tenant_id
self._app_id = app_id
def _key(self, node: AppAssetNode) -> str:
return AssetPaths.draft(self._tenant_id, self._app_id, node.id)
def load(self, node: AppAssetNode) -> bytes:
return self._storage.load_once(self._key(node))
def save(self, node: AppAssetNode, content: bytes) -> None:
self._storage.save(self._key(node), content)
def delete(self, node: AppAssetNode) -> None:
try:
self._storage.delete(self._key(node))
except Exception:
logger.warning("Failed to delete storage key %s", self._key(node), exc_info=True)
# ---------------------------------------------------------------------------
# DB-cached implementation (the public API)
# ---------------------------------------------------------------------------
class CachedContentAccessor:
"""App-level content accessor with DB read-through cache over S3.
Read path: DB first -> miss -> S3 fallback -> sync backfill DB
Write path: DB upsert + S3 save (dual-write)
Delete path: DB delete + S3 delete
bulk_load uses a single SQL query for all nodes, with S3 fallback per miss.
Usage:
accessor = CachedContentAccessor(storage, tenant_id, app_id)
content = accessor.load(node)
accessor.save(node, content)
results = accessor.bulk_load(nodes)
"""
_inner: _StorageAccessor
_tenant_id: str
_app_id: str
def __init__(self, storage: CachedPresignStorage, tenant_id: str, app_id: str) -> None:
self._inner = _StorageAccessor(storage, tenant_id, app_id)
self._tenant_id = tenant_id
self._app_id = app_id
def load(self, node: AppAssetNode) -> bytes:
# 1. Try DB
cached = AssetContentService.get(self._tenant_id, self._app_id, node.id)
if cached is not None:
return cached.encode("utf-8")
# 2. Fallback to S3
data = self._inner.load(node)
# 3. Sync backfill DB
AssetContentService.upsert(
tenant_id=self._tenant_id,
app_id=self._app_id,
node_id=node.id,
content=data.decode("utf-8"),
size=len(data),
)
return data
def bulk_load(self, nodes: list[AppAssetNode]) -> dict[str, bytes]:
"""Single SQL for all nodes, S3 fallback + backfill per miss."""
result: dict[str, bytes] = {}
node_ids = [n.id for n in nodes]
cached = AssetContentService.get_many(self._tenant_id, self._app_id, node_ids)
for node in nodes:
if node.id in cached:
result[node.id] = cached[node.id].encode("utf-8")
else:
# S3 fallback + sync backfill
data = self._inner.load(node)
AssetContentService.upsert(
tenant_id=self._tenant_id,
app_id=self._app_id,
node_id=node.id,
content=data.decode("utf-8"),
size=len(data),
)
result[node.id] = data
return result
def save(self, node: AppAssetNode, content: bytes) -> None:
# Dual-write: DB + S3
AssetContentService.upsert(
tenant_id=self._tenant_id,
app_id=self._app_id,
node_id=node.id,
content=content.decode("utf-8"),
size=len(content),
)
self._inner.save(node, content)
def resolve_items(self, items: list[AssetItem]) -> list[AssetItem]:
"""Batch-enrich asset items with DB-cached content.
Queries by ``asset_id`` only extension-agnostic. Items without
a DB cache row keep their original *content* value (typically
``None``), so only genuinely cached assets (e.g. ``.md`` skill
documents) get populated.
This eliminates the need for callers to filter by file extension
before deciding whether to read from the DB cache.
"""
if not items:
return items
node_ids = [a.asset_id for a in items]
cached = AssetContentService.get_many(self._tenant_id, self._app_id, node_ids)
if not cached:
return items
return [
AssetItem(
asset_id=a.asset_id,
path=a.path,
file_name=a.file_name,
extension=a.extension,
storage_key=a.storage_key,
content=cached[a.asset_id].encode("utf-8") if a.asset_id in cached else a.content,
)
for a in items
]
def delete(self, node: AppAssetNode) -> None:
AssetContentService.delete(self._tenant_id, self._app_id, node.id)
self._inner.delete(node)

View File

@ -0,0 +1,12 @@
from .base import AssetBuilder, BuildContext
from .file_builder import FileBuilder
from .pipeline import AssetBuildPipeline
from .skill_builder import SkillBuilder
__all__ = [
"AssetBuildPipeline",
"AssetBuilder",
"BuildContext",
"FileBuilder",
"SkillBuilder",
]

View File

@ -0,0 +1,20 @@
from dataclasses import dataclass
from typing import Protocol
from core.app.entities.app_asset_entities import AppAssetFileTree, AppAssetNode
from core.app_assets.entities import AssetItem
@dataclass
class BuildContext:
tenant_id: str
app_id: str
build_id: str
class AssetBuilder(Protocol):
def accept(self, node: AppAssetNode) -> bool: ...
def collect(self, node: AppAssetNode, path: str, ctx: BuildContext) -> None: ...
def build(self, tree: AppAssetFileTree, ctx: BuildContext) -> list[AssetItem]: ...

View File

@ -0,0 +1,30 @@
from core.app.entities.app_asset_entities import AppAssetFileTree, AppAssetNode
from core.app_assets.entities import AssetItem
from core.app_assets.storage import AssetPaths
from .base import BuildContext
class FileBuilder:
_nodes: list[tuple[AppAssetNode, str]]
def __init__(self) -> None:
self._nodes = []
def accept(self, node: AppAssetNode) -> bool:
return True
def collect(self, node: AppAssetNode, path: str, ctx: BuildContext) -> None:
self._nodes.append((node, path))
def build(self, tree: AppAssetFileTree, ctx: BuildContext) -> list[AssetItem]:
return [
AssetItem(
asset_id=node.id,
path=path,
file_name=node.name,
extension=node.extension or "",
storage_key=AssetPaths.draft(ctx.tenant_id, ctx.app_id, node.id),
)
for node, path in self._nodes
]

View File

@ -0,0 +1,27 @@
from core.app.entities.app_asset_entities import AppAssetFileTree
from core.app_assets.entities import AssetItem
from .base import AssetBuilder, BuildContext
class AssetBuildPipeline:
_builders: list[AssetBuilder]
def __init__(self, builders: list[AssetBuilder]) -> None:
self._builders = builders
def build_all(self, tree: AppAssetFileTree, ctx: BuildContext) -> list[AssetItem]:
# 1. Distribute: each node goes to first accepting builder
for node in tree.walk_files():
path = tree.get_path(node.id)
for builder in self._builders:
if builder.accept(node):
builder.collect(node, path, ctx)
break
# 2. Each builder builds its collected nodes
results: list[AssetItem] = []
for builder in self._builders:
results.extend(builder.build(tree, ctx))
return results

View File

@ -0,0 +1,96 @@
"""Builder that compiles ``.md`` skill documents into resolved content.
The builder reads raw draft content from the DB-backed accessor, parses
each into a ``SkillDocument``, assembles a ``SkillBundle`` (with
transitive tool/file dependency resolution), and returns ``AssetItem``
objects whose *content* field carries the resolved bytes in-process.
The assembled ``SkillBundle`` is persisted via ``SkillManager``
(S3 + Redis) **and** retained on the ``bundle`` property so that
callers (e.g. ``DraftAppAssetsInitializer``) can pass it directly to
``sandbox.attrs`` without a redundant Redis/S3 round-trip.
"""
import json
import logging
from core.app.entities.app_asset_entities import AppAssetFileTree, AppAssetNode
from core.app_assets.accessor import CachedContentAccessor
from core.app_assets.entities import AssetItem
from core.skill.assembler import SkillBundleAssembler
from core.skill.entities.skill_bundle import SkillBundle
from core.skill.entities.skill_document import SkillDocument
from .base import BuildContext
logger = logging.getLogger(__name__)
class SkillBuilder:
_nodes: list[tuple[AppAssetNode, str]]
_accessor: CachedContentAccessor
_bundle: SkillBundle | None
def __init__(self, accessor: CachedContentAccessor) -> None:
self._nodes = []
self._accessor = accessor
self._bundle = None
@property
def bundle(self) -> SkillBundle | None:
"""The ``SkillBundle`` produced by the last ``build()`` call, or *None*."""
return self._bundle
def accept(self, node: AppAssetNode) -> bool:
return node.extension == "md"
def collect(self, node: AppAssetNode, path: str, ctx: BuildContext) -> None:
self._nodes.append((node, path))
def build(self, tree: AppAssetFileTree, ctx: BuildContext) -> list[AssetItem]:
from core.skill.skill_manager import SkillManager
if not self._nodes:
bundle = SkillBundle(assets_id=ctx.build_id, asset_tree=tree)
SkillManager.save_bundle(ctx.tenant_id, ctx.app_id, ctx.build_id, bundle)
self._bundle = bundle
return []
# Batch-load all skill draft content in one DB query (with S3 fallback on miss).
nodes_only = [node for node, _ in self._nodes]
raw_contents = self._accessor.bulk_load(nodes_only)
# Parse documents — skip nodes whose draft content is still the empty
# placeholder written at creation time.
documents: dict[str, SkillDocument] = {}
for node, _ in self._nodes:
try:
raw = raw_contents.get(node.id)
if not raw:
continue
data = {"skill_id": node.id, **json.loads(raw)}
documents[node.id] = SkillDocument.model_validate(data)
except (FileNotFoundError, json.JSONDecodeError, TypeError, ValueError) as e:
logger.exception("Failed to load or parse skill document for node %s", node.id)
raise ValueError(f"Failed to load or parse skill document for node {node.id}") from e
bundle = SkillBundleAssembler(tree).assemble_bundle(documents, ctx.build_id)
SkillManager.save_bundle(ctx.tenant_id, ctx.app_id, ctx.build_id, bundle)
self._bundle = bundle
items: list[AssetItem] = []
for node, path in self._nodes:
skill = bundle.get(node.id)
if skill is None:
continue
items.append(
AssetItem(
asset_id=node.id,
path=path,
file_name=node.name,
extension=node.extension or "",
storage_key="",
content=skill.content.encode("utf-8"),
)
)
return items

View File

@ -0,0 +1,8 @@
from core.app.entities.app_asset_entities import AppAssetFileTree
from libs.attr_map import AttrKey
class AppAssetsAttrs:
# Skill artifact set
FILE_TREE = AttrKey("file_tree", AppAssetFileTree)
APP_ASSETS_ID = AttrKey("app_assets_id", str)

View File

@ -0,0 +1,20 @@
from __future__ import annotations
from core.app.entities.app_asset_entities import AppAssetFileTree, AssetNodeType
from core.app_assets.entities import AssetItem
from core.app_assets.storage import AssetPaths
def tree_to_asset_items(tree: AppAssetFileTree, tenant_id: str, app_id: str) -> list[AssetItem]:
"""Convert AppAssetFileTree to list of AssetItem for packaging."""
return [
AssetItem(
asset_id=node.id,
path=tree.get_path(node.id),
file_name=node.name,
extension=node.extension or "",
storage_key=AssetPaths.draft(tenant_id, app_id, node.id),
)
for node in tree.nodes
if node.node_type == AssetNodeType.FILE
]

View File

@ -0,0 +1,7 @@
from .assets import AssetItem
from .skill import SkillAsset
__all__ = [
"AssetItem",
"SkillAsset",
]

View File

@ -0,0 +1,20 @@
from dataclasses import dataclass, field
@dataclass
class AssetItem:
"""A single asset file produced by the build pipeline.
When *content* is set the payload is available in-process and can be
written directly into a ZIP or uploaded to a sandbox VM without an
extra S3 round-trip. When *content* is ``None`` the caller should
fetch the bytes from *storage_key* (the traditional presigned-URL
path).
"""
asset_id: str
path: str
file_name: str
extension: str
storage_key: str
content: bytes | None = field(default=None, repr=False)

View File

@ -0,0 +1,10 @@
from collections.abc import Mapping
from dataclasses import dataclass, field
from typing import Any
from .assets import AssetItem
@dataclass
class SkillAsset(AssetItem):
metadata: Mapping[str, Any] = field(default_factory=dict)

View File

@ -0,0 +1,68 @@
"""App assets storage key generation.
Provides AssetPaths facade for generating storage keys for app assets.
Storage instances are obtained via AppAssetService.get_storage().
"""
from __future__ import annotations
from uuid import UUID
_BASE = "app_assets"
def _check_uuid(value: str, name: str) -> None:
try:
UUID(value)
except (ValueError, TypeError) as e:
raise ValueError(f"{name} must be a valid UUID") from e
class AssetPaths:
"""Facade for generating app asset storage keys."""
@staticmethod
def draft(tenant_id: str, app_id: str, node_id: str) -> str:
"""app_assets/{tenant}/{app}/draft/{node_id}"""
_check_uuid(tenant_id, "tenant_id")
_check_uuid(app_id, "app_id")
_check_uuid(node_id, "node_id")
return f"{_BASE}/{tenant_id}/{app_id}/draft/{node_id}"
@staticmethod
def build_zip(tenant_id: str, app_id: str, assets_id: str) -> str:
"""app_assets/{tenant}/{app}/artifacts/{assets_id}.zip"""
_check_uuid(tenant_id, "tenant_id")
_check_uuid(app_id, "app_id")
_check_uuid(assets_id, "assets_id")
return f"{_BASE}/{tenant_id}/{app_id}/artifacts/{assets_id}.zip"
@staticmethod
def skill_bundle(tenant_id: str, app_id: str, assets_id: str) -> str:
"""app_assets/{tenant}/{app}/artifacts/{assets_id}/skill_artifact_set.json"""
_check_uuid(tenant_id, "tenant_id")
_check_uuid(app_id, "app_id")
_check_uuid(assets_id, "assets_id")
return f"{_BASE}/{tenant_id}/{app_id}/artifacts/{assets_id}/skill_artifact_set.json"
@staticmethod
def source_zip(tenant_id: str, app_id: str, workflow_id: str) -> str:
"""app_assets/{tenant}/{app}/sources/{workflow_id}.zip"""
_check_uuid(tenant_id, "tenant_id")
_check_uuid(app_id, "app_id")
_check_uuid(workflow_id, "workflow_id")
return f"{_BASE}/{tenant_id}/{app_id}/sources/{workflow_id}.zip"
@staticmethod
def bundle_export(tenant_id: str, app_id: str, export_id: str) -> str:
"""app_assets/{tenant}/{app}/bundle_exports/{export_id}.zip"""
_check_uuid(tenant_id, "tenant_id")
_check_uuid(app_id, "app_id")
_check_uuid(export_id, "export_id")
return f"{_BASE}/{tenant_id}/{app_id}/bundle_exports/{export_id}.zip"
@staticmethod
def bundle_import(tenant_id: str, import_id: str) -> str:
"""app_assets/{tenant}/imports/{import_id}.zip"""
_check_uuid(tenant_id, "tenant_id")
return f"{_BASE}/{tenant_id}/imports/{import_id}.zip"

View File

@ -0,0 +1 @@
# App bundle utilities - manifest-driven import/export handled by AppBundleService

View File

@ -0,0 +1,96 @@
from __future__ import annotations
import importlib
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .bash.dify_cli import (
DifyCliBinary,
DifyCliConfig,
DifyCliEnvConfig,
DifyCliLocator,
DifyCliToolConfig,
)
from .bash.session import SandboxBashSession
from .builder import SandboxBuilder, VMConfig
from .entities import AppAssets, DifyCli, SandboxProviderApiEntity, SandboxType
from .initializer import (
AsyncSandboxInitializer,
SandboxInitializeContext,
SandboxInitializer,
SyncSandboxInitializer,
)
from .initializer.app_assets_initializer import AppAssetsInitializer
from .initializer.dify_cli_initializer import DifyCliInitializer
from .initializer.draft_app_assets_initializer import DraftAppAssetsInitializer
from .sandbox import Sandbox
from .storage import ArchiveSandboxStorage, SandboxStorage
from .utils.debug import sandbox_debug
from .utils.encryption import create_sandbox_config_encrypter, masked_config
__all__ = [
"AppAssets",
"AppAssetsInitializer",
"ArchiveSandboxStorage",
"AsyncSandboxInitializer",
"DifyCli",
"DifyCliBinary",
"DifyCliConfig",
"DifyCliEnvConfig",
"DifyCliInitializer",
"DifyCliLocator",
"DifyCliToolConfig",
"DraftAppAssetsInitializer",
"Sandbox",
"SandboxBashSession",
"SandboxBuilder",
"SandboxInitializeContext",
"SandboxInitializer",
"SandboxProviderApiEntity",
"SandboxStorage",
"SandboxType",
"SyncSandboxInitializer",
"VMConfig",
"create_sandbox_config_encrypter",
"masked_config",
"sandbox_debug",
]
_LAZY_IMPORTS = {
"AppAssets": ("core.sandbox.entities", "AppAssets"),
"AppAssetsInitializer": ("core.sandbox.initializer.app_assets_initializer", "AppAssetsInitializer"),
"AsyncSandboxInitializer": ("core.sandbox.initializer", "AsyncSandboxInitializer"),
"ArchiveSandboxStorage": ("core.sandbox.storage", "ArchiveSandboxStorage"),
"DifyCli": ("core.sandbox.entities", "DifyCli"),
"DifyCliBinary": ("core.sandbox.bash.dify_cli", "DifyCliBinary"),
"DifyCliConfig": ("core.sandbox.bash.dify_cli", "DifyCliConfig"),
"DifyCliEnvConfig": ("core.sandbox.bash.dify_cli", "DifyCliEnvConfig"),
"DifyCliInitializer": ("core.sandbox.initializer.dify_cli_initializer", "DifyCliInitializer"),
"DifyCliLocator": ("core.sandbox.bash.dify_cli", "DifyCliLocator"),
"DifyCliToolConfig": ("core.sandbox.bash.dify_cli", "DifyCliToolConfig"),
"DraftAppAssetsInitializer": ("core.sandbox.initializer.draft_app_assets_initializer", "DraftAppAssetsInitializer"),
"Sandbox": ("core.sandbox.sandbox", "Sandbox"),
"SandboxBashSession": ("core.sandbox.bash.session", "SandboxBashSession"),
"SandboxBuilder": ("core.sandbox.builder", "SandboxBuilder"),
"SandboxInitializeContext": ("core.sandbox.initializer", "SandboxInitializeContext"),
"SandboxInitializer": ("core.sandbox.initializer", "SandboxInitializer"),
"SandboxManager": ("core.sandbox.manager", "SandboxManager"),
"SandboxProviderApiEntity": ("core.sandbox.entities", "SandboxProviderApiEntity"),
"SandboxStorage": ("core.sandbox.storage", "SandboxStorage"),
"SandboxType": ("core.sandbox.entities", "SandboxType"),
"SyncSandboxInitializer": ("core.sandbox.initializer", "SyncSandboxInitializer"),
"VMConfig": ("core.sandbox.builder", "VMConfig"),
"create_sandbox_config_encrypter": ("core.sandbox.utils.encryption", "create_sandbox_config_encrypter"),
"masked_config": ("core.sandbox.utils.encryption", "masked_config"),
"sandbox_debug": ("core.sandbox.utils.debug", "sandbox_debug"),
}
def __getattr__(name: str):
if name not in _LAZY_IMPORTS:
raise AttributeError(f"module 'core.sandbox' has no attribute {name}")
module_path, attr_name = _LAZY_IMPORTS[name]
module = importlib.import_module(module_path)
value = getattr(module, attr_name)
globals()[name] = value
return value

View File

@ -0,0 +1,15 @@
from .dify_cli import (
DifyCliBinary,
DifyCliConfig,
DifyCliEnvConfig,
DifyCliLocator,
DifyCliToolConfig,
)
__all__ = [
"DifyCliBinary",
"DifyCliConfig",
"DifyCliEnvConfig",
"DifyCliLocator",
"DifyCliToolConfig",
]

View File

@ -0,0 +1,139 @@
from collections.abc import Generator
from typing import Any
from core.sandbox.entities import DifyCli
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
ToolDescription,
ToolEntity,
ToolIdentity,
ToolInvokeMessage,
ToolParameter,
ToolProviderType,
)
from core.virtual_environment.__base.helpers import submit_command, with_connection
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
from ..utils.debug import sandbox_debug
COMMAND_TIMEOUT_SECONDS = 60 * 60 * 2 # 2 hours, can be adjusted based on expected command execution times
# Output truncation settings to avoid overwhelming model context
# 8000 chars ≈ 2000-2700 tokens, safe for models with 8K+ context
MAX_OUTPUT_LENGTH = 8000
TRUNCATE_HEAD_LENGTH = 2500 # Keep beginning for context
TRUNCATE_TAIL_LENGTH = 2500 # Keep end for results/errors
def _truncate_output(output: str, name: str = "output") -> str:
"""Truncate output if it exceeds the maximum length.
Keeps the head and tail of the output to preserve context and final results.
"""
if len(output) <= MAX_OUTPUT_LENGTH:
return output
omitted_length = len(output) - TRUNCATE_HEAD_LENGTH - TRUNCATE_TAIL_LENGTH
head = output[:TRUNCATE_HEAD_LENGTH]
tail = output[-TRUNCATE_TAIL_LENGTH:]
return f"{head}\n\n... [{omitted_length} characters omitted from {name}] ...\n\n{tail}"
class SandboxBashTool(Tool):
def __init__(self, sandbox: VirtualEnvironment, tenant_id: str, tools_path: str) -> None:
self._sandbox = sandbox
self._tools_path = tools_path
entity = ToolEntity(
identity=ToolIdentity(
author="Dify",
name="bash",
label=I18nObject(en_US="Bash", zh_Hans="Bash"),
provider="sandbox",
),
parameters=[
ToolParameter.get_simple_instance(
name="bash",
llm_description="The bash command to execute in current working directory",
typ=ToolParameter.ToolParameterType.STRING,
required=True,
),
],
description=ToolDescription(
human=I18nObject(
en_US="Execute bash commands in current working directory",
),
llm="Execute bash commands in current working directory. "
"Use this tool to run shell commands, scripts, or interact with the system. "
"The command will be executed in the current working directory. "
"IMPORTANT: If you generate any output files (images, documents, etc.) that need to be "
"returned or referenced later, you MUST save them to the 'output/' directory "
"(e.g., 'mkdir -p output && cp result.png output/'). Only files in output/ will be collected.",
),
)
runtime = ToolRuntime(tenant_id=tenant_id)
super().__init__(entity=entity, runtime=runtime)
def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.BUILT_IN
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
command = tool_parameters.get("bash", "")
if not command:
sandbox_debug("bash_tool", "parameters", tool_parameters)
yield self.create_text_message(
'Error: No command provided. The "bash" parameter is required and must contain '
'the shell command to execute. Example: {"bash": "ls -la"}'
)
return
try:
with with_connection(self._sandbox) as conn:
# Build command with embedded environment variables
env_exports = (
f"export PATH={self._tools_path}:/usr/local/bin:/usr/bin:/bin && "
f"export DIFY_CLI_CONFIG={self._tools_path}/{DifyCli.CONFIG_FILENAME} && "
)
full_command = env_exports + command
cmd_list = ["bash", "-c", full_command]
sandbox_debug("bash_tool", "cmd_list", cmd_list)
future = submit_command(
self._sandbox,
conn,
cmd_list,
)
timeout = COMMAND_TIMEOUT_SECONDS if COMMAND_TIMEOUT_SECONDS > 0 else None
result = future.result(timeout=timeout)
stdout = result.stdout.decode("utf-8", errors="replace") if result.stdout else ""
stderr = result.stderr.decode("utf-8", errors="replace") if result.stderr else ""
# Truncate long outputs to avoid overwhelming the model
stdout = _truncate_output(stdout, "stdout")
stderr = _truncate_output(stderr, "stderr")
output_parts: list[str] = []
if stdout:
output_parts.append(f"\n{stdout}")
if stderr:
output_parts.append(f"\n{stderr}")
yield self.create_text_message("\n".join(output_parts))
except TimeoutError:
yield self.create_text_message(f"Error: Command timed out after {COMMAND_TIMEOUT_SECONDS}s")
except Exception as e:
yield self.create_text_message(f"Error: {e!s}")

View File

@ -0,0 +1,164 @@
from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING, Any
from pydantic import BaseModel, Field
from core.session.cli_api import CliApiSession
from core.skill.entities import ToolDependencies, ToolReference
from core.tools.entities.tool_entities import ToolParameter, ToolProviderType
from core.virtual_environment.__base.entities import Arch, OperatingSystem
from graphon.model_runtime.utils.encoders import jsonable_encoder
from ..entities import DifyCli
if TYPE_CHECKING:
from core.tools.__base.tool import Tool
class DifyCliBinary(BaseModel):
operating_system: OperatingSystem = Field(alias="os")
arch: Arch
path: Path
model_config = {
"populate_by_name": True,
"arbitrary_types_allowed": True,
}
class DifyCliLocator:
def __init__(self, root: str | Path | None = None) -> None:
from configs import dify_config
if root is not None:
self._root = Path(root)
elif dify_config.SANDBOX_DIFY_CLI_ROOT:
self._root = Path(dify_config.SANDBOX_DIFY_CLI_ROOT)
else:
api_root = Path(__file__).resolve().parents[3]
self._root = api_root / "bin"
def resolve(self, operating_system: OperatingSystem, arch: Arch) -> DifyCliBinary:
filename = DifyCli.PATH_PATTERN.format(os=operating_system.value, arch=arch.value)
candidate = self._root / filename
if not candidate.is_file():
raise FileNotFoundError(
f"dify CLI binary not found: {candidate}. Configure SANDBOX_DIFY_CLI_ROOT or ensure the file exists."
)
return DifyCliBinary(os=operating_system, arch=arch, path=candidate)
class DifyCliEnvConfig(BaseModel):
files_url: str
cli_api_url: str
cli_api_session_id: str
cli_api_secret: str
class DifyCliToolConfig(BaseModel):
provider_type: str
enabled: bool = True
identity: dict[str, Any]
description: dict[str, Any]
parameters: list[dict[str, Any]]
@classmethod
def transform_provider_type(cls, tool_provider_type: ToolProviderType) -> str:
provider_type = tool_provider_type
match tool_provider_type:
case ToolProviderType.BUILT_IN | ToolProviderType.PLUGIN:
provider_type = "builtin"
case ToolProviderType.MCP | ToolProviderType.WORKFLOW | ToolProviderType.API:
provider_type = provider_type
case _:
raise ValueError(f"Invalid tool provider type: {tool_provider_type}")
return provider_type
@classmethod
def create_from_tool(cls, tool: Tool) -> DifyCliToolConfig:
return cls(
identity=to_json(tool.entity.identity),
provider_type=cls.transform_provider_type(tool.tool_provider_type()),
description=to_json(tool.entity.description),
parameters=[cls.transform_parameter(parameter) for parameter in tool.entity.parameters],
)
@classmethod
def transform_parameter(cls, parameter: ToolParameter) -> dict[str, Any]:
transformed_parameter = to_json(parameter)
transformed_parameter.pop("input_schema", None)
transformed_parameter.pop("form", None)
match parameter.type:
case (
ToolParameter.ToolParameterType.SYSTEM_FILES
| ToolParameter.ToolParameterType.FILE
| ToolParameter.ToolParameterType.FILES
):
return transformed_parameter
case _:
return transformed_parameter
class DifyCliToolReference(BaseModel):
id: str
tool_type: str
tool_name: str
tool_provider: str
credential_id: str | None = None
default_value: dict[str, Any] | None = None
@classmethod
def create_from_tool_reference(cls, reference: ToolReference) -> DifyCliToolReference:
return cls(
id=reference.uuid,
tool_type=reference.type.value,
tool_name=reference.tool_name,
tool_provider=reference.provider,
credential_id=reference.credential_id,
default_value=reference.configuration.default_values() if reference.configuration else None,
)
class DifyCliConfig(BaseModel):
env: DifyCliEnvConfig
tool_references: list[DifyCliToolReference]
tools: list[DifyCliToolConfig]
@classmethod
def create(
cls,
session: CliApiSession,
tenant_id: str,
tool_deps: ToolDependencies,
) -> DifyCliConfig:
from configs import dify_config
cli_api_url = dify_config.CLI_API_URL
return cls(
env=DifyCliEnvConfig(
files_url=dify_config.FILES_API_URL,
cli_api_url=cli_api_url,
cli_api_session_id=session.id,
cli_api_secret=session.secret,
),
tool_references=[DifyCliToolReference.create_from_tool_reference(ref) for ref in tool_deps.references],
tools=[],
)
def to_json(obj: Any) -> dict[str, Any]:
return jsonable_encoder(obj, exclude_unset=True, exclude_defaults=True, exclude_none=True)
__all__ = [
"DifyCliBinary",
"DifyCliConfig",
"DifyCliEnvConfig",
"DifyCliLocator",
"DifyCliToolConfig",
"DifyCliToolReference",
]

View File

@ -0,0 +1,239 @@
from __future__ import annotations
import json
import logging
import mimetypes
import os
import shlex
from types import TracebackType
from core.sandbox.sandbox import Sandbox
from core.session.cli_api import CliApiSession, CliApiSessionManager, CliContext
from core.skill.entities import ToolAccessPolicy
from core.skill.entities.tool_dependencies import ToolDependencies
from core.tools.signature import sign_tool_file
from core.tools.tool_file_manager import ToolFileManager
from core.virtual_environment.__base.helpers import pipeline
from graphon.file import File, FileTransferMethod, FileType
from ..bash.dify_cli import DifyCliConfig
from ..entities import DifyCli
from .bash_tool import SandboxBashTool
logger = logging.getLogger(__name__)
SANDBOX_READY_TIMEOUT = 60 * 10
# Default output directory for sandbox-generated files
SANDBOX_OUTPUT_DIR = "output"
# Maximum number of files to collect from sandbox output
MAX_OUTPUT_FILES = 50
# Maximum file size to collect (10MB)
MAX_OUTPUT_FILE_SIZE = 10 * 1024 * 1024
class SandboxBashSession:
def __init__(self, *, sandbox: Sandbox, node_id: str, tools: ToolDependencies | None) -> None:
self._sandbox = sandbox
self._node_id = node_id
self._tools = tools
self._bash_tool: SandboxBashTool | None = None
self._cli_api_session: CliApiSession | None = None
self._tenant_id = sandbox.tenant_id
self._user_id = sandbox.user_id
self._app_id = sandbox.app_id
self._assets_id = sandbox.assets_id
def __enter__(self) -> SandboxBashSession:
# Ensure sandbox initialization completes before any bash commands run.
self._sandbox.wait_ready(timeout=SANDBOX_READY_TIMEOUT)
cli = DifyCli(self._sandbox.id)
self._cli_api_session = CliApiSessionManager().create(
tenant_id=self._tenant_id,
user_id=self._user_id,
context=CliContext(tool_access=ToolAccessPolicy.from_dependencies(self._tools)),
)
if self._tools is not None and not self._tools.is_empty():
tools_path = self._setup_node_tools_directory(cli, self._node_id, self._tools, self._cli_api_session)
else:
tools_path = cli.global_tools_path
self._bash_tool = SandboxBashTool(
sandbox=self._sandbox.vm,
tenant_id=self._tenant_id,
tools_path=tools_path,
)
return self
def _setup_node_tools_directory(
self,
cli: DifyCli,
node_id: str,
tools: ToolDependencies,
cli_api_session: CliApiSession,
) -> str:
node_tools_path = cli.node_tools_path(node_id)
config_json = json.dumps(
DifyCliConfig.create(session=cli_api_session, tenant_id=self._tenant_id, tool_deps=tools).model_dump(
mode="json"
),
ensure_ascii=False,
)
config_path = shlex.quote(cli.node_config_path(node_id))
vm = self._sandbox.vm
# Merge mkdir + config write into a single pipeline to reduce round-trips.
(
pipeline(vm)
.add(["mkdir", "-p", cli.global_tools_path], error_message="Failed to create global tools dir")
.add(["mkdir", "-p", node_tools_path], error_message="Failed to create node tools dir")
# Use a quoted heredoc (<<'EOF') so the shell performs no expansion on the
# content — safe regardless of $, `, \, or quotes inside the JSON.
.add(
["sh", "-c", f"cat > {config_path} << '__DIFY_CFG__'\n{config_json}\n__DIFY_CFG__"],
error_message="Failed to write CLI config",
)
.execute(raise_on_error=True)
)
pipeline(vm, cwd=node_tools_path).add(
[cli.bin_path, "init"], error_message="Failed to initialize Dify CLI"
).execute(raise_on_error=True)
logger.info(
"Node %s tools initialized, path=%s, tool_count=%d", node_id, node_tools_path, len(tools.references)
)
return node_tools_path
def __exit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
tb: TracebackType | None,
) -> bool:
try:
if self._cli_api_session is not None:
CliApiSessionManager().delete(self._cli_api_session.id)
logger.debug("Cleaned up SandboxSession session_id=%s", self._cli_api_session.id)
self._cli_api_session = None
except Exception:
logger.exception("Failed to cleanup SandboxSession")
return False
@property
def bash_tool(self) -> SandboxBashTool:
if self._bash_tool is None:
raise RuntimeError("SandboxSession is not initialized")
return self._bash_tool
def collect_output_files(self, output_dir: str = SANDBOX_OUTPUT_DIR) -> list[File]:
"""
Collect files from sandbox output directory and save them as ToolFiles.
Scans the specified output directory in sandbox, downloads each file,
saves it as a ToolFile, and returns a list of File objects. The File
objects will have valid tool_file_id that can be referenced by subsequent
nodes via structured output.
Args:
output_dir: Directory path in sandbox to scan for output files.
Defaults to "output" (relative to workspace).
Returns:
List of File objects representing the collected files.
"""
vm = self._sandbox.vm
collected_files: list[File] = []
try:
file_states = vm.list_files(output_dir, limit=MAX_OUTPUT_FILES)
except Exception as exc:
# Output directory may not exist if no files were generated
logger.debug("Failed to list sandbox output files in %s: %s", output_dir, exc)
return collected_files
tool_file_manager = ToolFileManager()
for file_state in file_states:
# Skip files that are too large
if file_state.size > MAX_OUTPUT_FILE_SIZE:
logger.warning(
"Skipping sandbox output file %s: size %d exceeds limit %d",
file_state.path,
file_state.size,
MAX_OUTPUT_FILE_SIZE,
)
continue
try:
# file_state.path is already relative to working_path (e.g., "output/file.png")
file_content = vm.download_file(file_state.path)
file_binary = file_content.getvalue()
# Determine mime type from extension
filename = os.path.basename(file_state.path)
mime_type, _ = mimetypes.guess_type(filename)
if not mime_type:
mime_type = "application/octet-stream"
# Save as ToolFile
tool_file = tool_file_manager.create_file_by_raw(
user_id=self._user_id,
tenant_id=self._tenant_id,
conversation_id=None,
file_binary=file_binary,
mimetype=mime_type,
filename=filename,
)
# Determine file type from mime type
file_type = _get_file_type_from_mime(mime_type)
extension = os.path.splitext(filename)[1] if "." in filename else ".bin"
url = sign_tool_file(tool_file.id, extension)
# Create File object with tool_file_id as related_id
file_obj = File(
id=tool_file.id, # Use tool_file_id as the File id for easy reference
tenant_id=self._tenant_id,
type=file_type,
transfer_method=FileTransferMethod.TOOL_FILE,
filename=filename,
extension=extension,
mime_type=mime_type,
size=len(file_binary),
related_id=tool_file.id,
url=url,
storage_key=tool_file.file_key,
)
collected_files.append(file_obj)
logger.info(
"Collected sandbox output file: %s -> tool_file_id=%s",
file_state.path,
tool_file.id,
)
except Exception as exc:
logger.warning("Failed to collect sandbox output file %s: %s", file_state.path, exc)
continue
logger.info(
"Collected %d files from sandbox output directory %s",
len(collected_files),
output_dir,
)
return collected_files
def _get_file_type_from_mime(mime_type: str) -> FileType:
"""Determine FileType from mime type."""
if mime_type.startswith("image/"):
return FileType.IMAGE
elif mime_type.startswith("video/"):
return FileType.VIDEO
elif mime_type.startswith("audio/"):
return FileType.AUDIO
elif "text" in mime_type or "pdf" in mime_type:
return FileType.DOCUMENT
else:
return FileType.CUSTOM

222
api/core/sandbox/builder.py Normal file
View File

@ -0,0 +1,222 @@
from __future__ import annotations
import logging
import threading
from collections.abc import Mapping, Sequence
from contextlib import nullcontext
from typing import TYPE_CHECKING, Any, cast
from flask import Flask, current_app, has_app_context
from core.entities.provider_entities import BasicProviderConfig
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
from .entities.sandbox_type import SandboxType
from .initializer import AsyncSandboxInitializer, SandboxInitializeContext, SandboxInitializer, SyncSandboxInitializer
from .sandbox import Sandbox
if TYPE_CHECKING:
from .storage.sandbox_storage import SandboxStorage
logger = logging.getLogger(__name__)
def _get_sandbox_class(sandbox_type: SandboxType) -> type[VirtualEnvironment]:
match sandbox_type:
case SandboxType.DOCKER:
from core.virtual_environment.providers.docker_daemon_sandbox import DockerDaemonEnvironment
return DockerDaemonEnvironment
case SandboxType.E2B:
from core.virtual_environment.providers.e2b_sandbox import E2BEnvironment
return E2BEnvironment
case SandboxType.LOCAL:
from core.virtual_environment.providers.local_without_isolation import LocalVirtualEnvironment
return LocalVirtualEnvironment
case SandboxType.SSH:
from core.virtual_environment.providers.ssh_sandbox import SSHSandboxEnvironment
return SSHSandboxEnvironment
case SandboxType.AWS_CODE_INTERPRETER:
from core.virtual_environment.providers.aws_code_interpreter_sandbox import (
AWSCodeInterpreterEnvironment,
)
return AWSCodeInterpreterEnvironment
case _:
raise ValueError(f"Unsupported sandbox type: {sandbox_type}")
class SandboxBuilder:
_tenant_id: str
_sandbox_type: SandboxType
_user_id: str | None
_app_id: str | None
_options: dict[str, Any]
_environments: dict[str, str]
_initializers: list[SandboxInitializer]
_storage: SandboxStorage | None
_assets_id: str | None
def __init__(self, tenant_id: str, sandbox_type: SandboxType) -> None:
self._tenant_id = tenant_id
self._sandbox_type = sandbox_type
self._user_id = None
self._app_id = None
self._options = {}
self._environments = {}
self._initializers = []
self._storage = None
self._assets_id = None
def user(self, user_id: str) -> SandboxBuilder:
self._user_id = user_id
return self
def app(self, app_id: str) -> SandboxBuilder:
self._app_id = app_id
return self
def options(self, options: Mapping[str, Any]) -> SandboxBuilder:
self._options = dict(options)
return self
def environments(self, environments: Mapping[str, str]) -> SandboxBuilder:
self._environments = dict(environments)
return self
def initializer(self, initializer: SandboxInitializer) -> SandboxBuilder:
self._initializers.append(initializer)
return self
def initializers(self, initializers: Sequence[SandboxInitializer]) -> SandboxBuilder:
self._initializers.extend(initializers)
return self
def storage(self, storage: SandboxStorage, assets_id: str) -> SandboxBuilder:
self._storage = storage
self._assets_id = assets_id
return self
def build(self) -> Sandbox:
"""Create a sandbox and start background initialization.
The builder is responsible for cleaning up any VM or sandbox that was
successfully created if a later setup step fails.
"""
if self._storage is None:
raise ValueError("storage is required, call .storage() before .build()")
if self._assets_id is None:
raise ValueError("assets_id is required, call .storage() before .build()")
if self._user_id is None:
raise ValueError("user_id is required, call .user() before .build()")
if self._app_id is None:
raise ValueError("app_id is required, call .app() before .build()")
ctx = SandboxInitializeContext(
tenant_id=self._tenant_id,
app_id=self._app_id,
assets_id=self._assets_id,
user_id=self._user_id,
)
vm: VirtualEnvironment | None = None
sandbox: Sandbox | None = None
try:
vm_class = _get_sandbox_class(self._sandbox_type)
vm = vm_class(
tenant_id=self._tenant_id,
options=self._options,
environments=self._environments,
user_id=self._user_id,
)
vm.open_enviroment()
sandbox = Sandbox(
vm=vm,
storage=self._storage,
tenant_id=self._tenant_id,
user_id=self._user_id,
app_id=self._app_id,
assets_id=self._assets_id,
)
for init in self._initializers:
if isinstance(init, SyncSandboxInitializer):
init.initialize(sandbox, ctx)
except Exception as exc:
logger.exception(
"Failed to initialize sandbox synchronously: tenant_id=%s, app_id=%s", self._tenant_id, self._app_id
)
if sandbox is not None:
sandbox.release()
elif vm is not None:
try:
vm.release_environment()
except Exception:
logger.exception("Failed to release sandbox VM during builder cleanup")
raise RuntimeError("Sandbox initialization failed") from exc
# Run sandbox setup asynchronously so workflow execution can proceed.
# Capture the Flask app before starting the thread for database access.
flask_app: Flask | None = cast(Any, current_app)._get_current_object() if has_app_context() else None
_sandbox: Sandbox = sandbox
def initialize() -> None:
try:
app_context = flask_app.app_context() if flask_app is not None else nullcontext()
with app_context:
for init in self._initializers:
if not isinstance(init, AsyncSandboxInitializer):
continue
if _sandbox.is_cancelled():
return
init.initialize(_sandbox, ctx)
if _sandbox.is_cancelled():
return
_sandbox.mount()
_sandbox.mark_ready()
except Exception as exc:
try:
logger.exception(
"Failed to initialize sandbox: tenant_id=%s, app_id=%s", self._tenant_id, self._app_id
)
_sandbox.release()
_sandbox.mark_failed(exc)
except Exception:
logger.exception(
"Failed to mark sandbox initialization failure: tenant_id=%s, app_id=%s",
self._tenant_id,
self._app_id,
)
# Background init completes or signals failure via sandbox state.
try:
threading.Thread(target=initialize, daemon=True).start()
except Exception:
logger.exception(
"Failed to start sandbox initialization thread: tenant_id=%s, app_id=%s",
self._tenant_id,
self._app_id,
)
sandbox.release()
raise RuntimeError("Sandbox initialization failed")
return sandbox
@staticmethod
def validate(vm_type: SandboxType, options: Mapping[str, Any]) -> None:
vm_class = _get_sandbox_class(vm_type)
vm_class.validate(options)
@classmethod
def draft_id(cls, user_id: str) -> str:
return user_id
class VMConfig:
@staticmethod
def get_schema(vm_type: SandboxType) -> list[BasicProviderConfig]:
return _get_sandbox_class(vm_type).get_config_schema()

View File

@ -0,0 +1,13 @@
from .config import AppAssets, DifyCli
from .files import SandboxFileDownloadTicket, SandboxFileNode
from .providers import SandboxProviderApiEntity
from .sandbox_type import SandboxType
__all__ = [
"AppAssets",
"DifyCli",
"SandboxFileDownloadTicket",
"SandboxFileNode",
"SandboxProviderApiEntity",
"SandboxType",
]

View File

@ -0,0 +1,58 @@
from typing import Final
class DifyCli:
"""Per-sandbox Dify CLI paths, namespaced under ``/tmp/.dify/{env_id}``.
Every sandbox environment gets its own directory tree so that
concurrent sessions on the same host (e.g. SSH provider) never
collide on config files or CLI binaries.
Class-level constants (``CONFIG_FILENAME``, ``PATH_PATTERN``) are
safe to share; all path attributes are instance-level and derived
from the ``env_id`` passed at construction time.
"""
# --- class-level constants (no path component) ---
CONFIG_FILENAME: Final[str] = ".dify_cli.json"
PATH_PATTERN: Final[str] = "dify-cli-{os}-{arch}"
# --- instance attributes ---
root: str
bin_dir: str
bin_path: str
tools_root: str
global_tools_path: str
global_config_path: str
def __init__(self, env_id: str) -> None:
self.root = f"/tmp/.dify/{env_id}"
self.bin_dir = f"{self.root}/bin"
self.bin_path = f"{self.bin_dir}/dify"
self.tools_root = f"{self.root}/tools"
self.global_tools_path = f"{self.root}/tools/global"
self.global_config_path = f"{self.global_tools_path}/{DifyCli.CONFIG_FILENAME}"
def node_tools_path(self, node_id: str) -> str:
return f"{self.tools_root}/{node_id}"
def node_config_path(self, node_id: str) -> str:
return f"{self.node_tools_path(node_id)}/{DifyCli.CONFIG_FILENAME}"
class AppAssets:
"""App Assets constants.
``PATH`` is a relative path resolved by each provider against its
own workspace root already isolated. ``zip_path`` is an absolute
temp path and must be namespaced per environment to avoid collisions.
"""
PATH: Final[str] = "skills"
root: str
zip_path: str
def __init__(self, env_id: str) -> None:
self.root = f"/tmp/.dify/{env_id}"
self.zip_path = f"{self.root}/assets.zip"

View File

@ -0,0 +1,19 @@
from __future__ import annotations
from dataclasses import dataclass
@dataclass(frozen=True)
class SandboxFileNode:
path: str
is_dir: bool
size: int | None
mtime: int | None
extension: str | None
@dataclass(frozen=True)
class SandboxFileDownloadTicket:
download_url: str
expires_in: int
export_id: str

View File

@ -0,0 +1,21 @@
from collections.abc import Mapping
from typing import Any
from pydantic import BaseModel, Field
class SandboxProviderApiEntity(BaseModel):
provider_type: str = Field(..., description="Provider type identifier")
is_system_configured: bool = Field(default=False)
is_tenant_configured: bool = Field(default=False)
is_active: bool = Field(default=False)
config: Mapping[str, Any] = Field(default_factory=dict)
config_schema: list[dict[str, Any]] = Field(default_factory=list)
class SandboxProviderEntity(BaseModel):
id: str = Field(..., description="Provider identifier")
provider_type: str = Field(..., description="Provider type identifier")
is_active: bool = Field(default=False)
config: Mapping[str, Any] = Field(default_factory=dict)
config_schema: list[dict[str, Any]] = Field(default_factory=list)

View File

@ -0,0 +1,18 @@
from enum import StrEnum
from configs import dify_config
class SandboxType(StrEnum):
DOCKER = "docker"
E2B = "e2b"
LOCAL = "local"
SSH = "ssh"
AWS_CODE_INTERPRETER = "aws_code_interpreter"
@classmethod
def get_all(cls) -> list[str]:
if dify_config.EDITION == "SELF_HOSTED":
return [p.value for p in cls]
else:
return [p.value for p in cls if p != SandboxType.LOCAL]

View File

@ -0,0 +1,8 @@
from .base import AsyncSandboxInitializer, SandboxInitializeContext, SandboxInitializer, SyncSandboxInitializer
__all__ = [
"AsyncSandboxInitializer",
"SandboxInitializeContext",
"SandboxInitializer",
"SyncSandboxInitializer",
]

View File

@ -0,0 +1,19 @@
import logging
from core.app_assets.constants import AppAssetsAttrs
from core.sandbox.sandbox import Sandbox
from services.app_asset_package_service import AppAssetPackageService
from .base import SandboxInitializeContext, SyncSandboxInitializer
logger = logging.getLogger(__name__)
APP_ASSETS_DOWNLOAD_TIMEOUT = 60 * 10
class AppAssetAttrsInitializer(SyncSandboxInitializer):
def initialize(self, sandbox: Sandbox, ctx: SandboxInitializeContext) -> None:
# Load published app assets and unzip the artifact bundle.
app_assets = AppAssetPackageService.get_tenant_app_assets(ctx.tenant_id, ctx.assets_id)
sandbox.attrs.set(AppAssetsAttrs.FILE_TREE, app_assets.asset_tree)
sandbox.attrs.set(AppAssetsAttrs.APP_ASSETS_ID, ctx.assets_id)

View File

@ -0,0 +1,59 @@
import logging
from core.app_assets.storage import AssetPaths
from core.sandbox.sandbox import Sandbox
from core.virtual_environment.__base.helpers import pipeline
from ..entities import AppAssets
from .base import AsyncSandboxInitializer, SandboxInitializeContext
logger = logging.getLogger(__name__)
APP_ASSETS_DOWNLOAD_TIMEOUT = 60 * 10
class AppAssetsInitializer(AsyncSandboxInitializer):
def initialize(self, sandbox: Sandbox, ctx: SandboxInitializeContext) -> None:
from services.app_asset_service import AppAssetService
# Load published app assets and unzip the artifact bundle.
vm = sandbox.vm
assets = AppAssets(sandbox.id)
asset_storage = AppAssetService.get_storage()
key = AssetPaths.build_zip(ctx.tenant_id, ctx.app_id, ctx.assets_id)
download_url = asset_storage.get_download_url(key)
(
pipeline(vm)
.add(
["mkdir", "-p", assets.root],
error_message="Failed to create assets temp directory",
)
.add(
["sh", "-c", 'curl -fsSL "$1" -o "$2"', "sh", download_url, assets.zip_path],
error_message="Failed to download assets zip",
)
# Create the assets directory first to ensure it exists even if zip is empty
.add(
["mkdir", "-p", AppAssets.PATH],
error_message="Failed to create assets directory",
)
# unzip with silent error and return 1 if the zip is empty
# FIXME(Mairuis): should use a more robust way to check if the zip is empty
.add(
["sh", "-c", 'unzip "$1" -d "$2" 2>/dev/null || [ $? -eq 1 ]', "sh", assets.zip_path, AppAssets.PATH],
error_message="Failed to unzip assets",
)
# Ensure directories have execute permission for traversal and files are readable
.add(
["sh", "-c", 'chmod -R u+rwX,go+rX "$1"', "sh", AppAssets.PATH],
error_message="Failed to set permissions on assets",
)
.execute(timeout=APP_ASSETS_DOWNLOAD_TIMEOUT, raise_on_error=True)
)
logger.info(
"App assets initialized for app_id=%s, published_id=%s",
ctx.app_id,
ctx.assets_id,
)

View File

@ -0,0 +1,46 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from core.sandbox.sandbox import Sandbox
if TYPE_CHECKING:
from core.app_assets.entities.assets import AssetItem
@dataclass
class SandboxInitializeContext:
"""Shared identity context passed to every ``SandboxInitializer``.
Carries the common identity fields that virtually every initializer
needs, plus optional artefact slots that sync initializers populate
for async initializers to consume.
Identity fields are immutable by convention; artefact slots are
written at most once during the sync phase and read during the
async phase.
"""
tenant_id: str
app_id: str
assets_id: str
user_id: str
# Populated by DraftAppAssetsInitializer (sync) for
# DraftAppAssetsDownloader (async) to download into the VM.
built_assets: list[AssetItem] | None = field(default=None)
class SandboxInitializer(ABC):
@abstractmethod
def initialize(self, sandbox: Sandbox, ctx: SandboxInitializeContext) -> None: ...
class SyncSandboxInitializer(SandboxInitializer):
"""Marker class for initializers that must run before async setup."""
class AsyncSandboxInitializer(SandboxInitializer):
"""Marker class for initializers that can run in the background."""

View File

@ -0,0 +1,76 @@
from __future__ import annotations
import json
import logging
from io import BytesIO
from pathlib import Path
from core.sandbox.sandbox import Sandbox
from core.session.cli_api import CliApiSessionManager, CliContext
from core.skill.constants import SkillAttrs
from core.skill.entities import ToolAccessPolicy
from core.virtual_environment.__base.helpers import pipeline
from ..bash.dify_cli import DifyCliConfig, DifyCliLocator
from ..entities import DifyCli
from .base import AsyncSandboxInitializer, SandboxInitializeContext
logger = logging.getLogger(__name__)
class DifyCliInitializer(AsyncSandboxInitializer):
def __init__(self, cli_root: str | Path | None = None) -> None:
self._locator = DifyCliLocator(root=cli_root)
self._tools: list[object] = []
def initialize(self, sandbox: Sandbox, ctx: SandboxInitializeContext) -> None:
vm = sandbox.vm
cli = DifyCli(sandbox.id)
# FIXME(Mairuis): should be more robust, effectively.
binary = self._locator.resolve(vm.metadata.os, vm.metadata.arch)
pipeline(vm).add(["mkdir", "-p", cli.bin_dir], error_message="Failed to create dify CLI directory").execute(
raise_on_error=True
)
vm.upload_file(cli.bin_path, BytesIO(binary.path.read_bytes()))
pipeline(vm).add(
[
"sh",
"-c",
f"cat '{cli.bin_path}' > '{cli.bin_path}.tmp' && "
f"mv '{cli.bin_path}.tmp' '{cli.bin_path}' && "
f"chmod +x '{cli.bin_path}'",
],
error_message="Failed to mark dify CLI as executable",
).execute(raise_on_error=True)
logger.info("Dify CLI uploaded to sandbox, path=%s", cli.bin_path)
bundle = sandbox.attrs.get(SkillAttrs.BUNDLE)
if bundle is None or bundle.get_tool_dependencies().is_empty():
logger.info("No tools found in bundle for assets_id=%s", ctx.assets_id)
return
global_cli_session = CliApiSessionManager().create(
tenant_id=ctx.tenant_id,
user_id=ctx.user_id,
context=CliContext(tool_access=ToolAccessPolicy.from_dependencies(bundle.get_tool_dependencies())),
)
pipeline(vm).add(
["mkdir", "-p", cli.global_tools_path], error_message="Failed to create global tools dir"
).execute(raise_on_error=True)
config = DifyCliConfig.create(global_cli_session, ctx.tenant_id, bundle.get_tool_dependencies())
config_json = json.dumps(config.model_dump(mode="json"), ensure_ascii=False)
config_path = cli.global_config_path
vm.upload_file(config_path, BytesIO(config_json.encode("utf-8")))
pipeline(vm, cwd=cli.global_tools_path).add(
[cli.bin_path, "init"], error_message="Failed to initialize Dify CLI"
).execute(raise_on_error=True)
logger.info("Global tools initialized, path=%s, tool_count=%d", cli.global_tools_path, len(self._tools))

View File

@ -0,0 +1,89 @@
"""Synchronous initializer that compiles draft app assets.
Unlike ``AppAssetsInitializer`` (which downloads a pre-built ZIP for
published assets), this initializer runs the build pipeline on the fly
so that ``.md`` skill documents are compiled and their resolved content
is embedded directly into the download script avoiding the S3
round-trip that was previously required for resolved keys.
Execution order:
``DraftAppAssetsInitializer`` (sync) compiles assets and publishes
the ``SkillBundle`` to ``sandbox.attrs`` in-memory, so the
downstream ``SkillInitializer`` can skip the Redis/S3 round-trip.
``DraftAppAssetsDownloader`` (async) then pushes the compiled
artefacts into the sandbox VM in the background.
"""
import logging
from core.app_assets.builder.base import BuildContext
from core.app_assets.builder.file_builder import FileBuilder
from core.app_assets.builder.pipeline import AssetBuildPipeline
from core.app_assets.builder.skill_builder import SkillBuilder
from core.app_assets.constants import AppAssetsAttrs
from core.sandbox.entities import AppAssets
from core.sandbox.sandbox import Sandbox
from core.sandbox.services import AssetDownloadService
from core.skill import SkillAttrs
from core.virtual_environment.__base.helpers import pipeline
from services.app_asset_service import AppAssetService
from .base import AsyncSandboxInitializer, SandboxInitializeContext, SyncSandboxInitializer
logger = logging.getLogger(__name__)
class DraftAppAssetsInitializer(SyncSandboxInitializer):
"""Compile draft assets and publish the ``SkillBundle`` to attrs.
The build pipeline compiles ``.md`` skill files in-process.
The resulting ``SkillBundle`` is persisted to Redis/S3 (by
``SkillBuilder``) **and** written to ``sandbox.attrs[BUNDLE]``
so that ``SkillInitializer`` can read it without a round-trip.
Built asset items are stored on ``ctx.built_assets`` for the
async ``DraftAppAssetsDownloader`` to consume.
"""
def initialize(self, sandbox: Sandbox, ctx: SandboxInitializeContext) -> None:
tree = sandbox.attrs.get(AppAssetsAttrs.FILE_TREE)
# --- 1. Run the build pipeline (SkillBuilder compiles .md inline) ---
accessor = AppAssetService.get_accessor(ctx.tenant_id, ctx.app_id)
skill_builder = SkillBuilder(accessor=accessor)
build_pipeline = AssetBuildPipeline([skill_builder, FileBuilder()])
build_ctx = BuildContext(tenant_id=ctx.tenant_id, app_id=ctx.app_id, build_id=ctx.assets_id)
built_assets = build_pipeline.build_all(tree, build_ctx)
ctx.built_assets = built_assets
# Publish the in-memory bundle so SkillInitializer skips Redis/S3.
if skill_builder.bundle is not None:
sandbox.attrs.set(SkillAttrs.BUNDLE, skill_builder.bundle)
class DraftAppAssetsDownloader(AsyncSandboxInitializer):
"""Download the compiled assets into the sandbox VM.
The download script is generated by ``DraftAppAssetsInitializer`` and
includes inline base64 content for compiled skills, as well as
presigned URLs for other files.
"""
_TIMEOUT = 600 # 10 minutes
def initialize(self, sandbox: Sandbox, ctx: SandboxInitializeContext) -> None:
if not ctx.built_assets:
logger.debug("No built assets found for assets_id=%s", ctx.assets_id)
return
download_items = AppAssetService.to_download_items(ctx.built_assets)
script = AssetDownloadService.build_download_script(download_items, AppAssets.PATH)
pipeline(sandbox.vm).add(
["sh", "-c", script],
error_message="Failed to download draft assets",
).execute(timeout=self._TIMEOUT, raise_on_error=True)
logger.info(
"Draft app assets initialized for app_id=%s, assets_id=%s",
ctx.app_id,
ctx.assets_id,
)

View File

@ -0,0 +1,34 @@
from __future__ import annotations
import logging
from core.sandbox.sandbox import Sandbox
from core.skill import SkillAttrs
from core.skill.skill_manager import SkillManager
from .base import SandboxInitializeContext, SyncSandboxInitializer
logger = logging.getLogger(__name__)
class SkillInitializer(SyncSandboxInitializer):
"""Ensure ``sandbox.attrs[BUNDLE]`` is populated for downstream consumers.
In the draft path ``DraftAppAssetsInitializer`` already sets the
bundle on attrs from the in-memory build result, so this initializer
becomes a no-op. In the published path no prior initializer sets
it, so we fall back to ``SkillManager.load_bundle()`` (Redis/S3).
"""
def initialize(self, sandbox: Sandbox, ctx: SandboxInitializeContext) -> None:
# Draft path: bundle already populated by DraftAppAssetsInitializer.
if sandbox.attrs.has(SkillAttrs.BUNDLE):
return
# Published path: load from Redis/S3.
bundle = SkillManager.load_bundle(
ctx.tenant_id,
ctx.app_id,
ctx.assets_id,
)
sandbox.attrs.set(SkillAttrs.BUNDLE, bundle)

View File

@ -0,0 +1,11 @@
from core.sandbox.inspector.archive_source import SandboxFileArchiveSource
from core.sandbox.inspector.base import SandboxFileSource
from core.sandbox.inspector.browser import SandboxFileBrowser
from core.sandbox.inspector.runtime_source import SandboxFileRuntimeSource
__all__ = [
"SandboxFileArchiveSource",
"SandboxFileBrowser",
"SandboxFileRuntimeSource",
"SandboxFileSource",
]

View File

@ -0,0 +1,169 @@
from __future__ import annotations
import logging
import os
from typing import TYPE_CHECKING
from uuid import uuid4
from core.sandbox.entities.files import SandboxFileDownloadTicket, SandboxFileNode
from core.sandbox.inspector.base import SandboxFileSource
from core.sandbox.inspector.script_utils import (
build_detect_kind_command,
build_list_command,
build_upload_command,
guess_content_type,
parse_kind_output,
parse_list_output,
)
from core.sandbox.storage import SandboxFilePaths
from core.virtual_environment.__base.exec import PipelineExecutionError
from core.virtual_environment.__base.helpers import pipeline
from extensions.ext_storage import storage
if TYPE_CHECKING:
from core.zip_sandbox import ZipSandbox
logger = logging.getLogger(__name__)
class SandboxFileArchiveSource(SandboxFileSource):
def _get_archive_download_url(self) -> str:
"""Get a pre-signed download URL for the sandbox archive."""
from extensions.storage.file_presign_storage import FilePresignStorage
storage_key = SandboxFilePaths.archive(self._tenant_id, self._app_id, self._sandbox_id)
if not storage.exists(storage_key):
raise ValueError("Sandbox archive not found")
presign_storage = FilePresignStorage(storage.storage_runner)
return presign_storage.get_download_url(storage_key, self._EXPORT_EXPIRES_IN_SECONDS)
def _create_zip_sandbox(self) -> ZipSandbox:
"""Create a ZipSandbox instance for archive operations."""
from core.zip_sandbox import ZipSandbox
return ZipSandbox(tenant_id=self._tenant_id, user_id="system", app_id=self._app_id)
def exists(self) -> bool:
"""Check if the sandbox archive exists in storage."""
storage_key = SandboxFilePaths.archive(self._tenant_id, self._app_id, self._sandbox_id)
return storage.exists(storage_key)
def list_files(self, *, path: str, recursive: bool) -> list[SandboxFileNode]:
archive_url = self._get_archive_download_url()
with self._create_zip_sandbox() as zs:
# Download and extract the archive
archive_path = zs.download_archive(archive_url, path="workspace.tar.gz")
zs.untar(archive_path=archive_path, dest_dir="workspace")
# List files using Python script in sandbox
try:
list_path = f"workspace/{path}" if path not in (".", "") else "workspace"
results = (
pipeline(zs.vm)
.add(
build_list_command(list_path, recursive),
error_message="Failed to list sandbox files",
)
.execute(timeout=self._LIST_TIMEOUT_SECONDS, raise_on_error=True)
)
except PipelineExecutionError as exc:
raise RuntimeError(str(exc)) from exc
raw = parse_list_output(results[0].stdout)
entries: list[SandboxFileNode] = []
for item in raw:
item_path = str(item.get("path"))
# Strip the "workspace/" prefix from paths
if item_path.startswith("workspace/"):
item_path = item_path[len("workspace/") :]
elif item_path == "workspace":
continue # Skip the workspace directory itself
item_is_dir = bool(item.get("is_dir"))
extension = None
if not item_is_dir:
ext = os.path.splitext(item_path)[1]
extension = ext or None
entries.append(
SandboxFileNode(
path=item_path,
is_dir=item_is_dir,
size=item.get("size"),
mtime=item.get("mtime"),
extension=extension,
)
)
return sorted(entries, key=lambda e: e.path)
def download_file(self, *, path: str) -> SandboxFileDownloadTicket:
"""Download a file or directory from the archived sandbox.
Uses direct upload from sandbox to storage via presigned URL, avoiding
data transfer through the service layer. This preserves binary integrity
(no text encoding issues) and reduces bandwidth overhead.
"""
from services.sandbox.sandbox_file_service import SandboxFileService
archive_url = self._get_archive_download_url()
export_name = os.path.basename(path.rstrip("/")) or "workspace"
export_id = uuid4().hex
with self._create_zip_sandbox() as zs:
archive_path = zs.download_archive(archive_url, path="workspace.tar.gz")
zs.untar(archive_path=archive_path, dest_dir="workspace")
target_path = f"workspace/{path}" if path not in (".", "") else "workspace"
try:
results = (
pipeline(zs.vm)
.add(
build_detect_kind_command(target_path),
error_message="Failed to check path in sandbox",
)
.execute(timeout=self._LIST_TIMEOUT_SECONDS, raise_on_error=True)
)
except PipelineExecutionError as exc:
raise ValueError(str(exc)) from exc
kind = parse_kind_output(results[0].stdout, not_found_message="File not found in sandbox archive")
sandbox_storage = SandboxFileService.get_storage()
is_file = kind == "file"
filename = (os.path.basename(path) or "file") if is_file else f"{export_name}.tar.gz"
export_key = SandboxFilePaths.export(self._tenant_id, self._app_id, self._sandbox_id, export_id)
upload_url = sandbox_storage.get_upload_url(export_key, self._EXPORT_EXPIRES_IN_SECONDS)
content_type = guess_content_type(filename)
# Build pipeline: for directories, tar first then upload; for files, upload directly
archive_temp = f"/tmp/{export_id}.tar.gz"
src_path = target_path if is_file else archive_temp
tar_src = path if path not in (".", "") else "."
try:
(
pipeline(zs.vm)
.add(
["tar", "-czf", archive_temp, "-C", "workspace", tar_src],
error_message="Failed to archive directory",
on=not is_file,
)
.add(
build_upload_command(src_path, upload_url, content_type=content_type),
error_message="Failed to upload file",
)
.add(["rm", "-f", archive_temp], on=not is_file)
.execute(timeout=self._UPLOAD_TIMEOUT_SECONDS, raise_on_error=True)
)
except PipelineExecutionError as exc:
raise RuntimeError(str(exc)) from exc
download_url = sandbox_storage.get_download_url(
export_key, self._EXPORT_EXPIRES_IN_SECONDS, download_filename=filename
)
return SandboxFileDownloadTicket(
download_url=download_url,
expires_in=self._EXPORT_EXPIRES_IN_SECONDS,
export_id=export_id,
)

View File

@ -0,0 +1,33 @@
from __future__ import annotations
import abc
from core.sandbox.entities.files import SandboxFileDownloadTicket, SandboxFileNode
class SandboxFileSource(abc.ABC):
_LIST_TIMEOUT_SECONDS = 30
_UPLOAD_TIMEOUT_SECONDS = 60 * 10
_EXPORT_EXPIRES_IN_SECONDS = 60 * 10
def __init__(self, *, tenant_id: str, app_id: str, sandbox_id: str):
self._tenant_id = tenant_id
self._app_id = app_id
self._sandbox_id = sandbox_id
@abc.abstractmethod
def exists(self) -> bool:
"""Check if the sandbox source exists and is available.
Returns:
True if the sandbox source exists and can be accessed, False otherwise.
"""
raise NotImplementedError
@abc.abstractmethod
def list_files(self, *, path: str, recursive: bool) -> list[SandboxFileNode]:
raise NotImplementedError
@abc.abstractmethod
def download_file(self, *, path: str) -> SandboxFileDownloadTicket:
raise NotImplementedError

View File

@ -0,0 +1,48 @@
from __future__ import annotations
from pathlib import PurePosixPath
from core.sandbox.entities.files import SandboxFileDownloadTicket, SandboxFileNode
from core.sandbox.inspector.archive_source import SandboxFileArchiveSource
from core.sandbox.inspector.base import SandboxFileSource
class SandboxFileBrowser:
def __init__(self, *, tenant_id: str, app_id: str, sandbox_id: str):
self._tenant_id = tenant_id
self._app_id = app_id
self._sandbox_id = sandbox_id
@staticmethod
def _normalize_workspace_path(path: str | None) -> str:
raw = (path or ".").strip()
if raw == "":
raw = "."
p = PurePosixPath(raw)
if p.is_absolute():
raise ValueError("path must be relative")
if any(part == ".." for part in p.parts):
raise ValueError("path must not contain '..'")
normalized = str(p)
return "." if normalized in (".", "") else normalized
def _backend(self) -> SandboxFileSource:
return SandboxFileArchiveSource(
tenant_id=self._tenant_id,
app_id=self._app_id,
sandbox_id=self._sandbox_id,
)
def exists(self) -> bool:
"""Check if the sandbox source exists and is available."""
return self._backend().exists()
def list_files(self, *, path: str | None = None, recursive: bool = False) -> list[SandboxFileNode]:
workspace_path = self._normalize_workspace_path(path)
return self._backend().list_files(path=workspace_path, recursive=recursive)
def download_file(self, *, path: str) -> SandboxFileDownloadTicket:
workspace_path = self._normalize_workspace_path(path)
return self._backend().download_file(path=workspace_path)

View File

@ -0,0 +1,140 @@
from __future__ import annotations
import logging
import os
from uuid import uuid4
from core.sandbox.entities.files import SandboxFileDownloadTicket, SandboxFileNode
from core.sandbox.inspector.base import SandboxFileSource
from core.sandbox.inspector.script_utils import (
build_detect_kind_command,
build_list_command,
build_upload_command,
guess_content_type,
parse_kind_output,
parse_list_output,
)
from core.sandbox.storage import SandboxFilePaths
from core.virtual_environment.__base.exec import PipelineExecutionError
from core.virtual_environment.__base.helpers import pipeline
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
logger = logging.getLogger(__name__)
class SandboxFileRuntimeSource(SandboxFileSource):
def __init__(self, *, tenant_id: str, app_id: str, sandbox_id: str, runtime: VirtualEnvironment):
super().__init__(tenant_id=tenant_id, app_id=app_id, sandbox_id=sandbox_id)
self._runtime = runtime
def exists(self) -> bool:
"""Check if the sandbox runtime exists and is available."""
return self._runtime is not None
def list_files(self, *, path: str, recursive: bool) -> list[SandboxFileNode]:
try:
results = (
pipeline(self._runtime)
.add(
build_list_command(path, recursive),
error_message="Failed to list sandbox files",
)
.execute(timeout=self._LIST_TIMEOUT_SECONDS, raise_on_error=True)
)
except PipelineExecutionError as exc:
raise RuntimeError(str(exc)) from exc
raw = parse_list_output(results[0].stdout)
entries: list[SandboxFileNode] = []
for item in raw:
item_path = str(item.get("path"))
item_is_dir = bool(item.get("is_dir"))
extension = None
if not item_is_dir:
ext = os.path.splitext(item_path)[1]
extension = ext or None
entries.append(
SandboxFileNode(
path=item_path,
is_dir=item_is_dir,
size=item.get("size"),
mtime=item.get("mtime"),
extension=extension,
)
)
return entries
def download_file(self, *, path: str) -> SandboxFileDownloadTicket:
from services.sandbox.sandbox_file_service import SandboxFileService
try:
results = (
pipeline(self._runtime)
.add(
build_detect_kind_command(path),
error_message="Failed to check path in sandbox",
)
.execute(timeout=self._LIST_TIMEOUT_SECONDS, raise_on_error=True)
)
except PipelineExecutionError as exc:
raise ValueError(str(exc)) from exc
kind = parse_kind_output(results[0].stdout, not_found_message="File not found in sandbox")
export_name = os.path.basename(path.rstrip("/")) or "workspace"
filename = f"{export_name}.tar.gz" if kind == "dir" else (os.path.basename(path) or "file")
export_id = uuid4().hex
export_key = SandboxFilePaths.export(
self._tenant_id,
self._app_id,
self._sandbox_id,
export_id,
)
sandbox_storage = SandboxFileService.get_storage()
upload_url = sandbox_storage.get_upload_url(export_key, self._EXPORT_EXPIRES_IN_SECONDS)
content_type = guess_content_type(filename)
if kind == "dir":
archive_path = f"/tmp/{export_id}.tar.gz"
try:
(
pipeline(self._runtime)
.add(
["tar", "-czf", archive_path, "-C", ".", path],
error_message="Failed to archive directory in sandbox",
)
.add(
build_upload_command(archive_path, upload_url, content_type=content_type),
error_message="Failed to upload directory archive from sandbox",
)
.execute(timeout=self._UPLOAD_TIMEOUT_SECONDS, raise_on_error=True)
)
except PipelineExecutionError as exc:
raise RuntimeError(str(exc)) from exc
finally:
try:
pipeline(self._runtime).add(["rm", "-f", archive_path]).execute(timeout=self._LIST_TIMEOUT_SECONDS)
except Exception as exc:
# Best-effort cleanup; do not fail the download on cleanup issues.
logger.debug("Failed to cleanup temp archive %s: %s", archive_path, exc)
else:
try:
(
pipeline(self._runtime)
.add(
build_upload_command(path, upload_url, content_type=content_type),
error_message="Failed to upload file from sandbox",
)
.execute(timeout=self._UPLOAD_TIMEOUT_SECONDS, raise_on_error=True)
)
except PipelineExecutionError as exc:
raise RuntimeError(str(exc)) from exc
download_url = sandbox_storage.get_download_url(export_key, self._EXPORT_EXPIRES_IN_SECONDS)
return SandboxFileDownloadTicket(
download_url=download_url,
expires_in=self._EXPORT_EXPIRES_IN_SECONDS,
export_id=export_id,
)

View File

@ -0,0 +1,118 @@
"""Shared helpers for sandbox inspector shell commands."""
from __future__ import annotations
import json
import mimetypes
from typing import TypedDict, cast
_PYTHON_EXEC_CMD = 'if command -v python3 >/dev/null 2>&1; then py=python3; else py=python; fi; "$py" -c "$0" "$@"'
_LIST_SCRIPT = r"""
import json
import os
import sys
path = sys.argv[1]
recursive = sys.argv[2] == "1"
def norm(rel: str) -> str:
rel = rel.replace("\\\\", "/")
rel = rel.lstrip("./")
return rel or "."
def stat_entry(full_path: str, rel_path: str) -> dict[str, object]:
st = os.stat(full_path)
is_dir = os.path.isdir(full_path)
return {
"path": norm(rel_path),
"is_dir": is_dir,
"size": None if is_dir else int(st.st_size),
"mtime": int(st.st_mtime),
}
entries = []
if recursive:
for root, dirs, files in os.walk(path):
for d in dirs:
fp = os.path.join(root, d)
rp = os.path.relpath(fp, ".")
entries.append(stat_entry(fp, rp))
for f in files:
fp = os.path.join(root, f)
rp = os.path.relpath(fp, ".")
entries.append(stat_entry(fp, rp))
else:
if os.path.isfile(path):
rel_path = os.path.relpath(path, ".")
entries.append(stat_entry(path, rel_path))
else:
for item in os.scandir(path):
rel_path = os.path.relpath(item.path, ".")
entries.append(stat_entry(item.path, rel_path))
print(json.dumps(entries))
"""
class ListedEntry(TypedDict):
path: str
is_dir: bool
size: int | None
mtime: int
def build_list_command(path: str, recursive: bool) -> list[str]:
return [
"sh",
"-c",
_PYTHON_EXEC_CMD,
_LIST_SCRIPT,
path,
"1" if recursive else "0",
]
def parse_list_output(stdout: bytes) -> list[ListedEntry]:
try:
raw = json.loads(stdout.decode("utf-8"))
except Exception as exc:
raise RuntimeError("Malformed sandbox file list output") from exc
if not isinstance(raw, list):
raise RuntimeError("Malformed sandbox file list output")
return cast(list[ListedEntry], raw)
def build_detect_kind_command(path: str) -> list[str]:
return [
"sh",
"-c",
'if [ -d "$1" ]; then echo dir; elif [ -f "$1" ]; then echo file; else exit 2; fi',
"sh",
path,
]
def parse_kind_output(stdout: bytes, *, not_found_message: str) -> str:
kind = stdout.decode("utf-8", errors="replace").strip()
if kind not in ("dir", "file"):
raise ValueError(not_found_message)
return kind
def guess_content_type(filename: str) -> str | None:
content_type, _ = mimetypes.guess_type(filename, strict=False)
if content_type is None:
return None
if content_type.startswith("text/"):
return f"{content_type}; charset=utf-8"
if content_type == "application/json":
return "application/json; charset=utf-8"
return content_type
def build_upload_command(src_path: str, upload_url: str, *, content_type: str | None) -> list[str]:
command = ["curl", "-s", "-f", "-X", "PUT", "-T", src_path]
if content_type:
command.extend(["-H", f"Content-Type: {content_type}"])
command.append(upload_url)
return command

133
api/core/sandbox/sandbox.py Normal file
View File

@ -0,0 +1,133 @@
from __future__ import annotations
import logging
import threading
from typing import TYPE_CHECKING
from uuid import uuid4
from libs.attr_map import AttrMap
if TYPE_CHECKING:
from core.sandbox.storage.sandbox_storage import SandboxStorage
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
logger = logging.getLogger(__name__)
class Sandbox:
"""Represents a single sandbox environment.
Each ``Sandbox`` owns a stable, path-safe ``id`` (a 32-char hex
UUID4) that is independent of the underlying provider's environment
ID. Use ``sandbox.id`` for any path or resource namespacing
(e.g. ``DifyCli(sandbox.id)``).
The raw provider identifier is still accessible via
``sandbox.vm.metadata.id`` when needed (logging, API calls back to
the provider, etc.).
"""
def __init__(
self,
*,
vm: VirtualEnvironment,
storage: SandboxStorage,
tenant_id: str,
user_id: str,
app_id: str,
assets_id: str,
) -> None:
self._id = uuid4().hex
self._vm = vm
self._storage = storage
self._tenant_id = tenant_id
self._user_id = user_id
self._app_id = app_id
self._assets_id = assets_id
self._attributes = AttrMap()
self._ready_event = threading.Event()
self._cancel_event = threading.Event()
self._init_error: Exception | None = None
@property
def id(self) -> str:
"""Stable, path-safe identifier for this sandbox (UUID4 hex)."""
return self._id
@property
def attrs(self) -> AttrMap:
return self._attributes
@property
def vm(self) -> VirtualEnvironment:
return self._vm
@property
def storage(self) -> SandboxStorage:
return self._storage
@property
def tenant_id(self) -> str:
return self._tenant_id
@property
def user_id(self) -> str:
return self._user_id
@property
def app_id(self) -> str:
return self._app_id
@property
def assets_id(self) -> str:
return self._assets_id
def mark_ready(self) -> None:
# Signal that sandbox initialization has completed successfully.
self._ready_event.set()
def mark_failed(self, error: Exception) -> None:
# Capture initialization error and unblock waiters.
self._init_error = error
self._ready_event.set()
def cancel_init(self) -> None:
# Mark initialization as cancelled to stop background setup.
self._cancel_event.set()
self._ready_event.set()
def is_cancelled(self) -> bool:
return self._cancel_event.is_set()
def wait_ready(self, timeout: float | None = None) -> None:
# Block until initialization completes, fails, or is cancelled.
if not self._ready_event.wait(timeout=timeout):
raise TimeoutError("Sandbox initialization timed out")
if self._cancel_event.is_set():
raise RuntimeError("Sandbox initialization was cancelled")
if self._init_error is not None:
if isinstance(self._init_error, ValueError):
raise RuntimeError(f"Sandbox initialization failed: {self._init_error}") from self._init_error
else:
raise RuntimeError("Sandbox initialization failed") from self._init_error
def mount(self) -> bool:
return self._storage.mount(self._vm)
def unmount(self) -> bool:
return self._storage.unmount(self._vm)
def release(self) -> None:
self.cancel_init()
sandbox_id = self.id
try:
self._storage.unmount(self._vm)
logger.info("Sandbox storage unmounted: sandbox_id=%s", sandbox_id)
except Exception:
logger.exception("Failed to unmount sandbox storage: sandbox_id=%s", sandbox_id)
try:
self._vm.release_environment()
logger.info("Sandbox released: sandbox_id=%s", sandbox_id)
except Exception:
logger.exception("Failed to release sandbox: sandbox_id=%s", sandbox_id)

View File

@ -0,0 +1,3 @@
from .asset_download_service import AssetDownloadService
__all__ = ["AssetDownloadService"]

View File

@ -0,0 +1,140 @@
"""Shell script builder for downloading / writing assets into a sandbox VM.
Generates a self-contained POSIX shell script that handles two kinds of
``SandboxDownloadItem``:
- Items with *content* written via base64 heredoc (sequential).
- Items with *url* fetched via ``curl``/``wget``/``python3`` with
auto-detection, run as parallel background jobs.
Both kinds can be mixed freely in a single call.
"""
from __future__ import annotations
import base64
import shlex
import textwrap
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from core.zip_sandbox.entities import SandboxDownloadItem
def _build_inline_commands(items: list[SandboxDownloadItem], root_var: str) -> str:
"""Generate shell commands that write base64-encoded content to files."""
lines: list[str] = []
for idx, item in enumerate(items):
assert item.content is not None
dest = f"${{{root_var}}}/{shlex.quote(item.path)}"
encoded = base64.b64encode(item.content).decode("ascii")
lines.append(f'mkdir -p "$(dirname "{dest}")"')
lines.append(f"base64 -d <<'_INLINE_{idx}' > \"{dest}\"")
lines.append(encoded)
lines.append(f"_INLINE_{idx}")
return "\n".join(lines)
def _render_download_script(
root_path: str,
inline_commands: str,
download_commands: str,
need_downloader: bool,
) -> str:
python_download_cmd = (
'python3 - "${url}" "${dest}" <<"PY"\n'
"import sys\n"
"import urllib.request\n"
"url = sys.argv[1]\n"
"dest = sys.argv[2]\n"
"with urllib.request.urlopen(url) as resp:\n"
" data = resp.read()\n"
'with open(dest, "wb") as f:\n'
" f.write(data)\n"
"PY"
)
# Only emit the downloader-detection block when there are remote items.
if need_downloader:
downloader_block = f"""\
if command -v curl >/dev/null 2>&1; then
download_cmd='curl -fsSL "${{url}}" -o "${{dest}}"'
elif command -v wget >/dev/null 2>&1; then
download_cmd='wget -q "${{url}}" -O "${{dest}}"'
elif command -v python3 >/dev/null 2>&1; then
download_cmd={shlex.quote(python_download_cmd)}
else
echo 'No downloader found (curl/wget/python3)' >&2
exit 1
fi
fail_log="$(mktemp)"
download_one() {{
file_path="$1"
url="$2"
dest="${{download_root}}/${{file_path}}"
mkdir -p "$(dirname "${{dest}}")"
eval "${{download_cmd}}" 2>/dev/null || echo "${{file_path}}" >> "${{fail_log}}"
}}"""
else:
downloader_block = ""
# The failure-check block is only meaningful when downloads occurred.
if need_downloader:
wait_block = textwrap.dedent("""\
wait
if [ -s "${fail_log}" ]; then
mv "${fail_log}" "${download_root}/DOWNLOAD_FAILURES.txt"
else
rm -f "${fail_log}"
fi""")
else:
wait_block = ""
script = f"""\
download_root={shlex.quote(root_path)}
mkdir -p "${{download_root}}"
{downloader_block}
{inline_commands}
{download_commands}
{wait_block}
exit 0"""
return script
class AssetDownloadService:
@staticmethod
def build_download_script(
items: list[SandboxDownloadItem],
root_path: str,
) -> str:
"""Build a portable shell script to write inline assets and download remote ones.
Items with *content* are written first (sequential base64 decode),
then items with *url* are fetched in parallel background jobs.
The two kinds can be mixed freely in a single list.
"""
inline = [item for item in items if item.content is not None]
remote = [item for item in items if item.content is None]
inline_commands = _build_inline_commands(inline, "download_root") if inline else ""
commands: list[str] = []
for item in remote:
path = shlex.quote(item.path)
url = shlex.quote(item.url)
commands.append(f"download_one {path} {url} &")
download_commands = "\n".join(commands)
return _render_download_script(
root_path,
inline_commands,
download_commands,
need_downloader=bool(remote),
)

View File

@ -0,0 +1,11 @@
from .archive_storage import ArchiveSandboxStorage
from .noop_storage import NoopSandboxStorage
from .sandbox_file_storage import SandboxFilePaths
from .sandbox_storage import SandboxStorage
__all__ = [
"ArchiveSandboxStorage",
"NoopSandboxStorage",
"SandboxFilePaths",
"SandboxStorage",
]

View File

@ -0,0 +1,83 @@
"""Archive-based sandbox storage for persisting sandbox state."""
from __future__ import annotations
import logging
from core.virtual_environment.__base.helpers import pipeline
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
from extensions.storage.base_storage import BaseStorage
from extensions.storage.cached_presign_storage import CachedPresignStorage
from extensions.storage.file_presign_storage import FilePresignStorage
from .sandbox_file_storage import SandboxFilePaths
from .sandbox_storage import SandboxStorage
logger = logging.getLogger(__name__)
_ARCHIVE_TIMEOUT = 300 # 5 minutes
class ArchiveSandboxStorage(SandboxStorage):
"""Archive-based storage for sandbox workspace persistence."""
def __init__(
self,
tenant_id: str,
app_id: str,
sandbox_id: str,
storage: BaseStorage,
exclude_patterns: list[str] | None = None,
):
self._sandbox_id = sandbox_id
self._exclude_patterns = exclude_patterns or []
self._storage_key = SandboxFilePaths.archive(tenant_id, app_id, sandbox_id)
self._storage = CachedPresignStorage(
storage=FilePresignStorage(storage),
cache_key_prefix="sandbox_archives",
)
def mount(self, sandbox: VirtualEnvironment) -> bool:
"""Load archive from storage into sandbox workspace."""
if not self.exists():
logger.debug("No archive found for sandbox %s, skipping mount", self._sandbox_id)
return False
download_url = self._storage.get_download_url(self._storage_key, _ARCHIVE_TIMEOUT)
archive = "archive.tar.gz"
(
pipeline(sandbox)
.add(["curl", "-fsSL", download_url, "-o", archive], error_message="Failed to download archive")
.add(["sh", "-c", 'tar -xzf "$1" 2>/dev/null; exit $?', "sh", archive], error_message="Failed to extract")
.add(["rm", archive], error_message="Failed to cleanup")
.execute(timeout=_ARCHIVE_TIMEOUT, raise_on_error=True)
)
logger.info("Mounted archive for sandbox %s", self._sandbox_id)
return True
def unmount(self, sandbox: VirtualEnvironment) -> bool:
"""Save sandbox workspace to storage as archive."""
upload_url = self._storage.get_upload_url(self._storage_key, _ARCHIVE_TIMEOUT)
archive = f"/tmp/{self._sandbox_id}.tar.gz"
exclude_args = [f"--exclude={p}" for p in self._exclude_patterns]
(
pipeline(sandbox)
.add(["tar", "-czf", archive, *exclude_args, "-C", ".", "."], error_message="Failed to create archive")
.add(["curl", "-sf", "-X", "PUT", "-T", archive, upload_url], error_message="Failed to upload archive")
.execute(timeout=_ARCHIVE_TIMEOUT, raise_on_error=True)
)
logger.info("Unmounted archive for sandbox %s", self._sandbox_id)
return True
def exists(self) -> bool:
return self._storage.exists(self._storage_key)
def delete(self) -> None:
try:
self._storage.delete(self._storage_key)
logger.info("Deleted archive for sandbox %s", self._sandbox_id)
except Exception:
logger.exception("Failed to delete archive for sandbox %s", self._sandbox_id)

View File

@ -0,0 +1,18 @@
from core.sandbox.storage.sandbox_storage import SandboxStorage
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
class NoopSandboxStorage(SandboxStorage):
"""A no-op storage implementation that does nothing on mount/unmount."""
def mount(self, sandbox: VirtualEnvironment) -> bool:
return True
def unmount(self, sandbox: VirtualEnvironment) -> bool:
return True
def exists(self) -> bool:
return False
def delete(self) -> None:
return

View File

@ -0,0 +1,21 @@
"""Sandbox file storage key generation.
Provides SandboxFilePaths facade for generating storage keys for sandbox files.
Storage instances are obtained via SandboxFileService.get_storage().
"""
from __future__ import annotations
class SandboxFilePaths:
"""Facade for generating sandbox file storage keys."""
@staticmethod
def export(tenant_id: str, app_id: str, sandbox_id: str, export_id: str) -> str:
"""sandbox_files/{tenant}/{app}/{sandbox}/{export_id}/{filename}"""
return f"sandbox_files/{tenant_id}/{app_id}/{sandbox_id}/{export_id}"
@staticmethod
def archive(tenant_id: str, app_id: str, sandbox_id: str) -> str:
"""sandbox_archives/{tenant}/{app}/{sandbox}.tar.gz"""
return f"sandbox_archives/{tenant_id}/{app_id}/{sandbox_id}.tar.gz"

View File

@ -0,0 +1,21 @@
from abc import ABC, abstractmethod
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
class SandboxStorage(ABC):
@abstractmethod
def mount(self, sandbox: VirtualEnvironment) -> bool:
"""Load files from storage into VM. Returns True if files were loaded."""
@abstractmethod
def unmount(self, sandbox: VirtualEnvironment) -> bool:
"""Save files from VM to storage. Returns True if files were saved."""
@abstractmethod
def exists(self) -> bool:
"""Check if storage has saved data."""
@abstractmethod
def delete(self) -> None:
"""Delete saved data from storage."""

View File

@ -0,0 +1,2 @@
# Sandbox utilities
# Connection helpers have been moved to core.virtual_environment.helpers

View File

@ -0,0 +1,22 @@
"""Sandbox debug utilities. TODO: Remove this module when sandbox debugging is complete."""
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
pass
SANDBOX_DEBUG_ENABLED = True
def sandbox_debug(tag: str, message: str, data: Any = None) -> None:
if not SANDBOX_DEBUG_ENABLED:
return
# Lazy import to avoid circular dependency
from core.callback_handler.agent_tool_callback_handler import print_text
print_text(f"\n[{tag}]\n", color="blue")
if data is not None:
print_text(f"{message}: {data}\n", color="blue")
else:
print_text(f"{message}\n", color="blue")

View File

@ -0,0 +1,48 @@
from collections.abc import Mapping
from typing import Any
from core.entities.provider_entities import BasicProviderConfig
from core.helper.provider_cache import ProviderCredentialsCache
from core.helper.provider_encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
class SandboxProviderConfigCache(ProviderCredentialsCache):
def __init__(self, tenant_id: str, provider_type: str):
super().__init__(tenant_id=tenant_id, provider_type=provider_type)
def _generate_cache_key(self, **kwargs) -> str:
tenant_id = kwargs["tenant_id"]
provider_type = kwargs["provider_type"]
return f"sandbox_config:tenant_id:{tenant_id}:provider_type:{provider_type}"
def create_sandbox_config_encrypter(
tenant_id: str,
config_schema: list[BasicProviderConfig],
provider_type: str,
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
cache = SandboxProviderConfigCache(tenant_id=tenant_id, provider_type=provider_type)
return create_provider_encrypter(tenant_id=tenant_id, config=config_schema, cache=cache)
def masked_config(
schemas: list[BasicProviderConfig],
config: Mapping[str, Any],
) -> Mapping[str, Any]:
masked = dict(config)
configs = {x.name: x for x in schemas}
for key, value in config.items():
schema = configs.get(key)
if not schema:
masked[key] = value
continue
if schema.type == BasicProviderConfig.Type.SECRET_INPUT:
if not isinstance(value, str):
continue
if len(value) <= 4:
masked[key] = "*" * len(value)
else:
masked[key] = value[:2] + "*" * (len(value) - 4) + value[-2:]
else:
masked[key] = value
return masked

View File

@ -0,0 +1,11 @@
from .constants import SkillAttrs
from .entities import ToolDependencies, ToolDependency, ToolReference
from .skill_manager import SkillManager
__all__ = [
"SkillAttrs",
"SkillManager",
"ToolDependencies",
"ToolDependency",
"ToolReference",
]

View File

@ -0,0 +1,6 @@
from core.skill.assembler.assemblers import SkillBundleAssembler, SkillDocumentAssembler
__all__ = [
"SkillBundleAssembler",
"SkillDocumentAssembler",
]

View File

@ -0,0 +1,80 @@
from collections.abc import Mapping
from core.app.entities.app_asset_entities import AppAssetFileTree
from core.skill.assembler.common import (
build_skill_graph,
compute_transitive_dependance,
expand_referenced_skill_ids,
get_metadata,
process_skill_content,
)
from core.skill.entities.skill_bundle import Skill, SkillBundle, SkillDependance
from core.skill.entities.skill_document import SkillDocument
class SkillBundleAssembler:
_file_tree: AppAssetFileTree
def __init__(self, file_tree: AppAssetFileTree) -> None:
self._file_tree = file_tree
def assemble_bundle(
self,
documents: Mapping[str, SkillDocument],
assets_id: str,
) -> SkillBundle:
direct_skills: dict[str, Skill] = {}
for skill_id, doc in documents.items():
metadata = get_metadata(doc.content, doc.metadata)
direct_dependance = SkillDependance.from_metadata(metadata)
direct_skills[skill_id] = Skill(
skill_id=skill_id,
direct_dependance=direct_dependance,
dependance=direct_dependance,
content=process_skill_content(doc.content, metadata, self._file_tree, skill_id),
)
graph = build_skill_graph(direct_skills, self._file_tree)
transitive_map = compute_transitive_dependance(direct_skills, graph)
compiled_skills: dict[str, Skill] = {}
for skill_id, skill in direct_skills.items():
compiled_skills[skill_id] = skill.model_copy(update={"dependance": transitive_map[skill_id]})
return SkillBundle(asset_tree=self._file_tree, assets_id=assets_id, skills=compiled_skills)
class SkillDocumentAssembler:
_bundle: SkillBundle
def __init__(self, bundle: SkillBundle) -> None:
self._bundle = bundle
def assemble_document(self, document: SkillDocument, base_path: str = "") -> Skill:
metadata = get_metadata(document.content, document.metadata)
direct_dependance = SkillDependance.from_metadata(metadata)
resolved_content = process_skill_content(
document.content,
metadata,
self._bundle.asset_tree,
document.skill_id,
base_path,
)
transitive_dependance = direct_dependance
known_skill_ids = set(self._bundle.skills.keys())
referenced_skill_ids = expand_referenced_skill_ids(
direct_dependance.files, known_skill_ids, self._bundle.asset_tree
)
for skill_id in sorted(referenced_skill_ids):
referenced_skill = self._bundle.get(skill_id)
if referenced_skill is None:
continue
transitive_dependance = transitive_dependance | referenced_skill.dependance
return Skill(
skill_id=document.skill_id,
direct_dependance=direct_dependance,
dependance=transitive_dependance,
content=resolved_content,
)

View File

@ -0,0 +1,136 @@
from collections import deque
from collections.abc import Mapping
from core.app.entities.app_asset_entities import AppAssetFileTree, AssetNodeType
from core.skill.assembler.replacers import (
FILE_PATTERN,
TOOL_METADATA_PATTERN,
FileReplacer,
Replacer,
ToolGroupReplacer,
ToolReplacer,
)
from core.skill.entities.skill_bundle import Skill, SkillDependance
from core.skill.entities.skill_metadata import FileReference, SkillMetadata, ToolReference
def process_skill_content(
content: str,
metadata: SkillMetadata,
file_tree: AppAssetFileTree,
current_id: str,
base_path: str = "",
) -> str:
"""Resolve all placeholders in content through the ordered replacer pipeline."""
replacers: list[Replacer] = [
FileReplacer(file_tree, current_id, base_path),
ToolGroupReplacer(metadata),
ToolReplacer(metadata),
]
for replacer in replacers:
content = replacer.resolve(content)
return content
def get_metadata(content: str, metadata: SkillMetadata) -> SkillMetadata:
"""Parse effective metadata from content placeholders and raw metadata."""
tools: dict[str, ToolReference] = {}
# find all tool refs actually used in content
for match in TOOL_METADATA_PATTERN.finditer(content):
provider, name, uuid = match.group(1), match.group(2), match.group(3)
tool_ref = metadata.tools.get(uuid)
if tool_ref is None:
raise ValueError(f"Tool reference with UUID {uuid} not found in metadata")
tool_ref.uuid = uuid
tool_ref.tool_name = name
tool_ref.provider = provider
tools[uuid] = tool_ref
# find all file refs
files: set[FileReference] = set()
for match in FILE_PATTERN.finditer(content):
source, asset_id = match.group(1), match.group(2)
files.add(FileReference(source=source, asset_id=asset_id))
return SkillMetadata(tools=tools, files=files)
def build_skill_graph(skills: Mapping[str, Skill], file_tree: AppAssetFileTree) -> dict[str, set[str]]:
"""Build adjacency list: skill_id -> referenced skill IDs."""
known_skill_ids = set(skills.keys())
graph: dict[str, set[str]] = {skill_id: set() for skill_id in known_skill_ids}
for skill_id, skill in skills.items():
graph[skill_id] = expand_referenced_skill_ids(skill.direct_dependance.files, known_skill_ids, file_tree)
return graph
def compute_transitive_dependance(
skills: Mapping[str, Skill],
graph: Mapping[str, set[str]],
) -> dict[str, SkillDependance]:
"""Compute transitive dependency closure with fixed-point iteration."""
dependance_map = {skill_id: skill.direct_dependance for skill_id, skill in skills.items()}
changed = True
while changed:
changed = False
for skill_id in sorted(skills.keys()):
merged = dependance_map[skill_id]
for dep_skill_id in sorted(graph.get(skill_id, set())):
if dep_skill_id == skill_id:
continue
merged = merged | dependance_map[dep_skill_id]
if merged != dependance_map[skill_id]:
dependance_map[skill_id] = merged
changed = True
return dependance_map
def expand_referenced_skill_ids(
refs: set[FileReference],
known_skill_ids: set[str],
file_tree: AppAssetFileTree,
) -> set[str]:
"""Resolve file/folder references to concrete known skill IDs."""
resolved: set[str] = set()
for ref in refs:
node = file_tree.get(ref.asset_id)
if node is None:
continue
if node.node_type == AssetNodeType.FILE:
if node.id in known_skill_ids:
resolved.add(node.id)
continue
descendant_ids = file_tree.get_descendant_ids(node.id)
for descendant_id in descendant_ids:
descendant = file_tree.get(descendant_id)
if descendant is None or descendant.node_type != AssetNodeType.FILE:
continue
if descendant_id in known_skill_ids:
resolved.add(descendant_id)
return resolved
def collect_transitive_skill_ids(
root_skill_ids: set[str],
graph: Mapping[str, set[str]],
) -> set[str]:
"""Collect all transitively reachable skill IDs from roots via BFS."""
visited: set[str] = set()
queue = deque(sorted(root_skill_ids))
while queue:
current = queue.popleft()
if current in visited:
continue
visited.add(current)
for next_skill_id in sorted(graph.get(current, set())):
if next_skill_id not in visited:
queue.append(next_skill_id)
return visited

View File

@ -0,0 +1,108 @@
"""Placeholder replacers for skill content.
Each replacer handles one category of ``§[...]§`` placeholder via the unified
``Replacer`` protocol. The shared ``resolve_content`` pipeline in
``core.skill.assembler.common`` builds a ``list[Replacer]`` and applies them
in order:
``FileReplacer`` ``ToolGroupReplacer`` ``ToolReplacer``
``ToolGroupReplacer`` MUST run before ``ToolReplacer`` so that group brackets
``[§[tool]...§, §[tool]...§]`` are resolved atomically; otherwise individual
tool replacement would destroy the group structure.
"""
import re
from typing import Protocol
from core.app.entities.app_asset_entities import AppAssetFileTree
from core.skill.entities.skill_metadata import SkillMetadata
TOOL_METADATA_PATTERN: re.Pattern[str] = re.compile(r"§\[tool\]\.\[([^\]]+)\]\.\[([^\]]+)\]\.\[([^\]]+)\")
TOOL_PATTERN: re.Pattern[str] = re.compile(r"§\[tool\]\.\[.*?\]\.\[.*?\]\.\[(.*?)\")
TOOL_GROUP_PATTERN: re.Pattern[str] = re.compile(
r"\[\s*§\[tool\]\.\[[^\]]+\]\.\[[^\]]+\]\.\[[^\]]+\"
r"(?:\s*,\s*§\[tool\]\.\[[^\]]+\]\.\[[^\]]+\]\.\[[^\]]+\]§)*\s*\]"
)
FILE_PATTERN: re.Pattern[str] = re.compile(r"§\[file\]\.\[([^\]]+)\]\.\[([^\]]+)\")
class Replacer(Protocol):
def resolve(self, content: str) -> str: ...
class FileReplacer:
_tree: AppAssetFileTree
_current_id: str
_base_path: str
def __init__(self, tree: AppAssetFileTree, current_id: str, base_path: str = "") -> None:
self._tree = tree
self._current_id = current_id
self._base_path = base_path.rstrip("/")
def resolve(self, content: str) -> str:
return FILE_PATTERN.sub(self._replace_match, content)
def _replace_match(self, match: re.Match[str]) -> str:
target_id = match.group(2)
source_node = self._tree.get(self._current_id)
target_node = self._tree.get(target_id)
if target_node is None:
return "[File not found]"
if source_node is not None:
return self._tree.relative_path(source_node, target_node)
full_path = self._tree.get_path(target_node.id)
if self._base_path:
return f"{self._base_path}/{full_path}"
return full_path
class ToolReplacer:
_metadata: SkillMetadata
def __init__(self, metadata: SkillMetadata) -> None:
self._metadata = metadata
def resolve(self, content: str) -> str:
return TOOL_PATTERN.sub(self._replace_match, content)
def _replace_match(self, match: re.Match[str]) -> str:
tool_id = match.group(1)
tool_ref = self._metadata.tools.get(tool_id)
if tool_ref is None:
return f"[Tool not found or disabled: {tool_id}]"
if not tool_ref.enabled:
return ""
return f"[Executable: {tool_ref.tool_name}_{tool_ref.uuid} --help command]"
class ToolGroupReplacer:
_metadata: SkillMetadata
def __init__(self, metadata: SkillMetadata) -> None:
self._metadata = metadata
def resolve(self, content: str) -> str:
return TOOL_GROUP_PATTERN.sub(self._replace_match, content)
def _replace_match(self, match: re.Match[str]) -> str:
group_text = match.group(0)
enabled_renders: list[str] = []
for tool_match in TOOL_PATTERN.finditer(group_text):
tool_id = tool_match.group(1)
tool_ref = self._metadata.tools.get(tool_id)
if tool_ref is None:
enabled_renders.append(f"[Tool not found or disabled: {tool_id}]")
continue
if not tool_ref.enabled:
continue
enabled_renders.append(f"[Executable: {tool_ref.tool_name}_{tool_ref.uuid} --help command]")
if not enabled_renders:
return ""
return "[" + ", ".join(enabled_renders) + "]"

View File

@ -0,0 +1,6 @@
from core.skill.entities.skill_bundle import SkillBundle
from libs.attr_map import AttrKey
class SkillAttrs:
BUNDLE = AttrKey("skill_bundle", SkillBundle)

View File

@ -0,0 +1,29 @@
from .skill_bundle import Skill, SkillBundle, SkillDependance
from .skill_document import SkillDocument
from .skill_metadata import (
FileReference,
SkillMetadata,
ToolConfiguration,
ToolFieldConfig,
ToolReference,
)
from .tool_access_policy import ToolAccessDescription, ToolAccessPolicy, ToolDescription, ToolInvocationRequest
from .tool_dependencies import ToolDependencies, ToolDependency
__all__ = [
"FileReference",
"Skill",
"SkillBundle",
"SkillDependance",
"SkillDocument",
"SkillMetadata",
"ToolAccessDescription",
"ToolAccessPolicy",
"ToolConfiguration",
"ToolDependencies",
"ToolDependency",
"ToolDescription",
"ToolFieldConfig",
"ToolInvocationRequest",
"ToolReference",
]

View File

@ -0,0 +1,16 @@
from pydantic import BaseModel, Field
from core.skill.entities.tool_dependencies import ToolDependency
class NodeSkillInfo(BaseModel):
"""Information about skills referenced by a workflow node.
Used by the whole-workflow skills endpoint to return per-node
tool dependency information.
"""
node_id: str = Field(description="The node ID")
tool_dependencies: list[ToolDependency] = Field(
default_factory=list, description="Tool dependencies extracted from skill prompts"
)

View File

@ -0,0 +1,94 @@
from typing import TYPE_CHECKING
from pydantic import BaseModel, ConfigDict, Field
from core.app.entities.app_asset_entities import AppAssetFileTree
from core.skill.entities.skill_metadata import FileReference
from core.skill.entities.tool_dependencies import ToolDependencies, ToolDependency
if TYPE_CHECKING:
from core.skill.entities.skill_metadata import SkillMetadata
class SkillDependance(BaseModel):
model_config = ConfigDict(extra="forbid")
tools: ToolDependencies = Field(description="Direct tool dependencies parsed from this skill only")
files: set[FileReference] = Field(
default_factory=set,
description="Direct file references parsed from this skill only",
)
def __or__(self, other: "SkillDependance") -> "SkillDependance":
return SkillDependance(tools=self.tools.merge(other.tools), files=self.files | other.files)
@staticmethod
def from_metadata(metadata: "SkillMetadata") -> "SkillDependance":
"""Convert parsed metadata into direct tool/file dependency model."""
from core.skill.entities.skill_metadata import ToolReference
dep_map: dict[str, ToolDependency] = {}
ref_map: dict[str, ToolReference] = {}
for tool_ref in metadata.tools.values():
dep_map.setdefault(
tool_ref.tool_id(),
ToolDependency(
type=tool_ref.type,
provider=tool_ref.provider,
tool_name=tool_ref.tool_name,
enabled=tool_ref.enabled,
),
)
ref_map.setdefault(tool_ref.uuid, tool_ref)
return SkillDependance(
tools=ToolDependencies(
dependencies=[dep_map[key] for key in sorted(dep_map.keys())],
references=[ref_map[key] for key in sorted(ref_map.keys())],
),
files=metadata.files,
)
class Skill(BaseModel):
model_config = ConfigDict(extra="forbid")
skill_id: str = Field(description="Unique identifier for this skill, same with skill_id")
direct_dependance: SkillDependance = Field(description="Direct dependencies parsed from this skill only")
dependance: SkillDependance = Field(description="All dependencies including transitive closure")
content: str = Field(description="Resolved content with all references replaced")
@property
def tools(self) -> ToolDependencies:
return self.dependance.tools
class SkillBundle(BaseModel):
model_config = ConfigDict(extra="forbid")
asset_tree: AppAssetFileTree = Field(description="Asset tree for this bundle")
assets_id: str = Field(description="Assets ID this bundle belongs to")
skills: dict[str, Skill] = Field(default_factory=dict)
@property
def entries(self) -> dict[str, Skill]:
return self.skills
def get(self, skill_id: str) -> Skill | None:
return self.skills.get(skill_id)
def get_tool_dependencies(self) -> ToolDependencies:
merged = ToolDependencies()
for skill in self.skills.values():
merged = merged.merge(skill.dependance.tools)
return merged
def put(self, skill: Skill) -> None:
self.skills[skill.skill_id] = skill

View File

@ -0,0 +1,17 @@
from pydantic import BaseModel, ConfigDict, Field
from core.skill.entities.skill_metadata import SkillMetadata
class SkillFile(BaseModel):
model_config = ConfigDict(extra="forbid")
class SkillDocument(BaseModel):
"""Input document for skill compilation."""
model_config = ConfigDict(extra="forbid")
skill_id: str = Field(description="Unique identifier, must match SkillAsset.asset_id")
content: str = Field(description="Raw content with reference placeholders")
metadata: SkillMetadata = Field(default_factory=SkillMetadata, description="Additional metadata for this skill")

View File

@ -0,0 +1,113 @@
from typing import Any
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from core.tools.entities.tool_entities import ToolProviderType
class ToolFieldConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
id: str
value: Any
auto: bool = False
class ToolConfiguration(BaseModel):
model_config = ConfigDict(extra="forbid")
fields: list[ToolFieldConfig] = Field(
default_factory=list, description="List of field configurations for this tool"
)
def default_values(self) -> dict[str, Any]:
return {field.id: field.value for field in self.fields if field.value is not None}
def create_tool_id(provider: str, tool_name: str) -> str:
return f"{provider}.{tool_name}"
class ToolReference(BaseModel):
model_config = ConfigDict(extra="forbid")
uuid: str = Field(
default="",
description=(
"Unique identifier for this tool reference, used to distinguish multiple references to the same tool"
),
)
type: ToolProviderType = Field(description="The provider type of the tool")
provider: str = Field(
default="",
description="The provider name of the tool plugin. Can be inferred from placeholders during compilation.",
)
tool_name: str = Field(
default="",
description=(
"The tool name defined in the provider plugin. Can be inferred from placeholders during compilation."
),
)
enabled: bool = Field(default=True, description="Whether this tool reference is enabled")
credential_id: str | None = Field(
default=None,
description="Credential ID used to resolve credentials when invoking the tool.",
)
configuration: ToolConfiguration | None = Field(
default=None,
description=(
"Optional configuration for this tool reference, used to provide "
"additional parameters when invoking the tool"
),
)
def reference_id(self) -> str:
return f"{self.provider}.{self.tool_name}.{self.uuid}"
def tool_id(self) -> str:
return f"{self.provider}.{self.tool_name}"
class FileReference(BaseModel):
model_config = ConfigDict(frozen=True)
source: str = Field(default="app")
asset_id: str
@model_validator(mode="before")
@classmethod
def normalize_input(cls, data: Any) -> Any:
if not isinstance(data, dict):
return data
if "asset_id" in data and "source" in data:
return {"source": data.get("source", "app"), "asset_id": data["asset_id"]}
# front end support
if "id" in data:
return {"source": "app", "asset_id": data["id"]}
return data
class SkillMetadata(BaseModel):
model_config = ConfigDict(extra="forbid")
tools: dict[str, ToolReference] = Field(default_factory=dict)
files: set[FileReference] = Field(default_factory=set)
@field_validator("files", mode="before")
@classmethod
def coerce_files_to_set(cls, v: Any) -> set[FileReference] | Any:
if isinstance(v, list):
refs: set[FileReference] = set()
for item in v:
if isinstance(item, dict):
refs.add(FileReference.model_validate(item))
elif isinstance(item, FileReference):
refs.add(item)
return refs
if isinstance(v, dict):
refs = set()
for item in v.values():
if isinstance(item, dict):
refs.add(FileReference.model_validate(item))
return refs
return v

View File

@ -0,0 +1,145 @@
from collections.abc import Mapping
from pydantic import BaseModel, ConfigDict, Field
from core.skill.entities.tool_dependencies import ToolDependencies
from core.tools.entities.tool_entities import ToolProviderType
class ToolDescription(BaseModel):
"""Immutable identifier for a tool (type + provider + name)."""
model_config = ConfigDict(frozen=True)
tool_type: ToolProviderType
provider: str
tool_name: str
def tool_id(self) -> str:
return f"{self.tool_type.value}:{self.provider}:{self.tool_name}"
class ToolAccessDescription(BaseModel):
"""
Per-tool access descriptor that bundles identity with allowed credentials.
Each allowed tool is represented by exactly one ``ToolAccessDescription``.
``allowed_credentials`` captures the set of credential IDs that may be used
when invoking this tool:
* **empty set** the tool requires no special credential; only requests
*without* a ``credential_id`` are accepted.
* **non-empty set** the tool requires an explicit credential; the
request's ``credential_id`` must be a member of this set.
"""
model_config = ConfigDict(frozen=True)
tool_type: ToolProviderType
provider: str
tool_name: str
allowed_credentials: frozenset[str] = Field(default_factory=frozenset)
def tool_id(self) -> str:
return f"{self.tool_type.value}:{self.provider}:{self.tool_name}"
def is_credential_allowed(self, credential_id: str | None) -> bool:
"""Check whether *credential_id* satisfies this tool's credential policy.
* No credentials registered (``allowed_credentials`` is empty)
only requests *without* a credential are accepted.
* Credentials registered the supplied ``credential_id`` must be in
the set.
"""
if credential_id is None or credential_id == "":
return True
return credential_id in self.allowed_credentials
class ToolInvocationRequest(BaseModel):
"""A request to invoke a specific tool with optional credential."""
model_config = ConfigDict(frozen=True)
tool_type: ToolProviderType
provider: str
tool_name: str
credential_id: str | None = None
@property
def tool_description(self) -> ToolDescription:
return ToolDescription(tool_type=self.tool_type, provider=self.provider, tool_name=self.tool_name)
class ToolAccessPolicy(BaseModel):
"""
Determines whether a tool invocation is allowed based on ToolDependencies.
The policy is built exclusively from ``ToolDependencies.references`` each
``ToolReference`` declares both the tool identity *and* the credential that
may be used. ``ToolDependencies.dependencies`` is a de-duplicated identity
list and does not participate in access-control decisions.
Rules:
1. The tool must appear in at least one reference.
2. If references for the tool carry credential IDs, the request must supply
one of those exact IDs.
3. If no reference for the tool carries a credential ID, the request must
*not* supply one (use default/ambient credentials).
"""
model_config = ConfigDict(frozen=True)
access_map: Mapping[str, ToolAccessDescription] = Field(default_factory=dict)
@classmethod
def from_dependencies(cls, deps: ToolDependencies | None) -> "ToolAccessPolicy":
"""Build a policy from ``ToolDependencies``.
Only ``deps.references`` are used. Multiple references to the same
tool are merged their credential IDs are unioned into a single
``ToolAccessDescription.allowed_credentials`` set.
"""
if deps is None or deps.is_empty():
return cls()
# Accumulate credential sets keyed by tool_id so that multiple
# references to the same tool are merged correctly.
credentials_by_tool: dict[str, set[str]] = {}
first_seen: dict[str, tuple[ToolProviderType, str, str]] = {}
for ref in deps.references:
tool_id = f"{ref.type.value}:{ref.provider}:{ref.tool_name}"
if tool_id not in first_seen:
first_seen[tool_id] = (ref.type, ref.provider, ref.tool_name)
credentials_by_tool[tool_id] = set()
if ref.credential_id is not None:
credentials_by_tool[tool_id].add(ref.credential_id)
access_map: dict[str, ToolAccessDescription] = {}
for tool_id, (tool_type, provider, tool_name) in first_seen.items():
access_map[tool_id] = ToolAccessDescription(
tool_type=tool_type,
provider=provider,
tool_name=tool_name,
allowed_credentials=frozenset(credentials_by_tool[tool_id]),
)
return cls(access_map=access_map)
def is_empty(self) -> bool:
return len(self.access_map) == 0
def is_allowed(self, request: ToolInvocationRequest) -> bool:
"""Check if the tool invocation request is allowed."""
# An empty policy (no references declared) permits any invocation.
if self.is_empty():
return True
tool_id = request.tool_description.tool_id()
access_desc = self.access_map.get(tool_id)
if access_desc is None:
return False
return access_desc.is_credential_allowed(request.credential_id)

View File

@ -0,0 +1,78 @@
from pydantic import BaseModel, ConfigDict, Field
from core.skill.entities.skill_metadata import ToolReference
from core.tools.entities.tool_entities import ToolProviderType
class ToolDependency(BaseModel):
model_config = ConfigDict(extra="forbid")
type: ToolProviderType
provider: str
tool_name: str
enabled: bool = True
def tool_id(self) -> str:
return f"{self.provider}.{self.tool_name}"
class ToolDependencies(BaseModel):
model_config = ConfigDict(extra="forbid")
dependencies: list[ToolDependency] = Field(default_factory=list)
references: list[ToolReference] = Field(default_factory=list)
def is_empty(self) -> bool:
return not self.dependencies and not self.references
def filter(self, tools: list[tuple[str, str]]) -> "ToolDependencies":
tool_names = {f"{provider}.{tool_name}" for provider, tool_name in tools}
return ToolDependencies(
dependencies=[
dependency
for dependency in self.dependencies
if f"{dependency.provider}.{dependency.tool_name}" in tool_names
],
references=[
reference
for reference in self.references
if f"{reference.provider}.{reference.tool_name}" in tool_names
],
)
def merge(self, other: "ToolDependencies") -> "ToolDependencies":
dep_map: dict[str, ToolDependency] = {}
for dep in self.dependencies:
key = f"{dep.provider}.{dep.tool_name}"
dep_map[key] = dep
for dep in other.dependencies:
key = f"{dep.provider}.{dep.tool_name}"
if key not in dep_map:
dep_map[key] = dep
ref_map: dict[str, ToolReference] = {}
for ref in self.references:
ref_map[ref.uuid] = ref
for ref in other.references:
if ref.uuid not in ref_map:
ref_map[ref.uuid] = ref
return ToolDependencies(
dependencies=list(dep_map.values()),
references=list(ref_map.values()),
)
def remove_tools(self, tools: list[ToolDependency]) -> "ToolDependencies":
tool_keys = {f"{tool.provider}.{tool.tool_name}" for tool in tools}
return ToolDependencies(
dependencies=[
dependency
for dependency in self.dependencies
if f"{dependency.provider}.{dependency.tool_name}" not in tool_keys
],
references=[
reference
for reference in self.references
if f"{reference.provider}.{reference.tool_name}" not in tool_keys
],
)

View File

@ -0,0 +1,43 @@
import logging
from core.app_assets.storage import AssetPaths
from core.skill.entities.skill_bundle import SkillBundle
from extensions.ext_redis import redis_client
from services.app_asset_service import AppAssetService
logger = logging.getLogger(__name__)
_CACHE_PREFIX = "skill_bundle"
_CACHE_TTL = 86400 # 24 hours
class SkillManager:
@staticmethod
def load_bundle(tenant_id: str, app_id: str, assets_id: str) -> SkillBundle:
cache_key = f"{_CACHE_PREFIX}:{tenant_id}:{app_id}:{assets_id}"
data = redis_client.get(cache_key)
if data:
return SkillBundle.model_validate_json(data)
key = AssetPaths.skill_bundle(tenant_id, app_id, assets_id)
try:
data = AppAssetService.get_storage().load_once(key)
except FileNotFoundError:
logger.exception(
"Skill bundle not found in storage: key=%s, tenant_id=%s, app_id=%s, assets_id=%s",
key,
tenant_id,
app_id,
assets_id,
)
raise
bundle = SkillBundle.model_validate_json(data)
redis_client.setex(cache_key, _CACHE_TTL, bundle.model_dump_json(indent=2).encode("utf-8"))
return bundle
@staticmethod
def save_bundle(tenant_id: str, app_id: str, assets_id: str, bundle: SkillBundle) -> None:
key = AssetPaths.skill_bundle(tenant_id, app_id, assets_id)
AppAssetService.get_storage().save(key, data=bundle.model_dump_json(indent=2).encode("utf-8"))
cache_key = f"{_CACHE_PREFIX}:{tenant_id}:{app_id}:{assets_id}"
redis_client.delete(cache_key)

View File

@ -0,0 +1,208 @@
import contextlib
import logging
import threading
import time
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
from core.virtual_environment.__base.entities import CommandResult, CommandStatus
from core.virtual_environment.__base.exec import NotSupportedOperationError
from core.virtual_environment.channel.exec import TransportEOFError
from core.virtual_environment.channel.transport import TransportReadCloser, TransportWriteCloser
logger = logging.getLogger(__name__)
class CommandTimeoutError(Exception):
pass
class CommandCancelledError(Exception):
pass
class CommandFuture:
"""
Lightweight future for command execution.
Mirrors concurrent.futures.Future API with 4 essential methods:
result(), done(), cancel(), cancelled().
When a command is cancelled or times out the future now asks the provider
to terminate the underlying process/session before marking itself done.
"""
def __init__(
self,
pid: str,
stdin_transport: TransportWriteCloser,
stdout_transport: TransportReadCloser,
stderr_transport: TransportReadCloser,
poll_status: Callable[[], CommandStatus],
terminate_command: Callable[[], bool] | None = None,
poll_interval: float = 0.1,
):
self._pid = pid
self._stdin_transport = stdin_transport
self._stdout_transport = stdout_transport
self._stderr_transport = stderr_transport
self._poll_status = poll_status
self._terminate_command = terminate_command
self._poll_interval = poll_interval
self._done_event = threading.Event()
self._lock = threading.Lock()
self._result: CommandResult | None = None
self._exception: BaseException | None = None
self._cancelled = False
self._timed_out = False
self._started = False
self._termination_requested = False
def result(self, timeout: float | None = None) -> CommandResult:
"""
Block until command completes and return result.
Args:
timeout: Maximum seconds to wait. None means wait forever.
Raises:
CommandTimeoutError: If timeout exceeded.
CommandCancelledError: If command was cancelled.
A timeout is terminal for this future: it triggers best-effort command
termination and subsequent ``result()`` calls keep raising timeout.
"""
self._ensure_started()
if not self._done_event.wait(timeout):
self._request_stop(timed_out=True)
raise CommandTimeoutError(f"Command timed out after {timeout}s")
if self._cancelled:
raise CommandCancelledError("Command was cancelled")
if self._timed_out:
raise CommandTimeoutError("Command timed out")
if self._exception is not None:
raise self._exception
assert self._result is not None
return self._result
def done(self) -> bool:
self._ensure_started()
return self._done_event.is_set()
def cancel(self) -> bool:
"""
Attempt to cancel command by terminating it and closing transports.
Returns True if cancelled, False if already completed.
"""
return self._request_stop(cancelled=True)
def cancelled(self) -> bool:
return self._cancelled
def _ensure_started(self) -> None:
with self._lock:
if not self._started:
self._started = True
thread = threading.Thread(target=self._execute, daemon=True)
thread.start()
def _request_stop(self, *, cancelled: bool = False, timed_out: bool = False) -> bool:
should_terminate = False
with self._lock:
if self._done_event.is_set():
return False
if cancelled:
self._cancelled = True
if timed_out:
self._timed_out = True
should_terminate = not self._termination_requested
if should_terminate:
self._termination_requested = True
self._close_transports()
self._done_event.set()
if should_terminate:
self._terminate_running_command()
return True
def _execute(self) -> None:
stdout_buf = bytearray()
stderr_buf = bytearray()
is_combined_stream = self._stdout_transport is self._stderr_transport
try:
with ThreadPoolExecutor(max_workers=2) as executor:
stdout_future = executor.submit(self._drain_transport, self._stdout_transport, stdout_buf)
stderr_future = None
if not is_combined_stream:
stderr_future = executor.submit(self._drain_transport, self._stderr_transport, stderr_buf)
exit_code = self._wait_for_completion()
stdout_future.result()
if stderr_future:
stderr_future.result()
with self._lock:
if not self._cancelled:
self._result = CommandResult(
stdout=bytes(stdout_buf),
stderr=b"" if is_combined_stream else bytes(stderr_buf),
exit_code=exit_code,
pid=self._pid,
)
self._done_event.set()
except Exception as e:
logger.exception("Command execution failed for pid %s", self._pid)
with self._lock:
if not self._cancelled:
self._exception = e
self._done_event.set()
finally:
self._close_transports()
def _wait_for_completion(self) -> int | None:
while not self._cancelled and not self._timed_out:
try:
status = self._poll_status()
except NotSupportedOperationError:
return None
if status.status == CommandStatus.Status.COMPLETED:
return status.exit_code
time.sleep(self._poll_interval)
return None
def _drain_transport(self, transport: TransportReadCloser, buffer: bytearray) -> None:
try:
while True:
buffer.extend(transport.read(4096))
except TransportEOFError:
pass
except Exception:
logger.exception("Failed reading transport")
def _close_transports(self) -> None:
for transport in (self._stdin_transport, self._stdout_transport, self._stderr_transport):
with contextlib.suppress(Exception):
transport.close()
def _terminate_running_command(self) -> None:
if self._terminate_command is None:
return
try:
self._terminate_command()
except Exception:
logger.exception("Failed to terminate command for pid %s", self._pid)

View File

@ -0,0 +1,100 @@
from collections.abc import Mapping
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field
class Arch(StrEnum):
"""
Architecture types for virtual environments.
"""
ARM64 = "arm64"
AMD64 = "amd64"
class OperatingSystem(StrEnum):
"""
Operating system types for virtual environments.
"""
LINUX = "linux"
DARWIN = "darwin"
class Metadata(BaseModel):
"""
Returned metadata about a virtual environment.
"""
id: str = Field(description="The unique identifier of the virtual environment.")
arch: Arch = Field(description="Which architecture was used to create the virtual environment.")
os: OperatingSystem = Field(description="The operating system of the virtual environment.")
store: Mapping[str, Any] = Field(
default_factory=dict, description="The store information of the virtual environment., Additional data."
)
class ConnectionHandle(BaseModel):
"""
Handle for managing connections to the virtual environment.
"""
id: str = Field(description="The unique identifier of the connection handle.")
class CommandStatus(BaseModel):
"""
Status of a command executed in the virtual environment.
"""
class Status(StrEnum):
RUNNING = "running"
COMPLETED = "completed"
status: Status = Field(description="The status of the command execution.")
exit_code: int | None = Field(description="The return code of the command execution.")
class FileState(BaseModel):
"""
State of a file in the virtual environment.
"""
size: int = Field(description="The size of the file in bytes.")
path: str = Field(description="The path of the file in the virtual environment.")
created_at: int = Field(description="The creation timestamp of the file.")
updated_at: int = Field(description="The last modified timestamp of the file.")
class CommandResult(BaseModel):
"""
Result of a synchronous command execution.
"""
stdout: bytes = Field(description="Standard output content.")
stderr: bytes = Field(description="Standard error content.")
exit_code: int | None = Field(description="Exit code of the command. None if unavailable.")
pid: str = Field(description="Process ID of the executed command.")
@property
def is_error(self) -> bool:
return self.exit_code not in (None, 0) or bool(self.stderr.decode("utf-8", errors="replace"))
@property
def error_message(self) -> str:
return self.stderr.decode("utf-8", errors="replace") if self.stderr else ""
@property
def info_message(self) -> str:
return self.stdout.decode("utf-8", errors="replace") if self.stdout else ""
@property
def debug_message(self) -> str:
return (
f"stdout: {self.stdout.decode('utf-8', errors='replace')}\n"
f"stderr: {self.stderr.decode('utf-8', errors='replace')}\n"
f"exit_code: {self.exit_code}\n"
f"pid: {self.pid}"
)

View File

@ -0,0 +1,64 @@
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from core.virtual_environment.__base.entities import CommandResult
class ArchNotSupportedError(Exception):
"""Exception raised when the architecture is not supported."""
pass
class VirtualEnvironmentLaunchFailedError(Exception):
"""Exception raised when launching the virtual environment fails."""
pass
class NotSupportedOperationError(Exception):
"""Exception raised when an operation is not supported."""
pass
class SandboxConfigValidationError(ValueError):
"""Exception raised when sandbox configuration validation fails."""
pass
class CommandExecutionError(ValueError):
"""Raised when a command execution fails."""
result: CommandResult
def __init__(self, message: str, result: CommandResult):
super().__init__(message)
self.result = result
@property
def exit_code(self) -> int | None:
return self.result.exit_code
@property
def stderr(self) -> bytes:
return self.result.stderr
class PipelineExecutionError(CommandExecutionError):
"""Raised when a pipeline command fails in strict mode."""
index: int
command: list[str]
results: list[CommandResult]
def __init__(
self, message: str, result: CommandResult, *, index: int, command: list[str], results: list[CommandResult]
):
super().__init__(message, result)
self.index = index
self.command = command
self.results = results

View File

@ -0,0 +1,279 @@
from __future__ import annotations
import contextlib
import shlex
from collections.abc import Generator, Mapping
from contextlib import contextmanager
from dataclasses import dataclass, field
from functools import partial
from core.virtual_environment.__base.command_future import CommandFuture
from core.virtual_environment.__base.entities import CommandResult, ConnectionHandle
from core.virtual_environment.__base.exec import CommandExecutionError, PipelineExecutionError
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
_PIPE_SENTINEL = "__DIFY_PIPE__"
@contextmanager
def with_connection(env: VirtualEnvironment) -> Generator[ConnectionHandle, None, None]:
"""Context manager for VirtualEnvironment connection lifecycle.
Automatically establishes and releases connection handles.
Usage:
with with_connection(env) as conn:
future = run_command(env, conn, ["echo", "hello"])
result = future.result(timeout=10)
"""
connection_handle = env.establish_connection()
try:
yield connection_handle
finally:
with contextlib.suppress(Exception):
env.release_connection(connection_handle)
def submit_command(
env: VirtualEnvironment,
connection: ConnectionHandle,
command: list[str],
environments: Mapping[str, str] | None = None,
*,
cwd: str | None = None,
) -> CommandFuture:
"""Execute a command and return a Future for the result.
High-level interface that handles IO draining internally.
For streaming output, use env.execute_command() instead.
Args:
env: The virtual environment to execute the command in.
connection: The connection handle.
command: Command as list of strings.
environments: Environment variables.
cwd: Working directory for the command. If None, uses the provider's default.
Returns:
CommandFuture that can be used to get result with timeout or cancel.
Example:
with with_connection(env) as conn:
result = run_command(env, conn, ["ls", "-la"]).result(timeout=30)
"""
pid, stdin_transport, stdout_transport, stderr_transport = env.execute_command(
connection, command, environments, cwd
)
return CommandFuture(
pid=pid,
stdin_transport=stdin_transport,
stdout_transport=stdout_transport,
stderr_transport=stderr_transport,
poll_status=partial(env.get_command_status, connection, pid),
terminate_command=partial(env.terminate_command, connection, pid),
)
def _execute_with_connection(
env: VirtualEnvironment,
conn: ConnectionHandle,
command: list[str],
timeout: float | None,
cwd: str | None,
) -> CommandResult:
"""Internal helper to execute command with given connection."""
future = submit_command(env, conn, command, cwd=cwd)
return future.result(timeout=timeout)
def execute(
env: VirtualEnvironment,
command: list[str],
*,
timeout: float | None = 30,
cwd: str | None = None,
error_message: str = "Command failed",
connection: ConnectionHandle | None = None,
) -> CommandResult:
"""Execute a command with automatic connection management.
Raises CommandExecutionError if the command fails (non-zero exit code).
Args:
env: The virtual environment to execute the command in.
command: The command to execute as a list of strings.
timeout: Maximum time to wait for the command to complete (seconds).
cwd: Working directory for the command.
error_message: Custom error message prefix for failures.
connection: Optional connection handle to reuse. If None, creates and releases a new connection.
Returns:
CommandResult on success.
Raises:
CommandExecutionError: If the command fails.
"""
if connection is not None:
result = _execute_with_connection(env, connection, command, timeout, cwd)
else:
with with_connection(env) as conn:
result = _execute_with_connection(env, conn, command, timeout, cwd)
if result.is_error:
raise CommandExecutionError(f"{error_message}: {result.error_message}", result)
return result
def try_execute(
env: VirtualEnvironment,
command: list[str],
*,
timeout: float | None = 30,
cwd: str | None = None,
connection: ConnectionHandle | None = None,
) -> CommandResult:
"""Execute a command with automatic connection management.
Does not raise on failure - returns the result for caller to handle.
Args:
env: The virtual environment to execute the command in.
command: The command to execute as a list of strings.
timeout: Maximum time to wait for the command to complete (seconds).
cwd: Working directory for the command.
connection: Optional connection handle to reuse. If None, creates and releases a new connection.
Returns:
CommandResult containing stdout, stderr, and exit_code.
"""
if connection is not None:
return _execute_with_connection(env, connection, command, timeout, cwd)
with with_connection(env) as conn:
return _execute_with_connection(env, conn, command, timeout, cwd)
@dataclass(frozen=True)
class _PipelineStep:
argv: list[str]
error_message: str = "Command failed"
@dataclass
class CommandPipeline:
"""Batch multiple commands into a single shell execution (Redis pipeline style).
Example:
results = pipeline(env).add(["echo", "hi"]).add(["ls"]).execute()
# Strict mode: raise on first failure
pipeline(env).add(["mkdir", "/x"], error_message="mkdir failed").execute(raise_on_error=True)
"""
env: VirtualEnvironment
connection: ConnectionHandle | None = None
cwd: str | None = None
environments: Mapping[str, str] | None = None
_steps: list[_PipelineStep] = field(default_factory=list) # pyright: ignore[reportUnknownVariableType]
def add(self, command: list[str], *, error_message: str = "Command failed", on: bool = True) -> CommandPipeline:
if on:
self._steps.append(_PipelineStep(argv=command, error_message=error_message))
return self
def execute(self, *, timeout: float | None = 30, raise_on_error: bool = False) -> list[CommandResult]:
if not self._steps:
return []
script = self._build_script(fail_fast=raise_on_error)
batch_cmd = ["sh", "-c", script]
if self.connection is not None:
batch_result = try_execute(self.env, batch_cmd, timeout=timeout, cwd=self.cwd, connection=self.connection)
else:
with with_connection(self.env) as conn:
batch_result = try_execute(self.env, batch_cmd, timeout=timeout, cwd=self.cwd, connection=conn)
results = self._parse_results(batch_result.stdout, batch_result.pid)
if raise_on_error:
for i, r in enumerate(iterable=results):
if r.is_error:
step = self._steps[i]
raise PipelineExecutionError(
f"{step.error_message}: {r.error_message}",
r,
index=i,
command=step.argv,
results=results,
)
return results
def _build_script(self, *, fail_fast: bool = False) -> str:
lines = [
"run() {",
' i="$1"; shift',
' out="$(mktemp)"; err="$(mktemp)"',
' ("$@") >"$out" 2>"$err"; ec="$?"',
' os="$(wc -c <"$out" | tr -d \' \')"',
' es="$(wc -c <"$err" | tr -d \' \')"',
f' printf \'{_PIPE_SENTINEL} %s %s %s %s\\n\' "$i" "$ec" "$os" "$es"',
' cat "$out"',
' cat "$err"',
' rm -f "$out" "$err"',
' return "$ec"',
"}",
]
suffix = " || exit $?" if fail_fast else ""
for i, step in enumerate(self._steps):
quoted = " ".join(shlex.quote(arg) for arg in step.argv)
lines.append(f"run {i} {quoted}{suffix}")
return "\n".join(lines)
@staticmethod
def _parse_results(stdout: bytes, pid: str) -> list[CommandResult]:
results: list[CommandResult] = []
pos = 0
sentinel = _PIPE_SENTINEL.encode() + b" "
while pos < len(stdout):
nl = stdout.find(b"\n", pos)
if nl == -1:
break
header = stdout[pos : nl + 1]
pos = nl + 1
if not header.startswith(sentinel):
raise ValueError("Malformed pipeline output: missing sentinel")
parts = header.decode().strip().split(" ")
_, idx, ec, os_len, es_len = parts
out_len, err_len = int(os_len), int(es_len)
out_bytes = stdout[pos : pos + out_len]
pos += out_len
err_bytes = stdout[pos : pos + err_len]
pos += err_len
results.append(
CommandResult(
stdout=out_bytes,
stderr=err_bytes,
exit_code=int(ec),
pid=f"{pid}:{idx}",
)
)
return results
def pipeline(
env: VirtualEnvironment,
connection: ConnectionHandle | None = None,
*,
cwd: str | None = None,
environments: Mapping[str, str] | None = None,
) -> CommandPipeline:
return CommandPipeline(env=env, connection=connection, cwd=cwd, environments=environments)

View File

@ -0,0 +1,231 @@
from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence
from io import BytesIO
from typing import Any
from core.entities.provider_entities import BasicProviderConfig
from core.virtual_environment.__base.entities import CommandStatus, ConnectionHandle, FileState, Metadata
from core.virtual_environment.channel.transport import TransportReadCloser, TransportWriteCloser
class VirtualEnvironment(ABC):
"""
Base class for virtual environment implementations.
``VirtualEnvironment`` instances are configured at construction time but do
not allocate provider resources until ``open_enviroment()`` is called.
This keeps object construction side-effect free and gives callers a chance
to own startup error handling explicitly.
"""
tenant_id: str
user_id: str | None
options: Mapping[str, Any]
_environments: Mapping[str, str]
_metadata: Metadata | None
def __init__(
self,
tenant_id: str,
options: Mapping[str, Any],
environments: Mapping[str, str] | None = None,
user_id: str | None = None,
) -> None:
"""
Initialize the virtual environment configuration.
Args:
tenant_id: The tenant ID associated with this environment (required).
options: Provider-specific configuration options.
environments: Environment variables to set in the virtual environment.
user_id: The user ID associated with this environment (optional).
The provider runtime itself is created later by ``open_enviroment()``.
"""
self.tenant_id = tenant_id
self.user_id = user_id
self.options = options
self._environments = dict(environments or {})
self._metadata = None
@property
def metadata(self) -> Metadata:
"""Provider metadata for a started environment.
Raises:
RuntimeError: If the environment has not been started yet.
"""
if self._metadata is None:
raise RuntimeError("Virtual environment has not been started")
return self._metadata
def open_enviroment(self) -> Metadata:
"""Allocate provider resources and return the resulting metadata.
Multiple calls are safe and return the existing metadata after the first
successful start.
"""
if self._metadata is None:
self._metadata = self._construct_environment(self.options, self._environments)
return self._metadata
@abstractmethod
def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata:
"""
Construct the unique identifier for the virtual environment.
Returns:
str: The unique identifier of the virtual environment.
"""
@abstractmethod
def upload_file(self, path: str, content: BytesIO) -> None:
"""
Upload a file to the virtual environment.
Args:
path (str): The destination path in the virtual environment.
content (BytesIO): The content of the file to upload.
Raises:
Exception: If the file cannot be uploaded.
"""
@abstractmethod
def download_file(self, path: str) -> BytesIO:
"""
Download a file from the virtual environment.
Args:
source_path (str): The source path in the virtual environment.
Returns:
BytesIO: The content of the downloaded file.
Raises:
Exception: If the file cannot be downloaded.
"""
@abstractmethod
def list_files(self, directory_path: str, limit: int) -> Sequence[FileState]:
"""
List files in a directory of the virtual environment.
Args:
directory_path (str): The directory path in the virtual environment.
limit (int): The maximum number of files(including recursive paths) to return.
Returns:
Sequence[FileState]: A list of file states in the specified directory.
Raises:
Exception: If the files cannot be listed.
Example:
If the directory structure is like:
/dir
/subdir1
file1.txt
/subdir2
file2.txt
And limit is 2, the returned list may look like:
[
FileState(path="/dir/subdir1/file1.txt", is_directory=False, size=1234, created_at=..., updated_at=...),
FileState(path="/dir/subdir2", is_directory=True, size=0, created_at=..., updated_at=...),
]
"""
@abstractmethod
def establish_connection(self) -> ConnectionHandle:
"""
Establish a connection to the virtual environment.
Returns:
ConnectionHandle: Handle for managing the connection to the virtual environment.
Raises:
Exception: If the connection cannot be established.
"""
@abstractmethod
def release_connection(self, connection_handle: ConnectionHandle) -> None:
"""
Release the connection to the virtual environment.
Args:
connection_handle (ConnectionHandle): The handle for managing the connection.
Raises:
Exception: If the connection cannot be released.
"""
@abstractmethod
def release_environment(self) -> None:
"""
Release the virtual environment.
Raises:
Exception: If the environment cannot be released.
Multiple calls to `release_environment` with the same `environment_id` is acceptable.
"""
def terminate_command(self, connection_handle: ConnectionHandle, pid: str) -> bool:
"""Best-effort termination hook for a running command.
Providers that can map ``pid`` back to a real process/session should
override this method and stop the command. The default implementation is
a no-op so providers without a termination mechanism remain compatible.
"""
_ = connection_handle
_ = pid
return False
@abstractmethod
def execute_command(
self,
connection_handle: ConnectionHandle,
command: list[str],
environments: Mapping[str, str] | None = None,
cwd: str | None = None,
) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]:
"""
Execute a command in the virtual environment.
Args:
connection_handle (ConnectionHandle): The handle for managing the connection.
command (list[str]): The command to execute as a list of strings.
environments (Mapping[str, str] | None): Environment variables for the command.
cwd (str | None): Working directory for the command. If None, uses the provider's default.
Returns:
tuple[int, TransportWriteCloser, TransportReadCloser, TransportReadCloser]
a tuple containing pid and 3 handle to os.pipe(): (stdin, stdout, stderr).
After exuection, the 3 handles will be closed by caller.
"""
@classmethod
@abstractmethod
def validate(cls, options: Mapping[str, Any]) -> None:
"""
Validate that options can connect to the provider.
Raises:
SandboxConfigValidationError: If validation fails
"""
@abstractmethod
def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus:
"""
Get the status of a command executed in the virtual environment.
Args:
connection_handle (ConnectionHandle): The handle for managing the connection.
pid (int): The process ID of the command.
Returns:
CommandStatus: The status of the command execution.
"""
@classmethod
@abstractmethod
def get_config_schema(cls) -> list[BasicProviderConfig]:
pass

View File

View File

@ -0,0 +1,4 @@
class TransportEOFError(Exception):
"""Exception raised when attempting to read from a closed transport."""
pass

View File

@ -0,0 +1,72 @@
import os
from core.virtual_environment.channel.exec import TransportEOFError
from core.virtual_environment.channel.transport import Transport, TransportReadCloser, TransportWriteCloser
class PipeTransport(Transport):
"""
A Transport implementation using OS pipes. it requires two file descriptors:
one for reading and one for writing.
NOTE: r_fd and w_fd must be a pair created by os.pipe(). or returned from subprocess.Popen
NEVER FORGET TO CALL `close()` METHOD TO AVOID FILE DESCRIPTOR LEAKAGE.
"""
def __init__(self, r_fd: int, w_fd: int):
self.r_fd = r_fd
self.w_fd = w_fd
def write(self, data: bytes) -> None:
try:
os.write(self.w_fd, data)
except OSError:
raise TransportEOFError("Pipe write error, maybe the read end is closed")
def read(self, n: int) -> bytes:
data = os.read(self.r_fd, n)
if data == b"":
raise TransportEOFError("End of Pipe reached")
return data
def close(self) -> None:
os.close(self.r_fd)
os.close(self.w_fd)
class PipeReadCloser(TransportReadCloser):
"""
A Transport implementation using OS pipe for reading.
"""
def __init__(self, r_fd: int):
self.r_fd = r_fd
def read(self, n: int) -> bytes:
data = os.read(self.r_fd, n)
if data == b"":
raise TransportEOFError("End of Pipe reached")
return data
def close(self) -> None:
os.close(self.r_fd)
class PipeWriteCloser(TransportWriteCloser):
"""
A Transport implementation using OS pipe for writing.
"""
def __init__(self, w_fd: int):
self.w_fd = w_fd
def write(self, data: bytes) -> None:
try:
os.write(self.w_fd, data)
except OSError:
raise TransportEOFError("Pipe write error, maybe the read end is closed")
def close(self) -> None:
os.close(self.w_fd)

View File

@ -0,0 +1,117 @@
from core.virtual_environment.channel.exec import TransportEOFError
from core.virtual_environment.channel.transport import TransportReadCloser
class QueueTransportReadCloser(TransportReadCloser):
"""
Transport implementation using queues for inter-thread communication.
Usage:
q_transport = QueueTransportReadCloser()
write_handler = q_transport.get_write_handler()
# In writer thread
write_handler.write(b"data")
# In reader thread
data = q_transport.read(1024)
# Close transport when done
q_transport.close()
"""
_QUEUE_GET_TIMEOUT = 5.0
class WriteHandler:
"""
A write handler that writes data to a queue.
"""
from queue import Queue
def __init__(self, queue: Queue[bytes | None]) -> None:
self.queue = queue
def write(self, data: bytes) -> None:
self.queue.put(data)
def __init__(
self,
) -> None:
"""
Initialize the QueueTransportReadCloser with write function.
"""
from queue import Queue
self.q = Queue[bytes | None]()
self._read_buffer = bytearray()
self._closed = False
self._write_channel_closed = False
def get_write_handler(self) -> WriteHandler:
"""
Get a write handler that writes to the internal queue.
"""
return QueueTransportReadCloser.WriteHandler(self.q)
def close(self) -> None:
"""
Close the transport by putting a sentinel value in the queue.
"""
if self._write_channel_closed:
raise TransportEOFError("Write channel already closed")
self._write_channel_closed = True
self.q.put(None)
def read(self, n: int) -> bytes:
"""
Read up to n bytes from the queue.
NEVER USE IT IN A MULTI-THREADED CONTEXT WITHOUT PROPER SYNCHRONIZATION.
"""
from queue import Empty
if n <= 0:
return b""
if self._closed:
raise TransportEOFError("Transport is closed")
to_return = self._drain_buffer(n)
# At the first round reading from queue, hanging is required to wait for the data
# But after that, return immediately if no data is available
round = 0
while len(to_return) < n and not self._closed and (self.q.qsize() > 0 or round == 0):
try:
chunk = self.q.get(timeout=self._QUEUE_GET_TIMEOUT)
except Empty:
if self._closed:
raise TransportEOFError("Transport is closed")
continue
if chunk is None:
self._closed = True
if len(to_return) == 0:
raise TransportEOFError("Transport is closed")
else:
break
self._read_buffer.extend(chunk)
if n - len(to_return) > 0:
# Drain the buffer if we still need more data
to_return += self._drain_buffer(n - len(to_return))
else:
# No more data needed, break
break
round += 1
return to_return
def _drain_buffer(self, n: int) -> bytes:
data = bytes(self._read_buffer[:n])
del self._read_buffer[:n]
return data

View File

@ -0,0 +1,70 @@
import socket
from core.virtual_environment.channel.exec import TransportEOFError
from core.virtual_environment.channel.transport import Transport, TransportReadCloser, TransportWriteCloser
class SocketTransport(Transport):
"""
A Transport implementation using a socket.
"""
def __init__(self, sock: socket.SocketIO):
self.sock = sock
def write(self, data: bytes) -> None:
try:
self.sock.write(data)
except (ConnectionResetError, BrokenPipeError):
raise TransportEOFError("Socket write error, maybe the read end is closed")
def read(self, n: int) -> bytes:
try:
data = self.sock.read(n)
if data == b"":
raise TransportEOFError("End of Socket reached")
except (ConnectionResetError, BrokenPipeError):
raise TransportEOFError("Socket connection reset")
return data
def close(self) -> None:
self.sock.close()
class SocketReadCloser(TransportReadCloser):
"""
A Transport implementation using a socket for reading.
"""
def __init__(self, sock: socket.SocketIO):
self.sock = sock
def read(self, n: int) -> bytes:
try:
data = self.sock.read(n)
if data == b"":
raise TransportEOFError("End of Socket reached")
return data
except (ConnectionResetError, BrokenPipeError):
raise TransportEOFError("Socket connection reset")
def close(self) -> None:
self.sock.close()
class SocketWriteCloser(TransportWriteCloser):
"""
A Transport implementation using a socket for writing.
"""
def __init__(self, sock: socket.SocketIO):
self.sock = sock
def write(self, data: bytes) -> None:
try:
self.sock.write(data)
except (ConnectionResetError, BrokenPipeError):
raise TransportEOFError("Socket write error, maybe the read end is closed")
def close(self) -> None:
self.sock.close()

View File

@ -0,0 +1,80 @@
from abc import abstractmethod
from typing import Protocol
class TransportCloser(Protocol):
"""
Transport that can be closed.
"""
@abstractmethod
def close(self) -> None:
"""
Close the transport.
"""
class TransportWriter(Protocol):
"""
Transport that can be written to.
"""
@abstractmethod
def write(self, data: bytes) -> None:
"""
Write data to the transport.
Raises TransportEOFError if the transport is closed.
"""
class TransportReader(Protocol):
"""
Transport that can be read from.
"""
@abstractmethod
def read(self, n: int) -> bytes:
"""
Read up to n bytes from the transport.
Raises TransportEOFError if the end of the transport is reached.
"""
class TransportReadCloser(TransportReader, TransportCloser):
"""
Transport that can be read from and closed.
"""
class TransportWriteCloser(TransportWriter, TransportCloser):
"""
Transport that can be written to and closed.
"""
class Transport(TransportReader, TransportWriter, TransportCloser):
"""
Transport that can be read from, written to, and closed.
"""
class NopTransportWriteCloser(TransportWriteCloser):
"""
A no-operation TransportWriteCloser implementation.
This transport does nothing on write and close operations.
"""
def write(self, data: bytes) -> None:
"""
No-operation write method.
"""
pass
def close(self) -> None:
"""
No-operation close method.
"""
pass

View File

@ -0,0 +1,10 @@
"""
Constants for virtual environment providers.
Centralizes timeout and other configuration values used across different sandbox providers
(E2B, SSH, Docker) to ensure consistency and ease of maintenance.
"""
# Command execution timeout in seconds (5 hours)
# Used by providers to limit how long a single command can run
COMMAND_EXECUTION_TIMEOUT_SECONDS = 5 * 60 * 60 # 18000 seconds

View File

@ -0,0 +1,531 @@
"""
AWS Bedrock AgentCore Code Interpreter sandbox provider.
Uses the AgentCore Code Interpreter built-in tool to provide a sandboxed code execution
environment with shell access and file operations. The Code Interpreter runs in an
isolated microVM managed by AWS and communicates via the `InvokeCodeInterpreter` API.
Two boto3 clients are involved:
- **bedrock-agentcore-control** (Control Plane): manages the Code Interpreter resource
(create / delete). Users must create one beforehand or use the system-provided
``aws.codeinterpreter.v1``.
- **bedrock-agentcore** (Data Plane): manages sessions and executes operations
(start/stop session, execute commands, file I/O).
Key differences from other providers:
- stdin is not supported (same as E2B) uses ``NopTransportWriteCloser``.
- ``executeCommand`` returns the full stdout/stderr once the command completes,
rather than streaming incrementally. We wrap the result with
``QueueTransportReadCloser`` so the upper-layer ``CommandFuture`` works unchanged.
- ``get_command_status`` raises ``NotSupportedOperationError`` (same as E2B).
The synchronous ``executeCommand`` path is used instead.
"""
import logging
import posixpath
import shlex
import threading
from collections.abc import Mapping, Sequence
from enum import StrEnum
from io import BytesIO
from typing import Any
from uuid import uuid4
from core.entities.provider_entities import BasicProviderConfig
from core.virtual_environment.__base.entities import (
Arch,
CommandStatus,
ConnectionHandle,
FileState,
Metadata,
OperatingSystem,
)
from core.virtual_environment.__base.exec import (
ArchNotSupportedError,
NotSupportedOperationError,
SandboxConfigValidationError,
)
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
from core.virtual_environment.channel.queue_transport import QueueTransportReadCloser
from core.virtual_environment.channel.transport import (
NopTransportWriteCloser,
TransportReadCloser,
TransportWriteCloser,
)
from core.virtual_environment.constants import COMMAND_EXECUTION_TIMEOUT_SECONDS
logger = logging.getLogger(__name__)
# Maximum time in seconds that the boto3 read on the EventStream socket is
# allowed to block. Must exceed the longest expected command execution.
_BOTO3_READ_TIMEOUT_SECONDS = COMMAND_EXECUTION_TIMEOUT_SECONDS + 60
class AWSCodeInterpreterEnvironment(VirtualEnvironment):
"""
AWS Bedrock AgentCore Code Interpreter virtual environment provider.
The provider maps the ``VirtualEnvironment`` protocol onto the AgentCore
Code Interpreter Data-Plane API (``InvokeCodeInterpreter``).
Lifecycle:
1. ``_construct_environment`` starts a new session via
``StartCodeInterpreterSession``.
2. Commands and file operations invoke ``InvokeCodeInterpreter`` with
the appropriate ``name`` (``executeCommand``, ``writeFiles``, etc.).
3. ``release_environment`` stops the session via
``StopCodeInterpreterSession``.
Configuration (``OptionsKey``):
- ``aws_access_key_id`` / ``aws_secret_access_key``: IAM credentials.
- ``aws_region``: AWS region (e.g. ``us-east-1``).
- ``code_interpreter_id``: the Code Interpreter resource identifier
(e.g. ``aws.codeinterpreter.v1`` for the system-provided one).
- ``session_timeout_seconds``: optional; defaults to 900 (15 min).
"""
_WORKDIR = "/home/user"
class OptionsKey(StrEnum):
AWS_ACCESS_KEY_ID = "aws_access_key_id"
AWS_SECRET_ACCESS_KEY = "aws_secret_access_key"
AWS_REGION = "aws_region"
CODE_INTERPRETER_ID = "code_interpreter_id"
SESSION_TIMEOUT_SECONDS = "session_timeout_seconds"
class StoreKey(StrEnum):
CLIENT = "client"
SESSION_ID = "session_id"
CODE_INTERPRETER_ID = "code_interpreter_id"
# ------------------------------------------------------------------
# Config schema & validation
# ------------------------------------------------------------------
@classmethod
def get_config_schema(cls) -> list[BasicProviderConfig]:
return [
BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=cls.OptionsKey.AWS_ACCESS_KEY_ID),
BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=cls.OptionsKey.AWS_SECRET_ACCESS_KEY),
BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name=cls.OptionsKey.AWS_REGION),
BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name=cls.OptionsKey.CODE_INTERPRETER_ID),
]
@classmethod
def validate(cls, options: Mapping[str, Any]) -> None:
"""Validate credentials by starting then immediately stopping a session."""
import boto3
from botocore.exceptions import ClientError
for key in (cls.OptionsKey.AWS_ACCESS_KEY_ID, cls.OptionsKey.AWS_SECRET_ACCESS_KEY, cls.OptionsKey.AWS_REGION):
if not options.get(key):
raise SandboxConfigValidationError(f"{key} is required")
code_interpreter_id = options.get(cls.OptionsKey.CODE_INTERPRETER_ID, "")
if not code_interpreter_id:
raise SandboxConfigValidationError("code_interpreter_id is required")
client = boto3.client(
"bedrock-agentcore",
region_name=options[cls.OptionsKey.AWS_REGION],
aws_access_key_id=options[cls.OptionsKey.AWS_ACCESS_KEY_ID],
aws_secret_access_key=options[cls.OptionsKey.AWS_SECRET_ACCESS_KEY],
)
try:
resp = client.start_code_interpreter_session(
codeInterpreterIdentifier=code_interpreter_id,
sessionTimeoutSeconds=60,
)
session_id = resp["sessionId"]
# Immediately stop the validation session.
client.stop_code_interpreter_session(
codeInterpreterIdentifier=code_interpreter_id,
sessionId=session_id,
)
except ClientError as exc:
raise SandboxConfigValidationError(f"AWS AgentCore Code Interpreter validation failed: {exc}") from exc
except Exception as exc:
raise SandboxConfigValidationError(f"AWS AgentCore Code Interpreter connection failed: {exc}") from exc
# ------------------------------------------------------------------
# Environment lifecycle
# ------------------------------------------------------------------
def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata:
"""Start a new Code Interpreter session and detect platform info."""
import boto3
from botocore.config import Config
code_interpreter_id: str = options.get(self.OptionsKey.CODE_INTERPRETER_ID, "")
timeout_seconds: int = int(options.get(self.OptionsKey.SESSION_TIMEOUT_SECONDS, 900))
client = boto3.client(
"bedrock-agentcore",
region_name=options[self.OptionsKey.AWS_REGION],
aws_access_key_id=options[self.OptionsKey.AWS_ACCESS_KEY_ID],
aws_secret_access_key=options[self.OptionsKey.AWS_SECRET_ACCESS_KEY],
config=Config(read_timeout=_BOTO3_READ_TIMEOUT_SECONDS),
)
resp = client.start_code_interpreter_session(
codeInterpreterIdentifier=code_interpreter_id,
sessionTimeoutSeconds=timeout_seconds,
)
session_id: str = resp["sessionId"]
logger.info(
"AgentCore Code Interpreter session started: code_interpreter_id=%s, session_id=%s",
code_interpreter_id,
session_id,
)
# Detect architecture and OS via a quick command.
arch = Arch.AMD64
operating_system = OperatingSystem.LINUX
try:
result = self._invoke(client, code_interpreter_id, session_id, "executeCommand", {"command": "uname -m -s"})
system_info = (result.get("stdout") or "").strip()
parts = system_info.split()
if len(parts) >= 2:
operating_system = self._convert_operating_system(parts[0])
arch = self._convert_architecture(parts[1])
elif len(parts) == 1:
arch = self._convert_architecture(parts[0])
except Exception:
logger.warning("Failed to detect platform info, defaulting to Linux/AMD64")
# Inject environment variables if provided.
if environments:
export_parts = [f"export {k}={shlex.quote(v)}" for k, v in environments.items()]
export_cmd = " && ".join(export_parts)
try:
self._invoke(client, code_interpreter_id, session_id, "executeCommand", {"command": export_cmd})
except Exception:
logger.warning("Failed to inject environment variables into AgentCore session")
return Metadata(
id=session_id,
arch=arch,
os=operating_system,
store={
self.StoreKey.CLIENT: client,
self.StoreKey.SESSION_ID: session_id,
self.StoreKey.CODE_INTERPRETER_ID: code_interpreter_id,
},
)
def release_environment(self) -> None:
"""Stop the Code Interpreter session and release resources."""
client = self._client
try:
client.stop_code_interpreter_session(
codeInterpreterIdentifier=self._code_interpreter_id,
sessionId=self._session_id,
)
logger.info("AgentCore Code Interpreter session stopped: session_id=%s", self._session_id)
except Exception:
logger.warning(
"Failed to stop AgentCore Code Interpreter session: session_id=%s",
self._session_id,
exc_info=True,
)
# ------------------------------------------------------------------
# Connection (virtual — no real connection needed)
# ------------------------------------------------------------------
def establish_connection(self) -> ConnectionHandle:
return ConnectionHandle(id=uuid4().hex)
def release_connection(self, connection_handle: ConnectionHandle) -> None:
pass
# ------------------------------------------------------------------
# Command execution
# ------------------------------------------------------------------
def execute_command(
self,
connection_handle: ConnectionHandle,
command: list[str],
environments: Mapping[str, str] | None = None,
cwd: str | None = None,
) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]:
"""
Execute a shell command via AgentCore Code Interpreter.
The command is executed synchronously on the AWS side; a background thread
calls ``executeCommand`` and feeds the result into queue-based transports
so the caller sees the standard Transport interface.
stdin is not supported ``NopTransportWriteCloser`` is returned.
"""
stdout_stream = QueueTransportReadCloser()
stderr_stream = QueueTransportReadCloser()
working_dir = self._workspace_path(cwd) if cwd else self._WORKDIR
cmd_str = shlex.join(command)
# Wrap env vars and cwd into the command string since the API only
# accepts a flat ``command`` string argument.
prefix_parts: list[str] = []
if environments:
for k, v in environments.items():
prefix_parts.append(f"export {k}={shlex.quote(v)}")
prefix_parts.append(f"cd {shlex.quote(working_dir)}")
full_cmd = " && ".join([*prefix_parts, cmd_str])
threading.Thread(
target=self._cmd_thread,
args=(full_cmd, stdout_stream, stderr_stream),
daemon=True,
).start()
return (
"N/A",
NopTransportWriteCloser(),
stdout_stream,
stderr_stream,
)
def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus:
"""Not supported — same as E2B. ``CommandFuture`` handles this gracefully."""
raise NotSupportedOperationError("AgentCore Code Interpreter does not support getting command status.")
# ------------------------------------------------------------------
# File operations
# ------------------------------------------------------------------
def upload_file(self, path: str, content: BytesIO) -> None:
"""Upload a file to the Code Interpreter session."""
remote_path = self._workspace_path(path)
file_bytes = content.read()
self._invoke(
self._client,
self._code_interpreter_id,
self._session_id,
"writeFiles",
{"content": [{"path": remote_path, "blob": file_bytes}]},
)
def download_file(self, path: str) -> BytesIO:
"""Download a file from the Code Interpreter session."""
remote_path = self._workspace_path(path)
result = self._invoke(
self._client,
self._code_interpreter_id,
self._session_id,
"readFiles",
{"path": remote_path},
)
# The response content blocks may contain blob or text data.
content_blocks: list[dict[str, Any]] = result.get("content", [])
for block in content_blocks:
resource = block.get("resource")
if resource:
blob = resource.get("blob")
if blob:
return BytesIO(blob if isinstance(blob, bytes) else blob.encode())
text = resource.get("text")
if text:
return BytesIO(text.encode("utf-8"))
# Fallback: check top-level data/text fields.
if block.get("data"):
data = block["data"]
return BytesIO(data if isinstance(data, bytes) else data.encode())
if block.get("text"):
return BytesIO(block["text"].encode("utf-8"))
return BytesIO(b"")
def list_files(self, directory_path: str, limit: int) -> Sequence[FileState]:
"""List files in a directory of the Code Interpreter session."""
remote_dir = self._workspace_path(directory_path)
result = self._invoke(
self._client,
self._code_interpreter_id,
self._session_id,
"listFiles",
{"directoryPath": remote_dir},
)
# The API returns file information in content blocks.
# Since the exact structure may vary, we also fall back to running
# a shell command to list files if the content format is not parseable.
content_blocks: list[dict[str, Any]] = result.get("content", [])
files: list[FileState] = []
for block in content_blocks:
text = block.get("text", "")
name = block.get("name", "")
size = block.get("size", 0)
uri = block.get("uri", "")
file_path = uri or name or text
if not file_path:
continue
# Normalise to relative path from workdir.
if file_path.startswith(self._WORKDIR):
file_path = posixpath.relpath(file_path, self._WORKDIR)
files.append(
FileState(
path=file_path,
size=size or 0,
created_at=0,
updated_at=0,
)
)
if len(files) >= limit:
break
return files
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
@property
def _client(self) -> Any:
return self.metadata.store[self.StoreKey.CLIENT]
@property
def _session_id(self) -> str:
return str(self.metadata.store[self.StoreKey.SESSION_ID])
@property
def _code_interpreter_id(self) -> str:
return str(self.metadata.store[self.StoreKey.CODE_INTERPRETER_ID])
def _workspace_path(self, path: str) -> str:
"""Convert a path to an absolute path in the Code Interpreter session."""
normalized = posixpath.normpath(path)
if normalized in ("", "."):
return self._WORKDIR
if normalized.startswith("/"):
return normalized
return posixpath.join(self._WORKDIR, normalized)
def _cmd_thread(
self,
command: str,
stdout_stream: QueueTransportReadCloser,
stderr_stream: QueueTransportReadCloser,
) -> None:
"""Background thread that executes a command and feeds output into queue transports."""
stdout_writer = stdout_stream.get_write_handler()
stderr_writer = stderr_stream.get_write_handler()
try:
result = self._invoke(
self._client,
self._code_interpreter_id,
self._session_id,
"executeCommand",
{"command": command},
)
stdout_data = result.get("stdout", "")
stderr_data = result.get("stderr", "")
if stdout_data:
stdout_writer.write(stdout_data.encode("utf-8") if isinstance(stdout_data, str) else stdout_data)
if stderr_data:
stderr_writer.write(stderr_data.encode("utf-8") if isinstance(stderr_data, str) else stderr_data)
except Exception as exc:
error_msg = f"Command execution failed: {type(exc).__name__}: {exc}\n"
stderr_writer.write(error_msg.encode())
finally:
stdout_stream.close()
stderr_stream.close()
@staticmethod
def _invoke(
client: Any,
code_interpreter_id: str,
session_id: str,
name: str,
arguments: dict[str, Any],
) -> dict[str, Any]:
"""
Call ``InvokeCodeInterpreter`` and extract the structured result.
The API returns an EventStream; this helper iterates over it and merges
``structuredContent`` and ``content`` from all ``result`` events into a
single dict.
Returns a dict that may contain:
- ``stdout``, ``stderr``, ``exitCode``, ``taskId``, ``taskStatus``,
``executionTime`` (from ``structuredContent``)
- ``content`` (list of content blocks)
- ``isError`` (bool)
"""
response = client.invoke_code_interpreter(
codeInterpreterIdentifier=code_interpreter_id,
sessionId=session_id,
name=name,
arguments=arguments,
)
merged: dict[str, Any] = {"content": []}
stream = response.get("stream")
if stream is None:
return merged
for event in stream:
result = event.get("result")
if result is None:
# Check for exception events.
for exc_key in (
"accessDeniedException",
"validationException",
"resourceNotFoundException",
"throttlingException",
"internalServerException",
):
if event.get(exc_key):
raise RuntimeError(f"AgentCore error ({exc_key}): {event[exc_key]}")
continue
if result.get("isError"):
merged["isError"] = True
structured = result.get("structuredContent")
if structured:
merged.update(structured)
content = result.get("content")
if content:
merged["content"].extend(content)
return merged
@staticmethod
def _convert_architecture(arch_str: str) -> Arch:
arch_map: dict[str, Arch] = {
"x86_64": Arch.AMD64,
"aarch64": Arch.ARM64,
"armv7l": Arch.ARM64,
"arm64": Arch.ARM64,
"amd64": Arch.AMD64,
}
if arch_str in arch_map:
return arch_map[arch_str]
raise ArchNotSupportedError(f"Unsupported architecture: {arch_str}")
@staticmethod
def _convert_operating_system(os_str: str) -> OperatingSystem:
os_map: dict[str, OperatingSystem] = {
"Linux": OperatingSystem.LINUX,
"Darwin": OperatingSystem.DARWIN,
}
if os_str in os_map:
return os_map[os_str]
raise ArchNotSupportedError(f"Unsupported operating system: {os_str}")

View File

@ -0,0 +1,281 @@
import logging
import posixpath
import shlex
import threading
import time
from collections.abc import Mapping, Sequence
from datetime import datetime
from enum import StrEnum
from io import BytesIO
from typing import Any, TypedDict, cast
from uuid import uuid4
from daytona import (
CodeLanguage,
CreateSandboxFromImageParams,
CreateSandboxFromSnapshotParams,
Daytona,
DaytonaConfig,
Sandbox,
)
logger = logging.getLogger(__name__)
class _CommandRecord(TypedDict):
"""Record for tracking command execution state."""
thread: threading.Thread
exit_code: int | None
from core.virtual_environment.__base.entities import (
Arch,
CommandStatus,
ConnectionHandle,
FileState,
Metadata,
OperatingSystem,
)
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
from core.virtual_environment.channel.queue_transport import QueueTransportReadCloser
from core.virtual_environment.channel.transport import (
NopTransportWriteCloser,
TransportReadCloser,
TransportWriteCloser,
)
class DaytonaEnvironment(VirtualEnvironment):
"""
Daytona virtual environment provider backed by Daytona Sandboxes.
"""
_DEFAULT_DAYTONA_API_URL = "https://app.daytona.io/api"
class OptionsKey(StrEnum):
API_KEY = "api_key"
API_URL = "api_url"
TARGET = "target"
LANGUAGE = "language"
SNAPSHOT = "snapshot"
IMAGE = "image"
AUTO_STOP_INTERVAL = "auto_stop_interval"
AUTO_ARCHIVE_INTERVAL = "auto_archive_interval"
AUTO_DELETE_INTERVAL = "auto_delete_interval"
PUBLIC = "public"
NAME = "name"
LABELS = "labels"
class StoreKey(StrEnum):
DAYTONA = "daytona"
SANDBOX = "sandbox"
WORKDIR = "workdir"
COMMANDS = "commands"
COMMANDS_LOCK = "commands_lock"
def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata:
config = DaytonaConfig(
api_key=cast(str | None, options.get(self.OptionsKey.API_KEY)),
api_url=cast(str | None, options.get(self.OptionsKey.API_URL, self._DEFAULT_DAYTONA_API_URL)),
target=cast(str | None, options.get(self.OptionsKey.TARGET)),
)
daytona = Daytona(config)
language = cast(CodeLanguage, options.get(self.OptionsKey.LANGUAGE, CodeLanguage.PYTHON))
auto_stop_interval = cast(int | None, options.get(self.OptionsKey.AUTO_STOP_INTERVAL))
auto_archive_interval = cast(int | None, options.get(self.OptionsKey.AUTO_ARCHIVE_INTERVAL))
auto_delete_interval = cast(int | None, options.get(self.OptionsKey.AUTO_DELETE_INTERVAL))
public = cast(bool | None, options.get(self.OptionsKey.PUBLIC))
name = cast(str | None, options.get(self.OptionsKey.NAME))
labels = cast(dict[str, str] | None, options.get(self.OptionsKey.LABELS))
image = cast(str | None, options.get(self.OptionsKey.IMAGE))
snapshot = cast(str | None, options.get(self.OptionsKey.SNAPSHOT))
if image is not None:
params = CreateSandboxFromImageParams(
image=image,
language=language,
env_vars=dict(environments or {}),
auto_stop_interval=auto_stop_interval,
auto_archive_interval=auto_archive_interval,
auto_delete_interval=auto_delete_interval,
public=public,
name=name,
labels=labels,
)
else:
params = CreateSandboxFromSnapshotParams(
snapshot=snapshot,
language=language,
env_vars=dict(environments or {}),
auto_stop_interval=auto_stop_interval,
auto_archive_interval=auto_archive_interval,
auto_delete_interval=auto_delete_interval,
public=public,
name=name,
labels=labels,
)
sandbox = daytona.create(params=params)
workdir = sandbox.get_work_dir()
return Metadata(
id=sandbox.id,
arch=Arch.AMD64,
os=OperatingSystem.LINUX,
store={
self.StoreKey.DAYTONA: daytona,
self.StoreKey.SANDBOX: sandbox,
self.StoreKey.WORKDIR: workdir,
self.StoreKey.COMMANDS: {},
self.StoreKey.COMMANDS_LOCK: threading.Lock(),
},
)
def release_environment(self) -> None:
daytona: Daytona = self.metadata.store[self.StoreKey.DAYTONA]
sandbox = self.metadata.store[self.StoreKey.SANDBOX]
try:
daytona.delete(sandbox)
except Exception:
logger.exception("Failed to delete Daytona sandbox %s during cleanup", sandbox.id)
def establish_connection(self) -> ConnectionHandle:
return ConnectionHandle(id=uuid4().hex)
def release_connection(self, connection_handle: ConnectionHandle) -> None:
pass
def upload_file(self, path: str, content: BytesIO) -> None:
remote_path = self._workspace_path(path)
sandbox: Sandbox = self.metadata.store[self.StoreKey.SANDBOX]
sandbox.fs.upload_file(content.getvalue(), remote_path)
def download_file(self, path: str) -> BytesIO:
remote_path = self._workspace_path(path)
sandbox: Sandbox = self.metadata.store[self.StoreKey.SANDBOX]
data = sandbox.fs.download_file(remote_path)
return BytesIO(data)
def list_files(self, directory_path: str, limit: int) -> Sequence[FileState]:
remote_dir = self._workspace_path(directory_path)
sandbox: Sandbox = self.metadata.store[self.StoreKey.SANDBOX]
try:
file_infos = sandbox.fs.list_files(remote_dir)
except Exception:
logger.exception("Failed to list files in directory %s", remote_dir)
return []
files: list[FileState] = []
for info in file_infos:
full_path = posixpath.join(remote_dir, info.name)
relative_path = posixpath.relpath(full_path, self._working_dir)
files.append(
FileState(
path=relative_path,
size=info.size,
created_at=self._parse_mod_time(info.mod_time),
updated_at=self._parse_mod_time(info.mod_time),
)
)
if len(files) >= limit:
break
return files
def execute_command(
self,
connection_handle: ConnectionHandle,
command: list[str],
environments: Mapping[str, str] | None = None,
cwd: str | None = None,
) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]:
sandbox: Sandbox = self.metadata.store[self.StoreKey.SANDBOX]
stdout_stream = QueueTransportReadCloser()
stderr_stream = QueueTransportReadCloser()
pid = uuid4().hex
working_dir = self._workspace_path(cwd) if cwd else self._working_dir
thread = threading.Thread(
target=self._exec_thread,
args=(pid, sandbox, command, environments or {}, working_dir, stdout_stream, stderr_stream),
daemon=True,
)
thread.start()
return pid, NopTransportWriteCloser(), stdout_stream, stderr_stream
def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus:
commands: dict[str, _CommandRecord] = self.metadata.store[self.StoreKey.COMMANDS]
commands_lock: threading.Lock = self.metadata.store[self.StoreKey.COMMANDS_LOCK]
with commands_lock:
record = commands.get(pid)
if not record:
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=None)
thread: threading.Thread = record["thread"]
exit_code = record.get("exit_code")
if thread.is_alive():
return CommandStatus(status=CommandStatus.Status.RUNNING, exit_code=None)
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=exit_code)
@property
def _working_dir(self) -> str:
return cast(str, self.metadata.store[self.StoreKey.WORKDIR])
def _workspace_path(self, path: str) -> str:
normalized = posixpath.normpath(path)
if normalized in ("", "."):
return self._working_dir
if normalized.startswith("/"):
return normalized
return posixpath.join(self._working_dir, normalized)
def _exec_thread(
self,
pid: str,
sandbox: Sandbox,
command: list[str],
environments: Mapping[str, str],
cwd: str,
stdout_stream: QueueTransportReadCloser,
stderr_stream: QueueTransportReadCloser,
) -> None:
commands: dict[str, _CommandRecord] = self.metadata.store[self.StoreKey.COMMANDS]
commands_lock: threading.Lock = self.metadata.store[self.StoreKey.COMMANDS_LOCK]
stdout_writer = stdout_stream.get_write_handler()
stderr_writer = stderr_stream.get_write_handler()
exit_code: int | None = None
try:
response = sandbox.process.exec(
command=shlex.join(command),
env=dict(environments),
cwd=cwd,
)
exit_code = response.exit_code
output = response.artifacts.stdout if response.artifacts and response.artifacts.stdout else response.result
if output:
stdout_writer.write(output.encode())
except Exception as exc:
stderr_writer.write(str(exc).encode())
exit_code = 1
finally:
stdout_stream.close()
stderr_stream.close()
with commands_lock:
if pid in commands:
commands[pid]["exit_code"] = exit_code
def _parse_mod_time(self, mod_time: str) -> int:
try:
cleaned = mod_time.replace("Z", "+00:00")
return int(datetime.fromisoformat(cleaned).timestamp())
except (ValueError, AttributeError, OSError):
logger.warning("Failed to parse modification time '%s', falling back to current time", mod_time)
return int(time.time())

View File

@ -0,0 +1,635 @@
from __future__ import annotations
import logging
import socket
import tarfile
import threading
from collections.abc import Mapping, Sequence
from enum import IntEnum, StrEnum
from functools import lru_cache
from io import BytesIO
from pathlib import PurePosixPath
from queue import Empty, Queue
from typing import TYPE_CHECKING, Any, cast
from uuid import uuid4
if TYPE_CHECKING:
from docker.models.containers import Container
import docker
from configs import dify_config
from core.entities.provider_entities import BasicProviderConfig
from core.virtual_environment.__base.entities import (
Arch,
CommandStatus,
ConnectionHandle,
FileState,
Metadata,
OperatingSystem,
)
from core.virtual_environment.__base.exec import SandboxConfigValidationError, VirtualEnvironmentLaunchFailedError
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
from core.virtual_environment.channel.exec import TransportEOFError
from core.virtual_environment.channel.socket_transport import SocketWriteCloser
from core.virtual_environment.channel.transport import TransportReadCloser, TransportWriteCloser
class DockerStreamType(IntEnum):
"""
Docker multiplexed stream types.
When Docker exec runs with tty=False, it multiplexes stdout and stderr over a single
socket connection. Each frame is prefixed with an 8-byte header:
[stream_type (1 byte)][0][0][0][payload_size (4 bytes, big-endian)]
This allows the client to distinguish between stdout (type=1) and stderr (type=2).
See: https://docs.docker.com/engine/api/v1.41/#operation/ContainerAttach
"""
STDIN = 0
STDOUT = 1
STDERR = 2
class DockerDemuxer:
"""
Demultiplexes Docker's combined stdout/stderr stream using producer-consumer pattern.
Docker exec with tty=False sends stdout and stderr over a single socket,
each frame prefixed with an 8-byte header:
- Byte 0: stream type (1=stdout, 2=stderr)
- Bytes 1-3: reserved (zeros)
- Bytes 4-7: payload size (big-endian uint32)
THREAD SAFETY:
A single background thread reads frames from the socket and dispatches payloads
to thread-safe queues. This avoids race conditions where multiple threads
calling _read_next_frame() simultaneously caused frame header/body corruption,
resulting in incomplete stdout/stderr output.
TIMEOUT HANDLING:
Queue.get() uses a timeout to prevent indefinite blocking when the socket is
closed unexpectedly (e.g., container removed). This allows periodic checks for
error conditions and closed state.
"""
_HEADER_SIZE = 8
_QUEUE_GET_TIMEOUT = 5.0 # seconds
def __init__(self, sock: socket.SocketIO):
self._sock = sock
self._stdout_queue: Queue[bytes | None] = Queue()
self._stderr_queue: Queue[bytes | None] = Queue()
self._closed = False
self._error: BaseException | None = None
self._demux_thread = threading.Thread(
target=self._demux_loop,
daemon=True,
name="docker-demuxer",
)
self._demux_thread.start()
def _demux_loop(self) -> None:
try:
while not self._closed:
header = self._read_exact(self._HEADER_SIZE)
if not header or len(header) < self._HEADER_SIZE:
break
frame_type = header[0]
payload_size = int.from_bytes(header[4:8], "big")
if payload_size == 0:
continue
payload = self._read_exact(payload_size)
if not payload:
break
if frame_type == DockerStreamType.STDOUT:
self._stdout_queue.put(payload)
elif frame_type == DockerStreamType.STDERR:
self._stderr_queue.put(payload)
except BaseException as e:
self._error = e
finally:
self._stdout_queue.put(None)
self._stderr_queue.put(None)
def _read_exact(self, size: int) -> bytes:
data = bytearray()
remaining = size
while remaining > 0:
try:
chunk = self._sock.read(remaining)
if not chunk:
return bytes(data) if data else b""
data.extend(chunk)
remaining -= len(chunk)
except (ConnectionResetError, BrokenPipeError):
return bytes(data) if data else b""
return bytes(data)
def read_stdout(self) -> bytes:
return self._read_from_queue(self._stdout_queue)
def read_stderr(self) -> bytes:
return self._read_from_queue(self._stderr_queue)
def _read_from_queue(self, queue: Queue[bytes | None]) -> bytes:
"""
Read from queue with timeout to prevent indefinite blocking.
When the Docker container is removed or the socket is closed unexpectedly,
the demux thread may be stuck in socket.read(). Using a timeout allows us
to periodically check for errors and closed state instead of blocking forever.
"""
if self._error:
raise TransportEOFError(f"Demuxer error: {self._error}") from self._error
while True:
try:
chunk = queue.get(timeout=self._QUEUE_GET_TIMEOUT)
if chunk is None:
if self._error:
raise TransportEOFError(f"Demuxer error: {str(self._error)}")
raise TransportEOFError("End of demuxed stream")
return chunk
except Empty:
# Timeout - check if we should continue waiting
if self._closed:
raise TransportEOFError("Demuxer closed")
if self._error:
error = cast(BaseException, self._error)
raise TransportEOFError(f"Demuxer error: {error}") from error
# No error, continue waiting
def close(self) -> None:
if not self._closed:
self._closed = True
try:
self._sock.close()
except Exception:
logging.error("Failed to close Docker demuxer socket", exc_info=True)
class DemuxedStdoutReader(TransportReadCloser):
def __init__(self, demuxer: DockerDemuxer):
self._demuxer = demuxer
self._buffer = bytearray()
def read(self, n: int) -> bytes:
if self._buffer:
data = bytes(self._buffer[:n])
del self._buffer[:n]
if data:
return data
chunk = self._demuxer.read_stdout()
if len(chunk) <= n:
return chunk
self._buffer.extend(chunk[n:])
return chunk[:n]
def close(self) -> None:
self._demuxer.close()
class DemuxedStderrReader(TransportReadCloser):
def __init__(self, demuxer: DockerDemuxer):
self._demuxer = demuxer
self._buffer = bytearray()
def read(self, n: int) -> bytes:
if self._buffer:
data = bytes(self._buffer[:n])
del self._buffer[:n]
if data:
return data
chunk = self._demuxer.read_stderr()
if len(chunk) <= n:
return chunk
self._buffer.extend(chunk[n:])
return chunk[:n]
def close(self) -> None:
self._demuxer.close()
"""
EXAMPLE:
from collections.abc import Mapping
from typing import Any
from
from core.virtual_environment.providers.docker_daemon_sandbox import DockerDaemonEnvironment
options: Mapping[str, Any] = {
# OptionsKey values are optional
# DockerDaemonEnvironment.OptionsKey.DOCKER_SOCK: "unix:///var/run/docker.sock",
# DockerDaemonEnvironment.OptionsKey.DOCKER_AGENT_IMAGE: "ubuntu:latest",
# DockerDaemonEnvironment.OptionsKey.DOCKER_AGENT_COMMAND
#
"docker_sock": "unix:///var/run/docker.sock", # optional, default to unix socket
"docker_agent_image": "ubuntu:latest", # optional, default to ubuntu:latest
"docker_agent_command": "/bin/sh -c 'while true; do sleep 1; done'", # optional, default to None
}
environment = DockerDaemonEnvironment(options=options)
connection_handle = environment.establish_connection()
pid, transport_stdout, transport_stderr, transport_stdin = environment.execute_command(
connection_handle, ["uname", "-a"]
)
print(f"Executed command with PID: {pid}")
# consume stdout
# consume stdout
while True:
try:
output = transport_stdout.read(1024)
except TransportEOFError:
logger.info("End of stdout reached")
break
logger.info("Command output: %s", output.decode().strip())
environment.release_connection(connection_handle)
environment.release_environment()
"""
class DockerDaemonEnvironment(VirtualEnvironment):
_WORKING_DIR = "/workspace"
_DEAFULT_DOCKER_IMAGE = "ubuntu:latest"
_DEFAULT_DOCKER_SOCK = (
"unix:///var/run/docker.sock" # Use an invalid default to avoid accidental local docker usage
)
class OptionsKey(StrEnum):
DOCKER_SOCK = "docker_sock"
DOCKER_IMAGE = "docker_image"
DOCKER_COMMAND = "docker_command"
@classmethod
def get_config_schema(cls) -> list[BasicProviderConfig]:
return [
BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name=cls.OptionsKey.DOCKER_SOCK),
BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name=cls.OptionsKey.DOCKER_IMAGE),
]
@classmethod
def validate(cls, options: Mapping[str, Any]) -> None:
# Import Docker SDK lazily so it is loaded after gevent monkey-patching.
import docker.errors
import docker
docker_sock = options.get(cls.OptionsKey.DOCKER_SOCK, cls._DEFAULT_DOCKER_SOCK)
try:
client = docker.DockerClient(base_url=docker_sock)
client.ping()
except docker.errors.DockerException as e:
raise SandboxConfigValidationError(f"Docker connection failed: {e}") from e
def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata:
"""
Construct the Docker daemon virtual environment.
"""
docker_client = self.get_docker_daemon(
docker_sock=options.get(self.OptionsKey.DOCKER_SOCK, self._DEFAULT_DOCKER_SOCK)
)
default_docker_image = options.get(self.OptionsKey.DOCKER_IMAGE, self._DEAFULT_DOCKER_IMAGE)
container_command = options.get(self.OptionsKey.DOCKER_COMMAND)
container = docker_client.containers.run(
image=default_docker_image,
command=container_command,
detach=True,
remove=True,
stdin_open=True,
working_dir=self._WORKING_DIR,
environment=dict(environments),
)
# FIXME(yeuoly): For a better solution
if dify_config.FILES_URL.startswith("http://localhost") or dify_config.FILES_URL.startswith("http://127.0.0.1"):
logging.warning(
"DIFY_FILES_URL is set to a localhost address. "
"Docker containers may not be able to access the host's localhost. "
"Consider using host.docker.internal or the host machine's IP address."
)
dify_host_port = dify_config.DIFY_PORT
# launch socat to forward 5001 port from host to container
container.exec_run( # pyright: ignore[reportUnknownMemberType] #
cmd=[
"bash",
"-c",
f"nohup socat TCP-LISTEN:{dify_host_port},bind=127.0.0.1,fork,reuseaddr "
f"TCP:host.docker.internal:{dify_host_port} >/tmp/socat.log 2>&1",
],
detach=True,
)
# wait for the container to be fully started
container.reload()
if not container.id:
raise VirtualEnvironmentLaunchFailedError("Failed to start Docker container for DockerDaemonEnvironment.")
return Metadata(
id=container.id,
arch=self._get_container_architecture(container),
os=OperatingSystem.LINUX,
)
@classmethod
@lru_cache(maxsize=5)
def get_docker_daemon(cls, docker_sock: str) -> docker.DockerClient:
"""
Get the Docker daemon client.
NOTE: I guess nobody will use more than 5 different docker sockets in practice....
"""
import docker
return docker.DockerClient(base_url=docker_sock)
@classmethod
@lru_cache(maxsize=5)
def get_docker_api_client(cls, docker_sock: str) -> docker.APIClient:
"""
Get the Docker low-level API client.
"""
import docker
return docker.APIClient(base_url=docker_sock)
def get_docker_sock(self) -> str:
"""
Get the Docker socket path.
"""
return self.options.get(self.OptionsKey.DOCKER_SOCK, self._DEFAULT_DOCKER_SOCK)
@property
def _working_dir(self) -> str:
"""
Get the working directory inside the Docker container.
"""
return self._WORKING_DIR
def _get_container(self) -> Container:
"""
Get the Docker container instance.
"""
docker_client = self.get_docker_daemon(self.get_docker_sock())
return docker_client.containers.get(self.metadata.id)
def _normalize_relative_path(self, path: str) -> PurePosixPath:
parts: list[str] = []
for part in PurePosixPath(path).parts:
if part in ("", ".", "/"):
continue
if part == "..":
if not parts:
raise ValueError("Path escapes the workspace.")
parts.pop()
continue
parts.append(part)
return PurePosixPath(*parts)
def _relative_path(self, path: str) -> PurePosixPath:
normalized = self._normalize_relative_path(path)
if normalized.parts:
return normalized
return PurePosixPath()
def _container_path(self, path: str) -> str:
relative = self._relative_path(path)
if not relative.parts:
return self._working_dir
return f"{self._working_dir}/{relative.as_posix()}"
def _workspace_path(self, path: str) -> str:
"""
Convert a path to an absolute path in the Docker container.
Absolute paths are returned as-is, relative paths are joined with _working_dir.
"""
normalized = PurePosixPath(path)
if normalized.is_absolute():
return str(normalized)
return self._container_path(path)
def upload_file(self, path: str, content: BytesIO) -> None:
"""Upload a file to the container.
Files and intermediate directories are created with world-writable permissions
(0o777 for directories, 0o666 for files) to avoid permission issues when the container
runs as a non-root user but Docker's put_archive creates files as root.
"""
container = self._get_container()
normalized = PurePosixPath(path)
if normalized.is_absolute():
parent_dir = str(normalized.parent)
file_name = normalized.name
payload = content.getvalue()
tar_stream = BytesIO()
with tarfile.open(fileobj=tar_stream, mode="w") as tar:
tar_info = tarfile.TarInfo(name=file_name)
tar_info.size = len(payload)
tar_info.mode = 0o666
tar.addfile(tar_info, BytesIO(payload))
tar_stream.seek(0)
container.put_archive(parent_dir, tar_stream.read()) # pyright: ignore[reportUnknownMemberType] #
return
relative_path = self._relative_path(path)
if not relative_path.parts:
raise ValueError("Upload path must point to a file within the workspace.")
payload = content.getvalue()
tar_stream = BytesIO()
with tarfile.open(fileobj=tar_stream, mode="w") as tar:
# Add intermediate directories with proper permissions
for i in range(len(relative_path.parts) - 1):
dir_path = PurePosixPath(*relative_path.parts[: i + 1])
dir_info = tarfile.TarInfo(name=dir_path.as_posix() + "/")
dir_info.type = tarfile.DIRTYPE
dir_info.mode = 0o777
tar.addfile(dir_info)
# Add the file
tar_info = tarfile.TarInfo(name=relative_path.as_posix())
tar_info.size = len(payload)
tar_info.mode = 0o666
tar.addfile(tar_info, BytesIO(payload))
tar_stream.seek(0)
container.put_archive(self._working_dir, tar_stream.read()) # pyright: ignore[reportUnknownMemberType] #
def download_file(self, path: str) -> BytesIO:
container = self._get_container()
container_path = self._workspace_path(path)
stream, _ = container.get_archive(container_path)
tar_stream = BytesIO()
for chunk in stream:
tar_stream.write(chunk)
tar_stream.seek(0)
with tarfile.open(fileobj=tar_stream, mode="r:*") as tar:
members = [member for member in tar.getmembers() if member.isfile()]
if not members:
return BytesIO()
extracted = tar.extractfile(members[0])
if extracted is None:
return BytesIO()
return BytesIO(extracted.read())
def list_files(self, directory_path: str, limit: int) -> Sequence[FileState]:
import docker.errors
container = self._get_container()
container_path = self._container_path(directory_path)
relative_base = self._relative_path(directory_path)
try:
stream, _ = container.get_archive(container_path)
except docker.errors.NotFound:
return []
tar_stream = BytesIO()
for chunk in stream:
tar_stream.write(chunk)
tar_stream.seek(0)
files: list[FileState] = []
archive_root = PurePosixPath(container_path).name
with tarfile.open(fileobj=tar_stream, mode="r:*") as tar:
for member in tar.getmembers():
if not member.isfile():
continue
member_path = PurePosixPath(member.name)
if member_path.parts and member_path.parts[0] == archive_root:
member_path = PurePosixPath(*member_path.parts[1:])
if not member_path.parts:
continue
relative_path = relative_base / member_path
files.append(
FileState(
path=relative_path.as_posix(),
size=member.size,
created_at=int(member.mtime),
updated_at=int(member.mtime),
)
)
if len(files) >= limit:
break
return files
def establish_connection(self) -> ConnectionHandle:
return ConnectionHandle(id=uuid4().hex)
def release_connection(self, connection_handle: ConnectionHandle) -> None:
# No action needed for Docker exec connections
pass
def release_environment(self) -> None:
import docker.errors
try:
container = self._get_container()
except docker.errors.NotFound:
return
try:
container.remove(force=True)
except docker.errors.NotFound:
return
def execute_command(
self,
connection_handle: ConnectionHandle,
command: list[str],
environments: Mapping[str, str] | None = None,
cwd: str | None = None,
) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]:
container = self._get_container()
container_id = container.id
if not isinstance(container_id, str) or not container_id:
raise RuntimeError("Docker container ID is not available for exec.")
api_client = self.get_docker_api_client(self.get_docker_sock())
working_dir = self._workspace_path(cwd) if cwd else self._working_dir
exec_info: dict[str, object] = cast(
dict[str, object],
api_client.exec_create( # pyright: ignore[reportUnknownMemberType] #
container_id,
cmd=command,
stdin=True,
stdout=True,
stderr=True,
tty=False,
workdir=working_dir,
environment=dict(environments) if environments else None,
),
)
if not isinstance(exec_info.get("Id"), str):
raise RuntimeError("Failed to create Docker exec instance.")
exec_id: str = str(exec_info.get("Id"))
raw_sock: socket.SocketIO = cast(socket.SocketIO, api_client.exec_start(exec_id, socket=True, tty=False)) # pyright: ignore[reportUnknownMemberType] #
stdin_transport = SocketWriteCloser(raw_sock)
demuxer = DockerDemuxer(raw_sock)
stdout_transport = DemuxedStdoutReader(demuxer)
stderr_transport = DemuxedStderrReader(demuxer)
return exec_id, stdin_transport, stdout_transport, stderr_transport
def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus:
api_client = self.get_docker_api_client(self.get_docker_sock())
inspect: dict[str, object] = cast(dict[str, object], api_client.exec_inspect(pid)) # pyright: ignore[reportUnknownMemberType] #
exit_code = inspect.get("ExitCode")
if inspect.get("Running") or exit_code is None:
return CommandStatus(status=CommandStatus.Status.RUNNING, exit_code=None)
if not isinstance(exit_code, int):
exit_code = None
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=exit_code)
def _get_container_architecture(self, container: Container) -> Arch:
"""
Detect the container's CPU architecture from its image metadata.
Falls back to ``uname -m`` inside the container when image attrs are unavailable.
"""
try:
image = container.image
arch_str = str(image.attrs.get("Architecture", "")).lower() if image else ""
except Exception:
arch_str = ""
if not arch_str:
result = container.exec_run("uname -m")
arch_str = result.output.decode("utf-8", errors="replace").strip().lower() if result.exit_code == 0 else ""
match arch_str:
case "x86_64" | "amd64":
return Arch.AMD64
case "aarch64" | "arm64":
return Arch.ARM64
case _:
logging.warning("Unknown container architecture '%s', defaulting to AMD64", arch_str)
return Arch.AMD64

View File

@ -0,0 +1,360 @@
import logging
import posixpath
import shlex
import threading
from collections.abc import Mapping, Sequence
from enum import StrEnum
from functools import cached_property
from io import BytesIO
from typing import Any
from uuid import uuid4
from configs import dify_config
from core.entities.provider_entities import BasicProviderConfig
from core.virtual_environment.__base.entities import (
Arch,
CommandStatus,
ConnectionHandle,
FileState,
Metadata,
OperatingSystem,
)
from core.virtual_environment.__base.exec import (
ArchNotSupportedError,
NotSupportedOperationError,
SandboxConfigValidationError,
)
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
from core.virtual_environment.channel.queue_transport import QueueTransportReadCloser
from core.virtual_environment.channel.transport import (
NopTransportWriteCloser,
TransportReadCloser,
TransportWriteCloser,
)
from core.virtual_environment.constants import COMMAND_EXECUTION_TIMEOUT_SECONDS
logger = logging.getLogger(__name__)
"""
import logging
from collections.abc import Mapping
from typing import Any
from core.virtual_environment.providers.e2b_sandbox import E2BEnvironment
options: Mapping[str, Any] = {
E2BEnvironment.OptionsKey.API_KEY: "?????????",
E2BEnvironment.OptionsKey.E2B_DEFAULT_TEMPLATE: "code-interpreter-v1",
E2BEnvironment.OptionsKey.E2B_LIST_FILE_DEPTH: 2,
E2BEnvironment.OptionsKey.E2B_API_URL: "https://api.e2b.app",
}
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG)
# environment = DockerDaemonEnvironment(options=options)
# environment = LocalVirtualEnvironment(options=options)
environment = E2BEnvironment(options=options)
connection_handle = environment.establish_connection()
pid, transport_stdin, transport_stdout, transport_stderr = environment.execute_command(
connection_handle, ["uname", "-a"]
)
logger.info("Executed command with PID: %s", pid)
# consume stdout
# consume stdout
while True:
try:
output = transport_stdout.read(1024)
except TransportEOFError:
logger.info("End of stdout reached")
break
logger.info("Command output: %s", output.decode().strip())
environment.release_connection(connection_handle)
environment.release_environment()
"""
class E2BEnvironment(VirtualEnvironment):
"""
E2B virtual environment provider.
"""
_WORKDIR = "/home/user"
_E2B_API_URL = "https://api.e2b.app"
class OptionsKey(StrEnum):
API_KEY = "api_key"
E2B_LIST_FILE_DEPTH = "e2b_list_file_depth"
E2B_DEFAULT_TEMPLATE = "e2b_default_template"
E2B_API_URL = "e2b_api_url"
class StoreKey(StrEnum):
SANDBOX = "sandbox"
@classmethod
def get_config_schema(cls) -> list[BasicProviderConfig]:
return [
BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=cls.OptionsKey.API_KEY),
BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name=cls.OptionsKey.E2B_API_URL),
BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name=cls.OptionsKey.E2B_DEFAULT_TEMPLATE),
]
@classmethod
def validate(cls, options: Mapping[str, Any]) -> None:
# Import E2B SDK lazily so it is loaded after gevent monkey-patching.
# See `api/gunicorn.conf.py` for how we patch other third-party libs (e.g. gRPC).
from e2b.exceptions import (
AuthenticationException, # type: ignore[import-untyped]
)
from e2b_code_interpreter import Sandbox # type: ignore[import-untyped]
api_key = options.get(cls.OptionsKey.API_KEY, "")
if not api_key:
raise SandboxConfigValidationError("E2B API key is required")
try:
Sandbox.list(api_key=api_key, limit=1).next_items()
except AuthenticationException as e:
raise SandboxConfigValidationError(f"E2B authentication failed: {e}") from e
except Exception as e:
raise SandboxConfigValidationError(f"E2B connection failed: {e}") from e
def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata:
"""
Construct a new E2B virtual environment.
The sandbox lifetime is capped by ``WORKFLOW_MAX_EXECUTION_TIME`` so the
provider can rely on E2B's native timeout instead of a background
keepalive thread that continuously extends the session.
E2B allocates the remote sandbox before metadata probing completes, so
startup failures must best-effort terminate the sandbox before the
exception escapes.
"""
# Import E2B SDK lazily so it is loaded after gevent monkey-patching.
from e2b_code_interpreter import Sandbox # type: ignore[import-untyped]
# TODO: add Dify as the user agent
sandbox = None
sandbox_id: str | None = None
api_key = options.get(self.OptionsKey.API_KEY, "")
try:
sandbox = Sandbox.create(
template=options.get(self.OptionsKey.E2B_DEFAULT_TEMPLATE, "code-interpreter-v1"),
timeout=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
api_key=api_key,
api_url=options.get(self.OptionsKey.E2B_API_URL, self._E2B_API_URL),
envs=dict(environments),
)
info = sandbox.get_info(api_key=api_key)
sandbox_id = info.sandbox_id
system_info = sandbox.commands.run("uname -m -s").stdout.strip()
system_parts = system_info.split()
if len(system_parts) == 2:
os_part, arch_part = system_parts
else:
arch_part = system_parts[0]
os_part = system_parts[1] if len(system_parts) > 1 else ""
return Metadata(
id=info.sandbox_id,
arch=self._convert_architecture(arch_part.strip()),
os=self._convert_operating_system(os_part.strip()),
store={
self.StoreKey.SANDBOX: sandbox,
},
)
except Exception:
if sandbox_id is None and sandbox is not None:
sandbox_id = getattr(sandbox, "sandbox_id", None)
if sandbox_id is not None:
try:
Sandbox.kill(api_key=api_key, sandbox_id=sandbox_id)
except Exception:
logger.exception("Failed to cleanup E2B sandbox after startup failure")
raise
def release_environment(self) -> None:
"""
Release the E2B virtual environment.
"""
from e2b_code_interpreter import Sandbox # type: ignore[import-untyped]
if not Sandbox.kill(api_key=self.api_key, sandbox_id=self.metadata.id):
raise Exception(f"Failed to release E2B sandbox with ID: {self.metadata.id}")
def establish_connection(self) -> ConnectionHandle:
"""
Establish a connection to the E2B virtual environment.
"""
return ConnectionHandle(id=uuid4().hex)
def release_connection(self, connection_handle: ConnectionHandle) -> None:
"""
Release the connection to the E2B virtual environment.
"""
pass
def upload_file(self, path: str, content: BytesIO) -> None:
"""
Upload a file to the E2B virtual environment.
Args:
path (str): The path to upload the file to.
content (BytesIO): The content of the file.
"""
remote_path = self._workspace_path(path)
sandbox = self.metadata.store[self.StoreKey.SANDBOX]
sandbox.files.write(remote_path, content) # pyright: ignore[reportUnknownMemberType] #
def download_file(self, path: str) -> BytesIO:
"""
Download a file from the E2B virtual environment.
Args:
path (str): The path to download the file from.
Returns:
BytesIO: The content of the file.
"""
remote_path = self._workspace_path(path)
sandbox = self.metadata.store[self.StoreKey.SANDBOX]
content = sandbox.files.read(remote_path, format="bytes")
return BytesIO(bytes(content))
def list_files(self, directory_path: str, limit: int) -> Sequence[FileState]:
"""
List files in a directory of the E2B virtual environment.
"""
sandbox = self.metadata.store[self.StoreKey.SANDBOX]
remote_dir = self._workspace_path(directory_path)
files_info = sandbox.files.list(remote_dir, depth=self.options.get(self.OptionsKey.E2B_LIST_FILE_DEPTH, 3))
return [
FileState(
path=posixpath.relpath(file_info.path, self._WORKDIR),
size=file_info.size,
created_at=int(file_info.modified_time.timestamp()),
updated_at=int(file_info.modified_time.timestamp()),
)
for file_info in files_info
]
def execute_command(
self,
connection_handle: ConnectionHandle,
command: list[str],
environments: Mapping[str, str] | None = None,
cwd: str | None = None,
) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]:
"""
Execute a command in the E2B virtual environment.
STDIN is not yet supported. E2B's API is such a terrible mess... to support it may lead a bad design.
as a result we leave it for future improvement.
"""
sandbox = self.metadata.store[self.StoreKey.SANDBOX]
stdout_stream = QueueTransportReadCloser()
stderr_stream = QueueTransportReadCloser()
working_dir = self._workspace_path(cwd) if cwd else self._WORKDIR
threading.Thread(
target=self._cmd_thread,
args=(sandbox, command, environments, working_dir, stdout_stream, stderr_stream),
).start()
return (
"N/A",
NopTransportWriteCloser(), # stdin not supported yet
stdout_stream,
stderr_stream,
)
def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus:
"""
Nop, E2B does not support getting command status yet.
"""
raise NotSupportedOperationError("E2B does not support getting command status yet.")
def _cmd_thread(
self,
sandbox: Any,
command: list[str],
environments: Mapping[str, str] | None,
cwd: str,
stdout_stream: QueueTransportReadCloser,
stderr_stream: QueueTransportReadCloser,
) -> None:
stdout_stream_write_handler = stdout_stream.get_write_handler()
stderr_stream_write_handler = stderr_stream.get_write_handler()
try:
sandbox.commands.run(
cmd=shlex.join(command),
envs=dict(environments or {}),
cwd=cwd,
on_stdout=lambda data: stdout_stream_write_handler.write(data.encode()),
on_stderr=lambda data: stderr_stream_write_handler.write(data.encode()),
timeout=COMMAND_EXECUTION_TIMEOUT_SECONDS,
)
except Exception as e:
# Capture exceptions and write to stderr stream so they can be retrieved via CommandFuture
# This prevents uncaught exceptions from being printed to console
error_msg = f"Command execution failed: {type(e).__name__}: {str(e)}\n"
stderr_stream_write_handler.write(error_msg.encode())
finally:
# Close the write handlers to signal EOF
stdout_stream.close()
stderr_stream.close()
@cached_property
def api_key(self) -> str:
"""
Get the API key for the E2B environment.
"""
return self.options.get(self.OptionsKey.API_KEY, "")
def _workspace_path(self, path: str) -> str:
"""
Convert a path to an absolute path in the E2B environment.
Absolute paths are returned as-is, relative paths are joined with _WORKDIR.
"""
normalized = posixpath.normpath(path)
if normalized in ("", "."):
return self._WORKDIR
if normalized.startswith("/"):
return normalized
return posixpath.join(self._WORKDIR, normalized)
def _convert_architecture(self, arch_str: str) -> Arch:
arch_map = {
"x86_64": Arch.AMD64,
"aarch64": Arch.ARM64,
"armv7l": Arch.ARM64,
"arm64": Arch.ARM64,
"amd64": Arch.AMD64,
"arm64v8": Arch.ARM64,
"arm64v7": Arch.ARM64,
}
if arch_str in arch_map:
return arch_map[arch_str]
raise ArchNotSupportedError(f"Unsupported architecture: {arch_str}")
def _convert_operating_system(self, os_str: str) -> OperatingSystem:
os_map = {
"Linux": OperatingSystem.LINUX,
"Darwin": OperatingSystem.DARWIN,
}
if os_str in os_map:
return os_map[os_str]
raise ArchNotSupportedError(f"Unsupported operating system: {os_str}")

View File

@ -0,0 +1,308 @@
import os
import pathlib
import shutil
import signal
import subprocess
from collections.abc import Mapping, Sequence
from functools import cached_property
from io import BytesIO
from platform import machine, system
from typing import Any
from uuid import uuid4
from core.entities.provider_entities import BasicProviderConfig
from core.virtual_environment.__base.entities import (
Arch,
CommandStatus,
ConnectionHandle,
FileState,
Metadata,
OperatingSystem,
)
from core.virtual_environment.__base.exec import ArchNotSupportedError
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
from core.virtual_environment.channel.pipe_transport import PipeReadCloser, PipeWriteCloser
from core.virtual_environment.channel.transport import TransportReadCloser, TransportWriteCloser
"""
USAGE:
import logging
from collections.abc import Mapping
from typing import Any
from core.virtual_environment.channel.exec import TransportEOFError
from core.virtual_environment.providers.local_without_isolation import LocalVirtualEnvironment
options: Mapping[str, Any] = {}
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG)
environment = LocalVirtualEnvironment(options=options)
connection_handle = environment.establish_connection()
pid, transport_stdin, transport_stdout, transport_stderr = environment.execute_command(
connection_handle,
["sh", "-lc", "for i in 1 2 3 4 5; do date '+%F %T'; sleep 1; done"],
)
logger.info("Executed command with PID: %s", pid)
# consume stdout
while True:
try:
output = transport_stdout.read(1024)
except TransportEOFError:
logger.info("End of stdout reached")
break
logger.info("Command output: %s", output.decode().strip())
environment.release_connection(connection_handle)
environment.release_environment()
"""
class LocalVirtualEnvironment(VirtualEnvironment):
"""
Local virtual environment provider without isolation.
WARNING: This provider does not provide any isolation. It's only suitable for development and testing purposes.
NEVER USE IT IN PRODUCTION ENVIRONMENTS.
"""
@classmethod
def get_config_schema(cls) -> list[BasicProviderConfig]:
return [
BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name="base_working_path"),
]
@classmethod
def validate(cls, options: Mapping[str, Any]) -> None:
pass
def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata:
"""
Construct the local virtual environment.
Under local without isolation, this method simply create a path for the environment and return the metadata.
"""
id = uuid4().hex
working_path = os.path.join(self._base_working_path, id)
os.makedirs(working_path, exist_ok=True)
return Metadata(
id=id,
arch=self._get_os_architecture(),
os=self._get_operating_system(),
)
def release_environment(self) -> None:
"""
Release the local virtual environment.
Just simply remove the working directory.
"""
working_path = self.get_working_path()
if os.path.exists(working_path):
shutil.rmtree(working_path)
def upload_file(self, path: str, content: BytesIO) -> None:
"""
Upload a file to the local virtual environment.
Args:
path (str): The path to upload the file to.
content (BytesIO): The content of the file.
"""
working_path = self.get_working_path()
full_path = os.path.join(working_path, path)
os.makedirs(os.path.dirname(full_path), exist_ok=True)
pathlib.Path(full_path).write_bytes(content.getbuffer())
def download_file(self, path: str) -> BytesIO:
"""
Download a file from the local virtual environment.
Args:
path (str): The path to download the file from.
Returns:
BytesIO: The content of the file.
"""
working_path = self.get_working_path()
full_path = os.path.join(working_path, path)
content = pathlib.Path(full_path).read_bytes()
return BytesIO(content)
def list_files(self, directory_path: str, limit: int) -> Sequence[FileState]:
"""
List files in a directory of the local virtual environment.
"""
working_path = self.get_working_path()
full_directory_path = os.path.join(working_path, directory_path)
files: list[FileState] = []
for root, _, filenames in os.walk(full_directory_path):
for filename in filenames:
if len(files) >= limit:
break
file_path = os.path.relpath(os.path.join(root, filename), working_path)
state = os.stat(os.path.join(root, filename))
files.append(
FileState(
path=file_path,
size=state.st_size,
created_at=int(state.st_ctime),
updated_at=int(state.st_mtime),
)
)
if len(files) >= limit:
# break the outer loop as well
return files
return files
def establish_connection(self) -> ConnectionHandle:
"""
Establish a connection to the local virtual environment.
"""
return ConnectionHandle(
id=uuid4().hex,
)
def release_connection(self, connection_handle: ConnectionHandle) -> None:
"""
Release the connection to the local virtual environment.
"""
# No action needed for local without isolation
pass
def execute_command(
self,
connection_handle: ConnectionHandle,
command: list[str],
environments: Mapping[str, str] | None = None,
cwd: str | None = None,
) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]:
working_path = os.path.join(self.get_working_path(), cwd) if cwd else self.get_working_path()
stdin_read_fd, stdin_write_fd = os.pipe()
stdout_read_fd, stdout_write_fd = os.pipe()
stderr_read_fd, stderr_write_fd = os.pipe()
try:
process = subprocess.Popen(
command,
stdin=stdin_read_fd,
stdout=stdout_write_fd,
stderr=stderr_write_fd,
cwd=working_path,
close_fds=True,
env=environments,
)
except Exception:
# Clean up file descriptors if process creation fails
for fd in (
stdin_read_fd,
stdin_write_fd,
stdout_read_fd,
stdout_write_fd,
stderr_read_fd,
stderr_write_fd,
):
try:
os.close(fd)
except OSError:
pass
raise
# Close unused fds in the parent process
os.close(stdin_read_fd)
os.close(stdout_write_fd)
os.close(stderr_write_fd)
# Create PipeTransport instances for stdin, stdout, and stderr
stdin_transport = PipeWriteCloser(w_fd=stdin_write_fd)
stdout_transport = PipeReadCloser(r_fd=stdout_read_fd)
stderr_transport = PipeReadCloser(r_fd=stderr_read_fd)
# Return the process ID and file descriptors for stdin, stdout, and stderr
return str(process.pid), stdin_transport, stdout_transport, stderr_transport
def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus:
pid_int = int(pid)
try:
waited_pid, wait_status = os.waitpid(pid_int, os.WNOHANG)
if waited_pid == 0:
return CommandStatus(status=CommandStatus.Status.RUNNING, exit_code=None)
if os.WIFEXITED(wait_status):
exit_code = os.WEXITSTATUS(wait_status)
elif os.WIFSIGNALED(wait_status):
exit_code = -os.WTERMSIG(wait_status)
else:
exit_code = None
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=exit_code)
except ChildProcessError:
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=None)
def terminate_command(self, connection_handle: ConnectionHandle, pid: str) -> bool:
"""Terminate a locally spawned process by PID when cancellation is requested."""
_ = connection_handle
try:
os.kill(int(pid), signal.SIGTERM)
except ProcessLookupError:
return False
return True
def _get_os_architecture(self) -> Arch:
"""
Get the operating system architecture.
Returns:
Arch: The operating system architecture.
"""
arch = machine()
match arch.lower():
case "x86_64" | "amd64":
return Arch.AMD64
case "aarch64" | "arm64":
return Arch.ARM64
case _:
raise ArchNotSupportedError(f"Unsupported architecture: {arch}")
def _get_operating_system(self) -> OperatingSystem:
os_name = system().lower()
match os_name:
case "linux":
return OperatingSystem.LINUX
case "darwin":
return OperatingSystem.DARWIN
case _:
raise ArchNotSupportedError(f"Unsupported operating system: {os_name}")
@cached_property
def _base_working_path(self) -> str:
"""
Get the base working path for the local virtual environment.
Args:
options (Mapping[str, Any]): Options for requesting the virtual environment.
Returns:
str: The base working path.
"""
cwd = os.getcwd()
return self.options.get("base_working_path", os.path.join(cwd, "local_virtual_environments"))
def get_working_path(self) -> str:
"""
Get the working path for the local virtual environment.
Returns:
str: The working path.
"""
return os.path.join(self._base_working_path, self.metadata.id)

View File

@ -0,0 +1,499 @@
from __future__ import annotations
import contextlib
import logging
import shlex
import stat
import threading
import time
from collections.abc import Mapping, Sequence
from enum import StrEnum
from io import BytesIO
from pathlib import PurePosixPath
from typing import Any
from uuid import uuid4
from core.entities.provider_entities import BasicProviderConfig
from core.virtual_environment.__base.entities import (
Arch,
CommandStatus,
ConnectionHandle,
FileState,
Metadata,
OperatingSystem,
)
from core.virtual_environment.__base.exec import SandboxConfigValidationError, VirtualEnvironmentLaunchFailedError
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
from core.virtual_environment.channel.exec import TransportEOFError
from core.virtual_environment.channel.queue_transport import QueueTransportReadCloser
from core.virtual_environment.channel.transport import TransportWriteCloser
from core.virtual_environment.constants import COMMAND_EXECUTION_TIMEOUT_SECONDS
logger = logging.getLogger(__name__)
class _SSHStdinTransport(TransportWriteCloser):
def __init__(self, channel: Any):
self._channel = channel
self._closed = False
def write(self, data: bytes) -> None:
if self._closed:
raise TransportEOFError("Transport is closed")
if not data:
return
self._channel.sendall(data)
def close(self) -> None:
if self._closed:
return
self._closed = True
with contextlib.suppress(Exception):
self._channel.shutdown_write()
class SSHSandboxEnvironment(VirtualEnvironment):
_DEFAULT_SSH_HOST = "agentbox"
_DEFAULT_SSH_PORT = 22
_DEFAULT_BASE_WORKING_PATH = "/workspace/sandboxes"
_SSH_CONNECT_TIMEOUT_SECONDS = 10
_SSH_OPERATION_TIMEOUT_SECONDS = 30
_COMMAND_TIMEOUT_EXIT_CODE = 124
class OptionsKey(StrEnum):
SSH_HOST = "ssh_host"
SSH_PORT = "ssh_port"
SSH_USERNAME = "ssh_username"
SSH_PASSWORD = "ssh_password"
BASE_WORKING_PATH = "base_working_path"
def __init__(
self,
tenant_id: str,
options: Mapping[str, Any],
environments: Mapping[str, str] | None = None,
user_id: str | None = None,
) -> None:
self._connections: dict[str, Any] = {}
self._commands: dict[str, CommandStatus] = {}
self._command_channels: dict[str, Any] = {}
self._lock = threading.Lock()
super().__init__(tenant_id=tenant_id, options=options, environments=environments, user_id=user_id)
@classmethod
def get_config_schema(cls) -> list[BasicProviderConfig]:
return [
BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name=cls.OptionsKey.SSH_HOST),
BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name=cls.OptionsKey.SSH_PORT),
BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name=cls.OptionsKey.SSH_USERNAME),
BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=cls.OptionsKey.SSH_PASSWORD),
BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name=cls.OptionsKey.BASE_WORKING_PATH),
]
@classmethod
def validate(cls, options: Mapping[str, Any]) -> None:
cls._require_non_empty_option(options, cls.OptionsKey.SSH_USERNAME)
cls._require_non_empty_option(options, cls.OptionsKey.SSH_PASSWORD)
with cls._create_ssh_client(options):
return
def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata:
environment_id = uuid4().hex
working_path = self._workspace_path_from_id(environment_id)
try:
with self._client() as client:
self._run_command(client, f"mkdir -p {shlex.quote(working_path)}")
arch_stdout = self._run_command(client, "uname -m")
os_stdout = self._run_command(client, "uname -s")
except SandboxConfigValidationError as e:
raise ValueError(f"SSH configuration validation failed, please check sandbox provider: {e}") from e
except Exception as e:
raise VirtualEnvironmentLaunchFailedError(f"Failed to construct SSH environment: {e}") from e
return Metadata(
id=environment_id,
arch=self._parse_arch(arch_stdout.decode("utf-8", errors="replace").strip()),
os=self._parse_os(os_stdout.decode("utf-8", errors="replace").strip()),
store={"working_path": working_path},
)
def establish_connection(self) -> ConnectionHandle:
connection_id = uuid4().hex
client = self._create_ssh_client(self.options)
with self._lock:
self._connections[connection_id] = client
return ConnectionHandle(id=connection_id)
def release_connection(self, connection_handle: ConnectionHandle) -> None:
with self._lock:
client = self._connections.pop(connection_handle.id, None)
if client is not None:
with contextlib.suppress(Exception):
client.close()
def release_environment(self) -> None:
working_path = self.get_working_path()
with contextlib.suppress(Exception):
with self._client() as client:
self._run_command(client, f"rm -rf {shlex.quote(working_path)}")
def execute_command(
self,
connection_handle: ConnectionHandle,
command: list[str],
environments: Mapping[str, str] | None = None,
cwd: str | None = None,
) -> tuple[str, TransportWriteCloser, QueueTransportReadCloser, QueueTransportReadCloser]:
client = self._get_connection(connection_handle)
transport = client.get_transport()
if transport is None:
raise RuntimeError("SSH transport is not available")
channel = transport.open_session()
channel.settimeout(self._SSH_OPERATION_TIMEOUT_SECONDS)
channel.set_combine_stderr(False)
execution_command = self._build_exec_command(command, environments, cwd)
channel.exec_command(execution_command)
pid = uuid4().hex
stdin_transport = _SSHStdinTransport(channel)
stdout_transport = QueueTransportReadCloser()
stderr_transport = QueueTransportReadCloser()
with self._lock:
self._commands[pid] = CommandStatus(status=CommandStatus.Status.RUNNING, exit_code=None)
self._command_channels[pid] = channel
threading.Thread(
target=self._consume_channel_output,
args=(pid, channel, stdout_transport, stderr_transport, COMMAND_EXECUTION_TIMEOUT_SECONDS),
daemon=True,
).start()
return pid, stdin_transport, stdout_transport, stderr_transport
def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus:
with self._lock:
status = self._commands.get(pid)
if status is None:
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=None)
return status
def terminate_command(self, connection_handle: ConnectionHandle, pid: str) -> bool:
"""Best-effort termination by closing the SSH channel that owns the command."""
_ = connection_handle
with self._lock:
channel = self._command_channels.get(pid)
if channel is None:
return False
self._commands[pid] = CommandStatus(
status=CommandStatus.Status.COMPLETED,
exit_code=self._COMMAND_TIMEOUT_EXIT_CODE,
)
with contextlib.suppress(Exception):
channel.close()
return True
def upload_file(self, path: str, content: BytesIO) -> None:
destination_path = self._workspace_path(path)
with self._client() as client:
sftp = client.open_sftp()
try:
self._set_sftp_operation_timeout(sftp)
self._sftp_mkdirs(sftp, str(PurePosixPath(destination_path).parent))
with sftp.file(destination_path, "wb") as remote_file:
remote_file.write(content.getvalue())
finally:
sftp.close()
def download_file(self, path: str) -> BytesIO:
source_path = self._workspace_path(path)
with self._client() as client:
sftp = client.open_sftp()
try:
self._set_sftp_operation_timeout(sftp)
with sftp.file(source_path, "rb") as remote_file:
return BytesIO(remote_file.read())
finally:
sftp.close()
def list_files(self, directory_path: str, limit: int) -> Sequence[FileState]:
if limit <= 0:
return []
root_directory = self._workspace_path(directory_path)
files: list[FileState] = []
with self._client() as client:
sftp = client.open_sftp()
try:
self._set_sftp_operation_timeout(sftp)
pending = [root_directory]
while pending and len(files) < limit:
current_directory = pending.pop(0)
with contextlib.suppress(FileNotFoundError):
for attr in sftp.listdir_attr(current_directory):
current_path = str(PurePosixPath(current_directory) / attr.filename)
mode = attr.st_mode
if stat.S_ISDIR(mode):
pending.append(current_path)
continue
files.append(
FileState(
path=self._to_relative_workspace_path(current_path),
size=attr.st_size,
created_at=int(attr.st_mtime),
updated_at=int(attr.st_mtime),
)
)
if len(files) >= limit:
break
finally:
sftp.close()
return files
@classmethod
def _require_non_empty_option(cls, options: Mapping[str, Any], key: OptionsKey) -> str:
value = options.get(key)
if not isinstance(value, str) or not value.strip():
raise SandboxConfigValidationError(f"Missing required option: {key}")
return value.strip()
@classmethod
def _create_ssh_client(cls, options: Mapping[str, Any]) -> Any:
import paramiko
host = options.get(cls.OptionsKey.SSH_HOST, cls._DEFAULT_SSH_HOST)
port = options.get(cls.OptionsKey.SSH_PORT, cls._DEFAULT_SSH_PORT)
username = cls._require_non_empty_option(options, cls.OptionsKey.SSH_USERNAME)
password = cls._require_non_empty_option(options, cls.OptionsKey.SSH_PASSWORD)
if not isinstance(host, str) or not host.strip():
raise SandboxConfigValidationError(f"Invalid option value: {cls.OptionsKey.SSH_HOST}")
try:
port_int = int(port)
except (TypeError, ValueError) as e:
raise SandboxConfigValidationError(f"Invalid option value: {cls.OptionsKey.SSH_PORT}") from e
client = paramiko.SSHClient()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
try:
client.connect(
hostname=host.strip(),
port=port_int,
username=username,
password=password,
look_for_keys=False,
allow_agent=False,
timeout=cls._SSH_CONNECT_TIMEOUT_SECONDS,
banner_timeout=cls._SSH_CONNECT_TIMEOUT_SECONDS,
auth_timeout=cls._SSH_CONNECT_TIMEOUT_SECONDS,
)
transport = client.get_transport()
if transport is not None:
transport.set_keepalive(30)
except Exception as e:
with contextlib.suppress(Exception):
client.close()
raise SandboxConfigValidationError(f"SSH connection failed: {e}") from e
return client
@contextlib.contextmanager
def _client(self):
client = self._create_ssh_client(self.options)
try:
yield client
finally:
with contextlib.suppress(Exception):
client.close()
def _get_connection(self, connection_handle: ConnectionHandle) -> Any:
with self._lock:
client = self._connections.get(connection_handle.id)
if client is None:
raise ValueError(f"Connection handle not found: {connection_handle.id}")
return client
def _workspace_path_from_id(self, environment_id: str) -> str:
base_path = self.options.get(self.OptionsKey.BASE_WORKING_PATH, self._DEFAULT_BASE_WORKING_PATH)
if not isinstance(base_path, str) or not base_path.strip():
base_path = self._DEFAULT_BASE_WORKING_PATH
return str(PurePosixPath(base_path) / environment_id)
def get_working_path(self) -> str:
working_path = self.metadata.store.get("working_path")
if not isinstance(working_path, str) or not working_path:
return self._workspace_path_from_id(self.metadata.id)
return working_path
def _workspace_path(self, path: str | None) -> str:
if not path:
return self.get_working_path()
normalized = PurePosixPath(path)
if normalized.is_absolute():
return str(normalized)
return str(PurePosixPath(self.get_working_path()) / self._normalize_relative_path(path))
@staticmethod
def _normalize_relative_path(path: str) -> PurePosixPath:
parts: list[str] = []
for part in PurePosixPath(path).parts:
if part in ("", ".", "/"):
continue
if part == "..":
if not parts:
raise ValueError("Path escapes the workspace.")
parts.pop()
continue
parts.append(part)
return PurePosixPath(*parts)
def _to_relative_workspace_path(self, path: str) -> str:
workspace = PurePosixPath(self.get_working_path())
target = PurePosixPath(path)
if target.is_relative_to(workspace):
return target.relative_to(workspace).as_posix()
return target.as_posix()
def _build_exec_command(
self, command: list[str], environments: Mapping[str, str] | None = None, cwd: str | None = None
) -> str:
working_path = self._workspace_path(cwd)
command_body = f"cd {shlex.quote(working_path)} && "
if environments:
env_clause = " ".join(f"{key}={shlex.quote(value)}" for key, value in environments.items())
command_body += f"{env_clause} "
command_body += shlex.join(command)
return f"sh -lc {shlex.quote(command_body)}"
@classmethod
def _run_command(cls, client: Any, command: str) -> bytes:
_, stdout, stderr = client.exec_command(command, timeout=cls._SSH_OPERATION_TIMEOUT_SECONDS)
stdout.channel.settimeout(cls._SSH_OPERATION_TIMEOUT_SECONDS)
deadline = time.monotonic() + COMMAND_EXECUTION_TIMEOUT_SECONDS
while not stdout.channel.exit_status_ready():
if time.monotonic() >= deadline:
with contextlib.suppress(Exception):
stdout.channel.close()
raise TimeoutError(f"SSH command timed out after {COMMAND_EXECUTION_TIMEOUT_SECONDS}s")
time.sleep(0.05)
exit_code = stdout.channel.recv_exit_status()
stdout_data = stdout.read()
stderr_data = stderr.read()
if exit_code != 0:
stderr_text = stderr_data.decode("utf-8", errors="replace")
raise RuntimeError(f"SSH command failed ({exit_code}): {stderr_text}")
return stdout_data
def _consume_channel_output(
self,
pid: str,
channel: Any,
stdout_transport: QueueTransportReadCloser,
stderr_transport: QueueTransportReadCloser,
max_runtime_seconds: int,
) -> None:
stdout_writer = stdout_transport.get_write_handler()
stderr_writer = stderr_transport.get_write_handler()
exit_code: int | None = None
started_at = time.monotonic()
try:
while True:
if time.monotonic() - started_at >= max_runtime_seconds:
exit_code = self._COMMAND_TIMEOUT_EXIT_CODE
stderr_writer.write(f"Command timed out after {max_runtime_seconds}s".encode())
break
if channel.recv_ready():
stdout_writer.write(channel.recv(4096))
if channel.recv_stderr_ready():
stderr_writer.write(channel.recv_stderr(4096))
if channel.exit_status_ready() and not channel.recv_ready() and not channel.recv_stderr_ready():
exit_code = int(channel.recv_exit_status())
break
time.sleep(0.05)
except TimeoutError:
logger.warning("SSH channel read timed out for command %s", pid)
exit_code = self._COMMAND_TIMEOUT_EXIT_CODE
finally:
with contextlib.suppress(Exception):
stdout_transport.close()
with contextlib.suppress(Exception):
stderr_transport.close()
with contextlib.suppress(Exception):
channel.close()
with self._lock:
self._command_channels.pop(pid, None)
self._commands[pid] = CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=exit_code)
def _set_sftp_operation_timeout(self, sftp: Any) -> None:
with contextlib.suppress(Exception):
sftp.get_channel().settimeout(self._SSH_OPERATION_TIMEOUT_SECONDS)
@staticmethod
def _parse_arch(raw_arch: str) -> Arch:
arch = raw_arch.lower()
if arch in {"x86_64", "amd64"}:
return Arch.AMD64
if arch in {"arm64", "aarch64"}:
return Arch.ARM64
return Arch.AMD64
@staticmethod
def _parse_os(raw_os: str) -> OperatingSystem:
system_name = raw_os.lower()
if system_name == "darwin":
return OperatingSystem.DARWIN
return OperatingSystem.LINUX
@staticmethod
def _sftp_mkdirs(sftp: Any, directory: str) -> None:
if not directory or directory == "/":
return
path = PurePosixPath(directory)
current = PurePosixPath("/") if path.is_absolute() else PurePosixPath()
for part in path.parts:
if part in ("", "/"):
continue
current = current / part
current_path = str(current)
try:
attrs = sftp.stat(current_path)
if not stat.S_ISDIR(attrs.st_mode):
raise OSError(f"Path exists but is not a directory: {current_path}")
continue
except OSError as e:
missing = isinstance(e, FileNotFoundError) or getattr(e, "errno", None) == 2
missing = missing or "no such file" in str(e).lower()
if not missing:
raise
try:
sftp.mkdir(current_path)
except OSError:
# Some SFTP servers report generic "Failure" when directory already exists.
attrs = sftp.stat(current_path)
if not stat.S_ISDIR(attrs.st_mode):
raise OSError(f"Failed to create directory: {current_path}")

View File

@ -90,6 +90,10 @@ def init_app(app: DifyApp):
app.register_blueprint(inner_api_bp)
app.register_blueprint(mcp_bp)
# TODO: enable after full sandbox integration
# from controllers.cli_api import bp as cli_api_bp
# app.register_blueprint(cli_api_bp)
# Register trigger blueprint with CORS for webhook calls
_apply_cors_once(
trigger_bp,

Some files were not shown because too many files have changed in this diff Show More