mirror of
https://github.com/langgenius/dify.git
synced 2026-05-10 05:56:31 +08:00
feat(api): port Sandbox + VirtualEnvironment + Skill system from feat/support-agent-sandbox (Phase 5-6)
Port the complete infrastructure for agent sandbox execution and skill system: Sandbox & Virtual Environment (core/sandbox/, core/virtual_environment/): - Sandbox entity with lifecycle management (ready/failed/cancelled states) - SandboxBuilder with fluent API for configuring providers - 5 VM providers: Local, SSH, Docker, E2B, AWS CodeInterpreter - VirtualEnvironment base with command execution, file transfer, transport layers - Channel transport: pipe, queue, socket implementations - Bash session management and DifyCli binary integration - Storage: archive storage, file storage, noop storage, presign storage - Initializers: DifyCli, AppAssets, DraftAppAssets, Skills - Inspector: file browser, archive/runtime source, script utils - Security: encryption utils, debug helpers Skill & App Assets (core/skill/, core/app_assets/, core/app_bundle/): - Skill entity and manager - App asset accessor, builder pipeline (file, skill builders) - App bundle source zip extractor - Storage and converter utilities API Endpoints: - CLI API blueprint (controllers/cli_api/) for sandbox callback - Sandbox provider management (workspace/sandbox_providers) - Sandbox file browser (console/sandbox_files) - App asset management (console/app/app_asset) - Skill management (console/app/skills) - Storage file endpoints (controllers/files/storage_files) Services: - Sandbox service, provider service, file service - App asset service, app bundle service Config: - CliApiConfig, CreatorsPlatformConfig, CollaborationConfig - FILES_API_URL for sandbox file access Note: Controller route registration temporarily commented out (marked TODO) pending resolution of deep dependency chains (socketio, workflow_comment, command node, etc.). Core sandbox modules are fully ported and syntax-validated. 110 files changed, 10,549 insertions. Made-with: Cursor
This commit is contained in:
parent
d9d1e9b63a
commit
0c7e7e0c4e
@ -271,6 +271,17 @@ class PluginConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class CliApiConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for CLI API (for dify-cli to call back from external sandbox environments)
|
||||
"""
|
||||
|
||||
CLI_API_URL: str = Field(
|
||||
description="CLI API URL for external sandbox (e.g., e2b) to call back.",
|
||||
default="http://localhost:5001",
|
||||
)
|
||||
|
||||
|
||||
class MarketplaceConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for marketplace
|
||||
@ -287,6 +298,27 @@ class MarketplaceConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class CreatorsPlatformConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for creators platform
|
||||
"""
|
||||
|
||||
CREATORS_PLATFORM_FEATURES_ENABLED: bool = Field(
|
||||
description="Enable or disable creators platform features",
|
||||
default=True,
|
||||
)
|
||||
|
||||
CREATORS_PLATFORM_API_URL: HttpUrl = Field(
|
||||
description="Creators Platform API URL",
|
||||
default=HttpUrl("https://creators.dify.ai"),
|
||||
)
|
||||
|
||||
CREATORS_PLATFORM_OAUTH_CLIENT_ID: str = Field(
|
||||
description="OAuth client_id for the Creators Platform app registered in Dify",
|
||||
default="",
|
||||
)
|
||||
|
||||
|
||||
class EndpointConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for various application endpoints and URLs
|
||||
@ -341,6 +373,15 @@ class FileAccessConfig(BaseSettings):
|
||||
default="",
|
||||
)
|
||||
|
||||
FILES_API_URL: str = Field(
|
||||
description="Base URL for storage file ticket API endpoints."
|
||||
" Used by sandbox containers (internal or external like e2b) that need"
|
||||
" an absolute, routable address to upload/download files via the API."
|
||||
" For all-in-one Docker deployments, set to http://localhost."
|
||||
" For public sandbox environments, set to a public domain or IP.",
|
||||
default="",
|
||||
)
|
||||
|
||||
FILES_ACCESS_TIMEOUT: int = Field(
|
||||
description="Expiration time in seconds for file access URLs",
|
||||
default=300,
|
||||
@ -1274,6 +1315,13 @@ class PositionConfig(BaseSettings):
|
||||
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""}
|
||||
|
||||
|
||||
class CollaborationConfig(BaseSettings):
|
||||
ENABLE_COLLABORATION_MODE: bool = Field(
|
||||
description="Whether to enable collaboration mode features across the workspace",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class LoginConfig(BaseSettings):
|
||||
ENABLE_EMAIL_CODE_LOGIN: bool = Field(
|
||||
description="whether to enable email code login",
|
||||
@ -1375,7 +1423,9 @@ class FeatureConfig(
|
||||
TriggerConfig,
|
||||
AsyncWorkflowConfig,
|
||||
PluginConfig,
|
||||
CliApiConfig,
|
||||
MarketplaceConfig,
|
||||
CreatorsPlatformConfig,
|
||||
DataSetConfig,
|
||||
EndpointConfig,
|
||||
FileAccessConfig,
|
||||
@ -1399,6 +1449,7 @@ class FeatureConfig(
|
||||
WorkflowConfig,
|
||||
WorkflowNodeExecutionConfig,
|
||||
WorkspaceConfig,
|
||||
CollaborationConfig,
|
||||
LoginConfig,
|
||||
AccountConfig,
|
||||
SwaggerUIConfig,
|
||||
|
||||
27
api/controllers/cli_api/__init__.py
Normal file
27
api/controllers/cli_api/__init__.py
Normal file
@ -0,0 +1,27 @@
|
||||
from flask import Blueprint
|
||||
from flask_restx import Namespace
|
||||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
bp = Blueprint("cli_api", __name__, url_prefix="/cli/api")
|
||||
|
||||
api = ExternalApi(
|
||||
bp,
|
||||
version="1.0",
|
||||
title="CLI API",
|
||||
description="APIs for Dify CLI to call back from external sandbox environments (e.g., e2b)",
|
||||
)
|
||||
|
||||
# Create namespace
|
||||
cli_api_ns = Namespace("cli_api", description="CLI API operations", path="/")
|
||||
|
||||
from .dify_cli import cli_api as _plugin
|
||||
|
||||
api.add_namespace(cli_api_ns)
|
||||
|
||||
__all__ = [
|
||||
"_plugin",
|
||||
"api",
|
||||
"bp",
|
||||
"cli_api_ns",
|
||||
]
|
||||
0
api/controllers/cli_api/dify_cli/__init__.py
Normal file
0
api/controllers/cli_api/dify_cli/__init__.py
Normal file
192
api/controllers/cli_api/dify_cli/cli_api.py
Normal file
192
api/controllers/cli_api/dify_cli/cli_api.py
Normal file
@ -0,0 +1,192 @@
|
||||
from flask import abort
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
|
||||
from controllers.cli_api import cli_api_ns
|
||||
from controllers.cli_api.dify_cli.wraps import get_cli_user_tenant, plugin_data
|
||||
from controllers.cli_api.wraps import cli_api_only
|
||||
from controllers.console.wraps import setup_required
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation
|
||||
from core.plugin.backwards_invocation.base import BaseBackwardsInvocationResponse
|
||||
from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation
|
||||
from core.plugin.backwards_invocation.tool import PluginToolBackwardsInvocation
|
||||
from core.plugin.entities.request import (
|
||||
RequestInvokeApp,
|
||||
RequestInvokeLLM,
|
||||
RequestInvokeTool,
|
||||
RequestRequestUploadFile,
|
||||
)
|
||||
from core.sandbox.bash.dify_cli import DifyCliToolConfig
|
||||
from core.session.cli_api import CliContext
|
||||
from core.skill.entities import ToolInvocationRequest
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from graphon.file.helpers import get_signed_file_url_for_plugin
|
||||
from libs.helper import length_prefixed_response
|
||||
from models.account import Account
|
||||
from models.model import EndUser, Tenant
|
||||
|
||||
|
||||
class FetchToolItem(BaseModel):
|
||||
tool_type: str
|
||||
tool_provider: str
|
||||
tool_name: str
|
||||
credential_id: str | None = None
|
||||
|
||||
|
||||
class FetchToolBatchRequest(BaseModel):
|
||||
tools: list[FetchToolItem]
|
||||
|
||||
|
||||
@cli_api_ns.route("/invoke/llm")
|
||||
class CliInvokeLLMApi(Resource):
|
||||
@cli_api_only
|
||||
@get_cli_user_tenant
|
||||
@setup_required
|
||||
@plugin_data(payload_type=RequestInvokeLLM)
|
||||
def post(
|
||||
self,
|
||||
user_model: Account | EndUser,
|
||||
tenant_model: Tenant,
|
||||
payload: RequestInvokeLLM,
|
||||
cli_context: CliContext,
|
||||
):
|
||||
def generator():
|
||||
response = PluginModelBackwardsInvocation.invoke_llm(user_model.id, tenant_model, payload)
|
||||
return PluginModelBackwardsInvocation.convert_to_event_stream(response)
|
||||
|
||||
return length_prefixed_response(0xF, generator())
|
||||
|
||||
|
||||
@cli_api_ns.route("/invoke/tool")
|
||||
class CliInvokeToolApi(Resource):
|
||||
@cli_api_only
|
||||
@get_cli_user_tenant
|
||||
@setup_required
|
||||
@plugin_data(payload_type=RequestInvokeTool)
|
||||
def post(
|
||||
self,
|
||||
user_model: Account | EndUser,
|
||||
tenant_model: Tenant,
|
||||
payload: RequestInvokeTool,
|
||||
cli_context: CliContext,
|
||||
):
|
||||
tool_type = ToolProviderType.value_of(payload.tool_type)
|
||||
|
||||
request = ToolInvocationRequest(
|
||||
tool_type=tool_type,
|
||||
provider=payload.provider,
|
||||
tool_name=payload.tool,
|
||||
credential_id=payload.credential_id,
|
||||
)
|
||||
if cli_context.tool_access and not cli_context.tool_access.is_allowed(request):
|
||||
abort(403, description=f"Access denied for tool: {payload.provider}/{payload.tool}")
|
||||
|
||||
def generator():
|
||||
return PluginToolBackwardsInvocation.convert_to_event_stream(
|
||||
PluginToolBackwardsInvocation.invoke_tool(
|
||||
tenant_id=tenant_model.id,
|
||||
user_id=user_model.id,
|
||||
tool_type=tool_type,
|
||||
provider=payload.provider,
|
||||
tool_name=payload.tool,
|
||||
tool_parameters=payload.tool_parameters,
|
||||
credential_id=payload.credential_id,
|
||||
),
|
||||
)
|
||||
|
||||
return length_prefixed_response(0xF, generator())
|
||||
|
||||
|
||||
@cli_api_ns.route("/invoke/app")
|
||||
class CliInvokeAppApi(Resource):
|
||||
@cli_api_only
|
||||
@get_cli_user_tenant
|
||||
@setup_required
|
||||
@plugin_data(payload_type=RequestInvokeApp)
|
||||
def post(
|
||||
self,
|
||||
user_model: Account | EndUser,
|
||||
tenant_model: Tenant,
|
||||
payload: RequestInvokeApp,
|
||||
cli_context: CliContext,
|
||||
):
|
||||
response = PluginAppBackwardsInvocation.invoke_app(
|
||||
app_id=payload.app_id,
|
||||
user_id=user_model.id,
|
||||
tenant_id=tenant_model.id,
|
||||
conversation_id=payload.conversation_id,
|
||||
query=payload.query,
|
||||
stream=payload.response_mode == "streaming",
|
||||
inputs=payload.inputs,
|
||||
files=payload.files,
|
||||
)
|
||||
|
||||
return length_prefixed_response(0xF, PluginAppBackwardsInvocation.convert_to_event_stream(response))
|
||||
|
||||
|
||||
@cli_api_ns.route("/upload/file/request")
|
||||
class CliUploadFileRequestApi(Resource):
|
||||
@cli_api_only
|
||||
@get_cli_user_tenant
|
||||
@setup_required
|
||||
@plugin_data(payload_type=RequestRequestUploadFile)
|
||||
def post(
|
||||
self,
|
||||
user_model: Account | EndUser,
|
||||
tenant_model: Tenant,
|
||||
payload: RequestRequestUploadFile,
|
||||
cli_context: CliContext,
|
||||
):
|
||||
url = get_signed_file_url_for_plugin(
|
||||
filename=payload.filename,
|
||||
mimetype=payload.mimetype,
|
||||
tenant_id=tenant_model.id,
|
||||
user_id=user_model.id,
|
||||
)
|
||||
return BaseBackwardsInvocationResponse(data={"url": url}).model_dump()
|
||||
|
||||
|
||||
@cli_api_ns.route("/fetch/tools/batch")
|
||||
class CliFetchToolsBatchApi(Resource):
|
||||
@cli_api_only
|
||||
@get_cli_user_tenant
|
||||
@setup_required
|
||||
@plugin_data(payload_type=FetchToolBatchRequest)
|
||||
def post(
|
||||
self,
|
||||
user_model: Account | EndUser,
|
||||
tenant_model: Tenant,
|
||||
payload: FetchToolBatchRequest,
|
||||
cli_context: CliContext,
|
||||
):
|
||||
tools: list[dict] = []
|
||||
|
||||
for item in payload.tools:
|
||||
provider_type = ToolProviderType.value_of(item.tool_type)
|
||||
|
||||
request = ToolInvocationRequest(
|
||||
tool_type=provider_type,
|
||||
provider=item.tool_provider,
|
||||
tool_name=item.tool_name,
|
||||
credential_id=item.credential_id,
|
||||
)
|
||||
if cli_context.tool_access and not cli_context.tool_access.is_allowed(request):
|
||||
abort(403, description=f"Access denied for tool: {item.tool_provider}/{item.tool_name}")
|
||||
|
||||
try:
|
||||
tool_runtime = ToolManager.get_tool_runtime(
|
||||
tenant_id=tenant_model.id,
|
||||
provider_type=provider_type,
|
||||
provider_id=item.tool_provider,
|
||||
tool_name=item.tool_name,
|
||||
invoke_from=InvokeFrom.AGENT,
|
||||
credential_id=item.credential_id,
|
||||
)
|
||||
tool_config = DifyCliToolConfig.create_from_tool(tool_runtime)
|
||||
tools.append(tool_config.model_dump())
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return BaseBackwardsInvocationResponse(data={"tools": tools}).model_dump()
|
||||
137
api/controllers/cli_api/dify_cli/wraps.py
Normal file
137
api/controllers/cli_api/dify_cli/wraps.py
Normal file
@ -0,0 +1,137 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import current_app, g, request
|
||||
from flask_login import user_logged_in
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.session.cli_api import CliApiSession, CliContext
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user
|
||||
from models.account import Tenant
|
||||
from models.model import DefaultEndUserSessionID, EndUser
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class TenantUserPayload(BaseModel):
|
||||
tenant_id: str
|
||||
user_id: str
|
||||
|
||||
|
||||
def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
||||
"""
|
||||
Get current user
|
||||
|
||||
NOTE: user_id is not trusted, it could be maliciously set to any value.
|
||||
As a result, it could only be considered as an end user id.
|
||||
"""
|
||||
if not user_id:
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
is_anonymous = user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
user_model = None
|
||||
|
||||
if is_anonymous:
|
||||
user_model = (
|
||||
session.query(EndUser)
|
||||
.where(
|
||||
EndUser.session_id == user_id,
|
||||
EndUser.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
else:
|
||||
user_model = (
|
||||
session.query(EndUser)
|
||||
.where(
|
||||
EndUser.id == user_id,
|
||||
EndUser.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not user_model:
|
||||
user_model = EndUser(
|
||||
tenant_id=tenant_id,
|
||||
type="service_api",
|
||||
is_anonymous=is_anonymous,
|
||||
session_id=user_id,
|
||||
)
|
||||
session.add(user_model)
|
||||
session.commit()
|
||||
session.refresh(user_model)
|
||||
|
||||
except Exception:
|
||||
raise ValueError("user not found")
|
||||
|
||||
return user_model
|
||||
|
||||
|
||||
def get_cli_user_tenant(view_func: Callable[P, R]):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
session: CliApiSession | None = getattr(g, "cli_api_session", None)
|
||||
if session is None:
|
||||
raise ValueError("session not found")
|
||||
|
||||
user_id = session.user_id
|
||||
tenant_id = session.tenant_id
|
||||
cli_context = CliContext.model_validate(session.context)
|
||||
|
||||
if not user_id:
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
|
||||
try:
|
||||
tenant_model = (
|
||||
db.session.query(Tenant)
|
||||
.where(
|
||||
Tenant.id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
except Exception:
|
||||
raise ValueError("tenant not found")
|
||||
|
||||
if not tenant_model:
|
||||
raise ValueError("tenant not found")
|
||||
|
||||
kwargs["tenant_model"] = tenant_model
|
||||
kwargs["user_model"] = get_user(tenant_id, user_id)
|
||||
kwargs["cli_context"] = cli_context
|
||||
|
||||
current_app.login_manager._update_request_context_with_user(kwargs["user_model"]) # type: ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
|
||||
def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseModel]):
|
||||
def decorator(view_func: Callable[P, R]):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
try:
|
||||
data = request.get_json()
|
||||
except Exception:
|
||||
raise ValueError("invalid json")
|
||||
|
||||
try:
|
||||
payload = payload_type.model_validate(data)
|
||||
except Exception as e:
|
||||
raise ValueError(f"invalid payload: {str(e)}")
|
||||
|
||||
kwargs["payload"] = payload
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
if view is None:
|
||||
return decorator
|
||||
else:
|
||||
return decorator(view)
|
||||
56
api/controllers/cli_api/wraps.py
Normal file
56
api/controllers/cli_api/wraps.py
Normal file
@ -0,0 +1,56 @@
|
||||
import hashlib
|
||||
import hmac
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import abort, g, request
|
||||
|
||||
from core.session.cli_api import CliApiSessionManager
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
SIGNATURE_TTL_SECONDS = 300
|
||||
|
||||
|
||||
def _verify_signature(session_secret: str, timestamp: str, body: bytes, signature: str) -> bool:
|
||||
expected = hmac.new(
|
||||
session_secret.encode(),
|
||||
f"{timestamp}.".encode() + body,
|
||||
hashlib.sha256,
|
||||
).hexdigest()
|
||||
return hmac.compare_digest(f"sha256={expected}", signature)
|
||||
|
||||
|
||||
def cli_api_only(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
session_id = request.headers.get("X-Cli-Api-Session-Id")
|
||||
timestamp = request.headers.get("X-Cli-Api-Timestamp")
|
||||
signature = request.headers.get("X-Cli-Api-Signature")
|
||||
|
||||
if not session_id or not timestamp or not signature:
|
||||
abort(401)
|
||||
|
||||
try:
|
||||
ts = int(timestamp)
|
||||
if abs(time.time() - ts) > SIGNATURE_TTL_SECONDS:
|
||||
abort(401)
|
||||
except ValueError:
|
||||
abort(401)
|
||||
|
||||
session = CliApiSessionManager().get(session_id)
|
||||
if not session:
|
||||
abort(401)
|
||||
|
||||
body = request.get_data()
|
||||
if not _verify_signature(session.secret, timestamp, body, signature):
|
||||
abort(401)
|
||||
|
||||
g.cli_api_session = session
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
@ -41,6 +41,7 @@ from . import (
|
||||
init_validate,
|
||||
notification,
|
||||
ping,
|
||||
# sandbox_files, # TODO: enable after full sandbox integration
|
||||
setup,
|
||||
spec,
|
||||
version,
|
||||
@ -52,6 +53,7 @@ from .app import (
|
||||
agent,
|
||||
annotation,
|
||||
app,
|
||||
# app_asset, # TODO: enable after full sandbox integration
|
||||
audio,
|
||||
completion,
|
||||
conversation,
|
||||
@ -62,6 +64,7 @@ from .app import (
|
||||
model_config,
|
||||
ops_trace,
|
||||
site,
|
||||
# skills, # TODO: enable after full sandbox integration
|
||||
statistic,
|
||||
workflow,
|
||||
workflow_app_log,
|
||||
@ -130,6 +133,7 @@ from .workspace import (
|
||||
model_providers,
|
||||
models,
|
||||
plugin,
|
||||
# sandbox_providers, # TODO: enable after full sandbox integration
|
||||
tool_providers,
|
||||
trigger_providers,
|
||||
workspace,
|
||||
|
||||
333
api/controllers/console/app/app_asset.py
Normal file
333
api/controllers/console/app/app_asset.py
Normal file
@ -0,0 +1,333 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
AppAssetNodeNotFoundError,
|
||||
AppAssetPathConflictError,
|
||||
)
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.app.entities.app_asset_entities import BatchUploadNode
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from services.app_asset_service import AppAssetService
|
||||
from services.errors.app_asset import (
|
||||
AppAssetNodeNotFoundError as ServiceNodeNotFoundError,
|
||||
)
|
||||
from services.errors.app_asset import (
|
||||
AppAssetParentNotFoundError,
|
||||
)
|
||||
from services.errors.app_asset import (
|
||||
AppAssetPathConflictError as ServicePathConflictError,
|
||||
)
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class CreateFolderPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
parent_id: str | None = None
|
||||
|
||||
|
||||
class CreateFilePayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
parent_id: str | None = None
|
||||
|
||||
@field_validator("name", mode="before")
|
||||
@classmethod
|
||||
def strip_name(cls, v: str) -> str:
|
||||
return v.strip() if isinstance(v, str) else v
|
||||
|
||||
@field_validator("parent_id", mode="before")
|
||||
@classmethod
|
||||
def empty_to_none(cls, v: str | None) -> str | None:
|
||||
return v or None
|
||||
|
||||
|
||||
class GetUploadUrlPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
size: int = Field(..., ge=0)
|
||||
parent_id: str | None = None
|
||||
|
||||
@field_validator("name", mode="before")
|
||||
@classmethod
|
||||
def strip_name(cls, v: str) -> str:
|
||||
return v.strip() if isinstance(v, str) else v
|
||||
|
||||
@field_validator("parent_id", mode="before")
|
||||
@classmethod
|
||||
def empty_to_none(cls, v: str | None) -> str | None:
|
||||
return v or None
|
||||
|
||||
|
||||
class BatchUploadPayload(BaseModel):
|
||||
children: list[BatchUploadNode] = Field(..., min_length=1)
|
||||
parent_id: str | None = None
|
||||
|
||||
@field_validator("parent_id", mode="before")
|
||||
@classmethod
|
||||
def empty_to_none(cls, v: str | None) -> str | None:
|
||||
return v or None
|
||||
|
||||
|
||||
class UpdateFileContentPayload(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class RenameNodePayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
|
||||
|
||||
class MoveNodePayload(BaseModel):
|
||||
parent_id: str | None = None
|
||||
|
||||
|
||||
class ReorderNodePayload(BaseModel):
|
||||
after_node_id: str | None = Field(default=None, description="Place after this node, None for first position")
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]) -> None:
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(CreateFolderPayload)
|
||||
reg(CreateFilePayload)
|
||||
reg(GetUploadUrlPayload)
|
||||
reg(BatchUploadNode)
|
||||
reg(BatchUploadPayload)
|
||||
reg(UpdateFileContentPayload)
|
||||
reg(RenameNodePayload)
|
||||
reg(MoveNodePayload)
|
||||
reg(ReorderNodePayload)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/tree")
|
||||
class AppAssetTreeResource(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
tree = AppAssetService.get_asset_tree(app_model, current_user.id)
|
||||
return {"children": [view.model_dump() for view in tree.transform()]}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/folders")
|
||||
class AppAssetFolderResource(Resource):
|
||||
@console_ns.expect(console_ns.models[CreateFolderPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = CreateFolderPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
node = AppAssetService.create_folder(app_model, current_user.id, payload.name, payload.parent_id)
|
||||
return node.model_dump(), 201
|
||||
except AppAssetParentNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
except ServicePathConflictError:
|
||||
raise AppAssetPathConflictError()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/files/<string:node_id>")
|
||||
class AppAssetFileDetailResource(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, node_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
try:
|
||||
content = AppAssetService.get_file_content(app_model, current_user.id, node_id)
|
||||
return {"content": content.decode("utf-8", errors="replace")}
|
||||
except ServiceNodeNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
|
||||
@console_ns.expect(console_ns.models[UpdateFileContentPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def put(self, app_model: App, node_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
file = request.files.get("file")
|
||||
if file:
|
||||
content = file.read()
|
||||
else:
|
||||
payload = UpdateFileContentPayload.model_validate(console_ns.payload or {})
|
||||
content = payload.content.encode("utf-8")
|
||||
|
||||
try:
|
||||
node = AppAssetService.update_file_content(app_model, current_user.id, node_id, content)
|
||||
return node.model_dump()
|
||||
except ServiceNodeNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/nodes/<string:node_id>")
|
||||
class AppAssetNodeResource(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def delete(self, app_model: App, node_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
try:
|
||||
AppAssetService.delete_node(app_model, current_user.id, node_id)
|
||||
return {"result": "success"}, 200
|
||||
except ServiceNodeNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/nodes/<string:node_id>/rename")
|
||||
class AppAssetNodeRenameResource(Resource):
|
||||
@console_ns.expect(console_ns.models[RenameNodePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App, node_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = RenameNodePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
node = AppAssetService.rename_node(app_model, current_user.id, node_id, payload.name)
|
||||
return node.model_dump()
|
||||
except ServiceNodeNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
except ServicePathConflictError:
|
||||
raise AppAssetPathConflictError()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/nodes/<string:node_id>/move")
|
||||
class AppAssetNodeMoveResource(Resource):
|
||||
@console_ns.expect(console_ns.models[MoveNodePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App, node_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = MoveNodePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
node = AppAssetService.move_node(app_model, current_user.id, node_id, payload.parent_id)
|
||||
return node.model_dump()
|
||||
except ServiceNodeNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
except AppAssetParentNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
except ServicePathConflictError:
|
||||
raise AppAssetPathConflictError()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/nodes/<string:node_id>/reorder")
|
||||
class AppAssetNodeReorderResource(Resource):
|
||||
@console_ns.expect(console_ns.models[ReorderNodePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App, node_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = ReorderNodePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
node = AppAssetService.reorder_node(app_model, current_user.id, node_id, payload.after_node_id)
|
||||
return node.model_dump()
|
||||
except ServiceNodeNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/files/<string:node_id>/download-url")
|
||||
class AppAssetFileDownloadUrlResource(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, node_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
try:
|
||||
download_url = AppAssetService.get_file_download_url(app_model, current_user.id, node_id)
|
||||
return {"download_url": download_url}
|
||||
except ServiceNodeNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/files/upload")
|
||||
class AppAssetFileUploadUrlResource(Resource):
|
||||
@console_ns.expect(console_ns.models[GetUploadUrlPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = GetUploadUrlPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
node, upload_url = AppAssetService.get_file_upload_url(
|
||||
app_model, current_user.id, payload.name, payload.size, payload.parent_id
|
||||
)
|
||||
return {"node": node.model_dump(), "upload_url": upload_url}, 201
|
||||
except AppAssetParentNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
except ServicePathConflictError:
|
||||
raise AppAssetPathConflictError()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/batch-upload")
|
||||
class AppAssetBatchUploadResource(Resource):
|
||||
@console_ns.expect(console_ns.models[BatchUploadPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
"""
|
||||
Create nodes from tree structure and return upload URLs.
|
||||
|
||||
Input:
|
||||
{
|
||||
"parent_id": "optional-target-folder-id",
|
||||
"children": [
|
||||
{"name": "folder1", "node_type": "folder", "children": [
|
||||
{"name": "file1.txt", "node_type": "file", "size": 1024}
|
||||
]},
|
||||
{"name": "root.txt", "node_type": "file", "size": 512}
|
||||
]
|
||||
}
|
||||
|
||||
Output:
|
||||
{
|
||||
"children": [
|
||||
{"id": "xxx", "name": "folder1", "node_type": "folder", "children": [
|
||||
{"id": "yyy", "name": "file1.txt", "node_type": "file", "size": 1024, "upload_url": "..."}
|
||||
]},
|
||||
{"id": "zzz", "name": "root.txt", "node_type": "file", "size": 512, "upload_url": "..."}
|
||||
]
|
||||
}
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = BatchUploadPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
result_children = AppAssetService.batch_create_from_tree(
|
||||
app_model,
|
||||
current_user.id,
|
||||
payload.children,
|
||||
parent_id=payload.parent_id,
|
||||
)
|
||||
return {"children": [child.model_dump() for child in result_children]}, 201
|
||||
except AppAssetParentNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
except ServicePathConflictError:
|
||||
raise AppAssetPathConflictError()
|
||||
38
api/controllers/console/app/skills.py
Normal file
38
api/controllers/console/app/skills.py
Normal file
@ -0,0 +1,38 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, current_account_with_tenant, setup_required
|
||||
from libs.login import login_required
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from services.skill_service import SkillService
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/llm/skills")
|
||||
class NodeSkillsApi(Resource):
|
||||
"""Extract tool dependencies from an LLM node's skill prompts.
|
||||
|
||||
The client sends the full node ``data`` object in the request body.
|
||||
The server real-time builds a ``SkillBundle`` from the current draft
|
||||
``.md`` assets and resolves transitive tool dependencies — no cached
|
||||
bundle is used.
|
||||
"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
node_data = request.get_json(force=True)
|
||||
if not isinstance(node_data, dict):
|
||||
return {"tool_dependencies": []}
|
||||
|
||||
tool_deps = SkillService.extract_tool_dependencies(
|
||||
app=app_model,
|
||||
node_data=node_data,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
return {"tool_dependencies": [d.model_dump() for d in tool_deps]}
|
||||
103
api/controllers/console/sandbox_files.py
Normal file
103
api/controllers/console/sandbox_files.py
Normal file
@ -0,0 +1,103 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.sandbox.sandbox_file_service import SandboxFileService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class SandboxFileListQuery(BaseModel):
|
||||
path: str | None = Field(default=None, description="Workspace relative path")
|
||||
recursive: bool = Field(default=False, description="List recursively")
|
||||
|
||||
|
||||
class SandboxFileDownloadRequest(BaseModel):
|
||||
path: str = Field(..., description="Workspace relative file path")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
SandboxFileListQuery.__name__,
|
||||
SandboxFileListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
SandboxFileDownloadRequest.__name__,
|
||||
SandboxFileDownloadRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
SANDBOX_FILE_NODE_FIELDS = {
|
||||
"path": fields.String,
|
||||
"is_dir": fields.Boolean,
|
||||
"size": fields.Raw,
|
||||
"mtime": fields.Raw,
|
||||
"extension": fields.String,
|
||||
}
|
||||
|
||||
|
||||
SANDBOX_FILE_DOWNLOAD_TICKET_FIELDS = {
|
||||
"download_url": fields.String,
|
||||
"expires_in": fields.Integer,
|
||||
"export_id": fields.String,
|
||||
}
|
||||
|
||||
|
||||
sandbox_file_node_model = console_ns.model("SandboxFileNode", SANDBOX_FILE_NODE_FIELDS)
|
||||
sandbox_file_download_ticket_model = console_ns.model("SandboxFileDownloadTicket", SANDBOX_FILE_DOWNLOAD_TICKET_FIELDS)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/sandbox/files")
|
||||
class SandboxFilesApi(Resource):
|
||||
"""List sandbox files for the current user.
|
||||
|
||||
The sandbox_id is derived from the current user's ID, as each user has
|
||||
their own sandbox workspace per app.
|
||||
"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.expect(console_ns.models[SandboxFileListQuery.__name__])
|
||||
@console_ns.marshal_list_with(sandbox_file_node_model)
|
||||
def get(self, app_id: str):
|
||||
args = SandboxFileListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore[arg-type]
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
sandbox_id = account.id
|
||||
return jsonable_encoder(
|
||||
SandboxFileService.list_files(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
sandbox_id=sandbox_id,
|
||||
path=args.path,
|
||||
recursive=args.recursive,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/sandbox/files/download")
|
||||
class SandboxFileDownloadApi(Resource):
|
||||
"""Download a sandbox file for the current user.
|
||||
|
||||
The sandbox_id is derived from the current user's ID, as each user has
|
||||
their own sandbox workspace per app.
|
||||
"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.expect(console_ns.models[SandboxFileDownloadRequest.__name__])
|
||||
@console_ns.marshal_with(sandbox_file_download_ticket_model)
|
||||
def post(self, app_id: str):
|
||||
payload = SandboxFileDownloadRequest.model_validate(console_ns.payload or {})
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
sandbox_id = account.id
|
||||
res = SandboxFileService.download_file(
|
||||
tenant_id=tenant_id, app_id=app_id, sandbox_id=sandbox_id, path=payload.path
|
||||
)
|
||||
return jsonable_encoder(res)
|
||||
104
api/controllers/console/workspace/sandbox_providers.py
Normal file
104
api/controllers/console/workspace/sandbox_providers.py
Normal file
@ -0,0 +1,104 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.sandbox.sandbox_provider_service import SandboxProviderService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SandboxProviderConfigRequest(BaseModel):
|
||||
config: dict
|
||||
activate: bool = False
|
||||
|
||||
|
||||
class SandboxProviderActivateRequest(BaseModel):
|
||||
type: str
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/sandbox-providers")
|
||||
class SandboxProviderListApi(Resource):
|
||||
@console_ns.doc("list_sandbox_providers")
|
||||
@console_ns.doc(description="Get list of available sandbox providers with configuration status")
|
||||
@console_ns.response(200, "Success", fields.List(fields.Raw(description="Sandbox provider information")))
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
providers = SandboxProviderService.list_providers(current_tenant_id)
|
||||
return jsonable_encoder([p.model_dump() for p in providers])
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/sandbox-provider/<string:provider_type>/config")
|
||||
class SandboxProviderConfigApi(Resource):
|
||||
@console_ns.doc("save_sandbox_provider_config")
|
||||
@console_ns.doc(description="Save or update configuration for a sandbox provider")
|
||||
@console_ns.response(200, "Success")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_type: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
args = SandboxProviderConfigRequest.model_validate(request.get_json())
|
||||
|
||||
try:
|
||||
result = SandboxProviderService.save_config(
|
||||
tenant_id=current_tenant_id,
|
||||
provider_type=provider_type,
|
||||
config=args.config,
|
||||
activate=args.activate,
|
||||
)
|
||||
return result
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
@console_ns.doc("delete_sandbox_provider_config")
|
||||
@console_ns.doc(description="Delete configuration for a sandbox provider")
|
||||
@console_ns.response(200, "Success")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider_type: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
result = SandboxProviderService.delete_config(
|
||||
tenant_id=current_tenant_id,
|
||||
provider_type=provider_type,
|
||||
)
|
||||
return result
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/sandbox-provider/<string:provider_type>/activate")
|
||||
class SandboxProviderActivateApi(Resource):
|
||||
"""Activate a sandbox provider."""
|
||||
|
||||
@console_ns.doc("activate_sandbox_provider")
|
||||
@console_ns.doc(description="Activate a sandbox provider for the current workspace")
|
||||
@console_ns.response(200, "Success")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_type: str):
|
||||
"""Activate a sandbox provider."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
args = SandboxProviderActivateRequest.model_validate(request.get_json())
|
||||
result = SandboxProviderService.activate_provider(
|
||||
tenant_id=current_tenant_id,
|
||||
provider_type=provider_type,
|
||||
type=args.type,
|
||||
)
|
||||
return result
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
80
api/controllers/files/storage_files.py
Normal file
80
api/controllers/files/storage_files.py
Normal file
@ -0,0 +1,80 @@
|
||||
"""Token-based file proxy controller for storage operations.
|
||||
|
||||
This controller handles file download and upload operations using opaque UUID tokens.
|
||||
The token maps to the real storage key in Redis, so the actual storage path is never
|
||||
exposed in the URL.
|
||||
|
||||
Routes:
|
||||
GET /files/storage-files/{token} - Download a file
|
||||
PUT /files/storage-files/{token} - Upload a file
|
||||
|
||||
The operation type (download/upload) is determined by the ticket stored in Redis,
|
||||
not by the HTTP method. This ensures a download ticket cannot be used for upload
|
||||
and vice versa.
|
||||
"""
|
||||
|
||||
from urllib.parse import quote
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import Forbidden, NotFound, RequestEntityTooLarge
|
||||
|
||||
from controllers.files import files_ns
|
||||
from extensions.ext_storage import storage
|
||||
from services.storage_ticket_service import StorageTicketService
|
||||
|
||||
|
||||
@files_ns.route("/storage-files/<string:token>")
|
||||
class StorageFilesApi(Resource):
|
||||
"""Handle file operations through token-based URLs."""
|
||||
|
||||
def get(self, token: str):
|
||||
"""Download a file using a token.
|
||||
|
||||
The ticket must have op="download", otherwise returns 403.
|
||||
"""
|
||||
ticket = StorageTicketService.get_ticket(token)
|
||||
if ticket is None:
|
||||
raise Forbidden("Invalid or expired token")
|
||||
|
||||
if ticket.op != "download":
|
||||
raise Forbidden("This token is not valid for download")
|
||||
|
||||
try:
|
||||
generator = storage.load_stream(ticket.storage_key)
|
||||
except FileNotFoundError:
|
||||
raise NotFound("File not found")
|
||||
|
||||
filename = ticket.filename or ticket.storage_key.rsplit("/", 1)[-1]
|
||||
encoded_filename = quote(filename)
|
||||
|
||||
return Response(
|
||||
generator,
|
||||
mimetype="application/octet-stream",
|
||||
direct_passthrough=True,
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
|
||||
},
|
||||
)
|
||||
|
||||
def put(self, token: str):
|
||||
"""Upload a file using a token.
|
||||
|
||||
The ticket must have op="upload", otherwise returns 403.
|
||||
If the request body exceeds max_bytes, returns 413.
|
||||
"""
|
||||
ticket = StorageTicketService.get_ticket(token)
|
||||
if ticket is None:
|
||||
raise Forbidden("Invalid or expired token")
|
||||
|
||||
if ticket.op != "upload":
|
||||
raise Forbidden("This token is not valid for upload")
|
||||
|
||||
content = request.get_data()
|
||||
|
||||
if ticket.max_bytes is not None and len(content) > ticket.max_bytes:
|
||||
raise RequestEntityTooLarge(f"Upload exceeds maximum size of {ticket.max_bytes} bytes")
|
||||
|
||||
storage.save(ticket.storage_key, content)
|
||||
|
||||
return Response(status=204)
|
||||
0
api/core/app/layers/__init__.py
Normal file
0
api/core/app/layers/__init__.py
Normal file
22
api/core/app/layers/sandbox_layer.py
Normal file
22
api/core/app/layers/sandbox_layer.py
Normal file
@ -0,0 +1,22 @@
|
||||
import logging
|
||||
|
||||
from core.sandbox import Sandbox
|
||||
from graphon.graph_engine.layers.base import GraphEngineLayer
|
||||
from graphon.graph_events.base import GraphEngineEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SandboxLayer(GraphEngineLayer):
|
||||
def __init__(self, sandbox: Sandbox) -> None:
|
||||
super().__init__()
|
||||
self._sandbox = sandbox
|
||||
|
||||
def on_graph_start(self) -> None:
|
||||
pass
|
||||
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
pass
|
||||
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
self._sandbox.release()
|
||||
13
api/core/app_assets/__init__.py
Normal file
13
api/core/app_assets/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
from .constants import AppAssetsAttrs
|
||||
from .entities import (
|
||||
AssetItem,
|
||||
SkillAsset,
|
||||
)
|
||||
from .storage import AssetPaths
|
||||
|
||||
__all__ = [
|
||||
"AppAssetsAttrs",
|
||||
"AssetItem",
|
||||
"AssetPaths",
|
||||
"SkillAsset",
|
||||
]
|
||||
180
api/core/app_assets/accessor.py
Normal file
180
api/core/app_assets/accessor.py
Normal file
@ -0,0 +1,180 @@
|
||||
"""Unified content accessor for app asset nodes.
|
||||
|
||||
Accessor is scoped to a single app (tenant_id + app_id), not a single node.
|
||||
All methods accept an AppAssetNode parameter to identify the target.
|
||||
|
||||
CachedContentAccessor is the primary entry point:
|
||||
- Reads DB first, misses fall through to S3 with sync backfill.
|
||||
- Writes go to both DB and S3 (dual-write).
|
||||
- resolve_items() batch-enriches AssetItem lists with DB-cached content
|
||||
(extension-agnostic), so callers never need to filter by extension.
|
||||
- Wraps an internal _StorageAccessor for S3 I/O.
|
||||
|
||||
Collaborators:
|
||||
- services.asset_content_service.AssetContentService (DB layer)
|
||||
- core.app_assets.storage.AssetPaths (S3 key generation)
|
||||
- extensions.storage.cached_presign_storage.CachedPresignStorage (S3 I/O)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from core.app.entities.app_asset_entities import AppAssetNode
|
||||
from core.app_assets.entities.assets import AssetItem
|
||||
from core.app_assets.storage import AssetPaths
|
||||
from extensions.storage.cached_presign_storage import CachedPresignStorage
|
||||
from services.asset_content_service import AssetContentService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# S3-only implementation (internal, used as inner delegate)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _StorageAccessor:
|
||||
"""Reads/writes draft content via object storage (S3) only."""
|
||||
|
||||
_storage: CachedPresignStorage
|
||||
_tenant_id: str
|
||||
_app_id: str
|
||||
|
||||
def __init__(self, storage: CachedPresignStorage, tenant_id: str, app_id: str) -> None:
|
||||
self._storage = storage
|
||||
self._tenant_id = tenant_id
|
||||
self._app_id = app_id
|
||||
|
||||
def _key(self, node: AppAssetNode) -> str:
|
||||
return AssetPaths.draft(self._tenant_id, self._app_id, node.id)
|
||||
|
||||
def load(self, node: AppAssetNode) -> bytes:
|
||||
return self._storage.load_once(self._key(node))
|
||||
|
||||
def save(self, node: AppAssetNode, content: bytes) -> None:
|
||||
self._storage.save(self._key(node), content)
|
||||
|
||||
def delete(self, node: AppAssetNode) -> None:
|
||||
try:
|
||||
self._storage.delete(self._key(node))
|
||||
except Exception:
|
||||
logger.warning("Failed to delete storage key %s", self._key(node), exc_info=True)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DB-cached implementation (the public API)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class CachedContentAccessor:
|
||||
"""App-level content accessor with DB read-through cache over S3.
|
||||
|
||||
Read path: DB first -> miss -> S3 fallback -> sync backfill DB
|
||||
Write path: DB upsert + S3 save (dual-write)
|
||||
Delete path: DB delete + S3 delete
|
||||
|
||||
bulk_load uses a single SQL query for all nodes, with S3 fallback per miss.
|
||||
|
||||
Usage:
|
||||
accessor = CachedContentAccessor(storage, tenant_id, app_id)
|
||||
content = accessor.load(node)
|
||||
accessor.save(node, content)
|
||||
results = accessor.bulk_load(nodes)
|
||||
"""
|
||||
|
||||
_inner: _StorageAccessor
|
||||
_tenant_id: str
|
||||
_app_id: str
|
||||
|
||||
def __init__(self, storage: CachedPresignStorage, tenant_id: str, app_id: str) -> None:
|
||||
self._inner = _StorageAccessor(storage, tenant_id, app_id)
|
||||
self._tenant_id = tenant_id
|
||||
self._app_id = app_id
|
||||
|
||||
def load(self, node: AppAssetNode) -> bytes:
|
||||
# 1. Try DB
|
||||
cached = AssetContentService.get(self._tenant_id, self._app_id, node.id)
|
||||
if cached is not None:
|
||||
return cached.encode("utf-8")
|
||||
|
||||
# 2. Fallback to S3
|
||||
data = self._inner.load(node)
|
||||
|
||||
# 3. Sync backfill DB
|
||||
AssetContentService.upsert(
|
||||
tenant_id=self._tenant_id,
|
||||
app_id=self._app_id,
|
||||
node_id=node.id,
|
||||
content=data.decode("utf-8"),
|
||||
size=len(data),
|
||||
)
|
||||
return data
|
||||
|
||||
def bulk_load(self, nodes: list[AppAssetNode]) -> dict[str, bytes]:
|
||||
"""Single SQL for all nodes, S3 fallback + backfill per miss."""
|
||||
result: dict[str, bytes] = {}
|
||||
node_ids = [n.id for n in nodes]
|
||||
cached = AssetContentService.get_many(self._tenant_id, self._app_id, node_ids)
|
||||
|
||||
for node in nodes:
|
||||
if node.id in cached:
|
||||
result[node.id] = cached[node.id].encode("utf-8")
|
||||
else:
|
||||
# S3 fallback + sync backfill
|
||||
data = self._inner.load(node)
|
||||
AssetContentService.upsert(
|
||||
tenant_id=self._tenant_id,
|
||||
app_id=self._app_id,
|
||||
node_id=node.id,
|
||||
content=data.decode("utf-8"),
|
||||
size=len(data),
|
||||
)
|
||||
result[node.id] = data
|
||||
return result
|
||||
|
||||
def save(self, node: AppAssetNode, content: bytes) -> None:
|
||||
# Dual-write: DB + S3
|
||||
AssetContentService.upsert(
|
||||
tenant_id=self._tenant_id,
|
||||
app_id=self._app_id,
|
||||
node_id=node.id,
|
||||
content=content.decode("utf-8"),
|
||||
size=len(content),
|
||||
)
|
||||
self._inner.save(node, content)
|
||||
|
||||
def resolve_items(self, items: list[AssetItem]) -> list[AssetItem]:
|
||||
"""Batch-enrich asset items with DB-cached content.
|
||||
|
||||
Queries by ``asset_id`` only — extension-agnostic. Items without
|
||||
a DB cache row keep their original *content* value (typically
|
||||
``None``), so only genuinely cached assets (e.g. ``.md`` skill
|
||||
documents) get populated.
|
||||
|
||||
This eliminates the need for callers to filter by file extension
|
||||
before deciding whether to read from the DB cache.
|
||||
"""
|
||||
if not items:
|
||||
return items
|
||||
|
||||
node_ids = [a.asset_id for a in items]
|
||||
cached = AssetContentService.get_many(self._tenant_id, self._app_id, node_ids)
|
||||
|
||||
if not cached:
|
||||
return items
|
||||
|
||||
return [
|
||||
AssetItem(
|
||||
asset_id=a.asset_id,
|
||||
path=a.path,
|
||||
file_name=a.file_name,
|
||||
extension=a.extension,
|
||||
storage_key=a.storage_key,
|
||||
content=cached[a.asset_id].encode("utf-8") if a.asset_id in cached else a.content,
|
||||
)
|
||||
for a in items
|
||||
]
|
||||
|
||||
def delete(self, node: AppAssetNode) -> None:
|
||||
AssetContentService.delete(self._tenant_id, self._app_id, node.id)
|
||||
self._inner.delete(node)
|
||||
12
api/core/app_assets/builder/__init__.py
Normal file
12
api/core/app_assets/builder/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
from .base import AssetBuilder, BuildContext
|
||||
from .file_builder import FileBuilder
|
||||
from .pipeline import AssetBuildPipeline
|
||||
from .skill_builder import SkillBuilder
|
||||
|
||||
__all__ = [
|
||||
"AssetBuildPipeline",
|
||||
"AssetBuilder",
|
||||
"BuildContext",
|
||||
"FileBuilder",
|
||||
"SkillBuilder",
|
||||
]
|
||||
20
api/core/app_assets/builder/base.py
Normal file
20
api/core/app_assets/builder/base.py
Normal file
@ -0,0 +1,20 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Protocol
|
||||
|
||||
from core.app.entities.app_asset_entities import AppAssetFileTree, AppAssetNode
|
||||
from core.app_assets.entities import AssetItem
|
||||
|
||||
|
||||
@dataclass
|
||||
class BuildContext:
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
build_id: str
|
||||
|
||||
|
||||
class AssetBuilder(Protocol):
|
||||
def accept(self, node: AppAssetNode) -> bool: ...
|
||||
|
||||
def collect(self, node: AppAssetNode, path: str, ctx: BuildContext) -> None: ...
|
||||
|
||||
def build(self, tree: AppAssetFileTree, ctx: BuildContext) -> list[AssetItem]: ...
|
||||
30
api/core/app_assets/builder/file_builder.py
Normal file
30
api/core/app_assets/builder/file_builder.py
Normal file
@ -0,0 +1,30 @@
|
||||
from core.app.entities.app_asset_entities import AppAssetFileTree, AppAssetNode
|
||||
from core.app_assets.entities import AssetItem
|
||||
from core.app_assets.storage import AssetPaths
|
||||
|
||||
from .base import BuildContext
|
||||
|
||||
|
||||
class FileBuilder:
|
||||
_nodes: list[tuple[AppAssetNode, str]]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._nodes = []
|
||||
|
||||
def accept(self, node: AppAssetNode) -> bool:
|
||||
return True
|
||||
|
||||
def collect(self, node: AppAssetNode, path: str, ctx: BuildContext) -> None:
|
||||
self._nodes.append((node, path))
|
||||
|
||||
def build(self, tree: AppAssetFileTree, ctx: BuildContext) -> list[AssetItem]:
|
||||
return [
|
||||
AssetItem(
|
||||
asset_id=node.id,
|
||||
path=path,
|
||||
file_name=node.name,
|
||||
extension=node.extension or "",
|
||||
storage_key=AssetPaths.draft(ctx.tenant_id, ctx.app_id, node.id),
|
||||
)
|
||||
for node, path in self._nodes
|
||||
]
|
||||
27
api/core/app_assets/builder/pipeline.py
Normal file
27
api/core/app_assets/builder/pipeline.py
Normal file
@ -0,0 +1,27 @@
|
||||
from core.app.entities.app_asset_entities import AppAssetFileTree
|
||||
from core.app_assets.entities import AssetItem
|
||||
|
||||
from .base import AssetBuilder, BuildContext
|
||||
|
||||
|
||||
class AssetBuildPipeline:
|
||||
_builders: list[AssetBuilder]
|
||||
|
||||
def __init__(self, builders: list[AssetBuilder]) -> None:
|
||||
self._builders = builders
|
||||
|
||||
def build_all(self, tree: AppAssetFileTree, ctx: BuildContext) -> list[AssetItem]:
|
||||
# 1. Distribute: each node goes to first accepting builder
|
||||
for node in tree.walk_files():
|
||||
path = tree.get_path(node.id)
|
||||
for builder in self._builders:
|
||||
if builder.accept(node):
|
||||
builder.collect(node, path, ctx)
|
||||
break
|
||||
|
||||
# 2. Each builder builds its collected nodes
|
||||
results: list[AssetItem] = []
|
||||
for builder in self._builders:
|
||||
results.extend(builder.build(tree, ctx))
|
||||
|
||||
return results
|
||||
96
api/core/app_assets/builder/skill_builder.py
Normal file
96
api/core/app_assets/builder/skill_builder.py
Normal file
@ -0,0 +1,96 @@
|
||||
"""Builder that compiles ``.md`` skill documents into resolved content.
|
||||
|
||||
The builder reads raw draft content from the DB-backed accessor, parses
|
||||
each into a ``SkillDocument``, assembles a ``SkillBundle`` (with
|
||||
transitive tool/file dependency resolution), and returns ``AssetItem``
|
||||
objects whose *content* field carries the resolved bytes in-process.
|
||||
|
||||
The assembled ``SkillBundle`` is persisted via ``SkillManager``
|
||||
(S3 + Redis) **and** retained on the ``bundle`` property so that
|
||||
callers (e.g. ``DraftAppAssetsInitializer``) can pass it directly to
|
||||
``sandbox.attrs`` without a redundant Redis/S3 round-trip.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from core.app.entities.app_asset_entities import AppAssetFileTree, AppAssetNode
|
||||
from core.app_assets.accessor import CachedContentAccessor
|
||||
from core.app_assets.entities import AssetItem
|
||||
from core.skill.assembler import SkillBundleAssembler
|
||||
from core.skill.entities.skill_bundle import SkillBundle
|
||||
from core.skill.entities.skill_document import SkillDocument
|
||||
|
||||
from .base import BuildContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SkillBuilder:
|
||||
_nodes: list[tuple[AppAssetNode, str]]
|
||||
_accessor: CachedContentAccessor
|
||||
_bundle: SkillBundle | None
|
||||
|
||||
def __init__(self, accessor: CachedContentAccessor) -> None:
|
||||
self._nodes = []
|
||||
self._accessor = accessor
|
||||
self._bundle = None
|
||||
|
||||
@property
|
||||
def bundle(self) -> SkillBundle | None:
|
||||
"""The ``SkillBundle`` produced by the last ``build()`` call, or *None*."""
|
||||
return self._bundle
|
||||
|
||||
def accept(self, node: AppAssetNode) -> bool:
|
||||
return node.extension == "md"
|
||||
|
||||
def collect(self, node: AppAssetNode, path: str, ctx: BuildContext) -> None:
|
||||
self._nodes.append((node, path))
|
||||
|
||||
def build(self, tree: AppAssetFileTree, ctx: BuildContext) -> list[AssetItem]:
|
||||
from core.skill.skill_manager import SkillManager
|
||||
|
||||
if not self._nodes:
|
||||
bundle = SkillBundle(assets_id=ctx.build_id, asset_tree=tree)
|
||||
SkillManager.save_bundle(ctx.tenant_id, ctx.app_id, ctx.build_id, bundle)
|
||||
self._bundle = bundle
|
||||
return []
|
||||
|
||||
# Batch-load all skill draft content in one DB query (with S3 fallback on miss).
|
||||
nodes_only = [node for node, _ in self._nodes]
|
||||
raw_contents = self._accessor.bulk_load(nodes_only)
|
||||
|
||||
# Parse documents — skip nodes whose draft content is still the empty
|
||||
# placeholder written at creation time.
|
||||
documents: dict[str, SkillDocument] = {}
|
||||
for node, _ in self._nodes:
|
||||
try:
|
||||
raw = raw_contents.get(node.id)
|
||||
if not raw:
|
||||
continue
|
||||
data = {"skill_id": node.id, **json.loads(raw)}
|
||||
documents[node.id] = SkillDocument.model_validate(data)
|
||||
except (FileNotFoundError, json.JSONDecodeError, TypeError, ValueError) as e:
|
||||
logger.exception("Failed to load or parse skill document for node %s", node.id)
|
||||
raise ValueError(f"Failed to load or parse skill document for node {node.id}") from e
|
||||
|
||||
bundle = SkillBundleAssembler(tree).assemble_bundle(documents, ctx.build_id)
|
||||
SkillManager.save_bundle(ctx.tenant_id, ctx.app_id, ctx.build_id, bundle)
|
||||
self._bundle = bundle
|
||||
|
||||
items: list[AssetItem] = []
|
||||
for node, path in self._nodes:
|
||||
skill = bundle.get(node.id)
|
||||
if skill is None:
|
||||
continue
|
||||
items.append(
|
||||
AssetItem(
|
||||
asset_id=node.id,
|
||||
path=path,
|
||||
file_name=node.name,
|
||||
extension=node.extension or "",
|
||||
storage_key="",
|
||||
content=skill.content.encode("utf-8"),
|
||||
)
|
||||
)
|
||||
return items
|
||||
8
api/core/app_assets/constants.py
Normal file
8
api/core/app_assets/constants.py
Normal file
@ -0,0 +1,8 @@
|
||||
from core.app.entities.app_asset_entities import AppAssetFileTree
|
||||
from libs.attr_map import AttrKey
|
||||
|
||||
|
||||
class AppAssetsAttrs:
|
||||
# Skill artifact set
|
||||
FILE_TREE = AttrKey("file_tree", AppAssetFileTree)
|
||||
APP_ASSETS_ID = AttrKey("app_assets_id", str)
|
||||
20
api/core/app_assets/converters.py
Normal file
20
api/core/app_assets/converters.py
Normal file
@ -0,0 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.app.entities.app_asset_entities import AppAssetFileTree, AssetNodeType
|
||||
from core.app_assets.entities import AssetItem
|
||||
from core.app_assets.storage import AssetPaths
|
||||
|
||||
|
||||
def tree_to_asset_items(tree: AppAssetFileTree, tenant_id: str, app_id: str) -> list[AssetItem]:
|
||||
"""Convert AppAssetFileTree to list of AssetItem for packaging."""
|
||||
return [
|
||||
AssetItem(
|
||||
asset_id=node.id,
|
||||
path=tree.get_path(node.id),
|
||||
file_name=node.name,
|
||||
extension=node.extension or "",
|
||||
storage_key=AssetPaths.draft(tenant_id, app_id, node.id),
|
||||
)
|
||||
for node in tree.nodes
|
||||
if node.node_type == AssetNodeType.FILE
|
||||
]
|
||||
7
api/core/app_assets/entities/__init__.py
Normal file
7
api/core/app_assets/entities/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
from .assets import AssetItem
|
||||
from .skill import SkillAsset
|
||||
|
||||
__all__ = [
|
||||
"AssetItem",
|
||||
"SkillAsset",
|
||||
]
|
||||
20
api/core/app_assets/entities/assets.py
Normal file
20
api/core/app_assets/entities/assets.py
Normal file
@ -0,0 +1,20 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class AssetItem:
|
||||
"""A single asset file produced by the build pipeline.
|
||||
|
||||
When *content* is set the payload is available in-process and can be
|
||||
written directly into a ZIP or uploaded to a sandbox VM without an
|
||||
extra S3 round-trip. When *content* is ``None`` the caller should
|
||||
fetch the bytes from *storage_key* (the traditional presigned-URL
|
||||
path).
|
||||
"""
|
||||
|
||||
asset_id: str
|
||||
path: str
|
||||
file_name: str
|
||||
extension: str
|
||||
storage_key: str
|
||||
content: bytes | None = field(default=None, repr=False)
|
||||
10
api/core/app_assets/entities/skill.py
Normal file
10
api/core/app_assets/entities/skill.py
Normal file
@ -0,0 +1,10 @@
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from .assets import AssetItem
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillAsset(AssetItem):
|
||||
metadata: Mapping[str, Any] = field(default_factory=dict)
|
||||
68
api/core/app_assets/storage.py
Normal file
68
api/core/app_assets/storage.py
Normal file
@ -0,0 +1,68 @@
|
||||
"""App assets storage key generation.
|
||||
|
||||
Provides AssetPaths facade for generating storage keys for app assets.
|
||||
Storage instances are obtained via AppAssetService.get_storage().
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
_BASE = "app_assets"
|
||||
|
||||
|
||||
def _check_uuid(value: str, name: str) -> None:
|
||||
try:
|
||||
UUID(value)
|
||||
except (ValueError, TypeError) as e:
|
||||
raise ValueError(f"{name} must be a valid UUID") from e
|
||||
|
||||
|
||||
class AssetPaths:
|
||||
"""Facade for generating app asset storage keys."""
|
||||
|
||||
@staticmethod
|
||||
def draft(tenant_id: str, app_id: str, node_id: str) -> str:
|
||||
"""app_assets/{tenant}/{app}/draft/{node_id}"""
|
||||
_check_uuid(tenant_id, "tenant_id")
|
||||
_check_uuid(app_id, "app_id")
|
||||
_check_uuid(node_id, "node_id")
|
||||
return f"{_BASE}/{tenant_id}/{app_id}/draft/{node_id}"
|
||||
|
||||
@staticmethod
|
||||
def build_zip(tenant_id: str, app_id: str, assets_id: str) -> str:
|
||||
"""app_assets/{tenant}/{app}/artifacts/{assets_id}.zip"""
|
||||
_check_uuid(tenant_id, "tenant_id")
|
||||
_check_uuid(app_id, "app_id")
|
||||
_check_uuid(assets_id, "assets_id")
|
||||
return f"{_BASE}/{tenant_id}/{app_id}/artifacts/{assets_id}.zip"
|
||||
|
||||
@staticmethod
|
||||
def skill_bundle(tenant_id: str, app_id: str, assets_id: str) -> str:
|
||||
"""app_assets/{tenant}/{app}/artifacts/{assets_id}/skill_artifact_set.json"""
|
||||
_check_uuid(tenant_id, "tenant_id")
|
||||
_check_uuid(app_id, "app_id")
|
||||
_check_uuid(assets_id, "assets_id")
|
||||
return f"{_BASE}/{tenant_id}/{app_id}/artifacts/{assets_id}/skill_artifact_set.json"
|
||||
|
||||
@staticmethod
|
||||
def source_zip(tenant_id: str, app_id: str, workflow_id: str) -> str:
|
||||
"""app_assets/{tenant}/{app}/sources/{workflow_id}.zip"""
|
||||
_check_uuid(tenant_id, "tenant_id")
|
||||
_check_uuid(app_id, "app_id")
|
||||
_check_uuid(workflow_id, "workflow_id")
|
||||
return f"{_BASE}/{tenant_id}/{app_id}/sources/{workflow_id}.zip"
|
||||
|
||||
@staticmethod
|
||||
def bundle_export(tenant_id: str, app_id: str, export_id: str) -> str:
|
||||
"""app_assets/{tenant}/{app}/bundle_exports/{export_id}.zip"""
|
||||
_check_uuid(tenant_id, "tenant_id")
|
||||
_check_uuid(app_id, "app_id")
|
||||
_check_uuid(export_id, "export_id")
|
||||
return f"{_BASE}/{tenant_id}/{app_id}/bundle_exports/{export_id}.zip"
|
||||
|
||||
@staticmethod
|
||||
def bundle_import(tenant_id: str, import_id: str) -> str:
|
||||
"""app_assets/{tenant}/imports/{import_id}.zip"""
|
||||
_check_uuid(tenant_id, "tenant_id")
|
||||
return f"{_BASE}/{tenant_id}/imports/{import_id}.zip"
|
||||
1
api/core/app_bundle/__init__.py
Normal file
1
api/core/app_bundle/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# App bundle utilities - manifest-driven import/export handled by AppBundleService
|
||||
96
api/core/sandbox/__init__.py
Normal file
96
api/core/sandbox/__init__.py
Normal file
@ -0,0 +1,96 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .bash.dify_cli import (
|
||||
DifyCliBinary,
|
||||
DifyCliConfig,
|
||||
DifyCliEnvConfig,
|
||||
DifyCliLocator,
|
||||
DifyCliToolConfig,
|
||||
)
|
||||
from .bash.session import SandboxBashSession
|
||||
from .builder import SandboxBuilder, VMConfig
|
||||
from .entities import AppAssets, DifyCli, SandboxProviderApiEntity, SandboxType
|
||||
from .initializer import (
|
||||
AsyncSandboxInitializer,
|
||||
SandboxInitializeContext,
|
||||
SandboxInitializer,
|
||||
SyncSandboxInitializer,
|
||||
)
|
||||
from .initializer.app_assets_initializer import AppAssetsInitializer
|
||||
from .initializer.dify_cli_initializer import DifyCliInitializer
|
||||
from .initializer.draft_app_assets_initializer import DraftAppAssetsInitializer
|
||||
from .sandbox import Sandbox
|
||||
from .storage import ArchiveSandboxStorage, SandboxStorage
|
||||
from .utils.debug import sandbox_debug
|
||||
from .utils.encryption import create_sandbox_config_encrypter, masked_config
|
||||
|
||||
__all__ = [
|
||||
"AppAssets",
|
||||
"AppAssetsInitializer",
|
||||
"ArchiveSandboxStorage",
|
||||
"AsyncSandboxInitializer",
|
||||
"DifyCli",
|
||||
"DifyCliBinary",
|
||||
"DifyCliConfig",
|
||||
"DifyCliEnvConfig",
|
||||
"DifyCliInitializer",
|
||||
"DifyCliLocator",
|
||||
"DifyCliToolConfig",
|
||||
"DraftAppAssetsInitializer",
|
||||
"Sandbox",
|
||||
"SandboxBashSession",
|
||||
"SandboxBuilder",
|
||||
"SandboxInitializeContext",
|
||||
"SandboxInitializer",
|
||||
"SandboxProviderApiEntity",
|
||||
"SandboxStorage",
|
||||
"SandboxType",
|
||||
"SyncSandboxInitializer",
|
||||
"VMConfig",
|
||||
"create_sandbox_config_encrypter",
|
||||
"masked_config",
|
||||
"sandbox_debug",
|
||||
]
|
||||
|
||||
_LAZY_IMPORTS = {
|
||||
"AppAssets": ("core.sandbox.entities", "AppAssets"),
|
||||
"AppAssetsInitializer": ("core.sandbox.initializer.app_assets_initializer", "AppAssetsInitializer"),
|
||||
"AsyncSandboxInitializer": ("core.sandbox.initializer", "AsyncSandboxInitializer"),
|
||||
"ArchiveSandboxStorage": ("core.sandbox.storage", "ArchiveSandboxStorage"),
|
||||
"DifyCli": ("core.sandbox.entities", "DifyCli"),
|
||||
"DifyCliBinary": ("core.sandbox.bash.dify_cli", "DifyCliBinary"),
|
||||
"DifyCliConfig": ("core.sandbox.bash.dify_cli", "DifyCliConfig"),
|
||||
"DifyCliEnvConfig": ("core.sandbox.bash.dify_cli", "DifyCliEnvConfig"),
|
||||
"DifyCliInitializer": ("core.sandbox.initializer.dify_cli_initializer", "DifyCliInitializer"),
|
||||
"DifyCliLocator": ("core.sandbox.bash.dify_cli", "DifyCliLocator"),
|
||||
"DifyCliToolConfig": ("core.sandbox.bash.dify_cli", "DifyCliToolConfig"),
|
||||
"DraftAppAssetsInitializer": ("core.sandbox.initializer.draft_app_assets_initializer", "DraftAppAssetsInitializer"),
|
||||
"Sandbox": ("core.sandbox.sandbox", "Sandbox"),
|
||||
"SandboxBashSession": ("core.sandbox.bash.session", "SandboxBashSession"),
|
||||
"SandboxBuilder": ("core.sandbox.builder", "SandboxBuilder"),
|
||||
"SandboxInitializeContext": ("core.sandbox.initializer", "SandboxInitializeContext"),
|
||||
"SandboxInitializer": ("core.sandbox.initializer", "SandboxInitializer"),
|
||||
"SandboxManager": ("core.sandbox.manager", "SandboxManager"),
|
||||
"SandboxProviderApiEntity": ("core.sandbox.entities", "SandboxProviderApiEntity"),
|
||||
"SandboxStorage": ("core.sandbox.storage", "SandboxStorage"),
|
||||
"SandboxType": ("core.sandbox.entities", "SandboxType"),
|
||||
"SyncSandboxInitializer": ("core.sandbox.initializer", "SyncSandboxInitializer"),
|
||||
"VMConfig": ("core.sandbox.builder", "VMConfig"),
|
||||
"create_sandbox_config_encrypter": ("core.sandbox.utils.encryption", "create_sandbox_config_encrypter"),
|
||||
"masked_config": ("core.sandbox.utils.encryption", "masked_config"),
|
||||
"sandbox_debug": ("core.sandbox.utils.debug", "sandbox_debug"),
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name not in _LAZY_IMPORTS:
|
||||
raise AttributeError(f"module 'core.sandbox' has no attribute {name}")
|
||||
module_path, attr_name = _LAZY_IMPORTS[name]
|
||||
module = importlib.import_module(module_path)
|
||||
value = getattr(module, attr_name)
|
||||
globals()[name] = value
|
||||
return value
|
||||
15
api/core/sandbox/bash/__init__.py
Normal file
15
api/core/sandbox/bash/__init__.py
Normal file
@ -0,0 +1,15 @@
|
||||
from .dify_cli import (
|
||||
DifyCliBinary,
|
||||
DifyCliConfig,
|
||||
DifyCliEnvConfig,
|
||||
DifyCliLocator,
|
||||
DifyCliToolConfig,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DifyCliBinary",
|
||||
"DifyCliConfig",
|
||||
"DifyCliEnvConfig",
|
||||
"DifyCliLocator",
|
||||
"DifyCliToolConfig",
|
||||
]
|
||||
139
api/core/sandbox/bash/bash_tool.py
Normal file
139
api/core/sandbox/bash/bash_tool.py
Normal file
@ -0,0 +1,139 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.sandbox.entities import DifyCli
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolDescription,
|
||||
ToolEntity,
|
||||
ToolIdentity,
|
||||
ToolInvokeMessage,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.virtual_environment.__base.helpers import submit_command, with_connection
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
|
||||
from ..utils.debug import sandbox_debug
|
||||
|
||||
COMMAND_TIMEOUT_SECONDS = 60 * 60 * 2 # 2 hours, can be adjusted based on expected command execution times
|
||||
|
||||
# Output truncation settings to avoid overwhelming model context
|
||||
# 8000 chars ≈ 2000-2700 tokens, safe for models with 8K+ context
|
||||
MAX_OUTPUT_LENGTH = 8000
|
||||
TRUNCATE_HEAD_LENGTH = 2500 # Keep beginning for context
|
||||
TRUNCATE_TAIL_LENGTH = 2500 # Keep end for results/errors
|
||||
|
||||
|
||||
def _truncate_output(output: str, name: str = "output") -> str:
|
||||
"""Truncate output if it exceeds the maximum length.
|
||||
|
||||
Keeps the head and tail of the output to preserve context and final results.
|
||||
"""
|
||||
if len(output) <= MAX_OUTPUT_LENGTH:
|
||||
return output
|
||||
|
||||
omitted_length = len(output) - TRUNCATE_HEAD_LENGTH - TRUNCATE_TAIL_LENGTH
|
||||
head = output[:TRUNCATE_HEAD_LENGTH]
|
||||
tail = output[-TRUNCATE_TAIL_LENGTH:]
|
||||
|
||||
return f"{head}\n\n... [{omitted_length} characters omitted from {name}] ...\n\n{tail}"
|
||||
|
||||
|
||||
class SandboxBashTool(Tool):
|
||||
def __init__(self, sandbox: VirtualEnvironment, tenant_id: str, tools_path: str) -> None:
|
||||
self._sandbox = sandbox
|
||||
self._tools_path = tools_path
|
||||
|
||||
entity = ToolEntity(
|
||||
identity=ToolIdentity(
|
||||
author="Dify",
|
||||
name="bash",
|
||||
label=I18nObject(en_US="Bash", zh_Hans="Bash"),
|
||||
provider="sandbox",
|
||||
),
|
||||
parameters=[
|
||||
ToolParameter.get_simple_instance(
|
||||
name="bash",
|
||||
llm_description="The bash command to execute in current working directory",
|
||||
typ=ToolParameter.ToolParameterType.STRING,
|
||||
required=True,
|
||||
),
|
||||
],
|
||||
description=ToolDescription(
|
||||
human=I18nObject(
|
||||
en_US="Execute bash commands in current working directory",
|
||||
),
|
||||
llm="Execute bash commands in current working directory. "
|
||||
"Use this tool to run shell commands, scripts, or interact with the system. "
|
||||
"The command will be executed in the current working directory. "
|
||||
"IMPORTANT: If you generate any output files (images, documents, etc.) that need to be "
|
||||
"returned or referenced later, you MUST save them to the 'output/' directory "
|
||||
"(e.g., 'mkdir -p output && cp result.png output/'). Only files in output/ will be collected.",
|
||||
),
|
||||
)
|
||||
|
||||
runtime = ToolRuntime(tenant_id=tenant_id)
|
||||
super().__init__(entity=entity, runtime=runtime)
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.BUILT_IN
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
command = tool_parameters.get("bash", "")
|
||||
if not command:
|
||||
sandbox_debug("bash_tool", "parameters", tool_parameters)
|
||||
yield self.create_text_message(
|
||||
'Error: No command provided. The "bash" parameter is required and must contain '
|
||||
'the shell command to execute. Example: {"bash": "ls -la"}'
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
with with_connection(self._sandbox) as conn:
|
||||
# Build command with embedded environment variables
|
||||
env_exports = (
|
||||
f"export PATH={self._tools_path}:/usr/local/bin:/usr/bin:/bin && "
|
||||
f"export DIFY_CLI_CONFIG={self._tools_path}/{DifyCli.CONFIG_FILENAME} && "
|
||||
)
|
||||
full_command = env_exports + command
|
||||
|
||||
cmd_list = ["bash", "-c", full_command]
|
||||
sandbox_debug("bash_tool", "cmd_list", cmd_list)
|
||||
|
||||
future = submit_command(
|
||||
self._sandbox,
|
||||
conn,
|
||||
cmd_list,
|
||||
)
|
||||
timeout = COMMAND_TIMEOUT_SECONDS if COMMAND_TIMEOUT_SECONDS > 0 else None
|
||||
result = future.result(timeout=timeout)
|
||||
|
||||
stdout = result.stdout.decode("utf-8", errors="replace") if result.stdout else ""
|
||||
stderr = result.stderr.decode("utf-8", errors="replace") if result.stderr else ""
|
||||
|
||||
# Truncate long outputs to avoid overwhelming the model
|
||||
stdout = _truncate_output(stdout, "stdout")
|
||||
stderr = _truncate_output(stderr, "stderr")
|
||||
|
||||
output_parts: list[str] = []
|
||||
if stdout:
|
||||
output_parts.append(f"\n{stdout}")
|
||||
if stderr:
|
||||
output_parts.append(f"\n{stderr}")
|
||||
|
||||
yield self.create_text_message("\n".join(output_parts))
|
||||
|
||||
except TimeoutError:
|
||||
yield self.create_text_message(f"Error: Command timed out after {COMMAND_TIMEOUT_SECONDS}s")
|
||||
except Exception as e:
|
||||
yield self.create_text_message(f"Error: {e!s}")
|
||||
164
api/core/sandbox/bash/dify_cli.py
Normal file
164
api/core/sandbox/bash/dify_cli.py
Normal file
@ -0,0 +1,164 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.session.cli_api import CliApiSession
|
||||
from core.skill.entities import ToolDependencies, ToolReference
|
||||
from core.tools.entities.tool_entities import ToolParameter, ToolProviderType
|
||||
from core.virtual_environment.__base.entities import Arch, OperatingSystem
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
from ..entities import DifyCli
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.__base.tool import Tool
|
||||
|
||||
|
||||
class DifyCliBinary(BaseModel):
|
||||
operating_system: OperatingSystem = Field(alias="os")
|
||||
arch: Arch
|
||||
path: Path
|
||||
|
||||
model_config = {
|
||||
"populate_by_name": True,
|
||||
"arbitrary_types_allowed": True,
|
||||
}
|
||||
|
||||
|
||||
class DifyCliLocator:
|
||||
def __init__(self, root: str | Path | None = None) -> None:
|
||||
from configs import dify_config
|
||||
|
||||
if root is not None:
|
||||
self._root = Path(root)
|
||||
elif dify_config.SANDBOX_DIFY_CLI_ROOT:
|
||||
self._root = Path(dify_config.SANDBOX_DIFY_CLI_ROOT)
|
||||
else:
|
||||
api_root = Path(__file__).resolve().parents[3]
|
||||
self._root = api_root / "bin"
|
||||
|
||||
def resolve(self, operating_system: OperatingSystem, arch: Arch) -> DifyCliBinary:
|
||||
filename = DifyCli.PATH_PATTERN.format(os=operating_system.value, arch=arch.value)
|
||||
candidate = self._root / filename
|
||||
if not candidate.is_file():
|
||||
raise FileNotFoundError(
|
||||
f"dify CLI binary not found: {candidate}. Configure SANDBOX_DIFY_CLI_ROOT or ensure the file exists."
|
||||
)
|
||||
|
||||
return DifyCliBinary(os=operating_system, arch=arch, path=candidate)
|
||||
|
||||
|
||||
class DifyCliEnvConfig(BaseModel):
|
||||
files_url: str
|
||||
cli_api_url: str
|
||||
cli_api_session_id: str
|
||||
cli_api_secret: str
|
||||
|
||||
|
||||
class DifyCliToolConfig(BaseModel):
|
||||
provider_type: str
|
||||
enabled: bool = True
|
||||
identity: dict[str, Any]
|
||||
description: dict[str, Any]
|
||||
parameters: list[dict[str, Any]]
|
||||
|
||||
@classmethod
|
||||
def transform_provider_type(cls, tool_provider_type: ToolProviderType) -> str:
|
||||
provider_type = tool_provider_type
|
||||
match tool_provider_type:
|
||||
case ToolProviderType.BUILT_IN | ToolProviderType.PLUGIN:
|
||||
provider_type = "builtin"
|
||||
case ToolProviderType.MCP | ToolProviderType.WORKFLOW | ToolProviderType.API:
|
||||
provider_type = provider_type
|
||||
case _:
|
||||
raise ValueError(f"Invalid tool provider type: {tool_provider_type}")
|
||||
return provider_type
|
||||
|
||||
@classmethod
|
||||
def create_from_tool(cls, tool: Tool) -> DifyCliToolConfig:
|
||||
return cls(
|
||||
identity=to_json(tool.entity.identity),
|
||||
provider_type=cls.transform_provider_type(tool.tool_provider_type()),
|
||||
description=to_json(tool.entity.description),
|
||||
parameters=[cls.transform_parameter(parameter) for parameter in tool.entity.parameters],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def transform_parameter(cls, parameter: ToolParameter) -> dict[str, Any]:
|
||||
transformed_parameter = to_json(parameter)
|
||||
transformed_parameter.pop("input_schema", None)
|
||||
transformed_parameter.pop("form", None)
|
||||
match parameter.type:
|
||||
case (
|
||||
ToolParameter.ToolParameterType.SYSTEM_FILES
|
||||
| ToolParameter.ToolParameterType.FILE
|
||||
| ToolParameter.ToolParameterType.FILES
|
||||
):
|
||||
return transformed_parameter
|
||||
case _:
|
||||
return transformed_parameter
|
||||
|
||||
|
||||
class DifyCliToolReference(BaseModel):
|
||||
id: str
|
||||
tool_type: str
|
||||
tool_name: str
|
||||
tool_provider: str
|
||||
credential_id: str | None = None
|
||||
default_value: dict[str, Any] | None = None
|
||||
|
||||
@classmethod
|
||||
def create_from_tool_reference(cls, reference: ToolReference) -> DifyCliToolReference:
|
||||
return cls(
|
||||
id=reference.uuid,
|
||||
tool_type=reference.type.value,
|
||||
tool_name=reference.tool_name,
|
||||
tool_provider=reference.provider,
|
||||
credential_id=reference.credential_id,
|
||||
default_value=reference.configuration.default_values() if reference.configuration else None,
|
||||
)
|
||||
|
||||
|
||||
class DifyCliConfig(BaseModel):
|
||||
env: DifyCliEnvConfig
|
||||
tool_references: list[DifyCliToolReference]
|
||||
tools: list[DifyCliToolConfig]
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
session: CliApiSession,
|
||||
tenant_id: str,
|
||||
tool_deps: ToolDependencies,
|
||||
) -> DifyCliConfig:
|
||||
from configs import dify_config
|
||||
|
||||
cli_api_url = dify_config.CLI_API_URL
|
||||
|
||||
return cls(
|
||||
env=DifyCliEnvConfig(
|
||||
files_url=dify_config.FILES_API_URL,
|
||||
cli_api_url=cli_api_url,
|
||||
cli_api_session_id=session.id,
|
||||
cli_api_secret=session.secret,
|
||||
),
|
||||
tool_references=[DifyCliToolReference.create_from_tool_reference(ref) for ref in tool_deps.references],
|
||||
tools=[],
|
||||
)
|
||||
|
||||
|
||||
def to_json(obj: Any) -> dict[str, Any]:
|
||||
return jsonable_encoder(obj, exclude_unset=True, exclude_defaults=True, exclude_none=True)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DifyCliBinary",
|
||||
"DifyCliConfig",
|
||||
"DifyCliEnvConfig",
|
||||
"DifyCliLocator",
|
||||
"DifyCliToolConfig",
|
||||
"DifyCliToolReference",
|
||||
]
|
||||
239
api/core/sandbox/bash/session.py
Normal file
239
api/core/sandbox/bash/session.py
Normal file
@ -0,0 +1,239 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
import shlex
|
||||
from types import TracebackType
|
||||
|
||||
from core.sandbox.sandbox import Sandbox
|
||||
from core.session.cli_api import CliApiSession, CliApiSessionManager, CliContext
|
||||
from core.skill.entities import ToolAccessPolicy
|
||||
from core.skill.entities.tool_dependencies import ToolDependencies
|
||||
from core.tools.signature import sign_tool_file
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.virtual_environment.__base.helpers import pipeline
|
||||
from graphon.file import File, FileTransferMethod, FileType
|
||||
|
||||
from ..bash.dify_cli import DifyCliConfig
|
||||
from ..entities import DifyCli
|
||||
from .bash_tool import SandboxBashTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SANDBOX_READY_TIMEOUT = 60 * 10
|
||||
|
||||
# Default output directory for sandbox-generated files
|
||||
SANDBOX_OUTPUT_DIR = "output"
|
||||
# Maximum number of files to collect from sandbox output
|
||||
MAX_OUTPUT_FILES = 50
|
||||
# Maximum file size to collect (10MB)
|
||||
MAX_OUTPUT_FILE_SIZE = 10 * 1024 * 1024
|
||||
|
||||
|
||||
class SandboxBashSession:
|
||||
def __init__(self, *, sandbox: Sandbox, node_id: str, tools: ToolDependencies | None) -> None:
|
||||
self._sandbox = sandbox
|
||||
self._node_id = node_id
|
||||
self._tools = tools
|
||||
self._bash_tool: SandboxBashTool | None = None
|
||||
self._cli_api_session: CliApiSession | None = None
|
||||
self._tenant_id = sandbox.tenant_id
|
||||
self._user_id = sandbox.user_id
|
||||
self._app_id = sandbox.app_id
|
||||
self._assets_id = sandbox.assets_id
|
||||
|
||||
def __enter__(self) -> SandboxBashSession:
|
||||
# Ensure sandbox initialization completes before any bash commands run.
|
||||
self._sandbox.wait_ready(timeout=SANDBOX_READY_TIMEOUT)
|
||||
cli = DifyCli(self._sandbox.id)
|
||||
self._cli_api_session = CliApiSessionManager().create(
|
||||
tenant_id=self._tenant_id,
|
||||
user_id=self._user_id,
|
||||
context=CliContext(tool_access=ToolAccessPolicy.from_dependencies(self._tools)),
|
||||
)
|
||||
if self._tools is not None and not self._tools.is_empty():
|
||||
tools_path = self._setup_node_tools_directory(cli, self._node_id, self._tools, self._cli_api_session)
|
||||
else:
|
||||
tools_path = cli.global_tools_path
|
||||
|
||||
self._bash_tool = SandboxBashTool(
|
||||
sandbox=self._sandbox.vm,
|
||||
tenant_id=self._tenant_id,
|
||||
tools_path=tools_path,
|
||||
)
|
||||
return self
|
||||
|
||||
def _setup_node_tools_directory(
|
||||
self,
|
||||
cli: DifyCli,
|
||||
node_id: str,
|
||||
tools: ToolDependencies,
|
||||
cli_api_session: CliApiSession,
|
||||
) -> str:
|
||||
node_tools_path = cli.node_tools_path(node_id)
|
||||
config_json = json.dumps(
|
||||
DifyCliConfig.create(session=cli_api_session, tenant_id=self._tenant_id, tool_deps=tools).model_dump(
|
||||
mode="json"
|
||||
),
|
||||
ensure_ascii=False,
|
||||
)
|
||||
config_path = shlex.quote(cli.node_config_path(node_id))
|
||||
|
||||
vm = self._sandbox.vm
|
||||
# Merge mkdir + config write into a single pipeline to reduce round-trips.
|
||||
(
|
||||
pipeline(vm)
|
||||
.add(["mkdir", "-p", cli.global_tools_path], error_message="Failed to create global tools dir")
|
||||
.add(["mkdir", "-p", node_tools_path], error_message="Failed to create node tools dir")
|
||||
# Use a quoted heredoc (<<'EOF') so the shell performs no expansion on the
|
||||
# content — safe regardless of $, `, \, or quotes inside the JSON.
|
||||
.add(
|
||||
["sh", "-c", f"cat > {config_path} << '__DIFY_CFG__'\n{config_json}\n__DIFY_CFG__"],
|
||||
error_message="Failed to write CLI config",
|
||||
)
|
||||
.execute(raise_on_error=True)
|
||||
)
|
||||
|
||||
pipeline(vm, cwd=node_tools_path).add(
|
||||
[cli.bin_path, "init"], error_message="Failed to initialize Dify CLI"
|
||||
).execute(raise_on_error=True)
|
||||
|
||||
logger.info(
|
||||
"Node %s tools initialized, path=%s, tool_count=%d", node_id, node_tools_path, len(tools.references)
|
||||
)
|
||||
return node_tools_path
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc: BaseException | None,
|
||||
tb: TracebackType | None,
|
||||
) -> bool:
|
||||
try:
|
||||
if self._cli_api_session is not None:
|
||||
CliApiSessionManager().delete(self._cli_api_session.id)
|
||||
logger.debug("Cleaned up SandboxSession session_id=%s", self._cli_api_session.id)
|
||||
self._cli_api_session = None
|
||||
except Exception:
|
||||
logger.exception("Failed to cleanup SandboxSession")
|
||||
return False
|
||||
|
||||
@property
|
||||
def bash_tool(self) -> SandboxBashTool:
|
||||
if self._bash_tool is None:
|
||||
raise RuntimeError("SandboxSession is not initialized")
|
||||
return self._bash_tool
|
||||
|
||||
def collect_output_files(self, output_dir: str = SANDBOX_OUTPUT_DIR) -> list[File]:
|
||||
"""
|
||||
Collect files from sandbox output directory and save them as ToolFiles.
|
||||
|
||||
Scans the specified output directory in sandbox, downloads each file,
|
||||
saves it as a ToolFile, and returns a list of File objects. The File
|
||||
objects will have valid tool_file_id that can be referenced by subsequent
|
||||
nodes via structured output.
|
||||
|
||||
Args:
|
||||
output_dir: Directory path in sandbox to scan for output files.
|
||||
Defaults to "output" (relative to workspace).
|
||||
|
||||
Returns:
|
||||
List of File objects representing the collected files.
|
||||
"""
|
||||
vm = self._sandbox.vm
|
||||
collected_files: list[File] = []
|
||||
|
||||
try:
|
||||
file_states = vm.list_files(output_dir, limit=MAX_OUTPUT_FILES)
|
||||
except Exception as exc:
|
||||
# Output directory may not exist if no files were generated
|
||||
logger.debug("Failed to list sandbox output files in %s: %s", output_dir, exc)
|
||||
return collected_files
|
||||
|
||||
tool_file_manager = ToolFileManager()
|
||||
|
||||
for file_state in file_states:
|
||||
# Skip files that are too large
|
||||
if file_state.size > MAX_OUTPUT_FILE_SIZE:
|
||||
logger.warning(
|
||||
"Skipping sandbox output file %s: size %d exceeds limit %d",
|
||||
file_state.path,
|
||||
file_state.size,
|
||||
MAX_OUTPUT_FILE_SIZE,
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
# file_state.path is already relative to working_path (e.g., "output/file.png")
|
||||
file_content = vm.download_file(file_state.path)
|
||||
file_binary = file_content.getvalue()
|
||||
|
||||
# Determine mime type from extension
|
||||
filename = os.path.basename(file_state.path)
|
||||
mime_type, _ = mimetypes.guess_type(filename)
|
||||
if not mime_type:
|
||||
mime_type = "application/octet-stream"
|
||||
|
||||
# Save as ToolFile
|
||||
tool_file = tool_file_manager.create_file_by_raw(
|
||||
user_id=self._user_id,
|
||||
tenant_id=self._tenant_id,
|
||||
conversation_id=None,
|
||||
file_binary=file_binary,
|
||||
mimetype=mime_type,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
# Determine file type from mime type
|
||||
file_type = _get_file_type_from_mime(mime_type)
|
||||
extension = os.path.splitext(filename)[1] if "." in filename else ".bin"
|
||||
url = sign_tool_file(tool_file.id, extension)
|
||||
|
||||
# Create File object with tool_file_id as related_id
|
||||
file_obj = File(
|
||||
id=tool_file.id, # Use tool_file_id as the File id for easy reference
|
||||
tenant_id=self._tenant_id,
|
||||
type=file_type,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
filename=filename,
|
||||
extension=extension,
|
||||
mime_type=mime_type,
|
||||
size=len(file_binary),
|
||||
related_id=tool_file.id,
|
||||
url=url,
|
||||
storage_key=tool_file.file_key,
|
||||
)
|
||||
collected_files.append(file_obj)
|
||||
|
||||
logger.info(
|
||||
"Collected sandbox output file: %s -> tool_file_id=%s",
|
||||
file_state.path,
|
||||
tool_file.id,
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to collect sandbox output file %s: %s", file_state.path, exc)
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
"Collected %d files from sandbox output directory %s",
|
||||
len(collected_files),
|
||||
output_dir,
|
||||
)
|
||||
return collected_files
|
||||
|
||||
|
||||
def _get_file_type_from_mime(mime_type: str) -> FileType:
|
||||
"""Determine FileType from mime type."""
|
||||
if mime_type.startswith("image/"):
|
||||
return FileType.IMAGE
|
||||
elif mime_type.startswith("video/"):
|
||||
return FileType.VIDEO
|
||||
elif mime_type.startswith("audio/"):
|
||||
return FileType.AUDIO
|
||||
elif "text" in mime_type or "pdf" in mime_type:
|
||||
return FileType.DOCUMENT
|
||||
else:
|
||||
return FileType.CUSTOM
|
||||
222
api/core/sandbox/builder.py
Normal file
222
api/core/sandbox/builder.py
Normal file
@ -0,0 +1,222 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from collections.abc import Mapping, Sequence
|
||||
from contextlib import nullcontext
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from flask import Flask, current_app, has_app_context
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
|
||||
from .entities.sandbox_type import SandboxType
|
||||
from .initializer import AsyncSandboxInitializer, SandboxInitializeContext, SandboxInitializer, SyncSandboxInitializer
|
||||
from .sandbox import Sandbox
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .storage.sandbox_storage import SandboxStorage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_sandbox_class(sandbox_type: SandboxType) -> type[VirtualEnvironment]:
|
||||
match sandbox_type:
|
||||
case SandboxType.DOCKER:
|
||||
from core.virtual_environment.providers.docker_daemon_sandbox import DockerDaemonEnvironment
|
||||
|
||||
return DockerDaemonEnvironment
|
||||
case SandboxType.E2B:
|
||||
from core.virtual_environment.providers.e2b_sandbox import E2BEnvironment
|
||||
|
||||
return E2BEnvironment
|
||||
case SandboxType.LOCAL:
|
||||
from core.virtual_environment.providers.local_without_isolation import LocalVirtualEnvironment
|
||||
|
||||
return LocalVirtualEnvironment
|
||||
case SandboxType.SSH:
|
||||
from core.virtual_environment.providers.ssh_sandbox import SSHSandboxEnvironment
|
||||
|
||||
return SSHSandboxEnvironment
|
||||
case SandboxType.AWS_CODE_INTERPRETER:
|
||||
from core.virtual_environment.providers.aws_code_interpreter_sandbox import (
|
||||
AWSCodeInterpreterEnvironment,
|
||||
)
|
||||
|
||||
return AWSCodeInterpreterEnvironment
|
||||
case _:
|
||||
raise ValueError(f"Unsupported sandbox type: {sandbox_type}")
|
||||
|
||||
|
||||
class SandboxBuilder:
|
||||
_tenant_id: str
|
||||
_sandbox_type: SandboxType
|
||||
_user_id: str | None
|
||||
_app_id: str | None
|
||||
_options: dict[str, Any]
|
||||
_environments: dict[str, str]
|
||||
_initializers: list[SandboxInitializer]
|
||||
_storage: SandboxStorage | None
|
||||
_assets_id: str | None
|
||||
|
||||
def __init__(self, tenant_id: str, sandbox_type: SandboxType) -> None:
|
||||
self._tenant_id = tenant_id
|
||||
self._sandbox_type = sandbox_type
|
||||
self._user_id = None
|
||||
self._app_id = None
|
||||
self._options = {}
|
||||
self._environments = {}
|
||||
self._initializers = []
|
||||
self._storage = None
|
||||
self._assets_id = None
|
||||
|
||||
def user(self, user_id: str) -> SandboxBuilder:
|
||||
self._user_id = user_id
|
||||
return self
|
||||
|
||||
def app(self, app_id: str) -> SandboxBuilder:
|
||||
self._app_id = app_id
|
||||
return self
|
||||
|
||||
def options(self, options: Mapping[str, Any]) -> SandboxBuilder:
|
||||
self._options = dict(options)
|
||||
return self
|
||||
|
||||
def environments(self, environments: Mapping[str, str]) -> SandboxBuilder:
|
||||
self._environments = dict(environments)
|
||||
return self
|
||||
|
||||
def initializer(self, initializer: SandboxInitializer) -> SandboxBuilder:
|
||||
self._initializers.append(initializer)
|
||||
return self
|
||||
|
||||
def initializers(self, initializers: Sequence[SandboxInitializer]) -> SandboxBuilder:
|
||||
self._initializers.extend(initializers)
|
||||
return self
|
||||
|
||||
def storage(self, storage: SandboxStorage, assets_id: str) -> SandboxBuilder:
|
||||
self._storage = storage
|
||||
self._assets_id = assets_id
|
||||
return self
|
||||
|
||||
def build(self) -> Sandbox:
|
||||
"""Create a sandbox and start background initialization.
|
||||
|
||||
The builder is responsible for cleaning up any VM or sandbox that was
|
||||
successfully created if a later setup step fails.
|
||||
"""
|
||||
if self._storage is None:
|
||||
raise ValueError("storage is required, call .storage() before .build()")
|
||||
if self._assets_id is None:
|
||||
raise ValueError("assets_id is required, call .storage() before .build()")
|
||||
if self._user_id is None:
|
||||
raise ValueError("user_id is required, call .user() before .build()")
|
||||
if self._app_id is None:
|
||||
raise ValueError("app_id is required, call .app() before .build()")
|
||||
|
||||
ctx = SandboxInitializeContext(
|
||||
tenant_id=self._tenant_id,
|
||||
app_id=self._app_id,
|
||||
assets_id=self._assets_id,
|
||||
user_id=self._user_id,
|
||||
)
|
||||
vm: VirtualEnvironment | None = None
|
||||
sandbox: Sandbox | None = None
|
||||
try:
|
||||
vm_class = _get_sandbox_class(self._sandbox_type)
|
||||
vm = vm_class(
|
||||
tenant_id=self._tenant_id,
|
||||
options=self._options,
|
||||
environments=self._environments,
|
||||
user_id=self._user_id,
|
||||
)
|
||||
vm.open_enviroment()
|
||||
sandbox = Sandbox(
|
||||
vm=vm,
|
||||
storage=self._storage,
|
||||
tenant_id=self._tenant_id,
|
||||
user_id=self._user_id,
|
||||
app_id=self._app_id,
|
||||
assets_id=self._assets_id,
|
||||
)
|
||||
|
||||
for init in self._initializers:
|
||||
if isinstance(init, SyncSandboxInitializer):
|
||||
init.initialize(sandbox, ctx)
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"Failed to initialize sandbox synchronously: tenant_id=%s, app_id=%s", self._tenant_id, self._app_id
|
||||
)
|
||||
if sandbox is not None:
|
||||
sandbox.release()
|
||||
elif vm is not None:
|
||||
try:
|
||||
vm.release_environment()
|
||||
except Exception:
|
||||
logger.exception("Failed to release sandbox VM during builder cleanup")
|
||||
raise RuntimeError("Sandbox initialization failed") from exc
|
||||
|
||||
# Run sandbox setup asynchronously so workflow execution can proceed.
|
||||
# Capture the Flask app before starting the thread for database access.
|
||||
flask_app: Flask | None = cast(Any, current_app)._get_current_object() if has_app_context() else None
|
||||
|
||||
_sandbox: Sandbox = sandbox
|
||||
|
||||
def initialize() -> None:
|
||||
try:
|
||||
app_context = flask_app.app_context() if flask_app is not None else nullcontext()
|
||||
with app_context:
|
||||
for init in self._initializers:
|
||||
if not isinstance(init, AsyncSandboxInitializer):
|
||||
continue
|
||||
|
||||
if _sandbox.is_cancelled():
|
||||
return
|
||||
init.initialize(_sandbox, ctx)
|
||||
|
||||
if _sandbox.is_cancelled():
|
||||
return
|
||||
_sandbox.mount()
|
||||
_sandbox.mark_ready()
|
||||
except Exception as exc:
|
||||
try:
|
||||
logger.exception(
|
||||
"Failed to initialize sandbox: tenant_id=%s, app_id=%s", self._tenant_id, self._app_id
|
||||
)
|
||||
_sandbox.release()
|
||||
_sandbox.mark_failed(exc)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to mark sandbox initialization failure: tenant_id=%s, app_id=%s",
|
||||
self._tenant_id,
|
||||
self._app_id,
|
||||
)
|
||||
|
||||
# Background init completes or signals failure via sandbox state.
|
||||
try:
|
||||
threading.Thread(target=initialize, daemon=True).start()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to start sandbox initialization thread: tenant_id=%s, app_id=%s",
|
||||
self._tenant_id,
|
||||
self._app_id,
|
||||
)
|
||||
sandbox.release()
|
||||
raise RuntimeError("Sandbox initialization failed")
|
||||
return sandbox
|
||||
|
||||
@staticmethod
|
||||
def validate(vm_type: SandboxType, options: Mapping[str, Any]) -> None:
|
||||
vm_class = _get_sandbox_class(vm_type)
|
||||
vm_class.validate(options)
|
||||
|
||||
@classmethod
|
||||
def draft_id(cls, user_id: str) -> str:
|
||||
return user_id
|
||||
|
||||
|
||||
class VMConfig:
|
||||
@staticmethod
|
||||
def get_schema(vm_type: SandboxType) -> list[BasicProviderConfig]:
|
||||
return _get_sandbox_class(vm_type).get_config_schema()
|
||||
13
api/core/sandbox/entities/__init__.py
Normal file
13
api/core/sandbox/entities/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
from .config import AppAssets, DifyCli
|
||||
from .files import SandboxFileDownloadTicket, SandboxFileNode
|
||||
from .providers import SandboxProviderApiEntity
|
||||
from .sandbox_type import SandboxType
|
||||
|
||||
__all__ = [
|
||||
"AppAssets",
|
||||
"DifyCli",
|
||||
"SandboxFileDownloadTicket",
|
||||
"SandboxFileNode",
|
||||
"SandboxProviderApiEntity",
|
||||
"SandboxType",
|
||||
]
|
||||
58
api/core/sandbox/entities/config.py
Normal file
58
api/core/sandbox/entities/config.py
Normal file
@ -0,0 +1,58 @@
|
||||
from typing import Final
|
||||
|
||||
|
||||
class DifyCli:
|
||||
"""Per-sandbox Dify CLI paths, namespaced under ``/tmp/.dify/{env_id}``.
|
||||
|
||||
Every sandbox environment gets its own directory tree so that
|
||||
concurrent sessions on the same host (e.g. SSH provider) never
|
||||
collide on config files or CLI binaries.
|
||||
|
||||
Class-level constants (``CONFIG_FILENAME``, ``PATH_PATTERN``) are
|
||||
safe to share; all path attributes are instance-level and derived
|
||||
from the ``env_id`` passed at construction time.
|
||||
"""
|
||||
|
||||
# --- class-level constants (no path component) ---
|
||||
CONFIG_FILENAME: Final[str] = ".dify_cli.json"
|
||||
PATH_PATTERN: Final[str] = "dify-cli-{os}-{arch}"
|
||||
|
||||
# --- instance attributes ---
|
||||
root: str
|
||||
bin_dir: str
|
||||
bin_path: str
|
||||
tools_root: str
|
||||
global_tools_path: str
|
||||
global_config_path: str
|
||||
|
||||
def __init__(self, env_id: str) -> None:
|
||||
self.root = f"/tmp/.dify/{env_id}"
|
||||
self.bin_dir = f"{self.root}/bin"
|
||||
self.bin_path = f"{self.bin_dir}/dify"
|
||||
self.tools_root = f"{self.root}/tools"
|
||||
self.global_tools_path = f"{self.root}/tools/global"
|
||||
self.global_config_path = f"{self.global_tools_path}/{DifyCli.CONFIG_FILENAME}"
|
||||
|
||||
def node_tools_path(self, node_id: str) -> str:
|
||||
return f"{self.tools_root}/{node_id}"
|
||||
|
||||
def node_config_path(self, node_id: str) -> str:
|
||||
return f"{self.node_tools_path(node_id)}/{DifyCli.CONFIG_FILENAME}"
|
||||
|
||||
|
||||
class AppAssets:
|
||||
"""App Assets constants.
|
||||
|
||||
``PATH`` is a relative path resolved by each provider against its
|
||||
own workspace root — already isolated. ``zip_path`` is an absolute
|
||||
temp path and must be namespaced per environment to avoid collisions.
|
||||
"""
|
||||
|
||||
PATH: Final[str] = "skills"
|
||||
|
||||
root: str
|
||||
zip_path: str
|
||||
|
||||
def __init__(self, env_id: str) -> None:
|
||||
self.root = f"/tmp/.dify/{env_id}"
|
||||
self.zip_path = f"{self.root}/assets.zip"
|
||||
19
api/core/sandbox/entities/files.py
Normal file
19
api/core/sandbox/entities/files.py
Normal file
@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SandboxFileNode:
|
||||
path: str
|
||||
is_dir: bool
|
||||
size: int | None
|
||||
mtime: int | None
|
||||
extension: str | None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SandboxFileDownloadTicket:
|
||||
download_url: str
|
||||
expires_in: int
|
||||
export_id: str
|
||||
21
api/core/sandbox/entities/providers.py
Normal file
21
api/core/sandbox/entities/providers.py
Normal file
@ -0,0 +1,21 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SandboxProviderApiEntity(BaseModel):
|
||||
provider_type: str = Field(..., description="Provider type identifier")
|
||||
is_system_configured: bool = Field(default=False)
|
||||
is_tenant_configured: bool = Field(default=False)
|
||||
is_active: bool = Field(default=False)
|
||||
config: Mapping[str, Any] = Field(default_factory=dict)
|
||||
config_schema: list[dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class SandboxProviderEntity(BaseModel):
|
||||
id: str = Field(..., description="Provider identifier")
|
||||
provider_type: str = Field(..., description="Provider type identifier")
|
||||
is_active: bool = Field(default=False)
|
||||
config: Mapping[str, Any] = Field(default_factory=dict)
|
||||
config_schema: list[dict[str, Any]] = Field(default_factory=list)
|
||||
18
api/core/sandbox/entities/sandbox_type.py
Normal file
18
api/core/sandbox/entities/sandbox_type.py
Normal file
@ -0,0 +1,18 @@
|
||||
from enum import StrEnum
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
class SandboxType(StrEnum):
|
||||
DOCKER = "docker"
|
||||
E2B = "e2b"
|
||||
LOCAL = "local"
|
||||
SSH = "ssh"
|
||||
AWS_CODE_INTERPRETER = "aws_code_interpreter"
|
||||
|
||||
@classmethod
|
||||
def get_all(cls) -> list[str]:
|
||||
if dify_config.EDITION == "SELF_HOSTED":
|
||||
return [p.value for p in cls]
|
||||
else:
|
||||
return [p.value for p in cls if p != SandboxType.LOCAL]
|
||||
8
api/core/sandbox/initializer/__init__.py
Normal file
8
api/core/sandbox/initializer/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
from .base import AsyncSandboxInitializer, SandboxInitializeContext, SandboxInitializer, SyncSandboxInitializer
|
||||
|
||||
__all__ = [
|
||||
"AsyncSandboxInitializer",
|
||||
"SandboxInitializeContext",
|
||||
"SandboxInitializer",
|
||||
"SyncSandboxInitializer",
|
||||
]
|
||||
19
api/core/sandbox/initializer/app_asset_attrs_initializer.py
Normal file
19
api/core/sandbox/initializer/app_asset_attrs_initializer.py
Normal file
@ -0,0 +1,19 @@
|
||||
import logging
|
||||
|
||||
from core.app_assets.constants import AppAssetsAttrs
|
||||
from core.sandbox.sandbox import Sandbox
|
||||
from services.app_asset_package_service import AppAssetPackageService
|
||||
|
||||
from .base import SandboxInitializeContext, SyncSandboxInitializer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
APP_ASSETS_DOWNLOAD_TIMEOUT = 60 * 10
|
||||
|
||||
|
||||
class AppAssetAttrsInitializer(SyncSandboxInitializer):
|
||||
def initialize(self, sandbox: Sandbox, ctx: SandboxInitializeContext) -> None:
|
||||
# Load published app assets and unzip the artifact bundle.
|
||||
app_assets = AppAssetPackageService.get_tenant_app_assets(ctx.tenant_id, ctx.assets_id)
|
||||
sandbox.attrs.set(AppAssetsAttrs.FILE_TREE, app_assets.asset_tree)
|
||||
sandbox.attrs.set(AppAssetsAttrs.APP_ASSETS_ID, ctx.assets_id)
|
||||
59
api/core/sandbox/initializer/app_assets_initializer.py
Normal file
59
api/core/sandbox/initializer/app_assets_initializer.py
Normal file
@ -0,0 +1,59 @@
|
||||
import logging
|
||||
|
||||
from core.app_assets.storage import AssetPaths
|
||||
from core.sandbox.sandbox import Sandbox
|
||||
from core.virtual_environment.__base.helpers import pipeline
|
||||
|
||||
from ..entities import AppAssets
|
||||
from .base import AsyncSandboxInitializer, SandboxInitializeContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
APP_ASSETS_DOWNLOAD_TIMEOUT = 60 * 10
|
||||
|
||||
|
||||
class AppAssetsInitializer(AsyncSandboxInitializer):
|
||||
def initialize(self, sandbox: Sandbox, ctx: SandboxInitializeContext) -> None:
|
||||
from services.app_asset_service import AppAssetService
|
||||
|
||||
# Load published app assets and unzip the artifact bundle.
|
||||
vm = sandbox.vm
|
||||
assets = AppAssets(sandbox.id)
|
||||
asset_storage = AppAssetService.get_storage()
|
||||
key = AssetPaths.build_zip(ctx.tenant_id, ctx.app_id, ctx.assets_id)
|
||||
download_url = asset_storage.get_download_url(key)
|
||||
|
||||
(
|
||||
pipeline(vm)
|
||||
.add(
|
||||
["mkdir", "-p", assets.root],
|
||||
error_message="Failed to create assets temp directory",
|
||||
)
|
||||
.add(
|
||||
["sh", "-c", 'curl -fsSL "$1" -o "$2"', "sh", download_url, assets.zip_path],
|
||||
error_message="Failed to download assets zip",
|
||||
)
|
||||
# Create the assets directory first to ensure it exists even if zip is empty
|
||||
.add(
|
||||
["mkdir", "-p", AppAssets.PATH],
|
||||
error_message="Failed to create assets directory",
|
||||
)
|
||||
# unzip with silent error and return 1 if the zip is empty
|
||||
# FIXME(Mairuis): should use a more robust way to check if the zip is empty
|
||||
.add(
|
||||
["sh", "-c", 'unzip "$1" -d "$2" 2>/dev/null || [ $? -eq 1 ]', "sh", assets.zip_path, AppAssets.PATH],
|
||||
error_message="Failed to unzip assets",
|
||||
)
|
||||
# Ensure directories have execute permission for traversal and files are readable
|
||||
.add(
|
||||
["sh", "-c", 'chmod -R u+rwX,go+rX "$1"', "sh", AppAssets.PATH],
|
||||
error_message="Failed to set permissions on assets",
|
||||
)
|
||||
.execute(timeout=APP_ASSETS_DOWNLOAD_TIMEOUT, raise_on_error=True)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"App assets initialized for app_id=%s, published_id=%s",
|
||||
ctx.app_id,
|
||||
ctx.assets_id,
|
||||
)
|
||||
46
api/core/sandbox/initializer/base.py
Normal file
46
api/core/sandbox/initializer/base.py
Normal file
@ -0,0 +1,46 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.sandbox.sandbox import Sandbox
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.app_assets.entities.assets import AssetItem
|
||||
|
||||
|
||||
@dataclass
|
||||
class SandboxInitializeContext:
|
||||
"""Shared identity context passed to every ``SandboxInitializer``.
|
||||
|
||||
Carries the common identity fields that virtually every initializer
|
||||
needs, plus optional artefact slots that sync initializers populate
|
||||
for async initializers to consume.
|
||||
|
||||
Identity fields are immutable by convention; artefact slots are
|
||||
written at most once during the sync phase and read during the
|
||||
async phase.
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
assets_id: str
|
||||
user_id: str
|
||||
|
||||
# Populated by DraftAppAssetsInitializer (sync) for
|
||||
# DraftAppAssetsDownloader (async) to download into the VM.
|
||||
built_assets: list[AssetItem] | None = field(default=None)
|
||||
|
||||
|
||||
class SandboxInitializer(ABC):
|
||||
@abstractmethod
|
||||
def initialize(self, sandbox: Sandbox, ctx: SandboxInitializeContext) -> None: ...
|
||||
|
||||
|
||||
class SyncSandboxInitializer(SandboxInitializer):
|
||||
"""Marker class for initializers that must run before async setup."""
|
||||
|
||||
|
||||
class AsyncSandboxInitializer(SandboxInitializer):
|
||||
"""Marker class for initializers that can run in the background."""
|
||||
76
api/core/sandbox/initializer/dify_cli_initializer.py
Normal file
76
api/core/sandbox/initializer/dify_cli_initializer.py
Normal file
@ -0,0 +1,76 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
from core.sandbox.sandbox import Sandbox
|
||||
from core.session.cli_api import CliApiSessionManager, CliContext
|
||||
from core.skill.constants import SkillAttrs
|
||||
from core.skill.entities import ToolAccessPolicy
|
||||
from core.virtual_environment.__base.helpers import pipeline
|
||||
|
||||
from ..bash.dify_cli import DifyCliConfig, DifyCliLocator
|
||||
from ..entities import DifyCli
|
||||
from .base import AsyncSandboxInitializer, SandboxInitializeContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DifyCliInitializer(AsyncSandboxInitializer):
|
||||
def __init__(self, cli_root: str | Path | None = None) -> None:
|
||||
self._locator = DifyCliLocator(root=cli_root)
|
||||
self._tools: list[object] = []
|
||||
|
||||
def initialize(self, sandbox: Sandbox, ctx: SandboxInitializeContext) -> None:
|
||||
vm = sandbox.vm
|
||||
cli = DifyCli(sandbox.id)
|
||||
|
||||
# FIXME(Mairuis): should be more robust, effectively.
|
||||
binary = self._locator.resolve(vm.metadata.os, vm.metadata.arch)
|
||||
|
||||
pipeline(vm).add(["mkdir", "-p", cli.bin_dir], error_message="Failed to create dify CLI directory").execute(
|
||||
raise_on_error=True
|
||||
)
|
||||
|
||||
vm.upload_file(cli.bin_path, BytesIO(binary.path.read_bytes()))
|
||||
|
||||
pipeline(vm).add(
|
||||
[
|
||||
"sh",
|
||||
"-c",
|
||||
f"cat '{cli.bin_path}' > '{cli.bin_path}.tmp' && "
|
||||
f"mv '{cli.bin_path}.tmp' '{cli.bin_path}' && "
|
||||
f"chmod +x '{cli.bin_path}'",
|
||||
],
|
||||
error_message="Failed to mark dify CLI as executable",
|
||||
).execute(raise_on_error=True)
|
||||
|
||||
logger.info("Dify CLI uploaded to sandbox, path=%s", cli.bin_path)
|
||||
|
||||
bundle = sandbox.attrs.get(SkillAttrs.BUNDLE)
|
||||
if bundle is None or bundle.get_tool_dependencies().is_empty():
|
||||
logger.info("No tools found in bundle for assets_id=%s", ctx.assets_id)
|
||||
return
|
||||
|
||||
global_cli_session = CliApiSessionManager().create(
|
||||
tenant_id=ctx.tenant_id,
|
||||
user_id=ctx.user_id,
|
||||
context=CliContext(tool_access=ToolAccessPolicy.from_dependencies(bundle.get_tool_dependencies())),
|
||||
)
|
||||
|
||||
pipeline(vm).add(
|
||||
["mkdir", "-p", cli.global_tools_path], error_message="Failed to create global tools dir"
|
||||
).execute(raise_on_error=True)
|
||||
|
||||
config = DifyCliConfig.create(global_cli_session, ctx.tenant_id, bundle.get_tool_dependencies())
|
||||
config_json = json.dumps(config.model_dump(mode="json"), ensure_ascii=False)
|
||||
config_path = cli.global_config_path
|
||||
vm.upload_file(config_path, BytesIO(config_json.encode("utf-8")))
|
||||
|
||||
pipeline(vm, cwd=cli.global_tools_path).add(
|
||||
[cli.bin_path, "init"], error_message="Failed to initialize Dify CLI"
|
||||
).execute(raise_on_error=True)
|
||||
|
||||
logger.info("Global tools initialized, path=%s, tool_count=%d", cli.global_tools_path, len(self._tools))
|
||||
89
api/core/sandbox/initializer/draft_app_assets_initializer.py
Normal file
89
api/core/sandbox/initializer/draft_app_assets_initializer.py
Normal file
@ -0,0 +1,89 @@
|
||||
"""Synchronous initializer that compiles draft app assets.
|
||||
|
||||
Unlike ``AppAssetsInitializer`` (which downloads a pre-built ZIP for
|
||||
published assets), this initializer runs the build pipeline on the fly
|
||||
so that ``.md`` skill documents are compiled and their resolved content
|
||||
is embedded directly into the download script — avoiding the S3
|
||||
round-trip that was previously required for resolved keys.
|
||||
|
||||
Execution order:
|
||||
``DraftAppAssetsInitializer`` (sync) compiles assets and publishes
|
||||
the ``SkillBundle`` to ``sandbox.attrs`` in-memory, so the
|
||||
downstream ``SkillInitializer`` can skip the Redis/S3 round-trip.
|
||||
``DraftAppAssetsDownloader`` (async) then pushes the compiled
|
||||
artefacts into the sandbox VM in the background.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from core.app_assets.builder.base import BuildContext
|
||||
from core.app_assets.builder.file_builder import FileBuilder
|
||||
from core.app_assets.builder.pipeline import AssetBuildPipeline
|
||||
from core.app_assets.builder.skill_builder import SkillBuilder
|
||||
from core.app_assets.constants import AppAssetsAttrs
|
||||
from core.sandbox.entities import AppAssets
|
||||
from core.sandbox.sandbox import Sandbox
|
||||
from core.sandbox.services import AssetDownloadService
|
||||
from core.skill import SkillAttrs
|
||||
from core.virtual_environment.__base.helpers import pipeline
|
||||
from services.app_asset_service import AppAssetService
|
||||
|
||||
from .base import AsyncSandboxInitializer, SandboxInitializeContext, SyncSandboxInitializer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DraftAppAssetsInitializer(SyncSandboxInitializer):
|
||||
"""Compile draft assets and publish the ``SkillBundle`` to attrs.
|
||||
|
||||
The build pipeline compiles ``.md`` skill files in-process.
|
||||
The resulting ``SkillBundle`` is persisted to Redis/S3 (by
|
||||
``SkillBuilder``) **and** written to ``sandbox.attrs[BUNDLE]``
|
||||
so that ``SkillInitializer`` can read it without a round-trip.
|
||||
Built asset items are stored on ``ctx.built_assets`` for the
|
||||
async ``DraftAppAssetsDownloader`` to consume.
|
||||
"""
|
||||
|
||||
def initialize(self, sandbox: Sandbox, ctx: SandboxInitializeContext) -> None:
|
||||
tree = sandbox.attrs.get(AppAssetsAttrs.FILE_TREE)
|
||||
|
||||
# --- 1. Run the build pipeline (SkillBuilder compiles .md inline) ---
|
||||
accessor = AppAssetService.get_accessor(ctx.tenant_id, ctx.app_id)
|
||||
skill_builder = SkillBuilder(accessor=accessor)
|
||||
build_pipeline = AssetBuildPipeline([skill_builder, FileBuilder()])
|
||||
build_ctx = BuildContext(tenant_id=ctx.tenant_id, app_id=ctx.app_id, build_id=ctx.assets_id)
|
||||
built_assets = build_pipeline.build_all(tree, build_ctx)
|
||||
ctx.built_assets = built_assets
|
||||
|
||||
# Publish the in-memory bundle so SkillInitializer skips Redis/S3.
|
||||
if skill_builder.bundle is not None:
|
||||
sandbox.attrs.set(SkillAttrs.BUNDLE, skill_builder.bundle)
|
||||
|
||||
|
||||
class DraftAppAssetsDownloader(AsyncSandboxInitializer):
|
||||
"""Download the compiled assets into the sandbox VM.
|
||||
|
||||
The download script is generated by ``DraftAppAssetsInitializer`` and
|
||||
includes inline base64 content for compiled skills, as well as
|
||||
presigned URLs for other files.
|
||||
"""
|
||||
|
||||
_TIMEOUT = 600 # 10 minutes
|
||||
|
||||
def initialize(self, sandbox: Sandbox, ctx: SandboxInitializeContext) -> None:
|
||||
if not ctx.built_assets:
|
||||
logger.debug("No built assets found for assets_id=%s", ctx.assets_id)
|
||||
return
|
||||
|
||||
download_items = AppAssetService.to_download_items(ctx.built_assets)
|
||||
script = AssetDownloadService.build_download_script(download_items, AppAssets.PATH)
|
||||
pipeline(sandbox.vm).add(
|
||||
["sh", "-c", script],
|
||||
error_message="Failed to download draft assets",
|
||||
).execute(timeout=self._TIMEOUT, raise_on_error=True)
|
||||
|
||||
logger.info(
|
||||
"Draft app assets initialized for app_id=%s, assets_id=%s",
|
||||
ctx.app_id,
|
||||
ctx.assets_id,
|
||||
)
|
||||
34
api/core/sandbox/initializer/skill_initializer.py
Normal file
34
api/core/sandbox/initializer/skill_initializer.py
Normal file
@ -0,0 +1,34 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from core.sandbox.sandbox import Sandbox
|
||||
from core.skill import SkillAttrs
|
||||
from core.skill.skill_manager import SkillManager
|
||||
|
||||
from .base import SandboxInitializeContext, SyncSandboxInitializer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SkillInitializer(SyncSandboxInitializer):
|
||||
"""Ensure ``sandbox.attrs[BUNDLE]`` is populated for downstream consumers.
|
||||
|
||||
In the draft path ``DraftAppAssetsInitializer`` already sets the
|
||||
bundle on attrs from the in-memory build result, so this initializer
|
||||
becomes a no-op. In the published path no prior initializer sets
|
||||
it, so we fall back to ``SkillManager.load_bundle()`` (Redis/S3).
|
||||
"""
|
||||
|
||||
def initialize(self, sandbox: Sandbox, ctx: SandboxInitializeContext) -> None:
|
||||
# Draft path: bundle already populated by DraftAppAssetsInitializer.
|
||||
if sandbox.attrs.has(SkillAttrs.BUNDLE):
|
||||
return
|
||||
|
||||
# Published path: load from Redis/S3.
|
||||
bundle = SkillManager.load_bundle(
|
||||
ctx.tenant_id,
|
||||
ctx.app_id,
|
||||
ctx.assets_id,
|
||||
)
|
||||
sandbox.attrs.set(SkillAttrs.BUNDLE, bundle)
|
||||
11
api/core/sandbox/inspector/__init__.py
Normal file
11
api/core/sandbox/inspector/__init__.py
Normal file
@ -0,0 +1,11 @@
|
||||
from core.sandbox.inspector.archive_source import SandboxFileArchiveSource
|
||||
from core.sandbox.inspector.base import SandboxFileSource
|
||||
from core.sandbox.inspector.browser import SandboxFileBrowser
|
||||
from core.sandbox.inspector.runtime_source import SandboxFileRuntimeSource
|
||||
|
||||
__all__ = [
|
||||
"SandboxFileArchiveSource",
|
||||
"SandboxFileBrowser",
|
||||
"SandboxFileRuntimeSource",
|
||||
"SandboxFileSource",
|
||||
]
|
||||
169
api/core/sandbox/inspector/archive_source.py
Normal file
169
api/core/sandbox/inspector/archive_source.py
Normal file
@ -0,0 +1,169 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import uuid4
|
||||
|
||||
from core.sandbox.entities.files import SandboxFileDownloadTicket, SandboxFileNode
|
||||
from core.sandbox.inspector.base import SandboxFileSource
|
||||
from core.sandbox.inspector.script_utils import (
|
||||
build_detect_kind_command,
|
||||
build_list_command,
|
||||
build_upload_command,
|
||||
guess_content_type,
|
||||
parse_kind_output,
|
||||
parse_list_output,
|
||||
)
|
||||
from core.sandbox.storage import SandboxFilePaths
|
||||
from core.virtual_environment.__base.exec import PipelineExecutionError
|
||||
from core.virtual_environment.__base.helpers import pipeline
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.zip_sandbox import ZipSandbox
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SandboxFileArchiveSource(SandboxFileSource):
|
||||
def _get_archive_download_url(self) -> str:
|
||||
"""Get a pre-signed download URL for the sandbox archive."""
|
||||
from extensions.storage.file_presign_storage import FilePresignStorage
|
||||
|
||||
storage_key = SandboxFilePaths.archive(self._tenant_id, self._app_id, self._sandbox_id)
|
||||
if not storage.exists(storage_key):
|
||||
raise ValueError("Sandbox archive not found")
|
||||
presign_storage = FilePresignStorage(storage.storage_runner)
|
||||
return presign_storage.get_download_url(storage_key, self._EXPORT_EXPIRES_IN_SECONDS)
|
||||
|
||||
def _create_zip_sandbox(self) -> ZipSandbox:
|
||||
"""Create a ZipSandbox instance for archive operations."""
|
||||
from core.zip_sandbox import ZipSandbox
|
||||
|
||||
return ZipSandbox(tenant_id=self._tenant_id, user_id="system", app_id=self._app_id)
|
||||
|
||||
def exists(self) -> bool:
|
||||
"""Check if the sandbox archive exists in storage."""
|
||||
storage_key = SandboxFilePaths.archive(self._tenant_id, self._app_id, self._sandbox_id)
|
||||
return storage.exists(storage_key)
|
||||
|
||||
def list_files(self, *, path: str, recursive: bool) -> list[SandboxFileNode]:
|
||||
archive_url = self._get_archive_download_url()
|
||||
with self._create_zip_sandbox() as zs:
|
||||
# Download and extract the archive
|
||||
archive_path = zs.download_archive(archive_url, path="workspace.tar.gz")
|
||||
zs.untar(archive_path=archive_path, dest_dir="workspace")
|
||||
|
||||
# List files using Python script in sandbox
|
||||
try:
|
||||
list_path = f"workspace/{path}" if path not in (".", "") else "workspace"
|
||||
results = (
|
||||
pipeline(zs.vm)
|
||||
.add(
|
||||
build_list_command(list_path, recursive),
|
||||
error_message="Failed to list sandbox files",
|
||||
)
|
||||
.execute(timeout=self._LIST_TIMEOUT_SECONDS, raise_on_error=True)
|
||||
)
|
||||
except PipelineExecutionError as exc:
|
||||
raise RuntimeError(str(exc)) from exc
|
||||
|
||||
raw = parse_list_output(results[0].stdout)
|
||||
|
||||
entries: list[SandboxFileNode] = []
|
||||
for item in raw:
|
||||
item_path = str(item.get("path"))
|
||||
# Strip the "workspace/" prefix from paths
|
||||
if item_path.startswith("workspace/"):
|
||||
item_path = item_path[len("workspace/") :]
|
||||
elif item_path == "workspace":
|
||||
continue # Skip the workspace directory itself
|
||||
|
||||
item_is_dir = bool(item.get("is_dir"))
|
||||
extension = None
|
||||
if not item_is_dir:
|
||||
ext = os.path.splitext(item_path)[1]
|
||||
extension = ext or None
|
||||
entries.append(
|
||||
SandboxFileNode(
|
||||
path=item_path,
|
||||
is_dir=item_is_dir,
|
||||
size=item.get("size"),
|
||||
mtime=item.get("mtime"),
|
||||
extension=extension,
|
||||
)
|
||||
)
|
||||
return sorted(entries, key=lambda e: e.path)
|
||||
|
||||
def download_file(self, *, path: str) -> SandboxFileDownloadTicket:
|
||||
"""Download a file or directory from the archived sandbox.
|
||||
|
||||
Uses direct upload from sandbox to storage via presigned URL, avoiding
|
||||
data transfer through the service layer. This preserves binary integrity
|
||||
(no text encoding issues) and reduces bandwidth overhead.
|
||||
"""
|
||||
from services.sandbox.sandbox_file_service import SandboxFileService
|
||||
|
||||
archive_url = self._get_archive_download_url()
|
||||
export_name = os.path.basename(path.rstrip("/")) or "workspace"
|
||||
export_id = uuid4().hex
|
||||
|
||||
with self._create_zip_sandbox() as zs:
|
||||
archive_path = zs.download_archive(archive_url, path="workspace.tar.gz")
|
||||
zs.untar(archive_path=archive_path, dest_dir="workspace")
|
||||
|
||||
target_path = f"workspace/{path}" if path not in (".", "") else "workspace"
|
||||
try:
|
||||
results = (
|
||||
pipeline(zs.vm)
|
||||
.add(
|
||||
build_detect_kind_command(target_path),
|
||||
error_message="Failed to check path in sandbox",
|
||||
)
|
||||
.execute(timeout=self._LIST_TIMEOUT_SECONDS, raise_on_error=True)
|
||||
)
|
||||
except PipelineExecutionError as exc:
|
||||
raise ValueError(str(exc)) from exc
|
||||
|
||||
kind = parse_kind_output(results[0].stdout, not_found_message="File not found in sandbox archive")
|
||||
|
||||
sandbox_storage = SandboxFileService.get_storage()
|
||||
is_file = kind == "file"
|
||||
filename = (os.path.basename(path) or "file") if is_file else f"{export_name}.tar.gz"
|
||||
export_key = SandboxFilePaths.export(self._tenant_id, self._app_id, self._sandbox_id, export_id)
|
||||
upload_url = sandbox_storage.get_upload_url(export_key, self._EXPORT_EXPIRES_IN_SECONDS)
|
||||
content_type = guess_content_type(filename)
|
||||
|
||||
# Build pipeline: for directories, tar first then upload; for files, upload directly
|
||||
archive_temp = f"/tmp/{export_id}.tar.gz"
|
||||
src_path = target_path if is_file else archive_temp
|
||||
tar_src = path if path not in (".", "") else "."
|
||||
|
||||
try:
|
||||
(
|
||||
pipeline(zs.vm)
|
||||
.add(
|
||||
["tar", "-czf", archive_temp, "-C", "workspace", tar_src],
|
||||
error_message="Failed to archive directory",
|
||||
on=not is_file,
|
||||
)
|
||||
.add(
|
||||
build_upload_command(src_path, upload_url, content_type=content_type),
|
||||
error_message="Failed to upload file",
|
||||
)
|
||||
.add(["rm", "-f", archive_temp], on=not is_file)
|
||||
.execute(timeout=self._UPLOAD_TIMEOUT_SECONDS, raise_on_error=True)
|
||||
)
|
||||
except PipelineExecutionError as exc:
|
||||
raise RuntimeError(str(exc)) from exc
|
||||
|
||||
download_url = sandbox_storage.get_download_url(
|
||||
export_key, self._EXPORT_EXPIRES_IN_SECONDS, download_filename=filename
|
||||
)
|
||||
|
||||
return SandboxFileDownloadTicket(
|
||||
download_url=download_url,
|
||||
expires_in=self._EXPORT_EXPIRES_IN_SECONDS,
|
||||
export_id=export_id,
|
||||
)
|
||||
33
api/core/sandbox/inspector/base.py
Normal file
33
api/core/sandbox/inspector/base.py
Normal file
@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
|
||||
from core.sandbox.entities.files import SandboxFileDownloadTicket, SandboxFileNode
|
||||
|
||||
|
||||
class SandboxFileSource(abc.ABC):
|
||||
_LIST_TIMEOUT_SECONDS = 30
|
||||
_UPLOAD_TIMEOUT_SECONDS = 60 * 10
|
||||
_EXPORT_EXPIRES_IN_SECONDS = 60 * 10
|
||||
|
||||
def __init__(self, *, tenant_id: str, app_id: str, sandbox_id: str):
|
||||
self._tenant_id = tenant_id
|
||||
self._app_id = app_id
|
||||
self._sandbox_id = sandbox_id
|
||||
|
||||
@abc.abstractmethod
|
||||
def exists(self) -> bool:
|
||||
"""Check if the sandbox source exists and is available.
|
||||
|
||||
Returns:
|
||||
True if the sandbox source exists and can be accessed, False otherwise.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def list_files(self, *, path: str, recursive: bool) -> list[SandboxFileNode]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def download_file(self, *, path: str) -> SandboxFileDownloadTicket:
|
||||
raise NotImplementedError
|
||||
48
api/core/sandbox/inspector/browser.py
Normal file
48
api/core/sandbox/inspector/browser.py
Normal file
@ -0,0 +1,48 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import PurePosixPath
|
||||
|
||||
from core.sandbox.entities.files import SandboxFileDownloadTicket, SandboxFileNode
|
||||
from core.sandbox.inspector.archive_source import SandboxFileArchiveSource
|
||||
from core.sandbox.inspector.base import SandboxFileSource
|
||||
|
||||
|
||||
class SandboxFileBrowser:
|
||||
def __init__(self, *, tenant_id: str, app_id: str, sandbox_id: str):
|
||||
self._tenant_id = tenant_id
|
||||
self._app_id = app_id
|
||||
self._sandbox_id = sandbox_id
|
||||
|
||||
@staticmethod
|
||||
def _normalize_workspace_path(path: str | None) -> str:
|
||||
raw = (path or ".").strip()
|
||||
if raw == "":
|
||||
raw = "."
|
||||
|
||||
p = PurePosixPath(raw)
|
||||
if p.is_absolute():
|
||||
raise ValueError("path must be relative")
|
||||
if any(part == ".." for part in p.parts):
|
||||
raise ValueError("path must not contain '..'")
|
||||
|
||||
normalized = str(p)
|
||||
return "." if normalized in (".", "") else normalized
|
||||
|
||||
def _backend(self) -> SandboxFileSource:
|
||||
return SandboxFileArchiveSource(
|
||||
tenant_id=self._tenant_id,
|
||||
app_id=self._app_id,
|
||||
sandbox_id=self._sandbox_id,
|
||||
)
|
||||
|
||||
def exists(self) -> bool:
|
||||
"""Check if the sandbox source exists and is available."""
|
||||
return self._backend().exists()
|
||||
|
||||
def list_files(self, *, path: str | None = None, recursive: bool = False) -> list[SandboxFileNode]:
|
||||
workspace_path = self._normalize_workspace_path(path)
|
||||
return self._backend().list_files(path=workspace_path, recursive=recursive)
|
||||
|
||||
def download_file(self, *, path: str) -> SandboxFileDownloadTicket:
|
||||
workspace_path = self._normalize_workspace_path(path)
|
||||
return self._backend().download_file(path=workspace_path)
|
||||
140
api/core/sandbox/inspector/runtime_source.py
Normal file
140
api/core/sandbox/inspector/runtime_source.py
Normal file
@ -0,0 +1,140 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from uuid import uuid4
|
||||
|
||||
from core.sandbox.entities.files import SandboxFileDownloadTicket, SandboxFileNode
|
||||
from core.sandbox.inspector.base import SandboxFileSource
|
||||
from core.sandbox.inspector.script_utils import (
|
||||
build_detect_kind_command,
|
||||
build_list_command,
|
||||
build_upload_command,
|
||||
guess_content_type,
|
||||
parse_kind_output,
|
||||
parse_list_output,
|
||||
)
|
||||
from core.sandbox.storage import SandboxFilePaths
|
||||
from core.virtual_environment.__base.exec import PipelineExecutionError
|
||||
from core.virtual_environment.__base.helpers import pipeline
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SandboxFileRuntimeSource(SandboxFileSource):
|
||||
def __init__(self, *, tenant_id: str, app_id: str, sandbox_id: str, runtime: VirtualEnvironment):
|
||||
super().__init__(tenant_id=tenant_id, app_id=app_id, sandbox_id=sandbox_id)
|
||||
self._runtime = runtime
|
||||
|
||||
def exists(self) -> bool:
|
||||
"""Check if the sandbox runtime exists and is available."""
|
||||
return self._runtime is not None
|
||||
|
||||
def list_files(self, *, path: str, recursive: bool) -> list[SandboxFileNode]:
|
||||
try:
|
||||
results = (
|
||||
pipeline(self._runtime)
|
||||
.add(
|
||||
build_list_command(path, recursive),
|
||||
error_message="Failed to list sandbox files",
|
||||
)
|
||||
.execute(timeout=self._LIST_TIMEOUT_SECONDS, raise_on_error=True)
|
||||
)
|
||||
except PipelineExecutionError as exc:
|
||||
raise RuntimeError(str(exc)) from exc
|
||||
|
||||
raw = parse_list_output(results[0].stdout)
|
||||
|
||||
entries: list[SandboxFileNode] = []
|
||||
for item in raw:
|
||||
item_path = str(item.get("path"))
|
||||
item_is_dir = bool(item.get("is_dir"))
|
||||
extension = None
|
||||
if not item_is_dir:
|
||||
ext = os.path.splitext(item_path)[1]
|
||||
extension = ext or None
|
||||
entries.append(
|
||||
SandboxFileNode(
|
||||
path=item_path,
|
||||
is_dir=item_is_dir,
|
||||
size=item.get("size"),
|
||||
mtime=item.get("mtime"),
|
||||
extension=extension,
|
||||
)
|
||||
)
|
||||
return entries
|
||||
|
||||
def download_file(self, *, path: str) -> SandboxFileDownloadTicket:
|
||||
from services.sandbox.sandbox_file_service import SandboxFileService
|
||||
|
||||
try:
|
||||
results = (
|
||||
pipeline(self._runtime)
|
||||
.add(
|
||||
build_detect_kind_command(path),
|
||||
error_message="Failed to check path in sandbox",
|
||||
)
|
||||
.execute(timeout=self._LIST_TIMEOUT_SECONDS, raise_on_error=True)
|
||||
)
|
||||
except PipelineExecutionError as exc:
|
||||
raise ValueError(str(exc)) from exc
|
||||
|
||||
kind = parse_kind_output(results[0].stdout, not_found_message="File not found in sandbox")
|
||||
|
||||
export_name = os.path.basename(path.rstrip("/")) or "workspace"
|
||||
filename = f"{export_name}.tar.gz" if kind == "dir" else (os.path.basename(path) or "file")
|
||||
export_id = uuid4().hex
|
||||
export_key = SandboxFilePaths.export(
|
||||
self._tenant_id,
|
||||
self._app_id,
|
||||
self._sandbox_id,
|
||||
export_id,
|
||||
)
|
||||
|
||||
sandbox_storage = SandboxFileService.get_storage()
|
||||
upload_url = sandbox_storage.get_upload_url(export_key, self._EXPORT_EXPIRES_IN_SECONDS)
|
||||
content_type = guess_content_type(filename)
|
||||
|
||||
if kind == "dir":
|
||||
archive_path = f"/tmp/{export_id}.tar.gz"
|
||||
try:
|
||||
(
|
||||
pipeline(self._runtime)
|
||||
.add(
|
||||
["tar", "-czf", archive_path, "-C", ".", path],
|
||||
error_message="Failed to archive directory in sandbox",
|
||||
)
|
||||
.add(
|
||||
build_upload_command(archive_path, upload_url, content_type=content_type),
|
||||
error_message="Failed to upload directory archive from sandbox",
|
||||
)
|
||||
.execute(timeout=self._UPLOAD_TIMEOUT_SECONDS, raise_on_error=True)
|
||||
)
|
||||
except PipelineExecutionError as exc:
|
||||
raise RuntimeError(str(exc)) from exc
|
||||
finally:
|
||||
try:
|
||||
pipeline(self._runtime).add(["rm", "-f", archive_path]).execute(timeout=self._LIST_TIMEOUT_SECONDS)
|
||||
except Exception as exc:
|
||||
# Best-effort cleanup; do not fail the download on cleanup issues.
|
||||
logger.debug("Failed to cleanup temp archive %s: %s", archive_path, exc)
|
||||
else:
|
||||
try:
|
||||
(
|
||||
pipeline(self._runtime)
|
||||
.add(
|
||||
build_upload_command(path, upload_url, content_type=content_type),
|
||||
error_message="Failed to upload file from sandbox",
|
||||
)
|
||||
.execute(timeout=self._UPLOAD_TIMEOUT_SECONDS, raise_on_error=True)
|
||||
)
|
||||
except PipelineExecutionError as exc:
|
||||
raise RuntimeError(str(exc)) from exc
|
||||
|
||||
download_url = sandbox_storage.get_download_url(export_key, self._EXPORT_EXPIRES_IN_SECONDS)
|
||||
return SandboxFileDownloadTicket(
|
||||
download_url=download_url,
|
||||
expires_in=self._EXPORT_EXPIRES_IN_SECONDS,
|
||||
export_id=export_id,
|
||||
)
|
||||
118
api/core/sandbox/inspector/script_utils.py
Normal file
118
api/core/sandbox/inspector/script_utils.py
Normal file
@ -0,0 +1,118 @@
|
||||
"""Shared helpers for sandbox inspector shell commands."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import mimetypes
|
||||
from typing import TypedDict, cast
|
||||
|
||||
_PYTHON_EXEC_CMD = 'if command -v python3 >/dev/null 2>&1; then py=python3; else py=python; fi; "$py" -c "$0" "$@"'
|
||||
_LIST_SCRIPT = r"""
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
path = sys.argv[1]
|
||||
recursive = sys.argv[2] == "1"
|
||||
|
||||
def norm(rel: str) -> str:
|
||||
rel = rel.replace("\\\\", "/")
|
||||
rel = rel.lstrip("./")
|
||||
return rel or "."
|
||||
|
||||
def stat_entry(full_path: str, rel_path: str) -> dict[str, object]:
|
||||
st = os.stat(full_path)
|
||||
is_dir = os.path.isdir(full_path)
|
||||
return {
|
||||
"path": norm(rel_path),
|
||||
"is_dir": is_dir,
|
||||
"size": None if is_dir else int(st.st_size),
|
||||
"mtime": int(st.st_mtime),
|
||||
}
|
||||
|
||||
entries = []
|
||||
if recursive:
|
||||
for root, dirs, files in os.walk(path):
|
||||
for d in dirs:
|
||||
fp = os.path.join(root, d)
|
||||
rp = os.path.relpath(fp, ".")
|
||||
entries.append(stat_entry(fp, rp))
|
||||
for f in files:
|
||||
fp = os.path.join(root, f)
|
||||
rp = os.path.relpath(fp, ".")
|
||||
entries.append(stat_entry(fp, rp))
|
||||
else:
|
||||
if os.path.isfile(path):
|
||||
rel_path = os.path.relpath(path, ".")
|
||||
entries.append(stat_entry(path, rel_path))
|
||||
else:
|
||||
for item in os.scandir(path):
|
||||
rel_path = os.path.relpath(item.path, ".")
|
||||
entries.append(stat_entry(item.path, rel_path))
|
||||
|
||||
print(json.dumps(entries))
|
||||
"""
|
||||
|
||||
|
||||
class ListedEntry(TypedDict):
|
||||
path: str
|
||||
is_dir: bool
|
||||
size: int | None
|
||||
mtime: int
|
||||
|
||||
|
||||
def build_list_command(path: str, recursive: bool) -> list[str]:
|
||||
return [
|
||||
"sh",
|
||||
"-c",
|
||||
_PYTHON_EXEC_CMD,
|
||||
_LIST_SCRIPT,
|
||||
path,
|
||||
"1" if recursive else "0",
|
||||
]
|
||||
|
||||
|
||||
def parse_list_output(stdout: bytes) -> list[ListedEntry]:
|
||||
try:
|
||||
raw = json.loads(stdout.decode("utf-8"))
|
||||
except Exception as exc:
|
||||
raise RuntimeError("Malformed sandbox file list output") from exc
|
||||
if not isinstance(raw, list):
|
||||
raise RuntimeError("Malformed sandbox file list output")
|
||||
return cast(list[ListedEntry], raw)
|
||||
|
||||
|
||||
def build_detect_kind_command(path: str) -> list[str]:
|
||||
return [
|
||||
"sh",
|
||||
"-c",
|
||||
'if [ -d "$1" ]; then echo dir; elif [ -f "$1" ]; then echo file; else exit 2; fi',
|
||||
"sh",
|
||||
path,
|
||||
]
|
||||
|
||||
|
||||
def parse_kind_output(stdout: bytes, *, not_found_message: str) -> str:
|
||||
kind = stdout.decode("utf-8", errors="replace").strip()
|
||||
if kind not in ("dir", "file"):
|
||||
raise ValueError(not_found_message)
|
||||
return kind
|
||||
|
||||
|
||||
def guess_content_type(filename: str) -> str | None:
|
||||
content_type, _ = mimetypes.guess_type(filename, strict=False)
|
||||
if content_type is None:
|
||||
return None
|
||||
if content_type.startswith("text/"):
|
||||
return f"{content_type}; charset=utf-8"
|
||||
if content_type == "application/json":
|
||||
return "application/json; charset=utf-8"
|
||||
return content_type
|
||||
|
||||
|
||||
def build_upload_command(src_path: str, upload_url: str, *, content_type: str | None) -> list[str]:
|
||||
command = ["curl", "-s", "-f", "-X", "PUT", "-T", src_path]
|
||||
if content_type:
|
||||
command.extend(["-H", f"Content-Type: {content_type}"])
|
||||
command.append(upload_url)
|
||||
return command
|
||||
133
api/core/sandbox/sandbox.py
Normal file
133
api/core/sandbox/sandbox.py
Normal file
@ -0,0 +1,133 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import uuid4
|
||||
|
||||
from libs.attr_map import AttrMap
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.sandbox.storage.sandbox_storage import SandboxStorage
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Sandbox:
|
||||
"""Represents a single sandbox environment.
|
||||
|
||||
Each ``Sandbox`` owns a stable, path-safe ``id`` (a 32-char hex
|
||||
UUID4) that is independent of the underlying provider's environment
|
||||
ID. Use ``sandbox.id`` for any path or resource namespacing
|
||||
(e.g. ``DifyCli(sandbox.id)``).
|
||||
|
||||
The raw provider identifier is still accessible via
|
||||
``sandbox.vm.metadata.id`` when needed (logging, API calls back to
|
||||
the provider, etc.).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vm: VirtualEnvironment,
|
||||
storage: SandboxStorage,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
app_id: str,
|
||||
assets_id: str,
|
||||
) -> None:
|
||||
self._id = uuid4().hex
|
||||
self._vm = vm
|
||||
self._storage = storage
|
||||
self._tenant_id = tenant_id
|
||||
self._user_id = user_id
|
||||
self._app_id = app_id
|
||||
self._assets_id = assets_id
|
||||
self._attributes = AttrMap()
|
||||
self._ready_event = threading.Event()
|
||||
self._cancel_event = threading.Event()
|
||||
self._init_error: Exception | None = None
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
"""Stable, path-safe identifier for this sandbox (UUID4 hex)."""
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def attrs(self) -> AttrMap:
|
||||
return self._attributes
|
||||
|
||||
@property
|
||||
def vm(self) -> VirtualEnvironment:
|
||||
return self._vm
|
||||
|
||||
@property
|
||||
def storage(self) -> SandboxStorage:
|
||||
return self._storage
|
||||
|
||||
@property
|
||||
def tenant_id(self) -> str:
|
||||
return self._tenant_id
|
||||
|
||||
@property
|
||||
def user_id(self) -> str:
|
||||
return self._user_id
|
||||
|
||||
@property
|
||||
def app_id(self) -> str:
|
||||
return self._app_id
|
||||
|
||||
@property
|
||||
def assets_id(self) -> str:
|
||||
return self._assets_id
|
||||
|
||||
def mark_ready(self) -> None:
|
||||
# Signal that sandbox initialization has completed successfully.
|
||||
self._ready_event.set()
|
||||
|
||||
def mark_failed(self, error: Exception) -> None:
|
||||
# Capture initialization error and unblock waiters.
|
||||
self._init_error = error
|
||||
self._ready_event.set()
|
||||
|
||||
def cancel_init(self) -> None:
|
||||
# Mark initialization as cancelled to stop background setup.
|
||||
self._cancel_event.set()
|
||||
self._ready_event.set()
|
||||
|
||||
def is_cancelled(self) -> bool:
|
||||
return self._cancel_event.is_set()
|
||||
|
||||
def wait_ready(self, timeout: float | None = None) -> None:
|
||||
# Block until initialization completes, fails, or is cancelled.
|
||||
if not self._ready_event.wait(timeout=timeout):
|
||||
raise TimeoutError("Sandbox initialization timed out")
|
||||
if self._cancel_event.is_set():
|
||||
raise RuntimeError("Sandbox initialization was cancelled")
|
||||
if self._init_error is not None:
|
||||
if isinstance(self._init_error, ValueError):
|
||||
raise RuntimeError(f"Sandbox initialization failed: {self._init_error}") from self._init_error
|
||||
else:
|
||||
raise RuntimeError("Sandbox initialization failed") from self._init_error
|
||||
|
||||
def mount(self) -> bool:
|
||||
return self._storage.mount(self._vm)
|
||||
|
||||
def unmount(self) -> bool:
|
||||
return self._storage.unmount(self._vm)
|
||||
|
||||
def release(self) -> None:
|
||||
self.cancel_init()
|
||||
sandbox_id = self.id
|
||||
try:
|
||||
self._storage.unmount(self._vm)
|
||||
logger.info("Sandbox storage unmounted: sandbox_id=%s", sandbox_id)
|
||||
except Exception:
|
||||
logger.exception("Failed to unmount sandbox storage: sandbox_id=%s", sandbox_id)
|
||||
|
||||
try:
|
||||
self._vm.release_environment()
|
||||
logger.info("Sandbox released: sandbox_id=%s", sandbox_id)
|
||||
except Exception:
|
||||
logger.exception("Failed to release sandbox: sandbox_id=%s", sandbox_id)
|
||||
3
api/core/sandbox/services/__init__.py
Normal file
3
api/core/sandbox/services/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .asset_download_service import AssetDownloadService
|
||||
|
||||
__all__ = ["AssetDownloadService"]
|
||||
140
api/core/sandbox/services/asset_download_service.py
Normal file
140
api/core/sandbox/services/asset_download_service.py
Normal file
@ -0,0 +1,140 @@
|
||||
"""Shell script builder for downloading / writing assets into a sandbox VM.
|
||||
|
||||
Generates a self-contained POSIX shell script that handles two kinds of
|
||||
``SandboxDownloadItem``:
|
||||
|
||||
- Items with *content* — written via base64 heredoc (sequential).
|
||||
- Items with *url* — fetched via ``curl``/``wget``/``python3`` with
|
||||
auto-detection, run as parallel background jobs.
|
||||
|
||||
Both kinds can be mixed freely in a single call.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import shlex
|
||||
import textwrap
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.zip_sandbox.entities import SandboxDownloadItem
|
||||
|
||||
|
||||
def _build_inline_commands(items: list[SandboxDownloadItem], root_var: str) -> str:
|
||||
"""Generate shell commands that write base64-encoded content to files."""
|
||||
lines: list[str] = []
|
||||
for idx, item in enumerate(items):
|
||||
assert item.content is not None
|
||||
dest = f"${{{root_var}}}/{shlex.quote(item.path)}"
|
||||
encoded = base64.b64encode(item.content).decode("ascii")
|
||||
lines.append(f'mkdir -p "$(dirname "{dest}")"')
|
||||
lines.append(f"base64 -d <<'_INLINE_{idx}' > \"{dest}\"")
|
||||
lines.append(encoded)
|
||||
lines.append(f"_INLINE_{idx}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _render_download_script(
|
||||
root_path: str,
|
||||
inline_commands: str,
|
||||
download_commands: str,
|
||||
need_downloader: bool,
|
||||
) -> str:
|
||||
python_download_cmd = (
|
||||
'python3 - "${url}" "${dest}" <<"PY"\n'
|
||||
"import sys\n"
|
||||
"import urllib.request\n"
|
||||
"url = sys.argv[1]\n"
|
||||
"dest = sys.argv[2]\n"
|
||||
"with urllib.request.urlopen(url) as resp:\n"
|
||||
" data = resp.read()\n"
|
||||
'with open(dest, "wb") as f:\n'
|
||||
" f.write(data)\n"
|
||||
"PY"
|
||||
)
|
||||
|
||||
# Only emit the downloader-detection block when there are remote items.
|
||||
if need_downloader:
|
||||
downloader_block = f"""\
|
||||
if command -v curl >/dev/null 2>&1; then
|
||||
download_cmd='curl -fsSL "${{url}}" -o "${{dest}}"'
|
||||
elif command -v wget >/dev/null 2>&1; then
|
||||
download_cmd='wget -q "${{url}}" -O "${{dest}}"'
|
||||
elif command -v python3 >/dev/null 2>&1; then
|
||||
download_cmd={shlex.quote(python_download_cmd)}
|
||||
else
|
||||
echo 'No downloader found (curl/wget/python3)' >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
fail_log="$(mktemp)"
|
||||
|
||||
download_one() {{
|
||||
file_path="$1"
|
||||
url="$2"
|
||||
dest="${{download_root}}/${{file_path}}"
|
||||
mkdir -p "$(dirname "${{dest}}")"
|
||||
eval "${{download_cmd}}" 2>/dev/null || echo "${{file_path}}" >> "${{fail_log}}"
|
||||
}}"""
|
||||
else:
|
||||
downloader_block = ""
|
||||
|
||||
# The failure-check block is only meaningful when downloads occurred.
|
||||
if need_downloader:
|
||||
wait_block = textwrap.dedent("""\
|
||||
wait
|
||||
|
||||
if [ -s "${fail_log}" ]; then
|
||||
mv "${fail_log}" "${download_root}/DOWNLOAD_FAILURES.txt"
|
||||
else
|
||||
rm -f "${fail_log}"
|
||||
fi""")
|
||||
else:
|
||||
wait_block = ""
|
||||
|
||||
script = f"""\
|
||||
download_root={shlex.quote(root_path)}
|
||||
mkdir -p "${{download_root}}"
|
||||
|
||||
{downloader_block}
|
||||
|
||||
{inline_commands}
|
||||
|
||||
{download_commands}
|
||||
|
||||
{wait_block}
|
||||
exit 0"""
|
||||
return script
|
||||
|
||||
|
||||
class AssetDownloadService:
|
||||
@staticmethod
|
||||
def build_download_script(
|
||||
items: list[SandboxDownloadItem],
|
||||
root_path: str,
|
||||
) -> str:
|
||||
"""Build a portable shell script to write inline assets and download remote ones.
|
||||
|
||||
Items with *content* are written first (sequential base64 decode),
|
||||
then items with *url* are fetched in parallel background jobs.
|
||||
The two kinds can be mixed freely in a single list.
|
||||
"""
|
||||
inline = [item for item in items if item.content is not None]
|
||||
remote = [item for item in items if item.content is None]
|
||||
|
||||
inline_commands = _build_inline_commands(inline, "download_root") if inline else ""
|
||||
|
||||
commands: list[str] = []
|
||||
for item in remote:
|
||||
path = shlex.quote(item.path)
|
||||
url = shlex.quote(item.url)
|
||||
commands.append(f"download_one {path} {url} &")
|
||||
download_commands = "\n".join(commands)
|
||||
|
||||
return _render_download_script(
|
||||
root_path,
|
||||
inline_commands,
|
||||
download_commands,
|
||||
need_downloader=bool(remote),
|
||||
)
|
||||
11
api/core/sandbox/storage/__init__.py
Normal file
11
api/core/sandbox/storage/__init__.py
Normal file
@ -0,0 +1,11 @@
|
||||
from .archive_storage import ArchiveSandboxStorage
|
||||
from .noop_storage import NoopSandboxStorage
|
||||
from .sandbox_file_storage import SandboxFilePaths
|
||||
from .sandbox_storage import SandboxStorage
|
||||
|
||||
__all__ = [
|
||||
"ArchiveSandboxStorage",
|
||||
"NoopSandboxStorage",
|
||||
"SandboxFilePaths",
|
||||
"SandboxStorage",
|
||||
]
|
||||
83
api/core/sandbox/storage/archive_storage.py
Normal file
83
api/core/sandbox/storage/archive_storage.py
Normal file
@ -0,0 +1,83 @@
|
||||
"""Archive-based sandbox storage for persisting sandbox state."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from core.virtual_environment.__base.helpers import pipeline
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
from extensions.storage.cached_presign_storage import CachedPresignStorage
|
||||
from extensions.storage.file_presign_storage import FilePresignStorage
|
||||
|
||||
from .sandbox_file_storage import SandboxFilePaths
|
||||
from .sandbox_storage import SandboxStorage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ARCHIVE_TIMEOUT = 300 # 5 minutes
|
||||
|
||||
|
||||
class ArchiveSandboxStorage(SandboxStorage):
|
||||
"""Archive-based storage for sandbox workspace persistence."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
sandbox_id: str,
|
||||
storage: BaseStorage,
|
||||
exclude_patterns: list[str] | None = None,
|
||||
):
|
||||
self._sandbox_id = sandbox_id
|
||||
self._exclude_patterns = exclude_patterns or []
|
||||
self._storage_key = SandboxFilePaths.archive(tenant_id, app_id, sandbox_id)
|
||||
self._storage = CachedPresignStorage(
|
||||
storage=FilePresignStorage(storage),
|
||||
cache_key_prefix="sandbox_archives",
|
||||
)
|
||||
|
||||
def mount(self, sandbox: VirtualEnvironment) -> bool:
|
||||
"""Load archive from storage into sandbox workspace."""
|
||||
if not self.exists():
|
||||
logger.debug("No archive found for sandbox %s, skipping mount", self._sandbox_id)
|
||||
return False
|
||||
|
||||
download_url = self._storage.get_download_url(self._storage_key, _ARCHIVE_TIMEOUT)
|
||||
archive = "archive.tar.gz"
|
||||
|
||||
(
|
||||
pipeline(sandbox)
|
||||
.add(["curl", "-fsSL", download_url, "-o", archive], error_message="Failed to download archive")
|
||||
.add(["sh", "-c", 'tar -xzf "$1" 2>/dev/null; exit $?', "sh", archive], error_message="Failed to extract")
|
||||
.add(["rm", archive], error_message="Failed to cleanup")
|
||||
.execute(timeout=_ARCHIVE_TIMEOUT, raise_on_error=True)
|
||||
)
|
||||
|
||||
logger.info("Mounted archive for sandbox %s", self._sandbox_id)
|
||||
return True
|
||||
|
||||
def unmount(self, sandbox: VirtualEnvironment) -> bool:
|
||||
"""Save sandbox workspace to storage as archive."""
|
||||
upload_url = self._storage.get_upload_url(self._storage_key, _ARCHIVE_TIMEOUT)
|
||||
archive = f"/tmp/{self._sandbox_id}.tar.gz"
|
||||
exclude_args = [f"--exclude={p}" for p in self._exclude_patterns]
|
||||
|
||||
(
|
||||
pipeline(sandbox)
|
||||
.add(["tar", "-czf", archive, *exclude_args, "-C", ".", "."], error_message="Failed to create archive")
|
||||
.add(["curl", "-sf", "-X", "PUT", "-T", archive, upload_url], error_message="Failed to upload archive")
|
||||
.execute(timeout=_ARCHIVE_TIMEOUT, raise_on_error=True)
|
||||
)
|
||||
logger.info("Unmounted archive for sandbox %s", self._sandbox_id)
|
||||
return True
|
||||
|
||||
def exists(self) -> bool:
|
||||
return self._storage.exists(self._storage_key)
|
||||
|
||||
def delete(self) -> None:
|
||||
try:
|
||||
self._storage.delete(self._storage_key)
|
||||
logger.info("Deleted archive for sandbox %s", self._sandbox_id)
|
||||
except Exception:
|
||||
logger.exception("Failed to delete archive for sandbox %s", self._sandbox_id)
|
||||
18
api/core/sandbox/storage/noop_storage.py
Normal file
18
api/core/sandbox/storage/noop_storage.py
Normal file
@ -0,0 +1,18 @@
|
||||
from core.sandbox.storage.sandbox_storage import SandboxStorage
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
|
||||
|
||||
class NoopSandboxStorage(SandboxStorage):
|
||||
"""A no-op storage implementation that does nothing on mount/unmount."""
|
||||
|
||||
def mount(self, sandbox: VirtualEnvironment) -> bool:
|
||||
return True
|
||||
|
||||
def unmount(self, sandbox: VirtualEnvironment) -> bool:
|
||||
return True
|
||||
|
||||
def exists(self) -> bool:
|
||||
return False
|
||||
|
||||
def delete(self) -> None:
|
||||
return
|
||||
21
api/core/sandbox/storage/sandbox_file_storage.py
Normal file
21
api/core/sandbox/storage/sandbox_file_storage.py
Normal file
@ -0,0 +1,21 @@
|
||||
"""Sandbox file storage key generation.
|
||||
|
||||
Provides SandboxFilePaths facade for generating storage keys for sandbox files.
|
||||
Storage instances are obtained via SandboxFileService.get_storage().
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class SandboxFilePaths:
|
||||
"""Facade for generating sandbox file storage keys."""
|
||||
|
||||
@staticmethod
|
||||
def export(tenant_id: str, app_id: str, sandbox_id: str, export_id: str) -> str:
|
||||
"""sandbox_files/{tenant}/{app}/{sandbox}/{export_id}/{filename}"""
|
||||
return f"sandbox_files/{tenant_id}/{app_id}/{sandbox_id}/{export_id}"
|
||||
|
||||
@staticmethod
|
||||
def archive(tenant_id: str, app_id: str, sandbox_id: str) -> str:
|
||||
"""sandbox_archives/{tenant}/{app}/{sandbox}.tar.gz"""
|
||||
return f"sandbox_archives/{tenant_id}/{app_id}/{sandbox_id}.tar.gz"
|
||||
21
api/core/sandbox/storage/sandbox_storage.py
Normal file
21
api/core/sandbox/storage/sandbox_storage.py
Normal file
@ -0,0 +1,21 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
|
||||
|
||||
class SandboxStorage(ABC):
|
||||
@abstractmethod
|
||||
def mount(self, sandbox: VirtualEnvironment) -> bool:
|
||||
"""Load files from storage into VM. Returns True if files were loaded."""
|
||||
|
||||
@abstractmethod
|
||||
def unmount(self, sandbox: VirtualEnvironment) -> bool:
|
||||
"""Save files from VM to storage. Returns True if files were saved."""
|
||||
|
||||
@abstractmethod
|
||||
def exists(self) -> bool:
|
||||
"""Check if storage has saved data."""
|
||||
|
||||
@abstractmethod
|
||||
def delete(self) -> None:
|
||||
"""Delete saved data from storage."""
|
||||
2
api/core/sandbox/utils/__init__.py
Normal file
2
api/core/sandbox/utils/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
# Sandbox utilities
|
||||
# Connection helpers have been moved to core.virtual_environment.helpers
|
||||
22
api/core/sandbox/utils/debug.py
Normal file
22
api/core/sandbox/utils/debug.py
Normal file
@ -0,0 +1,22 @@
|
||||
"""Sandbox debug utilities. TODO: Remove this module when sandbox debugging is complete."""
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
SANDBOX_DEBUG_ENABLED = True
|
||||
|
||||
|
||||
def sandbox_debug(tag: str, message: str, data: Any = None) -> None:
|
||||
if not SANDBOX_DEBUG_ENABLED:
|
||||
return
|
||||
|
||||
# Lazy import to avoid circular dependency
|
||||
from core.callback_handler.agent_tool_callback_handler import print_text
|
||||
|
||||
print_text(f"\n[{tag}]\n", color="blue")
|
||||
if data is not None:
|
||||
print_text(f"{message}: {data}\n", color="blue")
|
||||
else:
|
||||
print_text(f"{message}\n", color="blue")
|
||||
48
api/core/sandbox/utils/encryption.py
Normal file
48
api/core/sandbox/utils/encryption.py
Normal file
@ -0,0 +1,48 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.helper.provider_cache import ProviderCredentialsCache
|
||||
from core.helper.provider_encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
|
||||
|
||||
|
||||
class SandboxProviderConfigCache(ProviderCredentialsCache):
|
||||
def __init__(self, tenant_id: str, provider_type: str):
|
||||
super().__init__(tenant_id=tenant_id, provider_type=provider_type)
|
||||
|
||||
def _generate_cache_key(self, **kwargs) -> str:
|
||||
tenant_id = kwargs["tenant_id"]
|
||||
provider_type = kwargs["provider_type"]
|
||||
return f"sandbox_config:tenant_id:{tenant_id}:provider_type:{provider_type}"
|
||||
|
||||
|
||||
def create_sandbox_config_encrypter(
|
||||
tenant_id: str,
|
||||
config_schema: list[BasicProviderConfig],
|
||||
provider_type: str,
|
||||
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
|
||||
cache = SandboxProviderConfigCache(tenant_id=tenant_id, provider_type=provider_type)
|
||||
return create_provider_encrypter(tenant_id=tenant_id, config=config_schema, cache=cache)
|
||||
|
||||
|
||||
def masked_config(
|
||||
schemas: list[BasicProviderConfig],
|
||||
config: Mapping[str, Any],
|
||||
) -> Mapping[str, Any]:
|
||||
masked = dict(config)
|
||||
configs = {x.name: x for x in schemas}
|
||||
for key, value in config.items():
|
||||
schema = configs.get(key)
|
||||
if not schema:
|
||||
masked[key] = value
|
||||
continue
|
||||
if schema.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if not isinstance(value, str):
|
||||
continue
|
||||
if len(value) <= 4:
|
||||
masked[key] = "*" * len(value)
|
||||
else:
|
||||
masked[key] = value[:2] + "*" * (len(value) - 4) + value[-2:]
|
||||
else:
|
||||
masked[key] = value
|
||||
return masked
|
||||
11
api/core/skill/__init__.py
Normal file
11
api/core/skill/__init__.py
Normal file
@ -0,0 +1,11 @@
|
||||
from .constants import SkillAttrs
|
||||
from .entities import ToolDependencies, ToolDependency, ToolReference
|
||||
from .skill_manager import SkillManager
|
||||
|
||||
__all__ = [
|
||||
"SkillAttrs",
|
||||
"SkillManager",
|
||||
"ToolDependencies",
|
||||
"ToolDependency",
|
||||
"ToolReference",
|
||||
]
|
||||
6
api/core/skill/assembler/__init__.py
Normal file
6
api/core/skill/assembler/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
from core.skill.assembler.assemblers import SkillBundleAssembler, SkillDocumentAssembler
|
||||
|
||||
__all__ = [
|
||||
"SkillBundleAssembler",
|
||||
"SkillDocumentAssembler",
|
||||
]
|
||||
80
api/core/skill/assembler/assemblers.py
Normal file
80
api/core/skill/assembler/assemblers.py
Normal file
@ -0,0 +1,80 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
from core.app.entities.app_asset_entities import AppAssetFileTree
|
||||
from core.skill.assembler.common import (
|
||||
build_skill_graph,
|
||||
compute_transitive_dependance,
|
||||
expand_referenced_skill_ids,
|
||||
get_metadata,
|
||||
process_skill_content,
|
||||
)
|
||||
from core.skill.entities.skill_bundle import Skill, SkillBundle, SkillDependance
|
||||
from core.skill.entities.skill_document import SkillDocument
|
||||
|
||||
|
||||
class SkillBundleAssembler:
|
||||
_file_tree: AppAssetFileTree
|
||||
|
||||
def __init__(self, file_tree: AppAssetFileTree) -> None:
|
||||
self._file_tree = file_tree
|
||||
|
||||
def assemble_bundle(
|
||||
self,
|
||||
documents: Mapping[str, SkillDocument],
|
||||
assets_id: str,
|
||||
) -> SkillBundle:
|
||||
direct_skills: dict[str, Skill] = {}
|
||||
for skill_id, doc in documents.items():
|
||||
metadata = get_metadata(doc.content, doc.metadata)
|
||||
direct_dependance = SkillDependance.from_metadata(metadata)
|
||||
direct_skills[skill_id] = Skill(
|
||||
skill_id=skill_id,
|
||||
direct_dependance=direct_dependance,
|
||||
dependance=direct_dependance,
|
||||
content=process_skill_content(doc.content, metadata, self._file_tree, skill_id),
|
||||
)
|
||||
|
||||
graph = build_skill_graph(direct_skills, self._file_tree)
|
||||
transitive_map = compute_transitive_dependance(direct_skills, graph)
|
||||
|
||||
compiled_skills: dict[str, Skill] = {}
|
||||
for skill_id, skill in direct_skills.items():
|
||||
compiled_skills[skill_id] = skill.model_copy(update={"dependance": transitive_map[skill_id]})
|
||||
|
||||
return SkillBundle(asset_tree=self._file_tree, assets_id=assets_id, skills=compiled_skills)
|
||||
|
||||
|
||||
class SkillDocumentAssembler:
|
||||
_bundle: SkillBundle
|
||||
|
||||
def __init__(self, bundle: SkillBundle) -> None:
|
||||
self._bundle = bundle
|
||||
|
||||
def assemble_document(self, document: SkillDocument, base_path: str = "") -> Skill:
|
||||
metadata = get_metadata(document.content, document.metadata)
|
||||
direct_dependance = SkillDependance.from_metadata(metadata)
|
||||
resolved_content = process_skill_content(
|
||||
document.content,
|
||||
metadata,
|
||||
self._bundle.asset_tree,
|
||||
document.skill_id,
|
||||
base_path,
|
||||
)
|
||||
|
||||
transitive_dependance = direct_dependance
|
||||
known_skill_ids = set(self._bundle.skills.keys())
|
||||
referenced_skill_ids = expand_referenced_skill_ids(
|
||||
direct_dependance.files, known_skill_ids, self._bundle.asset_tree
|
||||
)
|
||||
for skill_id in sorted(referenced_skill_ids):
|
||||
referenced_skill = self._bundle.get(skill_id)
|
||||
if referenced_skill is None:
|
||||
continue
|
||||
transitive_dependance = transitive_dependance | referenced_skill.dependance
|
||||
|
||||
return Skill(
|
||||
skill_id=document.skill_id,
|
||||
direct_dependance=direct_dependance,
|
||||
dependance=transitive_dependance,
|
||||
content=resolved_content,
|
||||
)
|
||||
136
api/core/skill/assembler/common.py
Normal file
136
api/core/skill/assembler/common.py
Normal file
@ -0,0 +1,136 @@
|
||||
from collections import deque
|
||||
from collections.abc import Mapping
|
||||
|
||||
from core.app.entities.app_asset_entities import AppAssetFileTree, AssetNodeType
|
||||
from core.skill.assembler.replacers import (
|
||||
FILE_PATTERN,
|
||||
TOOL_METADATA_PATTERN,
|
||||
FileReplacer,
|
||||
Replacer,
|
||||
ToolGroupReplacer,
|
||||
ToolReplacer,
|
||||
)
|
||||
from core.skill.entities.skill_bundle import Skill, SkillDependance
|
||||
from core.skill.entities.skill_metadata import FileReference, SkillMetadata, ToolReference
|
||||
|
||||
|
||||
def process_skill_content(
|
||||
content: str,
|
||||
metadata: SkillMetadata,
|
||||
file_tree: AppAssetFileTree,
|
||||
current_id: str,
|
||||
base_path: str = "",
|
||||
) -> str:
|
||||
"""Resolve all placeholders in content through the ordered replacer pipeline."""
|
||||
replacers: list[Replacer] = [
|
||||
FileReplacer(file_tree, current_id, base_path),
|
||||
ToolGroupReplacer(metadata),
|
||||
ToolReplacer(metadata),
|
||||
]
|
||||
for replacer in replacers:
|
||||
content = replacer.resolve(content)
|
||||
return content
|
||||
|
||||
|
||||
def get_metadata(content: str, metadata: SkillMetadata) -> SkillMetadata:
|
||||
"""Parse effective metadata from content placeholders and raw metadata."""
|
||||
tools: dict[str, ToolReference] = {}
|
||||
# find all tool refs actually used in content
|
||||
for match in TOOL_METADATA_PATTERN.finditer(content):
|
||||
provider, name, uuid = match.group(1), match.group(2), match.group(3)
|
||||
tool_ref = metadata.tools.get(uuid)
|
||||
if tool_ref is None:
|
||||
raise ValueError(f"Tool reference with UUID {uuid} not found in metadata")
|
||||
tool_ref.uuid = uuid
|
||||
tool_ref.tool_name = name
|
||||
tool_ref.provider = provider
|
||||
tools[uuid] = tool_ref
|
||||
|
||||
# find all file refs
|
||||
files: set[FileReference] = set()
|
||||
for match in FILE_PATTERN.finditer(content):
|
||||
source, asset_id = match.group(1), match.group(2)
|
||||
files.add(FileReference(source=source, asset_id=asset_id))
|
||||
|
||||
return SkillMetadata(tools=tools, files=files)
|
||||
|
||||
|
||||
def build_skill_graph(skills: Mapping[str, Skill], file_tree: AppAssetFileTree) -> dict[str, set[str]]:
|
||||
"""Build adjacency list: skill_id -> referenced skill IDs."""
|
||||
known_skill_ids = set(skills.keys())
|
||||
graph: dict[str, set[str]] = {skill_id: set() for skill_id in known_skill_ids}
|
||||
|
||||
for skill_id, skill in skills.items():
|
||||
graph[skill_id] = expand_referenced_skill_ids(skill.direct_dependance.files, known_skill_ids, file_tree)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
def compute_transitive_dependance(
|
||||
skills: Mapping[str, Skill],
|
||||
graph: Mapping[str, set[str]],
|
||||
) -> dict[str, SkillDependance]:
|
||||
"""Compute transitive dependency closure with fixed-point iteration."""
|
||||
dependance_map = {skill_id: skill.direct_dependance for skill_id, skill in skills.items()}
|
||||
|
||||
changed = True
|
||||
while changed:
|
||||
changed = False
|
||||
for skill_id in sorted(skills.keys()):
|
||||
merged = dependance_map[skill_id]
|
||||
for dep_skill_id in sorted(graph.get(skill_id, set())):
|
||||
if dep_skill_id == skill_id:
|
||||
continue
|
||||
merged = merged | dependance_map[dep_skill_id]
|
||||
|
||||
if merged != dependance_map[skill_id]:
|
||||
dependance_map[skill_id] = merged
|
||||
changed = True
|
||||
|
||||
return dependance_map
|
||||
|
||||
|
||||
def expand_referenced_skill_ids(
|
||||
refs: set[FileReference],
|
||||
known_skill_ids: set[str],
|
||||
file_tree: AppAssetFileTree,
|
||||
) -> set[str]:
|
||||
"""Resolve file/folder references to concrete known skill IDs."""
|
||||
resolved: set[str] = set()
|
||||
for ref in refs:
|
||||
node = file_tree.get(ref.asset_id)
|
||||
if node is None:
|
||||
continue
|
||||
|
||||
if node.node_type == AssetNodeType.FILE:
|
||||
if node.id in known_skill_ids:
|
||||
resolved.add(node.id)
|
||||
continue
|
||||
|
||||
descendant_ids = file_tree.get_descendant_ids(node.id)
|
||||
for descendant_id in descendant_ids:
|
||||
descendant = file_tree.get(descendant_id)
|
||||
if descendant is None or descendant.node_type != AssetNodeType.FILE:
|
||||
continue
|
||||
if descendant_id in known_skill_ids:
|
||||
resolved.add(descendant_id)
|
||||
|
||||
return resolved
|
||||
|
||||
|
||||
def collect_transitive_skill_ids(
|
||||
root_skill_ids: set[str],
|
||||
graph: Mapping[str, set[str]],
|
||||
) -> set[str]:
|
||||
"""Collect all transitively reachable skill IDs from roots via BFS."""
|
||||
visited: set[str] = set()
|
||||
queue = deque(sorted(root_skill_ids))
|
||||
while queue:
|
||||
current = queue.popleft()
|
||||
if current in visited:
|
||||
continue
|
||||
visited.add(current)
|
||||
for next_skill_id in sorted(graph.get(current, set())):
|
||||
if next_skill_id not in visited:
|
||||
queue.append(next_skill_id)
|
||||
return visited
|
||||
108
api/core/skill/assembler/replacers.py
Normal file
108
api/core/skill/assembler/replacers.py
Normal file
@ -0,0 +1,108 @@
|
||||
"""Placeholder replacers for skill content.
|
||||
|
||||
Each replacer handles one category of ``§[...]§`` placeholder via the unified
|
||||
``Replacer`` protocol. The shared ``resolve_content`` pipeline in
|
||||
``core.skill.assembler.common`` builds a ``list[Replacer]`` and applies them
|
||||
in order:
|
||||
|
||||
``FileReplacer`` → ``ToolGroupReplacer`` → ``ToolReplacer``
|
||||
|
||||
``ToolGroupReplacer`` MUST run before ``ToolReplacer`` so that group brackets
|
||||
``[§[tool]...§, §[tool]...§]`` are resolved atomically; otherwise individual
|
||||
tool replacement would destroy the group structure.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Protocol
|
||||
|
||||
from core.app.entities.app_asset_entities import AppAssetFileTree
|
||||
from core.skill.entities.skill_metadata import SkillMetadata
|
||||
|
||||
TOOL_METADATA_PATTERN: re.Pattern[str] = re.compile(r"§\[tool\]\.\[([^\]]+)\]\.\[([^\]]+)\]\.\[([^\]]+)\]§")
|
||||
TOOL_PATTERN: re.Pattern[str] = re.compile(r"§\[tool\]\.\[.*?\]\.\[.*?\]\.\[(.*?)\]§")
|
||||
TOOL_GROUP_PATTERN: re.Pattern[str] = re.compile(
|
||||
r"\[\s*§\[tool\]\.\[[^\]]+\]\.\[[^\]]+\]\.\[[^\]]+\]§"
|
||||
r"(?:\s*,\s*§\[tool\]\.\[[^\]]+\]\.\[[^\]]+\]\.\[[^\]]+\]§)*\s*\]"
|
||||
)
|
||||
FILE_PATTERN: re.Pattern[str] = re.compile(r"§\[file\]\.\[([^\]]+)\]\.\[([^\]]+)\]§")
|
||||
|
||||
|
||||
class Replacer(Protocol):
|
||||
def resolve(self, content: str) -> str: ...
|
||||
|
||||
|
||||
class FileReplacer:
|
||||
_tree: AppAssetFileTree
|
||||
_current_id: str
|
||||
_base_path: str
|
||||
|
||||
def __init__(self, tree: AppAssetFileTree, current_id: str, base_path: str = "") -> None:
|
||||
self._tree = tree
|
||||
self._current_id = current_id
|
||||
self._base_path = base_path.rstrip("/")
|
||||
|
||||
def resolve(self, content: str) -> str:
|
||||
return FILE_PATTERN.sub(self._replace_match, content)
|
||||
|
||||
def _replace_match(self, match: re.Match[str]) -> str:
|
||||
target_id = match.group(2)
|
||||
source_node = self._tree.get(self._current_id)
|
||||
target_node = self._tree.get(target_id)
|
||||
|
||||
if target_node is None:
|
||||
return "[File not found]"
|
||||
|
||||
if source_node is not None:
|
||||
return self._tree.relative_path(source_node, target_node)
|
||||
|
||||
full_path = self._tree.get_path(target_node.id)
|
||||
if self._base_path:
|
||||
return f"{self._base_path}/{full_path}"
|
||||
return full_path
|
||||
|
||||
|
||||
class ToolReplacer:
|
||||
_metadata: SkillMetadata
|
||||
|
||||
def __init__(self, metadata: SkillMetadata) -> None:
|
||||
self._metadata = metadata
|
||||
|
||||
def resolve(self, content: str) -> str:
|
||||
return TOOL_PATTERN.sub(self._replace_match, content)
|
||||
|
||||
def _replace_match(self, match: re.Match[str]) -> str:
|
||||
tool_id = match.group(1)
|
||||
tool_ref = self._metadata.tools.get(tool_id)
|
||||
if tool_ref is None:
|
||||
return f"[Tool not found or disabled: {tool_id}]"
|
||||
if not tool_ref.enabled:
|
||||
return ""
|
||||
return f"[Executable: {tool_ref.tool_name}_{tool_ref.uuid} --help command]"
|
||||
|
||||
|
||||
class ToolGroupReplacer:
|
||||
_metadata: SkillMetadata
|
||||
|
||||
def __init__(self, metadata: SkillMetadata) -> None:
|
||||
self._metadata = metadata
|
||||
|
||||
def resolve(self, content: str) -> str:
|
||||
return TOOL_GROUP_PATTERN.sub(self._replace_match, content)
|
||||
|
||||
def _replace_match(self, match: re.Match[str]) -> str:
|
||||
group_text = match.group(0)
|
||||
enabled_renders: list[str] = []
|
||||
|
||||
for tool_match in TOOL_PATTERN.finditer(group_text):
|
||||
tool_id = tool_match.group(1)
|
||||
tool_ref = self._metadata.tools.get(tool_id)
|
||||
if tool_ref is None:
|
||||
enabled_renders.append(f"[Tool not found or disabled: {tool_id}]")
|
||||
continue
|
||||
if not tool_ref.enabled:
|
||||
continue
|
||||
enabled_renders.append(f"[Executable: {tool_ref.tool_name}_{tool_ref.uuid} --help command]")
|
||||
|
||||
if not enabled_renders:
|
||||
return ""
|
||||
return "[" + ", ".join(enabled_renders) + "]"
|
||||
6
api/core/skill/constants.py
Normal file
6
api/core/skill/constants.py
Normal file
@ -0,0 +1,6 @@
|
||||
from core.skill.entities.skill_bundle import SkillBundle
|
||||
from libs.attr_map import AttrKey
|
||||
|
||||
|
||||
class SkillAttrs:
|
||||
BUNDLE = AttrKey("skill_bundle", SkillBundle)
|
||||
29
api/core/skill/entities/__init__.py
Normal file
29
api/core/skill/entities/__init__.py
Normal file
@ -0,0 +1,29 @@
|
||||
from .skill_bundle import Skill, SkillBundle, SkillDependance
|
||||
from .skill_document import SkillDocument
|
||||
from .skill_metadata import (
|
||||
FileReference,
|
||||
SkillMetadata,
|
||||
ToolConfiguration,
|
||||
ToolFieldConfig,
|
||||
ToolReference,
|
||||
)
|
||||
from .tool_access_policy import ToolAccessDescription, ToolAccessPolicy, ToolDescription, ToolInvocationRequest
|
||||
from .tool_dependencies import ToolDependencies, ToolDependency
|
||||
|
||||
__all__ = [
|
||||
"FileReference",
|
||||
"Skill",
|
||||
"SkillBundle",
|
||||
"SkillDependance",
|
||||
"SkillDocument",
|
||||
"SkillMetadata",
|
||||
"ToolAccessDescription",
|
||||
"ToolAccessPolicy",
|
||||
"ToolConfiguration",
|
||||
"ToolDependencies",
|
||||
"ToolDependency",
|
||||
"ToolDescription",
|
||||
"ToolFieldConfig",
|
||||
"ToolInvocationRequest",
|
||||
"ToolReference",
|
||||
]
|
||||
16
api/core/skill/entities/api_entities.py
Normal file
16
api/core/skill/entities/api_entities.py
Normal file
@ -0,0 +1,16 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.skill.entities.tool_dependencies import ToolDependency
|
||||
|
||||
|
||||
class NodeSkillInfo(BaseModel):
|
||||
"""Information about skills referenced by a workflow node.
|
||||
|
||||
Used by the whole-workflow skills endpoint to return per-node
|
||||
tool dependency information.
|
||||
"""
|
||||
|
||||
node_id: str = Field(description="The node ID")
|
||||
tool_dependencies: list[ToolDependency] = Field(
|
||||
default_factory=list, description="Tool dependencies extracted from skill prompts"
|
||||
)
|
||||
94
api/core/skill/entities/skill_bundle.py
Normal file
94
api/core/skill/entities/skill_bundle.py
Normal file
@ -0,0 +1,94 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.app.entities.app_asset_entities import AppAssetFileTree
|
||||
from core.skill.entities.skill_metadata import FileReference
|
||||
from core.skill.entities.tool_dependencies import ToolDependencies, ToolDependency
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.skill.entities.skill_metadata import SkillMetadata
|
||||
|
||||
|
||||
class SkillDependance(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
tools: ToolDependencies = Field(description="Direct tool dependencies parsed from this skill only")
|
||||
|
||||
files: set[FileReference] = Field(
|
||||
default_factory=set,
|
||||
description="Direct file references parsed from this skill only",
|
||||
)
|
||||
|
||||
def __or__(self, other: "SkillDependance") -> "SkillDependance":
|
||||
return SkillDependance(tools=self.tools.merge(other.tools), files=self.files | other.files)
|
||||
|
||||
@staticmethod
|
||||
def from_metadata(metadata: "SkillMetadata") -> "SkillDependance":
|
||||
"""Convert parsed metadata into direct tool/file dependency model."""
|
||||
from core.skill.entities.skill_metadata import ToolReference
|
||||
|
||||
dep_map: dict[str, ToolDependency] = {}
|
||||
ref_map: dict[str, ToolReference] = {}
|
||||
|
||||
for tool_ref in metadata.tools.values():
|
||||
dep_map.setdefault(
|
||||
tool_ref.tool_id(),
|
||||
ToolDependency(
|
||||
type=tool_ref.type,
|
||||
provider=tool_ref.provider,
|
||||
tool_name=tool_ref.tool_name,
|
||||
enabled=tool_ref.enabled,
|
||||
),
|
||||
)
|
||||
ref_map.setdefault(tool_ref.uuid, tool_ref)
|
||||
|
||||
return SkillDependance(
|
||||
tools=ToolDependencies(
|
||||
dependencies=[dep_map[key] for key in sorted(dep_map.keys())],
|
||||
references=[ref_map[key] for key in sorted(ref_map.keys())],
|
||||
),
|
||||
files=metadata.files,
|
||||
)
|
||||
|
||||
|
||||
class Skill(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
skill_id: str = Field(description="Unique identifier for this skill, same with skill_id")
|
||||
|
||||
direct_dependance: SkillDependance = Field(description="Direct dependencies parsed from this skill only")
|
||||
|
||||
dependance: SkillDependance = Field(description="All dependencies including transitive closure")
|
||||
|
||||
content: str = Field(description="Resolved content with all references replaced")
|
||||
|
||||
@property
|
||||
def tools(self) -> ToolDependencies:
|
||||
return self.dependance.tools
|
||||
|
||||
|
||||
class SkillBundle(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
asset_tree: AppAssetFileTree = Field(description="Asset tree for this bundle")
|
||||
|
||||
assets_id: str = Field(description="Assets ID this bundle belongs to")
|
||||
|
||||
skills: dict[str, Skill] = Field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def entries(self) -> dict[str, Skill]:
|
||||
return self.skills
|
||||
|
||||
def get(self, skill_id: str) -> Skill | None:
|
||||
return self.skills.get(skill_id)
|
||||
|
||||
def get_tool_dependencies(self) -> ToolDependencies:
|
||||
merged = ToolDependencies()
|
||||
for skill in self.skills.values():
|
||||
merged = merged.merge(skill.dependance.tools)
|
||||
return merged
|
||||
|
||||
def put(self, skill: Skill) -> None:
|
||||
self.skills[skill.skill_id] = skill
|
||||
17
api/core/skill/entities/skill_document.py
Normal file
17
api/core/skill/entities/skill_document.py
Normal file
@ -0,0 +1,17 @@
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.skill.entities.skill_metadata import SkillMetadata
|
||||
|
||||
|
||||
class SkillFile(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class SkillDocument(BaseModel):
|
||||
"""Input document for skill compilation."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
skill_id: str = Field(description="Unique identifier, must match SkillAsset.asset_id")
|
||||
content: str = Field(description="Raw content with reference placeholders")
|
||||
metadata: SkillMetadata = Field(default_factory=SkillMetadata, description="Additional metadata for this skill")
|
||||
113
api/core/skill/entities/skill_metadata.py
Normal file
113
api/core/skill/entities/skill_metadata.py
Normal file
@ -0,0 +1,113 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
|
||||
|
||||
class ToolFieldConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
id: str
|
||||
value: Any
|
||||
auto: bool = False
|
||||
|
||||
|
||||
class ToolConfiguration(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
fields: list[ToolFieldConfig] = Field(
|
||||
default_factory=list, description="List of field configurations for this tool"
|
||||
)
|
||||
|
||||
def default_values(self) -> dict[str, Any]:
|
||||
return {field.id: field.value for field in self.fields if field.value is not None}
|
||||
|
||||
|
||||
def create_tool_id(provider: str, tool_name: str) -> str:
|
||||
return f"{provider}.{tool_name}"
|
||||
|
||||
|
||||
class ToolReference(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
uuid: str = Field(
|
||||
default="",
|
||||
description=(
|
||||
"Unique identifier for this tool reference, used to distinguish multiple references to the same tool"
|
||||
),
|
||||
)
|
||||
type: ToolProviderType = Field(description="The provider type of the tool")
|
||||
provider: str = Field(
|
||||
default="",
|
||||
description="The provider name of the tool plugin. Can be inferred from placeholders during compilation.",
|
||||
)
|
||||
tool_name: str = Field(
|
||||
default="",
|
||||
description=(
|
||||
"The tool name defined in the provider plugin. Can be inferred from placeholders during compilation."
|
||||
),
|
||||
)
|
||||
enabled: bool = Field(default=True, description="Whether this tool reference is enabled")
|
||||
credential_id: str | None = Field(
|
||||
default=None,
|
||||
description="Credential ID used to resolve credentials when invoking the tool.",
|
||||
)
|
||||
configuration: ToolConfiguration | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Optional configuration for this tool reference, used to provide "
|
||||
"additional parameters when invoking the tool"
|
||||
),
|
||||
)
|
||||
|
||||
def reference_id(self) -> str:
|
||||
return f"{self.provider}.{self.tool_name}.{self.uuid}"
|
||||
|
||||
def tool_id(self) -> str:
|
||||
return f"{self.provider}.{self.tool_name}"
|
||||
|
||||
|
||||
class FileReference(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
source: str = Field(default="app")
|
||||
asset_id: str
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def normalize_input(cls, data: Any) -> Any:
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
if "asset_id" in data and "source" in data:
|
||||
return {"source": data.get("source", "app"), "asset_id": data["asset_id"]}
|
||||
# front end support
|
||||
if "id" in data:
|
||||
return {"source": "app", "asset_id": data["id"]}
|
||||
return data
|
||||
|
||||
|
||||
class SkillMetadata(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
tools: dict[str, ToolReference] = Field(default_factory=dict)
|
||||
files: set[FileReference] = Field(default_factory=set)
|
||||
|
||||
@field_validator("files", mode="before")
|
||||
@classmethod
|
||||
def coerce_files_to_set(cls, v: Any) -> set[FileReference] | Any:
|
||||
if isinstance(v, list):
|
||||
refs: set[FileReference] = set()
|
||||
for item in v:
|
||||
if isinstance(item, dict):
|
||||
refs.add(FileReference.model_validate(item))
|
||||
elif isinstance(item, FileReference):
|
||||
refs.add(item)
|
||||
return refs
|
||||
if isinstance(v, dict):
|
||||
refs = set()
|
||||
for item in v.values():
|
||||
if isinstance(item, dict):
|
||||
refs.add(FileReference.model_validate(item))
|
||||
return refs
|
||||
return v
|
||||
145
api/core/skill/entities/tool_access_policy.py
Normal file
145
api/core/skill/entities/tool_access_policy.py
Normal file
@ -0,0 +1,145 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.skill.entities.tool_dependencies import ToolDependencies
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
|
||||
|
||||
class ToolDescription(BaseModel):
|
||||
"""Immutable identifier for a tool (type + provider + name)."""
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
tool_type: ToolProviderType
|
||||
provider: str
|
||||
tool_name: str
|
||||
|
||||
def tool_id(self) -> str:
|
||||
return f"{self.tool_type.value}:{self.provider}:{self.tool_name}"
|
||||
|
||||
|
||||
class ToolAccessDescription(BaseModel):
|
||||
"""
|
||||
Per-tool access descriptor that bundles identity with allowed credentials.
|
||||
|
||||
Each allowed tool is represented by exactly one ``ToolAccessDescription``.
|
||||
``allowed_credentials`` captures the set of credential IDs that may be used
|
||||
when invoking this tool:
|
||||
|
||||
* **empty set** – the tool requires no special credential; only requests
|
||||
*without* a ``credential_id`` are accepted.
|
||||
* **non-empty set** – the tool requires an explicit credential; the
|
||||
request's ``credential_id`` must be a member of this set.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
tool_type: ToolProviderType
|
||||
provider: str
|
||||
tool_name: str
|
||||
allowed_credentials: frozenset[str] = Field(default_factory=frozenset)
|
||||
|
||||
def tool_id(self) -> str:
|
||||
return f"{self.tool_type.value}:{self.provider}:{self.tool_name}"
|
||||
|
||||
def is_credential_allowed(self, credential_id: str | None) -> bool:
|
||||
"""Check whether *credential_id* satisfies this tool's credential policy.
|
||||
|
||||
* No credentials registered (``allowed_credentials`` is empty) →
|
||||
only requests *without* a credential are accepted.
|
||||
* Credentials registered → the supplied ``credential_id`` must be in
|
||||
the set.
|
||||
"""
|
||||
if credential_id is None or credential_id == "":
|
||||
return True
|
||||
|
||||
return credential_id in self.allowed_credentials
|
||||
|
||||
|
||||
class ToolInvocationRequest(BaseModel):
|
||||
"""A request to invoke a specific tool with optional credential."""
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
tool_type: ToolProviderType
|
||||
provider: str
|
||||
tool_name: str
|
||||
credential_id: str | None = None
|
||||
|
||||
@property
|
||||
def tool_description(self) -> ToolDescription:
|
||||
return ToolDescription(tool_type=self.tool_type, provider=self.provider, tool_name=self.tool_name)
|
||||
|
||||
|
||||
class ToolAccessPolicy(BaseModel):
|
||||
"""
|
||||
Determines whether a tool invocation is allowed based on ToolDependencies.
|
||||
|
||||
The policy is built exclusively from ``ToolDependencies.references`` – each
|
||||
``ToolReference`` declares both the tool identity *and* the credential that
|
||||
may be used. ``ToolDependencies.dependencies`` is a de-duplicated identity
|
||||
list and does not participate in access-control decisions.
|
||||
|
||||
Rules:
|
||||
1. The tool must appear in at least one reference.
|
||||
2. If references for the tool carry credential IDs, the request must supply
|
||||
one of those exact IDs.
|
||||
3. If no reference for the tool carries a credential ID, the request must
|
||||
*not* supply one (use default/ambient credentials).
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
access_map: Mapping[str, ToolAccessDescription] = Field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_dependencies(cls, deps: ToolDependencies | None) -> "ToolAccessPolicy":
|
||||
"""Build a policy from ``ToolDependencies``.
|
||||
|
||||
Only ``deps.references`` are used. Multiple references to the same
|
||||
tool are merged – their credential IDs are unioned into a single
|
||||
``ToolAccessDescription.allowed_credentials`` set.
|
||||
"""
|
||||
if deps is None or deps.is_empty():
|
||||
return cls()
|
||||
|
||||
# Accumulate credential sets keyed by tool_id so that multiple
|
||||
# references to the same tool are merged correctly.
|
||||
credentials_by_tool: dict[str, set[str]] = {}
|
||||
first_seen: dict[str, tuple[ToolProviderType, str, str]] = {}
|
||||
|
||||
for ref in deps.references:
|
||||
tool_id = f"{ref.type.value}:{ref.provider}:{ref.tool_name}"
|
||||
if tool_id not in first_seen:
|
||||
first_seen[tool_id] = (ref.type, ref.provider, ref.tool_name)
|
||||
credentials_by_tool[tool_id] = set()
|
||||
if ref.credential_id is not None:
|
||||
credentials_by_tool[tool_id].add(ref.credential_id)
|
||||
|
||||
access_map: dict[str, ToolAccessDescription] = {}
|
||||
for tool_id, (tool_type, provider, tool_name) in first_seen.items():
|
||||
access_map[tool_id] = ToolAccessDescription(
|
||||
tool_type=tool_type,
|
||||
provider=provider,
|
||||
tool_name=tool_name,
|
||||
allowed_credentials=frozenset(credentials_by_tool[tool_id]),
|
||||
)
|
||||
|
||||
return cls(access_map=access_map)
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return len(self.access_map) == 0
|
||||
|
||||
def is_allowed(self, request: ToolInvocationRequest) -> bool:
|
||||
"""Check if the tool invocation request is allowed."""
|
||||
# An empty policy (no references declared) permits any invocation.
|
||||
if self.is_empty():
|
||||
return True
|
||||
|
||||
tool_id = request.tool_description.tool_id()
|
||||
access_desc = self.access_map.get(tool_id)
|
||||
if access_desc is None:
|
||||
return False
|
||||
|
||||
return access_desc.is_credential_allowed(request.credential_id)
|
||||
78
api/core/skill/entities/tool_dependencies.py
Normal file
78
api/core/skill/entities/tool_dependencies.py
Normal file
@ -0,0 +1,78 @@
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.skill.entities.skill_metadata import ToolReference
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
|
||||
|
||||
class ToolDependency(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
type: ToolProviderType
|
||||
provider: str
|
||||
tool_name: str
|
||||
enabled: bool = True
|
||||
|
||||
def tool_id(self) -> str:
|
||||
return f"{self.provider}.{self.tool_name}"
|
||||
|
||||
|
||||
class ToolDependencies(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
dependencies: list[ToolDependency] = Field(default_factory=list)
|
||||
references: list[ToolReference] = Field(default_factory=list)
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return not self.dependencies and not self.references
|
||||
|
||||
def filter(self, tools: list[tuple[str, str]]) -> "ToolDependencies":
|
||||
tool_names = {f"{provider}.{tool_name}" for provider, tool_name in tools}
|
||||
return ToolDependencies(
|
||||
dependencies=[
|
||||
dependency
|
||||
for dependency in self.dependencies
|
||||
if f"{dependency.provider}.{dependency.tool_name}" in tool_names
|
||||
],
|
||||
references=[
|
||||
reference
|
||||
for reference in self.references
|
||||
if f"{reference.provider}.{reference.tool_name}" in tool_names
|
||||
],
|
||||
)
|
||||
|
||||
def merge(self, other: "ToolDependencies") -> "ToolDependencies":
|
||||
dep_map: dict[str, ToolDependency] = {}
|
||||
for dep in self.dependencies:
|
||||
key = f"{dep.provider}.{dep.tool_name}"
|
||||
dep_map[key] = dep
|
||||
for dep in other.dependencies:
|
||||
key = f"{dep.provider}.{dep.tool_name}"
|
||||
if key not in dep_map:
|
||||
dep_map[key] = dep
|
||||
|
||||
ref_map: dict[str, ToolReference] = {}
|
||||
for ref in self.references:
|
||||
ref_map[ref.uuid] = ref
|
||||
for ref in other.references:
|
||||
if ref.uuid not in ref_map:
|
||||
ref_map[ref.uuid] = ref
|
||||
|
||||
return ToolDependencies(
|
||||
dependencies=list(dep_map.values()),
|
||||
references=list(ref_map.values()),
|
||||
)
|
||||
|
||||
def remove_tools(self, tools: list[ToolDependency]) -> "ToolDependencies":
|
||||
tool_keys = {f"{tool.provider}.{tool.tool_name}" for tool in tools}
|
||||
return ToolDependencies(
|
||||
dependencies=[
|
||||
dependency
|
||||
for dependency in self.dependencies
|
||||
if f"{dependency.provider}.{dependency.tool_name}" not in tool_keys
|
||||
],
|
||||
references=[
|
||||
reference
|
||||
for reference in self.references
|
||||
if f"{reference.provider}.{reference.tool_name}" not in tool_keys
|
||||
],
|
||||
)
|
||||
43
api/core/skill/skill_manager.py
Normal file
43
api/core/skill/skill_manager.py
Normal file
@ -0,0 +1,43 @@
|
||||
import logging
|
||||
|
||||
from core.app_assets.storage import AssetPaths
|
||||
from core.skill.entities.skill_bundle import SkillBundle
|
||||
from extensions.ext_redis import redis_client
|
||||
from services.app_asset_service import AppAssetService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_CACHE_PREFIX = "skill_bundle"
|
||||
_CACHE_TTL = 86400 # 24 hours
|
||||
|
||||
|
||||
class SkillManager:
|
||||
@staticmethod
|
||||
def load_bundle(tenant_id: str, app_id: str, assets_id: str) -> SkillBundle:
|
||||
cache_key = f"{_CACHE_PREFIX}:{tenant_id}:{app_id}:{assets_id}"
|
||||
data = redis_client.get(cache_key)
|
||||
if data:
|
||||
return SkillBundle.model_validate_json(data)
|
||||
|
||||
key = AssetPaths.skill_bundle(tenant_id, app_id, assets_id)
|
||||
try:
|
||||
data = AppAssetService.get_storage().load_once(key)
|
||||
except FileNotFoundError:
|
||||
logger.exception(
|
||||
"Skill bundle not found in storage: key=%s, tenant_id=%s, app_id=%s, assets_id=%s",
|
||||
key,
|
||||
tenant_id,
|
||||
app_id,
|
||||
assets_id,
|
||||
)
|
||||
raise
|
||||
bundle = SkillBundle.model_validate_json(data)
|
||||
redis_client.setex(cache_key, _CACHE_TTL, bundle.model_dump_json(indent=2).encode("utf-8"))
|
||||
return bundle
|
||||
|
||||
@staticmethod
|
||||
def save_bundle(tenant_id: str, app_id: str, assets_id: str, bundle: SkillBundle) -> None:
|
||||
key = AssetPaths.skill_bundle(tenant_id, app_id, assets_id)
|
||||
AppAssetService.get_storage().save(key, data=bundle.model_dump_json(indent=2).encode("utf-8"))
|
||||
cache_key = f"{_CACHE_PREFIX}:{tenant_id}:{app_id}:{assets_id}"
|
||||
redis_client.delete(cache_key)
|
||||
0
api/core/virtual_environment/__base/__init__.py
Normal file
0
api/core/virtual_environment/__base/__init__.py
Normal file
208
api/core/virtual_environment/__base/command_future.py
Normal file
208
api/core/virtual_environment/__base/command_future.py
Normal file
@ -0,0 +1,208 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from core.virtual_environment.__base.entities import CommandResult, CommandStatus
|
||||
from core.virtual_environment.__base.exec import NotSupportedOperationError
|
||||
from core.virtual_environment.channel.exec import TransportEOFError
|
||||
from core.virtual_environment.channel.transport import TransportReadCloser, TransportWriteCloser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CommandTimeoutError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class CommandCancelledError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class CommandFuture:
|
||||
"""
|
||||
Lightweight future for command execution.
|
||||
Mirrors concurrent.futures.Future API with 4 essential methods:
|
||||
result(), done(), cancel(), cancelled().
|
||||
|
||||
When a command is cancelled or times out the future now asks the provider
|
||||
to terminate the underlying process/session before marking itself done.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pid: str,
|
||||
stdin_transport: TransportWriteCloser,
|
||||
stdout_transport: TransportReadCloser,
|
||||
stderr_transport: TransportReadCloser,
|
||||
poll_status: Callable[[], CommandStatus],
|
||||
terminate_command: Callable[[], bool] | None = None,
|
||||
poll_interval: float = 0.1,
|
||||
):
|
||||
self._pid = pid
|
||||
self._stdin_transport = stdin_transport
|
||||
self._stdout_transport = stdout_transport
|
||||
self._stderr_transport = stderr_transport
|
||||
self._poll_status = poll_status
|
||||
self._terminate_command = terminate_command
|
||||
self._poll_interval = poll_interval
|
||||
|
||||
self._done_event = threading.Event()
|
||||
self._lock = threading.Lock()
|
||||
self._result: CommandResult | None = None
|
||||
self._exception: BaseException | None = None
|
||||
self._cancelled = False
|
||||
self._timed_out = False
|
||||
self._started = False
|
||||
self._termination_requested = False
|
||||
|
||||
def result(self, timeout: float | None = None) -> CommandResult:
|
||||
"""
|
||||
Block until command completes and return result.
|
||||
|
||||
Args:
|
||||
timeout: Maximum seconds to wait. None means wait forever.
|
||||
|
||||
Raises:
|
||||
CommandTimeoutError: If timeout exceeded.
|
||||
CommandCancelledError: If command was cancelled.
|
||||
|
||||
A timeout is terminal for this future: it triggers best-effort command
|
||||
termination and subsequent ``result()`` calls keep raising timeout.
|
||||
"""
|
||||
self._ensure_started()
|
||||
|
||||
if not self._done_event.wait(timeout):
|
||||
self._request_stop(timed_out=True)
|
||||
raise CommandTimeoutError(f"Command timed out after {timeout}s")
|
||||
|
||||
if self._cancelled:
|
||||
raise CommandCancelledError("Command was cancelled")
|
||||
|
||||
if self._timed_out:
|
||||
raise CommandTimeoutError("Command timed out")
|
||||
|
||||
if self._exception is not None:
|
||||
raise self._exception
|
||||
|
||||
assert self._result is not None
|
||||
return self._result
|
||||
|
||||
def done(self) -> bool:
|
||||
self._ensure_started()
|
||||
return self._done_event.is_set()
|
||||
|
||||
def cancel(self) -> bool:
|
||||
"""
|
||||
Attempt to cancel command by terminating it and closing transports.
|
||||
Returns True if cancelled, False if already completed.
|
||||
"""
|
||||
return self._request_stop(cancelled=True)
|
||||
|
||||
def cancelled(self) -> bool:
|
||||
return self._cancelled
|
||||
|
||||
def _ensure_started(self) -> None:
|
||||
with self._lock:
|
||||
if not self._started:
|
||||
self._started = True
|
||||
thread = threading.Thread(target=self._execute, daemon=True)
|
||||
thread.start()
|
||||
|
||||
def _request_stop(self, *, cancelled: bool = False, timed_out: bool = False) -> bool:
|
||||
should_terminate = False
|
||||
with self._lock:
|
||||
if self._done_event.is_set():
|
||||
return False
|
||||
|
||||
if cancelled:
|
||||
self._cancelled = True
|
||||
if timed_out:
|
||||
self._timed_out = True
|
||||
|
||||
should_terminate = not self._termination_requested
|
||||
if should_terminate:
|
||||
self._termination_requested = True
|
||||
|
||||
self._close_transports()
|
||||
self._done_event.set()
|
||||
|
||||
if should_terminate:
|
||||
self._terminate_running_command()
|
||||
return True
|
||||
|
||||
def _execute(self) -> None:
|
||||
stdout_buf = bytearray()
|
||||
stderr_buf = bytearray()
|
||||
is_combined_stream = self._stdout_transport is self._stderr_transport
|
||||
|
||||
try:
|
||||
with ThreadPoolExecutor(max_workers=2) as executor:
|
||||
stdout_future = executor.submit(self._drain_transport, self._stdout_transport, stdout_buf)
|
||||
stderr_future = None
|
||||
if not is_combined_stream:
|
||||
stderr_future = executor.submit(self._drain_transport, self._stderr_transport, stderr_buf)
|
||||
|
||||
exit_code = self._wait_for_completion()
|
||||
|
||||
stdout_future.result()
|
||||
if stderr_future:
|
||||
stderr_future.result()
|
||||
|
||||
with self._lock:
|
||||
if not self._cancelled:
|
||||
self._result = CommandResult(
|
||||
stdout=bytes(stdout_buf),
|
||||
stderr=b"" if is_combined_stream else bytes(stderr_buf),
|
||||
exit_code=exit_code,
|
||||
pid=self._pid,
|
||||
)
|
||||
self._done_event.set()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Command execution failed for pid %s", self._pid)
|
||||
with self._lock:
|
||||
if not self._cancelled:
|
||||
self._exception = e
|
||||
self._done_event.set()
|
||||
finally:
|
||||
self._close_transports()
|
||||
|
||||
def _wait_for_completion(self) -> int | None:
|
||||
while not self._cancelled and not self._timed_out:
|
||||
try:
|
||||
status = self._poll_status()
|
||||
except NotSupportedOperationError:
|
||||
return None
|
||||
|
||||
if status.status == CommandStatus.Status.COMPLETED:
|
||||
return status.exit_code
|
||||
|
||||
time.sleep(self._poll_interval)
|
||||
|
||||
return None
|
||||
|
||||
def _drain_transport(self, transport: TransportReadCloser, buffer: bytearray) -> None:
|
||||
try:
|
||||
while True:
|
||||
buffer.extend(transport.read(4096))
|
||||
except TransportEOFError:
|
||||
pass
|
||||
except Exception:
|
||||
logger.exception("Failed reading transport")
|
||||
|
||||
def _close_transports(self) -> None:
|
||||
for transport in (self._stdin_transport, self._stdout_transport, self._stderr_transport):
|
||||
with contextlib.suppress(Exception):
|
||||
transport.close()
|
||||
|
||||
def _terminate_running_command(self) -> None:
|
||||
if self._terminate_command is None:
|
||||
return
|
||||
|
||||
try:
|
||||
self._terminate_command()
|
||||
except Exception:
|
||||
logger.exception("Failed to terminate command for pid %s", self._pid)
|
||||
100
api/core/virtual_environment/__base/entities.py
Normal file
100
api/core/virtual_environment/__base/entities.py
Normal file
@ -0,0 +1,100 @@
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Arch(StrEnum):
|
||||
"""
|
||||
Architecture types for virtual environments.
|
||||
"""
|
||||
|
||||
ARM64 = "arm64"
|
||||
AMD64 = "amd64"
|
||||
|
||||
|
||||
class OperatingSystem(StrEnum):
|
||||
"""
|
||||
Operating system types for virtual environments.
|
||||
"""
|
||||
|
||||
LINUX = "linux"
|
||||
DARWIN = "darwin"
|
||||
|
||||
|
||||
class Metadata(BaseModel):
|
||||
"""
|
||||
Returned metadata about a virtual environment.
|
||||
"""
|
||||
|
||||
id: str = Field(description="The unique identifier of the virtual environment.")
|
||||
arch: Arch = Field(description="Which architecture was used to create the virtual environment.")
|
||||
os: OperatingSystem = Field(description="The operating system of the virtual environment.")
|
||||
store: Mapping[str, Any] = Field(
|
||||
default_factory=dict, description="The store information of the virtual environment., Additional data."
|
||||
)
|
||||
|
||||
|
||||
class ConnectionHandle(BaseModel):
|
||||
"""
|
||||
Handle for managing connections to the virtual environment.
|
||||
"""
|
||||
|
||||
id: str = Field(description="The unique identifier of the connection handle.")
|
||||
|
||||
|
||||
class CommandStatus(BaseModel):
|
||||
"""
|
||||
Status of a command executed in the virtual environment.
|
||||
"""
|
||||
|
||||
class Status(StrEnum):
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
|
||||
status: Status = Field(description="The status of the command execution.")
|
||||
exit_code: int | None = Field(description="The return code of the command execution.")
|
||||
|
||||
|
||||
class FileState(BaseModel):
|
||||
"""
|
||||
State of a file in the virtual environment.
|
||||
"""
|
||||
|
||||
size: int = Field(description="The size of the file in bytes.")
|
||||
path: str = Field(description="The path of the file in the virtual environment.")
|
||||
created_at: int = Field(description="The creation timestamp of the file.")
|
||||
updated_at: int = Field(description="The last modified timestamp of the file.")
|
||||
|
||||
|
||||
class CommandResult(BaseModel):
|
||||
"""
|
||||
Result of a synchronous command execution.
|
||||
"""
|
||||
|
||||
stdout: bytes = Field(description="Standard output content.")
|
||||
stderr: bytes = Field(description="Standard error content.")
|
||||
exit_code: int | None = Field(description="Exit code of the command. None if unavailable.")
|
||||
pid: str = Field(description="Process ID of the executed command.")
|
||||
|
||||
@property
|
||||
def is_error(self) -> bool:
|
||||
return self.exit_code not in (None, 0) or bool(self.stderr.decode("utf-8", errors="replace"))
|
||||
|
||||
@property
|
||||
def error_message(self) -> str:
|
||||
return self.stderr.decode("utf-8", errors="replace") if self.stderr else ""
|
||||
|
||||
@property
|
||||
def info_message(self) -> str:
|
||||
return self.stdout.decode("utf-8", errors="replace") if self.stdout else ""
|
||||
|
||||
@property
|
||||
def debug_message(self) -> str:
|
||||
return (
|
||||
f"stdout: {self.stdout.decode('utf-8', errors='replace')}\n"
|
||||
f"stderr: {self.stderr.decode('utf-8', errors='replace')}\n"
|
||||
f"exit_code: {self.exit_code}\n"
|
||||
f"pid: {self.pid}"
|
||||
)
|
||||
64
api/core/virtual_environment/__base/exec.py
Normal file
64
api/core/virtual_environment/__base/exec.py
Normal file
@ -0,0 +1,64 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.virtual_environment.__base.entities import CommandResult
|
||||
|
||||
|
||||
class ArchNotSupportedError(Exception):
|
||||
"""Exception raised when the architecture is not supported."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class VirtualEnvironmentLaunchFailedError(Exception):
|
||||
"""Exception raised when launching the virtual environment fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class NotSupportedOperationError(Exception):
|
||||
"""Exception raised when an operation is not supported."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SandboxConfigValidationError(ValueError):
|
||||
"""Exception raised when sandbox configuration validation fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CommandExecutionError(ValueError):
|
||||
"""Raised when a command execution fails."""
|
||||
|
||||
result: CommandResult
|
||||
|
||||
def __init__(self, message: str, result: CommandResult):
|
||||
super().__init__(message)
|
||||
self.result = result
|
||||
|
||||
@property
|
||||
def exit_code(self) -> int | None:
|
||||
return self.result.exit_code
|
||||
|
||||
@property
|
||||
def stderr(self) -> bytes:
|
||||
return self.result.stderr
|
||||
|
||||
|
||||
class PipelineExecutionError(CommandExecutionError):
|
||||
"""Raised when a pipeline command fails in strict mode."""
|
||||
|
||||
index: int
|
||||
command: list[str]
|
||||
results: list[CommandResult]
|
||||
|
||||
def __init__(
|
||||
self, message: str, result: CommandResult, *, index: int, command: list[str], results: list[CommandResult]
|
||||
):
|
||||
super().__init__(message, result)
|
||||
self.index = index
|
||||
self.command = command
|
||||
self.results = results
|
||||
279
api/core/virtual_environment/__base/helpers.py
Normal file
279
api/core/virtual_environment/__base/helpers.py
Normal file
@ -0,0 +1,279 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import shlex
|
||||
from collections.abc import Generator, Mapping
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
|
||||
from core.virtual_environment.__base.command_future import CommandFuture
|
||||
from core.virtual_environment.__base.entities import CommandResult, ConnectionHandle
|
||||
from core.virtual_environment.__base.exec import CommandExecutionError, PipelineExecutionError
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
|
||||
_PIPE_SENTINEL = "__DIFY_PIPE__"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def with_connection(env: VirtualEnvironment) -> Generator[ConnectionHandle, None, None]:
|
||||
"""Context manager for VirtualEnvironment connection lifecycle.
|
||||
|
||||
Automatically establishes and releases connection handles.
|
||||
|
||||
Usage:
|
||||
with with_connection(env) as conn:
|
||||
future = run_command(env, conn, ["echo", "hello"])
|
||||
result = future.result(timeout=10)
|
||||
"""
|
||||
connection_handle = env.establish_connection()
|
||||
try:
|
||||
yield connection_handle
|
||||
finally:
|
||||
with contextlib.suppress(Exception):
|
||||
env.release_connection(connection_handle)
|
||||
|
||||
|
||||
def submit_command(
|
||||
env: VirtualEnvironment,
|
||||
connection: ConnectionHandle,
|
||||
command: list[str],
|
||||
environments: Mapping[str, str] | None = None,
|
||||
*,
|
||||
cwd: str | None = None,
|
||||
) -> CommandFuture:
|
||||
"""Execute a command and return a Future for the result.
|
||||
|
||||
High-level interface that handles IO draining internally.
|
||||
For streaming output, use env.execute_command() instead.
|
||||
|
||||
Args:
|
||||
env: The virtual environment to execute the command in.
|
||||
connection: The connection handle.
|
||||
command: Command as list of strings.
|
||||
environments: Environment variables.
|
||||
cwd: Working directory for the command. If None, uses the provider's default.
|
||||
|
||||
Returns:
|
||||
CommandFuture that can be used to get result with timeout or cancel.
|
||||
|
||||
Example:
|
||||
with with_connection(env) as conn:
|
||||
result = run_command(env, conn, ["ls", "-la"]).result(timeout=30)
|
||||
"""
|
||||
pid, stdin_transport, stdout_transport, stderr_transport = env.execute_command(
|
||||
connection, command, environments, cwd
|
||||
)
|
||||
|
||||
return CommandFuture(
|
||||
pid=pid,
|
||||
stdin_transport=stdin_transport,
|
||||
stdout_transport=stdout_transport,
|
||||
stderr_transport=stderr_transport,
|
||||
poll_status=partial(env.get_command_status, connection, pid),
|
||||
terminate_command=partial(env.terminate_command, connection, pid),
|
||||
)
|
||||
|
||||
|
||||
def _execute_with_connection(
|
||||
env: VirtualEnvironment,
|
||||
conn: ConnectionHandle,
|
||||
command: list[str],
|
||||
timeout: float | None,
|
||||
cwd: str | None,
|
||||
) -> CommandResult:
|
||||
"""Internal helper to execute command with given connection."""
|
||||
future = submit_command(env, conn, command, cwd=cwd)
|
||||
return future.result(timeout=timeout)
|
||||
|
||||
|
||||
def execute(
|
||||
env: VirtualEnvironment,
|
||||
command: list[str],
|
||||
*,
|
||||
timeout: float | None = 30,
|
||||
cwd: str | None = None,
|
||||
error_message: str = "Command failed",
|
||||
connection: ConnectionHandle | None = None,
|
||||
) -> CommandResult:
|
||||
"""Execute a command with automatic connection management.
|
||||
|
||||
Raises CommandExecutionError if the command fails (non-zero exit code).
|
||||
|
||||
Args:
|
||||
env: The virtual environment to execute the command in.
|
||||
command: The command to execute as a list of strings.
|
||||
timeout: Maximum time to wait for the command to complete (seconds).
|
||||
cwd: Working directory for the command.
|
||||
error_message: Custom error message prefix for failures.
|
||||
connection: Optional connection handle to reuse. If None, creates and releases a new connection.
|
||||
|
||||
Returns:
|
||||
CommandResult on success.
|
||||
|
||||
Raises:
|
||||
CommandExecutionError: If the command fails.
|
||||
"""
|
||||
if connection is not None:
|
||||
result = _execute_with_connection(env, connection, command, timeout, cwd)
|
||||
else:
|
||||
with with_connection(env) as conn:
|
||||
result = _execute_with_connection(env, conn, command, timeout, cwd)
|
||||
|
||||
if result.is_error:
|
||||
raise CommandExecutionError(f"{error_message}: {result.error_message}", result)
|
||||
return result
|
||||
|
||||
|
||||
def try_execute(
|
||||
env: VirtualEnvironment,
|
||||
command: list[str],
|
||||
*,
|
||||
timeout: float | None = 30,
|
||||
cwd: str | None = None,
|
||||
connection: ConnectionHandle | None = None,
|
||||
) -> CommandResult:
|
||||
"""Execute a command with automatic connection management.
|
||||
|
||||
Does not raise on failure - returns the result for caller to handle.
|
||||
|
||||
Args:
|
||||
env: The virtual environment to execute the command in.
|
||||
command: The command to execute as a list of strings.
|
||||
timeout: Maximum time to wait for the command to complete (seconds).
|
||||
cwd: Working directory for the command.
|
||||
connection: Optional connection handle to reuse. If None, creates and releases a new connection.
|
||||
|
||||
Returns:
|
||||
CommandResult containing stdout, stderr, and exit_code.
|
||||
"""
|
||||
if connection is not None:
|
||||
return _execute_with_connection(env, connection, command, timeout, cwd)
|
||||
|
||||
with with_connection(env) as conn:
|
||||
return _execute_with_connection(env, conn, command, timeout, cwd)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _PipelineStep:
|
||||
argv: list[str]
|
||||
error_message: str = "Command failed"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommandPipeline:
|
||||
"""Batch multiple commands into a single shell execution (Redis pipeline style).
|
||||
|
||||
Example:
|
||||
results = pipeline(env).add(["echo", "hi"]).add(["ls"]).execute()
|
||||
# Strict mode: raise on first failure
|
||||
pipeline(env).add(["mkdir", "/x"], error_message="mkdir failed").execute(raise_on_error=True)
|
||||
"""
|
||||
|
||||
env: VirtualEnvironment
|
||||
connection: ConnectionHandle | None = None
|
||||
cwd: str | None = None
|
||||
environments: Mapping[str, str] | None = None
|
||||
|
||||
_steps: list[_PipelineStep] = field(default_factory=list) # pyright: ignore[reportUnknownVariableType]
|
||||
|
||||
def add(self, command: list[str], *, error_message: str = "Command failed", on: bool = True) -> CommandPipeline:
|
||||
if on:
|
||||
self._steps.append(_PipelineStep(argv=command, error_message=error_message))
|
||||
return self
|
||||
|
||||
def execute(self, *, timeout: float | None = 30, raise_on_error: bool = False) -> list[CommandResult]:
|
||||
if not self._steps:
|
||||
return []
|
||||
|
||||
script = self._build_script(fail_fast=raise_on_error)
|
||||
batch_cmd = ["sh", "-c", script]
|
||||
|
||||
if self.connection is not None:
|
||||
batch_result = try_execute(self.env, batch_cmd, timeout=timeout, cwd=self.cwd, connection=self.connection)
|
||||
else:
|
||||
with with_connection(self.env) as conn:
|
||||
batch_result = try_execute(self.env, batch_cmd, timeout=timeout, cwd=self.cwd, connection=conn)
|
||||
|
||||
results = self._parse_results(batch_result.stdout, batch_result.pid)
|
||||
|
||||
if raise_on_error:
|
||||
for i, r in enumerate(iterable=results):
|
||||
if r.is_error:
|
||||
step = self._steps[i]
|
||||
raise PipelineExecutionError(
|
||||
f"{step.error_message}: {r.error_message}",
|
||||
r,
|
||||
index=i,
|
||||
command=step.argv,
|
||||
results=results,
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def _build_script(self, *, fail_fast: bool = False) -> str:
|
||||
lines = [
|
||||
"run() {",
|
||||
' i="$1"; shift',
|
||||
' out="$(mktemp)"; err="$(mktemp)"',
|
||||
' ("$@") >"$out" 2>"$err"; ec="$?"',
|
||||
' os="$(wc -c <"$out" | tr -d \' \')"',
|
||||
' es="$(wc -c <"$err" | tr -d \' \')"',
|
||||
f' printf \'{_PIPE_SENTINEL} %s %s %s %s\\n\' "$i" "$ec" "$os" "$es"',
|
||||
' cat "$out"',
|
||||
' cat "$err"',
|
||||
' rm -f "$out" "$err"',
|
||||
' return "$ec"',
|
||||
"}",
|
||||
]
|
||||
suffix = " || exit $?" if fail_fast else ""
|
||||
for i, step in enumerate(self._steps):
|
||||
quoted = " ".join(shlex.quote(arg) for arg in step.argv)
|
||||
lines.append(f"run {i} {quoted}{suffix}")
|
||||
return "\n".join(lines)
|
||||
|
||||
@staticmethod
|
||||
def _parse_results(stdout: bytes, pid: str) -> list[CommandResult]:
|
||||
results: list[CommandResult] = []
|
||||
pos = 0
|
||||
sentinel = _PIPE_SENTINEL.encode() + b" "
|
||||
|
||||
while pos < len(stdout):
|
||||
nl = stdout.find(b"\n", pos)
|
||||
if nl == -1:
|
||||
break
|
||||
header = stdout[pos : nl + 1]
|
||||
pos = nl + 1
|
||||
|
||||
if not header.startswith(sentinel):
|
||||
raise ValueError("Malformed pipeline output: missing sentinel")
|
||||
|
||||
parts = header.decode().strip().split(" ")
|
||||
_, idx, ec, os_len, es_len = parts
|
||||
out_len, err_len = int(os_len), int(es_len)
|
||||
|
||||
out_bytes = stdout[pos : pos + out_len]
|
||||
pos += out_len
|
||||
err_bytes = stdout[pos : pos + err_len]
|
||||
pos += err_len
|
||||
|
||||
results.append(
|
||||
CommandResult(
|
||||
stdout=out_bytes,
|
||||
stderr=err_bytes,
|
||||
exit_code=int(ec),
|
||||
pid=f"{pid}:{idx}",
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def pipeline(
|
||||
env: VirtualEnvironment,
|
||||
connection: ConnectionHandle | None = None,
|
||||
*,
|
||||
cwd: str | None = None,
|
||||
environments: Mapping[str, str] | None = None,
|
||||
) -> CommandPipeline:
|
||||
return CommandPipeline(env=env, connection=connection, cwd=cwd, environments=environments)
|
||||
231
api/core/virtual_environment/__base/virtual_environment.py
Normal file
231
api/core/virtual_environment/__base/virtual_environment.py
Normal file
@ -0,0 +1,231 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping, Sequence
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.virtual_environment.__base.entities import CommandStatus, ConnectionHandle, FileState, Metadata
|
||||
from core.virtual_environment.channel.transport import TransportReadCloser, TransportWriteCloser
|
||||
|
||||
|
||||
class VirtualEnvironment(ABC):
|
||||
"""
|
||||
Base class for virtual environment implementations.
|
||||
|
||||
``VirtualEnvironment`` instances are configured at construction time but do
|
||||
not allocate provider resources until ``open_enviroment()`` is called.
|
||||
This keeps object construction side-effect free and gives callers a chance
|
||||
to own startup error handling explicitly.
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
user_id: str | None
|
||||
options: Mapping[str, Any]
|
||||
_environments: Mapping[str, str]
|
||||
_metadata: Metadata | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
options: Mapping[str, Any],
|
||||
environments: Mapping[str, str] | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the virtual environment configuration.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID associated with this environment (required).
|
||||
options: Provider-specific configuration options.
|
||||
environments: Environment variables to set in the virtual environment.
|
||||
user_id: The user ID associated with this environment (optional).
|
||||
|
||||
The provider runtime itself is created later by ``open_enviroment()``.
|
||||
"""
|
||||
|
||||
self.tenant_id = tenant_id
|
||||
self.user_id = user_id
|
||||
self.options = options
|
||||
self._environments = dict(environments or {})
|
||||
self._metadata = None
|
||||
|
||||
@property
|
||||
def metadata(self) -> Metadata:
|
||||
"""Provider metadata for a started environment.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the environment has not been started yet.
|
||||
"""
|
||||
|
||||
if self._metadata is None:
|
||||
raise RuntimeError("Virtual environment has not been started")
|
||||
return self._metadata
|
||||
|
||||
def open_enviroment(self) -> Metadata:
|
||||
"""Allocate provider resources and return the resulting metadata.
|
||||
|
||||
Multiple calls are safe and return the existing metadata after the first
|
||||
successful start.
|
||||
"""
|
||||
|
||||
if self._metadata is None:
|
||||
self._metadata = self._construct_environment(self.options, self._environments)
|
||||
return self._metadata
|
||||
|
||||
@abstractmethod
|
||||
def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata:
|
||||
"""
|
||||
Construct the unique identifier for the virtual environment.
|
||||
|
||||
Returns:
|
||||
str: The unique identifier of the virtual environment.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def upload_file(self, path: str, content: BytesIO) -> None:
|
||||
"""
|
||||
Upload a file to the virtual environment.
|
||||
|
||||
Args:
|
||||
path (str): The destination path in the virtual environment.
|
||||
content (BytesIO): The content of the file to upload.
|
||||
|
||||
Raises:
|
||||
Exception: If the file cannot be uploaded.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def download_file(self, path: str) -> BytesIO:
|
||||
"""
|
||||
Download a file from the virtual environment.
|
||||
|
||||
Args:
|
||||
source_path (str): The source path in the virtual environment.
|
||||
Returns:
|
||||
BytesIO: The content of the downloaded file.
|
||||
Raises:
|
||||
Exception: If the file cannot be downloaded.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def list_files(self, directory_path: str, limit: int) -> Sequence[FileState]:
|
||||
"""
|
||||
List files in a directory of the virtual environment.
|
||||
|
||||
Args:
|
||||
directory_path (str): The directory path in the virtual environment.
|
||||
limit (int): The maximum number of files(including recursive paths) to return.
|
||||
Returns:
|
||||
Sequence[FileState]: A list of file states in the specified directory.
|
||||
Raises:
|
||||
Exception: If the files cannot be listed.
|
||||
|
||||
Example:
|
||||
If the directory structure is like:
|
||||
/dir
|
||||
/subdir1
|
||||
file1.txt
|
||||
/subdir2
|
||||
file2.txt
|
||||
And limit is 2, the returned list may look like:
|
||||
[
|
||||
FileState(path="/dir/subdir1/file1.txt", is_directory=False, size=1234, created_at=..., updated_at=...),
|
||||
FileState(path="/dir/subdir2", is_directory=True, size=0, created_at=..., updated_at=...),
|
||||
]
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def establish_connection(self) -> ConnectionHandle:
|
||||
"""
|
||||
Establish a connection to the virtual environment.
|
||||
|
||||
Returns:
|
||||
ConnectionHandle: Handle for managing the connection to the virtual environment.
|
||||
|
||||
Raises:
|
||||
Exception: If the connection cannot be established.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def release_connection(self, connection_handle: ConnectionHandle) -> None:
|
||||
"""
|
||||
Release the connection to the virtual environment.
|
||||
|
||||
Args:
|
||||
connection_handle (ConnectionHandle): The handle for managing the connection.
|
||||
|
||||
Raises:
|
||||
Exception: If the connection cannot be released.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def release_environment(self) -> None:
|
||||
"""
|
||||
Release the virtual environment.
|
||||
|
||||
Raises:
|
||||
Exception: If the environment cannot be released.
|
||||
Multiple calls to `release_environment` with the same `environment_id` is acceptable.
|
||||
"""
|
||||
|
||||
def terminate_command(self, connection_handle: ConnectionHandle, pid: str) -> bool:
|
||||
"""Best-effort termination hook for a running command.
|
||||
|
||||
Providers that can map ``pid`` back to a real process/session should
|
||||
override this method and stop the command. The default implementation is
|
||||
a no-op so providers without a termination mechanism remain compatible.
|
||||
"""
|
||||
|
||||
_ = connection_handle
|
||||
_ = pid
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def execute_command(
|
||||
self,
|
||||
connection_handle: ConnectionHandle,
|
||||
command: list[str],
|
||||
environments: Mapping[str, str] | None = None,
|
||||
cwd: str | None = None,
|
||||
) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]:
|
||||
"""
|
||||
Execute a command in the virtual environment.
|
||||
|
||||
Args:
|
||||
connection_handle (ConnectionHandle): The handle for managing the connection.
|
||||
command (list[str]): The command to execute as a list of strings.
|
||||
environments (Mapping[str, str] | None): Environment variables for the command.
|
||||
cwd (str | None): Working directory for the command. If None, uses the provider's default.
|
||||
|
||||
Returns:
|
||||
tuple[int, TransportWriteCloser, TransportReadCloser, TransportReadCloser]
|
||||
a tuple containing pid and 3 handle to os.pipe(): (stdin, stdout, stderr).
|
||||
After exuection, the 3 handles will be closed by caller.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def validate(cls, options: Mapping[str, Any]) -> None:
|
||||
"""
|
||||
Validate that options can connect to the provider.
|
||||
|
||||
Raises:
|
||||
SandboxConfigValidationError: If validation fails
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus:
|
||||
"""
|
||||
Get the status of a command executed in the virtual environment.
|
||||
|
||||
Args:
|
||||
connection_handle (ConnectionHandle): The handle for managing the connection.
|
||||
pid (int): The process ID of the command.
|
||||
Returns:
|
||||
CommandStatus: The status of the command execution.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_config_schema(cls) -> list[BasicProviderConfig]:
|
||||
pass
|
||||
0
api/core/virtual_environment/__init__.py
Normal file
0
api/core/virtual_environment/__init__.py
Normal file
0
api/core/virtual_environment/channel/__init__.py
Normal file
0
api/core/virtual_environment/channel/__init__.py
Normal file
4
api/core/virtual_environment/channel/exec.py
Normal file
4
api/core/virtual_environment/channel/exec.py
Normal file
@ -0,0 +1,4 @@
|
||||
class TransportEOFError(Exception):
|
||||
"""Exception raised when attempting to read from a closed transport."""
|
||||
|
||||
pass
|
||||
72
api/core/virtual_environment/channel/pipe_transport.py
Normal file
72
api/core/virtual_environment/channel/pipe_transport.py
Normal file
@ -0,0 +1,72 @@
|
||||
import os
|
||||
|
||||
from core.virtual_environment.channel.exec import TransportEOFError
|
||||
from core.virtual_environment.channel.transport import Transport, TransportReadCloser, TransportWriteCloser
|
||||
|
||||
|
||||
class PipeTransport(Transport):
|
||||
"""
|
||||
A Transport implementation using OS pipes. it requires two file descriptors:
|
||||
one for reading and one for writing.
|
||||
|
||||
NOTE: r_fd and w_fd must be a pair created by os.pipe(). or returned from subprocess.Popen
|
||||
|
||||
NEVER FORGET TO CALL `close()` METHOD TO AVOID FILE DESCRIPTOR LEAKAGE.
|
||||
"""
|
||||
|
||||
def __init__(self, r_fd: int, w_fd: int):
|
||||
self.r_fd = r_fd
|
||||
self.w_fd = w_fd
|
||||
|
||||
def write(self, data: bytes) -> None:
|
||||
try:
|
||||
os.write(self.w_fd, data)
|
||||
except OSError:
|
||||
raise TransportEOFError("Pipe write error, maybe the read end is closed")
|
||||
|
||||
def read(self, n: int) -> bytes:
|
||||
data = os.read(self.r_fd, n)
|
||||
if data == b"":
|
||||
raise TransportEOFError("End of Pipe reached")
|
||||
return data
|
||||
|
||||
def close(self) -> None:
|
||||
os.close(self.r_fd)
|
||||
os.close(self.w_fd)
|
||||
|
||||
|
||||
class PipeReadCloser(TransportReadCloser):
|
||||
"""
|
||||
A Transport implementation using OS pipe for reading.
|
||||
"""
|
||||
|
||||
def __init__(self, r_fd: int):
|
||||
self.r_fd = r_fd
|
||||
|
||||
def read(self, n: int) -> bytes:
|
||||
data = os.read(self.r_fd, n)
|
||||
if data == b"":
|
||||
raise TransportEOFError("End of Pipe reached")
|
||||
|
||||
return data
|
||||
|
||||
def close(self) -> None:
|
||||
os.close(self.r_fd)
|
||||
|
||||
|
||||
class PipeWriteCloser(TransportWriteCloser):
|
||||
"""
|
||||
A Transport implementation using OS pipe for writing.
|
||||
"""
|
||||
|
||||
def __init__(self, w_fd: int):
|
||||
self.w_fd = w_fd
|
||||
|
||||
def write(self, data: bytes) -> None:
|
||||
try:
|
||||
os.write(self.w_fd, data)
|
||||
except OSError:
|
||||
raise TransportEOFError("Pipe write error, maybe the read end is closed")
|
||||
|
||||
def close(self) -> None:
|
||||
os.close(self.w_fd)
|
||||
117
api/core/virtual_environment/channel/queue_transport.py
Normal file
117
api/core/virtual_environment/channel/queue_transport.py
Normal file
@ -0,0 +1,117 @@
|
||||
from core.virtual_environment.channel.exec import TransportEOFError
|
||||
from core.virtual_environment.channel.transport import TransportReadCloser
|
||||
|
||||
|
||||
class QueueTransportReadCloser(TransportReadCloser):
|
||||
"""
|
||||
Transport implementation using queues for inter-thread communication.
|
||||
|
||||
Usage:
|
||||
q_transport = QueueTransportReadCloser()
|
||||
write_handler = q_transport.get_write_handler()
|
||||
|
||||
# In writer thread
|
||||
write_handler.write(b"data")
|
||||
|
||||
# In reader thread
|
||||
data = q_transport.read(1024)
|
||||
|
||||
# Close transport when done
|
||||
q_transport.close()
|
||||
"""
|
||||
|
||||
_QUEUE_GET_TIMEOUT = 5.0
|
||||
|
||||
class WriteHandler:
|
||||
"""
|
||||
A write handler that writes data to a queue.
|
||||
"""
|
||||
|
||||
from queue import Queue
|
||||
|
||||
def __init__(self, queue: Queue[bytes | None]) -> None:
|
||||
self.queue = queue
|
||||
|
||||
def write(self, data: bytes) -> None:
|
||||
self.queue.put(data)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the QueueTransportReadCloser with write function.
|
||||
"""
|
||||
from queue import Queue
|
||||
|
||||
self.q = Queue[bytes | None]()
|
||||
self._read_buffer = bytearray()
|
||||
self._closed = False
|
||||
self._write_channel_closed = False
|
||||
|
||||
def get_write_handler(self) -> WriteHandler:
|
||||
"""
|
||||
Get a write handler that writes to the internal queue.
|
||||
"""
|
||||
return QueueTransportReadCloser.WriteHandler(self.q)
|
||||
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Close the transport by putting a sentinel value in the queue.
|
||||
"""
|
||||
if self._write_channel_closed:
|
||||
raise TransportEOFError("Write channel already closed")
|
||||
|
||||
self._write_channel_closed = True
|
||||
self.q.put(None)
|
||||
|
||||
def read(self, n: int) -> bytes:
|
||||
"""
|
||||
Read up to n bytes from the queue.
|
||||
|
||||
NEVER USE IT IN A MULTI-THREADED CONTEXT WITHOUT PROPER SYNCHRONIZATION.
|
||||
"""
|
||||
from queue import Empty
|
||||
|
||||
if n <= 0:
|
||||
return b""
|
||||
|
||||
if self._closed:
|
||||
raise TransportEOFError("Transport is closed")
|
||||
|
||||
to_return = self._drain_buffer(n)
|
||||
|
||||
# At the first round reading from queue, hanging is required to wait for the data
|
||||
# But after that, return immediately if no data is available
|
||||
round = 0
|
||||
|
||||
while len(to_return) < n and not self._closed and (self.q.qsize() > 0 or round == 0):
|
||||
try:
|
||||
chunk = self.q.get(timeout=self._QUEUE_GET_TIMEOUT)
|
||||
except Empty:
|
||||
if self._closed:
|
||||
raise TransportEOFError("Transport is closed")
|
||||
continue
|
||||
if chunk is None:
|
||||
self._closed = True
|
||||
if len(to_return) == 0:
|
||||
raise TransportEOFError("Transport is closed")
|
||||
else:
|
||||
break
|
||||
|
||||
self._read_buffer.extend(chunk)
|
||||
|
||||
if n - len(to_return) > 0:
|
||||
# Drain the buffer if we still need more data
|
||||
to_return += self._drain_buffer(n - len(to_return))
|
||||
else:
|
||||
# No more data needed, break
|
||||
break
|
||||
|
||||
round += 1
|
||||
|
||||
return to_return
|
||||
|
||||
def _drain_buffer(self, n: int) -> bytes:
|
||||
data = bytes(self._read_buffer[:n])
|
||||
del self._read_buffer[:n]
|
||||
return data
|
||||
70
api/core/virtual_environment/channel/socket_transport.py
Normal file
70
api/core/virtual_environment/channel/socket_transport.py
Normal file
@ -0,0 +1,70 @@
|
||||
import socket
|
||||
|
||||
from core.virtual_environment.channel.exec import TransportEOFError
|
||||
from core.virtual_environment.channel.transport import Transport, TransportReadCloser, TransportWriteCloser
|
||||
|
||||
|
||||
class SocketTransport(Transport):
|
||||
"""
|
||||
A Transport implementation using a socket.
|
||||
"""
|
||||
|
||||
def __init__(self, sock: socket.SocketIO):
|
||||
self.sock = sock
|
||||
|
||||
def write(self, data: bytes) -> None:
|
||||
try:
|
||||
self.sock.write(data)
|
||||
except (ConnectionResetError, BrokenPipeError):
|
||||
raise TransportEOFError("Socket write error, maybe the read end is closed")
|
||||
|
||||
def read(self, n: int) -> bytes:
|
||||
try:
|
||||
data = self.sock.read(n)
|
||||
if data == b"":
|
||||
raise TransportEOFError("End of Socket reached")
|
||||
except (ConnectionResetError, BrokenPipeError):
|
||||
raise TransportEOFError("Socket connection reset")
|
||||
return data
|
||||
|
||||
def close(self) -> None:
|
||||
self.sock.close()
|
||||
|
||||
|
||||
class SocketReadCloser(TransportReadCloser):
|
||||
"""
|
||||
A Transport implementation using a socket for reading.
|
||||
"""
|
||||
|
||||
def __init__(self, sock: socket.SocketIO):
|
||||
self.sock = sock
|
||||
|
||||
def read(self, n: int) -> bytes:
|
||||
try:
|
||||
data = self.sock.read(n)
|
||||
if data == b"":
|
||||
raise TransportEOFError("End of Socket reached")
|
||||
return data
|
||||
except (ConnectionResetError, BrokenPipeError):
|
||||
raise TransportEOFError("Socket connection reset")
|
||||
|
||||
def close(self) -> None:
|
||||
self.sock.close()
|
||||
|
||||
|
||||
class SocketWriteCloser(TransportWriteCloser):
|
||||
"""
|
||||
A Transport implementation using a socket for writing.
|
||||
"""
|
||||
|
||||
def __init__(self, sock: socket.SocketIO):
|
||||
self.sock = sock
|
||||
|
||||
def write(self, data: bytes) -> None:
|
||||
try:
|
||||
self.sock.write(data)
|
||||
except (ConnectionResetError, BrokenPipeError):
|
||||
raise TransportEOFError("Socket write error, maybe the read end is closed")
|
||||
|
||||
def close(self) -> None:
|
||||
self.sock.close()
|
||||
80
api/core/virtual_environment/channel/transport.py
Normal file
80
api/core/virtual_environment/channel/transport.py
Normal file
@ -0,0 +1,80 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class TransportCloser(Protocol):
|
||||
"""
|
||||
Transport that can be closed.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Close the transport.
|
||||
"""
|
||||
|
||||
|
||||
class TransportWriter(Protocol):
|
||||
"""
|
||||
Transport that can be written to.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def write(self, data: bytes) -> None:
|
||||
"""
|
||||
Write data to the transport.
|
||||
|
||||
Raises TransportEOFError if the transport is closed.
|
||||
"""
|
||||
|
||||
|
||||
class TransportReader(Protocol):
|
||||
"""
|
||||
Transport that can be read from.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def read(self, n: int) -> bytes:
|
||||
"""
|
||||
Read up to n bytes from the transport.
|
||||
|
||||
Raises TransportEOFError if the end of the transport is reached.
|
||||
"""
|
||||
|
||||
|
||||
class TransportReadCloser(TransportReader, TransportCloser):
|
||||
"""
|
||||
Transport that can be read from and closed.
|
||||
"""
|
||||
|
||||
|
||||
class TransportWriteCloser(TransportWriter, TransportCloser):
|
||||
"""
|
||||
Transport that can be written to and closed.
|
||||
"""
|
||||
|
||||
|
||||
class Transport(TransportReader, TransportWriter, TransportCloser):
|
||||
"""
|
||||
Transport that can be read from, written to, and closed.
|
||||
"""
|
||||
|
||||
|
||||
class NopTransportWriteCloser(TransportWriteCloser):
|
||||
"""
|
||||
A no-operation TransportWriteCloser implementation.
|
||||
|
||||
This transport does nothing on write and close operations.
|
||||
"""
|
||||
|
||||
def write(self, data: bytes) -> None:
|
||||
"""
|
||||
No-operation write method.
|
||||
"""
|
||||
pass
|
||||
|
||||
def close(self) -> None:
|
||||
"""
|
||||
No-operation close method.
|
||||
"""
|
||||
pass
|
||||
10
api/core/virtual_environment/constants.py
Normal file
10
api/core/virtual_environment/constants.py
Normal file
@ -0,0 +1,10 @@
|
||||
"""
|
||||
Constants for virtual environment providers.
|
||||
|
||||
Centralizes timeout and other configuration values used across different sandbox providers
|
||||
(E2B, SSH, Docker) to ensure consistency and ease of maintenance.
|
||||
"""
|
||||
|
||||
# Command execution timeout in seconds (5 hours)
|
||||
# Used by providers to limit how long a single command can run
|
||||
COMMAND_EXECUTION_TIMEOUT_SECONDS = 5 * 60 * 60 # 18000 seconds
|
||||
0
api/core/virtual_environment/providers/__init__.py
Normal file
0
api/core/virtual_environment/providers/__init__.py
Normal file
@ -0,0 +1,531 @@
|
||||
"""
|
||||
AWS Bedrock AgentCore Code Interpreter sandbox provider.
|
||||
|
||||
Uses the AgentCore Code Interpreter built-in tool to provide a sandboxed code execution
|
||||
environment with shell access and file operations. The Code Interpreter runs in an
|
||||
isolated microVM managed by AWS and communicates via the `InvokeCodeInterpreter` API.
|
||||
|
||||
Two boto3 clients are involved:
|
||||
- **bedrock-agentcore-control** (Control Plane): manages the Code Interpreter resource
|
||||
(create / delete). Users must create one beforehand or use the system-provided
|
||||
``aws.codeinterpreter.v1``.
|
||||
- **bedrock-agentcore** (Data Plane): manages sessions and executes operations
|
||||
(start/stop session, execute commands, file I/O).
|
||||
|
||||
Key differences from other providers:
|
||||
- stdin is not supported (same as E2B) — uses ``NopTransportWriteCloser``.
|
||||
- ``executeCommand`` returns the full stdout/stderr once the command completes,
|
||||
rather than streaming incrementally. We wrap the result with
|
||||
``QueueTransportReadCloser`` so the upper-layer ``CommandFuture`` works unchanged.
|
||||
- ``get_command_status`` raises ``NotSupportedOperationError`` (same as E2B).
|
||||
The synchronous ``executeCommand`` path is used instead.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import posixpath
|
||||
import shlex
|
||||
import threading
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.virtual_environment.__base.entities import (
|
||||
Arch,
|
||||
CommandStatus,
|
||||
ConnectionHandle,
|
||||
FileState,
|
||||
Metadata,
|
||||
OperatingSystem,
|
||||
)
|
||||
from core.virtual_environment.__base.exec import (
|
||||
ArchNotSupportedError,
|
||||
NotSupportedOperationError,
|
||||
SandboxConfigValidationError,
|
||||
)
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
from core.virtual_environment.channel.queue_transport import QueueTransportReadCloser
|
||||
from core.virtual_environment.channel.transport import (
|
||||
NopTransportWriteCloser,
|
||||
TransportReadCloser,
|
||||
TransportWriteCloser,
|
||||
)
|
||||
from core.virtual_environment.constants import COMMAND_EXECUTION_TIMEOUT_SECONDS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maximum time in seconds that the boto3 read on the EventStream socket is
|
||||
# allowed to block. Must exceed the longest expected command execution.
|
||||
_BOTO3_READ_TIMEOUT_SECONDS = COMMAND_EXECUTION_TIMEOUT_SECONDS + 60
|
||||
|
||||
|
||||
class AWSCodeInterpreterEnvironment(VirtualEnvironment):
|
||||
"""
|
||||
AWS Bedrock AgentCore Code Interpreter virtual environment provider.
|
||||
|
||||
The provider maps the ``VirtualEnvironment`` protocol onto the AgentCore
|
||||
Code Interpreter Data-Plane API (``InvokeCodeInterpreter``).
|
||||
|
||||
Lifecycle:
|
||||
1. ``_construct_environment`` starts a new session via
|
||||
``StartCodeInterpreterSession``.
|
||||
2. Commands and file operations invoke ``InvokeCodeInterpreter`` with
|
||||
the appropriate ``name`` (``executeCommand``, ``writeFiles``, etc.).
|
||||
3. ``release_environment`` stops the session via
|
||||
``StopCodeInterpreterSession``.
|
||||
|
||||
Configuration (``OptionsKey``):
|
||||
- ``aws_access_key_id`` / ``aws_secret_access_key``: IAM credentials.
|
||||
- ``aws_region``: AWS region (e.g. ``us-east-1``).
|
||||
- ``code_interpreter_id``: the Code Interpreter resource identifier
|
||||
(e.g. ``aws.codeinterpreter.v1`` for the system-provided one).
|
||||
- ``session_timeout_seconds``: optional; defaults to 900 (15 min).
|
||||
"""
|
||||
|
||||
_WORKDIR = "/home/user"
|
||||
|
||||
class OptionsKey(StrEnum):
|
||||
AWS_ACCESS_KEY_ID = "aws_access_key_id"
|
||||
AWS_SECRET_ACCESS_KEY = "aws_secret_access_key"
|
||||
AWS_REGION = "aws_region"
|
||||
CODE_INTERPRETER_ID = "code_interpreter_id"
|
||||
SESSION_TIMEOUT_SECONDS = "session_timeout_seconds"
|
||||
|
||||
class StoreKey(StrEnum):
|
||||
CLIENT = "client"
|
||||
SESSION_ID = "session_id"
|
||||
CODE_INTERPRETER_ID = "code_interpreter_id"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Config schema & validation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@classmethod
|
||||
def get_config_schema(cls) -> list[BasicProviderConfig]:
|
||||
return [
|
||||
BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=cls.OptionsKey.AWS_ACCESS_KEY_ID),
|
||||
BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=cls.OptionsKey.AWS_SECRET_ACCESS_KEY),
|
||||
BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name=cls.OptionsKey.AWS_REGION),
|
||||
BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name=cls.OptionsKey.CODE_INTERPRETER_ID),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def validate(cls, options: Mapping[str, Any]) -> None:
|
||||
"""Validate credentials by starting then immediately stopping a session."""
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
for key in (cls.OptionsKey.AWS_ACCESS_KEY_ID, cls.OptionsKey.AWS_SECRET_ACCESS_KEY, cls.OptionsKey.AWS_REGION):
|
||||
if not options.get(key):
|
||||
raise SandboxConfigValidationError(f"{key} is required")
|
||||
|
||||
code_interpreter_id = options.get(cls.OptionsKey.CODE_INTERPRETER_ID, "")
|
||||
if not code_interpreter_id:
|
||||
raise SandboxConfigValidationError("code_interpreter_id is required")
|
||||
|
||||
client = boto3.client(
|
||||
"bedrock-agentcore",
|
||||
region_name=options[cls.OptionsKey.AWS_REGION],
|
||||
aws_access_key_id=options[cls.OptionsKey.AWS_ACCESS_KEY_ID],
|
||||
aws_secret_access_key=options[cls.OptionsKey.AWS_SECRET_ACCESS_KEY],
|
||||
)
|
||||
|
||||
try:
|
||||
resp = client.start_code_interpreter_session(
|
||||
codeInterpreterIdentifier=code_interpreter_id,
|
||||
sessionTimeoutSeconds=60,
|
||||
)
|
||||
session_id = resp["sessionId"]
|
||||
# Immediately stop the validation session.
|
||||
client.stop_code_interpreter_session(
|
||||
codeInterpreterIdentifier=code_interpreter_id,
|
||||
sessionId=session_id,
|
||||
)
|
||||
except ClientError as exc:
|
||||
raise SandboxConfigValidationError(f"AWS AgentCore Code Interpreter validation failed: {exc}") from exc
|
||||
except Exception as exc:
|
||||
raise SandboxConfigValidationError(f"AWS AgentCore Code Interpreter connection failed: {exc}") from exc
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Environment lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata:
|
||||
"""Start a new Code Interpreter session and detect platform info."""
|
||||
import boto3
|
||||
from botocore.config import Config
|
||||
|
||||
code_interpreter_id: str = options.get(self.OptionsKey.CODE_INTERPRETER_ID, "")
|
||||
timeout_seconds: int = int(options.get(self.OptionsKey.SESSION_TIMEOUT_SECONDS, 900))
|
||||
|
||||
client = boto3.client(
|
||||
"bedrock-agentcore",
|
||||
region_name=options[self.OptionsKey.AWS_REGION],
|
||||
aws_access_key_id=options[self.OptionsKey.AWS_ACCESS_KEY_ID],
|
||||
aws_secret_access_key=options[self.OptionsKey.AWS_SECRET_ACCESS_KEY],
|
||||
config=Config(read_timeout=_BOTO3_READ_TIMEOUT_SECONDS),
|
||||
)
|
||||
|
||||
resp = client.start_code_interpreter_session(
|
||||
codeInterpreterIdentifier=code_interpreter_id,
|
||||
sessionTimeoutSeconds=timeout_seconds,
|
||||
)
|
||||
session_id: str = resp["sessionId"]
|
||||
|
||||
logger.info(
|
||||
"AgentCore Code Interpreter session started: code_interpreter_id=%s, session_id=%s",
|
||||
code_interpreter_id,
|
||||
session_id,
|
||||
)
|
||||
|
||||
# Detect architecture and OS via a quick command.
|
||||
arch = Arch.AMD64
|
||||
operating_system = OperatingSystem.LINUX
|
||||
try:
|
||||
result = self._invoke(client, code_interpreter_id, session_id, "executeCommand", {"command": "uname -m -s"})
|
||||
system_info = (result.get("stdout") or "").strip()
|
||||
parts = system_info.split()
|
||||
if len(parts) >= 2:
|
||||
operating_system = self._convert_operating_system(parts[0])
|
||||
arch = self._convert_architecture(parts[1])
|
||||
elif len(parts) == 1:
|
||||
arch = self._convert_architecture(parts[0])
|
||||
except Exception:
|
||||
logger.warning("Failed to detect platform info, defaulting to Linux/AMD64")
|
||||
|
||||
# Inject environment variables if provided.
|
||||
if environments:
|
||||
export_parts = [f"export {k}={shlex.quote(v)}" for k, v in environments.items()]
|
||||
export_cmd = " && ".join(export_parts)
|
||||
try:
|
||||
self._invoke(client, code_interpreter_id, session_id, "executeCommand", {"command": export_cmd})
|
||||
except Exception:
|
||||
logger.warning("Failed to inject environment variables into AgentCore session")
|
||||
|
||||
return Metadata(
|
||||
id=session_id,
|
||||
arch=arch,
|
||||
os=operating_system,
|
||||
store={
|
||||
self.StoreKey.CLIENT: client,
|
||||
self.StoreKey.SESSION_ID: session_id,
|
||||
self.StoreKey.CODE_INTERPRETER_ID: code_interpreter_id,
|
||||
},
|
||||
)
|
||||
|
||||
def release_environment(self) -> None:
|
||||
"""Stop the Code Interpreter session and release resources."""
|
||||
client = self._client
|
||||
try:
|
||||
client.stop_code_interpreter_session(
|
||||
codeInterpreterIdentifier=self._code_interpreter_id,
|
||||
sessionId=self._session_id,
|
||||
)
|
||||
logger.info("AgentCore Code Interpreter session stopped: session_id=%s", self._session_id)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to stop AgentCore Code Interpreter session: session_id=%s",
|
||||
self._session_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Connection (virtual — no real connection needed)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def establish_connection(self) -> ConnectionHandle:
|
||||
return ConnectionHandle(id=uuid4().hex)
|
||||
|
||||
def release_connection(self, connection_handle: ConnectionHandle) -> None:
|
||||
pass
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Command execution
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def execute_command(
|
||||
self,
|
||||
connection_handle: ConnectionHandle,
|
||||
command: list[str],
|
||||
environments: Mapping[str, str] | None = None,
|
||||
cwd: str | None = None,
|
||||
) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]:
|
||||
"""
|
||||
Execute a shell command via AgentCore Code Interpreter.
|
||||
|
||||
The command is executed synchronously on the AWS side; a background thread
|
||||
calls ``executeCommand`` and feeds the result into queue-based transports
|
||||
so the caller sees the standard Transport interface.
|
||||
|
||||
stdin is not supported — ``NopTransportWriteCloser`` is returned.
|
||||
"""
|
||||
stdout_stream = QueueTransportReadCloser()
|
||||
stderr_stream = QueueTransportReadCloser()
|
||||
|
||||
working_dir = self._workspace_path(cwd) if cwd else self._WORKDIR
|
||||
cmd_str = shlex.join(command)
|
||||
|
||||
# Wrap env vars and cwd into the command string since the API only
|
||||
# accepts a flat ``command`` string argument.
|
||||
prefix_parts: list[str] = []
|
||||
if environments:
|
||||
for k, v in environments.items():
|
||||
prefix_parts.append(f"export {k}={shlex.quote(v)}")
|
||||
prefix_parts.append(f"cd {shlex.quote(working_dir)}")
|
||||
full_cmd = " && ".join([*prefix_parts, cmd_str])
|
||||
|
||||
threading.Thread(
|
||||
target=self._cmd_thread,
|
||||
args=(full_cmd, stdout_stream, stderr_stream),
|
||||
daemon=True,
|
||||
).start()
|
||||
|
||||
return (
|
||||
"N/A",
|
||||
NopTransportWriteCloser(),
|
||||
stdout_stream,
|
||||
stderr_stream,
|
||||
)
|
||||
|
||||
def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus:
|
||||
"""Not supported — same as E2B. ``CommandFuture`` handles this gracefully."""
|
||||
raise NotSupportedOperationError("AgentCore Code Interpreter does not support getting command status.")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# File operations
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def upload_file(self, path: str, content: BytesIO) -> None:
|
||||
"""Upload a file to the Code Interpreter session."""
|
||||
remote_path = self._workspace_path(path)
|
||||
file_bytes = content.read()
|
||||
|
||||
self._invoke(
|
||||
self._client,
|
||||
self._code_interpreter_id,
|
||||
self._session_id,
|
||||
"writeFiles",
|
||||
{"content": [{"path": remote_path, "blob": file_bytes}]},
|
||||
)
|
||||
|
||||
def download_file(self, path: str) -> BytesIO:
|
||||
"""Download a file from the Code Interpreter session."""
|
||||
remote_path = self._workspace_path(path)
|
||||
|
||||
result = self._invoke(
|
||||
self._client,
|
||||
self._code_interpreter_id,
|
||||
self._session_id,
|
||||
"readFiles",
|
||||
{"path": remote_path},
|
||||
)
|
||||
|
||||
# The response content blocks may contain blob or text data.
|
||||
content_blocks: list[dict[str, Any]] = result.get("content", [])
|
||||
for block in content_blocks:
|
||||
resource = block.get("resource")
|
||||
if resource:
|
||||
blob = resource.get("blob")
|
||||
if blob:
|
||||
return BytesIO(blob if isinstance(blob, bytes) else blob.encode())
|
||||
text = resource.get("text")
|
||||
if text:
|
||||
return BytesIO(text.encode("utf-8"))
|
||||
# Fallback: check top-level data/text fields.
|
||||
if block.get("data"):
|
||||
data = block["data"]
|
||||
return BytesIO(data if isinstance(data, bytes) else data.encode())
|
||||
if block.get("text"):
|
||||
return BytesIO(block["text"].encode("utf-8"))
|
||||
|
||||
return BytesIO(b"")
|
||||
|
||||
def list_files(self, directory_path: str, limit: int) -> Sequence[FileState]:
|
||||
"""List files in a directory of the Code Interpreter session."""
|
||||
remote_dir = self._workspace_path(directory_path)
|
||||
|
||||
result = self._invoke(
|
||||
self._client,
|
||||
self._code_interpreter_id,
|
||||
self._session_id,
|
||||
"listFiles",
|
||||
{"directoryPath": remote_dir},
|
||||
)
|
||||
|
||||
# The API returns file information in content blocks.
|
||||
# Since the exact structure may vary, we also fall back to running
|
||||
# a shell command to list files if the content format is not parseable.
|
||||
content_blocks: list[dict[str, Any]] = result.get("content", [])
|
||||
files: list[FileState] = []
|
||||
|
||||
for block in content_blocks:
|
||||
text = block.get("text", "")
|
||||
name = block.get("name", "")
|
||||
size = block.get("size", 0)
|
||||
uri = block.get("uri", "")
|
||||
|
||||
file_path = uri or name or text
|
||||
if not file_path:
|
||||
continue
|
||||
|
||||
# Normalise to relative path from workdir.
|
||||
if file_path.startswith(self._WORKDIR):
|
||||
file_path = posixpath.relpath(file_path, self._WORKDIR)
|
||||
|
||||
files.append(
|
||||
FileState(
|
||||
path=file_path,
|
||||
size=size or 0,
|
||||
created_at=0,
|
||||
updated_at=0,
|
||||
)
|
||||
)
|
||||
|
||||
if len(files) >= limit:
|
||||
break
|
||||
|
||||
return files
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def _client(self) -> Any:
|
||||
return self.metadata.store[self.StoreKey.CLIENT]
|
||||
|
||||
@property
|
||||
def _session_id(self) -> str:
|
||||
return str(self.metadata.store[self.StoreKey.SESSION_ID])
|
||||
|
||||
@property
|
||||
def _code_interpreter_id(self) -> str:
|
||||
return str(self.metadata.store[self.StoreKey.CODE_INTERPRETER_ID])
|
||||
|
||||
def _workspace_path(self, path: str) -> str:
|
||||
"""Convert a path to an absolute path in the Code Interpreter session."""
|
||||
normalized = posixpath.normpath(path)
|
||||
if normalized in ("", "."):
|
||||
return self._WORKDIR
|
||||
if normalized.startswith("/"):
|
||||
return normalized
|
||||
return posixpath.join(self._WORKDIR, normalized)
|
||||
|
||||
def _cmd_thread(
|
||||
self,
|
||||
command: str,
|
||||
stdout_stream: QueueTransportReadCloser,
|
||||
stderr_stream: QueueTransportReadCloser,
|
||||
) -> None:
|
||||
"""Background thread that executes a command and feeds output into queue transports."""
|
||||
stdout_writer = stdout_stream.get_write_handler()
|
||||
stderr_writer = stderr_stream.get_write_handler()
|
||||
|
||||
try:
|
||||
result = self._invoke(
|
||||
self._client,
|
||||
self._code_interpreter_id,
|
||||
self._session_id,
|
||||
"executeCommand",
|
||||
{"command": command},
|
||||
)
|
||||
stdout_data = result.get("stdout", "")
|
||||
stderr_data = result.get("stderr", "")
|
||||
|
||||
if stdout_data:
|
||||
stdout_writer.write(stdout_data.encode("utf-8") if isinstance(stdout_data, str) else stdout_data)
|
||||
if stderr_data:
|
||||
stderr_writer.write(stderr_data.encode("utf-8") if isinstance(stderr_data, str) else stderr_data)
|
||||
except Exception as exc:
|
||||
error_msg = f"Command execution failed: {type(exc).__name__}: {exc}\n"
|
||||
stderr_writer.write(error_msg.encode())
|
||||
finally:
|
||||
stdout_stream.close()
|
||||
stderr_stream.close()
|
||||
|
||||
@staticmethod
|
||||
def _invoke(
|
||||
client: Any,
|
||||
code_interpreter_id: str,
|
||||
session_id: str,
|
||||
name: str,
|
||||
arguments: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Call ``InvokeCodeInterpreter`` and extract the structured result.
|
||||
|
||||
The API returns an EventStream; this helper iterates over it and merges
|
||||
``structuredContent`` and ``content`` from all ``result`` events into a
|
||||
single dict.
|
||||
|
||||
Returns a dict that may contain:
|
||||
- ``stdout``, ``stderr``, ``exitCode``, ``taskId``, ``taskStatus``,
|
||||
``executionTime`` (from ``structuredContent``)
|
||||
- ``content`` (list of content blocks)
|
||||
- ``isError`` (bool)
|
||||
"""
|
||||
response = client.invoke_code_interpreter(
|
||||
codeInterpreterIdentifier=code_interpreter_id,
|
||||
sessionId=session_id,
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
)
|
||||
|
||||
merged: dict[str, Any] = {"content": []}
|
||||
|
||||
stream = response.get("stream")
|
||||
if stream is None:
|
||||
return merged
|
||||
|
||||
for event in stream:
|
||||
result = event.get("result")
|
||||
if result is None:
|
||||
# Check for exception events.
|
||||
for exc_key in (
|
||||
"accessDeniedException",
|
||||
"validationException",
|
||||
"resourceNotFoundException",
|
||||
"throttlingException",
|
||||
"internalServerException",
|
||||
):
|
||||
if event.get(exc_key):
|
||||
raise RuntimeError(f"AgentCore error ({exc_key}): {event[exc_key]}")
|
||||
continue
|
||||
|
||||
if result.get("isError"):
|
||||
merged["isError"] = True
|
||||
|
||||
structured = result.get("structuredContent")
|
||||
if structured:
|
||||
merged.update(structured)
|
||||
|
||||
content = result.get("content")
|
||||
if content:
|
||||
merged["content"].extend(content)
|
||||
|
||||
return merged
|
||||
|
||||
@staticmethod
|
||||
def _convert_architecture(arch_str: str) -> Arch:
|
||||
arch_map: dict[str, Arch] = {
|
||||
"x86_64": Arch.AMD64,
|
||||
"aarch64": Arch.ARM64,
|
||||
"armv7l": Arch.ARM64,
|
||||
"arm64": Arch.ARM64,
|
||||
"amd64": Arch.AMD64,
|
||||
}
|
||||
if arch_str in arch_map:
|
||||
return arch_map[arch_str]
|
||||
raise ArchNotSupportedError(f"Unsupported architecture: {arch_str}")
|
||||
|
||||
@staticmethod
|
||||
def _convert_operating_system(os_str: str) -> OperatingSystem:
|
||||
os_map: dict[str, OperatingSystem] = {
|
||||
"Linux": OperatingSystem.LINUX,
|
||||
"Darwin": OperatingSystem.DARWIN,
|
||||
}
|
||||
if os_str in os_map:
|
||||
return os_map[os_str]
|
||||
raise ArchNotSupportedError(f"Unsupported operating system: {os_str}")
|
||||
281
api/core/virtual_environment/providers/daytona_sandbox.py
Normal file
281
api/core/virtual_environment/providers/daytona_sandbox.py
Normal file
@ -0,0 +1,281 @@
|
||||
import logging
|
||||
import posixpath
|
||||
import shlex
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from io import BytesIO
|
||||
from typing import Any, TypedDict, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from daytona import (
|
||||
CodeLanguage,
|
||||
CreateSandboxFromImageParams,
|
||||
CreateSandboxFromSnapshotParams,
|
||||
Daytona,
|
||||
DaytonaConfig,
|
||||
Sandbox,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _CommandRecord(TypedDict):
|
||||
"""Record for tracking command execution state."""
|
||||
|
||||
thread: threading.Thread
|
||||
exit_code: int | None
|
||||
|
||||
|
||||
from core.virtual_environment.__base.entities import (
|
||||
Arch,
|
||||
CommandStatus,
|
||||
ConnectionHandle,
|
||||
FileState,
|
||||
Metadata,
|
||||
OperatingSystem,
|
||||
)
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
from core.virtual_environment.channel.queue_transport import QueueTransportReadCloser
|
||||
from core.virtual_environment.channel.transport import (
|
||||
NopTransportWriteCloser,
|
||||
TransportReadCloser,
|
||||
TransportWriteCloser,
|
||||
)
|
||||
|
||||
|
||||
class DaytonaEnvironment(VirtualEnvironment):
|
||||
"""
|
||||
Daytona virtual environment provider backed by Daytona Sandboxes.
|
||||
"""
|
||||
|
||||
_DEFAULT_DAYTONA_API_URL = "https://app.daytona.io/api"
|
||||
|
||||
class OptionsKey(StrEnum):
|
||||
API_KEY = "api_key"
|
||||
API_URL = "api_url"
|
||||
TARGET = "target"
|
||||
LANGUAGE = "language"
|
||||
SNAPSHOT = "snapshot"
|
||||
IMAGE = "image"
|
||||
AUTO_STOP_INTERVAL = "auto_stop_interval"
|
||||
AUTO_ARCHIVE_INTERVAL = "auto_archive_interval"
|
||||
AUTO_DELETE_INTERVAL = "auto_delete_interval"
|
||||
PUBLIC = "public"
|
||||
NAME = "name"
|
||||
LABELS = "labels"
|
||||
|
||||
class StoreKey(StrEnum):
|
||||
DAYTONA = "daytona"
|
||||
SANDBOX = "sandbox"
|
||||
WORKDIR = "workdir"
|
||||
COMMANDS = "commands"
|
||||
COMMANDS_LOCK = "commands_lock"
|
||||
|
||||
def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata:
|
||||
config = DaytonaConfig(
|
||||
api_key=cast(str | None, options.get(self.OptionsKey.API_KEY)),
|
||||
api_url=cast(str | None, options.get(self.OptionsKey.API_URL, self._DEFAULT_DAYTONA_API_URL)),
|
||||
target=cast(str | None, options.get(self.OptionsKey.TARGET)),
|
||||
)
|
||||
daytona = Daytona(config)
|
||||
|
||||
language = cast(CodeLanguage, options.get(self.OptionsKey.LANGUAGE, CodeLanguage.PYTHON))
|
||||
auto_stop_interval = cast(int | None, options.get(self.OptionsKey.AUTO_STOP_INTERVAL))
|
||||
auto_archive_interval = cast(int | None, options.get(self.OptionsKey.AUTO_ARCHIVE_INTERVAL))
|
||||
auto_delete_interval = cast(int | None, options.get(self.OptionsKey.AUTO_DELETE_INTERVAL))
|
||||
public = cast(bool | None, options.get(self.OptionsKey.PUBLIC))
|
||||
name = cast(str | None, options.get(self.OptionsKey.NAME))
|
||||
labels = cast(dict[str, str] | None, options.get(self.OptionsKey.LABELS))
|
||||
|
||||
image = cast(str | None, options.get(self.OptionsKey.IMAGE))
|
||||
snapshot = cast(str | None, options.get(self.OptionsKey.SNAPSHOT))
|
||||
|
||||
if image is not None:
|
||||
params = CreateSandboxFromImageParams(
|
||||
image=image,
|
||||
language=language,
|
||||
env_vars=dict(environments or {}),
|
||||
auto_stop_interval=auto_stop_interval,
|
||||
auto_archive_interval=auto_archive_interval,
|
||||
auto_delete_interval=auto_delete_interval,
|
||||
public=public,
|
||||
name=name,
|
||||
labels=labels,
|
||||
)
|
||||
else:
|
||||
params = CreateSandboxFromSnapshotParams(
|
||||
snapshot=snapshot,
|
||||
language=language,
|
||||
env_vars=dict(environments or {}),
|
||||
auto_stop_interval=auto_stop_interval,
|
||||
auto_archive_interval=auto_archive_interval,
|
||||
auto_delete_interval=auto_delete_interval,
|
||||
public=public,
|
||||
name=name,
|
||||
labels=labels,
|
||||
)
|
||||
|
||||
sandbox = daytona.create(params=params)
|
||||
workdir = sandbox.get_work_dir()
|
||||
|
||||
return Metadata(
|
||||
id=sandbox.id,
|
||||
arch=Arch.AMD64,
|
||||
os=OperatingSystem.LINUX,
|
||||
store={
|
||||
self.StoreKey.DAYTONA: daytona,
|
||||
self.StoreKey.SANDBOX: sandbox,
|
||||
self.StoreKey.WORKDIR: workdir,
|
||||
self.StoreKey.COMMANDS: {},
|
||||
self.StoreKey.COMMANDS_LOCK: threading.Lock(),
|
||||
},
|
||||
)
|
||||
|
||||
def release_environment(self) -> None:
|
||||
daytona: Daytona = self.metadata.store[self.StoreKey.DAYTONA]
|
||||
sandbox = self.metadata.store[self.StoreKey.SANDBOX]
|
||||
try:
|
||||
daytona.delete(sandbox)
|
||||
except Exception:
|
||||
logger.exception("Failed to delete Daytona sandbox %s during cleanup", sandbox.id)
|
||||
|
||||
def establish_connection(self) -> ConnectionHandle:
|
||||
return ConnectionHandle(id=uuid4().hex)
|
||||
|
||||
def release_connection(self, connection_handle: ConnectionHandle) -> None:
|
||||
pass
|
||||
|
||||
def upload_file(self, path: str, content: BytesIO) -> None:
|
||||
remote_path = self._workspace_path(path)
|
||||
sandbox: Sandbox = self.metadata.store[self.StoreKey.SANDBOX]
|
||||
sandbox.fs.upload_file(content.getvalue(), remote_path)
|
||||
|
||||
def download_file(self, path: str) -> BytesIO:
|
||||
remote_path = self._workspace_path(path)
|
||||
sandbox: Sandbox = self.metadata.store[self.StoreKey.SANDBOX]
|
||||
data = sandbox.fs.download_file(remote_path)
|
||||
return BytesIO(data)
|
||||
|
||||
def list_files(self, directory_path: str, limit: int) -> Sequence[FileState]:
|
||||
remote_dir = self._workspace_path(directory_path)
|
||||
sandbox: Sandbox = self.metadata.store[self.StoreKey.SANDBOX]
|
||||
try:
|
||||
file_infos = sandbox.fs.list_files(remote_dir)
|
||||
except Exception:
|
||||
logger.exception("Failed to list files in directory %s", remote_dir)
|
||||
return []
|
||||
|
||||
files: list[FileState] = []
|
||||
for info in file_infos:
|
||||
full_path = posixpath.join(remote_dir, info.name)
|
||||
relative_path = posixpath.relpath(full_path, self._working_dir)
|
||||
files.append(
|
||||
FileState(
|
||||
path=relative_path,
|
||||
size=info.size,
|
||||
created_at=self._parse_mod_time(info.mod_time),
|
||||
updated_at=self._parse_mod_time(info.mod_time),
|
||||
)
|
||||
)
|
||||
if len(files) >= limit:
|
||||
break
|
||||
return files
|
||||
|
||||
def execute_command(
|
||||
self,
|
||||
connection_handle: ConnectionHandle,
|
||||
command: list[str],
|
||||
environments: Mapping[str, str] | None = None,
|
||||
cwd: str | None = None,
|
||||
) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]:
|
||||
sandbox: Sandbox = self.metadata.store[self.StoreKey.SANDBOX]
|
||||
|
||||
stdout_stream = QueueTransportReadCloser()
|
||||
stderr_stream = QueueTransportReadCloser()
|
||||
pid = uuid4().hex
|
||||
|
||||
working_dir = self._workspace_path(cwd) if cwd else self._working_dir
|
||||
|
||||
thread = threading.Thread(
|
||||
target=self._exec_thread,
|
||||
args=(pid, sandbox, command, environments or {}, working_dir, stdout_stream, stderr_stream),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
thread.start()
|
||||
|
||||
return pid, NopTransportWriteCloser(), stdout_stream, stderr_stream
|
||||
|
||||
def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus:
|
||||
commands: dict[str, _CommandRecord] = self.metadata.store[self.StoreKey.COMMANDS]
|
||||
commands_lock: threading.Lock = self.metadata.store[self.StoreKey.COMMANDS_LOCK]
|
||||
|
||||
with commands_lock:
|
||||
record = commands.get(pid)
|
||||
if not record:
|
||||
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=None)
|
||||
|
||||
thread: threading.Thread = record["thread"]
|
||||
exit_code = record.get("exit_code")
|
||||
if thread.is_alive():
|
||||
return CommandStatus(status=CommandStatus.Status.RUNNING, exit_code=None)
|
||||
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=exit_code)
|
||||
|
||||
@property
|
||||
def _working_dir(self) -> str:
|
||||
return cast(str, self.metadata.store[self.StoreKey.WORKDIR])
|
||||
|
||||
def _workspace_path(self, path: str) -> str:
|
||||
normalized = posixpath.normpath(path)
|
||||
if normalized in ("", "."):
|
||||
return self._working_dir
|
||||
if normalized.startswith("/"):
|
||||
return normalized
|
||||
return posixpath.join(self._working_dir, normalized)
|
||||
|
||||
def _exec_thread(
|
||||
self,
|
||||
pid: str,
|
||||
sandbox: Sandbox,
|
||||
command: list[str],
|
||||
environments: Mapping[str, str],
|
||||
cwd: str,
|
||||
stdout_stream: QueueTransportReadCloser,
|
||||
stderr_stream: QueueTransportReadCloser,
|
||||
) -> None:
|
||||
commands: dict[str, _CommandRecord] = self.metadata.store[self.StoreKey.COMMANDS]
|
||||
commands_lock: threading.Lock = self.metadata.store[self.StoreKey.COMMANDS_LOCK]
|
||||
|
||||
stdout_writer = stdout_stream.get_write_handler()
|
||||
stderr_writer = stderr_stream.get_write_handler()
|
||||
exit_code: int | None = None
|
||||
try:
|
||||
response = sandbox.process.exec(
|
||||
command=shlex.join(command),
|
||||
env=dict(environments),
|
||||
cwd=cwd,
|
||||
)
|
||||
exit_code = response.exit_code
|
||||
output = response.artifacts.stdout if response.artifacts and response.artifacts.stdout else response.result
|
||||
if output:
|
||||
stdout_writer.write(output.encode())
|
||||
except Exception as exc:
|
||||
stderr_writer.write(str(exc).encode())
|
||||
exit_code = 1
|
||||
finally:
|
||||
stdout_stream.close()
|
||||
stderr_stream.close()
|
||||
with commands_lock:
|
||||
if pid in commands:
|
||||
commands[pid]["exit_code"] = exit_code
|
||||
|
||||
def _parse_mod_time(self, mod_time: str) -> int:
|
||||
try:
|
||||
cleaned = mod_time.replace("Z", "+00:00")
|
||||
return int(datetime.fromisoformat(cleaned).timestamp())
|
||||
except (ValueError, AttributeError, OSError):
|
||||
logger.warning("Failed to parse modification time '%s', falling back to current time", mod_time)
|
||||
return int(time.time())
|
||||
635
api/core/virtual_environment/providers/docker_daemon_sandbox.py
Normal file
635
api/core/virtual_environment/providers/docker_daemon_sandbox.py
Normal file
@ -0,0 +1,635 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import socket
|
||||
import tarfile
|
||||
import threading
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import IntEnum, StrEnum
|
||||
from functools import lru_cache
|
||||
from io import BytesIO
|
||||
from pathlib import PurePosixPath
|
||||
from queue import Empty, Queue
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from uuid import uuid4
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from docker.models.containers import Container
|
||||
|
||||
import docker
|
||||
from configs import dify_config
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.virtual_environment.__base.entities import (
|
||||
Arch,
|
||||
CommandStatus,
|
||||
ConnectionHandle,
|
||||
FileState,
|
||||
Metadata,
|
||||
OperatingSystem,
|
||||
)
|
||||
from core.virtual_environment.__base.exec import SandboxConfigValidationError, VirtualEnvironmentLaunchFailedError
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
from core.virtual_environment.channel.exec import TransportEOFError
|
||||
from core.virtual_environment.channel.socket_transport import SocketWriteCloser
|
||||
from core.virtual_environment.channel.transport import TransportReadCloser, TransportWriteCloser
|
||||
|
||||
|
||||
class DockerStreamType(IntEnum):
|
||||
"""
|
||||
Docker multiplexed stream types.
|
||||
|
||||
When Docker exec runs with tty=False, it multiplexes stdout and stderr over a single
|
||||
socket connection. Each frame is prefixed with an 8-byte header:
|
||||
|
||||
[stream_type (1 byte)][0][0][0][payload_size (4 bytes, big-endian)]
|
||||
|
||||
This allows the client to distinguish between stdout (type=1) and stderr (type=2).
|
||||
See: https://docs.docker.com/engine/api/v1.41/#operation/ContainerAttach
|
||||
"""
|
||||
|
||||
STDIN = 0
|
||||
STDOUT = 1
|
||||
STDERR = 2
|
||||
|
||||
|
||||
class DockerDemuxer:
|
||||
"""
|
||||
Demultiplexes Docker's combined stdout/stderr stream using producer-consumer pattern.
|
||||
|
||||
Docker exec with tty=False sends stdout and stderr over a single socket,
|
||||
each frame prefixed with an 8-byte header:
|
||||
- Byte 0: stream type (1=stdout, 2=stderr)
|
||||
- Bytes 1-3: reserved (zeros)
|
||||
- Bytes 4-7: payload size (big-endian uint32)
|
||||
|
||||
THREAD SAFETY:
|
||||
A single background thread reads frames from the socket and dispatches payloads
|
||||
to thread-safe queues. This avoids race conditions where multiple threads
|
||||
calling _read_next_frame() simultaneously caused frame header/body corruption,
|
||||
resulting in incomplete stdout/stderr output.
|
||||
|
||||
TIMEOUT HANDLING:
|
||||
Queue.get() uses a timeout to prevent indefinite blocking when the socket is
|
||||
closed unexpectedly (e.g., container removed). This allows periodic checks for
|
||||
error conditions and closed state.
|
||||
"""
|
||||
|
||||
_HEADER_SIZE = 8
|
||||
_QUEUE_GET_TIMEOUT = 5.0 # seconds
|
||||
|
||||
def __init__(self, sock: socket.SocketIO):
|
||||
self._sock = sock
|
||||
self._stdout_queue: Queue[bytes | None] = Queue()
|
||||
self._stderr_queue: Queue[bytes | None] = Queue()
|
||||
self._closed = False
|
||||
self._error: BaseException | None = None
|
||||
|
||||
self._demux_thread = threading.Thread(
|
||||
target=self._demux_loop,
|
||||
daemon=True,
|
||||
name="docker-demuxer",
|
||||
)
|
||||
self._demux_thread.start()
|
||||
|
||||
def _demux_loop(self) -> None:
|
||||
try:
|
||||
while not self._closed:
|
||||
header = self._read_exact(self._HEADER_SIZE)
|
||||
if not header or len(header) < self._HEADER_SIZE:
|
||||
break
|
||||
|
||||
frame_type = header[0]
|
||||
payload_size = int.from_bytes(header[4:8], "big")
|
||||
|
||||
if payload_size == 0:
|
||||
continue
|
||||
|
||||
payload = self._read_exact(payload_size)
|
||||
if not payload:
|
||||
break
|
||||
|
||||
if frame_type == DockerStreamType.STDOUT:
|
||||
self._stdout_queue.put(payload)
|
||||
elif frame_type == DockerStreamType.STDERR:
|
||||
self._stderr_queue.put(payload)
|
||||
|
||||
except BaseException as e:
|
||||
self._error = e
|
||||
finally:
|
||||
self._stdout_queue.put(None)
|
||||
self._stderr_queue.put(None)
|
||||
|
||||
def _read_exact(self, size: int) -> bytes:
|
||||
data = bytearray()
|
||||
remaining = size
|
||||
while remaining > 0:
|
||||
try:
|
||||
chunk = self._sock.read(remaining)
|
||||
if not chunk:
|
||||
return bytes(data) if data else b""
|
||||
data.extend(chunk)
|
||||
remaining -= len(chunk)
|
||||
except (ConnectionResetError, BrokenPipeError):
|
||||
return bytes(data) if data else b""
|
||||
return bytes(data)
|
||||
|
||||
def read_stdout(self) -> bytes:
|
||||
return self._read_from_queue(self._stdout_queue)
|
||||
|
||||
def read_stderr(self) -> bytes:
|
||||
return self._read_from_queue(self._stderr_queue)
|
||||
|
||||
def _read_from_queue(self, queue: Queue[bytes | None]) -> bytes:
|
||||
"""
|
||||
Read from queue with timeout to prevent indefinite blocking.
|
||||
|
||||
When the Docker container is removed or the socket is closed unexpectedly,
|
||||
the demux thread may be stuck in socket.read(). Using a timeout allows us
|
||||
to periodically check for errors and closed state instead of blocking forever.
|
||||
"""
|
||||
if self._error:
|
||||
raise TransportEOFError(f"Demuxer error: {self._error}") from self._error
|
||||
|
||||
while True:
|
||||
try:
|
||||
chunk = queue.get(timeout=self._QUEUE_GET_TIMEOUT)
|
||||
if chunk is None:
|
||||
if self._error:
|
||||
raise TransportEOFError(f"Demuxer error: {str(self._error)}")
|
||||
raise TransportEOFError("End of demuxed stream")
|
||||
return chunk
|
||||
except Empty:
|
||||
# Timeout - check if we should continue waiting
|
||||
if self._closed:
|
||||
raise TransportEOFError("Demuxer closed")
|
||||
if self._error:
|
||||
error = cast(BaseException, self._error)
|
||||
raise TransportEOFError(f"Demuxer error: {error}") from error
|
||||
# No error, continue waiting
|
||||
|
||||
def close(self) -> None:
|
||||
if not self._closed:
|
||||
self._closed = True
|
||||
try:
|
||||
self._sock.close()
|
||||
except Exception:
|
||||
logging.error("Failed to close Docker demuxer socket", exc_info=True)
|
||||
|
||||
|
||||
class DemuxedStdoutReader(TransportReadCloser):
|
||||
def __init__(self, demuxer: DockerDemuxer):
|
||||
self._demuxer = demuxer
|
||||
self._buffer = bytearray()
|
||||
|
||||
def read(self, n: int) -> bytes:
|
||||
if self._buffer:
|
||||
data = bytes(self._buffer[:n])
|
||||
del self._buffer[:n]
|
||||
if data:
|
||||
return data
|
||||
|
||||
chunk = self._demuxer.read_stdout()
|
||||
if len(chunk) <= n:
|
||||
return chunk
|
||||
|
||||
self._buffer.extend(chunk[n:])
|
||||
return chunk[:n]
|
||||
|
||||
def close(self) -> None:
|
||||
self._demuxer.close()
|
||||
|
||||
|
||||
class DemuxedStderrReader(TransportReadCloser):
|
||||
def __init__(self, demuxer: DockerDemuxer):
|
||||
self._demuxer = demuxer
|
||||
self._buffer = bytearray()
|
||||
|
||||
def read(self, n: int) -> bytes:
|
||||
if self._buffer:
|
||||
data = bytes(self._buffer[:n])
|
||||
del self._buffer[:n]
|
||||
if data:
|
||||
return data
|
||||
|
||||
chunk = self._demuxer.read_stderr()
|
||||
if len(chunk) <= n:
|
||||
return chunk
|
||||
|
||||
self._buffer.extend(chunk[n:])
|
||||
return chunk[:n]
|
||||
|
||||
def close(self) -> None:
|
||||
self._demuxer.close()
|
||||
|
||||
|
||||
"""
|
||||
EXAMPLE:
|
||||
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from
|
||||
|
||||
from core.virtual_environment.providers.docker_daemon_sandbox import DockerDaemonEnvironment
|
||||
|
||||
options: Mapping[str, Any] = {
|
||||
# OptionsKey values are optional
|
||||
# DockerDaemonEnvironment.OptionsKey.DOCKER_SOCK: "unix:///var/run/docker.sock",
|
||||
# DockerDaemonEnvironment.OptionsKey.DOCKER_AGENT_IMAGE: "ubuntu:latest",
|
||||
# DockerDaemonEnvironment.OptionsKey.DOCKER_AGENT_COMMAND
|
||||
#
|
||||
"docker_sock": "unix:///var/run/docker.sock", # optional, default to unix socket
|
||||
"docker_agent_image": "ubuntu:latest", # optional, default to ubuntu:latest
|
||||
"docker_agent_command": "/bin/sh -c 'while true; do sleep 1; done'", # optional, default to None
|
||||
}
|
||||
|
||||
|
||||
environment = DockerDaemonEnvironment(options=options)
|
||||
connection_handle = environment.establish_connection()
|
||||
|
||||
pid, transport_stdout, transport_stderr, transport_stdin = environment.execute_command(
|
||||
connection_handle, ["uname", "-a"]
|
||||
)
|
||||
|
||||
print(f"Executed command with PID: {pid}")
|
||||
|
||||
# consume stdout
|
||||
# consume stdout
|
||||
while True:
|
||||
try:
|
||||
output = transport_stdout.read(1024)
|
||||
except TransportEOFError:
|
||||
logger.info("End of stdout reached")
|
||||
break
|
||||
|
||||
logger.info("Command output: %s", output.decode().strip())
|
||||
|
||||
|
||||
environment.release_connection(connection_handle)
|
||||
environment.release_environment()
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class DockerDaemonEnvironment(VirtualEnvironment):
|
||||
_WORKING_DIR = "/workspace"
|
||||
_DEAFULT_DOCKER_IMAGE = "ubuntu:latest"
|
||||
_DEFAULT_DOCKER_SOCK = (
|
||||
"unix:///var/run/docker.sock" # Use an invalid default to avoid accidental local docker usage
|
||||
)
|
||||
|
||||
class OptionsKey(StrEnum):
|
||||
DOCKER_SOCK = "docker_sock"
|
||||
DOCKER_IMAGE = "docker_image"
|
||||
DOCKER_COMMAND = "docker_command"
|
||||
|
||||
@classmethod
|
||||
def get_config_schema(cls) -> list[BasicProviderConfig]:
|
||||
return [
|
||||
BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name=cls.OptionsKey.DOCKER_SOCK),
|
||||
BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name=cls.OptionsKey.DOCKER_IMAGE),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def validate(cls, options: Mapping[str, Any]) -> None:
|
||||
# Import Docker SDK lazily so it is loaded after gevent monkey-patching.
|
||||
import docker.errors
|
||||
|
||||
import docker
|
||||
|
||||
docker_sock = options.get(cls.OptionsKey.DOCKER_SOCK, cls._DEFAULT_DOCKER_SOCK)
|
||||
try:
|
||||
client = docker.DockerClient(base_url=docker_sock)
|
||||
client.ping()
|
||||
except docker.errors.DockerException as e:
|
||||
raise SandboxConfigValidationError(f"Docker connection failed: {e}") from e
|
||||
|
||||
def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata:
|
||||
"""
|
||||
Construct the Docker daemon virtual environment.
|
||||
"""
|
||||
|
||||
docker_client = self.get_docker_daemon(
|
||||
docker_sock=options.get(self.OptionsKey.DOCKER_SOCK, self._DEFAULT_DOCKER_SOCK)
|
||||
)
|
||||
|
||||
default_docker_image = options.get(self.OptionsKey.DOCKER_IMAGE, self._DEAFULT_DOCKER_IMAGE)
|
||||
container_command = options.get(self.OptionsKey.DOCKER_COMMAND)
|
||||
|
||||
container = docker_client.containers.run(
|
||||
image=default_docker_image,
|
||||
command=container_command,
|
||||
detach=True,
|
||||
remove=True,
|
||||
stdin_open=True,
|
||||
working_dir=self._WORKING_DIR,
|
||||
environment=dict(environments),
|
||||
)
|
||||
|
||||
# FIXME(yeuoly): For a better solution
|
||||
if dify_config.FILES_URL.startswith("http://localhost") or dify_config.FILES_URL.startswith("http://127.0.0.1"):
|
||||
logging.warning(
|
||||
"DIFY_FILES_URL is set to a localhost address. "
|
||||
"Docker containers may not be able to access the host's localhost. "
|
||||
"Consider using host.docker.internal or the host machine's IP address."
|
||||
)
|
||||
|
||||
dify_host_port = dify_config.DIFY_PORT
|
||||
# launch socat to forward 5001 port from host to container
|
||||
container.exec_run( # pyright: ignore[reportUnknownMemberType] #
|
||||
cmd=[
|
||||
"bash",
|
||||
"-c",
|
||||
f"nohup socat TCP-LISTEN:{dify_host_port},bind=127.0.0.1,fork,reuseaddr "
|
||||
f"TCP:host.docker.internal:{dify_host_port} >/tmp/socat.log 2>&1",
|
||||
],
|
||||
detach=True,
|
||||
)
|
||||
|
||||
# wait for the container to be fully started
|
||||
container.reload()
|
||||
|
||||
if not container.id:
|
||||
raise VirtualEnvironmentLaunchFailedError("Failed to start Docker container for DockerDaemonEnvironment.")
|
||||
|
||||
return Metadata(
|
||||
id=container.id,
|
||||
arch=self._get_container_architecture(container),
|
||||
os=OperatingSystem.LINUX,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@lru_cache(maxsize=5)
|
||||
def get_docker_daemon(cls, docker_sock: str) -> docker.DockerClient:
|
||||
"""
|
||||
Get the Docker daemon client.
|
||||
|
||||
NOTE: I guess nobody will use more than 5 different docker sockets in practice....
|
||||
"""
|
||||
import docker
|
||||
|
||||
return docker.DockerClient(base_url=docker_sock)
|
||||
|
||||
@classmethod
|
||||
@lru_cache(maxsize=5)
|
||||
def get_docker_api_client(cls, docker_sock: str) -> docker.APIClient:
|
||||
"""
|
||||
Get the Docker low-level API client.
|
||||
"""
|
||||
import docker
|
||||
|
||||
return docker.APIClient(base_url=docker_sock)
|
||||
|
||||
def get_docker_sock(self) -> str:
|
||||
"""
|
||||
Get the Docker socket path.
|
||||
"""
|
||||
return self.options.get(self.OptionsKey.DOCKER_SOCK, self._DEFAULT_DOCKER_SOCK)
|
||||
|
||||
@property
|
||||
def _working_dir(self) -> str:
|
||||
"""
|
||||
Get the working directory inside the Docker container.
|
||||
"""
|
||||
return self._WORKING_DIR
|
||||
|
||||
def _get_container(self) -> Container:
|
||||
"""
|
||||
Get the Docker container instance.
|
||||
"""
|
||||
docker_client = self.get_docker_daemon(self.get_docker_sock())
|
||||
return docker_client.containers.get(self.metadata.id)
|
||||
|
||||
def _normalize_relative_path(self, path: str) -> PurePosixPath:
|
||||
parts: list[str] = []
|
||||
for part in PurePosixPath(path).parts:
|
||||
if part in ("", ".", "/"):
|
||||
continue
|
||||
if part == "..":
|
||||
if not parts:
|
||||
raise ValueError("Path escapes the workspace.")
|
||||
parts.pop()
|
||||
continue
|
||||
parts.append(part)
|
||||
return PurePosixPath(*parts)
|
||||
|
||||
def _relative_path(self, path: str) -> PurePosixPath:
|
||||
normalized = self._normalize_relative_path(path)
|
||||
if normalized.parts:
|
||||
return normalized
|
||||
return PurePosixPath()
|
||||
|
||||
def _container_path(self, path: str) -> str:
|
||||
relative = self._relative_path(path)
|
||||
if not relative.parts:
|
||||
return self._working_dir
|
||||
return f"{self._working_dir}/{relative.as_posix()}"
|
||||
|
||||
def _workspace_path(self, path: str) -> str:
|
||||
"""
|
||||
Convert a path to an absolute path in the Docker container.
|
||||
Absolute paths are returned as-is, relative paths are joined with _working_dir.
|
||||
"""
|
||||
normalized = PurePosixPath(path)
|
||||
if normalized.is_absolute():
|
||||
return str(normalized)
|
||||
return self._container_path(path)
|
||||
|
||||
def upload_file(self, path: str, content: BytesIO) -> None:
|
||||
"""Upload a file to the container.
|
||||
|
||||
Files and intermediate directories are created with world-writable permissions
|
||||
(0o777 for directories, 0o666 for files) to avoid permission issues when the container
|
||||
runs as a non-root user but Docker's put_archive creates files as root.
|
||||
"""
|
||||
container = self._get_container()
|
||||
normalized = PurePosixPath(path)
|
||||
|
||||
if normalized.is_absolute():
|
||||
parent_dir = str(normalized.parent)
|
||||
file_name = normalized.name
|
||||
payload = content.getvalue()
|
||||
tar_stream = BytesIO()
|
||||
with tarfile.open(fileobj=tar_stream, mode="w") as tar:
|
||||
tar_info = tarfile.TarInfo(name=file_name)
|
||||
tar_info.size = len(payload)
|
||||
tar_info.mode = 0o666
|
||||
tar.addfile(tar_info, BytesIO(payload))
|
||||
tar_stream.seek(0)
|
||||
container.put_archive(parent_dir, tar_stream.read()) # pyright: ignore[reportUnknownMemberType] #
|
||||
return
|
||||
|
||||
relative_path = self._relative_path(path)
|
||||
if not relative_path.parts:
|
||||
raise ValueError("Upload path must point to a file within the workspace.")
|
||||
|
||||
payload = content.getvalue()
|
||||
tar_stream = BytesIO()
|
||||
with tarfile.open(fileobj=tar_stream, mode="w") as tar:
|
||||
# Add intermediate directories with proper permissions
|
||||
for i in range(len(relative_path.parts) - 1):
|
||||
dir_path = PurePosixPath(*relative_path.parts[: i + 1])
|
||||
dir_info = tarfile.TarInfo(name=dir_path.as_posix() + "/")
|
||||
dir_info.type = tarfile.DIRTYPE
|
||||
dir_info.mode = 0o777
|
||||
tar.addfile(dir_info)
|
||||
|
||||
# Add the file
|
||||
tar_info = tarfile.TarInfo(name=relative_path.as_posix())
|
||||
tar_info.size = len(payload)
|
||||
tar_info.mode = 0o666
|
||||
tar.addfile(tar_info, BytesIO(payload))
|
||||
tar_stream.seek(0)
|
||||
container.put_archive(self._working_dir, tar_stream.read()) # pyright: ignore[reportUnknownMemberType] #
|
||||
|
||||
def download_file(self, path: str) -> BytesIO:
|
||||
container = self._get_container()
|
||||
container_path = self._workspace_path(path)
|
||||
stream, _ = container.get_archive(container_path)
|
||||
tar_stream = BytesIO()
|
||||
for chunk in stream:
|
||||
tar_stream.write(chunk)
|
||||
tar_stream.seek(0)
|
||||
|
||||
with tarfile.open(fileobj=tar_stream, mode="r:*") as tar:
|
||||
members = [member for member in tar.getmembers() if member.isfile()]
|
||||
if not members:
|
||||
return BytesIO()
|
||||
extracted = tar.extractfile(members[0])
|
||||
if extracted is None:
|
||||
return BytesIO()
|
||||
return BytesIO(extracted.read())
|
||||
|
||||
def list_files(self, directory_path: str, limit: int) -> Sequence[FileState]:
|
||||
import docker.errors
|
||||
|
||||
container = self._get_container()
|
||||
container_path = self._container_path(directory_path)
|
||||
relative_base = self._relative_path(directory_path)
|
||||
try:
|
||||
stream, _ = container.get_archive(container_path)
|
||||
except docker.errors.NotFound:
|
||||
return []
|
||||
tar_stream = BytesIO()
|
||||
for chunk in stream:
|
||||
tar_stream.write(chunk)
|
||||
tar_stream.seek(0)
|
||||
|
||||
files: list[FileState] = []
|
||||
archive_root = PurePosixPath(container_path).name
|
||||
with tarfile.open(fileobj=tar_stream, mode="r:*") as tar:
|
||||
for member in tar.getmembers():
|
||||
if not member.isfile():
|
||||
continue
|
||||
member_path = PurePosixPath(member.name)
|
||||
if member_path.parts and member_path.parts[0] == archive_root:
|
||||
member_path = PurePosixPath(*member_path.parts[1:])
|
||||
if not member_path.parts:
|
||||
continue
|
||||
relative_path = relative_base / member_path
|
||||
files.append(
|
||||
FileState(
|
||||
path=relative_path.as_posix(),
|
||||
size=member.size,
|
||||
created_at=int(member.mtime),
|
||||
updated_at=int(member.mtime),
|
||||
)
|
||||
)
|
||||
if len(files) >= limit:
|
||||
break
|
||||
return files
|
||||
|
||||
def establish_connection(self) -> ConnectionHandle:
|
||||
return ConnectionHandle(id=uuid4().hex)
|
||||
|
||||
def release_connection(self, connection_handle: ConnectionHandle) -> None:
|
||||
# No action needed for Docker exec connections
|
||||
pass
|
||||
|
||||
def release_environment(self) -> None:
|
||||
import docker.errors
|
||||
|
||||
try:
|
||||
container = self._get_container()
|
||||
except docker.errors.NotFound:
|
||||
return
|
||||
try:
|
||||
container.remove(force=True)
|
||||
except docker.errors.NotFound:
|
||||
return
|
||||
|
||||
def execute_command(
|
||||
self,
|
||||
connection_handle: ConnectionHandle,
|
||||
command: list[str],
|
||||
environments: Mapping[str, str] | None = None,
|
||||
cwd: str | None = None,
|
||||
) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]:
|
||||
container = self._get_container()
|
||||
container_id = container.id
|
||||
if not isinstance(container_id, str) or not container_id:
|
||||
raise RuntimeError("Docker container ID is not available for exec.")
|
||||
api_client = self.get_docker_api_client(self.get_docker_sock())
|
||||
|
||||
working_dir = self._workspace_path(cwd) if cwd else self._working_dir
|
||||
|
||||
exec_info: dict[str, object] = cast(
|
||||
dict[str, object],
|
||||
api_client.exec_create( # pyright: ignore[reportUnknownMemberType] #
|
||||
container_id,
|
||||
cmd=command,
|
||||
stdin=True,
|
||||
stdout=True,
|
||||
stderr=True,
|
||||
tty=False,
|
||||
workdir=working_dir,
|
||||
environment=dict(environments) if environments else None,
|
||||
),
|
||||
)
|
||||
|
||||
if not isinstance(exec_info.get("Id"), str):
|
||||
raise RuntimeError("Failed to create Docker exec instance.")
|
||||
|
||||
exec_id: str = str(exec_info.get("Id"))
|
||||
raw_sock: socket.SocketIO = cast(socket.SocketIO, api_client.exec_start(exec_id, socket=True, tty=False)) # pyright: ignore[reportUnknownMemberType] #
|
||||
|
||||
stdin_transport = SocketWriteCloser(raw_sock)
|
||||
demuxer = DockerDemuxer(raw_sock)
|
||||
stdout_transport = DemuxedStdoutReader(demuxer)
|
||||
stderr_transport = DemuxedStderrReader(demuxer)
|
||||
|
||||
return exec_id, stdin_transport, stdout_transport, stderr_transport
|
||||
|
||||
def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus:
|
||||
api_client = self.get_docker_api_client(self.get_docker_sock())
|
||||
inspect: dict[str, object] = cast(dict[str, object], api_client.exec_inspect(pid)) # pyright: ignore[reportUnknownMemberType] #
|
||||
exit_code = inspect.get("ExitCode")
|
||||
if inspect.get("Running") or exit_code is None:
|
||||
return CommandStatus(status=CommandStatus.Status.RUNNING, exit_code=None)
|
||||
if not isinstance(exit_code, int):
|
||||
exit_code = None
|
||||
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=exit_code)
|
||||
|
||||
def _get_container_architecture(self, container: Container) -> Arch:
|
||||
"""
|
||||
Detect the container's CPU architecture from its image metadata.
|
||||
Falls back to ``uname -m`` inside the container when image attrs are unavailable.
|
||||
"""
|
||||
try:
|
||||
image = container.image
|
||||
arch_str = str(image.attrs.get("Architecture", "")).lower() if image else ""
|
||||
except Exception:
|
||||
arch_str = ""
|
||||
|
||||
if not arch_str:
|
||||
result = container.exec_run("uname -m")
|
||||
arch_str = result.output.decode("utf-8", errors="replace").strip().lower() if result.exit_code == 0 else ""
|
||||
|
||||
match arch_str:
|
||||
case "x86_64" | "amd64":
|
||||
return Arch.AMD64
|
||||
case "aarch64" | "arm64":
|
||||
return Arch.ARM64
|
||||
case _:
|
||||
logging.warning("Unknown container architecture '%s', defaulting to AMD64", arch_str)
|
||||
return Arch.AMD64
|
||||
360
api/core/virtual_environment/providers/e2b_sandbox.py
Normal file
360
api/core/virtual_environment/providers/e2b_sandbox.py
Normal file
@ -0,0 +1,360 @@
|
||||
import logging
|
||||
import posixpath
|
||||
import shlex
|
||||
import threading
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum
|
||||
from functools import cached_property
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.virtual_environment.__base.entities import (
|
||||
Arch,
|
||||
CommandStatus,
|
||||
ConnectionHandle,
|
||||
FileState,
|
||||
Metadata,
|
||||
OperatingSystem,
|
||||
)
|
||||
from core.virtual_environment.__base.exec import (
|
||||
ArchNotSupportedError,
|
||||
NotSupportedOperationError,
|
||||
SandboxConfigValidationError,
|
||||
)
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
from core.virtual_environment.channel.queue_transport import QueueTransportReadCloser
|
||||
from core.virtual_environment.channel.transport import (
|
||||
NopTransportWriteCloser,
|
||||
TransportReadCloser,
|
||||
TransportWriteCloser,
|
||||
)
|
||||
from core.virtual_environment.constants import COMMAND_EXECUTION_TIMEOUT_SECONDS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
"""
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.virtual_environment.providers.e2b_sandbox import E2BEnvironment
|
||||
|
||||
options: Mapping[str, Any] = {
|
||||
E2BEnvironment.OptionsKey.API_KEY: "?????????",
|
||||
E2BEnvironment.OptionsKey.E2B_DEFAULT_TEMPLATE: "code-interpreter-v1",
|
||||
E2BEnvironment.OptionsKey.E2B_LIST_FILE_DEPTH: 2,
|
||||
E2BEnvironment.OptionsKey.E2B_API_URL: "https://api.e2b.app",
|
||||
}
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
# environment = DockerDaemonEnvironment(options=options)
|
||||
# environment = LocalVirtualEnvironment(options=options)
|
||||
environment = E2BEnvironment(options=options)
|
||||
|
||||
connection_handle = environment.establish_connection()
|
||||
|
||||
pid, transport_stdin, transport_stdout, transport_stderr = environment.execute_command(
|
||||
connection_handle, ["uname", "-a"]
|
||||
)
|
||||
|
||||
logger.info("Executed command with PID: %s", pid)
|
||||
|
||||
# consume stdout
|
||||
# consume stdout
|
||||
while True:
|
||||
try:
|
||||
output = transport_stdout.read(1024)
|
||||
except TransportEOFError:
|
||||
logger.info("End of stdout reached")
|
||||
break
|
||||
|
||||
logger.info("Command output: %s", output.decode().strip())
|
||||
|
||||
|
||||
environment.release_connection(connection_handle)
|
||||
environment.release_environment()
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class E2BEnvironment(VirtualEnvironment):
|
||||
"""
|
||||
E2B virtual environment provider.
|
||||
"""
|
||||
|
||||
_WORKDIR = "/home/user"
|
||||
_E2B_API_URL = "https://api.e2b.app"
|
||||
|
||||
class OptionsKey(StrEnum):
|
||||
API_KEY = "api_key"
|
||||
E2B_LIST_FILE_DEPTH = "e2b_list_file_depth"
|
||||
E2B_DEFAULT_TEMPLATE = "e2b_default_template"
|
||||
E2B_API_URL = "e2b_api_url"
|
||||
|
||||
class StoreKey(StrEnum):
|
||||
SANDBOX = "sandbox"
|
||||
|
||||
@classmethod
|
||||
def get_config_schema(cls) -> list[BasicProviderConfig]:
|
||||
return [
|
||||
BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=cls.OptionsKey.API_KEY),
|
||||
BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name=cls.OptionsKey.E2B_API_URL),
|
||||
BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name=cls.OptionsKey.E2B_DEFAULT_TEMPLATE),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def validate(cls, options: Mapping[str, Any]) -> None:
|
||||
# Import E2B SDK lazily so it is loaded after gevent monkey-patching.
|
||||
# See `api/gunicorn.conf.py` for how we patch other third-party libs (e.g. gRPC).
|
||||
from e2b.exceptions import (
|
||||
AuthenticationException, # type: ignore[import-untyped]
|
||||
)
|
||||
from e2b_code_interpreter import Sandbox # type: ignore[import-untyped]
|
||||
|
||||
api_key = options.get(cls.OptionsKey.API_KEY, "")
|
||||
if not api_key:
|
||||
raise SandboxConfigValidationError("E2B API key is required")
|
||||
|
||||
try:
|
||||
Sandbox.list(api_key=api_key, limit=1).next_items()
|
||||
except AuthenticationException as e:
|
||||
raise SandboxConfigValidationError(f"E2B authentication failed: {e}") from e
|
||||
except Exception as e:
|
||||
raise SandboxConfigValidationError(f"E2B connection failed: {e}") from e
|
||||
|
||||
def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata:
|
||||
"""
|
||||
Construct a new E2B virtual environment.
|
||||
|
||||
The sandbox lifetime is capped by ``WORKFLOW_MAX_EXECUTION_TIME`` so the
|
||||
provider can rely on E2B's native timeout instead of a background
|
||||
keepalive thread that continuously extends the session.
|
||||
|
||||
E2B allocates the remote sandbox before metadata probing completes, so
|
||||
startup failures must best-effort terminate the sandbox before the
|
||||
exception escapes.
|
||||
"""
|
||||
# Import E2B SDK lazily so it is loaded after gevent monkey-patching.
|
||||
from e2b_code_interpreter import Sandbox # type: ignore[import-untyped]
|
||||
|
||||
# TODO: add Dify as the user agent
|
||||
sandbox = None
|
||||
sandbox_id: str | None = None
|
||||
api_key = options.get(self.OptionsKey.API_KEY, "")
|
||||
try:
|
||||
sandbox = Sandbox.create(
|
||||
template=options.get(self.OptionsKey.E2B_DEFAULT_TEMPLATE, "code-interpreter-v1"),
|
||||
timeout=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
|
||||
api_key=api_key,
|
||||
api_url=options.get(self.OptionsKey.E2B_API_URL, self._E2B_API_URL),
|
||||
envs=dict(environments),
|
||||
)
|
||||
info = sandbox.get_info(api_key=api_key)
|
||||
sandbox_id = info.sandbox_id
|
||||
system_info = sandbox.commands.run("uname -m -s").stdout.strip()
|
||||
system_parts = system_info.split()
|
||||
if len(system_parts) == 2:
|
||||
os_part, arch_part = system_parts
|
||||
else:
|
||||
arch_part = system_parts[0]
|
||||
os_part = system_parts[1] if len(system_parts) > 1 else ""
|
||||
|
||||
return Metadata(
|
||||
id=info.sandbox_id,
|
||||
arch=self._convert_architecture(arch_part.strip()),
|
||||
os=self._convert_operating_system(os_part.strip()),
|
||||
store={
|
||||
self.StoreKey.SANDBOX: sandbox,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
if sandbox_id is None and sandbox is not None:
|
||||
sandbox_id = getattr(sandbox, "sandbox_id", None)
|
||||
if sandbox_id is not None:
|
||||
try:
|
||||
Sandbox.kill(api_key=api_key, sandbox_id=sandbox_id)
|
||||
except Exception:
|
||||
logger.exception("Failed to cleanup E2B sandbox after startup failure")
|
||||
raise
|
||||
|
||||
def release_environment(self) -> None:
|
||||
"""
|
||||
Release the E2B virtual environment.
|
||||
"""
|
||||
from e2b_code_interpreter import Sandbox # type: ignore[import-untyped]
|
||||
|
||||
if not Sandbox.kill(api_key=self.api_key, sandbox_id=self.metadata.id):
|
||||
raise Exception(f"Failed to release E2B sandbox with ID: {self.metadata.id}")
|
||||
|
||||
def establish_connection(self) -> ConnectionHandle:
|
||||
"""
|
||||
Establish a connection to the E2B virtual environment.
|
||||
"""
|
||||
return ConnectionHandle(id=uuid4().hex)
|
||||
|
||||
def release_connection(self, connection_handle: ConnectionHandle) -> None:
|
||||
"""
|
||||
Release the connection to the E2B virtual environment.
|
||||
"""
|
||||
pass
|
||||
|
||||
def upload_file(self, path: str, content: BytesIO) -> None:
|
||||
"""
|
||||
Upload a file to the E2B virtual environment.
|
||||
|
||||
Args:
|
||||
path (str): The path to upload the file to.
|
||||
content (BytesIO): The content of the file.
|
||||
"""
|
||||
remote_path = self._workspace_path(path)
|
||||
sandbox = self.metadata.store[self.StoreKey.SANDBOX]
|
||||
sandbox.files.write(remote_path, content) # pyright: ignore[reportUnknownMemberType] #
|
||||
|
||||
def download_file(self, path: str) -> BytesIO:
|
||||
"""
|
||||
Download a file from the E2B virtual environment.
|
||||
|
||||
Args:
|
||||
path (str): The path to download the file from.
|
||||
Returns:
|
||||
BytesIO: The content of the file.
|
||||
"""
|
||||
remote_path = self._workspace_path(path)
|
||||
sandbox = self.metadata.store[self.StoreKey.SANDBOX]
|
||||
content = sandbox.files.read(remote_path, format="bytes")
|
||||
return BytesIO(bytes(content))
|
||||
|
||||
def list_files(self, directory_path: str, limit: int) -> Sequence[FileState]:
|
||||
"""
|
||||
List files in a directory of the E2B virtual environment.
|
||||
"""
|
||||
sandbox = self.metadata.store[self.StoreKey.SANDBOX]
|
||||
remote_dir = self._workspace_path(directory_path)
|
||||
files_info = sandbox.files.list(remote_dir, depth=self.options.get(self.OptionsKey.E2B_LIST_FILE_DEPTH, 3))
|
||||
return [
|
||||
FileState(
|
||||
path=posixpath.relpath(file_info.path, self._WORKDIR),
|
||||
size=file_info.size,
|
||||
created_at=int(file_info.modified_time.timestamp()),
|
||||
updated_at=int(file_info.modified_time.timestamp()),
|
||||
)
|
||||
for file_info in files_info
|
||||
]
|
||||
|
||||
def execute_command(
|
||||
self,
|
||||
connection_handle: ConnectionHandle,
|
||||
command: list[str],
|
||||
environments: Mapping[str, str] | None = None,
|
||||
cwd: str | None = None,
|
||||
) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]:
|
||||
"""
|
||||
Execute a command in the E2B virtual environment.
|
||||
|
||||
STDIN is not yet supported. E2B's API is such a terrible mess... to support it may lead a bad design.
|
||||
as a result we leave it for future improvement.
|
||||
"""
|
||||
sandbox = self.metadata.store[self.StoreKey.SANDBOX]
|
||||
stdout_stream = QueueTransportReadCloser()
|
||||
stderr_stream = QueueTransportReadCloser()
|
||||
|
||||
working_dir = self._workspace_path(cwd) if cwd else self._WORKDIR
|
||||
|
||||
threading.Thread(
|
||||
target=self._cmd_thread,
|
||||
args=(sandbox, command, environments, working_dir, stdout_stream, stderr_stream),
|
||||
).start()
|
||||
|
||||
return (
|
||||
"N/A",
|
||||
NopTransportWriteCloser(), # stdin not supported yet
|
||||
stdout_stream,
|
||||
stderr_stream,
|
||||
)
|
||||
|
||||
def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus:
|
||||
"""
|
||||
Nop, E2B does not support getting command status yet.
|
||||
"""
|
||||
raise NotSupportedOperationError("E2B does not support getting command status yet.")
|
||||
|
||||
def _cmd_thread(
|
||||
self,
|
||||
sandbox: Any,
|
||||
command: list[str],
|
||||
environments: Mapping[str, str] | None,
|
||||
cwd: str,
|
||||
stdout_stream: QueueTransportReadCloser,
|
||||
stderr_stream: QueueTransportReadCloser,
|
||||
) -> None:
|
||||
stdout_stream_write_handler = stdout_stream.get_write_handler()
|
||||
stderr_stream_write_handler = stderr_stream.get_write_handler()
|
||||
|
||||
try:
|
||||
sandbox.commands.run(
|
||||
cmd=shlex.join(command),
|
||||
envs=dict(environments or {}),
|
||||
cwd=cwd,
|
||||
on_stdout=lambda data: stdout_stream_write_handler.write(data.encode()),
|
||||
on_stderr=lambda data: stderr_stream_write_handler.write(data.encode()),
|
||||
timeout=COMMAND_EXECUTION_TIMEOUT_SECONDS,
|
||||
)
|
||||
except Exception as e:
|
||||
# Capture exceptions and write to stderr stream so they can be retrieved via CommandFuture
|
||||
# This prevents uncaught exceptions from being printed to console
|
||||
error_msg = f"Command execution failed: {type(e).__name__}: {str(e)}\n"
|
||||
stderr_stream_write_handler.write(error_msg.encode())
|
||||
finally:
|
||||
# Close the write handlers to signal EOF
|
||||
stdout_stream.close()
|
||||
stderr_stream.close()
|
||||
|
||||
@cached_property
|
||||
def api_key(self) -> str:
|
||||
"""
|
||||
Get the API key for the E2B environment.
|
||||
"""
|
||||
return self.options.get(self.OptionsKey.API_KEY, "")
|
||||
|
||||
def _workspace_path(self, path: str) -> str:
|
||||
"""
|
||||
Convert a path to an absolute path in the E2B environment.
|
||||
Absolute paths are returned as-is, relative paths are joined with _WORKDIR.
|
||||
"""
|
||||
normalized = posixpath.normpath(path)
|
||||
if normalized in ("", "."):
|
||||
return self._WORKDIR
|
||||
if normalized.startswith("/"):
|
||||
return normalized
|
||||
return posixpath.join(self._WORKDIR, normalized)
|
||||
|
||||
def _convert_architecture(self, arch_str: str) -> Arch:
|
||||
arch_map = {
|
||||
"x86_64": Arch.AMD64,
|
||||
"aarch64": Arch.ARM64,
|
||||
"armv7l": Arch.ARM64,
|
||||
"arm64": Arch.ARM64,
|
||||
"amd64": Arch.AMD64,
|
||||
"arm64v8": Arch.ARM64,
|
||||
"arm64v7": Arch.ARM64,
|
||||
}
|
||||
if arch_str in arch_map:
|
||||
return arch_map[arch_str]
|
||||
|
||||
raise ArchNotSupportedError(f"Unsupported architecture: {arch_str}")
|
||||
|
||||
def _convert_operating_system(self, os_str: str) -> OperatingSystem:
|
||||
os_map = {
|
||||
"Linux": OperatingSystem.LINUX,
|
||||
"Darwin": OperatingSystem.DARWIN,
|
||||
}
|
||||
if os_str in os_map:
|
||||
return os_map[os_str]
|
||||
|
||||
raise ArchNotSupportedError(f"Unsupported operating system: {os_str}")
|
||||
@ -0,0 +1,308 @@
|
||||
import os
|
||||
import pathlib
|
||||
import shutil
|
||||
import signal
|
||||
import subprocess
|
||||
from collections.abc import Mapping, Sequence
|
||||
from functools import cached_property
|
||||
from io import BytesIO
|
||||
from platform import machine, system
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.virtual_environment.__base.entities import (
|
||||
Arch,
|
||||
CommandStatus,
|
||||
ConnectionHandle,
|
||||
FileState,
|
||||
Metadata,
|
||||
OperatingSystem,
|
||||
)
|
||||
from core.virtual_environment.__base.exec import ArchNotSupportedError
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
from core.virtual_environment.channel.pipe_transport import PipeReadCloser, PipeWriteCloser
|
||||
from core.virtual_environment.channel.transport import TransportReadCloser, TransportWriteCloser
|
||||
|
||||
"""
|
||||
USAGE:
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.virtual_environment.channel.exec import TransportEOFError
|
||||
from core.virtual_environment.providers.local_without_isolation import LocalVirtualEnvironment
|
||||
|
||||
options: Mapping[str, Any] = {}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
environment = LocalVirtualEnvironment(options=options)
|
||||
|
||||
connection_handle = environment.establish_connection()
|
||||
|
||||
pid, transport_stdin, transport_stdout, transport_stderr = environment.execute_command(
|
||||
connection_handle,
|
||||
["sh", "-lc", "for i in 1 2 3 4 5; do date '+%F %T'; sleep 1; done"],
|
||||
)
|
||||
|
||||
logger.info("Executed command with PID: %s", pid)
|
||||
|
||||
# consume stdout
|
||||
while True:
|
||||
try:
|
||||
output = transport_stdout.read(1024)
|
||||
except TransportEOFError:
|
||||
logger.info("End of stdout reached")
|
||||
break
|
||||
|
||||
logger.info("Command output: %s", output.decode().strip())
|
||||
|
||||
|
||||
environment.release_connection(connection_handle)
|
||||
environment.release_environment()
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class LocalVirtualEnvironment(VirtualEnvironment):
|
||||
"""
|
||||
Local virtual environment provider without isolation.
|
||||
|
||||
WARNING: This provider does not provide any isolation. It's only suitable for development and testing purposes.
|
||||
NEVER USE IT IN PRODUCTION ENVIRONMENTS.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_config_schema(cls) -> list[BasicProviderConfig]:
|
||||
return [
|
||||
BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name="base_working_path"),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def validate(cls, options: Mapping[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata:
|
||||
"""
|
||||
Construct the local virtual environment.
|
||||
|
||||
Under local without isolation, this method simply create a path for the environment and return the metadata.
|
||||
"""
|
||||
id = uuid4().hex
|
||||
working_path = os.path.join(self._base_working_path, id)
|
||||
os.makedirs(working_path, exist_ok=True)
|
||||
return Metadata(
|
||||
id=id,
|
||||
arch=self._get_os_architecture(),
|
||||
os=self._get_operating_system(),
|
||||
)
|
||||
|
||||
def release_environment(self) -> None:
|
||||
"""
|
||||
Release the local virtual environment.
|
||||
|
||||
Just simply remove the working directory.
|
||||
"""
|
||||
working_path = self.get_working_path()
|
||||
if os.path.exists(working_path):
|
||||
shutil.rmtree(working_path)
|
||||
|
||||
def upload_file(self, path: str, content: BytesIO) -> None:
|
||||
"""
|
||||
Upload a file to the local virtual environment.
|
||||
|
||||
Args:
|
||||
path (str): The path to upload the file to.
|
||||
content (BytesIO): The content of the file.
|
||||
"""
|
||||
working_path = self.get_working_path()
|
||||
full_path = os.path.join(working_path, path)
|
||||
os.makedirs(os.path.dirname(full_path), exist_ok=True)
|
||||
pathlib.Path(full_path).write_bytes(content.getbuffer())
|
||||
|
||||
def download_file(self, path: str) -> BytesIO:
|
||||
"""
|
||||
Download a file from the local virtual environment.
|
||||
|
||||
Args:
|
||||
path (str): The path to download the file from.
|
||||
Returns:
|
||||
BytesIO: The content of the file.
|
||||
"""
|
||||
working_path = self.get_working_path()
|
||||
full_path = os.path.join(working_path, path)
|
||||
content = pathlib.Path(full_path).read_bytes()
|
||||
return BytesIO(content)
|
||||
|
||||
def list_files(self, directory_path: str, limit: int) -> Sequence[FileState]:
|
||||
"""
|
||||
List files in a directory of the local virtual environment.
|
||||
"""
|
||||
working_path = self.get_working_path()
|
||||
full_directory_path = os.path.join(working_path, directory_path)
|
||||
files: list[FileState] = []
|
||||
for root, _, filenames in os.walk(full_directory_path):
|
||||
for filename in filenames:
|
||||
if len(files) >= limit:
|
||||
break
|
||||
file_path = os.path.relpath(os.path.join(root, filename), working_path)
|
||||
state = os.stat(os.path.join(root, filename))
|
||||
files.append(
|
||||
FileState(
|
||||
path=file_path,
|
||||
size=state.st_size,
|
||||
created_at=int(state.st_ctime),
|
||||
updated_at=int(state.st_mtime),
|
||||
)
|
||||
)
|
||||
if len(files) >= limit:
|
||||
# break the outer loop as well
|
||||
return files
|
||||
|
||||
return files
|
||||
|
||||
def establish_connection(self) -> ConnectionHandle:
|
||||
"""
|
||||
Establish a connection to the local virtual environment.
|
||||
"""
|
||||
return ConnectionHandle(
|
||||
id=uuid4().hex,
|
||||
)
|
||||
|
||||
def release_connection(self, connection_handle: ConnectionHandle) -> None:
|
||||
"""
|
||||
Release the connection to the local virtual environment.
|
||||
"""
|
||||
# No action needed for local without isolation
|
||||
pass
|
||||
|
||||
def execute_command(
|
||||
self,
|
||||
connection_handle: ConnectionHandle,
|
||||
command: list[str],
|
||||
environments: Mapping[str, str] | None = None,
|
||||
cwd: str | None = None,
|
||||
) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]:
|
||||
working_path = os.path.join(self.get_working_path(), cwd) if cwd else self.get_working_path()
|
||||
stdin_read_fd, stdin_write_fd = os.pipe()
|
||||
stdout_read_fd, stdout_write_fd = os.pipe()
|
||||
stderr_read_fd, stderr_write_fd = os.pipe()
|
||||
try:
|
||||
process = subprocess.Popen(
|
||||
command,
|
||||
stdin=stdin_read_fd,
|
||||
stdout=stdout_write_fd,
|
||||
stderr=stderr_write_fd,
|
||||
cwd=working_path,
|
||||
close_fds=True,
|
||||
env=environments,
|
||||
)
|
||||
except Exception:
|
||||
# Clean up file descriptors if process creation fails
|
||||
for fd in (
|
||||
stdin_read_fd,
|
||||
stdin_write_fd,
|
||||
stdout_read_fd,
|
||||
stdout_write_fd,
|
||||
stderr_read_fd,
|
||||
stderr_write_fd,
|
||||
):
|
||||
try:
|
||||
os.close(fd)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
|
||||
# Close unused fds in the parent process
|
||||
os.close(stdin_read_fd)
|
||||
os.close(stdout_write_fd)
|
||||
os.close(stderr_write_fd)
|
||||
|
||||
# Create PipeTransport instances for stdin, stdout, and stderr
|
||||
stdin_transport = PipeWriteCloser(w_fd=stdin_write_fd)
|
||||
stdout_transport = PipeReadCloser(r_fd=stdout_read_fd)
|
||||
stderr_transport = PipeReadCloser(r_fd=stderr_read_fd)
|
||||
|
||||
# Return the process ID and file descriptors for stdin, stdout, and stderr
|
||||
return str(process.pid), stdin_transport, stdout_transport, stderr_transport
|
||||
|
||||
def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus:
|
||||
pid_int = int(pid)
|
||||
try:
|
||||
waited_pid, wait_status = os.waitpid(pid_int, os.WNOHANG)
|
||||
if waited_pid == 0:
|
||||
return CommandStatus(status=CommandStatus.Status.RUNNING, exit_code=None)
|
||||
|
||||
if os.WIFEXITED(wait_status):
|
||||
exit_code = os.WEXITSTATUS(wait_status)
|
||||
elif os.WIFSIGNALED(wait_status):
|
||||
exit_code = -os.WTERMSIG(wait_status)
|
||||
else:
|
||||
exit_code = None
|
||||
|
||||
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=exit_code)
|
||||
except ChildProcessError:
|
||||
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=None)
|
||||
|
||||
def terminate_command(self, connection_handle: ConnectionHandle, pid: str) -> bool:
|
||||
"""Terminate a locally spawned process by PID when cancellation is requested."""
|
||||
|
||||
_ = connection_handle
|
||||
try:
|
||||
os.kill(int(pid), signal.SIGTERM)
|
||||
except ProcessLookupError:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _get_os_architecture(self) -> Arch:
|
||||
"""
|
||||
Get the operating system architecture.
|
||||
|
||||
Returns:
|
||||
Arch: The operating system architecture.
|
||||
"""
|
||||
|
||||
arch = machine()
|
||||
match arch.lower():
|
||||
case "x86_64" | "amd64":
|
||||
return Arch.AMD64
|
||||
case "aarch64" | "arm64":
|
||||
return Arch.ARM64
|
||||
case _:
|
||||
raise ArchNotSupportedError(f"Unsupported architecture: {arch}")
|
||||
|
||||
def _get_operating_system(self) -> OperatingSystem:
|
||||
os_name = system().lower()
|
||||
match os_name:
|
||||
case "linux":
|
||||
return OperatingSystem.LINUX
|
||||
case "darwin":
|
||||
return OperatingSystem.DARWIN
|
||||
case _:
|
||||
raise ArchNotSupportedError(f"Unsupported operating system: {os_name}")
|
||||
|
||||
@cached_property
|
||||
def _base_working_path(self) -> str:
|
||||
"""
|
||||
Get the base working path for the local virtual environment.
|
||||
|
||||
Args:
|
||||
options (Mapping[str, Any]): Options for requesting the virtual environment.
|
||||
|
||||
Returns:
|
||||
str: The base working path.
|
||||
"""
|
||||
cwd = os.getcwd()
|
||||
return self.options.get("base_working_path", os.path.join(cwd, "local_virtual_environments"))
|
||||
|
||||
def get_working_path(self) -> str:
|
||||
"""
|
||||
Get the working path for the local virtual environment.
|
||||
|
||||
Returns:
|
||||
str: The working path.
|
||||
"""
|
||||
return os.path.join(self._base_working_path, self.metadata.id)
|
||||
499
api/core/virtual_environment/providers/ssh_sandbox.py
Normal file
499
api/core/virtual_environment/providers/ssh_sandbox.py
Normal file
@ -0,0 +1,499 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import shlex
|
||||
import stat
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum
|
||||
from io import BytesIO
|
||||
from pathlib import PurePosixPath
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.virtual_environment.__base.entities import (
|
||||
Arch,
|
||||
CommandStatus,
|
||||
ConnectionHandle,
|
||||
FileState,
|
||||
Metadata,
|
||||
OperatingSystem,
|
||||
)
|
||||
from core.virtual_environment.__base.exec import SandboxConfigValidationError, VirtualEnvironmentLaunchFailedError
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
from core.virtual_environment.channel.exec import TransportEOFError
|
||||
from core.virtual_environment.channel.queue_transport import QueueTransportReadCloser
|
||||
from core.virtual_environment.channel.transport import TransportWriteCloser
|
||||
from core.virtual_environment.constants import COMMAND_EXECUTION_TIMEOUT_SECONDS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _SSHStdinTransport(TransportWriteCloser):
|
||||
def __init__(self, channel: Any):
|
||||
self._channel = channel
|
||||
self._closed = False
|
||||
|
||||
def write(self, data: bytes) -> None:
|
||||
if self._closed:
|
||||
raise TransportEOFError("Transport is closed")
|
||||
if not data:
|
||||
return
|
||||
self._channel.sendall(data)
|
||||
|
||||
def close(self) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
with contextlib.suppress(Exception):
|
||||
self._channel.shutdown_write()
|
||||
|
||||
|
||||
class SSHSandboxEnvironment(VirtualEnvironment):
|
||||
_DEFAULT_SSH_HOST = "agentbox"
|
||||
_DEFAULT_SSH_PORT = 22
|
||||
_DEFAULT_BASE_WORKING_PATH = "/workspace/sandboxes"
|
||||
_SSH_CONNECT_TIMEOUT_SECONDS = 10
|
||||
_SSH_OPERATION_TIMEOUT_SECONDS = 30
|
||||
_COMMAND_TIMEOUT_EXIT_CODE = 124
|
||||
|
||||
class OptionsKey(StrEnum):
|
||||
SSH_HOST = "ssh_host"
|
||||
SSH_PORT = "ssh_port"
|
||||
SSH_USERNAME = "ssh_username"
|
||||
SSH_PASSWORD = "ssh_password"
|
||||
BASE_WORKING_PATH = "base_working_path"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
options: Mapping[str, Any],
|
||||
environments: Mapping[str, str] | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> None:
|
||||
self._connections: dict[str, Any] = {}
|
||||
self._commands: dict[str, CommandStatus] = {}
|
||||
self._command_channels: dict[str, Any] = {}
|
||||
self._lock = threading.Lock()
|
||||
super().__init__(tenant_id=tenant_id, options=options, environments=environments, user_id=user_id)
|
||||
|
||||
@classmethod
|
||||
def get_config_schema(cls) -> list[BasicProviderConfig]:
|
||||
return [
|
||||
BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name=cls.OptionsKey.SSH_HOST),
|
||||
BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name=cls.OptionsKey.SSH_PORT),
|
||||
BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name=cls.OptionsKey.SSH_USERNAME),
|
||||
BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=cls.OptionsKey.SSH_PASSWORD),
|
||||
BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name=cls.OptionsKey.BASE_WORKING_PATH),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def validate(cls, options: Mapping[str, Any]) -> None:
|
||||
cls._require_non_empty_option(options, cls.OptionsKey.SSH_USERNAME)
|
||||
cls._require_non_empty_option(options, cls.OptionsKey.SSH_PASSWORD)
|
||||
with cls._create_ssh_client(options):
|
||||
return
|
||||
|
||||
def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata:
|
||||
environment_id = uuid4().hex
|
||||
working_path = self._workspace_path_from_id(environment_id)
|
||||
|
||||
try:
|
||||
with self._client() as client:
|
||||
self._run_command(client, f"mkdir -p {shlex.quote(working_path)}")
|
||||
arch_stdout = self._run_command(client, "uname -m")
|
||||
os_stdout = self._run_command(client, "uname -s")
|
||||
except SandboxConfigValidationError as e:
|
||||
raise ValueError(f"SSH configuration validation failed, please check sandbox provider: {e}") from e
|
||||
except Exception as e:
|
||||
raise VirtualEnvironmentLaunchFailedError(f"Failed to construct SSH environment: {e}") from e
|
||||
|
||||
return Metadata(
|
||||
id=environment_id,
|
||||
arch=self._parse_arch(arch_stdout.decode("utf-8", errors="replace").strip()),
|
||||
os=self._parse_os(os_stdout.decode("utf-8", errors="replace").strip()),
|
||||
store={"working_path": working_path},
|
||||
)
|
||||
|
||||
def establish_connection(self) -> ConnectionHandle:
|
||||
connection_id = uuid4().hex
|
||||
client = self._create_ssh_client(self.options)
|
||||
with self._lock:
|
||||
self._connections[connection_id] = client
|
||||
return ConnectionHandle(id=connection_id)
|
||||
|
||||
def release_connection(self, connection_handle: ConnectionHandle) -> None:
|
||||
with self._lock:
|
||||
client = self._connections.pop(connection_handle.id, None)
|
||||
if client is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
client.close()
|
||||
|
||||
def release_environment(self) -> None:
|
||||
working_path = self.get_working_path()
|
||||
with contextlib.suppress(Exception):
|
||||
with self._client() as client:
|
||||
self._run_command(client, f"rm -rf {shlex.quote(working_path)}")
|
||||
|
||||
def execute_command(
|
||||
self,
|
||||
connection_handle: ConnectionHandle,
|
||||
command: list[str],
|
||||
environments: Mapping[str, str] | None = None,
|
||||
cwd: str | None = None,
|
||||
) -> tuple[str, TransportWriteCloser, QueueTransportReadCloser, QueueTransportReadCloser]:
|
||||
client = self._get_connection(connection_handle)
|
||||
transport = client.get_transport()
|
||||
if transport is None:
|
||||
raise RuntimeError("SSH transport is not available")
|
||||
|
||||
channel = transport.open_session()
|
||||
channel.settimeout(self._SSH_OPERATION_TIMEOUT_SECONDS)
|
||||
channel.set_combine_stderr(False)
|
||||
|
||||
execution_command = self._build_exec_command(command, environments, cwd)
|
||||
channel.exec_command(execution_command)
|
||||
|
||||
pid = uuid4().hex
|
||||
stdin_transport = _SSHStdinTransport(channel)
|
||||
stdout_transport = QueueTransportReadCloser()
|
||||
stderr_transport = QueueTransportReadCloser()
|
||||
|
||||
with self._lock:
|
||||
self._commands[pid] = CommandStatus(status=CommandStatus.Status.RUNNING, exit_code=None)
|
||||
self._command_channels[pid] = channel
|
||||
|
||||
threading.Thread(
|
||||
target=self._consume_channel_output,
|
||||
args=(pid, channel, stdout_transport, stderr_transport, COMMAND_EXECUTION_TIMEOUT_SECONDS),
|
||||
daemon=True,
|
||||
).start()
|
||||
|
||||
return pid, stdin_transport, stdout_transport, stderr_transport
|
||||
|
||||
def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus:
|
||||
with self._lock:
|
||||
status = self._commands.get(pid)
|
||||
if status is None:
|
||||
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=None)
|
||||
return status
|
||||
|
||||
def terminate_command(self, connection_handle: ConnectionHandle, pid: str) -> bool:
|
||||
"""Best-effort termination by closing the SSH channel that owns the command."""
|
||||
|
||||
_ = connection_handle
|
||||
with self._lock:
|
||||
channel = self._command_channels.get(pid)
|
||||
if channel is None:
|
||||
return False
|
||||
self._commands[pid] = CommandStatus(
|
||||
status=CommandStatus.Status.COMPLETED,
|
||||
exit_code=self._COMMAND_TIMEOUT_EXIT_CODE,
|
||||
)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
channel.close()
|
||||
return True
|
||||
|
||||
def upload_file(self, path: str, content: BytesIO) -> None:
|
||||
destination_path = self._workspace_path(path)
|
||||
with self._client() as client:
|
||||
sftp = client.open_sftp()
|
||||
try:
|
||||
self._set_sftp_operation_timeout(sftp)
|
||||
self._sftp_mkdirs(sftp, str(PurePosixPath(destination_path).parent))
|
||||
with sftp.file(destination_path, "wb") as remote_file:
|
||||
remote_file.write(content.getvalue())
|
||||
finally:
|
||||
sftp.close()
|
||||
|
||||
def download_file(self, path: str) -> BytesIO:
|
||||
source_path = self._workspace_path(path)
|
||||
with self._client() as client:
|
||||
sftp = client.open_sftp()
|
||||
try:
|
||||
self._set_sftp_operation_timeout(sftp)
|
||||
with sftp.file(source_path, "rb") as remote_file:
|
||||
return BytesIO(remote_file.read())
|
||||
finally:
|
||||
sftp.close()
|
||||
|
||||
def list_files(self, directory_path: str, limit: int) -> Sequence[FileState]:
|
||||
if limit <= 0:
|
||||
return []
|
||||
|
||||
root_directory = self._workspace_path(directory_path)
|
||||
files: list[FileState] = []
|
||||
|
||||
with self._client() as client:
|
||||
sftp = client.open_sftp()
|
||||
try:
|
||||
self._set_sftp_operation_timeout(sftp)
|
||||
pending = [root_directory]
|
||||
while pending and len(files) < limit:
|
||||
current_directory = pending.pop(0)
|
||||
with contextlib.suppress(FileNotFoundError):
|
||||
for attr in sftp.listdir_attr(current_directory):
|
||||
current_path = str(PurePosixPath(current_directory) / attr.filename)
|
||||
mode = attr.st_mode
|
||||
if stat.S_ISDIR(mode):
|
||||
pending.append(current_path)
|
||||
continue
|
||||
|
||||
files.append(
|
||||
FileState(
|
||||
path=self._to_relative_workspace_path(current_path),
|
||||
size=attr.st_size,
|
||||
created_at=int(attr.st_mtime),
|
||||
updated_at=int(attr.st_mtime),
|
||||
)
|
||||
)
|
||||
if len(files) >= limit:
|
||||
break
|
||||
finally:
|
||||
sftp.close()
|
||||
|
||||
return files
|
||||
|
||||
@classmethod
|
||||
def _require_non_empty_option(cls, options: Mapping[str, Any], key: OptionsKey) -> str:
|
||||
value = options.get(key)
|
||||
if not isinstance(value, str) or not value.strip():
|
||||
raise SandboxConfigValidationError(f"Missing required option: {key}")
|
||||
return value.strip()
|
||||
|
||||
@classmethod
|
||||
def _create_ssh_client(cls, options: Mapping[str, Any]) -> Any:
|
||||
import paramiko
|
||||
|
||||
host = options.get(cls.OptionsKey.SSH_HOST, cls._DEFAULT_SSH_HOST)
|
||||
port = options.get(cls.OptionsKey.SSH_PORT, cls._DEFAULT_SSH_PORT)
|
||||
username = cls._require_non_empty_option(options, cls.OptionsKey.SSH_USERNAME)
|
||||
password = cls._require_non_empty_option(options, cls.OptionsKey.SSH_PASSWORD)
|
||||
|
||||
if not isinstance(host, str) or not host.strip():
|
||||
raise SandboxConfigValidationError(f"Invalid option value: {cls.OptionsKey.SSH_HOST}")
|
||||
|
||||
try:
|
||||
port_int = int(port)
|
||||
except (TypeError, ValueError) as e:
|
||||
raise SandboxConfigValidationError(f"Invalid option value: {cls.OptionsKey.SSH_PORT}") from e
|
||||
|
||||
client = paramiko.SSHClient()
|
||||
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||
|
||||
try:
|
||||
client.connect(
|
||||
hostname=host.strip(),
|
||||
port=port_int,
|
||||
username=username,
|
||||
password=password,
|
||||
look_for_keys=False,
|
||||
allow_agent=False,
|
||||
timeout=cls._SSH_CONNECT_TIMEOUT_SECONDS,
|
||||
banner_timeout=cls._SSH_CONNECT_TIMEOUT_SECONDS,
|
||||
auth_timeout=cls._SSH_CONNECT_TIMEOUT_SECONDS,
|
||||
)
|
||||
transport = client.get_transport()
|
||||
if transport is not None:
|
||||
transport.set_keepalive(30)
|
||||
except Exception as e:
|
||||
with contextlib.suppress(Exception):
|
||||
client.close()
|
||||
raise SandboxConfigValidationError(f"SSH connection failed: {e}") from e
|
||||
|
||||
return client
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _client(self):
|
||||
client = self._create_ssh_client(self.options)
|
||||
try:
|
||||
yield client
|
||||
finally:
|
||||
with contextlib.suppress(Exception):
|
||||
client.close()
|
||||
|
||||
def _get_connection(self, connection_handle: ConnectionHandle) -> Any:
|
||||
with self._lock:
|
||||
client = self._connections.get(connection_handle.id)
|
||||
if client is None:
|
||||
raise ValueError(f"Connection handle not found: {connection_handle.id}")
|
||||
return client
|
||||
|
||||
def _workspace_path_from_id(self, environment_id: str) -> str:
|
||||
base_path = self.options.get(self.OptionsKey.BASE_WORKING_PATH, self._DEFAULT_BASE_WORKING_PATH)
|
||||
if not isinstance(base_path, str) or not base_path.strip():
|
||||
base_path = self._DEFAULT_BASE_WORKING_PATH
|
||||
return str(PurePosixPath(base_path) / environment_id)
|
||||
|
||||
def get_working_path(self) -> str:
|
||||
working_path = self.metadata.store.get("working_path")
|
||||
if not isinstance(working_path, str) or not working_path:
|
||||
return self._workspace_path_from_id(self.metadata.id)
|
||||
return working_path
|
||||
|
||||
def _workspace_path(self, path: str | None) -> str:
|
||||
if not path:
|
||||
return self.get_working_path()
|
||||
|
||||
normalized = PurePosixPath(path)
|
||||
if normalized.is_absolute():
|
||||
return str(normalized)
|
||||
return str(PurePosixPath(self.get_working_path()) / self._normalize_relative_path(path))
|
||||
|
||||
@staticmethod
|
||||
def _normalize_relative_path(path: str) -> PurePosixPath:
|
||||
parts: list[str] = []
|
||||
for part in PurePosixPath(path).parts:
|
||||
if part in ("", ".", "/"):
|
||||
continue
|
||||
if part == "..":
|
||||
if not parts:
|
||||
raise ValueError("Path escapes the workspace.")
|
||||
parts.pop()
|
||||
continue
|
||||
parts.append(part)
|
||||
return PurePosixPath(*parts)
|
||||
|
||||
def _to_relative_workspace_path(self, path: str) -> str:
|
||||
workspace = PurePosixPath(self.get_working_path())
|
||||
target = PurePosixPath(path)
|
||||
if target.is_relative_to(workspace):
|
||||
return target.relative_to(workspace).as_posix()
|
||||
return target.as_posix()
|
||||
|
||||
def _build_exec_command(
|
||||
self, command: list[str], environments: Mapping[str, str] | None = None, cwd: str | None = None
|
||||
) -> str:
|
||||
working_path = self._workspace_path(cwd)
|
||||
command_body = f"cd {shlex.quote(working_path)} && "
|
||||
|
||||
if environments:
|
||||
env_clause = " ".join(f"{key}={shlex.quote(value)}" for key, value in environments.items())
|
||||
command_body += f"{env_clause} "
|
||||
|
||||
command_body += shlex.join(command)
|
||||
return f"sh -lc {shlex.quote(command_body)}"
|
||||
|
||||
@classmethod
|
||||
def _run_command(cls, client: Any, command: str) -> bytes:
|
||||
_, stdout, stderr = client.exec_command(command, timeout=cls._SSH_OPERATION_TIMEOUT_SECONDS)
|
||||
stdout.channel.settimeout(cls._SSH_OPERATION_TIMEOUT_SECONDS)
|
||||
|
||||
deadline = time.monotonic() + COMMAND_EXECUTION_TIMEOUT_SECONDS
|
||||
while not stdout.channel.exit_status_ready():
|
||||
if time.monotonic() >= deadline:
|
||||
with contextlib.suppress(Exception):
|
||||
stdout.channel.close()
|
||||
raise TimeoutError(f"SSH command timed out after {COMMAND_EXECUTION_TIMEOUT_SECONDS}s")
|
||||
time.sleep(0.05)
|
||||
|
||||
exit_code = stdout.channel.recv_exit_status()
|
||||
stdout_data = stdout.read()
|
||||
stderr_data = stderr.read()
|
||||
|
||||
if exit_code != 0:
|
||||
stderr_text = stderr_data.decode("utf-8", errors="replace")
|
||||
raise RuntimeError(f"SSH command failed ({exit_code}): {stderr_text}")
|
||||
|
||||
return stdout_data
|
||||
|
||||
def _consume_channel_output(
|
||||
self,
|
||||
pid: str,
|
||||
channel: Any,
|
||||
stdout_transport: QueueTransportReadCloser,
|
||||
stderr_transport: QueueTransportReadCloser,
|
||||
max_runtime_seconds: int,
|
||||
) -> None:
|
||||
stdout_writer = stdout_transport.get_write_handler()
|
||||
stderr_writer = stderr_transport.get_write_handler()
|
||||
exit_code: int | None = None
|
||||
started_at = time.monotonic()
|
||||
|
||||
try:
|
||||
while True:
|
||||
if time.monotonic() - started_at >= max_runtime_seconds:
|
||||
exit_code = self._COMMAND_TIMEOUT_EXIT_CODE
|
||||
stderr_writer.write(f"Command timed out after {max_runtime_seconds}s".encode())
|
||||
break
|
||||
|
||||
if channel.recv_ready():
|
||||
stdout_writer.write(channel.recv(4096))
|
||||
if channel.recv_stderr_ready():
|
||||
stderr_writer.write(channel.recv_stderr(4096))
|
||||
|
||||
if channel.exit_status_ready() and not channel.recv_ready() and not channel.recv_stderr_ready():
|
||||
exit_code = int(channel.recv_exit_status())
|
||||
break
|
||||
|
||||
time.sleep(0.05)
|
||||
except TimeoutError:
|
||||
logger.warning("SSH channel read timed out for command %s", pid)
|
||||
exit_code = self._COMMAND_TIMEOUT_EXIT_CODE
|
||||
finally:
|
||||
with contextlib.suppress(Exception):
|
||||
stdout_transport.close()
|
||||
with contextlib.suppress(Exception):
|
||||
stderr_transport.close()
|
||||
with contextlib.suppress(Exception):
|
||||
channel.close()
|
||||
|
||||
with self._lock:
|
||||
self._command_channels.pop(pid, None)
|
||||
self._commands[pid] = CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=exit_code)
|
||||
|
||||
def _set_sftp_operation_timeout(self, sftp: Any) -> None:
|
||||
with contextlib.suppress(Exception):
|
||||
sftp.get_channel().settimeout(self._SSH_OPERATION_TIMEOUT_SECONDS)
|
||||
|
||||
@staticmethod
|
||||
def _parse_arch(raw_arch: str) -> Arch:
|
||||
arch = raw_arch.lower()
|
||||
if arch in {"x86_64", "amd64"}:
|
||||
return Arch.AMD64
|
||||
if arch in {"arm64", "aarch64"}:
|
||||
return Arch.ARM64
|
||||
return Arch.AMD64
|
||||
|
||||
@staticmethod
|
||||
def _parse_os(raw_os: str) -> OperatingSystem:
|
||||
system_name = raw_os.lower()
|
||||
if system_name == "darwin":
|
||||
return OperatingSystem.DARWIN
|
||||
return OperatingSystem.LINUX
|
||||
|
||||
@staticmethod
|
||||
def _sftp_mkdirs(sftp: Any, directory: str) -> None:
|
||||
if not directory or directory == "/":
|
||||
return
|
||||
|
||||
path = PurePosixPath(directory)
|
||||
current = PurePosixPath("/") if path.is_absolute() else PurePosixPath()
|
||||
|
||||
for part in path.parts:
|
||||
if part in ("", "/"):
|
||||
continue
|
||||
current = current / part
|
||||
current_path = str(current)
|
||||
try:
|
||||
attrs = sftp.stat(current_path)
|
||||
if not stat.S_ISDIR(attrs.st_mode):
|
||||
raise OSError(f"Path exists but is not a directory: {current_path}")
|
||||
continue
|
||||
except OSError as e:
|
||||
missing = isinstance(e, FileNotFoundError) or getattr(e, "errno", None) == 2
|
||||
missing = missing or "no such file" in str(e).lower()
|
||||
if not missing:
|
||||
raise
|
||||
|
||||
try:
|
||||
sftp.mkdir(current_path)
|
||||
except OSError:
|
||||
# Some SFTP servers report generic "Failure" when directory already exists.
|
||||
attrs = sftp.stat(current_path)
|
||||
if not stat.S_ISDIR(attrs.st_mode):
|
||||
raise OSError(f"Failed to create directory: {current_path}")
|
||||
@ -90,6 +90,10 @@ def init_app(app: DifyApp):
|
||||
app.register_blueprint(inner_api_bp)
|
||||
app.register_blueprint(mcp_bp)
|
||||
|
||||
# TODO: enable after full sandbox integration
|
||||
# from controllers.cli_api import bp as cli_api_bp
|
||||
# app.register_blueprint(cli_api_bp)
|
||||
|
||||
# Register trigger blueprint with CORS for webhook calls
|
||||
_apply_cors_once(
|
||||
trigger_bp,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user