diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index d37cff63e9..f884489f5e 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -271,6 +271,17 @@ class PluginConfig(BaseSettings): ) +class CliApiConfig(BaseSettings): + """ + Configuration for CLI API (for dify-cli to call back from external sandbox environments) + """ + + CLI_API_URL: str = Field( + description="CLI API URL for external sandbox (e.g., e2b) to call back.", + default="http://localhost:5001", + ) + + class MarketplaceConfig(BaseSettings): """ Configuration for marketplace @@ -287,6 +298,27 @@ class MarketplaceConfig(BaseSettings): ) +class CreatorsPlatformConfig(BaseSettings): + """ + Configuration for creators platform + """ + + CREATORS_PLATFORM_FEATURES_ENABLED: bool = Field( + description="Enable or disable creators platform features", + default=True, + ) + + CREATORS_PLATFORM_API_URL: HttpUrl = Field( + description="Creators Platform API URL", + default=HttpUrl("https://creators.dify.ai"), + ) + + CREATORS_PLATFORM_OAUTH_CLIENT_ID: str = Field( + description="OAuth client_id for the Creators Platform app registered in Dify", + default="", + ) + + class EndpointConfig(BaseSettings): """ Configuration for various application endpoints and URLs @@ -341,6 +373,15 @@ class FileAccessConfig(BaseSettings): default="", ) + FILES_API_URL: str = Field( + description="Base URL for storage file ticket API endpoints." + " Used by sandbox containers (internal or external like e2b) that need" + " an absolute, routable address to upload/download files via the API." + " For all-in-one Docker deployments, set to http://localhost." + " For public sandbox environments, set to a public domain or IP.", + default="", + ) + FILES_ACCESS_TIMEOUT: int = Field( description="Expiration time in seconds for file access URLs", default=300, @@ -1274,6 +1315,13 @@ class PositionConfig(BaseSettings): return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""} +class CollaborationConfig(BaseSettings): + ENABLE_COLLABORATION_MODE: bool = Field( + description="Whether to enable collaboration mode features across the workspace", + default=False, + ) + + class LoginConfig(BaseSettings): ENABLE_EMAIL_CODE_LOGIN: bool = Field( description="whether to enable email code login", @@ -1375,7 +1423,9 @@ class FeatureConfig( TriggerConfig, AsyncWorkflowConfig, PluginConfig, + CliApiConfig, MarketplaceConfig, + CreatorsPlatformConfig, DataSetConfig, EndpointConfig, FileAccessConfig, @@ -1399,6 +1449,7 @@ class FeatureConfig( WorkflowConfig, WorkflowNodeExecutionConfig, WorkspaceConfig, + CollaborationConfig, LoginConfig, AccountConfig, SwaggerUIConfig, diff --git a/api/controllers/cli_api/__init__.py b/api/controllers/cli_api/__init__.py new file mode 100644 index 0000000000..9cd044b24e --- /dev/null +++ b/api/controllers/cli_api/__init__.py @@ -0,0 +1,27 @@ +from flask import Blueprint +from flask_restx import Namespace + +from libs.external_api import ExternalApi + +bp = Blueprint("cli_api", __name__, url_prefix="/cli/api") + +api = ExternalApi( + bp, + version="1.0", + title="CLI API", + description="APIs for Dify CLI to call back from external sandbox environments (e.g., e2b)", +) + +# Create namespace +cli_api_ns = Namespace("cli_api", description="CLI API operations", path="/") + +from .dify_cli import cli_api as _plugin + +api.add_namespace(cli_api_ns) + +__all__ = [ + "_plugin", + "api", + "bp", + "cli_api_ns", +] diff --git a/api/controllers/cli_api/dify_cli/__init__.py b/api/controllers/cli_api/dify_cli/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/controllers/cli_api/dify_cli/cli_api.py b/api/controllers/cli_api/dify_cli/cli_api.py new file mode 100644 index 0000000000..efcaaf0bf6 --- /dev/null +++ b/api/controllers/cli_api/dify_cli/cli_api.py @@ -0,0 +1,192 @@ +from flask import abort +from flask_restx import Resource +from pydantic import BaseModel + +from controllers.cli_api import cli_api_ns +from controllers.cli_api.dify_cli.wraps import get_cli_user_tenant, plugin_data +from controllers.cli_api.wraps import cli_api_only +from controllers.console.wraps import setup_required +from core.app.entities.app_invoke_entities import InvokeFrom +from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation +from core.plugin.backwards_invocation.base import BaseBackwardsInvocationResponse +from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation +from core.plugin.backwards_invocation.tool import PluginToolBackwardsInvocation +from core.plugin.entities.request import ( + RequestInvokeApp, + RequestInvokeLLM, + RequestInvokeTool, + RequestRequestUploadFile, +) +from core.sandbox.bash.dify_cli import DifyCliToolConfig +from core.session.cli_api import CliContext +from core.skill.entities import ToolInvocationRequest +from core.tools.entities.tool_entities import ToolProviderType +from core.tools.tool_manager import ToolManager +from graphon.file.helpers import get_signed_file_url_for_plugin +from libs.helper import length_prefixed_response +from models.account import Account +from models.model import EndUser, Tenant + + +class FetchToolItem(BaseModel): + tool_type: str + tool_provider: str + tool_name: str + credential_id: str | None = None + + +class FetchToolBatchRequest(BaseModel): + tools: list[FetchToolItem] + + +@cli_api_ns.route("/invoke/llm") +class CliInvokeLLMApi(Resource): + @cli_api_only + @get_cli_user_tenant + @setup_required + @plugin_data(payload_type=RequestInvokeLLM) + def post( + self, + user_model: Account | EndUser, + tenant_model: Tenant, + payload: RequestInvokeLLM, + cli_context: CliContext, + ): + def generator(): + response = PluginModelBackwardsInvocation.invoke_llm(user_model.id, tenant_model, payload) + return PluginModelBackwardsInvocation.convert_to_event_stream(response) + + return length_prefixed_response(0xF, generator()) + + +@cli_api_ns.route("/invoke/tool") +class CliInvokeToolApi(Resource): + @cli_api_only + @get_cli_user_tenant + @setup_required + @plugin_data(payload_type=RequestInvokeTool) + def post( + self, + user_model: Account | EndUser, + tenant_model: Tenant, + payload: RequestInvokeTool, + cli_context: CliContext, + ): + tool_type = ToolProviderType.value_of(payload.tool_type) + + request = ToolInvocationRequest( + tool_type=tool_type, + provider=payload.provider, + tool_name=payload.tool, + credential_id=payload.credential_id, + ) + if cli_context.tool_access and not cli_context.tool_access.is_allowed(request): + abort(403, description=f"Access denied for tool: {payload.provider}/{payload.tool}") + + def generator(): + return PluginToolBackwardsInvocation.convert_to_event_stream( + PluginToolBackwardsInvocation.invoke_tool( + tenant_id=tenant_model.id, + user_id=user_model.id, + tool_type=tool_type, + provider=payload.provider, + tool_name=payload.tool, + tool_parameters=payload.tool_parameters, + credential_id=payload.credential_id, + ), + ) + + return length_prefixed_response(0xF, generator()) + + +@cli_api_ns.route("/invoke/app") +class CliInvokeAppApi(Resource): + @cli_api_only + @get_cli_user_tenant + @setup_required + @plugin_data(payload_type=RequestInvokeApp) + def post( + self, + user_model: Account | EndUser, + tenant_model: Tenant, + payload: RequestInvokeApp, + cli_context: CliContext, + ): + response = PluginAppBackwardsInvocation.invoke_app( + app_id=payload.app_id, + user_id=user_model.id, + tenant_id=tenant_model.id, + conversation_id=payload.conversation_id, + query=payload.query, + stream=payload.response_mode == "streaming", + inputs=payload.inputs, + files=payload.files, + ) + + return length_prefixed_response(0xF, PluginAppBackwardsInvocation.convert_to_event_stream(response)) + + +@cli_api_ns.route("/upload/file/request") +class CliUploadFileRequestApi(Resource): + @cli_api_only + @get_cli_user_tenant + @setup_required + @plugin_data(payload_type=RequestRequestUploadFile) + def post( + self, + user_model: Account | EndUser, + tenant_model: Tenant, + payload: RequestRequestUploadFile, + cli_context: CliContext, + ): + url = get_signed_file_url_for_plugin( + filename=payload.filename, + mimetype=payload.mimetype, + tenant_id=tenant_model.id, + user_id=user_model.id, + ) + return BaseBackwardsInvocationResponse(data={"url": url}).model_dump() + + +@cli_api_ns.route("/fetch/tools/batch") +class CliFetchToolsBatchApi(Resource): + @cli_api_only + @get_cli_user_tenant + @setup_required + @plugin_data(payload_type=FetchToolBatchRequest) + def post( + self, + user_model: Account | EndUser, + tenant_model: Tenant, + payload: FetchToolBatchRequest, + cli_context: CliContext, + ): + tools: list[dict] = [] + + for item in payload.tools: + provider_type = ToolProviderType.value_of(item.tool_type) + + request = ToolInvocationRequest( + tool_type=provider_type, + provider=item.tool_provider, + tool_name=item.tool_name, + credential_id=item.credential_id, + ) + if cli_context.tool_access and not cli_context.tool_access.is_allowed(request): + abort(403, description=f"Access denied for tool: {item.tool_provider}/{item.tool_name}") + + try: + tool_runtime = ToolManager.get_tool_runtime( + tenant_id=tenant_model.id, + provider_type=provider_type, + provider_id=item.tool_provider, + tool_name=item.tool_name, + invoke_from=InvokeFrom.AGENT, + credential_id=item.credential_id, + ) + tool_config = DifyCliToolConfig.create_from_tool(tool_runtime) + tools.append(tool_config.model_dump()) + except Exception: + continue + + return BaseBackwardsInvocationResponse(data={"tools": tools}).model_dump() diff --git a/api/controllers/cli_api/dify_cli/wraps.py b/api/controllers/cli_api/dify_cli/wraps.py new file mode 100644 index 0000000000..4b37400043 --- /dev/null +++ b/api/controllers/cli_api/dify_cli/wraps.py @@ -0,0 +1,137 @@ +from collections.abc import Callable +from functools import wraps +from typing import ParamSpec, TypeVar + +from flask import current_app, g, request +from flask_login import user_logged_in +from pydantic import BaseModel +from sqlalchemy.orm import Session + +from core.session.cli_api import CliApiSession, CliContext +from extensions.ext_database import db +from libs.login import current_user +from models.account import Tenant +from models.model import DefaultEndUserSessionID, EndUser + +P = ParamSpec("P") +R = TypeVar("R") + + +class TenantUserPayload(BaseModel): + tenant_id: str + user_id: str + + +def get_user(tenant_id: str, user_id: str | None) -> EndUser: + """ + Get current user + + NOTE: user_id is not trusted, it could be maliciously set to any value. + As a result, it could only be considered as an end user id. + """ + if not user_id: + user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID + is_anonymous = user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID + try: + with Session(db.engine) as session: + user_model = None + + if is_anonymous: + user_model = ( + session.query(EndUser) + .where( + EndUser.session_id == user_id, + EndUser.tenant_id == tenant_id, + ) + .first() + ) + else: + user_model = ( + session.query(EndUser) + .where( + EndUser.id == user_id, + EndUser.tenant_id == tenant_id, + ) + .first() + ) + + if not user_model: + user_model = EndUser( + tenant_id=tenant_id, + type="service_api", + is_anonymous=is_anonymous, + session_id=user_id, + ) + session.add(user_model) + session.commit() + session.refresh(user_model) + + except Exception: + raise ValueError("user not found") + + return user_model + + +def get_cli_user_tenant(view_func: Callable[P, R]): + @wraps(view_func) + def decorated_view(*args: P.args, **kwargs: P.kwargs): + session: CliApiSession | None = getattr(g, "cli_api_session", None) + if session is None: + raise ValueError("session not found") + + user_id = session.user_id + tenant_id = session.tenant_id + cli_context = CliContext.model_validate(session.context) + + if not user_id: + user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID + + try: + tenant_model = ( + db.session.query(Tenant) + .where( + Tenant.id == tenant_id, + ) + .first() + ) + except Exception: + raise ValueError("tenant not found") + + if not tenant_model: + raise ValueError("tenant not found") + + kwargs["tenant_model"] = tenant_model + kwargs["user_model"] = get_user(tenant_id, user_id) + kwargs["cli_context"] = cli_context + + current_app.login_manager._update_request_context_with_user(kwargs["user_model"]) # type: ignore + user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore + + return view_func(*args, **kwargs) + + return decorated_view + + +def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseModel]): + def decorator(view_func: Callable[P, R]): + @wraps(view_func) + def decorated_view(*args: P.args, **kwargs: P.kwargs): + try: + data = request.get_json() + except Exception: + raise ValueError("invalid json") + + try: + payload = payload_type.model_validate(data) + except Exception as e: + raise ValueError(f"invalid payload: {str(e)}") + + kwargs["payload"] = payload + return view_func(*args, **kwargs) + + return decorated_view + + if view is None: + return decorator + else: + return decorator(view) diff --git a/api/controllers/cli_api/wraps.py b/api/controllers/cli_api/wraps.py new file mode 100644 index 0000000000..d4f1cdb522 --- /dev/null +++ b/api/controllers/cli_api/wraps.py @@ -0,0 +1,56 @@ +import hashlib +import hmac +import time +from collections.abc import Callable +from functools import wraps +from typing import ParamSpec, TypeVar + +from flask import abort, g, request + +from core.session.cli_api import CliApiSessionManager + +P = ParamSpec("P") +R = TypeVar("R") + +SIGNATURE_TTL_SECONDS = 300 + + +def _verify_signature(session_secret: str, timestamp: str, body: bytes, signature: str) -> bool: + expected = hmac.new( + session_secret.encode(), + f"{timestamp}.".encode() + body, + hashlib.sha256, + ).hexdigest() + return hmac.compare_digest(f"sha256={expected}", signature) + + +def cli_api_only(view: Callable[P, R]): + @wraps(view) + def decorated(*args: P.args, **kwargs: P.kwargs): + session_id = request.headers.get("X-Cli-Api-Session-Id") + timestamp = request.headers.get("X-Cli-Api-Timestamp") + signature = request.headers.get("X-Cli-Api-Signature") + + if not session_id or not timestamp or not signature: + abort(401) + + try: + ts = int(timestamp) + if abs(time.time() - ts) > SIGNATURE_TTL_SECONDS: + abort(401) + except ValueError: + abort(401) + + session = CliApiSessionManager().get(session_id) + if not session: + abort(401) + + body = request.get_data() + if not _verify_signature(session.secret, timestamp, body, signature): + abort(401) + + g.cli_api_session = session + + return view(*args, **kwargs) + + return decorated diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index d624b10b22..c26631574d 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -41,6 +41,7 @@ from . import ( init_validate, notification, ping, + # sandbox_files, # TODO: enable after full sandbox integration setup, spec, version, @@ -52,6 +53,7 @@ from .app import ( agent, annotation, app, + # app_asset, # TODO: enable after full sandbox integration audio, completion, conversation, @@ -62,6 +64,7 @@ from .app import ( model_config, ops_trace, site, + # skills, # TODO: enable after full sandbox integration statistic, workflow, workflow_app_log, @@ -130,6 +133,7 @@ from .workspace import ( model_providers, models, plugin, + # sandbox_providers, # TODO: enable after full sandbox integration tool_providers, trigger_providers, workspace, diff --git a/api/controllers/console/app/app_asset.py b/api/controllers/console/app/app_asset.py new file mode 100644 index 0000000000..e1a004eb76 --- /dev/null +++ b/api/controllers/console/app/app_asset.py @@ -0,0 +1,333 @@ +from flask import request +from flask_restx import Resource +from pydantic import BaseModel, Field, field_validator + +from controllers.console import console_ns +from controllers.console.app.error import ( + AppAssetNodeNotFoundError, + AppAssetPathConflictError, +) +from controllers.console.app.wraps import get_app_model +from controllers.console.wraps import account_initialization_required, setup_required +from core.app.entities.app_asset_entities import BatchUploadNode +from libs.login import current_account_with_tenant, login_required +from models import App +from models.model import AppMode +from services.app_asset_service import AppAssetService +from services.errors.app_asset import ( + AppAssetNodeNotFoundError as ServiceNodeNotFoundError, +) +from services.errors.app_asset import ( + AppAssetParentNotFoundError, +) +from services.errors.app_asset import ( + AppAssetPathConflictError as ServicePathConflictError, +) + +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class CreateFolderPayload(BaseModel): + name: str = Field(..., min_length=1, max_length=255) + parent_id: str | None = None + + +class CreateFilePayload(BaseModel): + name: str = Field(..., min_length=1, max_length=255) + parent_id: str | None = None + + @field_validator("name", mode="before") + @classmethod + def strip_name(cls, v: str) -> str: + return v.strip() if isinstance(v, str) else v + + @field_validator("parent_id", mode="before") + @classmethod + def empty_to_none(cls, v: str | None) -> str | None: + return v or None + + +class GetUploadUrlPayload(BaseModel): + name: str = Field(..., min_length=1, max_length=255) + size: int = Field(..., ge=0) + parent_id: str | None = None + + @field_validator("name", mode="before") + @classmethod + def strip_name(cls, v: str) -> str: + return v.strip() if isinstance(v, str) else v + + @field_validator("parent_id", mode="before") + @classmethod + def empty_to_none(cls, v: str | None) -> str | None: + return v or None + + +class BatchUploadPayload(BaseModel): + children: list[BatchUploadNode] = Field(..., min_length=1) + parent_id: str | None = None + + @field_validator("parent_id", mode="before") + @classmethod + def empty_to_none(cls, v: str | None) -> str | None: + return v or None + + +class UpdateFileContentPayload(BaseModel): + content: str + + +class RenameNodePayload(BaseModel): + name: str = Field(..., min_length=1, max_length=255) + + +class MoveNodePayload(BaseModel): + parent_id: str | None = None + + +class ReorderNodePayload(BaseModel): + after_node_id: str | None = Field(default=None, description="Place after this node, None for first position") + + +def reg(cls: type[BaseModel]) -> None: + console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) + + +reg(CreateFolderPayload) +reg(CreateFilePayload) +reg(GetUploadUrlPayload) +reg(BatchUploadNode) +reg(BatchUploadPayload) +reg(UpdateFileContentPayload) +reg(RenameNodePayload) +reg(MoveNodePayload) +reg(ReorderNodePayload) + + +@console_ns.route("/apps//assets/tree") +class AppAssetTreeResource(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def get(self, app_model: App): + current_user, _ = current_account_with_tenant() + tree = AppAssetService.get_asset_tree(app_model, current_user.id) + return {"children": [view.model_dump() for view in tree.transform()]} + + +@console_ns.route("/apps//assets/folders") +class AppAssetFolderResource(Resource): + @console_ns.expect(console_ns.models[CreateFolderPayload.__name__]) + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def post(self, app_model: App): + current_user, _ = current_account_with_tenant() + payload = CreateFolderPayload.model_validate(console_ns.payload or {}) + + try: + node = AppAssetService.create_folder(app_model, current_user.id, payload.name, payload.parent_id) + return node.model_dump(), 201 + except AppAssetParentNotFoundError: + raise AppAssetNodeNotFoundError() + except ServicePathConflictError: + raise AppAssetPathConflictError() + + +@console_ns.route("/apps//assets/files/") +class AppAssetFileDetailResource(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def get(self, app_model: App, node_id: str): + current_user, _ = current_account_with_tenant() + try: + content = AppAssetService.get_file_content(app_model, current_user.id, node_id) + return {"content": content.decode("utf-8", errors="replace")} + except ServiceNodeNotFoundError: + raise AppAssetNodeNotFoundError() + + @console_ns.expect(console_ns.models[UpdateFileContentPayload.__name__]) + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def put(self, app_model: App, node_id: str): + current_user, _ = current_account_with_tenant() + + file = request.files.get("file") + if file: + content = file.read() + else: + payload = UpdateFileContentPayload.model_validate(console_ns.payload or {}) + content = payload.content.encode("utf-8") + + try: + node = AppAssetService.update_file_content(app_model, current_user.id, node_id, content) + return node.model_dump() + except ServiceNodeNotFoundError: + raise AppAssetNodeNotFoundError() + + +@console_ns.route("/apps//assets/nodes/") +class AppAssetNodeResource(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def delete(self, app_model: App, node_id: str): + current_user, _ = current_account_with_tenant() + try: + AppAssetService.delete_node(app_model, current_user.id, node_id) + return {"result": "success"}, 200 + except ServiceNodeNotFoundError: + raise AppAssetNodeNotFoundError() + + +@console_ns.route("/apps//assets/nodes//rename") +class AppAssetNodeRenameResource(Resource): + @console_ns.expect(console_ns.models[RenameNodePayload.__name__]) + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def post(self, app_model: App, node_id: str): + current_user, _ = current_account_with_tenant() + payload = RenameNodePayload.model_validate(console_ns.payload or {}) + + try: + node = AppAssetService.rename_node(app_model, current_user.id, node_id, payload.name) + return node.model_dump() + except ServiceNodeNotFoundError: + raise AppAssetNodeNotFoundError() + except ServicePathConflictError: + raise AppAssetPathConflictError() + + +@console_ns.route("/apps//assets/nodes//move") +class AppAssetNodeMoveResource(Resource): + @console_ns.expect(console_ns.models[MoveNodePayload.__name__]) + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def post(self, app_model: App, node_id: str): + current_user, _ = current_account_with_tenant() + payload = MoveNodePayload.model_validate(console_ns.payload or {}) + + try: + node = AppAssetService.move_node(app_model, current_user.id, node_id, payload.parent_id) + return node.model_dump() + except ServiceNodeNotFoundError: + raise AppAssetNodeNotFoundError() + except AppAssetParentNotFoundError: + raise AppAssetNodeNotFoundError() + except ServicePathConflictError: + raise AppAssetPathConflictError() + + +@console_ns.route("/apps//assets/nodes//reorder") +class AppAssetNodeReorderResource(Resource): + @console_ns.expect(console_ns.models[ReorderNodePayload.__name__]) + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def post(self, app_model: App, node_id: str): + current_user, _ = current_account_with_tenant() + payload = ReorderNodePayload.model_validate(console_ns.payload or {}) + + try: + node = AppAssetService.reorder_node(app_model, current_user.id, node_id, payload.after_node_id) + return node.model_dump() + except ServiceNodeNotFoundError: + raise AppAssetNodeNotFoundError() + + +@console_ns.route("/apps//assets/files//download-url") +class AppAssetFileDownloadUrlResource(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def get(self, app_model: App, node_id: str): + current_user, _ = current_account_with_tenant() + try: + download_url = AppAssetService.get_file_download_url(app_model, current_user.id, node_id) + return {"download_url": download_url} + except ServiceNodeNotFoundError: + raise AppAssetNodeNotFoundError() + + +@console_ns.route("/apps//assets/files/upload") +class AppAssetFileUploadUrlResource(Resource): + @console_ns.expect(console_ns.models[GetUploadUrlPayload.__name__]) + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def post(self, app_model: App): + current_user, _ = current_account_with_tenant() + payload = GetUploadUrlPayload.model_validate(console_ns.payload or {}) + + try: + node, upload_url = AppAssetService.get_file_upload_url( + app_model, current_user.id, payload.name, payload.size, payload.parent_id + ) + return {"node": node.model_dump(), "upload_url": upload_url}, 201 + except AppAssetParentNotFoundError: + raise AppAssetNodeNotFoundError() + except ServicePathConflictError: + raise AppAssetPathConflictError() + + +@console_ns.route("/apps//assets/batch-upload") +class AppAssetBatchUploadResource(Resource): + @console_ns.expect(console_ns.models[BatchUploadPayload.__name__]) + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def post(self, app_model: App): + """ + Create nodes from tree structure and return upload URLs. + + Input: + { + "parent_id": "optional-target-folder-id", + "children": [ + {"name": "folder1", "node_type": "folder", "children": [ + {"name": "file1.txt", "node_type": "file", "size": 1024} + ]}, + {"name": "root.txt", "node_type": "file", "size": 512} + ] + } + + Output: + { + "children": [ + {"id": "xxx", "name": "folder1", "node_type": "folder", "children": [ + {"id": "yyy", "name": "file1.txt", "node_type": "file", "size": 1024, "upload_url": "..."} + ]}, + {"id": "zzz", "name": "root.txt", "node_type": "file", "size": 512, "upload_url": "..."} + ] + } + """ + current_user, _ = current_account_with_tenant() + payload = BatchUploadPayload.model_validate(console_ns.payload or {}) + + try: + result_children = AppAssetService.batch_create_from_tree( + app_model, + current_user.id, + payload.children, + parent_id=payload.parent_id, + ) + return {"children": [child.model_dump() for child in result_children]}, 201 + except AppAssetParentNotFoundError: + raise AppAssetNodeNotFoundError() + except ServicePathConflictError: + raise AppAssetPathConflictError() diff --git a/api/controllers/console/app/skills.py b/api/controllers/console/app/skills.py new file mode 100644 index 0000000000..39ad903bc2 --- /dev/null +++ b/api/controllers/console/app/skills.py @@ -0,0 +1,38 @@ +from flask import request +from flask_restx import Resource + +from controllers.console import console_ns +from controllers.console.app.wraps import get_app_model +from controllers.console.wraps import account_initialization_required, current_account_with_tenant, setup_required +from libs.login import login_required +from models import App +from models.model import AppMode +from services.skill_service import SkillService + + +@console_ns.route("/apps//workflows/draft/nodes/llm/skills") +class NodeSkillsApi(Resource): + """Extract tool dependencies from an LLM node's skill prompts. + + The client sends the full node ``data`` object in the request body. + The server real-time builds a ``SkillBundle`` from the current draft + ``.md`` assets and resolves transitive tool dependencies — no cached + bundle is used. + """ + + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def post(self, app_model: App): + current_user, _ = current_account_with_tenant() + node_data = request.get_json(force=True) + if not isinstance(node_data, dict): + return {"tool_dependencies": []} + + tool_deps = SkillService.extract_tool_dependencies( + app=app_model, + node_data=node_data, + user_id=current_user.id, + ) + return {"tool_dependencies": [d.model_dump() for d in tool_deps]} diff --git a/api/controllers/console/sandbox_files.py b/api/controllers/console/sandbox_files.py new file mode 100644 index 0000000000..4a9c8d0632 --- /dev/null +++ b/api/controllers/console/sandbox_files.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +from fastapi.encoders import jsonable_encoder +from flask import request +from flask_restx import Resource, fields +from pydantic import BaseModel, Field + +from controllers.console import console_ns +from controllers.console.wraps import account_initialization_required, setup_required +from libs.login import current_account_with_tenant, login_required +from services.sandbox.sandbox_file_service import SandboxFileService + +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class SandboxFileListQuery(BaseModel): + path: str | None = Field(default=None, description="Workspace relative path") + recursive: bool = Field(default=False, description="List recursively") + + +class SandboxFileDownloadRequest(BaseModel): + path: str = Field(..., description="Workspace relative file path") + + +console_ns.schema_model( + SandboxFileListQuery.__name__, + SandboxFileListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) +console_ns.schema_model( + SandboxFileDownloadRequest.__name__, + SandboxFileDownloadRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + + +SANDBOX_FILE_NODE_FIELDS = { + "path": fields.String, + "is_dir": fields.Boolean, + "size": fields.Raw, + "mtime": fields.Raw, + "extension": fields.String, +} + + +SANDBOX_FILE_DOWNLOAD_TICKET_FIELDS = { + "download_url": fields.String, + "expires_in": fields.Integer, + "export_id": fields.String, +} + + +sandbox_file_node_model = console_ns.model("SandboxFileNode", SANDBOX_FILE_NODE_FIELDS) +sandbox_file_download_ticket_model = console_ns.model("SandboxFileDownloadTicket", SANDBOX_FILE_DOWNLOAD_TICKET_FIELDS) + + +@console_ns.route("/apps//sandbox/files") +class SandboxFilesApi(Resource): + """List sandbox files for the current user. + + The sandbox_id is derived from the current user's ID, as each user has + their own sandbox workspace per app. + """ + + @setup_required + @login_required + @account_initialization_required + @console_ns.expect(console_ns.models[SandboxFileListQuery.__name__]) + @console_ns.marshal_list_with(sandbox_file_node_model) + def get(self, app_id: str): + args = SandboxFileListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore[arg-type] + account, tenant_id = current_account_with_tenant() + sandbox_id = account.id + return jsonable_encoder( + SandboxFileService.list_files( + tenant_id=tenant_id, + app_id=app_id, + sandbox_id=sandbox_id, + path=args.path, + recursive=args.recursive, + ) + ) + + +@console_ns.route("/apps//sandbox/files/download") +class SandboxFileDownloadApi(Resource): + """Download a sandbox file for the current user. + + The sandbox_id is derived from the current user's ID, as each user has + their own sandbox workspace per app. + """ + + @setup_required + @login_required + @account_initialization_required + @console_ns.expect(console_ns.models[SandboxFileDownloadRequest.__name__]) + @console_ns.marshal_with(sandbox_file_download_ticket_model) + def post(self, app_id: str): + payload = SandboxFileDownloadRequest.model_validate(console_ns.payload or {}) + account, tenant_id = current_account_with_tenant() + sandbox_id = account.id + res = SandboxFileService.download_file( + tenant_id=tenant_id, app_id=app_id, sandbox_id=sandbox_id, path=payload.path + ) + return jsonable_encoder(res) diff --git a/api/controllers/console/workspace/sandbox_providers.py b/api/controllers/console/workspace/sandbox_providers.py new file mode 100644 index 0000000000..cb7515f6a8 --- /dev/null +++ b/api/controllers/console/workspace/sandbox_providers.py @@ -0,0 +1,104 @@ +import logging + +from flask import request +from flask_restx import Resource, fields +from pydantic import BaseModel + +from controllers.console import console_ns +from controllers.console.wraps import account_initialization_required, setup_required +from graphon.model_runtime.utils.encoders import jsonable_encoder +from libs.login import current_account_with_tenant, login_required +from services.sandbox.sandbox_provider_service import SandboxProviderService + +logger = logging.getLogger(__name__) + + +class SandboxProviderConfigRequest(BaseModel): + config: dict + activate: bool = False + + +class SandboxProviderActivateRequest(BaseModel): + type: str + + +@console_ns.route("/workspaces/current/sandbox-providers") +class SandboxProviderListApi(Resource): + @console_ns.doc("list_sandbox_providers") + @console_ns.doc(description="Get list of available sandbox providers with configuration status") + @console_ns.response(200, "Success", fields.List(fields.Raw(description="Sandbox provider information"))) + @setup_required + @login_required + @account_initialization_required + def get(self): + _, current_tenant_id = current_account_with_tenant() + providers = SandboxProviderService.list_providers(current_tenant_id) + return jsonable_encoder([p.model_dump() for p in providers]) + + +@console_ns.route("/workspaces/current/sandbox-provider//config") +class SandboxProviderConfigApi(Resource): + @console_ns.doc("save_sandbox_provider_config") + @console_ns.doc(description="Save or update configuration for a sandbox provider") + @console_ns.response(200, "Success") + @setup_required + @login_required + @account_initialization_required + def post(self, provider_type: str): + _, current_tenant_id = current_account_with_tenant() + args = SandboxProviderConfigRequest.model_validate(request.get_json()) + + try: + result = SandboxProviderService.save_config( + tenant_id=current_tenant_id, + provider_type=provider_type, + config=args.config, + activate=args.activate, + ) + return result + except ValueError as e: + return {"message": str(e)}, 400 + + @console_ns.doc("delete_sandbox_provider_config") + @console_ns.doc(description="Delete configuration for a sandbox provider") + @console_ns.response(200, "Success") + @setup_required + @login_required + @account_initialization_required + def delete(self, provider_type: str): + _, current_tenant_id = current_account_with_tenant() + + try: + result = SandboxProviderService.delete_config( + tenant_id=current_tenant_id, + provider_type=provider_type, + ) + return result + except ValueError as e: + return {"message": str(e)}, 400 + + +@console_ns.route("/workspaces/current/sandbox-provider//activate") +class SandboxProviderActivateApi(Resource): + """Activate a sandbox provider.""" + + @console_ns.doc("activate_sandbox_provider") + @console_ns.doc(description="Activate a sandbox provider for the current workspace") + @console_ns.response(200, "Success") + @setup_required + @login_required + @account_initialization_required + def post(self, provider_type: str): + """Activate a sandbox provider.""" + _, current_tenant_id = current_account_with_tenant() + + try: + args = SandboxProviderActivateRequest.model_validate(request.get_json()) + result = SandboxProviderService.activate_provider( + tenant_id=current_tenant_id, + provider_type=provider_type, + type=args.type, + ) + return result + except ValueError as e: + return {"message": str(e)}, 400 diff --git a/api/controllers/files/storage_files.py b/api/controllers/files/storage_files.py new file mode 100644 index 0000000000..1623395e9b --- /dev/null +++ b/api/controllers/files/storage_files.py @@ -0,0 +1,80 @@ +"""Token-based file proxy controller for storage operations. + +This controller handles file download and upload operations using opaque UUID tokens. +The token maps to the real storage key in Redis, so the actual storage path is never +exposed in the URL. + +Routes: + GET /files/storage-files/{token} - Download a file + PUT /files/storage-files/{token} - Upload a file + +The operation type (download/upload) is determined by the ticket stored in Redis, +not by the HTTP method. This ensures a download ticket cannot be used for upload +and vice versa. +""" + +from urllib.parse import quote + +from flask import Response, request +from flask_restx import Resource +from werkzeug.exceptions import Forbidden, NotFound, RequestEntityTooLarge + +from controllers.files import files_ns +from extensions.ext_storage import storage +from services.storage_ticket_service import StorageTicketService + + +@files_ns.route("/storage-files/") +class StorageFilesApi(Resource): + """Handle file operations through token-based URLs.""" + + def get(self, token: str): + """Download a file using a token. + + The ticket must have op="download", otherwise returns 403. + """ + ticket = StorageTicketService.get_ticket(token) + if ticket is None: + raise Forbidden("Invalid or expired token") + + if ticket.op != "download": + raise Forbidden("This token is not valid for download") + + try: + generator = storage.load_stream(ticket.storage_key) + except FileNotFoundError: + raise NotFound("File not found") + + filename = ticket.filename or ticket.storage_key.rsplit("/", 1)[-1] + encoded_filename = quote(filename) + + return Response( + generator, + mimetype="application/octet-stream", + direct_passthrough=True, + headers={ + "Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}", + }, + ) + + def put(self, token: str): + """Upload a file using a token. + + The ticket must have op="upload", otherwise returns 403. + If the request body exceeds max_bytes, returns 413. + """ + ticket = StorageTicketService.get_ticket(token) + if ticket is None: + raise Forbidden("Invalid or expired token") + + if ticket.op != "upload": + raise Forbidden("This token is not valid for upload") + + content = request.get_data() + + if ticket.max_bytes is not None and len(content) > ticket.max_bytes: + raise RequestEntityTooLarge(f"Upload exceeds maximum size of {ticket.max_bytes} bytes") + + storage.save(ticket.storage_key, content) + + return Response(status=204) diff --git a/api/core/app/layers/__init__.py b/api/core/app/layers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/app/layers/sandbox_layer.py b/api/core/app/layers/sandbox_layer.py new file mode 100644 index 0000000000..6f0405a6cd --- /dev/null +++ b/api/core/app/layers/sandbox_layer.py @@ -0,0 +1,22 @@ +import logging + +from core.sandbox import Sandbox +from graphon.graph_engine.layers.base import GraphEngineLayer +from graphon.graph_events.base import GraphEngineEvent + +logger = logging.getLogger(__name__) + + +class SandboxLayer(GraphEngineLayer): + def __init__(self, sandbox: Sandbox) -> None: + super().__init__() + self._sandbox = sandbox + + def on_graph_start(self) -> None: + pass + + def on_event(self, event: GraphEngineEvent) -> None: + pass + + def on_graph_end(self, error: Exception | None) -> None: + self._sandbox.release() diff --git a/api/core/app_assets/__init__.py b/api/core/app_assets/__init__.py new file mode 100644 index 0000000000..2e000f4432 --- /dev/null +++ b/api/core/app_assets/__init__.py @@ -0,0 +1,13 @@ +from .constants import AppAssetsAttrs +from .entities import ( + AssetItem, + SkillAsset, +) +from .storage import AssetPaths + +__all__ = [ + "AppAssetsAttrs", + "AssetItem", + "AssetPaths", + "SkillAsset", +] diff --git a/api/core/app_assets/accessor.py b/api/core/app_assets/accessor.py new file mode 100644 index 0000000000..e32e02c2c3 --- /dev/null +++ b/api/core/app_assets/accessor.py @@ -0,0 +1,180 @@ +"""Unified content accessor for app asset nodes. + +Accessor is scoped to a single app (tenant_id + app_id), not a single node. +All methods accept an AppAssetNode parameter to identify the target. + +CachedContentAccessor is the primary entry point: +- Reads DB first, misses fall through to S3 with sync backfill. +- Writes go to both DB and S3 (dual-write). +- resolve_items() batch-enriches AssetItem lists with DB-cached content + (extension-agnostic), so callers never need to filter by extension. +- Wraps an internal _StorageAccessor for S3 I/O. + +Collaborators: + - services.asset_content_service.AssetContentService (DB layer) + - core.app_assets.storage.AssetPaths (S3 key generation) + - extensions.storage.cached_presign_storage.CachedPresignStorage (S3 I/O) +""" + +from __future__ import annotations + +import logging + +from core.app.entities.app_asset_entities import AppAssetNode +from core.app_assets.entities.assets import AssetItem +from core.app_assets.storage import AssetPaths +from extensions.storage.cached_presign_storage import CachedPresignStorage +from services.asset_content_service import AssetContentService + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# S3-only implementation (internal, used as inner delegate) +# --------------------------------------------------------------------------- + + +class _StorageAccessor: + """Reads/writes draft content via object storage (S3) only.""" + + _storage: CachedPresignStorage + _tenant_id: str + _app_id: str + + def __init__(self, storage: CachedPresignStorage, tenant_id: str, app_id: str) -> None: + self._storage = storage + self._tenant_id = tenant_id + self._app_id = app_id + + def _key(self, node: AppAssetNode) -> str: + return AssetPaths.draft(self._tenant_id, self._app_id, node.id) + + def load(self, node: AppAssetNode) -> bytes: + return self._storage.load_once(self._key(node)) + + def save(self, node: AppAssetNode, content: bytes) -> None: + self._storage.save(self._key(node), content) + + def delete(self, node: AppAssetNode) -> None: + try: + self._storage.delete(self._key(node)) + except Exception: + logger.warning("Failed to delete storage key %s", self._key(node), exc_info=True) + + +# --------------------------------------------------------------------------- +# DB-cached implementation (the public API) +# --------------------------------------------------------------------------- + + +class CachedContentAccessor: + """App-level content accessor with DB read-through cache over S3. + + Read path: DB first -> miss -> S3 fallback -> sync backfill DB + Write path: DB upsert + S3 save (dual-write) + Delete path: DB delete + S3 delete + + bulk_load uses a single SQL query for all nodes, with S3 fallback per miss. + + Usage: + accessor = CachedContentAccessor(storage, tenant_id, app_id) + content = accessor.load(node) + accessor.save(node, content) + results = accessor.bulk_load(nodes) + """ + + _inner: _StorageAccessor + _tenant_id: str + _app_id: str + + def __init__(self, storage: CachedPresignStorage, tenant_id: str, app_id: str) -> None: + self._inner = _StorageAccessor(storage, tenant_id, app_id) + self._tenant_id = tenant_id + self._app_id = app_id + + def load(self, node: AppAssetNode) -> bytes: + # 1. Try DB + cached = AssetContentService.get(self._tenant_id, self._app_id, node.id) + if cached is not None: + return cached.encode("utf-8") + + # 2. Fallback to S3 + data = self._inner.load(node) + + # 3. Sync backfill DB + AssetContentService.upsert( + tenant_id=self._tenant_id, + app_id=self._app_id, + node_id=node.id, + content=data.decode("utf-8"), + size=len(data), + ) + return data + + def bulk_load(self, nodes: list[AppAssetNode]) -> dict[str, bytes]: + """Single SQL for all nodes, S3 fallback + backfill per miss.""" + result: dict[str, bytes] = {} + node_ids = [n.id for n in nodes] + cached = AssetContentService.get_many(self._tenant_id, self._app_id, node_ids) + + for node in nodes: + if node.id in cached: + result[node.id] = cached[node.id].encode("utf-8") + else: + # S3 fallback + sync backfill + data = self._inner.load(node) + AssetContentService.upsert( + tenant_id=self._tenant_id, + app_id=self._app_id, + node_id=node.id, + content=data.decode("utf-8"), + size=len(data), + ) + result[node.id] = data + return result + + def save(self, node: AppAssetNode, content: bytes) -> None: + # Dual-write: DB + S3 + AssetContentService.upsert( + tenant_id=self._tenant_id, + app_id=self._app_id, + node_id=node.id, + content=content.decode("utf-8"), + size=len(content), + ) + self._inner.save(node, content) + + def resolve_items(self, items: list[AssetItem]) -> list[AssetItem]: + """Batch-enrich asset items with DB-cached content. + + Queries by ``asset_id`` only — extension-agnostic. Items without + a DB cache row keep their original *content* value (typically + ``None``), so only genuinely cached assets (e.g. ``.md`` skill + documents) get populated. + + This eliminates the need for callers to filter by file extension + before deciding whether to read from the DB cache. + """ + if not items: + return items + + node_ids = [a.asset_id for a in items] + cached = AssetContentService.get_many(self._tenant_id, self._app_id, node_ids) + + if not cached: + return items + + return [ + AssetItem( + asset_id=a.asset_id, + path=a.path, + file_name=a.file_name, + extension=a.extension, + storage_key=a.storage_key, + content=cached[a.asset_id].encode("utf-8") if a.asset_id in cached else a.content, + ) + for a in items + ] + + def delete(self, node: AppAssetNode) -> None: + AssetContentService.delete(self._tenant_id, self._app_id, node.id) + self._inner.delete(node) diff --git a/api/core/app_assets/builder/__init__.py b/api/core/app_assets/builder/__init__.py new file mode 100644 index 0000000000..9e64a31884 --- /dev/null +++ b/api/core/app_assets/builder/__init__.py @@ -0,0 +1,12 @@ +from .base import AssetBuilder, BuildContext +from .file_builder import FileBuilder +from .pipeline import AssetBuildPipeline +from .skill_builder import SkillBuilder + +__all__ = [ + "AssetBuildPipeline", + "AssetBuilder", + "BuildContext", + "FileBuilder", + "SkillBuilder", +] diff --git a/api/core/app_assets/builder/base.py b/api/core/app_assets/builder/base.py new file mode 100644 index 0000000000..595ce84882 --- /dev/null +++ b/api/core/app_assets/builder/base.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass +from typing import Protocol + +from core.app.entities.app_asset_entities import AppAssetFileTree, AppAssetNode +from core.app_assets.entities import AssetItem + + +@dataclass +class BuildContext: + tenant_id: str + app_id: str + build_id: str + + +class AssetBuilder(Protocol): + def accept(self, node: AppAssetNode) -> bool: ... + + def collect(self, node: AppAssetNode, path: str, ctx: BuildContext) -> None: ... + + def build(self, tree: AppAssetFileTree, ctx: BuildContext) -> list[AssetItem]: ... diff --git a/api/core/app_assets/builder/file_builder.py b/api/core/app_assets/builder/file_builder.py new file mode 100644 index 0000000000..7f49327ad4 --- /dev/null +++ b/api/core/app_assets/builder/file_builder.py @@ -0,0 +1,30 @@ +from core.app.entities.app_asset_entities import AppAssetFileTree, AppAssetNode +from core.app_assets.entities import AssetItem +from core.app_assets.storage import AssetPaths + +from .base import BuildContext + + +class FileBuilder: + _nodes: list[tuple[AppAssetNode, str]] + + def __init__(self) -> None: + self._nodes = [] + + def accept(self, node: AppAssetNode) -> bool: + return True + + def collect(self, node: AppAssetNode, path: str, ctx: BuildContext) -> None: + self._nodes.append((node, path)) + + def build(self, tree: AppAssetFileTree, ctx: BuildContext) -> list[AssetItem]: + return [ + AssetItem( + asset_id=node.id, + path=path, + file_name=node.name, + extension=node.extension or "", + storage_key=AssetPaths.draft(ctx.tenant_id, ctx.app_id, node.id), + ) + for node, path in self._nodes + ] diff --git a/api/core/app_assets/builder/pipeline.py b/api/core/app_assets/builder/pipeline.py new file mode 100644 index 0000000000..f8db220c0a --- /dev/null +++ b/api/core/app_assets/builder/pipeline.py @@ -0,0 +1,27 @@ +from core.app.entities.app_asset_entities import AppAssetFileTree +from core.app_assets.entities import AssetItem + +from .base import AssetBuilder, BuildContext + + +class AssetBuildPipeline: + _builders: list[AssetBuilder] + + def __init__(self, builders: list[AssetBuilder]) -> None: + self._builders = builders + + def build_all(self, tree: AppAssetFileTree, ctx: BuildContext) -> list[AssetItem]: + # 1. Distribute: each node goes to first accepting builder + for node in tree.walk_files(): + path = tree.get_path(node.id) + for builder in self._builders: + if builder.accept(node): + builder.collect(node, path, ctx) + break + + # 2. Each builder builds its collected nodes + results: list[AssetItem] = [] + for builder in self._builders: + results.extend(builder.build(tree, ctx)) + + return results diff --git a/api/core/app_assets/builder/skill_builder.py b/api/core/app_assets/builder/skill_builder.py new file mode 100644 index 0000000000..fd2a2fb946 --- /dev/null +++ b/api/core/app_assets/builder/skill_builder.py @@ -0,0 +1,96 @@ +"""Builder that compiles ``.md`` skill documents into resolved content. + +The builder reads raw draft content from the DB-backed accessor, parses +each into a ``SkillDocument``, assembles a ``SkillBundle`` (with +transitive tool/file dependency resolution), and returns ``AssetItem`` +objects whose *content* field carries the resolved bytes in-process. + +The assembled ``SkillBundle`` is persisted via ``SkillManager`` +(S3 + Redis) **and** retained on the ``bundle`` property so that +callers (e.g. ``DraftAppAssetsInitializer``) can pass it directly to +``sandbox.attrs`` without a redundant Redis/S3 round-trip. +""" + +import json +import logging + +from core.app.entities.app_asset_entities import AppAssetFileTree, AppAssetNode +from core.app_assets.accessor import CachedContentAccessor +from core.app_assets.entities import AssetItem +from core.skill.assembler import SkillBundleAssembler +from core.skill.entities.skill_bundle import SkillBundle +from core.skill.entities.skill_document import SkillDocument + +from .base import BuildContext + +logger = logging.getLogger(__name__) + + +class SkillBuilder: + _nodes: list[tuple[AppAssetNode, str]] + _accessor: CachedContentAccessor + _bundle: SkillBundle | None + + def __init__(self, accessor: CachedContentAccessor) -> None: + self._nodes = [] + self._accessor = accessor + self._bundle = None + + @property + def bundle(self) -> SkillBundle | None: + """The ``SkillBundle`` produced by the last ``build()`` call, or *None*.""" + return self._bundle + + def accept(self, node: AppAssetNode) -> bool: + return node.extension == "md" + + def collect(self, node: AppAssetNode, path: str, ctx: BuildContext) -> None: + self._nodes.append((node, path)) + + def build(self, tree: AppAssetFileTree, ctx: BuildContext) -> list[AssetItem]: + from core.skill.skill_manager import SkillManager + + if not self._nodes: + bundle = SkillBundle(assets_id=ctx.build_id, asset_tree=tree) + SkillManager.save_bundle(ctx.tenant_id, ctx.app_id, ctx.build_id, bundle) + self._bundle = bundle + return [] + + # Batch-load all skill draft content in one DB query (with S3 fallback on miss). + nodes_only = [node for node, _ in self._nodes] + raw_contents = self._accessor.bulk_load(nodes_only) + + # Parse documents — skip nodes whose draft content is still the empty + # placeholder written at creation time. + documents: dict[str, SkillDocument] = {} + for node, _ in self._nodes: + try: + raw = raw_contents.get(node.id) + if not raw: + continue + data = {"skill_id": node.id, **json.loads(raw)} + documents[node.id] = SkillDocument.model_validate(data) + except (FileNotFoundError, json.JSONDecodeError, TypeError, ValueError) as e: + logger.exception("Failed to load or parse skill document for node %s", node.id) + raise ValueError(f"Failed to load or parse skill document for node {node.id}") from e + + bundle = SkillBundleAssembler(tree).assemble_bundle(documents, ctx.build_id) + SkillManager.save_bundle(ctx.tenant_id, ctx.app_id, ctx.build_id, bundle) + self._bundle = bundle + + items: list[AssetItem] = [] + for node, path in self._nodes: + skill = bundle.get(node.id) + if skill is None: + continue + items.append( + AssetItem( + asset_id=node.id, + path=path, + file_name=node.name, + extension=node.extension or "", + storage_key="", + content=skill.content.encode("utf-8"), + ) + ) + return items diff --git a/api/core/app_assets/constants.py b/api/core/app_assets/constants.py new file mode 100644 index 0000000000..c6583989e8 --- /dev/null +++ b/api/core/app_assets/constants.py @@ -0,0 +1,8 @@ +from core.app.entities.app_asset_entities import AppAssetFileTree +from libs.attr_map import AttrKey + + +class AppAssetsAttrs: + # Skill artifact set + FILE_TREE = AttrKey("file_tree", AppAssetFileTree) + APP_ASSETS_ID = AttrKey("app_assets_id", str) diff --git a/api/core/app_assets/converters.py b/api/core/app_assets/converters.py new file mode 100644 index 0000000000..07ecf90775 --- /dev/null +++ b/api/core/app_assets/converters.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from core.app.entities.app_asset_entities import AppAssetFileTree, AssetNodeType +from core.app_assets.entities import AssetItem +from core.app_assets.storage import AssetPaths + + +def tree_to_asset_items(tree: AppAssetFileTree, tenant_id: str, app_id: str) -> list[AssetItem]: + """Convert AppAssetFileTree to list of AssetItem for packaging.""" + return [ + AssetItem( + asset_id=node.id, + path=tree.get_path(node.id), + file_name=node.name, + extension=node.extension or "", + storage_key=AssetPaths.draft(tenant_id, app_id, node.id), + ) + for node in tree.nodes + if node.node_type == AssetNodeType.FILE + ] diff --git a/api/core/app_assets/entities/__init__.py b/api/core/app_assets/entities/__init__.py new file mode 100644 index 0000000000..f33286fafb --- /dev/null +++ b/api/core/app_assets/entities/__init__.py @@ -0,0 +1,7 @@ +from .assets import AssetItem +from .skill import SkillAsset + +__all__ = [ + "AssetItem", + "SkillAsset", +] diff --git a/api/core/app_assets/entities/assets.py b/api/core/app_assets/entities/assets.py new file mode 100644 index 0000000000..fdd0c89768 --- /dev/null +++ b/api/core/app_assets/entities/assets.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass, field + + +@dataclass +class AssetItem: + """A single asset file produced by the build pipeline. + + When *content* is set the payload is available in-process and can be + written directly into a ZIP or uploaded to a sandbox VM without an + extra S3 round-trip. When *content* is ``None`` the caller should + fetch the bytes from *storage_key* (the traditional presigned-URL + path). + """ + + asset_id: str + path: str + file_name: str + extension: str + storage_key: str + content: bytes | None = field(default=None, repr=False) diff --git a/api/core/app_assets/entities/skill.py b/api/core/app_assets/entities/skill.py new file mode 100644 index 0000000000..c3442ebee3 --- /dev/null +++ b/api/core/app_assets/entities/skill.py @@ -0,0 +1,10 @@ +from collections.abc import Mapping +from dataclasses import dataclass, field +from typing import Any + +from .assets import AssetItem + + +@dataclass +class SkillAsset(AssetItem): + metadata: Mapping[str, Any] = field(default_factory=dict) diff --git a/api/core/app_assets/storage.py b/api/core/app_assets/storage.py new file mode 100644 index 0000000000..380f8daef1 --- /dev/null +++ b/api/core/app_assets/storage.py @@ -0,0 +1,68 @@ +"""App assets storage key generation. + +Provides AssetPaths facade for generating storage keys for app assets. +Storage instances are obtained via AppAssetService.get_storage(). +""" + +from __future__ import annotations + +from uuid import UUID + +_BASE = "app_assets" + + +def _check_uuid(value: str, name: str) -> None: + try: + UUID(value) + except (ValueError, TypeError) as e: + raise ValueError(f"{name} must be a valid UUID") from e + + +class AssetPaths: + """Facade for generating app asset storage keys.""" + + @staticmethod + def draft(tenant_id: str, app_id: str, node_id: str) -> str: + """app_assets/{tenant}/{app}/draft/{node_id}""" + _check_uuid(tenant_id, "tenant_id") + _check_uuid(app_id, "app_id") + _check_uuid(node_id, "node_id") + return f"{_BASE}/{tenant_id}/{app_id}/draft/{node_id}" + + @staticmethod + def build_zip(tenant_id: str, app_id: str, assets_id: str) -> str: + """app_assets/{tenant}/{app}/artifacts/{assets_id}.zip""" + _check_uuid(tenant_id, "tenant_id") + _check_uuid(app_id, "app_id") + _check_uuid(assets_id, "assets_id") + return f"{_BASE}/{tenant_id}/{app_id}/artifacts/{assets_id}.zip" + + @staticmethod + def skill_bundle(tenant_id: str, app_id: str, assets_id: str) -> str: + """app_assets/{tenant}/{app}/artifacts/{assets_id}/skill_artifact_set.json""" + _check_uuid(tenant_id, "tenant_id") + _check_uuid(app_id, "app_id") + _check_uuid(assets_id, "assets_id") + return f"{_BASE}/{tenant_id}/{app_id}/artifacts/{assets_id}/skill_artifact_set.json" + + @staticmethod + def source_zip(tenant_id: str, app_id: str, workflow_id: str) -> str: + """app_assets/{tenant}/{app}/sources/{workflow_id}.zip""" + _check_uuid(tenant_id, "tenant_id") + _check_uuid(app_id, "app_id") + _check_uuid(workflow_id, "workflow_id") + return f"{_BASE}/{tenant_id}/{app_id}/sources/{workflow_id}.zip" + + @staticmethod + def bundle_export(tenant_id: str, app_id: str, export_id: str) -> str: + """app_assets/{tenant}/{app}/bundle_exports/{export_id}.zip""" + _check_uuid(tenant_id, "tenant_id") + _check_uuid(app_id, "app_id") + _check_uuid(export_id, "export_id") + return f"{_BASE}/{tenant_id}/{app_id}/bundle_exports/{export_id}.zip" + + @staticmethod + def bundle_import(tenant_id: str, import_id: str) -> str: + """app_assets/{tenant}/imports/{import_id}.zip""" + _check_uuid(tenant_id, "tenant_id") + return f"{_BASE}/{tenant_id}/imports/{import_id}.zip" diff --git a/api/core/app_bundle/__init__.py b/api/core/app_bundle/__init__.py new file mode 100644 index 0000000000..7fb33b2b6d --- /dev/null +++ b/api/core/app_bundle/__init__.py @@ -0,0 +1 @@ +# App bundle utilities - manifest-driven import/export handled by AppBundleService diff --git a/api/core/sandbox/__init__.py b/api/core/sandbox/__init__.py new file mode 100644 index 0000000000..b2eded2c87 --- /dev/null +++ b/api/core/sandbox/__init__.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +import importlib +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .bash.dify_cli import ( + DifyCliBinary, + DifyCliConfig, + DifyCliEnvConfig, + DifyCliLocator, + DifyCliToolConfig, + ) + from .bash.session import SandboxBashSession + from .builder import SandboxBuilder, VMConfig + from .entities import AppAssets, DifyCli, SandboxProviderApiEntity, SandboxType + from .initializer import ( + AsyncSandboxInitializer, + SandboxInitializeContext, + SandboxInitializer, + SyncSandboxInitializer, + ) + from .initializer.app_assets_initializer import AppAssetsInitializer + from .initializer.dify_cli_initializer import DifyCliInitializer + from .initializer.draft_app_assets_initializer import DraftAppAssetsInitializer + from .sandbox import Sandbox + from .storage import ArchiveSandboxStorage, SandboxStorage + from .utils.debug import sandbox_debug + from .utils.encryption import create_sandbox_config_encrypter, masked_config + +__all__ = [ + "AppAssets", + "AppAssetsInitializer", + "ArchiveSandboxStorage", + "AsyncSandboxInitializer", + "DifyCli", + "DifyCliBinary", + "DifyCliConfig", + "DifyCliEnvConfig", + "DifyCliInitializer", + "DifyCliLocator", + "DifyCliToolConfig", + "DraftAppAssetsInitializer", + "Sandbox", + "SandboxBashSession", + "SandboxBuilder", + "SandboxInitializeContext", + "SandboxInitializer", + "SandboxProviderApiEntity", + "SandboxStorage", + "SandboxType", + "SyncSandboxInitializer", + "VMConfig", + "create_sandbox_config_encrypter", + "masked_config", + "sandbox_debug", +] + +_LAZY_IMPORTS = { + "AppAssets": ("core.sandbox.entities", "AppAssets"), + "AppAssetsInitializer": ("core.sandbox.initializer.app_assets_initializer", "AppAssetsInitializer"), + "AsyncSandboxInitializer": ("core.sandbox.initializer", "AsyncSandboxInitializer"), + "ArchiveSandboxStorage": ("core.sandbox.storage", "ArchiveSandboxStorage"), + "DifyCli": ("core.sandbox.entities", "DifyCli"), + "DifyCliBinary": ("core.sandbox.bash.dify_cli", "DifyCliBinary"), + "DifyCliConfig": ("core.sandbox.bash.dify_cli", "DifyCliConfig"), + "DifyCliEnvConfig": ("core.sandbox.bash.dify_cli", "DifyCliEnvConfig"), + "DifyCliInitializer": ("core.sandbox.initializer.dify_cli_initializer", "DifyCliInitializer"), + "DifyCliLocator": ("core.sandbox.bash.dify_cli", "DifyCliLocator"), + "DifyCliToolConfig": ("core.sandbox.bash.dify_cli", "DifyCliToolConfig"), + "DraftAppAssetsInitializer": ("core.sandbox.initializer.draft_app_assets_initializer", "DraftAppAssetsInitializer"), + "Sandbox": ("core.sandbox.sandbox", "Sandbox"), + "SandboxBashSession": ("core.sandbox.bash.session", "SandboxBashSession"), + "SandboxBuilder": ("core.sandbox.builder", "SandboxBuilder"), + "SandboxInitializeContext": ("core.sandbox.initializer", "SandboxInitializeContext"), + "SandboxInitializer": ("core.sandbox.initializer", "SandboxInitializer"), + "SandboxManager": ("core.sandbox.manager", "SandboxManager"), + "SandboxProviderApiEntity": ("core.sandbox.entities", "SandboxProviderApiEntity"), + "SandboxStorage": ("core.sandbox.storage", "SandboxStorage"), + "SandboxType": ("core.sandbox.entities", "SandboxType"), + "SyncSandboxInitializer": ("core.sandbox.initializer", "SyncSandboxInitializer"), + "VMConfig": ("core.sandbox.builder", "VMConfig"), + "create_sandbox_config_encrypter": ("core.sandbox.utils.encryption", "create_sandbox_config_encrypter"), + "masked_config": ("core.sandbox.utils.encryption", "masked_config"), + "sandbox_debug": ("core.sandbox.utils.debug", "sandbox_debug"), +} + + +def __getattr__(name: str): + if name not in _LAZY_IMPORTS: + raise AttributeError(f"module 'core.sandbox' has no attribute {name}") + module_path, attr_name = _LAZY_IMPORTS[name] + module = importlib.import_module(module_path) + value = getattr(module, attr_name) + globals()[name] = value + return value diff --git a/api/core/sandbox/bash/__init__.py b/api/core/sandbox/bash/__init__.py new file mode 100644 index 0000000000..fd69e39833 --- /dev/null +++ b/api/core/sandbox/bash/__init__.py @@ -0,0 +1,15 @@ +from .dify_cli import ( + DifyCliBinary, + DifyCliConfig, + DifyCliEnvConfig, + DifyCliLocator, + DifyCliToolConfig, +) + +__all__ = [ + "DifyCliBinary", + "DifyCliConfig", + "DifyCliEnvConfig", + "DifyCliLocator", + "DifyCliToolConfig", +] diff --git a/api/core/sandbox/bash/bash_tool.py b/api/core/sandbox/bash/bash_tool.py new file mode 100644 index 0000000000..f2573c9382 --- /dev/null +++ b/api/core/sandbox/bash/bash_tool.py @@ -0,0 +1,139 @@ +from collections.abc import Generator +from typing import Any + +from core.sandbox.entities import DifyCli +from core.tools.__base.tool import Tool +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ( + ToolDescription, + ToolEntity, + ToolIdentity, + ToolInvokeMessage, + ToolParameter, + ToolProviderType, +) +from core.virtual_environment.__base.helpers import submit_command, with_connection +from core.virtual_environment.__base.virtual_environment import VirtualEnvironment + +from ..utils.debug import sandbox_debug + +COMMAND_TIMEOUT_SECONDS = 60 * 60 * 2 # 2 hours, can be adjusted based on expected command execution times + +# Output truncation settings to avoid overwhelming model context +# 8000 chars ≈ 2000-2700 tokens, safe for models with 8K+ context +MAX_OUTPUT_LENGTH = 8000 +TRUNCATE_HEAD_LENGTH = 2500 # Keep beginning for context +TRUNCATE_TAIL_LENGTH = 2500 # Keep end for results/errors + + +def _truncate_output(output: str, name: str = "output") -> str: + """Truncate output if it exceeds the maximum length. + + Keeps the head and tail of the output to preserve context and final results. + """ + if len(output) <= MAX_OUTPUT_LENGTH: + return output + + omitted_length = len(output) - TRUNCATE_HEAD_LENGTH - TRUNCATE_TAIL_LENGTH + head = output[:TRUNCATE_HEAD_LENGTH] + tail = output[-TRUNCATE_TAIL_LENGTH:] + + return f"{head}\n\n... [{omitted_length} characters omitted from {name}] ...\n\n{tail}" + + +class SandboxBashTool(Tool): + def __init__(self, sandbox: VirtualEnvironment, tenant_id: str, tools_path: str) -> None: + self._sandbox = sandbox + self._tools_path = tools_path + + entity = ToolEntity( + identity=ToolIdentity( + author="Dify", + name="bash", + label=I18nObject(en_US="Bash", zh_Hans="Bash"), + provider="sandbox", + ), + parameters=[ + ToolParameter.get_simple_instance( + name="bash", + llm_description="The bash command to execute in current working directory", + typ=ToolParameter.ToolParameterType.STRING, + required=True, + ), + ], + description=ToolDescription( + human=I18nObject( + en_US="Execute bash commands in current working directory", + ), + llm="Execute bash commands in current working directory. " + "Use this tool to run shell commands, scripts, or interact with the system. " + "The command will be executed in the current working directory. " + "IMPORTANT: If you generate any output files (images, documents, etc.) that need to be " + "returned or referenced later, you MUST save them to the 'output/' directory " + "(e.g., 'mkdir -p output && cp result.png output/'). Only files in output/ will be collected.", + ), + ) + + runtime = ToolRuntime(tenant_id=tenant_id) + super().__init__(entity=entity, runtime=runtime) + + def tool_provider_type(self) -> ToolProviderType: + return ToolProviderType.BUILT_IN + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + ) -> Generator[ToolInvokeMessage, None, None]: + command = tool_parameters.get("bash", "") + if not command: + sandbox_debug("bash_tool", "parameters", tool_parameters) + yield self.create_text_message( + 'Error: No command provided. The "bash" parameter is required and must contain ' + 'the shell command to execute. Example: {"bash": "ls -la"}' + ) + return + + try: + with with_connection(self._sandbox) as conn: + # Build command with embedded environment variables + env_exports = ( + f"export PATH={self._tools_path}:/usr/local/bin:/usr/bin:/bin && " + f"export DIFY_CLI_CONFIG={self._tools_path}/{DifyCli.CONFIG_FILENAME} && " + ) + full_command = env_exports + command + + cmd_list = ["bash", "-c", full_command] + sandbox_debug("bash_tool", "cmd_list", cmd_list) + + future = submit_command( + self._sandbox, + conn, + cmd_list, + ) + timeout = COMMAND_TIMEOUT_SECONDS if COMMAND_TIMEOUT_SECONDS > 0 else None + result = future.result(timeout=timeout) + + stdout = result.stdout.decode("utf-8", errors="replace") if result.stdout else "" + stderr = result.stderr.decode("utf-8", errors="replace") if result.stderr else "" + + # Truncate long outputs to avoid overwhelming the model + stdout = _truncate_output(stdout, "stdout") + stderr = _truncate_output(stderr, "stderr") + + output_parts: list[str] = [] + if stdout: + output_parts.append(f"\n{stdout}") + if stderr: + output_parts.append(f"\n{stderr}") + + yield self.create_text_message("\n".join(output_parts)) + + except TimeoutError: + yield self.create_text_message(f"Error: Command timed out after {COMMAND_TIMEOUT_SECONDS}s") + except Exception as e: + yield self.create_text_message(f"Error: {e!s}") diff --git a/api/core/sandbox/bash/dify_cli.py b/api/core/sandbox/bash/dify_cli.py new file mode 100644 index 0000000000..a01b863486 --- /dev/null +++ b/api/core/sandbox/bash/dify_cli.py @@ -0,0 +1,164 @@ +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from pydantic import BaseModel, Field + +from core.session.cli_api import CliApiSession +from core.skill.entities import ToolDependencies, ToolReference +from core.tools.entities.tool_entities import ToolParameter, ToolProviderType +from core.virtual_environment.__base.entities import Arch, OperatingSystem +from graphon.model_runtime.utils.encoders import jsonable_encoder + +from ..entities import DifyCli + +if TYPE_CHECKING: + from core.tools.__base.tool import Tool + + +class DifyCliBinary(BaseModel): + operating_system: OperatingSystem = Field(alias="os") + arch: Arch + path: Path + + model_config = { + "populate_by_name": True, + "arbitrary_types_allowed": True, + } + + +class DifyCliLocator: + def __init__(self, root: str | Path | None = None) -> None: + from configs import dify_config + + if root is not None: + self._root = Path(root) + elif dify_config.SANDBOX_DIFY_CLI_ROOT: + self._root = Path(dify_config.SANDBOX_DIFY_CLI_ROOT) + else: + api_root = Path(__file__).resolve().parents[3] + self._root = api_root / "bin" + + def resolve(self, operating_system: OperatingSystem, arch: Arch) -> DifyCliBinary: + filename = DifyCli.PATH_PATTERN.format(os=operating_system.value, arch=arch.value) + candidate = self._root / filename + if not candidate.is_file(): + raise FileNotFoundError( + f"dify CLI binary not found: {candidate}. Configure SANDBOX_DIFY_CLI_ROOT or ensure the file exists." + ) + + return DifyCliBinary(os=operating_system, arch=arch, path=candidate) + + +class DifyCliEnvConfig(BaseModel): + files_url: str + cli_api_url: str + cli_api_session_id: str + cli_api_secret: str + + +class DifyCliToolConfig(BaseModel): + provider_type: str + enabled: bool = True + identity: dict[str, Any] + description: dict[str, Any] + parameters: list[dict[str, Any]] + + @classmethod + def transform_provider_type(cls, tool_provider_type: ToolProviderType) -> str: + provider_type = tool_provider_type + match tool_provider_type: + case ToolProviderType.BUILT_IN | ToolProviderType.PLUGIN: + provider_type = "builtin" + case ToolProviderType.MCP | ToolProviderType.WORKFLOW | ToolProviderType.API: + provider_type = provider_type + case _: + raise ValueError(f"Invalid tool provider type: {tool_provider_type}") + return provider_type + + @classmethod + def create_from_tool(cls, tool: Tool) -> DifyCliToolConfig: + return cls( + identity=to_json(tool.entity.identity), + provider_type=cls.transform_provider_type(tool.tool_provider_type()), + description=to_json(tool.entity.description), + parameters=[cls.transform_parameter(parameter) for parameter in tool.entity.parameters], + ) + + @classmethod + def transform_parameter(cls, parameter: ToolParameter) -> dict[str, Any]: + transformed_parameter = to_json(parameter) + transformed_parameter.pop("input_schema", None) + transformed_parameter.pop("form", None) + match parameter.type: + case ( + ToolParameter.ToolParameterType.SYSTEM_FILES + | ToolParameter.ToolParameterType.FILE + | ToolParameter.ToolParameterType.FILES + ): + return transformed_parameter + case _: + return transformed_parameter + + +class DifyCliToolReference(BaseModel): + id: str + tool_type: str + tool_name: str + tool_provider: str + credential_id: str | None = None + default_value: dict[str, Any] | None = None + + @classmethod + def create_from_tool_reference(cls, reference: ToolReference) -> DifyCliToolReference: + return cls( + id=reference.uuid, + tool_type=reference.type.value, + tool_name=reference.tool_name, + tool_provider=reference.provider, + credential_id=reference.credential_id, + default_value=reference.configuration.default_values() if reference.configuration else None, + ) + + +class DifyCliConfig(BaseModel): + env: DifyCliEnvConfig + tool_references: list[DifyCliToolReference] + tools: list[DifyCliToolConfig] + + @classmethod + def create( + cls, + session: CliApiSession, + tenant_id: str, + tool_deps: ToolDependencies, + ) -> DifyCliConfig: + from configs import dify_config + + cli_api_url = dify_config.CLI_API_URL + + return cls( + env=DifyCliEnvConfig( + files_url=dify_config.FILES_API_URL, + cli_api_url=cli_api_url, + cli_api_session_id=session.id, + cli_api_secret=session.secret, + ), + tool_references=[DifyCliToolReference.create_from_tool_reference(ref) for ref in tool_deps.references], + tools=[], + ) + + +def to_json(obj: Any) -> dict[str, Any]: + return jsonable_encoder(obj, exclude_unset=True, exclude_defaults=True, exclude_none=True) + + +__all__ = [ + "DifyCliBinary", + "DifyCliConfig", + "DifyCliEnvConfig", + "DifyCliLocator", + "DifyCliToolConfig", + "DifyCliToolReference", +] diff --git a/api/core/sandbox/bash/session.py b/api/core/sandbox/bash/session.py new file mode 100644 index 0000000000..8f8cc4b2d7 --- /dev/null +++ b/api/core/sandbox/bash/session.py @@ -0,0 +1,239 @@ +from __future__ import annotations + +import json +import logging +import mimetypes +import os +import shlex +from types import TracebackType + +from core.sandbox.sandbox import Sandbox +from core.session.cli_api import CliApiSession, CliApiSessionManager, CliContext +from core.skill.entities import ToolAccessPolicy +from core.skill.entities.tool_dependencies import ToolDependencies +from core.tools.signature import sign_tool_file +from core.tools.tool_file_manager import ToolFileManager +from core.virtual_environment.__base.helpers import pipeline +from graphon.file import File, FileTransferMethod, FileType + +from ..bash.dify_cli import DifyCliConfig +from ..entities import DifyCli +from .bash_tool import SandboxBashTool + +logger = logging.getLogger(__name__) + +SANDBOX_READY_TIMEOUT = 60 * 10 + +# Default output directory for sandbox-generated files +SANDBOX_OUTPUT_DIR = "output" +# Maximum number of files to collect from sandbox output +MAX_OUTPUT_FILES = 50 +# Maximum file size to collect (10MB) +MAX_OUTPUT_FILE_SIZE = 10 * 1024 * 1024 + + +class SandboxBashSession: + def __init__(self, *, sandbox: Sandbox, node_id: str, tools: ToolDependencies | None) -> None: + self._sandbox = sandbox + self._node_id = node_id + self._tools = tools + self._bash_tool: SandboxBashTool | None = None + self._cli_api_session: CliApiSession | None = None + self._tenant_id = sandbox.tenant_id + self._user_id = sandbox.user_id + self._app_id = sandbox.app_id + self._assets_id = sandbox.assets_id + + def __enter__(self) -> SandboxBashSession: + # Ensure sandbox initialization completes before any bash commands run. + self._sandbox.wait_ready(timeout=SANDBOX_READY_TIMEOUT) + cli = DifyCli(self._sandbox.id) + self._cli_api_session = CliApiSessionManager().create( + tenant_id=self._tenant_id, + user_id=self._user_id, + context=CliContext(tool_access=ToolAccessPolicy.from_dependencies(self._tools)), + ) + if self._tools is not None and not self._tools.is_empty(): + tools_path = self._setup_node_tools_directory(cli, self._node_id, self._tools, self._cli_api_session) + else: + tools_path = cli.global_tools_path + + self._bash_tool = SandboxBashTool( + sandbox=self._sandbox.vm, + tenant_id=self._tenant_id, + tools_path=tools_path, + ) + return self + + def _setup_node_tools_directory( + self, + cli: DifyCli, + node_id: str, + tools: ToolDependencies, + cli_api_session: CliApiSession, + ) -> str: + node_tools_path = cli.node_tools_path(node_id) + config_json = json.dumps( + DifyCliConfig.create(session=cli_api_session, tenant_id=self._tenant_id, tool_deps=tools).model_dump( + mode="json" + ), + ensure_ascii=False, + ) + config_path = shlex.quote(cli.node_config_path(node_id)) + + vm = self._sandbox.vm + # Merge mkdir + config write into a single pipeline to reduce round-trips. + ( + pipeline(vm) + .add(["mkdir", "-p", cli.global_tools_path], error_message="Failed to create global tools dir") + .add(["mkdir", "-p", node_tools_path], error_message="Failed to create node tools dir") + # Use a quoted heredoc (<<'EOF') so the shell performs no expansion on the + # content — safe regardless of $, `, \, or quotes inside the JSON. + .add( + ["sh", "-c", f"cat > {config_path} << '__DIFY_CFG__'\n{config_json}\n__DIFY_CFG__"], + error_message="Failed to write CLI config", + ) + .execute(raise_on_error=True) + ) + + pipeline(vm, cwd=node_tools_path).add( + [cli.bin_path, "init"], error_message="Failed to initialize Dify CLI" + ).execute(raise_on_error=True) + + logger.info( + "Node %s tools initialized, path=%s, tool_count=%d", node_id, node_tools_path, len(tools.references) + ) + return node_tools_path + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> bool: + try: + if self._cli_api_session is not None: + CliApiSessionManager().delete(self._cli_api_session.id) + logger.debug("Cleaned up SandboxSession session_id=%s", self._cli_api_session.id) + self._cli_api_session = None + except Exception: + logger.exception("Failed to cleanup SandboxSession") + return False + + @property + def bash_tool(self) -> SandboxBashTool: + if self._bash_tool is None: + raise RuntimeError("SandboxSession is not initialized") + return self._bash_tool + + def collect_output_files(self, output_dir: str = SANDBOX_OUTPUT_DIR) -> list[File]: + """ + Collect files from sandbox output directory and save them as ToolFiles. + + Scans the specified output directory in sandbox, downloads each file, + saves it as a ToolFile, and returns a list of File objects. The File + objects will have valid tool_file_id that can be referenced by subsequent + nodes via structured output. + + Args: + output_dir: Directory path in sandbox to scan for output files. + Defaults to "output" (relative to workspace). + + Returns: + List of File objects representing the collected files. + """ + vm = self._sandbox.vm + collected_files: list[File] = [] + + try: + file_states = vm.list_files(output_dir, limit=MAX_OUTPUT_FILES) + except Exception as exc: + # Output directory may not exist if no files were generated + logger.debug("Failed to list sandbox output files in %s: %s", output_dir, exc) + return collected_files + + tool_file_manager = ToolFileManager() + + for file_state in file_states: + # Skip files that are too large + if file_state.size > MAX_OUTPUT_FILE_SIZE: + logger.warning( + "Skipping sandbox output file %s: size %d exceeds limit %d", + file_state.path, + file_state.size, + MAX_OUTPUT_FILE_SIZE, + ) + continue + + try: + # file_state.path is already relative to working_path (e.g., "output/file.png") + file_content = vm.download_file(file_state.path) + file_binary = file_content.getvalue() + + # Determine mime type from extension + filename = os.path.basename(file_state.path) + mime_type, _ = mimetypes.guess_type(filename) + if not mime_type: + mime_type = "application/octet-stream" + + # Save as ToolFile + tool_file = tool_file_manager.create_file_by_raw( + user_id=self._user_id, + tenant_id=self._tenant_id, + conversation_id=None, + file_binary=file_binary, + mimetype=mime_type, + filename=filename, + ) + + # Determine file type from mime type + file_type = _get_file_type_from_mime(mime_type) + extension = os.path.splitext(filename)[1] if "." in filename else ".bin" + url = sign_tool_file(tool_file.id, extension) + + # Create File object with tool_file_id as related_id + file_obj = File( + id=tool_file.id, # Use tool_file_id as the File id for easy reference + tenant_id=self._tenant_id, + type=file_type, + transfer_method=FileTransferMethod.TOOL_FILE, + filename=filename, + extension=extension, + mime_type=mime_type, + size=len(file_binary), + related_id=tool_file.id, + url=url, + storage_key=tool_file.file_key, + ) + collected_files.append(file_obj) + + logger.info( + "Collected sandbox output file: %s -> tool_file_id=%s", + file_state.path, + tool_file.id, + ) + + except Exception as exc: + logger.warning("Failed to collect sandbox output file %s: %s", file_state.path, exc) + continue + + logger.info( + "Collected %d files from sandbox output directory %s", + len(collected_files), + output_dir, + ) + return collected_files + + +def _get_file_type_from_mime(mime_type: str) -> FileType: + """Determine FileType from mime type.""" + if mime_type.startswith("image/"): + return FileType.IMAGE + elif mime_type.startswith("video/"): + return FileType.VIDEO + elif mime_type.startswith("audio/"): + return FileType.AUDIO + elif "text" in mime_type or "pdf" in mime_type: + return FileType.DOCUMENT + else: + return FileType.CUSTOM diff --git a/api/core/sandbox/builder.py b/api/core/sandbox/builder.py new file mode 100644 index 0000000000..97b9f96fb8 --- /dev/null +++ b/api/core/sandbox/builder.py @@ -0,0 +1,222 @@ +from __future__ import annotations + +import logging +import threading +from collections.abc import Mapping, Sequence +from contextlib import nullcontext +from typing import TYPE_CHECKING, Any, cast + +from flask import Flask, current_app, has_app_context + +from core.entities.provider_entities import BasicProviderConfig +from core.virtual_environment.__base.virtual_environment import VirtualEnvironment + +from .entities.sandbox_type import SandboxType +from .initializer import AsyncSandboxInitializer, SandboxInitializeContext, SandboxInitializer, SyncSandboxInitializer +from .sandbox import Sandbox + +if TYPE_CHECKING: + from .storage.sandbox_storage import SandboxStorage + +logger = logging.getLogger(__name__) + + +def _get_sandbox_class(sandbox_type: SandboxType) -> type[VirtualEnvironment]: + match sandbox_type: + case SandboxType.DOCKER: + from core.virtual_environment.providers.docker_daemon_sandbox import DockerDaemonEnvironment + + return DockerDaemonEnvironment + case SandboxType.E2B: + from core.virtual_environment.providers.e2b_sandbox import E2BEnvironment + + return E2BEnvironment + case SandboxType.LOCAL: + from core.virtual_environment.providers.local_without_isolation import LocalVirtualEnvironment + + return LocalVirtualEnvironment + case SandboxType.SSH: + from core.virtual_environment.providers.ssh_sandbox import SSHSandboxEnvironment + + return SSHSandboxEnvironment + case SandboxType.AWS_CODE_INTERPRETER: + from core.virtual_environment.providers.aws_code_interpreter_sandbox import ( + AWSCodeInterpreterEnvironment, + ) + + return AWSCodeInterpreterEnvironment + case _: + raise ValueError(f"Unsupported sandbox type: {sandbox_type}") + + +class SandboxBuilder: + _tenant_id: str + _sandbox_type: SandboxType + _user_id: str | None + _app_id: str | None + _options: dict[str, Any] + _environments: dict[str, str] + _initializers: list[SandboxInitializer] + _storage: SandboxStorage | None + _assets_id: str | None + + def __init__(self, tenant_id: str, sandbox_type: SandboxType) -> None: + self._tenant_id = tenant_id + self._sandbox_type = sandbox_type + self._user_id = None + self._app_id = None + self._options = {} + self._environments = {} + self._initializers = [] + self._storage = None + self._assets_id = None + + def user(self, user_id: str) -> SandboxBuilder: + self._user_id = user_id + return self + + def app(self, app_id: str) -> SandboxBuilder: + self._app_id = app_id + return self + + def options(self, options: Mapping[str, Any]) -> SandboxBuilder: + self._options = dict(options) + return self + + def environments(self, environments: Mapping[str, str]) -> SandboxBuilder: + self._environments = dict(environments) + return self + + def initializer(self, initializer: SandboxInitializer) -> SandboxBuilder: + self._initializers.append(initializer) + return self + + def initializers(self, initializers: Sequence[SandboxInitializer]) -> SandboxBuilder: + self._initializers.extend(initializers) + return self + + def storage(self, storage: SandboxStorage, assets_id: str) -> SandboxBuilder: + self._storage = storage + self._assets_id = assets_id + return self + + def build(self) -> Sandbox: + """Create a sandbox and start background initialization. + + The builder is responsible for cleaning up any VM or sandbox that was + successfully created if a later setup step fails. + """ + if self._storage is None: + raise ValueError("storage is required, call .storage() before .build()") + if self._assets_id is None: + raise ValueError("assets_id is required, call .storage() before .build()") + if self._user_id is None: + raise ValueError("user_id is required, call .user() before .build()") + if self._app_id is None: + raise ValueError("app_id is required, call .app() before .build()") + + ctx = SandboxInitializeContext( + tenant_id=self._tenant_id, + app_id=self._app_id, + assets_id=self._assets_id, + user_id=self._user_id, + ) + vm: VirtualEnvironment | None = None + sandbox: Sandbox | None = None + try: + vm_class = _get_sandbox_class(self._sandbox_type) + vm = vm_class( + tenant_id=self._tenant_id, + options=self._options, + environments=self._environments, + user_id=self._user_id, + ) + vm.open_enviroment() + sandbox = Sandbox( + vm=vm, + storage=self._storage, + tenant_id=self._tenant_id, + user_id=self._user_id, + app_id=self._app_id, + assets_id=self._assets_id, + ) + + for init in self._initializers: + if isinstance(init, SyncSandboxInitializer): + init.initialize(sandbox, ctx) + except Exception as exc: + logger.exception( + "Failed to initialize sandbox synchronously: tenant_id=%s, app_id=%s", self._tenant_id, self._app_id + ) + if sandbox is not None: + sandbox.release() + elif vm is not None: + try: + vm.release_environment() + except Exception: + logger.exception("Failed to release sandbox VM during builder cleanup") + raise RuntimeError("Sandbox initialization failed") from exc + + # Run sandbox setup asynchronously so workflow execution can proceed. + # Capture the Flask app before starting the thread for database access. + flask_app: Flask | None = cast(Any, current_app)._get_current_object() if has_app_context() else None + + _sandbox: Sandbox = sandbox + + def initialize() -> None: + try: + app_context = flask_app.app_context() if flask_app is not None else nullcontext() + with app_context: + for init in self._initializers: + if not isinstance(init, AsyncSandboxInitializer): + continue + + if _sandbox.is_cancelled(): + return + init.initialize(_sandbox, ctx) + + if _sandbox.is_cancelled(): + return + _sandbox.mount() + _sandbox.mark_ready() + except Exception as exc: + try: + logger.exception( + "Failed to initialize sandbox: tenant_id=%s, app_id=%s", self._tenant_id, self._app_id + ) + _sandbox.release() + _sandbox.mark_failed(exc) + except Exception: + logger.exception( + "Failed to mark sandbox initialization failure: tenant_id=%s, app_id=%s", + self._tenant_id, + self._app_id, + ) + + # Background init completes or signals failure via sandbox state. + try: + threading.Thread(target=initialize, daemon=True).start() + except Exception: + logger.exception( + "Failed to start sandbox initialization thread: tenant_id=%s, app_id=%s", + self._tenant_id, + self._app_id, + ) + sandbox.release() + raise RuntimeError("Sandbox initialization failed") + return sandbox + + @staticmethod + def validate(vm_type: SandboxType, options: Mapping[str, Any]) -> None: + vm_class = _get_sandbox_class(vm_type) + vm_class.validate(options) + + @classmethod + def draft_id(cls, user_id: str) -> str: + return user_id + + +class VMConfig: + @staticmethod + def get_schema(vm_type: SandboxType) -> list[BasicProviderConfig]: + return _get_sandbox_class(vm_type).get_config_schema() diff --git a/api/core/sandbox/entities/__init__.py b/api/core/sandbox/entities/__init__.py new file mode 100644 index 0000000000..be562f4e5b --- /dev/null +++ b/api/core/sandbox/entities/__init__.py @@ -0,0 +1,13 @@ +from .config import AppAssets, DifyCli +from .files import SandboxFileDownloadTicket, SandboxFileNode +from .providers import SandboxProviderApiEntity +from .sandbox_type import SandboxType + +__all__ = [ + "AppAssets", + "DifyCli", + "SandboxFileDownloadTicket", + "SandboxFileNode", + "SandboxProviderApiEntity", + "SandboxType", +] diff --git a/api/core/sandbox/entities/config.py b/api/core/sandbox/entities/config.py new file mode 100644 index 0000000000..3bc5dd2512 --- /dev/null +++ b/api/core/sandbox/entities/config.py @@ -0,0 +1,58 @@ +from typing import Final + + +class DifyCli: + """Per-sandbox Dify CLI paths, namespaced under ``/tmp/.dify/{env_id}``. + + Every sandbox environment gets its own directory tree so that + concurrent sessions on the same host (e.g. SSH provider) never + collide on config files or CLI binaries. + + Class-level constants (``CONFIG_FILENAME``, ``PATH_PATTERN``) are + safe to share; all path attributes are instance-level and derived + from the ``env_id`` passed at construction time. + """ + + # --- class-level constants (no path component) --- + CONFIG_FILENAME: Final[str] = ".dify_cli.json" + PATH_PATTERN: Final[str] = "dify-cli-{os}-{arch}" + + # --- instance attributes --- + root: str + bin_dir: str + bin_path: str + tools_root: str + global_tools_path: str + global_config_path: str + + def __init__(self, env_id: str) -> None: + self.root = f"/tmp/.dify/{env_id}" + self.bin_dir = f"{self.root}/bin" + self.bin_path = f"{self.bin_dir}/dify" + self.tools_root = f"{self.root}/tools" + self.global_tools_path = f"{self.root}/tools/global" + self.global_config_path = f"{self.global_tools_path}/{DifyCli.CONFIG_FILENAME}" + + def node_tools_path(self, node_id: str) -> str: + return f"{self.tools_root}/{node_id}" + + def node_config_path(self, node_id: str) -> str: + return f"{self.node_tools_path(node_id)}/{DifyCli.CONFIG_FILENAME}" + + +class AppAssets: + """App Assets constants. + + ``PATH`` is a relative path resolved by each provider against its + own workspace root — already isolated. ``zip_path`` is an absolute + temp path and must be namespaced per environment to avoid collisions. + """ + + PATH: Final[str] = "skills" + + root: str + zip_path: str + + def __init__(self, env_id: str) -> None: + self.root = f"/tmp/.dify/{env_id}" + self.zip_path = f"{self.root}/assets.zip" diff --git a/api/core/sandbox/entities/files.py b/api/core/sandbox/entities/files.py new file mode 100644 index 0000000000..489e49595c --- /dev/null +++ b/api/core/sandbox/entities/files.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class SandboxFileNode: + path: str + is_dir: bool + size: int | None + mtime: int | None + extension: str | None + + +@dataclass(frozen=True) +class SandboxFileDownloadTicket: + download_url: str + expires_in: int + export_id: str diff --git a/api/core/sandbox/entities/providers.py b/api/core/sandbox/entities/providers.py new file mode 100644 index 0000000000..82c00cb144 --- /dev/null +++ b/api/core/sandbox/entities/providers.py @@ -0,0 +1,21 @@ +from collections.abc import Mapping +from typing import Any + +from pydantic import BaseModel, Field + + +class SandboxProviderApiEntity(BaseModel): + provider_type: str = Field(..., description="Provider type identifier") + is_system_configured: bool = Field(default=False) + is_tenant_configured: bool = Field(default=False) + is_active: bool = Field(default=False) + config: Mapping[str, Any] = Field(default_factory=dict) + config_schema: list[dict[str, Any]] = Field(default_factory=list) + + +class SandboxProviderEntity(BaseModel): + id: str = Field(..., description="Provider identifier") + provider_type: str = Field(..., description="Provider type identifier") + is_active: bool = Field(default=False) + config: Mapping[str, Any] = Field(default_factory=dict) + config_schema: list[dict[str, Any]] = Field(default_factory=list) diff --git a/api/core/sandbox/entities/sandbox_type.py b/api/core/sandbox/entities/sandbox_type.py new file mode 100644 index 0000000000..ed8694fbc8 --- /dev/null +++ b/api/core/sandbox/entities/sandbox_type.py @@ -0,0 +1,18 @@ +from enum import StrEnum + +from configs import dify_config + + +class SandboxType(StrEnum): + DOCKER = "docker" + E2B = "e2b" + LOCAL = "local" + SSH = "ssh" + AWS_CODE_INTERPRETER = "aws_code_interpreter" + + @classmethod + def get_all(cls) -> list[str]: + if dify_config.EDITION == "SELF_HOSTED": + return [p.value for p in cls] + else: + return [p.value for p in cls if p != SandboxType.LOCAL] diff --git a/api/core/sandbox/initializer/__init__.py b/api/core/sandbox/initializer/__init__.py new file mode 100644 index 0000000000..0d1f476cc4 --- /dev/null +++ b/api/core/sandbox/initializer/__init__.py @@ -0,0 +1,8 @@ +from .base import AsyncSandboxInitializer, SandboxInitializeContext, SandboxInitializer, SyncSandboxInitializer + +__all__ = [ + "AsyncSandboxInitializer", + "SandboxInitializeContext", + "SandboxInitializer", + "SyncSandboxInitializer", +] diff --git a/api/core/sandbox/initializer/app_asset_attrs_initializer.py b/api/core/sandbox/initializer/app_asset_attrs_initializer.py new file mode 100644 index 0000000000..ed4b51a01b --- /dev/null +++ b/api/core/sandbox/initializer/app_asset_attrs_initializer.py @@ -0,0 +1,19 @@ +import logging + +from core.app_assets.constants import AppAssetsAttrs +from core.sandbox.sandbox import Sandbox +from services.app_asset_package_service import AppAssetPackageService + +from .base import SandboxInitializeContext, SyncSandboxInitializer + +logger = logging.getLogger(__name__) + +APP_ASSETS_DOWNLOAD_TIMEOUT = 60 * 10 + + +class AppAssetAttrsInitializer(SyncSandboxInitializer): + def initialize(self, sandbox: Sandbox, ctx: SandboxInitializeContext) -> None: + # Load published app assets and unzip the artifact bundle. + app_assets = AppAssetPackageService.get_tenant_app_assets(ctx.tenant_id, ctx.assets_id) + sandbox.attrs.set(AppAssetsAttrs.FILE_TREE, app_assets.asset_tree) + sandbox.attrs.set(AppAssetsAttrs.APP_ASSETS_ID, ctx.assets_id) diff --git a/api/core/sandbox/initializer/app_assets_initializer.py b/api/core/sandbox/initializer/app_assets_initializer.py new file mode 100644 index 0000000000..a4d281dae7 --- /dev/null +++ b/api/core/sandbox/initializer/app_assets_initializer.py @@ -0,0 +1,59 @@ +import logging + +from core.app_assets.storage import AssetPaths +from core.sandbox.sandbox import Sandbox +from core.virtual_environment.__base.helpers import pipeline + +from ..entities import AppAssets +from .base import AsyncSandboxInitializer, SandboxInitializeContext + +logger = logging.getLogger(__name__) + +APP_ASSETS_DOWNLOAD_TIMEOUT = 60 * 10 + + +class AppAssetsInitializer(AsyncSandboxInitializer): + def initialize(self, sandbox: Sandbox, ctx: SandboxInitializeContext) -> None: + from services.app_asset_service import AppAssetService + + # Load published app assets and unzip the artifact bundle. + vm = sandbox.vm + assets = AppAssets(sandbox.id) + asset_storage = AppAssetService.get_storage() + key = AssetPaths.build_zip(ctx.tenant_id, ctx.app_id, ctx.assets_id) + download_url = asset_storage.get_download_url(key) + + ( + pipeline(vm) + .add( + ["mkdir", "-p", assets.root], + error_message="Failed to create assets temp directory", + ) + .add( + ["sh", "-c", 'curl -fsSL "$1" -o "$2"', "sh", download_url, assets.zip_path], + error_message="Failed to download assets zip", + ) + # Create the assets directory first to ensure it exists even if zip is empty + .add( + ["mkdir", "-p", AppAssets.PATH], + error_message="Failed to create assets directory", + ) + # unzip with silent error and return 1 if the zip is empty + # FIXME(Mairuis): should use a more robust way to check if the zip is empty + .add( + ["sh", "-c", 'unzip "$1" -d "$2" 2>/dev/null || [ $? -eq 1 ]', "sh", assets.zip_path, AppAssets.PATH], + error_message="Failed to unzip assets", + ) + # Ensure directories have execute permission for traversal and files are readable + .add( + ["sh", "-c", 'chmod -R u+rwX,go+rX "$1"', "sh", AppAssets.PATH], + error_message="Failed to set permissions on assets", + ) + .execute(timeout=APP_ASSETS_DOWNLOAD_TIMEOUT, raise_on_error=True) + ) + + logger.info( + "App assets initialized for app_id=%s, published_id=%s", + ctx.app_id, + ctx.assets_id, + ) diff --git a/api/core/sandbox/initializer/base.py b/api/core/sandbox/initializer/base.py new file mode 100644 index 0000000000..7021a814e5 --- /dev/null +++ b/api/core/sandbox/initializer/base.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +from core.sandbox.sandbox import Sandbox + +if TYPE_CHECKING: + from core.app_assets.entities.assets import AssetItem + + +@dataclass +class SandboxInitializeContext: + """Shared identity context passed to every ``SandboxInitializer``. + + Carries the common identity fields that virtually every initializer + needs, plus optional artefact slots that sync initializers populate + for async initializers to consume. + + Identity fields are immutable by convention; artefact slots are + written at most once during the sync phase and read during the + async phase. + """ + + tenant_id: str + app_id: str + assets_id: str + user_id: str + + # Populated by DraftAppAssetsInitializer (sync) for + # DraftAppAssetsDownloader (async) to download into the VM. + built_assets: list[AssetItem] | None = field(default=None) + + +class SandboxInitializer(ABC): + @abstractmethod + def initialize(self, sandbox: Sandbox, ctx: SandboxInitializeContext) -> None: ... + + +class SyncSandboxInitializer(SandboxInitializer): + """Marker class for initializers that must run before async setup.""" + + +class AsyncSandboxInitializer(SandboxInitializer): + """Marker class for initializers that can run in the background.""" diff --git a/api/core/sandbox/initializer/dify_cli_initializer.py b/api/core/sandbox/initializer/dify_cli_initializer.py new file mode 100644 index 0000000000..49b71f75f0 --- /dev/null +++ b/api/core/sandbox/initializer/dify_cli_initializer.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +import json +import logging +from io import BytesIO +from pathlib import Path + +from core.sandbox.sandbox import Sandbox +from core.session.cli_api import CliApiSessionManager, CliContext +from core.skill.constants import SkillAttrs +from core.skill.entities import ToolAccessPolicy +from core.virtual_environment.__base.helpers import pipeline + +from ..bash.dify_cli import DifyCliConfig, DifyCliLocator +from ..entities import DifyCli +from .base import AsyncSandboxInitializer, SandboxInitializeContext + +logger = logging.getLogger(__name__) + + +class DifyCliInitializer(AsyncSandboxInitializer): + def __init__(self, cli_root: str | Path | None = None) -> None: + self._locator = DifyCliLocator(root=cli_root) + self._tools: list[object] = [] + + def initialize(self, sandbox: Sandbox, ctx: SandboxInitializeContext) -> None: + vm = sandbox.vm + cli = DifyCli(sandbox.id) + + # FIXME(Mairuis): should be more robust, effectively. + binary = self._locator.resolve(vm.metadata.os, vm.metadata.arch) + + pipeline(vm).add(["mkdir", "-p", cli.bin_dir], error_message="Failed to create dify CLI directory").execute( + raise_on_error=True + ) + + vm.upload_file(cli.bin_path, BytesIO(binary.path.read_bytes())) + + pipeline(vm).add( + [ + "sh", + "-c", + f"cat '{cli.bin_path}' > '{cli.bin_path}.tmp' && " + f"mv '{cli.bin_path}.tmp' '{cli.bin_path}' && " + f"chmod +x '{cli.bin_path}'", + ], + error_message="Failed to mark dify CLI as executable", + ).execute(raise_on_error=True) + + logger.info("Dify CLI uploaded to sandbox, path=%s", cli.bin_path) + + bundle = sandbox.attrs.get(SkillAttrs.BUNDLE) + if bundle is None or bundle.get_tool_dependencies().is_empty(): + logger.info("No tools found in bundle for assets_id=%s", ctx.assets_id) + return + + global_cli_session = CliApiSessionManager().create( + tenant_id=ctx.tenant_id, + user_id=ctx.user_id, + context=CliContext(tool_access=ToolAccessPolicy.from_dependencies(bundle.get_tool_dependencies())), + ) + + pipeline(vm).add( + ["mkdir", "-p", cli.global_tools_path], error_message="Failed to create global tools dir" + ).execute(raise_on_error=True) + + config = DifyCliConfig.create(global_cli_session, ctx.tenant_id, bundle.get_tool_dependencies()) + config_json = json.dumps(config.model_dump(mode="json"), ensure_ascii=False) + config_path = cli.global_config_path + vm.upload_file(config_path, BytesIO(config_json.encode("utf-8"))) + + pipeline(vm, cwd=cli.global_tools_path).add( + [cli.bin_path, "init"], error_message="Failed to initialize Dify CLI" + ).execute(raise_on_error=True) + + logger.info("Global tools initialized, path=%s, tool_count=%d", cli.global_tools_path, len(self._tools)) diff --git a/api/core/sandbox/initializer/draft_app_assets_initializer.py b/api/core/sandbox/initializer/draft_app_assets_initializer.py new file mode 100644 index 0000000000..02219a3374 --- /dev/null +++ b/api/core/sandbox/initializer/draft_app_assets_initializer.py @@ -0,0 +1,89 @@ +"""Synchronous initializer that compiles draft app assets. + +Unlike ``AppAssetsInitializer`` (which downloads a pre-built ZIP for +published assets), this initializer runs the build pipeline on the fly +so that ``.md`` skill documents are compiled and their resolved content +is embedded directly into the download script — avoiding the S3 +round-trip that was previously required for resolved keys. + +Execution order: + ``DraftAppAssetsInitializer`` (sync) compiles assets and publishes + the ``SkillBundle`` to ``sandbox.attrs`` in-memory, so the + downstream ``SkillInitializer`` can skip the Redis/S3 round-trip. + ``DraftAppAssetsDownloader`` (async) then pushes the compiled + artefacts into the sandbox VM in the background. +""" + +import logging + +from core.app_assets.builder.base import BuildContext +from core.app_assets.builder.file_builder import FileBuilder +from core.app_assets.builder.pipeline import AssetBuildPipeline +from core.app_assets.builder.skill_builder import SkillBuilder +from core.app_assets.constants import AppAssetsAttrs +from core.sandbox.entities import AppAssets +from core.sandbox.sandbox import Sandbox +from core.sandbox.services import AssetDownloadService +from core.skill import SkillAttrs +from core.virtual_environment.__base.helpers import pipeline +from services.app_asset_service import AppAssetService + +from .base import AsyncSandboxInitializer, SandboxInitializeContext, SyncSandboxInitializer + +logger = logging.getLogger(__name__) + + +class DraftAppAssetsInitializer(SyncSandboxInitializer): + """Compile draft assets and publish the ``SkillBundle`` to attrs. + + The build pipeline compiles ``.md`` skill files in-process. + The resulting ``SkillBundle`` is persisted to Redis/S3 (by + ``SkillBuilder``) **and** written to ``sandbox.attrs[BUNDLE]`` + so that ``SkillInitializer`` can read it without a round-trip. + Built asset items are stored on ``ctx.built_assets`` for the + async ``DraftAppAssetsDownloader`` to consume. + """ + + def initialize(self, sandbox: Sandbox, ctx: SandboxInitializeContext) -> None: + tree = sandbox.attrs.get(AppAssetsAttrs.FILE_TREE) + + # --- 1. Run the build pipeline (SkillBuilder compiles .md inline) --- + accessor = AppAssetService.get_accessor(ctx.tenant_id, ctx.app_id) + skill_builder = SkillBuilder(accessor=accessor) + build_pipeline = AssetBuildPipeline([skill_builder, FileBuilder()]) + build_ctx = BuildContext(tenant_id=ctx.tenant_id, app_id=ctx.app_id, build_id=ctx.assets_id) + built_assets = build_pipeline.build_all(tree, build_ctx) + ctx.built_assets = built_assets + + # Publish the in-memory bundle so SkillInitializer skips Redis/S3. + if skill_builder.bundle is not None: + sandbox.attrs.set(SkillAttrs.BUNDLE, skill_builder.bundle) + + +class DraftAppAssetsDownloader(AsyncSandboxInitializer): + """Download the compiled assets into the sandbox VM. + + The download script is generated by ``DraftAppAssetsInitializer`` and + includes inline base64 content for compiled skills, as well as + presigned URLs for other files. + """ + + _TIMEOUT = 600 # 10 minutes + + def initialize(self, sandbox: Sandbox, ctx: SandboxInitializeContext) -> None: + if not ctx.built_assets: + logger.debug("No built assets found for assets_id=%s", ctx.assets_id) + return + + download_items = AppAssetService.to_download_items(ctx.built_assets) + script = AssetDownloadService.build_download_script(download_items, AppAssets.PATH) + pipeline(sandbox.vm).add( + ["sh", "-c", script], + error_message="Failed to download draft assets", + ).execute(timeout=self._TIMEOUT, raise_on_error=True) + + logger.info( + "Draft app assets initialized for app_id=%s, assets_id=%s", + ctx.app_id, + ctx.assets_id, + ) diff --git a/api/core/sandbox/initializer/skill_initializer.py b/api/core/sandbox/initializer/skill_initializer.py new file mode 100644 index 0000000000..b6983d92c8 --- /dev/null +++ b/api/core/sandbox/initializer/skill_initializer.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +import logging + +from core.sandbox.sandbox import Sandbox +from core.skill import SkillAttrs +from core.skill.skill_manager import SkillManager + +from .base import SandboxInitializeContext, SyncSandboxInitializer + +logger = logging.getLogger(__name__) + + +class SkillInitializer(SyncSandboxInitializer): + """Ensure ``sandbox.attrs[BUNDLE]`` is populated for downstream consumers. + + In the draft path ``DraftAppAssetsInitializer`` already sets the + bundle on attrs from the in-memory build result, so this initializer + becomes a no-op. In the published path no prior initializer sets + it, so we fall back to ``SkillManager.load_bundle()`` (Redis/S3). + """ + + def initialize(self, sandbox: Sandbox, ctx: SandboxInitializeContext) -> None: + # Draft path: bundle already populated by DraftAppAssetsInitializer. + if sandbox.attrs.has(SkillAttrs.BUNDLE): + return + + # Published path: load from Redis/S3. + bundle = SkillManager.load_bundle( + ctx.tenant_id, + ctx.app_id, + ctx.assets_id, + ) + sandbox.attrs.set(SkillAttrs.BUNDLE, bundle) diff --git a/api/core/sandbox/inspector/__init__.py b/api/core/sandbox/inspector/__init__.py new file mode 100644 index 0000000000..e259a158a2 --- /dev/null +++ b/api/core/sandbox/inspector/__init__.py @@ -0,0 +1,11 @@ +from core.sandbox.inspector.archive_source import SandboxFileArchiveSource +from core.sandbox.inspector.base import SandboxFileSource +from core.sandbox.inspector.browser import SandboxFileBrowser +from core.sandbox.inspector.runtime_source import SandboxFileRuntimeSource + +__all__ = [ + "SandboxFileArchiveSource", + "SandboxFileBrowser", + "SandboxFileRuntimeSource", + "SandboxFileSource", +] diff --git a/api/core/sandbox/inspector/archive_source.py b/api/core/sandbox/inspector/archive_source.py new file mode 100644 index 0000000000..92c7213826 --- /dev/null +++ b/api/core/sandbox/inspector/archive_source.py @@ -0,0 +1,169 @@ +from __future__ import annotations + +import logging +import os +from typing import TYPE_CHECKING +from uuid import uuid4 + +from core.sandbox.entities.files import SandboxFileDownloadTicket, SandboxFileNode +from core.sandbox.inspector.base import SandboxFileSource +from core.sandbox.inspector.script_utils import ( + build_detect_kind_command, + build_list_command, + build_upload_command, + guess_content_type, + parse_kind_output, + parse_list_output, +) +from core.sandbox.storage import SandboxFilePaths +from core.virtual_environment.__base.exec import PipelineExecutionError +from core.virtual_environment.__base.helpers import pipeline +from extensions.ext_storage import storage + +if TYPE_CHECKING: + from core.zip_sandbox import ZipSandbox + +logger = logging.getLogger(__name__) + + +class SandboxFileArchiveSource(SandboxFileSource): + def _get_archive_download_url(self) -> str: + """Get a pre-signed download URL for the sandbox archive.""" + from extensions.storage.file_presign_storage import FilePresignStorage + + storage_key = SandboxFilePaths.archive(self._tenant_id, self._app_id, self._sandbox_id) + if not storage.exists(storage_key): + raise ValueError("Sandbox archive not found") + presign_storage = FilePresignStorage(storage.storage_runner) + return presign_storage.get_download_url(storage_key, self._EXPORT_EXPIRES_IN_SECONDS) + + def _create_zip_sandbox(self) -> ZipSandbox: + """Create a ZipSandbox instance for archive operations.""" + from core.zip_sandbox import ZipSandbox + + return ZipSandbox(tenant_id=self._tenant_id, user_id="system", app_id=self._app_id) + + def exists(self) -> bool: + """Check if the sandbox archive exists in storage.""" + storage_key = SandboxFilePaths.archive(self._tenant_id, self._app_id, self._sandbox_id) + return storage.exists(storage_key) + + def list_files(self, *, path: str, recursive: bool) -> list[SandboxFileNode]: + archive_url = self._get_archive_download_url() + with self._create_zip_sandbox() as zs: + # Download and extract the archive + archive_path = zs.download_archive(archive_url, path="workspace.tar.gz") + zs.untar(archive_path=archive_path, dest_dir="workspace") + + # List files using Python script in sandbox + try: + list_path = f"workspace/{path}" if path not in (".", "") else "workspace" + results = ( + pipeline(zs.vm) + .add( + build_list_command(list_path, recursive), + error_message="Failed to list sandbox files", + ) + .execute(timeout=self._LIST_TIMEOUT_SECONDS, raise_on_error=True) + ) + except PipelineExecutionError as exc: + raise RuntimeError(str(exc)) from exc + + raw = parse_list_output(results[0].stdout) + + entries: list[SandboxFileNode] = [] + for item in raw: + item_path = str(item.get("path")) + # Strip the "workspace/" prefix from paths + if item_path.startswith("workspace/"): + item_path = item_path[len("workspace/") :] + elif item_path == "workspace": + continue # Skip the workspace directory itself + + item_is_dir = bool(item.get("is_dir")) + extension = None + if not item_is_dir: + ext = os.path.splitext(item_path)[1] + extension = ext or None + entries.append( + SandboxFileNode( + path=item_path, + is_dir=item_is_dir, + size=item.get("size"), + mtime=item.get("mtime"), + extension=extension, + ) + ) + return sorted(entries, key=lambda e: e.path) + + def download_file(self, *, path: str) -> SandboxFileDownloadTicket: + """Download a file or directory from the archived sandbox. + + Uses direct upload from sandbox to storage via presigned URL, avoiding + data transfer through the service layer. This preserves binary integrity + (no text encoding issues) and reduces bandwidth overhead. + """ + from services.sandbox.sandbox_file_service import SandboxFileService + + archive_url = self._get_archive_download_url() + export_name = os.path.basename(path.rstrip("/")) or "workspace" + export_id = uuid4().hex + + with self._create_zip_sandbox() as zs: + archive_path = zs.download_archive(archive_url, path="workspace.tar.gz") + zs.untar(archive_path=archive_path, dest_dir="workspace") + + target_path = f"workspace/{path}" if path not in (".", "") else "workspace" + try: + results = ( + pipeline(zs.vm) + .add( + build_detect_kind_command(target_path), + error_message="Failed to check path in sandbox", + ) + .execute(timeout=self._LIST_TIMEOUT_SECONDS, raise_on_error=True) + ) + except PipelineExecutionError as exc: + raise ValueError(str(exc)) from exc + + kind = parse_kind_output(results[0].stdout, not_found_message="File not found in sandbox archive") + + sandbox_storage = SandboxFileService.get_storage() + is_file = kind == "file" + filename = (os.path.basename(path) or "file") if is_file else f"{export_name}.tar.gz" + export_key = SandboxFilePaths.export(self._tenant_id, self._app_id, self._sandbox_id, export_id) + upload_url = sandbox_storage.get_upload_url(export_key, self._EXPORT_EXPIRES_IN_SECONDS) + content_type = guess_content_type(filename) + + # Build pipeline: for directories, tar first then upload; for files, upload directly + archive_temp = f"/tmp/{export_id}.tar.gz" + src_path = target_path if is_file else archive_temp + tar_src = path if path not in (".", "") else "." + + try: + ( + pipeline(zs.vm) + .add( + ["tar", "-czf", archive_temp, "-C", "workspace", tar_src], + error_message="Failed to archive directory", + on=not is_file, + ) + .add( + build_upload_command(src_path, upload_url, content_type=content_type), + error_message="Failed to upload file", + ) + .add(["rm", "-f", archive_temp], on=not is_file) + .execute(timeout=self._UPLOAD_TIMEOUT_SECONDS, raise_on_error=True) + ) + except PipelineExecutionError as exc: + raise RuntimeError(str(exc)) from exc + + download_url = sandbox_storage.get_download_url( + export_key, self._EXPORT_EXPIRES_IN_SECONDS, download_filename=filename + ) + + return SandboxFileDownloadTicket( + download_url=download_url, + expires_in=self._EXPORT_EXPIRES_IN_SECONDS, + export_id=export_id, + ) diff --git a/api/core/sandbox/inspector/base.py b/api/core/sandbox/inspector/base.py new file mode 100644 index 0000000000..76f92a4f97 --- /dev/null +++ b/api/core/sandbox/inspector/base.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +import abc + +from core.sandbox.entities.files import SandboxFileDownloadTicket, SandboxFileNode + + +class SandboxFileSource(abc.ABC): + _LIST_TIMEOUT_SECONDS = 30 + _UPLOAD_TIMEOUT_SECONDS = 60 * 10 + _EXPORT_EXPIRES_IN_SECONDS = 60 * 10 + + def __init__(self, *, tenant_id: str, app_id: str, sandbox_id: str): + self._tenant_id = tenant_id + self._app_id = app_id + self._sandbox_id = sandbox_id + + @abc.abstractmethod + def exists(self) -> bool: + """Check if the sandbox source exists and is available. + + Returns: + True if the sandbox source exists and can be accessed, False otherwise. + """ + raise NotImplementedError + + @abc.abstractmethod + def list_files(self, *, path: str, recursive: bool) -> list[SandboxFileNode]: + raise NotImplementedError + + @abc.abstractmethod + def download_file(self, *, path: str) -> SandboxFileDownloadTicket: + raise NotImplementedError diff --git a/api/core/sandbox/inspector/browser.py b/api/core/sandbox/inspector/browser.py new file mode 100644 index 0000000000..d94945606b --- /dev/null +++ b/api/core/sandbox/inspector/browser.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from pathlib import PurePosixPath + +from core.sandbox.entities.files import SandboxFileDownloadTicket, SandboxFileNode +from core.sandbox.inspector.archive_source import SandboxFileArchiveSource +from core.sandbox.inspector.base import SandboxFileSource + + +class SandboxFileBrowser: + def __init__(self, *, tenant_id: str, app_id: str, sandbox_id: str): + self._tenant_id = tenant_id + self._app_id = app_id + self._sandbox_id = sandbox_id + + @staticmethod + def _normalize_workspace_path(path: str | None) -> str: + raw = (path or ".").strip() + if raw == "": + raw = "." + + p = PurePosixPath(raw) + if p.is_absolute(): + raise ValueError("path must be relative") + if any(part == ".." for part in p.parts): + raise ValueError("path must not contain '..'") + + normalized = str(p) + return "." if normalized in (".", "") else normalized + + def _backend(self) -> SandboxFileSource: + return SandboxFileArchiveSource( + tenant_id=self._tenant_id, + app_id=self._app_id, + sandbox_id=self._sandbox_id, + ) + + def exists(self) -> bool: + """Check if the sandbox source exists and is available.""" + return self._backend().exists() + + def list_files(self, *, path: str | None = None, recursive: bool = False) -> list[SandboxFileNode]: + workspace_path = self._normalize_workspace_path(path) + return self._backend().list_files(path=workspace_path, recursive=recursive) + + def download_file(self, *, path: str) -> SandboxFileDownloadTicket: + workspace_path = self._normalize_workspace_path(path) + return self._backend().download_file(path=workspace_path) diff --git a/api/core/sandbox/inspector/runtime_source.py b/api/core/sandbox/inspector/runtime_source.py new file mode 100644 index 0000000000..ae8294aca9 --- /dev/null +++ b/api/core/sandbox/inspector/runtime_source.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +import logging +import os +from uuid import uuid4 + +from core.sandbox.entities.files import SandboxFileDownloadTicket, SandboxFileNode +from core.sandbox.inspector.base import SandboxFileSource +from core.sandbox.inspector.script_utils import ( + build_detect_kind_command, + build_list_command, + build_upload_command, + guess_content_type, + parse_kind_output, + parse_list_output, +) +from core.sandbox.storage import SandboxFilePaths +from core.virtual_environment.__base.exec import PipelineExecutionError +from core.virtual_environment.__base.helpers import pipeline +from core.virtual_environment.__base.virtual_environment import VirtualEnvironment + +logger = logging.getLogger(__name__) + + +class SandboxFileRuntimeSource(SandboxFileSource): + def __init__(self, *, tenant_id: str, app_id: str, sandbox_id: str, runtime: VirtualEnvironment): + super().__init__(tenant_id=tenant_id, app_id=app_id, sandbox_id=sandbox_id) + self._runtime = runtime + + def exists(self) -> bool: + """Check if the sandbox runtime exists and is available.""" + return self._runtime is not None + + def list_files(self, *, path: str, recursive: bool) -> list[SandboxFileNode]: + try: + results = ( + pipeline(self._runtime) + .add( + build_list_command(path, recursive), + error_message="Failed to list sandbox files", + ) + .execute(timeout=self._LIST_TIMEOUT_SECONDS, raise_on_error=True) + ) + except PipelineExecutionError as exc: + raise RuntimeError(str(exc)) from exc + + raw = parse_list_output(results[0].stdout) + + entries: list[SandboxFileNode] = [] + for item in raw: + item_path = str(item.get("path")) + item_is_dir = bool(item.get("is_dir")) + extension = None + if not item_is_dir: + ext = os.path.splitext(item_path)[1] + extension = ext or None + entries.append( + SandboxFileNode( + path=item_path, + is_dir=item_is_dir, + size=item.get("size"), + mtime=item.get("mtime"), + extension=extension, + ) + ) + return entries + + def download_file(self, *, path: str) -> SandboxFileDownloadTicket: + from services.sandbox.sandbox_file_service import SandboxFileService + + try: + results = ( + pipeline(self._runtime) + .add( + build_detect_kind_command(path), + error_message="Failed to check path in sandbox", + ) + .execute(timeout=self._LIST_TIMEOUT_SECONDS, raise_on_error=True) + ) + except PipelineExecutionError as exc: + raise ValueError(str(exc)) from exc + + kind = parse_kind_output(results[0].stdout, not_found_message="File not found in sandbox") + + export_name = os.path.basename(path.rstrip("/")) or "workspace" + filename = f"{export_name}.tar.gz" if kind == "dir" else (os.path.basename(path) or "file") + export_id = uuid4().hex + export_key = SandboxFilePaths.export( + self._tenant_id, + self._app_id, + self._sandbox_id, + export_id, + ) + + sandbox_storage = SandboxFileService.get_storage() + upload_url = sandbox_storage.get_upload_url(export_key, self._EXPORT_EXPIRES_IN_SECONDS) + content_type = guess_content_type(filename) + + if kind == "dir": + archive_path = f"/tmp/{export_id}.tar.gz" + try: + ( + pipeline(self._runtime) + .add( + ["tar", "-czf", archive_path, "-C", ".", path], + error_message="Failed to archive directory in sandbox", + ) + .add( + build_upload_command(archive_path, upload_url, content_type=content_type), + error_message="Failed to upload directory archive from sandbox", + ) + .execute(timeout=self._UPLOAD_TIMEOUT_SECONDS, raise_on_error=True) + ) + except PipelineExecutionError as exc: + raise RuntimeError(str(exc)) from exc + finally: + try: + pipeline(self._runtime).add(["rm", "-f", archive_path]).execute(timeout=self._LIST_TIMEOUT_SECONDS) + except Exception as exc: + # Best-effort cleanup; do not fail the download on cleanup issues. + logger.debug("Failed to cleanup temp archive %s: %s", archive_path, exc) + else: + try: + ( + pipeline(self._runtime) + .add( + build_upload_command(path, upload_url, content_type=content_type), + error_message="Failed to upload file from sandbox", + ) + .execute(timeout=self._UPLOAD_TIMEOUT_SECONDS, raise_on_error=True) + ) + except PipelineExecutionError as exc: + raise RuntimeError(str(exc)) from exc + + download_url = sandbox_storage.get_download_url(export_key, self._EXPORT_EXPIRES_IN_SECONDS) + return SandboxFileDownloadTicket( + download_url=download_url, + expires_in=self._EXPORT_EXPIRES_IN_SECONDS, + export_id=export_id, + ) diff --git a/api/core/sandbox/inspector/script_utils.py b/api/core/sandbox/inspector/script_utils.py new file mode 100644 index 0000000000..10c21d5d8e --- /dev/null +++ b/api/core/sandbox/inspector/script_utils.py @@ -0,0 +1,118 @@ +"""Shared helpers for sandbox inspector shell commands.""" + +from __future__ import annotations + +import json +import mimetypes +from typing import TypedDict, cast + +_PYTHON_EXEC_CMD = 'if command -v python3 >/dev/null 2>&1; then py=python3; else py=python; fi; "$py" -c "$0" "$@"' +_LIST_SCRIPT = r""" +import json +import os +import sys + +path = sys.argv[1] +recursive = sys.argv[2] == "1" + +def norm(rel: str) -> str: + rel = rel.replace("\\\\", "/") + rel = rel.lstrip("./") + return rel or "." + +def stat_entry(full_path: str, rel_path: str) -> dict[str, object]: + st = os.stat(full_path) + is_dir = os.path.isdir(full_path) + return { + "path": norm(rel_path), + "is_dir": is_dir, + "size": None if is_dir else int(st.st_size), + "mtime": int(st.st_mtime), + } + +entries = [] +if recursive: + for root, dirs, files in os.walk(path): + for d in dirs: + fp = os.path.join(root, d) + rp = os.path.relpath(fp, ".") + entries.append(stat_entry(fp, rp)) + for f in files: + fp = os.path.join(root, f) + rp = os.path.relpath(fp, ".") + entries.append(stat_entry(fp, rp)) +else: + if os.path.isfile(path): + rel_path = os.path.relpath(path, ".") + entries.append(stat_entry(path, rel_path)) + else: + for item in os.scandir(path): + rel_path = os.path.relpath(item.path, ".") + entries.append(stat_entry(item.path, rel_path)) + +print(json.dumps(entries)) +""" + + +class ListedEntry(TypedDict): + path: str + is_dir: bool + size: int | None + mtime: int + + +def build_list_command(path: str, recursive: bool) -> list[str]: + return [ + "sh", + "-c", + _PYTHON_EXEC_CMD, + _LIST_SCRIPT, + path, + "1" if recursive else "0", + ] + + +def parse_list_output(stdout: bytes) -> list[ListedEntry]: + try: + raw = json.loads(stdout.decode("utf-8")) + except Exception as exc: + raise RuntimeError("Malformed sandbox file list output") from exc + if not isinstance(raw, list): + raise RuntimeError("Malformed sandbox file list output") + return cast(list[ListedEntry], raw) + + +def build_detect_kind_command(path: str) -> list[str]: + return [ + "sh", + "-c", + 'if [ -d "$1" ]; then echo dir; elif [ -f "$1" ]; then echo file; else exit 2; fi', + "sh", + path, + ] + + +def parse_kind_output(stdout: bytes, *, not_found_message: str) -> str: + kind = stdout.decode("utf-8", errors="replace").strip() + if kind not in ("dir", "file"): + raise ValueError(not_found_message) + return kind + + +def guess_content_type(filename: str) -> str | None: + content_type, _ = mimetypes.guess_type(filename, strict=False) + if content_type is None: + return None + if content_type.startswith("text/"): + return f"{content_type}; charset=utf-8" + if content_type == "application/json": + return "application/json; charset=utf-8" + return content_type + + +def build_upload_command(src_path: str, upload_url: str, *, content_type: str | None) -> list[str]: + command = ["curl", "-s", "-f", "-X", "PUT", "-T", src_path] + if content_type: + command.extend(["-H", f"Content-Type: {content_type}"]) + command.append(upload_url) + return command diff --git a/api/core/sandbox/sandbox.py b/api/core/sandbox/sandbox.py new file mode 100644 index 0000000000..7964acd8c0 --- /dev/null +++ b/api/core/sandbox/sandbox.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +import logging +import threading +from typing import TYPE_CHECKING +from uuid import uuid4 + +from libs.attr_map import AttrMap + +if TYPE_CHECKING: + from core.sandbox.storage.sandbox_storage import SandboxStorage + from core.virtual_environment.__base.virtual_environment import VirtualEnvironment + +logger = logging.getLogger(__name__) + + +class Sandbox: + """Represents a single sandbox environment. + + Each ``Sandbox`` owns a stable, path-safe ``id`` (a 32-char hex + UUID4) that is independent of the underlying provider's environment + ID. Use ``sandbox.id`` for any path or resource namespacing + (e.g. ``DifyCli(sandbox.id)``). + + The raw provider identifier is still accessible via + ``sandbox.vm.metadata.id`` when needed (logging, API calls back to + the provider, etc.). + """ + + def __init__( + self, + *, + vm: VirtualEnvironment, + storage: SandboxStorage, + tenant_id: str, + user_id: str, + app_id: str, + assets_id: str, + ) -> None: + self._id = uuid4().hex + self._vm = vm + self._storage = storage + self._tenant_id = tenant_id + self._user_id = user_id + self._app_id = app_id + self._assets_id = assets_id + self._attributes = AttrMap() + self._ready_event = threading.Event() + self._cancel_event = threading.Event() + self._init_error: Exception | None = None + + @property + def id(self) -> str: + """Stable, path-safe identifier for this sandbox (UUID4 hex).""" + return self._id + + @property + def attrs(self) -> AttrMap: + return self._attributes + + @property + def vm(self) -> VirtualEnvironment: + return self._vm + + @property + def storage(self) -> SandboxStorage: + return self._storage + + @property + def tenant_id(self) -> str: + return self._tenant_id + + @property + def user_id(self) -> str: + return self._user_id + + @property + def app_id(self) -> str: + return self._app_id + + @property + def assets_id(self) -> str: + return self._assets_id + + def mark_ready(self) -> None: + # Signal that sandbox initialization has completed successfully. + self._ready_event.set() + + def mark_failed(self, error: Exception) -> None: + # Capture initialization error and unblock waiters. + self._init_error = error + self._ready_event.set() + + def cancel_init(self) -> None: + # Mark initialization as cancelled to stop background setup. + self._cancel_event.set() + self._ready_event.set() + + def is_cancelled(self) -> bool: + return self._cancel_event.is_set() + + def wait_ready(self, timeout: float | None = None) -> None: + # Block until initialization completes, fails, or is cancelled. + if not self._ready_event.wait(timeout=timeout): + raise TimeoutError("Sandbox initialization timed out") + if self._cancel_event.is_set(): + raise RuntimeError("Sandbox initialization was cancelled") + if self._init_error is not None: + if isinstance(self._init_error, ValueError): + raise RuntimeError(f"Sandbox initialization failed: {self._init_error}") from self._init_error + else: + raise RuntimeError("Sandbox initialization failed") from self._init_error + + def mount(self) -> bool: + return self._storage.mount(self._vm) + + def unmount(self) -> bool: + return self._storage.unmount(self._vm) + + def release(self) -> None: + self.cancel_init() + sandbox_id = self.id + try: + self._storage.unmount(self._vm) + logger.info("Sandbox storage unmounted: sandbox_id=%s", sandbox_id) + except Exception: + logger.exception("Failed to unmount sandbox storage: sandbox_id=%s", sandbox_id) + + try: + self._vm.release_environment() + logger.info("Sandbox released: sandbox_id=%s", sandbox_id) + except Exception: + logger.exception("Failed to release sandbox: sandbox_id=%s", sandbox_id) diff --git a/api/core/sandbox/services/__init__.py b/api/core/sandbox/services/__init__.py new file mode 100644 index 0000000000..70e8a4359f --- /dev/null +++ b/api/core/sandbox/services/__init__.py @@ -0,0 +1,3 @@ +from .asset_download_service import AssetDownloadService + +__all__ = ["AssetDownloadService"] diff --git a/api/core/sandbox/services/asset_download_service.py b/api/core/sandbox/services/asset_download_service.py new file mode 100644 index 0000000000..96a05526f2 --- /dev/null +++ b/api/core/sandbox/services/asset_download_service.py @@ -0,0 +1,140 @@ +"""Shell script builder for downloading / writing assets into a sandbox VM. + +Generates a self-contained POSIX shell script that handles two kinds of +``SandboxDownloadItem``: + +- Items with *content* — written via base64 heredoc (sequential). +- Items with *url* — fetched via ``curl``/``wget``/``python3`` with + auto-detection, run as parallel background jobs. + +Both kinds can be mixed freely in a single call. +""" + +from __future__ import annotations + +import base64 +import shlex +import textwrap +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from core.zip_sandbox.entities import SandboxDownloadItem + + +def _build_inline_commands(items: list[SandboxDownloadItem], root_var: str) -> str: + """Generate shell commands that write base64-encoded content to files.""" + lines: list[str] = [] + for idx, item in enumerate(items): + assert item.content is not None + dest = f"${{{root_var}}}/{shlex.quote(item.path)}" + encoded = base64.b64encode(item.content).decode("ascii") + lines.append(f'mkdir -p "$(dirname "{dest}")"') + lines.append(f"base64 -d <<'_INLINE_{idx}' > \"{dest}\"") + lines.append(encoded) + lines.append(f"_INLINE_{idx}") + return "\n".join(lines) + + +def _render_download_script( + root_path: str, + inline_commands: str, + download_commands: str, + need_downloader: bool, +) -> str: + python_download_cmd = ( + 'python3 - "${url}" "${dest}" <<"PY"\n' + "import sys\n" + "import urllib.request\n" + "url = sys.argv[1]\n" + "dest = sys.argv[2]\n" + "with urllib.request.urlopen(url) as resp:\n" + " data = resp.read()\n" + 'with open(dest, "wb") as f:\n' + " f.write(data)\n" + "PY" + ) + + # Only emit the downloader-detection block when there are remote items. + if need_downloader: + downloader_block = f"""\ +if command -v curl >/dev/null 2>&1; then + download_cmd='curl -fsSL "${{url}}" -o "${{dest}}"' +elif command -v wget >/dev/null 2>&1; then + download_cmd='wget -q "${{url}}" -O "${{dest}}"' +elif command -v python3 >/dev/null 2>&1; then + download_cmd={shlex.quote(python_download_cmd)} +else + echo 'No downloader found (curl/wget/python3)' >&2 + exit 1 +fi + +fail_log="$(mktemp)" + +download_one() {{ + file_path="$1" + url="$2" + dest="${{download_root}}/${{file_path}}" + mkdir -p "$(dirname "${{dest}}")" + eval "${{download_cmd}}" 2>/dev/null || echo "${{file_path}}" >> "${{fail_log}}" +}}""" + else: + downloader_block = "" + + # The failure-check block is only meaningful when downloads occurred. + if need_downloader: + wait_block = textwrap.dedent("""\ + wait + + if [ -s "${fail_log}" ]; then + mv "${fail_log}" "${download_root}/DOWNLOAD_FAILURES.txt" + else + rm -f "${fail_log}" + fi""") + else: + wait_block = "" + + script = f"""\ +download_root={shlex.quote(root_path)} +mkdir -p "${{download_root}}" + +{downloader_block} + +{inline_commands} + +{download_commands} + +{wait_block} +exit 0""" + return script + + +class AssetDownloadService: + @staticmethod + def build_download_script( + items: list[SandboxDownloadItem], + root_path: str, + ) -> str: + """Build a portable shell script to write inline assets and download remote ones. + + Items with *content* are written first (sequential base64 decode), + then items with *url* are fetched in parallel background jobs. + The two kinds can be mixed freely in a single list. + """ + inline = [item for item in items if item.content is not None] + remote = [item for item in items if item.content is None] + + inline_commands = _build_inline_commands(inline, "download_root") if inline else "" + + commands: list[str] = [] + for item in remote: + path = shlex.quote(item.path) + url = shlex.quote(item.url) + commands.append(f"download_one {path} {url} &") + download_commands = "\n".join(commands) + + return _render_download_script( + root_path, + inline_commands, + download_commands, + need_downloader=bool(remote), + ) diff --git a/api/core/sandbox/storage/__init__.py b/api/core/sandbox/storage/__init__.py new file mode 100644 index 0000000000..62859ea724 --- /dev/null +++ b/api/core/sandbox/storage/__init__.py @@ -0,0 +1,11 @@ +from .archive_storage import ArchiveSandboxStorage +from .noop_storage import NoopSandboxStorage +from .sandbox_file_storage import SandboxFilePaths +from .sandbox_storage import SandboxStorage + +__all__ = [ + "ArchiveSandboxStorage", + "NoopSandboxStorage", + "SandboxFilePaths", + "SandboxStorage", +] diff --git a/api/core/sandbox/storage/archive_storage.py b/api/core/sandbox/storage/archive_storage.py new file mode 100644 index 0000000000..d225e45832 --- /dev/null +++ b/api/core/sandbox/storage/archive_storage.py @@ -0,0 +1,83 @@ +"""Archive-based sandbox storage for persisting sandbox state.""" + +from __future__ import annotations + +import logging + +from core.virtual_environment.__base.helpers import pipeline +from core.virtual_environment.__base.virtual_environment import VirtualEnvironment +from extensions.storage.base_storage import BaseStorage +from extensions.storage.cached_presign_storage import CachedPresignStorage +from extensions.storage.file_presign_storage import FilePresignStorage + +from .sandbox_file_storage import SandboxFilePaths +from .sandbox_storage import SandboxStorage + +logger = logging.getLogger(__name__) + +_ARCHIVE_TIMEOUT = 300 # 5 minutes + + +class ArchiveSandboxStorage(SandboxStorage): + """Archive-based storage for sandbox workspace persistence.""" + + def __init__( + self, + tenant_id: str, + app_id: str, + sandbox_id: str, + storage: BaseStorage, + exclude_patterns: list[str] | None = None, + ): + self._sandbox_id = sandbox_id + self._exclude_patterns = exclude_patterns or [] + self._storage_key = SandboxFilePaths.archive(tenant_id, app_id, sandbox_id) + self._storage = CachedPresignStorage( + storage=FilePresignStorage(storage), + cache_key_prefix="sandbox_archives", + ) + + def mount(self, sandbox: VirtualEnvironment) -> bool: + """Load archive from storage into sandbox workspace.""" + if not self.exists(): + logger.debug("No archive found for sandbox %s, skipping mount", self._sandbox_id) + return False + + download_url = self._storage.get_download_url(self._storage_key, _ARCHIVE_TIMEOUT) + archive = "archive.tar.gz" + + ( + pipeline(sandbox) + .add(["curl", "-fsSL", download_url, "-o", archive], error_message="Failed to download archive") + .add(["sh", "-c", 'tar -xzf "$1" 2>/dev/null; exit $?', "sh", archive], error_message="Failed to extract") + .add(["rm", archive], error_message="Failed to cleanup") + .execute(timeout=_ARCHIVE_TIMEOUT, raise_on_error=True) + ) + + logger.info("Mounted archive for sandbox %s", self._sandbox_id) + return True + + def unmount(self, sandbox: VirtualEnvironment) -> bool: + """Save sandbox workspace to storage as archive.""" + upload_url = self._storage.get_upload_url(self._storage_key, _ARCHIVE_TIMEOUT) + archive = f"/tmp/{self._sandbox_id}.tar.gz" + exclude_args = [f"--exclude={p}" for p in self._exclude_patterns] + + ( + pipeline(sandbox) + .add(["tar", "-czf", archive, *exclude_args, "-C", ".", "."], error_message="Failed to create archive") + .add(["curl", "-sf", "-X", "PUT", "-T", archive, upload_url], error_message="Failed to upload archive") + .execute(timeout=_ARCHIVE_TIMEOUT, raise_on_error=True) + ) + logger.info("Unmounted archive for sandbox %s", self._sandbox_id) + return True + + def exists(self) -> bool: + return self._storage.exists(self._storage_key) + + def delete(self) -> None: + try: + self._storage.delete(self._storage_key) + logger.info("Deleted archive for sandbox %s", self._sandbox_id) + except Exception: + logger.exception("Failed to delete archive for sandbox %s", self._sandbox_id) diff --git a/api/core/sandbox/storage/noop_storage.py b/api/core/sandbox/storage/noop_storage.py new file mode 100644 index 0000000000..d4f39ea9a7 --- /dev/null +++ b/api/core/sandbox/storage/noop_storage.py @@ -0,0 +1,18 @@ +from core.sandbox.storage.sandbox_storage import SandboxStorage +from core.virtual_environment.__base.virtual_environment import VirtualEnvironment + + +class NoopSandboxStorage(SandboxStorage): + """A no-op storage implementation that does nothing on mount/unmount.""" + + def mount(self, sandbox: VirtualEnvironment) -> bool: + return True + + def unmount(self, sandbox: VirtualEnvironment) -> bool: + return True + + def exists(self) -> bool: + return False + + def delete(self) -> None: + return diff --git a/api/core/sandbox/storage/sandbox_file_storage.py b/api/core/sandbox/storage/sandbox_file_storage.py new file mode 100644 index 0000000000..dec5c24199 --- /dev/null +++ b/api/core/sandbox/storage/sandbox_file_storage.py @@ -0,0 +1,21 @@ +"""Sandbox file storage key generation. + +Provides SandboxFilePaths facade for generating storage keys for sandbox files. +Storage instances are obtained via SandboxFileService.get_storage(). +""" + +from __future__ import annotations + + +class SandboxFilePaths: + """Facade for generating sandbox file storage keys.""" + + @staticmethod + def export(tenant_id: str, app_id: str, sandbox_id: str, export_id: str) -> str: + """sandbox_files/{tenant}/{app}/{sandbox}/{export_id}/{filename}""" + return f"sandbox_files/{tenant_id}/{app_id}/{sandbox_id}/{export_id}" + + @staticmethod + def archive(tenant_id: str, app_id: str, sandbox_id: str) -> str: + """sandbox_archives/{tenant}/{app}/{sandbox}.tar.gz""" + return f"sandbox_archives/{tenant_id}/{app_id}/{sandbox_id}.tar.gz" diff --git a/api/core/sandbox/storage/sandbox_storage.py b/api/core/sandbox/storage/sandbox_storage.py new file mode 100644 index 0000000000..16f65207a8 --- /dev/null +++ b/api/core/sandbox/storage/sandbox_storage.py @@ -0,0 +1,21 @@ +from abc import ABC, abstractmethod + +from core.virtual_environment.__base.virtual_environment import VirtualEnvironment + + +class SandboxStorage(ABC): + @abstractmethod + def mount(self, sandbox: VirtualEnvironment) -> bool: + """Load files from storage into VM. Returns True if files were loaded.""" + + @abstractmethod + def unmount(self, sandbox: VirtualEnvironment) -> bool: + """Save files from VM to storage. Returns True if files were saved.""" + + @abstractmethod + def exists(self) -> bool: + """Check if storage has saved data.""" + + @abstractmethod + def delete(self) -> None: + """Delete saved data from storage.""" diff --git a/api/core/sandbox/utils/__init__.py b/api/core/sandbox/utils/__init__.py new file mode 100644 index 0000000000..c1d71c108c --- /dev/null +++ b/api/core/sandbox/utils/__init__.py @@ -0,0 +1,2 @@ +# Sandbox utilities +# Connection helpers have been moved to core.virtual_environment.helpers diff --git a/api/core/sandbox/utils/debug.py b/api/core/sandbox/utils/debug.py new file mode 100644 index 0000000000..397d0ee322 --- /dev/null +++ b/api/core/sandbox/utils/debug.py @@ -0,0 +1,22 @@ +"""Sandbox debug utilities. TODO: Remove this module when sandbox debugging is complete.""" + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + pass + +SANDBOX_DEBUG_ENABLED = True + + +def sandbox_debug(tag: str, message: str, data: Any = None) -> None: + if not SANDBOX_DEBUG_ENABLED: + return + + # Lazy import to avoid circular dependency + from core.callback_handler.agent_tool_callback_handler import print_text + + print_text(f"\n[{tag}]\n", color="blue") + if data is not None: + print_text(f"{message}: {data}\n", color="blue") + else: + print_text(f"{message}\n", color="blue") diff --git a/api/core/sandbox/utils/encryption.py b/api/core/sandbox/utils/encryption.py new file mode 100644 index 0000000000..a6007a8ccf --- /dev/null +++ b/api/core/sandbox/utils/encryption.py @@ -0,0 +1,48 @@ +from collections.abc import Mapping +from typing import Any + +from core.entities.provider_entities import BasicProviderConfig +from core.helper.provider_cache import ProviderCredentialsCache +from core.helper.provider_encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter + + +class SandboxProviderConfigCache(ProviderCredentialsCache): + def __init__(self, tenant_id: str, provider_type: str): + super().__init__(tenant_id=tenant_id, provider_type=provider_type) + + def _generate_cache_key(self, **kwargs) -> str: + tenant_id = kwargs["tenant_id"] + provider_type = kwargs["provider_type"] + return f"sandbox_config:tenant_id:{tenant_id}:provider_type:{provider_type}" + + +def create_sandbox_config_encrypter( + tenant_id: str, + config_schema: list[BasicProviderConfig], + provider_type: str, +) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]: + cache = SandboxProviderConfigCache(tenant_id=tenant_id, provider_type=provider_type) + return create_provider_encrypter(tenant_id=tenant_id, config=config_schema, cache=cache) + + +def masked_config( + schemas: list[BasicProviderConfig], + config: Mapping[str, Any], +) -> Mapping[str, Any]: + masked = dict(config) + configs = {x.name: x for x in schemas} + for key, value in config.items(): + schema = configs.get(key) + if not schema: + masked[key] = value + continue + if schema.type == BasicProviderConfig.Type.SECRET_INPUT: + if not isinstance(value, str): + continue + if len(value) <= 4: + masked[key] = "*" * len(value) + else: + masked[key] = value[:2] + "*" * (len(value) - 4) + value[-2:] + else: + masked[key] = value + return masked diff --git a/api/core/skill/__init__.py b/api/core/skill/__init__.py new file mode 100644 index 0000000000..c68e7d83f0 --- /dev/null +++ b/api/core/skill/__init__.py @@ -0,0 +1,11 @@ +from .constants import SkillAttrs +from .entities import ToolDependencies, ToolDependency, ToolReference +from .skill_manager import SkillManager + +__all__ = [ + "SkillAttrs", + "SkillManager", + "ToolDependencies", + "ToolDependency", + "ToolReference", +] diff --git a/api/core/skill/assembler/__init__.py b/api/core/skill/assembler/__init__.py new file mode 100644 index 0000000000..173c1163ec --- /dev/null +++ b/api/core/skill/assembler/__init__.py @@ -0,0 +1,6 @@ +from core.skill.assembler.assemblers import SkillBundleAssembler, SkillDocumentAssembler + +__all__ = [ + "SkillBundleAssembler", + "SkillDocumentAssembler", +] diff --git a/api/core/skill/assembler/assemblers.py b/api/core/skill/assembler/assemblers.py new file mode 100644 index 0000000000..de37e9dfed --- /dev/null +++ b/api/core/skill/assembler/assemblers.py @@ -0,0 +1,80 @@ +from collections.abc import Mapping + +from core.app.entities.app_asset_entities import AppAssetFileTree +from core.skill.assembler.common import ( + build_skill_graph, + compute_transitive_dependance, + expand_referenced_skill_ids, + get_metadata, + process_skill_content, +) +from core.skill.entities.skill_bundle import Skill, SkillBundle, SkillDependance +from core.skill.entities.skill_document import SkillDocument + + +class SkillBundleAssembler: + _file_tree: AppAssetFileTree + + def __init__(self, file_tree: AppAssetFileTree) -> None: + self._file_tree = file_tree + + def assemble_bundle( + self, + documents: Mapping[str, SkillDocument], + assets_id: str, + ) -> SkillBundle: + direct_skills: dict[str, Skill] = {} + for skill_id, doc in documents.items(): + metadata = get_metadata(doc.content, doc.metadata) + direct_dependance = SkillDependance.from_metadata(metadata) + direct_skills[skill_id] = Skill( + skill_id=skill_id, + direct_dependance=direct_dependance, + dependance=direct_dependance, + content=process_skill_content(doc.content, metadata, self._file_tree, skill_id), + ) + + graph = build_skill_graph(direct_skills, self._file_tree) + transitive_map = compute_transitive_dependance(direct_skills, graph) + + compiled_skills: dict[str, Skill] = {} + for skill_id, skill in direct_skills.items(): + compiled_skills[skill_id] = skill.model_copy(update={"dependance": transitive_map[skill_id]}) + + return SkillBundle(asset_tree=self._file_tree, assets_id=assets_id, skills=compiled_skills) + + +class SkillDocumentAssembler: + _bundle: SkillBundle + + def __init__(self, bundle: SkillBundle) -> None: + self._bundle = bundle + + def assemble_document(self, document: SkillDocument, base_path: str = "") -> Skill: + metadata = get_metadata(document.content, document.metadata) + direct_dependance = SkillDependance.from_metadata(metadata) + resolved_content = process_skill_content( + document.content, + metadata, + self._bundle.asset_tree, + document.skill_id, + base_path, + ) + + transitive_dependance = direct_dependance + known_skill_ids = set(self._bundle.skills.keys()) + referenced_skill_ids = expand_referenced_skill_ids( + direct_dependance.files, known_skill_ids, self._bundle.asset_tree + ) + for skill_id in sorted(referenced_skill_ids): + referenced_skill = self._bundle.get(skill_id) + if referenced_skill is None: + continue + transitive_dependance = transitive_dependance | referenced_skill.dependance + + return Skill( + skill_id=document.skill_id, + direct_dependance=direct_dependance, + dependance=transitive_dependance, + content=resolved_content, + ) diff --git a/api/core/skill/assembler/common.py b/api/core/skill/assembler/common.py new file mode 100644 index 0000000000..6a910e0287 --- /dev/null +++ b/api/core/skill/assembler/common.py @@ -0,0 +1,136 @@ +from collections import deque +from collections.abc import Mapping + +from core.app.entities.app_asset_entities import AppAssetFileTree, AssetNodeType +from core.skill.assembler.replacers import ( + FILE_PATTERN, + TOOL_METADATA_PATTERN, + FileReplacer, + Replacer, + ToolGroupReplacer, + ToolReplacer, +) +from core.skill.entities.skill_bundle import Skill, SkillDependance +from core.skill.entities.skill_metadata import FileReference, SkillMetadata, ToolReference + + +def process_skill_content( + content: str, + metadata: SkillMetadata, + file_tree: AppAssetFileTree, + current_id: str, + base_path: str = "", +) -> str: + """Resolve all placeholders in content through the ordered replacer pipeline.""" + replacers: list[Replacer] = [ + FileReplacer(file_tree, current_id, base_path), + ToolGroupReplacer(metadata), + ToolReplacer(metadata), + ] + for replacer in replacers: + content = replacer.resolve(content) + return content + + +def get_metadata(content: str, metadata: SkillMetadata) -> SkillMetadata: + """Parse effective metadata from content placeholders and raw metadata.""" + tools: dict[str, ToolReference] = {} + # find all tool refs actually used in content + for match in TOOL_METADATA_PATTERN.finditer(content): + provider, name, uuid = match.group(1), match.group(2), match.group(3) + tool_ref = metadata.tools.get(uuid) + if tool_ref is None: + raise ValueError(f"Tool reference with UUID {uuid} not found in metadata") + tool_ref.uuid = uuid + tool_ref.tool_name = name + tool_ref.provider = provider + tools[uuid] = tool_ref + + # find all file refs + files: set[FileReference] = set() + for match in FILE_PATTERN.finditer(content): + source, asset_id = match.group(1), match.group(2) + files.add(FileReference(source=source, asset_id=asset_id)) + + return SkillMetadata(tools=tools, files=files) + + +def build_skill_graph(skills: Mapping[str, Skill], file_tree: AppAssetFileTree) -> dict[str, set[str]]: + """Build adjacency list: skill_id -> referenced skill IDs.""" + known_skill_ids = set(skills.keys()) + graph: dict[str, set[str]] = {skill_id: set() for skill_id in known_skill_ids} + + for skill_id, skill in skills.items(): + graph[skill_id] = expand_referenced_skill_ids(skill.direct_dependance.files, known_skill_ids, file_tree) + + return graph + + +def compute_transitive_dependance( + skills: Mapping[str, Skill], + graph: Mapping[str, set[str]], +) -> dict[str, SkillDependance]: + """Compute transitive dependency closure with fixed-point iteration.""" + dependance_map = {skill_id: skill.direct_dependance for skill_id, skill in skills.items()} + + changed = True + while changed: + changed = False + for skill_id in sorted(skills.keys()): + merged = dependance_map[skill_id] + for dep_skill_id in sorted(graph.get(skill_id, set())): + if dep_skill_id == skill_id: + continue + merged = merged | dependance_map[dep_skill_id] + + if merged != dependance_map[skill_id]: + dependance_map[skill_id] = merged + changed = True + + return dependance_map + + +def expand_referenced_skill_ids( + refs: set[FileReference], + known_skill_ids: set[str], + file_tree: AppAssetFileTree, +) -> set[str]: + """Resolve file/folder references to concrete known skill IDs.""" + resolved: set[str] = set() + for ref in refs: + node = file_tree.get(ref.asset_id) + if node is None: + continue + + if node.node_type == AssetNodeType.FILE: + if node.id in known_skill_ids: + resolved.add(node.id) + continue + + descendant_ids = file_tree.get_descendant_ids(node.id) + for descendant_id in descendant_ids: + descendant = file_tree.get(descendant_id) + if descendant is None or descendant.node_type != AssetNodeType.FILE: + continue + if descendant_id in known_skill_ids: + resolved.add(descendant_id) + + return resolved + + +def collect_transitive_skill_ids( + root_skill_ids: set[str], + graph: Mapping[str, set[str]], +) -> set[str]: + """Collect all transitively reachable skill IDs from roots via BFS.""" + visited: set[str] = set() + queue = deque(sorted(root_skill_ids)) + while queue: + current = queue.popleft() + if current in visited: + continue + visited.add(current) + for next_skill_id in sorted(graph.get(current, set())): + if next_skill_id not in visited: + queue.append(next_skill_id) + return visited diff --git a/api/core/skill/assembler/replacers.py b/api/core/skill/assembler/replacers.py new file mode 100644 index 0000000000..01d76b72db --- /dev/null +++ b/api/core/skill/assembler/replacers.py @@ -0,0 +1,108 @@ +"""Placeholder replacers for skill content. + +Each replacer handles one category of ``§[...]§`` placeholder via the unified +``Replacer`` protocol. The shared ``resolve_content`` pipeline in +``core.skill.assembler.common`` builds a ``list[Replacer]`` and applies them +in order: + + ``FileReplacer`` → ``ToolGroupReplacer`` → ``ToolReplacer`` + +``ToolGroupReplacer`` MUST run before ``ToolReplacer`` so that group brackets +``[§[tool]...§, §[tool]...§]`` are resolved atomically; otherwise individual +tool replacement would destroy the group structure. +""" + +import re +from typing import Protocol + +from core.app.entities.app_asset_entities import AppAssetFileTree +from core.skill.entities.skill_metadata import SkillMetadata + +TOOL_METADATA_PATTERN: re.Pattern[str] = re.compile(r"§\[tool\]\.\[([^\]]+)\]\.\[([^\]]+)\]\.\[([^\]]+)\]§") +TOOL_PATTERN: re.Pattern[str] = re.compile(r"§\[tool\]\.\[.*?\]\.\[.*?\]\.\[(.*?)\]§") +TOOL_GROUP_PATTERN: re.Pattern[str] = re.compile( + r"\[\s*§\[tool\]\.\[[^\]]+\]\.\[[^\]]+\]\.\[[^\]]+\]§" + r"(?:\s*,\s*§\[tool\]\.\[[^\]]+\]\.\[[^\]]+\]\.\[[^\]]+\]§)*\s*\]" +) +FILE_PATTERN: re.Pattern[str] = re.compile(r"§\[file\]\.\[([^\]]+)\]\.\[([^\]]+)\]§") + + +class Replacer(Protocol): + def resolve(self, content: str) -> str: ... + + +class FileReplacer: + _tree: AppAssetFileTree + _current_id: str + _base_path: str + + def __init__(self, tree: AppAssetFileTree, current_id: str, base_path: str = "") -> None: + self._tree = tree + self._current_id = current_id + self._base_path = base_path.rstrip("/") + + def resolve(self, content: str) -> str: + return FILE_PATTERN.sub(self._replace_match, content) + + def _replace_match(self, match: re.Match[str]) -> str: + target_id = match.group(2) + source_node = self._tree.get(self._current_id) + target_node = self._tree.get(target_id) + + if target_node is None: + return "[File not found]" + + if source_node is not None: + return self._tree.relative_path(source_node, target_node) + + full_path = self._tree.get_path(target_node.id) + if self._base_path: + return f"{self._base_path}/{full_path}" + return full_path + + +class ToolReplacer: + _metadata: SkillMetadata + + def __init__(self, metadata: SkillMetadata) -> None: + self._metadata = metadata + + def resolve(self, content: str) -> str: + return TOOL_PATTERN.sub(self._replace_match, content) + + def _replace_match(self, match: re.Match[str]) -> str: + tool_id = match.group(1) + tool_ref = self._metadata.tools.get(tool_id) + if tool_ref is None: + return f"[Tool not found or disabled: {tool_id}]" + if not tool_ref.enabled: + return "" + return f"[Executable: {tool_ref.tool_name}_{tool_ref.uuid} --help command]" + + +class ToolGroupReplacer: + _metadata: SkillMetadata + + def __init__(self, metadata: SkillMetadata) -> None: + self._metadata = metadata + + def resolve(self, content: str) -> str: + return TOOL_GROUP_PATTERN.sub(self._replace_match, content) + + def _replace_match(self, match: re.Match[str]) -> str: + group_text = match.group(0) + enabled_renders: list[str] = [] + + for tool_match in TOOL_PATTERN.finditer(group_text): + tool_id = tool_match.group(1) + tool_ref = self._metadata.tools.get(tool_id) + if tool_ref is None: + enabled_renders.append(f"[Tool not found or disabled: {tool_id}]") + continue + if not tool_ref.enabled: + continue + enabled_renders.append(f"[Executable: {tool_ref.tool_name}_{tool_ref.uuid} --help command]") + + if not enabled_renders: + return "" + return "[" + ", ".join(enabled_renders) + "]" diff --git a/api/core/skill/constants.py b/api/core/skill/constants.py new file mode 100644 index 0000000000..d16a8237ac --- /dev/null +++ b/api/core/skill/constants.py @@ -0,0 +1,6 @@ +from core.skill.entities.skill_bundle import SkillBundle +from libs.attr_map import AttrKey + + +class SkillAttrs: + BUNDLE = AttrKey("skill_bundle", SkillBundle) diff --git a/api/core/skill/entities/__init__.py b/api/core/skill/entities/__init__.py new file mode 100644 index 0000000000..710f5c73d4 --- /dev/null +++ b/api/core/skill/entities/__init__.py @@ -0,0 +1,29 @@ +from .skill_bundle import Skill, SkillBundle, SkillDependance +from .skill_document import SkillDocument +from .skill_metadata import ( + FileReference, + SkillMetadata, + ToolConfiguration, + ToolFieldConfig, + ToolReference, +) +from .tool_access_policy import ToolAccessDescription, ToolAccessPolicy, ToolDescription, ToolInvocationRequest +from .tool_dependencies import ToolDependencies, ToolDependency + +__all__ = [ + "FileReference", + "Skill", + "SkillBundle", + "SkillDependance", + "SkillDocument", + "SkillMetadata", + "ToolAccessDescription", + "ToolAccessPolicy", + "ToolConfiguration", + "ToolDependencies", + "ToolDependency", + "ToolDescription", + "ToolFieldConfig", + "ToolInvocationRequest", + "ToolReference", +] diff --git a/api/core/skill/entities/api_entities.py b/api/core/skill/entities/api_entities.py new file mode 100644 index 0000000000..aeb54d503f --- /dev/null +++ b/api/core/skill/entities/api_entities.py @@ -0,0 +1,16 @@ +from pydantic import BaseModel, Field + +from core.skill.entities.tool_dependencies import ToolDependency + + +class NodeSkillInfo(BaseModel): + """Information about skills referenced by a workflow node. + + Used by the whole-workflow skills endpoint to return per-node + tool dependency information. + """ + + node_id: str = Field(description="The node ID") + tool_dependencies: list[ToolDependency] = Field( + default_factory=list, description="Tool dependencies extracted from skill prompts" + ) diff --git a/api/core/skill/entities/skill_bundle.py b/api/core/skill/entities/skill_bundle.py new file mode 100644 index 0000000000..509a8c4746 --- /dev/null +++ b/api/core/skill/entities/skill_bundle.py @@ -0,0 +1,94 @@ +from typing import TYPE_CHECKING + +from pydantic import BaseModel, ConfigDict, Field + +from core.app.entities.app_asset_entities import AppAssetFileTree +from core.skill.entities.skill_metadata import FileReference +from core.skill.entities.tool_dependencies import ToolDependencies, ToolDependency + +if TYPE_CHECKING: + from core.skill.entities.skill_metadata import SkillMetadata + + +class SkillDependance(BaseModel): + model_config = ConfigDict(extra="forbid") + + tools: ToolDependencies = Field(description="Direct tool dependencies parsed from this skill only") + + files: set[FileReference] = Field( + default_factory=set, + description="Direct file references parsed from this skill only", + ) + + def __or__(self, other: "SkillDependance") -> "SkillDependance": + return SkillDependance(tools=self.tools.merge(other.tools), files=self.files | other.files) + + @staticmethod + def from_metadata(metadata: "SkillMetadata") -> "SkillDependance": + """Convert parsed metadata into direct tool/file dependency model.""" + from core.skill.entities.skill_metadata import ToolReference + + dep_map: dict[str, ToolDependency] = {} + ref_map: dict[str, ToolReference] = {} + + for tool_ref in metadata.tools.values(): + dep_map.setdefault( + tool_ref.tool_id(), + ToolDependency( + type=tool_ref.type, + provider=tool_ref.provider, + tool_name=tool_ref.tool_name, + enabled=tool_ref.enabled, + ), + ) + ref_map.setdefault(tool_ref.uuid, tool_ref) + + return SkillDependance( + tools=ToolDependencies( + dependencies=[dep_map[key] for key in sorted(dep_map.keys())], + references=[ref_map[key] for key in sorted(ref_map.keys())], + ), + files=metadata.files, + ) + + +class Skill(BaseModel): + model_config = ConfigDict(extra="forbid") + + skill_id: str = Field(description="Unique identifier for this skill, same with skill_id") + + direct_dependance: SkillDependance = Field(description="Direct dependencies parsed from this skill only") + + dependance: SkillDependance = Field(description="All dependencies including transitive closure") + + content: str = Field(description="Resolved content with all references replaced") + + @property + def tools(self) -> ToolDependencies: + return self.dependance.tools + + +class SkillBundle(BaseModel): + model_config = ConfigDict(extra="forbid") + + asset_tree: AppAssetFileTree = Field(description="Asset tree for this bundle") + + assets_id: str = Field(description="Assets ID this bundle belongs to") + + skills: dict[str, Skill] = Field(default_factory=dict) + + @property + def entries(self) -> dict[str, Skill]: + return self.skills + + def get(self, skill_id: str) -> Skill | None: + return self.skills.get(skill_id) + + def get_tool_dependencies(self) -> ToolDependencies: + merged = ToolDependencies() + for skill in self.skills.values(): + merged = merged.merge(skill.dependance.tools) + return merged + + def put(self, skill: Skill) -> None: + self.skills[skill.skill_id] = skill diff --git a/api/core/skill/entities/skill_document.py b/api/core/skill/entities/skill_document.py new file mode 100644 index 0000000000..76c79d2b9a --- /dev/null +++ b/api/core/skill/entities/skill_document.py @@ -0,0 +1,17 @@ +from pydantic import BaseModel, ConfigDict, Field + +from core.skill.entities.skill_metadata import SkillMetadata + + +class SkillFile(BaseModel): + model_config = ConfigDict(extra="forbid") + + +class SkillDocument(BaseModel): + """Input document for skill compilation.""" + + model_config = ConfigDict(extra="forbid") + + skill_id: str = Field(description="Unique identifier, must match SkillAsset.asset_id") + content: str = Field(description="Raw content with reference placeholders") + metadata: SkillMetadata = Field(default_factory=SkillMetadata, description="Additional metadata for this skill") diff --git a/api/core/skill/entities/skill_metadata.py b/api/core/skill/entities/skill_metadata.py new file mode 100644 index 0000000000..cf8b140e1f --- /dev/null +++ b/api/core/skill/entities/skill_metadata.py @@ -0,0 +1,113 @@ +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator + +from core.tools.entities.tool_entities import ToolProviderType + + +class ToolFieldConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + + id: str + value: Any + auto: bool = False + + +class ToolConfiguration(BaseModel): + model_config = ConfigDict(extra="forbid") + + fields: list[ToolFieldConfig] = Field( + default_factory=list, description="List of field configurations for this tool" + ) + + def default_values(self) -> dict[str, Any]: + return {field.id: field.value for field in self.fields if field.value is not None} + + +def create_tool_id(provider: str, tool_name: str) -> str: + return f"{provider}.{tool_name}" + + +class ToolReference(BaseModel): + model_config = ConfigDict(extra="forbid") + + uuid: str = Field( + default="", + description=( + "Unique identifier for this tool reference, used to distinguish multiple references to the same tool" + ), + ) + type: ToolProviderType = Field(description="The provider type of the tool") + provider: str = Field( + default="", + description="The provider name of the tool plugin. Can be inferred from placeholders during compilation.", + ) + tool_name: str = Field( + default="", + description=( + "The tool name defined in the provider plugin. Can be inferred from placeholders during compilation." + ), + ) + enabled: bool = Field(default=True, description="Whether this tool reference is enabled") + credential_id: str | None = Field( + default=None, + description="Credential ID used to resolve credentials when invoking the tool.", + ) + configuration: ToolConfiguration | None = Field( + default=None, + description=( + "Optional configuration for this tool reference, used to provide " + "additional parameters when invoking the tool" + ), + ) + + def reference_id(self) -> str: + return f"{self.provider}.{self.tool_name}.{self.uuid}" + + def tool_id(self) -> str: + return f"{self.provider}.{self.tool_name}" + + +class FileReference(BaseModel): + model_config = ConfigDict(frozen=True) + + source: str = Field(default="app") + asset_id: str + + @model_validator(mode="before") + @classmethod + def normalize_input(cls, data: Any) -> Any: + if not isinstance(data, dict): + return data + if "asset_id" in data and "source" in data: + return {"source": data.get("source", "app"), "asset_id": data["asset_id"]} + # front end support + if "id" in data: + return {"source": "app", "asset_id": data["id"]} + return data + + +class SkillMetadata(BaseModel): + model_config = ConfigDict(extra="forbid") + + tools: dict[str, ToolReference] = Field(default_factory=dict) + files: set[FileReference] = Field(default_factory=set) + + @field_validator("files", mode="before") + @classmethod + def coerce_files_to_set(cls, v: Any) -> set[FileReference] | Any: + if isinstance(v, list): + refs: set[FileReference] = set() + for item in v: + if isinstance(item, dict): + refs.add(FileReference.model_validate(item)) + elif isinstance(item, FileReference): + refs.add(item) + return refs + if isinstance(v, dict): + refs = set() + for item in v.values(): + if isinstance(item, dict): + refs.add(FileReference.model_validate(item)) + return refs + return v diff --git a/api/core/skill/entities/tool_access_policy.py b/api/core/skill/entities/tool_access_policy.py new file mode 100644 index 0000000000..0ca20f5a90 --- /dev/null +++ b/api/core/skill/entities/tool_access_policy.py @@ -0,0 +1,145 @@ +from collections.abc import Mapping + +from pydantic import BaseModel, ConfigDict, Field + +from core.skill.entities.tool_dependencies import ToolDependencies +from core.tools.entities.tool_entities import ToolProviderType + + +class ToolDescription(BaseModel): + """Immutable identifier for a tool (type + provider + name).""" + + model_config = ConfigDict(frozen=True) + + tool_type: ToolProviderType + provider: str + tool_name: str + + def tool_id(self) -> str: + return f"{self.tool_type.value}:{self.provider}:{self.tool_name}" + + +class ToolAccessDescription(BaseModel): + """ + Per-tool access descriptor that bundles identity with allowed credentials. + + Each allowed tool is represented by exactly one ``ToolAccessDescription``. + ``allowed_credentials`` captures the set of credential IDs that may be used + when invoking this tool: + + * **empty set** – the tool requires no special credential; only requests + *without* a ``credential_id`` are accepted. + * **non-empty set** – the tool requires an explicit credential; the + request's ``credential_id`` must be a member of this set. + """ + + model_config = ConfigDict(frozen=True) + + tool_type: ToolProviderType + provider: str + tool_name: str + allowed_credentials: frozenset[str] = Field(default_factory=frozenset) + + def tool_id(self) -> str: + return f"{self.tool_type.value}:{self.provider}:{self.tool_name}" + + def is_credential_allowed(self, credential_id: str | None) -> bool: + """Check whether *credential_id* satisfies this tool's credential policy. + + * No credentials registered (``allowed_credentials`` is empty) → + only requests *without* a credential are accepted. + * Credentials registered → the supplied ``credential_id`` must be in + the set. + """ + if credential_id is None or credential_id == "": + return True + + return credential_id in self.allowed_credentials + + +class ToolInvocationRequest(BaseModel): + """A request to invoke a specific tool with optional credential.""" + + model_config = ConfigDict(frozen=True) + + tool_type: ToolProviderType + provider: str + tool_name: str + credential_id: str | None = None + + @property + def tool_description(self) -> ToolDescription: + return ToolDescription(tool_type=self.tool_type, provider=self.provider, tool_name=self.tool_name) + + +class ToolAccessPolicy(BaseModel): + """ + Determines whether a tool invocation is allowed based on ToolDependencies. + + The policy is built exclusively from ``ToolDependencies.references`` – each + ``ToolReference`` declares both the tool identity *and* the credential that + may be used. ``ToolDependencies.dependencies`` is a de-duplicated identity + list and does not participate in access-control decisions. + + Rules: + 1. The tool must appear in at least one reference. + 2. If references for the tool carry credential IDs, the request must supply + one of those exact IDs. + 3. If no reference for the tool carries a credential ID, the request must + *not* supply one (use default/ambient credentials). + """ + + model_config = ConfigDict(frozen=True) + + access_map: Mapping[str, ToolAccessDescription] = Field(default_factory=dict) + + @classmethod + def from_dependencies(cls, deps: ToolDependencies | None) -> "ToolAccessPolicy": + """Build a policy from ``ToolDependencies``. + + Only ``deps.references`` are used. Multiple references to the same + tool are merged – their credential IDs are unioned into a single + ``ToolAccessDescription.allowed_credentials`` set. + """ + if deps is None or deps.is_empty(): + return cls() + + # Accumulate credential sets keyed by tool_id so that multiple + # references to the same tool are merged correctly. + credentials_by_tool: dict[str, set[str]] = {} + first_seen: dict[str, tuple[ToolProviderType, str, str]] = {} + + for ref in deps.references: + tool_id = f"{ref.type.value}:{ref.provider}:{ref.tool_name}" + if tool_id not in first_seen: + first_seen[tool_id] = (ref.type, ref.provider, ref.tool_name) + credentials_by_tool[tool_id] = set() + if ref.credential_id is not None: + credentials_by_tool[tool_id].add(ref.credential_id) + + access_map: dict[str, ToolAccessDescription] = {} + for tool_id, (tool_type, provider, tool_name) in first_seen.items(): + access_map[tool_id] = ToolAccessDescription( + tool_type=tool_type, + provider=provider, + tool_name=tool_name, + allowed_credentials=frozenset(credentials_by_tool[tool_id]), + ) + + return cls(access_map=access_map) + + def is_empty(self) -> bool: + return len(self.access_map) == 0 + + def is_allowed(self, request: ToolInvocationRequest) -> bool: + """Check if the tool invocation request is allowed.""" + # An empty policy (no references declared) permits any invocation. + if self.is_empty(): + return True + + tool_id = request.tool_description.tool_id() + access_desc = self.access_map.get(tool_id) + if access_desc is None: + return False + + return access_desc.is_credential_allowed(request.credential_id) diff --git a/api/core/skill/entities/tool_dependencies.py b/api/core/skill/entities/tool_dependencies.py new file mode 100644 index 0000000000..259476c661 --- /dev/null +++ b/api/core/skill/entities/tool_dependencies.py @@ -0,0 +1,78 @@ +from pydantic import BaseModel, ConfigDict, Field + +from core.skill.entities.skill_metadata import ToolReference +from core.tools.entities.tool_entities import ToolProviderType + + +class ToolDependency(BaseModel): + model_config = ConfigDict(extra="forbid") + + type: ToolProviderType + provider: str + tool_name: str + enabled: bool = True + + def tool_id(self) -> str: + return f"{self.provider}.{self.tool_name}" + + +class ToolDependencies(BaseModel): + model_config = ConfigDict(extra="forbid") + + dependencies: list[ToolDependency] = Field(default_factory=list) + references: list[ToolReference] = Field(default_factory=list) + + def is_empty(self) -> bool: + return not self.dependencies and not self.references + + def filter(self, tools: list[tuple[str, str]]) -> "ToolDependencies": + tool_names = {f"{provider}.{tool_name}" for provider, tool_name in tools} + return ToolDependencies( + dependencies=[ + dependency + for dependency in self.dependencies + if f"{dependency.provider}.{dependency.tool_name}" in tool_names + ], + references=[ + reference + for reference in self.references + if f"{reference.provider}.{reference.tool_name}" in tool_names + ], + ) + + def merge(self, other: "ToolDependencies") -> "ToolDependencies": + dep_map: dict[str, ToolDependency] = {} + for dep in self.dependencies: + key = f"{dep.provider}.{dep.tool_name}" + dep_map[key] = dep + for dep in other.dependencies: + key = f"{dep.provider}.{dep.tool_name}" + if key not in dep_map: + dep_map[key] = dep + + ref_map: dict[str, ToolReference] = {} + for ref in self.references: + ref_map[ref.uuid] = ref + for ref in other.references: + if ref.uuid not in ref_map: + ref_map[ref.uuid] = ref + + return ToolDependencies( + dependencies=list(dep_map.values()), + references=list(ref_map.values()), + ) + + def remove_tools(self, tools: list[ToolDependency]) -> "ToolDependencies": + tool_keys = {f"{tool.provider}.{tool.tool_name}" for tool in tools} + return ToolDependencies( + dependencies=[ + dependency + for dependency in self.dependencies + if f"{dependency.provider}.{dependency.tool_name}" not in tool_keys + ], + references=[ + reference + for reference in self.references + if f"{reference.provider}.{reference.tool_name}" not in tool_keys + ], + ) diff --git a/api/core/skill/skill_manager.py b/api/core/skill/skill_manager.py new file mode 100644 index 0000000000..786dc589b1 --- /dev/null +++ b/api/core/skill/skill_manager.py @@ -0,0 +1,43 @@ +import logging + +from core.app_assets.storage import AssetPaths +from core.skill.entities.skill_bundle import SkillBundle +from extensions.ext_redis import redis_client +from services.app_asset_service import AppAssetService + +logger = logging.getLogger(__name__) + +_CACHE_PREFIX = "skill_bundle" +_CACHE_TTL = 86400 # 24 hours + + +class SkillManager: + @staticmethod + def load_bundle(tenant_id: str, app_id: str, assets_id: str) -> SkillBundle: + cache_key = f"{_CACHE_PREFIX}:{tenant_id}:{app_id}:{assets_id}" + data = redis_client.get(cache_key) + if data: + return SkillBundle.model_validate_json(data) + + key = AssetPaths.skill_bundle(tenant_id, app_id, assets_id) + try: + data = AppAssetService.get_storage().load_once(key) + except FileNotFoundError: + logger.exception( + "Skill bundle not found in storage: key=%s, tenant_id=%s, app_id=%s, assets_id=%s", + key, + tenant_id, + app_id, + assets_id, + ) + raise + bundle = SkillBundle.model_validate_json(data) + redis_client.setex(cache_key, _CACHE_TTL, bundle.model_dump_json(indent=2).encode("utf-8")) + return bundle + + @staticmethod + def save_bundle(tenant_id: str, app_id: str, assets_id: str, bundle: SkillBundle) -> None: + key = AssetPaths.skill_bundle(tenant_id, app_id, assets_id) + AppAssetService.get_storage().save(key, data=bundle.model_dump_json(indent=2).encode("utf-8")) + cache_key = f"{_CACHE_PREFIX}:{tenant_id}:{app_id}:{assets_id}" + redis_client.delete(cache_key) diff --git a/api/core/virtual_environment/__base/__init__.py b/api/core/virtual_environment/__base/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/virtual_environment/__base/command_future.py b/api/core/virtual_environment/__base/command_future.py new file mode 100644 index 0000000000..f5363fcf4f --- /dev/null +++ b/api/core/virtual_environment/__base/command_future.py @@ -0,0 +1,208 @@ +import contextlib +import logging +import threading +import time +from collections.abc import Callable +from concurrent.futures import ThreadPoolExecutor + +from core.virtual_environment.__base.entities import CommandResult, CommandStatus +from core.virtual_environment.__base.exec import NotSupportedOperationError +from core.virtual_environment.channel.exec import TransportEOFError +from core.virtual_environment.channel.transport import TransportReadCloser, TransportWriteCloser + +logger = logging.getLogger(__name__) + + +class CommandTimeoutError(Exception): + pass + + +class CommandCancelledError(Exception): + pass + + +class CommandFuture: + """ + Lightweight future for command execution. + Mirrors concurrent.futures.Future API with 4 essential methods: + result(), done(), cancel(), cancelled(). + + When a command is cancelled or times out the future now asks the provider + to terminate the underlying process/session before marking itself done. + """ + + def __init__( + self, + pid: str, + stdin_transport: TransportWriteCloser, + stdout_transport: TransportReadCloser, + stderr_transport: TransportReadCloser, + poll_status: Callable[[], CommandStatus], + terminate_command: Callable[[], bool] | None = None, + poll_interval: float = 0.1, + ): + self._pid = pid + self._stdin_transport = stdin_transport + self._stdout_transport = stdout_transport + self._stderr_transport = stderr_transport + self._poll_status = poll_status + self._terminate_command = terminate_command + self._poll_interval = poll_interval + + self._done_event = threading.Event() + self._lock = threading.Lock() + self._result: CommandResult | None = None + self._exception: BaseException | None = None + self._cancelled = False + self._timed_out = False + self._started = False + self._termination_requested = False + + def result(self, timeout: float | None = None) -> CommandResult: + """ + Block until command completes and return result. + + Args: + timeout: Maximum seconds to wait. None means wait forever. + + Raises: + CommandTimeoutError: If timeout exceeded. + CommandCancelledError: If command was cancelled. + + A timeout is terminal for this future: it triggers best-effort command + termination and subsequent ``result()`` calls keep raising timeout. + """ + self._ensure_started() + + if not self._done_event.wait(timeout): + self._request_stop(timed_out=True) + raise CommandTimeoutError(f"Command timed out after {timeout}s") + + if self._cancelled: + raise CommandCancelledError("Command was cancelled") + + if self._timed_out: + raise CommandTimeoutError("Command timed out") + + if self._exception is not None: + raise self._exception + + assert self._result is not None + return self._result + + def done(self) -> bool: + self._ensure_started() + return self._done_event.is_set() + + def cancel(self) -> bool: + """ + Attempt to cancel command by terminating it and closing transports. + Returns True if cancelled, False if already completed. + """ + return self._request_stop(cancelled=True) + + def cancelled(self) -> bool: + return self._cancelled + + def _ensure_started(self) -> None: + with self._lock: + if not self._started: + self._started = True + thread = threading.Thread(target=self._execute, daemon=True) + thread.start() + + def _request_stop(self, *, cancelled: bool = False, timed_out: bool = False) -> bool: + should_terminate = False + with self._lock: + if self._done_event.is_set(): + return False + + if cancelled: + self._cancelled = True + if timed_out: + self._timed_out = True + + should_terminate = not self._termination_requested + if should_terminate: + self._termination_requested = True + + self._close_transports() + self._done_event.set() + + if should_terminate: + self._terminate_running_command() + return True + + def _execute(self) -> None: + stdout_buf = bytearray() + stderr_buf = bytearray() + is_combined_stream = self._stdout_transport is self._stderr_transport + + try: + with ThreadPoolExecutor(max_workers=2) as executor: + stdout_future = executor.submit(self._drain_transport, self._stdout_transport, stdout_buf) + stderr_future = None + if not is_combined_stream: + stderr_future = executor.submit(self._drain_transport, self._stderr_transport, stderr_buf) + + exit_code = self._wait_for_completion() + + stdout_future.result() + if stderr_future: + stderr_future.result() + + with self._lock: + if not self._cancelled: + self._result = CommandResult( + stdout=bytes(stdout_buf), + stderr=b"" if is_combined_stream else bytes(stderr_buf), + exit_code=exit_code, + pid=self._pid, + ) + self._done_event.set() + + except Exception as e: + logger.exception("Command execution failed for pid %s", self._pid) + with self._lock: + if not self._cancelled: + self._exception = e + self._done_event.set() + finally: + self._close_transports() + + def _wait_for_completion(self) -> int | None: + while not self._cancelled and not self._timed_out: + try: + status = self._poll_status() + except NotSupportedOperationError: + return None + + if status.status == CommandStatus.Status.COMPLETED: + return status.exit_code + + time.sleep(self._poll_interval) + + return None + + def _drain_transport(self, transport: TransportReadCloser, buffer: bytearray) -> None: + try: + while True: + buffer.extend(transport.read(4096)) + except TransportEOFError: + pass + except Exception: + logger.exception("Failed reading transport") + + def _close_transports(self) -> None: + for transport in (self._stdin_transport, self._stdout_transport, self._stderr_transport): + with contextlib.suppress(Exception): + transport.close() + + def _terminate_running_command(self) -> None: + if self._terminate_command is None: + return + + try: + self._terminate_command() + except Exception: + logger.exception("Failed to terminate command for pid %s", self._pid) diff --git a/api/core/virtual_environment/__base/entities.py b/api/core/virtual_environment/__base/entities.py new file mode 100644 index 0000000000..463f1c7c01 --- /dev/null +++ b/api/core/virtual_environment/__base/entities.py @@ -0,0 +1,100 @@ +from collections.abc import Mapping +from enum import StrEnum +from typing import Any + +from pydantic import BaseModel, Field + + +class Arch(StrEnum): + """ + Architecture types for virtual environments. + """ + + ARM64 = "arm64" + AMD64 = "amd64" + + +class OperatingSystem(StrEnum): + """ + Operating system types for virtual environments. + """ + + LINUX = "linux" + DARWIN = "darwin" + + +class Metadata(BaseModel): + """ + Returned metadata about a virtual environment. + """ + + id: str = Field(description="The unique identifier of the virtual environment.") + arch: Arch = Field(description="Which architecture was used to create the virtual environment.") + os: OperatingSystem = Field(description="The operating system of the virtual environment.") + store: Mapping[str, Any] = Field( + default_factory=dict, description="The store information of the virtual environment., Additional data." + ) + + +class ConnectionHandle(BaseModel): + """ + Handle for managing connections to the virtual environment. + """ + + id: str = Field(description="The unique identifier of the connection handle.") + + +class CommandStatus(BaseModel): + """ + Status of a command executed in the virtual environment. + """ + + class Status(StrEnum): + RUNNING = "running" + COMPLETED = "completed" + + status: Status = Field(description="The status of the command execution.") + exit_code: int | None = Field(description="The return code of the command execution.") + + +class FileState(BaseModel): + """ + State of a file in the virtual environment. + """ + + size: int = Field(description="The size of the file in bytes.") + path: str = Field(description="The path of the file in the virtual environment.") + created_at: int = Field(description="The creation timestamp of the file.") + updated_at: int = Field(description="The last modified timestamp of the file.") + + +class CommandResult(BaseModel): + """ + Result of a synchronous command execution. + """ + + stdout: bytes = Field(description="Standard output content.") + stderr: bytes = Field(description="Standard error content.") + exit_code: int | None = Field(description="Exit code of the command. None if unavailable.") + pid: str = Field(description="Process ID of the executed command.") + + @property + def is_error(self) -> bool: + return self.exit_code not in (None, 0) or bool(self.stderr.decode("utf-8", errors="replace")) + + @property + def error_message(self) -> str: + return self.stderr.decode("utf-8", errors="replace") if self.stderr else "" + + @property + def info_message(self) -> str: + return self.stdout.decode("utf-8", errors="replace") if self.stdout else "" + + @property + def debug_message(self) -> str: + return ( + f"stdout: {self.stdout.decode('utf-8', errors='replace')}\n" + f"stderr: {self.stderr.decode('utf-8', errors='replace')}\n" + f"exit_code: {self.exit_code}\n" + f"pid: {self.pid}" + ) diff --git a/api/core/virtual_environment/__base/exec.py b/api/core/virtual_environment/__base/exec.py new file mode 100644 index 0000000000..3a833045a7 --- /dev/null +++ b/api/core/virtual_environment/__base/exec.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from core.virtual_environment.__base.entities import CommandResult + + +class ArchNotSupportedError(Exception): + """Exception raised when the architecture is not supported.""" + + pass + + +class VirtualEnvironmentLaunchFailedError(Exception): + """Exception raised when launching the virtual environment fails.""" + + pass + + +class NotSupportedOperationError(Exception): + """Exception raised when an operation is not supported.""" + + pass + + +class SandboxConfigValidationError(ValueError): + """Exception raised when sandbox configuration validation fails.""" + + pass + + +class CommandExecutionError(ValueError): + """Raised when a command execution fails.""" + + result: CommandResult + + def __init__(self, message: str, result: CommandResult): + super().__init__(message) + self.result = result + + @property + def exit_code(self) -> int | None: + return self.result.exit_code + + @property + def stderr(self) -> bytes: + return self.result.stderr + + +class PipelineExecutionError(CommandExecutionError): + """Raised when a pipeline command fails in strict mode.""" + + index: int + command: list[str] + results: list[CommandResult] + + def __init__( + self, message: str, result: CommandResult, *, index: int, command: list[str], results: list[CommandResult] + ): + super().__init__(message, result) + self.index = index + self.command = command + self.results = results diff --git a/api/core/virtual_environment/__base/helpers.py b/api/core/virtual_environment/__base/helpers.py new file mode 100644 index 0000000000..e8094f4ba7 --- /dev/null +++ b/api/core/virtual_environment/__base/helpers.py @@ -0,0 +1,279 @@ +from __future__ import annotations + +import contextlib +import shlex +from collections.abc import Generator, Mapping +from contextlib import contextmanager +from dataclasses import dataclass, field +from functools import partial + +from core.virtual_environment.__base.command_future import CommandFuture +from core.virtual_environment.__base.entities import CommandResult, ConnectionHandle +from core.virtual_environment.__base.exec import CommandExecutionError, PipelineExecutionError +from core.virtual_environment.__base.virtual_environment import VirtualEnvironment + +_PIPE_SENTINEL = "__DIFY_PIPE__" + + +@contextmanager +def with_connection(env: VirtualEnvironment) -> Generator[ConnectionHandle, None, None]: + """Context manager for VirtualEnvironment connection lifecycle. + + Automatically establishes and releases connection handles. + + Usage: + with with_connection(env) as conn: + future = run_command(env, conn, ["echo", "hello"]) + result = future.result(timeout=10) + """ + connection_handle = env.establish_connection() + try: + yield connection_handle + finally: + with contextlib.suppress(Exception): + env.release_connection(connection_handle) + + +def submit_command( + env: VirtualEnvironment, + connection: ConnectionHandle, + command: list[str], + environments: Mapping[str, str] | None = None, + *, + cwd: str | None = None, +) -> CommandFuture: + """Execute a command and return a Future for the result. + + High-level interface that handles IO draining internally. + For streaming output, use env.execute_command() instead. + + Args: + env: The virtual environment to execute the command in. + connection: The connection handle. + command: Command as list of strings. + environments: Environment variables. + cwd: Working directory for the command. If None, uses the provider's default. + + Returns: + CommandFuture that can be used to get result with timeout or cancel. + + Example: + with with_connection(env) as conn: + result = run_command(env, conn, ["ls", "-la"]).result(timeout=30) + """ + pid, stdin_transport, stdout_transport, stderr_transport = env.execute_command( + connection, command, environments, cwd + ) + + return CommandFuture( + pid=pid, + stdin_transport=stdin_transport, + stdout_transport=stdout_transport, + stderr_transport=stderr_transport, + poll_status=partial(env.get_command_status, connection, pid), + terminate_command=partial(env.terminate_command, connection, pid), + ) + + +def _execute_with_connection( + env: VirtualEnvironment, + conn: ConnectionHandle, + command: list[str], + timeout: float | None, + cwd: str | None, +) -> CommandResult: + """Internal helper to execute command with given connection.""" + future = submit_command(env, conn, command, cwd=cwd) + return future.result(timeout=timeout) + + +def execute( + env: VirtualEnvironment, + command: list[str], + *, + timeout: float | None = 30, + cwd: str | None = None, + error_message: str = "Command failed", + connection: ConnectionHandle | None = None, +) -> CommandResult: + """Execute a command with automatic connection management. + + Raises CommandExecutionError if the command fails (non-zero exit code). + + Args: + env: The virtual environment to execute the command in. + command: The command to execute as a list of strings. + timeout: Maximum time to wait for the command to complete (seconds). + cwd: Working directory for the command. + error_message: Custom error message prefix for failures. + connection: Optional connection handle to reuse. If None, creates and releases a new connection. + + Returns: + CommandResult on success. + + Raises: + CommandExecutionError: If the command fails. + """ + if connection is not None: + result = _execute_with_connection(env, connection, command, timeout, cwd) + else: + with with_connection(env) as conn: + result = _execute_with_connection(env, conn, command, timeout, cwd) + + if result.is_error: + raise CommandExecutionError(f"{error_message}: {result.error_message}", result) + return result + + +def try_execute( + env: VirtualEnvironment, + command: list[str], + *, + timeout: float | None = 30, + cwd: str | None = None, + connection: ConnectionHandle | None = None, +) -> CommandResult: + """Execute a command with automatic connection management. + + Does not raise on failure - returns the result for caller to handle. + + Args: + env: The virtual environment to execute the command in. + command: The command to execute as a list of strings. + timeout: Maximum time to wait for the command to complete (seconds). + cwd: Working directory for the command. + connection: Optional connection handle to reuse. If None, creates and releases a new connection. + + Returns: + CommandResult containing stdout, stderr, and exit_code. + """ + if connection is not None: + return _execute_with_connection(env, connection, command, timeout, cwd) + + with with_connection(env) as conn: + return _execute_with_connection(env, conn, command, timeout, cwd) + + +@dataclass(frozen=True) +class _PipelineStep: + argv: list[str] + error_message: str = "Command failed" + + +@dataclass +class CommandPipeline: + """Batch multiple commands into a single shell execution (Redis pipeline style). + + Example: + results = pipeline(env).add(["echo", "hi"]).add(["ls"]).execute() + # Strict mode: raise on first failure + pipeline(env).add(["mkdir", "/x"], error_message="mkdir failed").execute(raise_on_error=True) + """ + + env: VirtualEnvironment + connection: ConnectionHandle | None = None + cwd: str | None = None + environments: Mapping[str, str] | None = None + + _steps: list[_PipelineStep] = field(default_factory=list) # pyright: ignore[reportUnknownVariableType] + + def add(self, command: list[str], *, error_message: str = "Command failed", on: bool = True) -> CommandPipeline: + if on: + self._steps.append(_PipelineStep(argv=command, error_message=error_message)) + return self + + def execute(self, *, timeout: float | None = 30, raise_on_error: bool = False) -> list[CommandResult]: + if not self._steps: + return [] + + script = self._build_script(fail_fast=raise_on_error) + batch_cmd = ["sh", "-c", script] + + if self.connection is not None: + batch_result = try_execute(self.env, batch_cmd, timeout=timeout, cwd=self.cwd, connection=self.connection) + else: + with with_connection(self.env) as conn: + batch_result = try_execute(self.env, batch_cmd, timeout=timeout, cwd=self.cwd, connection=conn) + + results = self._parse_results(batch_result.stdout, batch_result.pid) + + if raise_on_error: + for i, r in enumerate(iterable=results): + if r.is_error: + step = self._steps[i] + raise PipelineExecutionError( + f"{step.error_message}: {r.error_message}", + r, + index=i, + command=step.argv, + results=results, + ) + + return results + + def _build_script(self, *, fail_fast: bool = False) -> str: + lines = [ + "run() {", + ' i="$1"; shift', + ' out="$(mktemp)"; err="$(mktemp)"', + ' ("$@") >"$out" 2>"$err"; ec="$?"', + ' os="$(wc -c <"$out" | tr -d \' \')"', + ' es="$(wc -c <"$err" | tr -d \' \')"', + f' printf \'{_PIPE_SENTINEL} %s %s %s %s\\n\' "$i" "$ec" "$os" "$es"', + ' cat "$out"', + ' cat "$err"', + ' rm -f "$out" "$err"', + ' return "$ec"', + "}", + ] + suffix = " || exit $?" if fail_fast else "" + for i, step in enumerate(self._steps): + quoted = " ".join(shlex.quote(arg) for arg in step.argv) + lines.append(f"run {i} {quoted}{suffix}") + return "\n".join(lines) + + @staticmethod + def _parse_results(stdout: bytes, pid: str) -> list[CommandResult]: + results: list[CommandResult] = [] + pos = 0 + sentinel = _PIPE_SENTINEL.encode() + b" " + + while pos < len(stdout): + nl = stdout.find(b"\n", pos) + if nl == -1: + break + header = stdout[pos : nl + 1] + pos = nl + 1 + + if not header.startswith(sentinel): + raise ValueError("Malformed pipeline output: missing sentinel") + + parts = header.decode().strip().split(" ") + _, idx, ec, os_len, es_len = parts + out_len, err_len = int(os_len), int(es_len) + + out_bytes = stdout[pos : pos + out_len] + pos += out_len + err_bytes = stdout[pos : pos + err_len] + pos += err_len + + results.append( + CommandResult( + stdout=out_bytes, + stderr=err_bytes, + exit_code=int(ec), + pid=f"{pid}:{idx}", + ) + ) + + return results + + +def pipeline( + env: VirtualEnvironment, + connection: ConnectionHandle | None = None, + *, + cwd: str | None = None, + environments: Mapping[str, str] | None = None, +) -> CommandPipeline: + return CommandPipeline(env=env, connection=connection, cwd=cwd, environments=environments) diff --git a/api/core/virtual_environment/__base/virtual_environment.py b/api/core/virtual_environment/__base/virtual_environment.py new file mode 100644 index 0000000000..5332f1dd2e --- /dev/null +++ b/api/core/virtual_environment/__base/virtual_environment.py @@ -0,0 +1,231 @@ +from abc import ABC, abstractmethod +from collections.abc import Mapping, Sequence +from io import BytesIO +from typing import Any + +from core.entities.provider_entities import BasicProviderConfig +from core.virtual_environment.__base.entities import CommandStatus, ConnectionHandle, FileState, Metadata +from core.virtual_environment.channel.transport import TransportReadCloser, TransportWriteCloser + + +class VirtualEnvironment(ABC): + """ + Base class for virtual environment implementations. + + ``VirtualEnvironment`` instances are configured at construction time but do + not allocate provider resources until ``open_enviroment()`` is called. + This keeps object construction side-effect free and gives callers a chance + to own startup error handling explicitly. + """ + + tenant_id: str + user_id: str | None + options: Mapping[str, Any] + _environments: Mapping[str, str] + _metadata: Metadata | None + + def __init__( + self, + tenant_id: str, + options: Mapping[str, Any], + environments: Mapping[str, str] | None = None, + user_id: str | None = None, + ) -> None: + """ + Initialize the virtual environment configuration. + + Args: + tenant_id: The tenant ID associated with this environment (required). + options: Provider-specific configuration options. + environments: Environment variables to set in the virtual environment. + user_id: The user ID associated with this environment (optional). + + The provider runtime itself is created later by ``open_enviroment()``. + """ + + self.tenant_id = tenant_id + self.user_id = user_id + self.options = options + self._environments = dict(environments or {}) + self._metadata = None + + @property + def metadata(self) -> Metadata: + """Provider metadata for a started environment. + + Raises: + RuntimeError: If the environment has not been started yet. + """ + + if self._metadata is None: + raise RuntimeError("Virtual environment has not been started") + return self._metadata + + def open_enviroment(self) -> Metadata: + """Allocate provider resources and return the resulting metadata. + + Multiple calls are safe and return the existing metadata after the first + successful start. + """ + + if self._metadata is None: + self._metadata = self._construct_environment(self.options, self._environments) + return self._metadata + + @abstractmethod + def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata: + """ + Construct the unique identifier for the virtual environment. + + Returns: + str: The unique identifier of the virtual environment. + """ + + @abstractmethod + def upload_file(self, path: str, content: BytesIO) -> None: + """ + Upload a file to the virtual environment. + + Args: + path (str): The destination path in the virtual environment. + content (BytesIO): The content of the file to upload. + + Raises: + Exception: If the file cannot be uploaded. + """ + + @abstractmethod + def download_file(self, path: str) -> BytesIO: + """ + Download a file from the virtual environment. + + Args: + source_path (str): The source path in the virtual environment. + Returns: + BytesIO: The content of the downloaded file. + Raises: + Exception: If the file cannot be downloaded. + """ + + @abstractmethod + def list_files(self, directory_path: str, limit: int) -> Sequence[FileState]: + """ + List files in a directory of the virtual environment. + + Args: + directory_path (str): The directory path in the virtual environment. + limit (int): The maximum number of files(including recursive paths) to return. + Returns: + Sequence[FileState]: A list of file states in the specified directory. + Raises: + Exception: If the files cannot be listed. + + Example: + If the directory structure is like: + /dir + /subdir1 + file1.txt + /subdir2 + file2.txt + And limit is 2, the returned list may look like: + [ + FileState(path="/dir/subdir1/file1.txt", is_directory=False, size=1234, created_at=..., updated_at=...), + FileState(path="/dir/subdir2", is_directory=True, size=0, created_at=..., updated_at=...), + ] + """ + + @abstractmethod + def establish_connection(self) -> ConnectionHandle: + """ + Establish a connection to the virtual environment. + + Returns: + ConnectionHandle: Handle for managing the connection to the virtual environment. + + Raises: + Exception: If the connection cannot be established. + """ + + @abstractmethod + def release_connection(self, connection_handle: ConnectionHandle) -> None: + """ + Release the connection to the virtual environment. + + Args: + connection_handle (ConnectionHandle): The handle for managing the connection. + + Raises: + Exception: If the connection cannot be released. + """ + + @abstractmethod + def release_environment(self) -> None: + """ + Release the virtual environment. + + Raises: + Exception: If the environment cannot be released. + Multiple calls to `release_environment` with the same `environment_id` is acceptable. + """ + + def terminate_command(self, connection_handle: ConnectionHandle, pid: str) -> bool: + """Best-effort termination hook for a running command. + + Providers that can map ``pid`` back to a real process/session should + override this method and stop the command. The default implementation is + a no-op so providers without a termination mechanism remain compatible. + """ + + _ = connection_handle + _ = pid + return False + + @abstractmethod + def execute_command( + self, + connection_handle: ConnectionHandle, + command: list[str], + environments: Mapping[str, str] | None = None, + cwd: str | None = None, + ) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]: + """ + Execute a command in the virtual environment. + + Args: + connection_handle (ConnectionHandle): The handle for managing the connection. + command (list[str]): The command to execute as a list of strings. + environments (Mapping[str, str] | None): Environment variables for the command. + cwd (str | None): Working directory for the command. If None, uses the provider's default. + + Returns: + tuple[int, TransportWriteCloser, TransportReadCloser, TransportReadCloser] + a tuple containing pid and 3 handle to os.pipe(): (stdin, stdout, stderr). + After exuection, the 3 handles will be closed by caller. + """ + + @classmethod + @abstractmethod + def validate(cls, options: Mapping[str, Any]) -> None: + """ + Validate that options can connect to the provider. + + Raises: + SandboxConfigValidationError: If validation fails + """ + + @abstractmethod + def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus: + """ + Get the status of a command executed in the virtual environment. + + Args: + connection_handle (ConnectionHandle): The handle for managing the connection. + pid (int): The process ID of the command. + Returns: + CommandStatus: The status of the command execution. + """ + + @classmethod + @abstractmethod + def get_config_schema(cls) -> list[BasicProviderConfig]: + pass diff --git a/api/core/virtual_environment/__init__.py b/api/core/virtual_environment/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/virtual_environment/channel/__init__.py b/api/core/virtual_environment/channel/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/virtual_environment/channel/exec.py b/api/core/virtual_environment/channel/exec.py new file mode 100644 index 0000000000..6a03e2f766 --- /dev/null +++ b/api/core/virtual_environment/channel/exec.py @@ -0,0 +1,4 @@ +class TransportEOFError(Exception): + """Exception raised when attempting to read from a closed transport.""" + + pass diff --git a/api/core/virtual_environment/channel/pipe_transport.py b/api/core/virtual_environment/channel/pipe_transport.py new file mode 100644 index 0000000000..aecddeb6fc --- /dev/null +++ b/api/core/virtual_environment/channel/pipe_transport.py @@ -0,0 +1,72 @@ +import os + +from core.virtual_environment.channel.exec import TransportEOFError +from core.virtual_environment.channel.transport import Transport, TransportReadCloser, TransportWriteCloser + + +class PipeTransport(Transport): + """ + A Transport implementation using OS pipes. it requires two file descriptors: + one for reading and one for writing. + + NOTE: r_fd and w_fd must be a pair created by os.pipe(). or returned from subprocess.Popen + + NEVER FORGET TO CALL `close()` METHOD TO AVOID FILE DESCRIPTOR LEAKAGE. + """ + + def __init__(self, r_fd: int, w_fd: int): + self.r_fd = r_fd + self.w_fd = w_fd + + def write(self, data: bytes) -> None: + try: + os.write(self.w_fd, data) + except OSError: + raise TransportEOFError("Pipe write error, maybe the read end is closed") + + def read(self, n: int) -> bytes: + data = os.read(self.r_fd, n) + if data == b"": + raise TransportEOFError("End of Pipe reached") + return data + + def close(self) -> None: + os.close(self.r_fd) + os.close(self.w_fd) + + +class PipeReadCloser(TransportReadCloser): + """ + A Transport implementation using OS pipe for reading. + """ + + def __init__(self, r_fd: int): + self.r_fd = r_fd + + def read(self, n: int) -> bytes: + data = os.read(self.r_fd, n) + if data == b"": + raise TransportEOFError("End of Pipe reached") + + return data + + def close(self) -> None: + os.close(self.r_fd) + + +class PipeWriteCloser(TransportWriteCloser): + """ + A Transport implementation using OS pipe for writing. + """ + + def __init__(self, w_fd: int): + self.w_fd = w_fd + + def write(self, data: bytes) -> None: + try: + os.write(self.w_fd, data) + except OSError: + raise TransportEOFError("Pipe write error, maybe the read end is closed") + + def close(self) -> None: + os.close(self.w_fd) diff --git a/api/core/virtual_environment/channel/queue_transport.py b/api/core/virtual_environment/channel/queue_transport.py new file mode 100644 index 0000000000..7cf524316a --- /dev/null +++ b/api/core/virtual_environment/channel/queue_transport.py @@ -0,0 +1,117 @@ +from core.virtual_environment.channel.exec import TransportEOFError +from core.virtual_environment.channel.transport import TransportReadCloser + + +class QueueTransportReadCloser(TransportReadCloser): + """ + Transport implementation using queues for inter-thread communication. + + Usage: + q_transport = QueueTransportReadCloser() + write_handler = q_transport.get_write_handler() + + # In writer thread + write_handler.write(b"data") + + # In reader thread + data = q_transport.read(1024) + + # Close transport when done + q_transport.close() + """ + + _QUEUE_GET_TIMEOUT = 5.0 + + class WriteHandler: + """ + A write handler that writes data to a queue. + """ + + from queue import Queue + + def __init__(self, queue: Queue[bytes | None]) -> None: + self.queue = queue + + def write(self, data: bytes) -> None: + self.queue.put(data) + + def __init__( + self, + ) -> None: + """ + Initialize the QueueTransportReadCloser with write function. + """ + from queue import Queue + + self.q = Queue[bytes | None]() + self._read_buffer = bytearray() + self._closed = False + self._write_channel_closed = False + + def get_write_handler(self) -> WriteHandler: + """ + Get a write handler that writes to the internal queue. + """ + return QueueTransportReadCloser.WriteHandler(self.q) + + def close(self) -> None: + """ + Close the transport by putting a sentinel value in the queue. + """ + if self._write_channel_closed: + raise TransportEOFError("Write channel already closed") + + self._write_channel_closed = True + self.q.put(None) + + def read(self, n: int) -> bytes: + """ + Read up to n bytes from the queue. + + NEVER USE IT IN A MULTI-THREADED CONTEXT WITHOUT PROPER SYNCHRONIZATION. + """ + from queue import Empty + + if n <= 0: + return b"" + + if self._closed: + raise TransportEOFError("Transport is closed") + + to_return = self._drain_buffer(n) + + # At the first round reading from queue, hanging is required to wait for the data + # But after that, return immediately if no data is available + round = 0 + + while len(to_return) < n and not self._closed and (self.q.qsize() > 0 or round == 0): + try: + chunk = self.q.get(timeout=self._QUEUE_GET_TIMEOUT) + except Empty: + if self._closed: + raise TransportEOFError("Transport is closed") + continue + if chunk is None: + self._closed = True + if len(to_return) == 0: + raise TransportEOFError("Transport is closed") + else: + break + + self._read_buffer.extend(chunk) + + if n - len(to_return) > 0: + # Drain the buffer if we still need more data + to_return += self._drain_buffer(n - len(to_return)) + else: + # No more data needed, break + break + + round += 1 + + return to_return + + def _drain_buffer(self, n: int) -> bytes: + data = bytes(self._read_buffer[:n]) + del self._read_buffer[:n] + return data diff --git a/api/core/virtual_environment/channel/socket_transport.py b/api/core/virtual_environment/channel/socket_transport.py new file mode 100644 index 0000000000..904e42df37 --- /dev/null +++ b/api/core/virtual_environment/channel/socket_transport.py @@ -0,0 +1,70 @@ +import socket + +from core.virtual_environment.channel.exec import TransportEOFError +from core.virtual_environment.channel.transport import Transport, TransportReadCloser, TransportWriteCloser + + +class SocketTransport(Transport): + """ + A Transport implementation using a socket. + """ + + def __init__(self, sock: socket.SocketIO): + self.sock = sock + + def write(self, data: bytes) -> None: + try: + self.sock.write(data) + except (ConnectionResetError, BrokenPipeError): + raise TransportEOFError("Socket write error, maybe the read end is closed") + + def read(self, n: int) -> bytes: + try: + data = self.sock.read(n) + if data == b"": + raise TransportEOFError("End of Socket reached") + except (ConnectionResetError, BrokenPipeError): + raise TransportEOFError("Socket connection reset") + return data + + def close(self) -> None: + self.sock.close() + + +class SocketReadCloser(TransportReadCloser): + """ + A Transport implementation using a socket for reading. + """ + + def __init__(self, sock: socket.SocketIO): + self.sock = sock + + def read(self, n: int) -> bytes: + try: + data = self.sock.read(n) + if data == b"": + raise TransportEOFError("End of Socket reached") + return data + except (ConnectionResetError, BrokenPipeError): + raise TransportEOFError("Socket connection reset") + + def close(self) -> None: + self.sock.close() + + +class SocketWriteCloser(TransportWriteCloser): + """ + A Transport implementation using a socket for writing. + """ + + def __init__(self, sock: socket.SocketIO): + self.sock = sock + + def write(self, data: bytes) -> None: + try: + self.sock.write(data) + except (ConnectionResetError, BrokenPipeError): + raise TransportEOFError("Socket write error, maybe the read end is closed") + + def close(self) -> None: + self.sock.close() diff --git a/api/core/virtual_environment/channel/transport.py b/api/core/virtual_environment/channel/transport.py new file mode 100644 index 0000000000..130538ab63 --- /dev/null +++ b/api/core/virtual_environment/channel/transport.py @@ -0,0 +1,80 @@ +from abc import abstractmethod +from typing import Protocol + + +class TransportCloser(Protocol): + """ + Transport that can be closed. + """ + + @abstractmethod + def close(self) -> None: + """ + Close the transport. + """ + + +class TransportWriter(Protocol): + """ + Transport that can be written to. + """ + + @abstractmethod + def write(self, data: bytes) -> None: + """ + Write data to the transport. + + Raises TransportEOFError if the transport is closed. + """ + + +class TransportReader(Protocol): + """ + Transport that can be read from. + """ + + @abstractmethod + def read(self, n: int) -> bytes: + """ + Read up to n bytes from the transport. + + Raises TransportEOFError if the end of the transport is reached. + """ + + +class TransportReadCloser(TransportReader, TransportCloser): + """ + Transport that can be read from and closed. + """ + + +class TransportWriteCloser(TransportWriter, TransportCloser): + """ + Transport that can be written to and closed. + """ + + +class Transport(TransportReader, TransportWriter, TransportCloser): + """ + Transport that can be read from, written to, and closed. + """ + + +class NopTransportWriteCloser(TransportWriteCloser): + """ + A no-operation TransportWriteCloser implementation. + + This transport does nothing on write and close operations. + """ + + def write(self, data: bytes) -> None: + """ + No-operation write method. + """ + pass + + def close(self) -> None: + """ + No-operation close method. + """ + pass diff --git a/api/core/virtual_environment/constants.py b/api/core/virtual_environment/constants.py new file mode 100644 index 0000000000..11662bb817 --- /dev/null +++ b/api/core/virtual_environment/constants.py @@ -0,0 +1,10 @@ +""" +Constants for virtual environment providers. + +Centralizes timeout and other configuration values used across different sandbox providers +(E2B, SSH, Docker) to ensure consistency and ease of maintenance. +""" + +# Command execution timeout in seconds (5 hours) +# Used by providers to limit how long a single command can run +COMMAND_EXECUTION_TIMEOUT_SECONDS = 5 * 60 * 60 # 18000 seconds diff --git a/api/core/virtual_environment/providers/__init__.py b/api/core/virtual_environment/providers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/virtual_environment/providers/aws_code_interpreter_sandbox.py b/api/core/virtual_environment/providers/aws_code_interpreter_sandbox.py new file mode 100644 index 0000000000..4ead65c126 --- /dev/null +++ b/api/core/virtual_environment/providers/aws_code_interpreter_sandbox.py @@ -0,0 +1,531 @@ +""" +AWS Bedrock AgentCore Code Interpreter sandbox provider. + +Uses the AgentCore Code Interpreter built-in tool to provide a sandboxed code execution +environment with shell access and file operations. The Code Interpreter runs in an +isolated microVM managed by AWS and communicates via the `InvokeCodeInterpreter` API. + +Two boto3 clients are involved: +- **bedrock-agentcore-control** (Control Plane): manages the Code Interpreter resource + (create / delete). Users must create one beforehand or use the system-provided + ``aws.codeinterpreter.v1``. +- **bedrock-agentcore** (Data Plane): manages sessions and executes operations + (start/stop session, execute commands, file I/O). + +Key differences from other providers: +- stdin is not supported (same as E2B) — uses ``NopTransportWriteCloser``. +- ``executeCommand`` returns the full stdout/stderr once the command completes, + rather than streaming incrementally. We wrap the result with + ``QueueTransportReadCloser`` so the upper-layer ``CommandFuture`` works unchanged. +- ``get_command_status`` raises ``NotSupportedOperationError`` (same as E2B). + The synchronous ``executeCommand`` path is used instead. +""" + +import logging +import posixpath +import shlex +import threading +from collections.abc import Mapping, Sequence +from enum import StrEnum +from io import BytesIO +from typing import Any +from uuid import uuid4 + +from core.entities.provider_entities import BasicProviderConfig +from core.virtual_environment.__base.entities import ( + Arch, + CommandStatus, + ConnectionHandle, + FileState, + Metadata, + OperatingSystem, +) +from core.virtual_environment.__base.exec import ( + ArchNotSupportedError, + NotSupportedOperationError, + SandboxConfigValidationError, +) +from core.virtual_environment.__base.virtual_environment import VirtualEnvironment +from core.virtual_environment.channel.queue_transport import QueueTransportReadCloser +from core.virtual_environment.channel.transport import ( + NopTransportWriteCloser, + TransportReadCloser, + TransportWriteCloser, +) +from core.virtual_environment.constants import COMMAND_EXECUTION_TIMEOUT_SECONDS + +logger = logging.getLogger(__name__) + +# Maximum time in seconds that the boto3 read on the EventStream socket is +# allowed to block. Must exceed the longest expected command execution. +_BOTO3_READ_TIMEOUT_SECONDS = COMMAND_EXECUTION_TIMEOUT_SECONDS + 60 + + +class AWSCodeInterpreterEnvironment(VirtualEnvironment): + """ + AWS Bedrock AgentCore Code Interpreter virtual environment provider. + + The provider maps the ``VirtualEnvironment`` protocol onto the AgentCore + Code Interpreter Data-Plane API (``InvokeCodeInterpreter``). + + Lifecycle: + 1. ``_construct_environment`` starts a new session via + ``StartCodeInterpreterSession``. + 2. Commands and file operations invoke ``InvokeCodeInterpreter`` with + the appropriate ``name`` (``executeCommand``, ``writeFiles``, etc.). + 3. ``release_environment`` stops the session via + ``StopCodeInterpreterSession``. + + Configuration (``OptionsKey``): + - ``aws_access_key_id`` / ``aws_secret_access_key``: IAM credentials. + - ``aws_region``: AWS region (e.g. ``us-east-1``). + - ``code_interpreter_id``: the Code Interpreter resource identifier + (e.g. ``aws.codeinterpreter.v1`` for the system-provided one). + - ``session_timeout_seconds``: optional; defaults to 900 (15 min). + """ + + _WORKDIR = "/home/user" + + class OptionsKey(StrEnum): + AWS_ACCESS_KEY_ID = "aws_access_key_id" + AWS_SECRET_ACCESS_KEY = "aws_secret_access_key" + AWS_REGION = "aws_region" + CODE_INTERPRETER_ID = "code_interpreter_id" + SESSION_TIMEOUT_SECONDS = "session_timeout_seconds" + + class StoreKey(StrEnum): + CLIENT = "client" + SESSION_ID = "session_id" + CODE_INTERPRETER_ID = "code_interpreter_id" + + # ------------------------------------------------------------------ + # Config schema & validation + # ------------------------------------------------------------------ + + @classmethod + def get_config_schema(cls) -> list[BasicProviderConfig]: + return [ + BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=cls.OptionsKey.AWS_ACCESS_KEY_ID), + BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=cls.OptionsKey.AWS_SECRET_ACCESS_KEY), + BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name=cls.OptionsKey.AWS_REGION), + BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name=cls.OptionsKey.CODE_INTERPRETER_ID), + ] + + @classmethod + def validate(cls, options: Mapping[str, Any]) -> None: + """Validate credentials by starting then immediately stopping a session.""" + import boto3 + from botocore.exceptions import ClientError + + for key in (cls.OptionsKey.AWS_ACCESS_KEY_ID, cls.OptionsKey.AWS_SECRET_ACCESS_KEY, cls.OptionsKey.AWS_REGION): + if not options.get(key): + raise SandboxConfigValidationError(f"{key} is required") + + code_interpreter_id = options.get(cls.OptionsKey.CODE_INTERPRETER_ID, "") + if not code_interpreter_id: + raise SandboxConfigValidationError("code_interpreter_id is required") + + client = boto3.client( + "bedrock-agentcore", + region_name=options[cls.OptionsKey.AWS_REGION], + aws_access_key_id=options[cls.OptionsKey.AWS_ACCESS_KEY_ID], + aws_secret_access_key=options[cls.OptionsKey.AWS_SECRET_ACCESS_KEY], + ) + + try: + resp = client.start_code_interpreter_session( + codeInterpreterIdentifier=code_interpreter_id, + sessionTimeoutSeconds=60, + ) + session_id = resp["sessionId"] + # Immediately stop the validation session. + client.stop_code_interpreter_session( + codeInterpreterIdentifier=code_interpreter_id, + sessionId=session_id, + ) + except ClientError as exc: + raise SandboxConfigValidationError(f"AWS AgentCore Code Interpreter validation failed: {exc}") from exc + except Exception as exc: + raise SandboxConfigValidationError(f"AWS AgentCore Code Interpreter connection failed: {exc}") from exc + + # ------------------------------------------------------------------ + # Environment lifecycle + # ------------------------------------------------------------------ + + def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata: + """Start a new Code Interpreter session and detect platform info.""" + import boto3 + from botocore.config import Config + + code_interpreter_id: str = options.get(self.OptionsKey.CODE_INTERPRETER_ID, "") + timeout_seconds: int = int(options.get(self.OptionsKey.SESSION_TIMEOUT_SECONDS, 900)) + + client = boto3.client( + "bedrock-agentcore", + region_name=options[self.OptionsKey.AWS_REGION], + aws_access_key_id=options[self.OptionsKey.AWS_ACCESS_KEY_ID], + aws_secret_access_key=options[self.OptionsKey.AWS_SECRET_ACCESS_KEY], + config=Config(read_timeout=_BOTO3_READ_TIMEOUT_SECONDS), + ) + + resp = client.start_code_interpreter_session( + codeInterpreterIdentifier=code_interpreter_id, + sessionTimeoutSeconds=timeout_seconds, + ) + session_id: str = resp["sessionId"] + + logger.info( + "AgentCore Code Interpreter session started: code_interpreter_id=%s, session_id=%s", + code_interpreter_id, + session_id, + ) + + # Detect architecture and OS via a quick command. + arch = Arch.AMD64 + operating_system = OperatingSystem.LINUX + try: + result = self._invoke(client, code_interpreter_id, session_id, "executeCommand", {"command": "uname -m -s"}) + system_info = (result.get("stdout") or "").strip() + parts = system_info.split() + if len(parts) >= 2: + operating_system = self._convert_operating_system(parts[0]) + arch = self._convert_architecture(parts[1]) + elif len(parts) == 1: + arch = self._convert_architecture(parts[0]) + except Exception: + logger.warning("Failed to detect platform info, defaulting to Linux/AMD64") + + # Inject environment variables if provided. + if environments: + export_parts = [f"export {k}={shlex.quote(v)}" for k, v in environments.items()] + export_cmd = " && ".join(export_parts) + try: + self._invoke(client, code_interpreter_id, session_id, "executeCommand", {"command": export_cmd}) + except Exception: + logger.warning("Failed to inject environment variables into AgentCore session") + + return Metadata( + id=session_id, + arch=arch, + os=operating_system, + store={ + self.StoreKey.CLIENT: client, + self.StoreKey.SESSION_ID: session_id, + self.StoreKey.CODE_INTERPRETER_ID: code_interpreter_id, + }, + ) + + def release_environment(self) -> None: + """Stop the Code Interpreter session and release resources.""" + client = self._client + try: + client.stop_code_interpreter_session( + codeInterpreterIdentifier=self._code_interpreter_id, + sessionId=self._session_id, + ) + logger.info("AgentCore Code Interpreter session stopped: session_id=%s", self._session_id) + except Exception: + logger.warning( + "Failed to stop AgentCore Code Interpreter session: session_id=%s", + self._session_id, + exc_info=True, + ) + + # ------------------------------------------------------------------ + # Connection (virtual — no real connection needed) + # ------------------------------------------------------------------ + + def establish_connection(self) -> ConnectionHandle: + return ConnectionHandle(id=uuid4().hex) + + def release_connection(self, connection_handle: ConnectionHandle) -> None: + pass + + # ------------------------------------------------------------------ + # Command execution + # ------------------------------------------------------------------ + + def execute_command( + self, + connection_handle: ConnectionHandle, + command: list[str], + environments: Mapping[str, str] | None = None, + cwd: str | None = None, + ) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]: + """ + Execute a shell command via AgentCore Code Interpreter. + + The command is executed synchronously on the AWS side; a background thread + calls ``executeCommand`` and feeds the result into queue-based transports + so the caller sees the standard Transport interface. + + stdin is not supported — ``NopTransportWriteCloser`` is returned. + """ + stdout_stream = QueueTransportReadCloser() + stderr_stream = QueueTransportReadCloser() + + working_dir = self._workspace_path(cwd) if cwd else self._WORKDIR + cmd_str = shlex.join(command) + + # Wrap env vars and cwd into the command string since the API only + # accepts a flat ``command`` string argument. + prefix_parts: list[str] = [] + if environments: + for k, v in environments.items(): + prefix_parts.append(f"export {k}={shlex.quote(v)}") + prefix_parts.append(f"cd {shlex.quote(working_dir)}") + full_cmd = " && ".join([*prefix_parts, cmd_str]) + + threading.Thread( + target=self._cmd_thread, + args=(full_cmd, stdout_stream, stderr_stream), + daemon=True, + ).start() + + return ( + "N/A", + NopTransportWriteCloser(), + stdout_stream, + stderr_stream, + ) + + def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus: + """Not supported — same as E2B. ``CommandFuture`` handles this gracefully.""" + raise NotSupportedOperationError("AgentCore Code Interpreter does not support getting command status.") + + # ------------------------------------------------------------------ + # File operations + # ------------------------------------------------------------------ + + def upload_file(self, path: str, content: BytesIO) -> None: + """Upload a file to the Code Interpreter session.""" + remote_path = self._workspace_path(path) + file_bytes = content.read() + + self._invoke( + self._client, + self._code_interpreter_id, + self._session_id, + "writeFiles", + {"content": [{"path": remote_path, "blob": file_bytes}]}, + ) + + def download_file(self, path: str) -> BytesIO: + """Download a file from the Code Interpreter session.""" + remote_path = self._workspace_path(path) + + result = self._invoke( + self._client, + self._code_interpreter_id, + self._session_id, + "readFiles", + {"path": remote_path}, + ) + + # The response content blocks may contain blob or text data. + content_blocks: list[dict[str, Any]] = result.get("content", []) + for block in content_blocks: + resource = block.get("resource") + if resource: + blob = resource.get("blob") + if blob: + return BytesIO(blob if isinstance(blob, bytes) else blob.encode()) + text = resource.get("text") + if text: + return BytesIO(text.encode("utf-8")) + # Fallback: check top-level data/text fields. + if block.get("data"): + data = block["data"] + return BytesIO(data if isinstance(data, bytes) else data.encode()) + if block.get("text"): + return BytesIO(block["text"].encode("utf-8")) + + return BytesIO(b"") + + def list_files(self, directory_path: str, limit: int) -> Sequence[FileState]: + """List files in a directory of the Code Interpreter session.""" + remote_dir = self._workspace_path(directory_path) + + result = self._invoke( + self._client, + self._code_interpreter_id, + self._session_id, + "listFiles", + {"directoryPath": remote_dir}, + ) + + # The API returns file information in content blocks. + # Since the exact structure may vary, we also fall back to running + # a shell command to list files if the content format is not parseable. + content_blocks: list[dict[str, Any]] = result.get("content", []) + files: list[FileState] = [] + + for block in content_blocks: + text = block.get("text", "") + name = block.get("name", "") + size = block.get("size", 0) + uri = block.get("uri", "") + + file_path = uri or name or text + if not file_path: + continue + + # Normalise to relative path from workdir. + if file_path.startswith(self._WORKDIR): + file_path = posixpath.relpath(file_path, self._WORKDIR) + + files.append( + FileState( + path=file_path, + size=size or 0, + created_at=0, + updated_at=0, + ) + ) + + if len(files) >= limit: + break + + return files + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + @property + def _client(self) -> Any: + return self.metadata.store[self.StoreKey.CLIENT] + + @property + def _session_id(self) -> str: + return str(self.metadata.store[self.StoreKey.SESSION_ID]) + + @property + def _code_interpreter_id(self) -> str: + return str(self.metadata.store[self.StoreKey.CODE_INTERPRETER_ID]) + + def _workspace_path(self, path: str) -> str: + """Convert a path to an absolute path in the Code Interpreter session.""" + normalized = posixpath.normpath(path) + if normalized in ("", "."): + return self._WORKDIR + if normalized.startswith("/"): + return normalized + return posixpath.join(self._WORKDIR, normalized) + + def _cmd_thread( + self, + command: str, + stdout_stream: QueueTransportReadCloser, + stderr_stream: QueueTransportReadCloser, + ) -> None: + """Background thread that executes a command and feeds output into queue transports.""" + stdout_writer = stdout_stream.get_write_handler() + stderr_writer = stderr_stream.get_write_handler() + + try: + result = self._invoke( + self._client, + self._code_interpreter_id, + self._session_id, + "executeCommand", + {"command": command}, + ) + stdout_data = result.get("stdout", "") + stderr_data = result.get("stderr", "") + + if stdout_data: + stdout_writer.write(stdout_data.encode("utf-8") if isinstance(stdout_data, str) else stdout_data) + if stderr_data: + stderr_writer.write(stderr_data.encode("utf-8") if isinstance(stderr_data, str) else stderr_data) + except Exception as exc: + error_msg = f"Command execution failed: {type(exc).__name__}: {exc}\n" + stderr_writer.write(error_msg.encode()) + finally: + stdout_stream.close() + stderr_stream.close() + + @staticmethod + def _invoke( + client: Any, + code_interpreter_id: str, + session_id: str, + name: str, + arguments: dict[str, Any], + ) -> dict[str, Any]: + """ + Call ``InvokeCodeInterpreter`` and extract the structured result. + + The API returns an EventStream; this helper iterates over it and merges + ``structuredContent`` and ``content`` from all ``result`` events into a + single dict. + + Returns a dict that may contain: + - ``stdout``, ``stderr``, ``exitCode``, ``taskId``, ``taskStatus``, + ``executionTime`` (from ``structuredContent``) + - ``content`` (list of content blocks) + - ``isError`` (bool) + """ + response = client.invoke_code_interpreter( + codeInterpreterIdentifier=code_interpreter_id, + sessionId=session_id, + name=name, + arguments=arguments, + ) + + merged: dict[str, Any] = {"content": []} + + stream = response.get("stream") + if stream is None: + return merged + + for event in stream: + result = event.get("result") + if result is None: + # Check for exception events. + for exc_key in ( + "accessDeniedException", + "validationException", + "resourceNotFoundException", + "throttlingException", + "internalServerException", + ): + if event.get(exc_key): + raise RuntimeError(f"AgentCore error ({exc_key}): {event[exc_key]}") + continue + + if result.get("isError"): + merged["isError"] = True + + structured = result.get("structuredContent") + if structured: + merged.update(structured) + + content = result.get("content") + if content: + merged["content"].extend(content) + + return merged + + @staticmethod + def _convert_architecture(arch_str: str) -> Arch: + arch_map: dict[str, Arch] = { + "x86_64": Arch.AMD64, + "aarch64": Arch.ARM64, + "armv7l": Arch.ARM64, + "arm64": Arch.ARM64, + "amd64": Arch.AMD64, + } + if arch_str in arch_map: + return arch_map[arch_str] + raise ArchNotSupportedError(f"Unsupported architecture: {arch_str}") + + @staticmethod + def _convert_operating_system(os_str: str) -> OperatingSystem: + os_map: dict[str, OperatingSystem] = { + "Linux": OperatingSystem.LINUX, + "Darwin": OperatingSystem.DARWIN, + } + if os_str in os_map: + return os_map[os_str] + raise ArchNotSupportedError(f"Unsupported operating system: {os_str}") diff --git a/api/core/virtual_environment/providers/daytona_sandbox.py b/api/core/virtual_environment/providers/daytona_sandbox.py new file mode 100644 index 0000000000..f9d30d1574 --- /dev/null +++ b/api/core/virtual_environment/providers/daytona_sandbox.py @@ -0,0 +1,281 @@ +import logging +import posixpath +import shlex +import threading +import time +from collections.abc import Mapping, Sequence +from datetime import datetime +from enum import StrEnum +from io import BytesIO +from typing import Any, TypedDict, cast +from uuid import uuid4 + +from daytona import ( + CodeLanguage, + CreateSandboxFromImageParams, + CreateSandboxFromSnapshotParams, + Daytona, + DaytonaConfig, + Sandbox, +) + +logger = logging.getLogger(__name__) + + +class _CommandRecord(TypedDict): + """Record for tracking command execution state.""" + + thread: threading.Thread + exit_code: int | None + + +from core.virtual_environment.__base.entities import ( + Arch, + CommandStatus, + ConnectionHandle, + FileState, + Metadata, + OperatingSystem, +) +from core.virtual_environment.__base.virtual_environment import VirtualEnvironment +from core.virtual_environment.channel.queue_transport import QueueTransportReadCloser +from core.virtual_environment.channel.transport import ( + NopTransportWriteCloser, + TransportReadCloser, + TransportWriteCloser, +) + + +class DaytonaEnvironment(VirtualEnvironment): + """ + Daytona virtual environment provider backed by Daytona Sandboxes. + """ + + _DEFAULT_DAYTONA_API_URL = "https://app.daytona.io/api" + + class OptionsKey(StrEnum): + API_KEY = "api_key" + API_URL = "api_url" + TARGET = "target" + LANGUAGE = "language" + SNAPSHOT = "snapshot" + IMAGE = "image" + AUTO_STOP_INTERVAL = "auto_stop_interval" + AUTO_ARCHIVE_INTERVAL = "auto_archive_interval" + AUTO_DELETE_INTERVAL = "auto_delete_interval" + PUBLIC = "public" + NAME = "name" + LABELS = "labels" + + class StoreKey(StrEnum): + DAYTONA = "daytona" + SANDBOX = "sandbox" + WORKDIR = "workdir" + COMMANDS = "commands" + COMMANDS_LOCK = "commands_lock" + + def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata: + config = DaytonaConfig( + api_key=cast(str | None, options.get(self.OptionsKey.API_KEY)), + api_url=cast(str | None, options.get(self.OptionsKey.API_URL, self._DEFAULT_DAYTONA_API_URL)), + target=cast(str | None, options.get(self.OptionsKey.TARGET)), + ) + daytona = Daytona(config) + + language = cast(CodeLanguage, options.get(self.OptionsKey.LANGUAGE, CodeLanguage.PYTHON)) + auto_stop_interval = cast(int | None, options.get(self.OptionsKey.AUTO_STOP_INTERVAL)) + auto_archive_interval = cast(int | None, options.get(self.OptionsKey.AUTO_ARCHIVE_INTERVAL)) + auto_delete_interval = cast(int | None, options.get(self.OptionsKey.AUTO_DELETE_INTERVAL)) + public = cast(bool | None, options.get(self.OptionsKey.PUBLIC)) + name = cast(str | None, options.get(self.OptionsKey.NAME)) + labels = cast(dict[str, str] | None, options.get(self.OptionsKey.LABELS)) + + image = cast(str | None, options.get(self.OptionsKey.IMAGE)) + snapshot = cast(str | None, options.get(self.OptionsKey.SNAPSHOT)) + + if image is not None: + params = CreateSandboxFromImageParams( + image=image, + language=language, + env_vars=dict(environments or {}), + auto_stop_interval=auto_stop_interval, + auto_archive_interval=auto_archive_interval, + auto_delete_interval=auto_delete_interval, + public=public, + name=name, + labels=labels, + ) + else: + params = CreateSandboxFromSnapshotParams( + snapshot=snapshot, + language=language, + env_vars=dict(environments or {}), + auto_stop_interval=auto_stop_interval, + auto_archive_interval=auto_archive_interval, + auto_delete_interval=auto_delete_interval, + public=public, + name=name, + labels=labels, + ) + + sandbox = daytona.create(params=params) + workdir = sandbox.get_work_dir() + + return Metadata( + id=sandbox.id, + arch=Arch.AMD64, + os=OperatingSystem.LINUX, + store={ + self.StoreKey.DAYTONA: daytona, + self.StoreKey.SANDBOX: sandbox, + self.StoreKey.WORKDIR: workdir, + self.StoreKey.COMMANDS: {}, + self.StoreKey.COMMANDS_LOCK: threading.Lock(), + }, + ) + + def release_environment(self) -> None: + daytona: Daytona = self.metadata.store[self.StoreKey.DAYTONA] + sandbox = self.metadata.store[self.StoreKey.SANDBOX] + try: + daytona.delete(sandbox) + except Exception: + logger.exception("Failed to delete Daytona sandbox %s during cleanup", sandbox.id) + + def establish_connection(self) -> ConnectionHandle: + return ConnectionHandle(id=uuid4().hex) + + def release_connection(self, connection_handle: ConnectionHandle) -> None: + pass + + def upload_file(self, path: str, content: BytesIO) -> None: + remote_path = self._workspace_path(path) + sandbox: Sandbox = self.metadata.store[self.StoreKey.SANDBOX] + sandbox.fs.upload_file(content.getvalue(), remote_path) + + def download_file(self, path: str) -> BytesIO: + remote_path = self._workspace_path(path) + sandbox: Sandbox = self.metadata.store[self.StoreKey.SANDBOX] + data = sandbox.fs.download_file(remote_path) + return BytesIO(data) + + def list_files(self, directory_path: str, limit: int) -> Sequence[FileState]: + remote_dir = self._workspace_path(directory_path) + sandbox: Sandbox = self.metadata.store[self.StoreKey.SANDBOX] + try: + file_infos = sandbox.fs.list_files(remote_dir) + except Exception: + logger.exception("Failed to list files in directory %s", remote_dir) + return [] + + files: list[FileState] = [] + for info in file_infos: + full_path = posixpath.join(remote_dir, info.name) + relative_path = posixpath.relpath(full_path, self._working_dir) + files.append( + FileState( + path=relative_path, + size=info.size, + created_at=self._parse_mod_time(info.mod_time), + updated_at=self._parse_mod_time(info.mod_time), + ) + ) + if len(files) >= limit: + break + return files + + def execute_command( + self, + connection_handle: ConnectionHandle, + command: list[str], + environments: Mapping[str, str] | None = None, + cwd: str | None = None, + ) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]: + sandbox: Sandbox = self.metadata.store[self.StoreKey.SANDBOX] + + stdout_stream = QueueTransportReadCloser() + stderr_stream = QueueTransportReadCloser() + pid = uuid4().hex + + working_dir = self._workspace_path(cwd) if cwd else self._working_dir + + thread = threading.Thread( + target=self._exec_thread, + args=(pid, sandbox, command, environments or {}, working_dir, stdout_stream, stderr_stream), + daemon=True, + ) + + thread.start() + + return pid, NopTransportWriteCloser(), stdout_stream, stderr_stream + + def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus: + commands: dict[str, _CommandRecord] = self.metadata.store[self.StoreKey.COMMANDS] + commands_lock: threading.Lock = self.metadata.store[self.StoreKey.COMMANDS_LOCK] + + with commands_lock: + record = commands.get(pid) + if not record: + return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=None) + + thread: threading.Thread = record["thread"] + exit_code = record.get("exit_code") + if thread.is_alive(): + return CommandStatus(status=CommandStatus.Status.RUNNING, exit_code=None) + return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=exit_code) + + @property + def _working_dir(self) -> str: + return cast(str, self.metadata.store[self.StoreKey.WORKDIR]) + + def _workspace_path(self, path: str) -> str: + normalized = posixpath.normpath(path) + if normalized in ("", "."): + return self._working_dir + if normalized.startswith("/"): + return normalized + return posixpath.join(self._working_dir, normalized) + + def _exec_thread( + self, + pid: str, + sandbox: Sandbox, + command: list[str], + environments: Mapping[str, str], + cwd: str, + stdout_stream: QueueTransportReadCloser, + stderr_stream: QueueTransportReadCloser, + ) -> None: + commands: dict[str, _CommandRecord] = self.metadata.store[self.StoreKey.COMMANDS] + commands_lock: threading.Lock = self.metadata.store[self.StoreKey.COMMANDS_LOCK] + + stdout_writer = stdout_stream.get_write_handler() + stderr_writer = stderr_stream.get_write_handler() + exit_code: int | None = None + try: + response = sandbox.process.exec( + command=shlex.join(command), + env=dict(environments), + cwd=cwd, + ) + exit_code = response.exit_code + output = response.artifacts.stdout if response.artifacts and response.artifacts.stdout else response.result + if output: + stdout_writer.write(output.encode()) + except Exception as exc: + stderr_writer.write(str(exc).encode()) + exit_code = 1 + finally: + stdout_stream.close() + stderr_stream.close() + with commands_lock: + if pid in commands: + commands[pid]["exit_code"] = exit_code + + def _parse_mod_time(self, mod_time: str) -> int: + try: + cleaned = mod_time.replace("Z", "+00:00") + return int(datetime.fromisoformat(cleaned).timestamp()) + except (ValueError, AttributeError, OSError): + logger.warning("Failed to parse modification time '%s', falling back to current time", mod_time) + return int(time.time()) diff --git a/api/core/virtual_environment/providers/dify_simple_sandbox.py b/api/core/virtual_environment/providers/dify_simple_sandbox.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/virtual_environment/providers/docker_daemon_sandbox.py b/api/core/virtual_environment/providers/docker_daemon_sandbox.py new file mode 100644 index 0000000000..3e87e2c453 --- /dev/null +++ b/api/core/virtual_environment/providers/docker_daemon_sandbox.py @@ -0,0 +1,635 @@ +from __future__ import annotations + +import logging +import socket +import tarfile +import threading +from collections.abc import Mapping, Sequence +from enum import IntEnum, StrEnum +from functools import lru_cache +from io import BytesIO +from pathlib import PurePosixPath +from queue import Empty, Queue +from typing import TYPE_CHECKING, Any, cast +from uuid import uuid4 + +if TYPE_CHECKING: + from docker.models.containers import Container + + import docker +from configs import dify_config +from core.entities.provider_entities import BasicProviderConfig +from core.virtual_environment.__base.entities import ( + Arch, + CommandStatus, + ConnectionHandle, + FileState, + Metadata, + OperatingSystem, +) +from core.virtual_environment.__base.exec import SandboxConfigValidationError, VirtualEnvironmentLaunchFailedError +from core.virtual_environment.__base.virtual_environment import VirtualEnvironment +from core.virtual_environment.channel.exec import TransportEOFError +from core.virtual_environment.channel.socket_transport import SocketWriteCloser +from core.virtual_environment.channel.transport import TransportReadCloser, TransportWriteCloser + + +class DockerStreamType(IntEnum): + """ + Docker multiplexed stream types. + + When Docker exec runs with tty=False, it multiplexes stdout and stderr over a single + socket connection. Each frame is prefixed with an 8-byte header: + + [stream_type (1 byte)][0][0][0][payload_size (4 bytes, big-endian)] + + This allows the client to distinguish between stdout (type=1) and stderr (type=2). + See: https://docs.docker.com/engine/api/v1.41/#operation/ContainerAttach + """ + + STDIN = 0 + STDOUT = 1 + STDERR = 2 + + +class DockerDemuxer: + """ + Demultiplexes Docker's combined stdout/stderr stream using producer-consumer pattern. + + Docker exec with tty=False sends stdout and stderr over a single socket, + each frame prefixed with an 8-byte header: + - Byte 0: stream type (1=stdout, 2=stderr) + - Bytes 1-3: reserved (zeros) + - Bytes 4-7: payload size (big-endian uint32) + + THREAD SAFETY: + A single background thread reads frames from the socket and dispatches payloads + to thread-safe queues. This avoids race conditions where multiple threads + calling _read_next_frame() simultaneously caused frame header/body corruption, + resulting in incomplete stdout/stderr output. + + TIMEOUT HANDLING: + Queue.get() uses a timeout to prevent indefinite blocking when the socket is + closed unexpectedly (e.g., container removed). This allows periodic checks for + error conditions and closed state. + """ + + _HEADER_SIZE = 8 + _QUEUE_GET_TIMEOUT = 5.0 # seconds + + def __init__(self, sock: socket.SocketIO): + self._sock = sock + self._stdout_queue: Queue[bytes | None] = Queue() + self._stderr_queue: Queue[bytes | None] = Queue() + self._closed = False + self._error: BaseException | None = None + + self._demux_thread = threading.Thread( + target=self._demux_loop, + daemon=True, + name="docker-demuxer", + ) + self._demux_thread.start() + + def _demux_loop(self) -> None: + try: + while not self._closed: + header = self._read_exact(self._HEADER_SIZE) + if not header or len(header) < self._HEADER_SIZE: + break + + frame_type = header[0] + payload_size = int.from_bytes(header[4:8], "big") + + if payload_size == 0: + continue + + payload = self._read_exact(payload_size) + if not payload: + break + + if frame_type == DockerStreamType.STDOUT: + self._stdout_queue.put(payload) + elif frame_type == DockerStreamType.STDERR: + self._stderr_queue.put(payload) + + except BaseException as e: + self._error = e + finally: + self._stdout_queue.put(None) + self._stderr_queue.put(None) + + def _read_exact(self, size: int) -> bytes: + data = bytearray() + remaining = size + while remaining > 0: + try: + chunk = self._sock.read(remaining) + if not chunk: + return bytes(data) if data else b"" + data.extend(chunk) + remaining -= len(chunk) + except (ConnectionResetError, BrokenPipeError): + return bytes(data) if data else b"" + return bytes(data) + + def read_stdout(self) -> bytes: + return self._read_from_queue(self._stdout_queue) + + def read_stderr(self) -> bytes: + return self._read_from_queue(self._stderr_queue) + + def _read_from_queue(self, queue: Queue[bytes | None]) -> bytes: + """ + Read from queue with timeout to prevent indefinite blocking. + + When the Docker container is removed or the socket is closed unexpectedly, + the demux thread may be stuck in socket.read(). Using a timeout allows us + to periodically check for errors and closed state instead of blocking forever. + """ + if self._error: + raise TransportEOFError(f"Demuxer error: {self._error}") from self._error + + while True: + try: + chunk = queue.get(timeout=self._QUEUE_GET_TIMEOUT) + if chunk is None: + if self._error: + raise TransportEOFError(f"Demuxer error: {str(self._error)}") + raise TransportEOFError("End of demuxed stream") + return chunk + except Empty: + # Timeout - check if we should continue waiting + if self._closed: + raise TransportEOFError("Demuxer closed") + if self._error: + error = cast(BaseException, self._error) + raise TransportEOFError(f"Demuxer error: {error}") from error + # No error, continue waiting + + def close(self) -> None: + if not self._closed: + self._closed = True + try: + self._sock.close() + except Exception: + logging.error("Failed to close Docker demuxer socket", exc_info=True) + + +class DemuxedStdoutReader(TransportReadCloser): + def __init__(self, demuxer: DockerDemuxer): + self._demuxer = demuxer + self._buffer = bytearray() + + def read(self, n: int) -> bytes: + if self._buffer: + data = bytes(self._buffer[:n]) + del self._buffer[:n] + if data: + return data + + chunk = self._demuxer.read_stdout() + if len(chunk) <= n: + return chunk + + self._buffer.extend(chunk[n:]) + return chunk[:n] + + def close(self) -> None: + self._demuxer.close() + + +class DemuxedStderrReader(TransportReadCloser): + def __init__(self, demuxer: DockerDemuxer): + self._demuxer = demuxer + self._buffer = bytearray() + + def read(self, n: int) -> bytes: + if self._buffer: + data = bytes(self._buffer[:n]) + del self._buffer[:n] + if data: + return data + + chunk = self._demuxer.read_stderr() + if len(chunk) <= n: + return chunk + + self._buffer.extend(chunk[n:]) + return chunk[:n] + + def close(self) -> None: + self._demuxer.close() + + +""" +EXAMPLE: + + +from collections.abc import Mapping +from typing import Any +from + +from core.virtual_environment.providers.docker_daemon_sandbox import DockerDaemonEnvironment + +options: Mapping[str, Any] = { + # OptionsKey values are optional + # DockerDaemonEnvironment.OptionsKey.DOCKER_SOCK: "unix:///var/run/docker.sock", + # DockerDaemonEnvironment.OptionsKey.DOCKER_AGENT_IMAGE: "ubuntu:latest", + # DockerDaemonEnvironment.OptionsKey.DOCKER_AGENT_COMMAND + # + "docker_sock": "unix:///var/run/docker.sock", # optional, default to unix socket + "docker_agent_image": "ubuntu:latest", # optional, default to ubuntu:latest + "docker_agent_command": "/bin/sh -c 'while true; do sleep 1; done'", # optional, default to None +} + + +environment = DockerDaemonEnvironment(options=options) +connection_handle = environment.establish_connection() + +pid, transport_stdout, transport_stderr, transport_stdin = environment.execute_command( + connection_handle, ["uname", "-a"] +) + +print(f"Executed command with PID: {pid}") + +# consume stdout +# consume stdout +while True: + try: + output = transport_stdout.read(1024) + except TransportEOFError: + logger.info("End of stdout reached") + break + + logger.info("Command output: %s", output.decode().strip()) + + +environment.release_connection(connection_handle) +environment.release_environment() + +""" + + +class DockerDaemonEnvironment(VirtualEnvironment): + _WORKING_DIR = "/workspace" + _DEAFULT_DOCKER_IMAGE = "ubuntu:latest" + _DEFAULT_DOCKER_SOCK = ( + "unix:///var/run/docker.sock" # Use an invalid default to avoid accidental local docker usage + ) + + class OptionsKey(StrEnum): + DOCKER_SOCK = "docker_sock" + DOCKER_IMAGE = "docker_image" + DOCKER_COMMAND = "docker_command" + + @classmethod + def get_config_schema(cls) -> list[BasicProviderConfig]: + return [ + BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name=cls.OptionsKey.DOCKER_SOCK), + BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name=cls.OptionsKey.DOCKER_IMAGE), + ] + + @classmethod + def validate(cls, options: Mapping[str, Any]) -> None: + # Import Docker SDK lazily so it is loaded after gevent monkey-patching. + import docker.errors + + import docker + + docker_sock = options.get(cls.OptionsKey.DOCKER_SOCK, cls._DEFAULT_DOCKER_SOCK) + try: + client = docker.DockerClient(base_url=docker_sock) + client.ping() + except docker.errors.DockerException as e: + raise SandboxConfigValidationError(f"Docker connection failed: {e}") from e + + def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata: + """ + Construct the Docker daemon virtual environment. + """ + + docker_client = self.get_docker_daemon( + docker_sock=options.get(self.OptionsKey.DOCKER_SOCK, self._DEFAULT_DOCKER_SOCK) + ) + + default_docker_image = options.get(self.OptionsKey.DOCKER_IMAGE, self._DEAFULT_DOCKER_IMAGE) + container_command = options.get(self.OptionsKey.DOCKER_COMMAND) + + container = docker_client.containers.run( + image=default_docker_image, + command=container_command, + detach=True, + remove=True, + stdin_open=True, + working_dir=self._WORKING_DIR, + environment=dict(environments), + ) + + # FIXME(yeuoly): For a better solution + if dify_config.FILES_URL.startswith("http://localhost") or dify_config.FILES_URL.startswith("http://127.0.0.1"): + logging.warning( + "DIFY_FILES_URL is set to a localhost address. " + "Docker containers may not be able to access the host's localhost. " + "Consider using host.docker.internal or the host machine's IP address." + ) + + dify_host_port = dify_config.DIFY_PORT + # launch socat to forward 5001 port from host to container + container.exec_run( # pyright: ignore[reportUnknownMemberType] # + cmd=[ + "bash", + "-c", + f"nohup socat TCP-LISTEN:{dify_host_port},bind=127.0.0.1,fork,reuseaddr " + f"TCP:host.docker.internal:{dify_host_port} >/tmp/socat.log 2>&1", + ], + detach=True, + ) + + # wait for the container to be fully started + container.reload() + + if not container.id: + raise VirtualEnvironmentLaunchFailedError("Failed to start Docker container for DockerDaemonEnvironment.") + + return Metadata( + id=container.id, + arch=self._get_container_architecture(container), + os=OperatingSystem.LINUX, + ) + + @classmethod + @lru_cache(maxsize=5) + def get_docker_daemon(cls, docker_sock: str) -> docker.DockerClient: + """ + Get the Docker daemon client. + + NOTE: I guess nobody will use more than 5 different docker sockets in practice.... + """ + import docker + + return docker.DockerClient(base_url=docker_sock) + + @classmethod + @lru_cache(maxsize=5) + def get_docker_api_client(cls, docker_sock: str) -> docker.APIClient: + """ + Get the Docker low-level API client. + """ + import docker + + return docker.APIClient(base_url=docker_sock) + + def get_docker_sock(self) -> str: + """ + Get the Docker socket path. + """ + return self.options.get(self.OptionsKey.DOCKER_SOCK, self._DEFAULT_DOCKER_SOCK) + + @property + def _working_dir(self) -> str: + """ + Get the working directory inside the Docker container. + """ + return self._WORKING_DIR + + def _get_container(self) -> Container: + """ + Get the Docker container instance. + """ + docker_client = self.get_docker_daemon(self.get_docker_sock()) + return docker_client.containers.get(self.metadata.id) + + def _normalize_relative_path(self, path: str) -> PurePosixPath: + parts: list[str] = [] + for part in PurePosixPath(path).parts: + if part in ("", ".", "/"): + continue + if part == "..": + if not parts: + raise ValueError("Path escapes the workspace.") + parts.pop() + continue + parts.append(part) + return PurePosixPath(*parts) + + def _relative_path(self, path: str) -> PurePosixPath: + normalized = self._normalize_relative_path(path) + if normalized.parts: + return normalized + return PurePosixPath() + + def _container_path(self, path: str) -> str: + relative = self._relative_path(path) + if not relative.parts: + return self._working_dir + return f"{self._working_dir}/{relative.as_posix()}" + + def _workspace_path(self, path: str) -> str: + """ + Convert a path to an absolute path in the Docker container. + Absolute paths are returned as-is, relative paths are joined with _working_dir. + """ + normalized = PurePosixPath(path) + if normalized.is_absolute(): + return str(normalized) + return self._container_path(path) + + def upload_file(self, path: str, content: BytesIO) -> None: + """Upload a file to the container. + + Files and intermediate directories are created with world-writable permissions + (0o777 for directories, 0o666 for files) to avoid permission issues when the container + runs as a non-root user but Docker's put_archive creates files as root. + """ + container = self._get_container() + normalized = PurePosixPath(path) + + if normalized.is_absolute(): + parent_dir = str(normalized.parent) + file_name = normalized.name + payload = content.getvalue() + tar_stream = BytesIO() + with tarfile.open(fileobj=tar_stream, mode="w") as tar: + tar_info = tarfile.TarInfo(name=file_name) + tar_info.size = len(payload) + tar_info.mode = 0o666 + tar.addfile(tar_info, BytesIO(payload)) + tar_stream.seek(0) + container.put_archive(parent_dir, tar_stream.read()) # pyright: ignore[reportUnknownMemberType] # + return + + relative_path = self._relative_path(path) + if not relative_path.parts: + raise ValueError("Upload path must point to a file within the workspace.") + + payload = content.getvalue() + tar_stream = BytesIO() + with tarfile.open(fileobj=tar_stream, mode="w") as tar: + # Add intermediate directories with proper permissions + for i in range(len(relative_path.parts) - 1): + dir_path = PurePosixPath(*relative_path.parts[: i + 1]) + dir_info = tarfile.TarInfo(name=dir_path.as_posix() + "/") + dir_info.type = tarfile.DIRTYPE + dir_info.mode = 0o777 + tar.addfile(dir_info) + + # Add the file + tar_info = tarfile.TarInfo(name=relative_path.as_posix()) + tar_info.size = len(payload) + tar_info.mode = 0o666 + tar.addfile(tar_info, BytesIO(payload)) + tar_stream.seek(0) + container.put_archive(self._working_dir, tar_stream.read()) # pyright: ignore[reportUnknownMemberType] # + + def download_file(self, path: str) -> BytesIO: + container = self._get_container() + container_path = self._workspace_path(path) + stream, _ = container.get_archive(container_path) + tar_stream = BytesIO() + for chunk in stream: + tar_stream.write(chunk) + tar_stream.seek(0) + + with tarfile.open(fileobj=tar_stream, mode="r:*") as tar: + members = [member for member in tar.getmembers() if member.isfile()] + if not members: + return BytesIO() + extracted = tar.extractfile(members[0]) + if extracted is None: + return BytesIO() + return BytesIO(extracted.read()) + + def list_files(self, directory_path: str, limit: int) -> Sequence[FileState]: + import docker.errors + + container = self._get_container() + container_path = self._container_path(directory_path) + relative_base = self._relative_path(directory_path) + try: + stream, _ = container.get_archive(container_path) + except docker.errors.NotFound: + return [] + tar_stream = BytesIO() + for chunk in stream: + tar_stream.write(chunk) + tar_stream.seek(0) + + files: list[FileState] = [] + archive_root = PurePosixPath(container_path).name + with tarfile.open(fileobj=tar_stream, mode="r:*") as tar: + for member in tar.getmembers(): + if not member.isfile(): + continue + member_path = PurePosixPath(member.name) + if member_path.parts and member_path.parts[0] == archive_root: + member_path = PurePosixPath(*member_path.parts[1:]) + if not member_path.parts: + continue + relative_path = relative_base / member_path + files.append( + FileState( + path=relative_path.as_posix(), + size=member.size, + created_at=int(member.mtime), + updated_at=int(member.mtime), + ) + ) + if len(files) >= limit: + break + return files + + def establish_connection(self) -> ConnectionHandle: + return ConnectionHandle(id=uuid4().hex) + + def release_connection(self, connection_handle: ConnectionHandle) -> None: + # No action needed for Docker exec connections + pass + + def release_environment(self) -> None: + import docker.errors + + try: + container = self._get_container() + except docker.errors.NotFound: + return + try: + container.remove(force=True) + except docker.errors.NotFound: + return + + def execute_command( + self, + connection_handle: ConnectionHandle, + command: list[str], + environments: Mapping[str, str] | None = None, + cwd: str | None = None, + ) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]: + container = self._get_container() + container_id = container.id + if not isinstance(container_id, str) or not container_id: + raise RuntimeError("Docker container ID is not available for exec.") + api_client = self.get_docker_api_client(self.get_docker_sock()) + + working_dir = self._workspace_path(cwd) if cwd else self._working_dir + + exec_info: dict[str, object] = cast( + dict[str, object], + api_client.exec_create( # pyright: ignore[reportUnknownMemberType] # + container_id, + cmd=command, + stdin=True, + stdout=True, + stderr=True, + tty=False, + workdir=working_dir, + environment=dict(environments) if environments else None, + ), + ) + + if not isinstance(exec_info.get("Id"), str): + raise RuntimeError("Failed to create Docker exec instance.") + + exec_id: str = str(exec_info.get("Id")) + raw_sock: socket.SocketIO = cast(socket.SocketIO, api_client.exec_start(exec_id, socket=True, tty=False)) # pyright: ignore[reportUnknownMemberType] # + + stdin_transport = SocketWriteCloser(raw_sock) + demuxer = DockerDemuxer(raw_sock) + stdout_transport = DemuxedStdoutReader(demuxer) + stderr_transport = DemuxedStderrReader(demuxer) + + return exec_id, stdin_transport, stdout_transport, stderr_transport + + def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus: + api_client = self.get_docker_api_client(self.get_docker_sock()) + inspect: dict[str, object] = cast(dict[str, object], api_client.exec_inspect(pid)) # pyright: ignore[reportUnknownMemberType] # + exit_code = inspect.get("ExitCode") + if inspect.get("Running") or exit_code is None: + return CommandStatus(status=CommandStatus.Status.RUNNING, exit_code=None) + if not isinstance(exit_code, int): + exit_code = None + return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=exit_code) + + def _get_container_architecture(self, container: Container) -> Arch: + """ + Detect the container's CPU architecture from its image metadata. + Falls back to ``uname -m`` inside the container when image attrs are unavailable. + """ + try: + image = container.image + arch_str = str(image.attrs.get("Architecture", "")).lower() if image else "" + except Exception: + arch_str = "" + + if not arch_str: + result = container.exec_run("uname -m") + arch_str = result.output.decode("utf-8", errors="replace").strip().lower() if result.exit_code == 0 else "" + + match arch_str: + case "x86_64" | "amd64": + return Arch.AMD64 + case "aarch64" | "arm64": + return Arch.ARM64 + case _: + logging.warning("Unknown container architecture '%s', defaulting to AMD64", arch_str) + return Arch.AMD64 diff --git a/api/core/virtual_environment/providers/e2b_sandbox.py b/api/core/virtual_environment/providers/e2b_sandbox.py new file mode 100644 index 0000000000..ab2b5fbe6e --- /dev/null +++ b/api/core/virtual_environment/providers/e2b_sandbox.py @@ -0,0 +1,360 @@ +import logging +import posixpath +import shlex +import threading +from collections.abc import Mapping, Sequence +from enum import StrEnum +from functools import cached_property +from io import BytesIO +from typing import Any +from uuid import uuid4 + +from configs import dify_config +from core.entities.provider_entities import BasicProviderConfig +from core.virtual_environment.__base.entities import ( + Arch, + CommandStatus, + ConnectionHandle, + FileState, + Metadata, + OperatingSystem, +) +from core.virtual_environment.__base.exec import ( + ArchNotSupportedError, + NotSupportedOperationError, + SandboxConfigValidationError, +) +from core.virtual_environment.__base.virtual_environment import VirtualEnvironment +from core.virtual_environment.channel.queue_transport import QueueTransportReadCloser +from core.virtual_environment.channel.transport import ( + NopTransportWriteCloser, + TransportReadCloser, + TransportWriteCloser, +) +from core.virtual_environment.constants import COMMAND_EXECUTION_TIMEOUT_SECONDS + +logger = logging.getLogger(__name__) + +""" +import logging +from collections.abc import Mapping +from typing import Any + +from core.virtual_environment.providers.e2b_sandbox import E2BEnvironment + +options: Mapping[str, Any] = { + E2BEnvironment.OptionsKey.API_KEY: "?????????", + E2BEnvironment.OptionsKey.E2B_DEFAULT_TEMPLATE: "code-interpreter-v1", + E2BEnvironment.OptionsKey.E2B_LIST_FILE_DEPTH: 2, + E2BEnvironment.OptionsKey.E2B_API_URL: "https://api.e2b.app", +} + + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.DEBUG) + +# environment = DockerDaemonEnvironment(options=options) +# environment = LocalVirtualEnvironment(options=options) +environment = E2BEnvironment(options=options) + +connection_handle = environment.establish_connection() + +pid, transport_stdin, transport_stdout, transport_stderr = environment.execute_command( + connection_handle, ["uname", "-a"] +) + +logger.info("Executed command with PID: %s", pid) + +# consume stdout +# consume stdout +while True: + try: + output = transport_stdout.read(1024) + except TransportEOFError: + logger.info("End of stdout reached") + break + + logger.info("Command output: %s", output.decode().strip()) + + +environment.release_connection(connection_handle) +environment.release_environment() + +""" + + +class E2BEnvironment(VirtualEnvironment): + """ + E2B virtual environment provider. + """ + + _WORKDIR = "/home/user" + _E2B_API_URL = "https://api.e2b.app" + + class OptionsKey(StrEnum): + API_KEY = "api_key" + E2B_LIST_FILE_DEPTH = "e2b_list_file_depth" + E2B_DEFAULT_TEMPLATE = "e2b_default_template" + E2B_API_URL = "e2b_api_url" + + class StoreKey(StrEnum): + SANDBOX = "sandbox" + + @classmethod + def get_config_schema(cls) -> list[BasicProviderConfig]: + return [ + BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=cls.OptionsKey.API_KEY), + BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name=cls.OptionsKey.E2B_API_URL), + BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name=cls.OptionsKey.E2B_DEFAULT_TEMPLATE), + ] + + @classmethod + def validate(cls, options: Mapping[str, Any]) -> None: + # Import E2B SDK lazily so it is loaded after gevent monkey-patching. + # See `api/gunicorn.conf.py` for how we patch other third-party libs (e.g. gRPC). + from e2b.exceptions import ( + AuthenticationException, # type: ignore[import-untyped] + ) + from e2b_code_interpreter import Sandbox # type: ignore[import-untyped] + + api_key = options.get(cls.OptionsKey.API_KEY, "") + if not api_key: + raise SandboxConfigValidationError("E2B API key is required") + + try: + Sandbox.list(api_key=api_key, limit=1).next_items() + except AuthenticationException as e: + raise SandboxConfigValidationError(f"E2B authentication failed: {e}") from e + except Exception as e: + raise SandboxConfigValidationError(f"E2B connection failed: {e}") from e + + def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata: + """ + Construct a new E2B virtual environment. + + The sandbox lifetime is capped by ``WORKFLOW_MAX_EXECUTION_TIME`` so the + provider can rely on E2B's native timeout instead of a background + keepalive thread that continuously extends the session. + + E2B allocates the remote sandbox before metadata probing completes, so + startup failures must best-effort terminate the sandbox before the + exception escapes. + """ + # Import E2B SDK lazily so it is loaded after gevent monkey-patching. + from e2b_code_interpreter import Sandbox # type: ignore[import-untyped] + + # TODO: add Dify as the user agent + sandbox = None + sandbox_id: str | None = None + api_key = options.get(self.OptionsKey.API_KEY, "") + try: + sandbox = Sandbox.create( + template=options.get(self.OptionsKey.E2B_DEFAULT_TEMPLATE, "code-interpreter-v1"), + timeout=dify_config.WORKFLOW_MAX_EXECUTION_TIME, + api_key=api_key, + api_url=options.get(self.OptionsKey.E2B_API_URL, self._E2B_API_URL), + envs=dict(environments), + ) + info = sandbox.get_info(api_key=api_key) + sandbox_id = info.sandbox_id + system_info = sandbox.commands.run("uname -m -s").stdout.strip() + system_parts = system_info.split() + if len(system_parts) == 2: + os_part, arch_part = system_parts + else: + arch_part = system_parts[0] + os_part = system_parts[1] if len(system_parts) > 1 else "" + + return Metadata( + id=info.sandbox_id, + arch=self._convert_architecture(arch_part.strip()), + os=self._convert_operating_system(os_part.strip()), + store={ + self.StoreKey.SANDBOX: sandbox, + }, + ) + except Exception: + if sandbox_id is None and sandbox is not None: + sandbox_id = getattr(sandbox, "sandbox_id", None) + if sandbox_id is not None: + try: + Sandbox.kill(api_key=api_key, sandbox_id=sandbox_id) + except Exception: + logger.exception("Failed to cleanup E2B sandbox after startup failure") + raise + + def release_environment(self) -> None: + """ + Release the E2B virtual environment. + """ + from e2b_code_interpreter import Sandbox # type: ignore[import-untyped] + + if not Sandbox.kill(api_key=self.api_key, sandbox_id=self.metadata.id): + raise Exception(f"Failed to release E2B sandbox with ID: {self.metadata.id}") + + def establish_connection(self) -> ConnectionHandle: + """ + Establish a connection to the E2B virtual environment. + """ + return ConnectionHandle(id=uuid4().hex) + + def release_connection(self, connection_handle: ConnectionHandle) -> None: + """ + Release the connection to the E2B virtual environment. + """ + pass + + def upload_file(self, path: str, content: BytesIO) -> None: + """ + Upload a file to the E2B virtual environment. + + Args: + path (str): The path to upload the file to. + content (BytesIO): The content of the file. + """ + remote_path = self._workspace_path(path) + sandbox = self.metadata.store[self.StoreKey.SANDBOX] + sandbox.files.write(remote_path, content) # pyright: ignore[reportUnknownMemberType] # + + def download_file(self, path: str) -> BytesIO: + """ + Download a file from the E2B virtual environment. + + Args: + path (str): The path to download the file from. + Returns: + BytesIO: The content of the file. + """ + remote_path = self._workspace_path(path) + sandbox = self.metadata.store[self.StoreKey.SANDBOX] + content = sandbox.files.read(remote_path, format="bytes") + return BytesIO(bytes(content)) + + def list_files(self, directory_path: str, limit: int) -> Sequence[FileState]: + """ + List files in a directory of the E2B virtual environment. + """ + sandbox = self.metadata.store[self.StoreKey.SANDBOX] + remote_dir = self._workspace_path(directory_path) + files_info = sandbox.files.list(remote_dir, depth=self.options.get(self.OptionsKey.E2B_LIST_FILE_DEPTH, 3)) + return [ + FileState( + path=posixpath.relpath(file_info.path, self._WORKDIR), + size=file_info.size, + created_at=int(file_info.modified_time.timestamp()), + updated_at=int(file_info.modified_time.timestamp()), + ) + for file_info in files_info + ] + + def execute_command( + self, + connection_handle: ConnectionHandle, + command: list[str], + environments: Mapping[str, str] | None = None, + cwd: str | None = None, + ) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]: + """ + Execute a command in the E2B virtual environment. + + STDIN is not yet supported. E2B's API is such a terrible mess... to support it may lead a bad design. + as a result we leave it for future improvement. + """ + sandbox = self.metadata.store[self.StoreKey.SANDBOX] + stdout_stream = QueueTransportReadCloser() + stderr_stream = QueueTransportReadCloser() + + working_dir = self._workspace_path(cwd) if cwd else self._WORKDIR + + threading.Thread( + target=self._cmd_thread, + args=(sandbox, command, environments, working_dir, stdout_stream, stderr_stream), + ).start() + + return ( + "N/A", + NopTransportWriteCloser(), # stdin not supported yet + stdout_stream, + stderr_stream, + ) + + def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus: + """ + Nop, E2B does not support getting command status yet. + """ + raise NotSupportedOperationError("E2B does not support getting command status yet.") + + def _cmd_thread( + self, + sandbox: Any, + command: list[str], + environments: Mapping[str, str] | None, + cwd: str, + stdout_stream: QueueTransportReadCloser, + stderr_stream: QueueTransportReadCloser, + ) -> None: + stdout_stream_write_handler = stdout_stream.get_write_handler() + stderr_stream_write_handler = stderr_stream.get_write_handler() + + try: + sandbox.commands.run( + cmd=shlex.join(command), + envs=dict(environments or {}), + cwd=cwd, + on_stdout=lambda data: stdout_stream_write_handler.write(data.encode()), + on_stderr=lambda data: stderr_stream_write_handler.write(data.encode()), + timeout=COMMAND_EXECUTION_TIMEOUT_SECONDS, + ) + except Exception as e: + # Capture exceptions and write to stderr stream so they can be retrieved via CommandFuture + # This prevents uncaught exceptions from being printed to console + error_msg = f"Command execution failed: {type(e).__name__}: {str(e)}\n" + stderr_stream_write_handler.write(error_msg.encode()) + finally: + # Close the write handlers to signal EOF + stdout_stream.close() + stderr_stream.close() + + @cached_property + def api_key(self) -> str: + """ + Get the API key for the E2B environment. + """ + return self.options.get(self.OptionsKey.API_KEY, "") + + def _workspace_path(self, path: str) -> str: + """ + Convert a path to an absolute path in the E2B environment. + Absolute paths are returned as-is, relative paths are joined with _WORKDIR. + """ + normalized = posixpath.normpath(path) + if normalized in ("", "."): + return self._WORKDIR + if normalized.startswith("/"): + return normalized + return posixpath.join(self._WORKDIR, normalized) + + def _convert_architecture(self, arch_str: str) -> Arch: + arch_map = { + "x86_64": Arch.AMD64, + "aarch64": Arch.ARM64, + "armv7l": Arch.ARM64, + "arm64": Arch.ARM64, + "amd64": Arch.AMD64, + "arm64v8": Arch.ARM64, + "arm64v7": Arch.ARM64, + } + if arch_str in arch_map: + return arch_map[arch_str] + + raise ArchNotSupportedError(f"Unsupported architecture: {arch_str}") + + def _convert_operating_system(self, os_str: str) -> OperatingSystem: + os_map = { + "Linux": OperatingSystem.LINUX, + "Darwin": OperatingSystem.DARWIN, + } + if os_str in os_map: + return os_map[os_str] + + raise ArchNotSupportedError(f"Unsupported operating system: {os_str}") diff --git a/api/core/virtual_environment/providers/local_without_isolation.py b/api/core/virtual_environment/providers/local_without_isolation.py new file mode 100644 index 0000000000..494de05738 --- /dev/null +++ b/api/core/virtual_environment/providers/local_without_isolation.py @@ -0,0 +1,308 @@ +import os +import pathlib +import shutil +import signal +import subprocess +from collections.abc import Mapping, Sequence +from functools import cached_property +from io import BytesIO +from platform import machine, system +from typing import Any +from uuid import uuid4 + +from core.entities.provider_entities import BasicProviderConfig +from core.virtual_environment.__base.entities import ( + Arch, + CommandStatus, + ConnectionHandle, + FileState, + Metadata, + OperatingSystem, +) +from core.virtual_environment.__base.exec import ArchNotSupportedError +from core.virtual_environment.__base.virtual_environment import VirtualEnvironment +from core.virtual_environment.channel.pipe_transport import PipeReadCloser, PipeWriteCloser +from core.virtual_environment.channel.transport import TransportReadCloser, TransportWriteCloser + +""" +USAGE: + +import logging +from collections.abc import Mapping +from typing import Any + +from core.virtual_environment.channel.exec import TransportEOFError +from core.virtual_environment.providers.local_without_isolation import LocalVirtualEnvironment + +options: Mapping[str, Any] = {} + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.DEBUG) + +environment = LocalVirtualEnvironment(options=options) + +connection_handle = environment.establish_connection() + +pid, transport_stdin, transport_stdout, transport_stderr = environment.execute_command( + connection_handle, + ["sh", "-lc", "for i in 1 2 3 4 5; do date '+%F %T'; sleep 1; done"], +) + +logger.info("Executed command with PID: %s", pid) + +# consume stdout +while True: + try: + output = transport_stdout.read(1024) + except TransportEOFError: + logger.info("End of stdout reached") + break + + logger.info("Command output: %s", output.decode().strip()) + + +environment.release_connection(connection_handle) +environment.release_environment() + +""" + + +class LocalVirtualEnvironment(VirtualEnvironment): + """ + Local virtual environment provider without isolation. + + WARNING: This provider does not provide any isolation. It's only suitable for development and testing purposes. + NEVER USE IT IN PRODUCTION ENVIRONMENTS. + """ + + @classmethod + def get_config_schema(cls) -> list[BasicProviderConfig]: + return [ + BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name="base_working_path"), + ] + + @classmethod + def validate(cls, options: Mapping[str, Any]) -> None: + pass + + def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata: + """ + Construct the local virtual environment. + + Under local without isolation, this method simply create a path for the environment and return the metadata. + """ + id = uuid4().hex + working_path = os.path.join(self._base_working_path, id) + os.makedirs(working_path, exist_ok=True) + return Metadata( + id=id, + arch=self._get_os_architecture(), + os=self._get_operating_system(), + ) + + def release_environment(self) -> None: + """ + Release the local virtual environment. + + Just simply remove the working directory. + """ + working_path = self.get_working_path() + if os.path.exists(working_path): + shutil.rmtree(working_path) + + def upload_file(self, path: str, content: BytesIO) -> None: + """ + Upload a file to the local virtual environment. + + Args: + path (str): The path to upload the file to. + content (BytesIO): The content of the file. + """ + working_path = self.get_working_path() + full_path = os.path.join(working_path, path) + os.makedirs(os.path.dirname(full_path), exist_ok=True) + pathlib.Path(full_path).write_bytes(content.getbuffer()) + + def download_file(self, path: str) -> BytesIO: + """ + Download a file from the local virtual environment. + + Args: + path (str): The path to download the file from. + Returns: + BytesIO: The content of the file. + """ + working_path = self.get_working_path() + full_path = os.path.join(working_path, path) + content = pathlib.Path(full_path).read_bytes() + return BytesIO(content) + + def list_files(self, directory_path: str, limit: int) -> Sequence[FileState]: + """ + List files in a directory of the local virtual environment. + """ + working_path = self.get_working_path() + full_directory_path = os.path.join(working_path, directory_path) + files: list[FileState] = [] + for root, _, filenames in os.walk(full_directory_path): + for filename in filenames: + if len(files) >= limit: + break + file_path = os.path.relpath(os.path.join(root, filename), working_path) + state = os.stat(os.path.join(root, filename)) + files.append( + FileState( + path=file_path, + size=state.st_size, + created_at=int(state.st_ctime), + updated_at=int(state.st_mtime), + ) + ) + if len(files) >= limit: + # break the outer loop as well + return files + + return files + + def establish_connection(self) -> ConnectionHandle: + """ + Establish a connection to the local virtual environment. + """ + return ConnectionHandle( + id=uuid4().hex, + ) + + def release_connection(self, connection_handle: ConnectionHandle) -> None: + """ + Release the connection to the local virtual environment. + """ + # No action needed for local without isolation + pass + + def execute_command( + self, + connection_handle: ConnectionHandle, + command: list[str], + environments: Mapping[str, str] | None = None, + cwd: str | None = None, + ) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]: + working_path = os.path.join(self.get_working_path(), cwd) if cwd else self.get_working_path() + stdin_read_fd, stdin_write_fd = os.pipe() + stdout_read_fd, stdout_write_fd = os.pipe() + stderr_read_fd, stderr_write_fd = os.pipe() + try: + process = subprocess.Popen( + command, + stdin=stdin_read_fd, + stdout=stdout_write_fd, + stderr=stderr_write_fd, + cwd=working_path, + close_fds=True, + env=environments, + ) + except Exception: + # Clean up file descriptors if process creation fails + for fd in ( + stdin_read_fd, + stdin_write_fd, + stdout_read_fd, + stdout_write_fd, + stderr_read_fd, + stderr_write_fd, + ): + try: + os.close(fd) + except OSError: + pass + raise + + # Close unused fds in the parent process + os.close(stdin_read_fd) + os.close(stdout_write_fd) + os.close(stderr_write_fd) + + # Create PipeTransport instances for stdin, stdout, and stderr + stdin_transport = PipeWriteCloser(w_fd=stdin_write_fd) + stdout_transport = PipeReadCloser(r_fd=stdout_read_fd) + stderr_transport = PipeReadCloser(r_fd=stderr_read_fd) + + # Return the process ID and file descriptors for stdin, stdout, and stderr + return str(process.pid), stdin_transport, stdout_transport, stderr_transport + + def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus: + pid_int = int(pid) + try: + waited_pid, wait_status = os.waitpid(pid_int, os.WNOHANG) + if waited_pid == 0: + return CommandStatus(status=CommandStatus.Status.RUNNING, exit_code=None) + + if os.WIFEXITED(wait_status): + exit_code = os.WEXITSTATUS(wait_status) + elif os.WIFSIGNALED(wait_status): + exit_code = -os.WTERMSIG(wait_status) + else: + exit_code = None + + return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=exit_code) + except ChildProcessError: + return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=None) + + def terminate_command(self, connection_handle: ConnectionHandle, pid: str) -> bool: + """Terminate a locally spawned process by PID when cancellation is requested.""" + + _ = connection_handle + try: + os.kill(int(pid), signal.SIGTERM) + except ProcessLookupError: + return False + return True + + def _get_os_architecture(self) -> Arch: + """ + Get the operating system architecture. + + Returns: + Arch: The operating system architecture. + """ + + arch = machine() + match arch.lower(): + case "x86_64" | "amd64": + return Arch.AMD64 + case "aarch64" | "arm64": + return Arch.ARM64 + case _: + raise ArchNotSupportedError(f"Unsupported architecture: {arch}") + + def _get_operating_system(self) -> OperatingSystem: + os_name = system().lower() + match os_name: + case "linux": + return OperatingSystem.LINUX + case "darwin": + return OperatingSystem.DARWIN + case _: + raise ArchNotSupportedError(f"Unsupported operating system: {os_name}") + + @cached_property + def _base_working_path(self) -> str: + """ + Get the base working path for the local virtual environment. + + Args: + options (Mapping[str, Any]): Options for requesting the virtual environment. + + Returns: + str: The base working path. + """ + cwd = os.getcwd() + return self.options.get("base_working_path", os.path.join(cwd, "local_virtual_environments")) + + def get_working_path(self) -> str: + """ + Get the working path for the local virtual environment. + + Returns: + str: The working path. + """ + return os.path.join(self._base_working_path, self.metadata.id) diff --git a/api/core/virtual_environment/providers/ssh_sandbox.py b/api/core/virtual_environment/providers/ssh_sandbox.py new file mode 100644 index 0000000000..dd2c095509 --- /dev/null +++ b/api/core/virtual_environment/providers/ssh_sandbox.py @@ -0,0 +1,499 @@ +from __future__ import annotations + +import contextlib +import logging +import shlex +import stat +import threading +import time +from collections.abc import Mapping, Sequence +from enum import StrEnum +from io import BytesIO +from pathlib import PurePosixPath +from typing import Any +from uuid import uuid4 + +from core.entities.provider_entities import BasicProviderConfig +from core.virtual_environment.__base.entities import ( + Arch, + CommandStatus, + ConnectionHandle, + FileState, + Metadata, + OperatingSystem, +) +from core.virtual_environment.__base.exec import SandboxConfigValidationError, VirtualEnvironmentLaunchFailedError +from core.virtual_environment.__base.virtual_environment import VirtualEnvironment +from core.virtual_environment.channel.exec import TransportEOFError +from core.virtual_environment.channel.queue_transport import QueueTransportReadCloser +from core.virtual_environment.channel.transport import TransportWriteCloser +from core.virtual_environment.constants import COMMAND_EXECUTION_TIMEOUT_SECONDS + +logger = logging.getLogger(__name__) + + +class _SSHStdinTransport(TransportWriteCloser): + def __init__(self, channel: Any): + self._channel = channel + self._closed = False + + def write(self, data: bytes) -> None: + if self._closed: + raise TransportEOFError("Transport is closed") + if not data: + return + self._channel.sendall(data) + + def close(self) -> None: + if self._closed: + return + self._closed = True + with contextlib.suppress(Exception): + self._channel.shutdown_write() + + +class SSHSandboxEnvironment(VirtualEnvironment): + _DEFAULT_SSH_HOST = "agentbox" + _DEFAULT_SSH_PORT = 22 + _DEFAULT_BASE_WORKING_PATH = "/workspace/sandboxes" + _SSH_CONNECT_TIMEOUT_SECONDS = 10 + _SSH_OPERATION_TIMEOUT_SECONDS = 30 + _COMMAND_TIMEOUT_EXIT_CODE = 124 + + class OptionsKey(StrEnum): + SSH_HOST = "ssh_host" + SSH_PORT = "ssh_port" + SSH_USERNAME = "ssh_username" + SSH_PASSWORD = "ssh_password" + BASE_WORKING_PATH = "base_working_path" + + def __init__( + self, + tenant_id: str, + options: Mapping[str, Any], + environments: Mapping[str, str] | None = None, + user_id: str | None = None, + ) -> None: + self._connections: dict[str, Any] = {} + self._commands: dict[str, CommandStatus] = {} + self._command_channels: dict[str, Any] = {} + self._lock = threading.Lock() + super().__init__(tenant_id=tenant_id, options=options, environments=environments, user_id=user_id) + + @classmethod + def get_config_schema(cls) -> list[BasicProviderConfig]: + return [ + BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name=cls.OptionsKey.SSH_HOST), + BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name=cls.OptionsKey.SSH_PORT), + BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name=cls.OptionsKey.SSH_USERNAME), + BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=cls.OptionsKey.SSH_PASSWORD), + BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name=cls.OptionsKey.BASE_WORKING_PATH), + ] + + @classmethod + def validate(cls, options: Mapping[str, Any]) -> None: + cls._require_non_empty_option(options, cls.OptionsKey.SSH_USERNAME) + cls._require_non_empty_option(options, cls.OptionsKey.SSH_PASSWORD) + with cls._create_ssh_client(options): + return + + def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata: + environment_id = uuid4().hex + working_path = self._workspace_path_from_id(environment_id) + + try: + with self._client() as client: + self._run_command(client, f"mkdir -p {shlex.quote(working_path)}") + arch_stdout = self._run_command(client, "uname -m") + os_stdout = self._run_command(client, "uname -s") + except SandboxConfigValidationError as e: + raise ValueError(f"SSH configuration validation failed, please check sandbox provider: {e}") from e + except Exception as e: + raise VirtualEnvironmentLaunchFailedError(f"Failed to construct SSH environment: {e}") from e + + return Metadata( + id=environment_id, + arch=self._parse_arch(arch_stdout.decode("utf-8", errors="replace").strip()), + os=self._parse_os(os_stdout.decode("utf-8", errors="replace").strip()), + store={"working_path": working_path}, + ) + + def establish_connection(self) -> ConnectionHandle: + connection_id = uuid4().hex + client = self._create_ssh_client(self.options) + with self._lock: + self._connections[connection_id] = client + return ConnectionHandle(id=connection_id) + + def release_connection(self, connection_handle: ConnectionHandle) -> None: + with self._lock: + client = self._connections.pop(connection_handle.id, None) + if client is not None: + with contextlib.suppress(Exception): + client.close() + + def release_environment(self) -> None: + working_path = self.get_working_path() + with contextlib.suppress(Exception): + with self._client() as client: + self._run_command(client, f"rm -rf {shlex.quote(working_path)}") + + def execute_command( + self, + connection_handle: ConnectionHandle, + command: list[str], + environments: Mapping[str, str] | None = None, + cwd: str | None = None, + ) -> tuple[str, TransportWriteCloser, QueueTransportReadCloser, QueueTransportReadCloser]: + client = self._get_connection(connection_handle) + transport = client.get_transport() + if transport is None: + raise RuntimeError("SSH transport is not available") + + channel = transport.open_session() + channel.settimeout(self._SSH_OPERATION_TIMEOUT_SECONDS) + channel.set_combine_stderr(False) + + execution_command = self._build_exec_command(command, environments, cwd) + channel.exec_command(execution_command) + + pid = uuid4().hex + stdin_transport = _SSHStdinTransport(channel) + stdout_transport = QueueTransportReadCloser() + stderr_transport = QueueTransportReadCloser() + + with self._lock: + self._commands[pid] = CommandStatus(status=CommandStatus.Status.RUNNING, exit_code=None) + self._command_channels[pid] = channel + + threading.Thread( + target=self._consume_channel_output, + args=(pid, channel, stdout_transport, stderr_transport, COMMAND_EXECUTION_TIMEOUT_SECONDS), + daemon=True, + ).start() + + return pid, stdin_transport, stdout_transport, stderr_transport + + def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus: + with self._lock: + status = self._commands.get(pid) + if status is None: + return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=None) + return status + + def terminate_command(self, connection_handle: ConnectionHandle, pid: str) -> bool: + """Best-effort termination by closing the SSH channel that owns the command.""" + + _ = connection_handle + with self._lock: + channel = self._command_channels.get(pid) + if channel is None: + return False + self._commands[pid] = CommandStatus( + status=CommandStatus.Status.COMPLETED, + exit_code=self._COMMAND_TIMEOUT_EXIT_CODE, + ) + + with contextlib.suppress(Exception): + channel.close() + return True + + def upload_file(self, path: str, content: BytesIO) -> None: + destination_path = self._workspace_path(path) + with self._client() as client: + sftp = client.open_sftp() + try: + self._set_sftp_operation_timeout(sftp) + self._sftp_mkdirs(sftp, str(PurePosixPath(destination_path).parent)) + with sftp.file(destination_path, "wb") as remote_file: + remote_file.write(content.getvalue()) + finally: + sftp.close() + + def download_file(self, path: str) -> BytesIO: + source_path = self._workspace_path(path) + with self._client() as client: + sftp = client.open_sftp() + try: + self._set_sftp_operation_timeout(sftp) + with sftp.file(source_path, "rb") as remote_file: + return BytesIO(remote_file.read()) + finally: + sftp.close() + + def list_files(self, directory_path: str, limit: int) -> Sequence[FileState]: + if limit <= 0: + return [] + + root_directory = self._workspace_path(directory_path) + files: list[FileState] = [] + + with self._client() as client: + sftp = client.open_sftp() + try: + self._set_sftp_operation_timeout(sftp) + pending = [root_directory] + while pending and len(files) < limit: + current_directory = pending.pop(0) + with contextlib.suppress(FileNotFoundError): + for attr in sftp.listdir_attr(current_directory): + current_path = str(PurePosixPath(current_directory) / attr.filename) + mode = attr.st_mode + if stat.S_ISDIR(mode): + pending.append(current_path) + continue + + files.append( + FileState( + path=self._to_relative_workspace_path(current_path), + size=attr.st_size, + created_at=int(attr.st_mtime), + updated_at=int(attr.st_mtime), + ) + ) + if len(files) >= limit: + break + finally: + sftp.close() + + return files + + @classmethod + def _require_non_empty_option(cls, options: Mapping[str, Any], key: OptionsKey) -> str: + value = options.get(key) + if not isinstance(value, str) or not value.strip(): + raise SandboxConfigValidationError(f"Missing required option: {key}") + return value.strip() + + @classmethod + def _create_ssh_client(cls, options: Mapping[str, Any]) -> Any: + import paramiko + + host = options.get(cls.OptionsKey.SSH_HOST, cls._DEFAULT_SSH_HOST) + port = options.get(cls.OptionsKey.SSH_PORT, cls._DEFAULT_SSH_PORT) + username = cls._require_non_empty_option(options, cls.OptionsKey.SSH_USERNAME) + password = cls._require_non_empty_option(options, cls.OptionsKey.SSH_PASSWORD) + + if not isinstance(host, str) or not host.strip(): + raise SandboxConfigValidationError(f"Invalid option value: {cls.OptionsKey.SSH_HOST}") + + try: + port_int = int(port) + except (TypeError, ValueError) as e: + raise SandboxConfigValidationError(f"Invalid option value: {cls.OptionsKey.SSH_PORT}") from e + + client = paramiko.SSHClient() + client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + + try: + client.connect( + hostname=host.strip(), + port=port_int, + username=username, + password=password, + look_for_keys=False, + allow_agent=False, + timeout=cls._SSH_CONNECT_TIMEOUT_SECONDS, + banner_timeout=cls._SSH_CONNECT_TIMEOUT_SECONDS, + auth_timeout=cls._SSH_CONNECT_TIMEOUT_SECONDS, + ) + transport = client.get_transport() + if transport is not None: + transport.set_keepalive(30) + except Exception as e: + with contextlib.suppress(Exception): + client.close() + raise SandboxConfigValidationError(f"SSH connection failed: {e}") from e + + return client + + @contextlib.contextmanager + def _client(self): + client = self._create_ssh_client(self.options) + try: + yield client + finally: + with contextlib.suppress(Exception): + client.close() + + def _get_connection(self, connection_handle: ConnectionHandle) -> Any: + with self._lock: + client = self._connections.get(connection_handle.id) + if client is None: + raise ValueError(f"Connection handle not found: {connection_handle.id}") + return client + + def _workspace_path_from_id(self, environment_id: str) -> str: + base_path = self.options.get(self.OptionsKey.BASE_WORKING_PATH, self._DEFAULT_BASE_WORKING_PATH) + if not isinstance(base_path, str) or not base_path.strip(): + base_path = self._DEFAULT_BASE_WORKING_PATH + return str(PurePosixPath(base_path) / environment_id) + + def get_working_path(self) -> str: + working_path = self.metadata.store.get("working_path") + if not isinstance(working_path, str) or not working_path: + return self._workspace_path_from_id(self.metadata.id) + return working_path + + def _workspace_path(self, path: str | None) -> str: + if not path: + return self.get_working_path() + + normalized = PurePosixPath(path) + if normalized.is_absolute(): + return str(normalized) + return str(PurePosixPath(self.get_working_path()) / self._normalize_relative_path(path)) + + @staticmethod + def _normalize_relative_path(path: str) -> PurePosixPath: + parts: list[str] = [] + for part in PurePosixPath(path).parts: + if part in ("", ".", "/"): + continue + if part == "..": + if not parts: + raise ValueError("Path escapes the workspace.") + parts.pop() + continue + parts.append(part) + return PurePosixPath(*parts) + + def _to_relative_workspace_path(self, path: str) -> str: + workspace = PurePosixPath(self.get_working_path()) + target = PurePosixPath(path) + if target.is_relative_to(workspace): + return target.relative_to(workspace).as_posix() + return target.as_posix() + + def _build_exec_command( + self, command: list[str], environments: Mapping[str, str] | None = None, cwd: str | None = None + ) -> str: + working_path = self._workspace_path(cwd) + command_body = f"cd {shlex.quote(working_path)} && " + + if environments: + env_clause = " ".join(f"{key}={shlex.quote(value)}" for key, value in environments.items()) + command_body += f"{env_clause} " + + command_body += shlex.join(command) + return f"sh -lc {shlex.quote(command_body)}" + + @classmethod + def _run_command(cls, client: Any, command: str) -> bytes: + _, stdout, stderr = client.exec_command(command, timeout=cls._SSH_OPERATION_TIMEOUT_SECONDS) + stdout.channel.settimeout(cls._SSH_OPERATION_TIMEOUT_SECONDS) + + deadline = time.monotonic() + COMMAND_EXECUTION_TIMEOUT_SECONDS + while not stdout.channel.exit_status_ready(): + if time.monotonic() >= deadline: + with contextlib.suppress(Exception): + stdout.channel.close() + raise TimeoutError(f"SSH command timed out after {COMMAND_EXECUTION_TIMEOUT_SECONDS}s") + time.sleep(0.05) + + exit_code = stdout.channel.recv_exit_status() + stdout_data = stdout.read() + stderr_data = stderr.read() + + if exit_code != 0: + stderr_text = stderr_data.decode("utf-8", errors="replace") + raise RuntimeError(f"SSH command failed ({exit_code}): {stderr_text}") + + return stdout_data + + def _consume_channel_output( + self, + pid: str, + channel: Any, + stdout_transport: QueueTransportReadCloser, + stderr_transport: QueueTransportReadCloser, + max_runtime_seconds: int, + ) -> None: + stdout_writer = stdout_transport.get_write_handler() + stderr_writer = stderr_transport.get_write_handler() + exit_code: int | None = None + started_at = time.monotonic() + + try: + while True: + if time.monotonic() - started_at >= max_runtime_seconds: + exit_code = self._COMMAND_TIMEOUT_EXIT_CODE + stderr_writer.write(f"Command timed out after {max_runtime_seconds}s".encode()) + break + + if channel.recv_ready(): + stdout_writer.write(channel.recv(4096)) + if channel.recv_stderr_ready(): + stderr_writer.write(channel.recv_stderr(4096)) + + if channel.exit_status_ready() and not channel.recv_ready() and not channel.recv_stderr_ready(): + exit_code = int(channel.recv_exit_status()) + break + + time.sleep(0.05) + except TimeoutError: + logger.warning("SSH channel read timed out for command %s", pid) + exit_code = self._COMMAND_TIMEOUT_EXIT_CODE + finally: + with contextlib.suppress(Exception): + stdout_transport.close() + with contextlib.suppress(Exception): + stderr_transport.close() + with contextlib.suppress(Exception): + channel.close() + + with self._lock: + self._command_channels.pop(pid, None) + self._commands[pid] = CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=exit_code) + + def _set_sftp_operation_timeout(self, sftp: Any) -> None: + with contextlib.suppress(Exception): + sftp.get_channel().settimeout(self._SSH_OPERATION_TIMEOUT_SECONDS) + + @staticmethod + def _parse_arch(raw_arch: str) -> Arch: + arch = raw_arch.lower() + if arch in {"x86_64", "amd64"}: + return Arch.AMD64 + if arch in {"arm64", "aarch64"}: + return Arch.ARM64 + return Arch.AMD64 + + @staticmethod + def _parse_os(raw_os: str) -> OperatingSystem: + system_name = raw_os.lower() + if system_name == "darwin": + return OperatingSystem.DARWIN + return OperatingSystem.LINUX + + @staticmethod + def _sftp_mkdirs(sftp: Any, directory: str) -> None: + if not directory or directory == "/": + return + + path = PurePosixPath(directory) + current = PurePosixPath("/") if path.is_absolute() else PurePosixPath() + + for part in path.parts: + if part in ("", "/"): + continue + current = current / part + current_path = str(current) + try: + attrs = sftp.stat(current_path) + if not stat.S_ISDIR(attrs.st_mode): + raise OSError(f"Path exists but is not a directory: {current_path}") + continue + except OSError as e: + missing = isinstance(e, FileNotFoundError) or getattr(e, "errno", None) == 2 + missing = missing or "no such file" in str(e).lower() + if not missing: + raise + + try: + sftp.mkdir(current_path) + except OSError: + # Some SFTP servers report generic "Failure" when directory already exists. + attrs = sftp.stat(current_path) + if not stat.S_ISDIR(attrs.st_mode): + raise OSError(f"Failed to create directory: {current_path}") diff --git a/api/extensions/ext_blueprints.py b/api/extensions/ext_blueprints.py index 7d13f0c061..6a1bebffcd 100644 --- a/api/extensions/ext_blueprints.py +++ b/api/extensions/ext_blueprints.py @@ -90,6 +90,10 @@ def init_app(app: DifyApp): app.register_blueprint(inner_api_bp) app.register_blueprint(mcp_bp) + # TODO: enable after full sandbox integration + # from controllers.cli_api import bp as cli_api_bp + # app.register_blueprint(cli_api_bp) + # Register trigger blueprint with CORS for webhook calls _apply_cors_once( trigger_bp, diff --git a/api/extensions/storage/cached_presign_storage.py b/api/extensions/storage/cached_presign_storage.py new file mode 100644 index 0000000000..079a43aa70 --- /dev/null +++ b/api/extensions/storage/cached_presign_storage.py @@ -0,0 +1,209 @@ +"""Storage wrapper that caches presigned download URLs.""" + +import hashlib +import logging +from itertools import starmap + +from extensions.ext_redis import redis_client +from extensions.storage.base_storage import BaseStorage +from extensions.storage.storage_wrapper import StorageWrapper + +logger = logging.getLogger(__name__) + + +class CachedPresignStorage(StorageWrapper): + """Storage wrapper that caches presigned download URLs. + + Wraps a storage with presign capability and caches the generated URLs + in Redis to reduce repeated presign API calls. + + Example: + cached_storage = CachedPresignStorage( + storage=FilePresignStorage(base_storage), + cache_key_prefix="app_asset:draft_download", + ) + url = cached_storage.get_download_url("path/to/file.txt", expires_in=3600) + """ + + TTL_BUFFER_SECONDS = 60 + MIN_TTL_SECONDS = 60 + + def __init__( + self, + storage: BaseStorage, + cache_key_prefix: str = "presign_cache", + ): + super().__init__(storage) + self._redis = redis_client + self._cache_key_prefix = cache_key_prefix + + def delete(self, filename: str): + super().delete(filename) + self.invalidate([filename]) + + def get_download_url( + self, + filename: str, + expires_in: int = 3600, + *, + download_filename: str | None = None, + ) -> str: + """Get a presigned download URL, using cache when available. + + Args: + filename: The file path/key in storage + expires_in: URL validity duration in seconds (default: 1 hour) + download_filename: If provided, the browser will use this as the downloaded + file name. Cache keys include this value to avoid conflicts. + + Returns: + Presigned URL string + """ + cache_key = self._cache_key(filename, download_filename) + + cached = self._get_cached(cache_key) + if cached: + return cached + + url = self._storage.get_download_url(filename, expires_in, download_filename=download_filename) + self._set_cached(cache_key, url, expires_in) + + return url + + def get_download_urls( + self, + filenames: list[str], + expires_in: int = 3600, + *, + download_filenames: list[str] | None = None, + ) -> list[str]: + """Batch get download URLs with cache. + + Args: + filenames: List of file paths/keys in storage + expires_in: URL validity duration in seconds (default: 1 hour) + download_filenames: If provided, must match len(filenames). Each element + specifies the download filename for the corresponding file. + + Returns: + List of presigned URLs in the same order as filenames + """ + if not filenames: + return [] + + # Build cache keys including download_filename for uniqueness + if download_filenames is None: + cache_keys = [self._cache_key(f, None) for f in filenames] + else: + cache_keys = list(starmap(self._cache_key, zip(filenames, download_filenames, strict=True))) + + cached_values = self._get_cached_batch(cache_keys) + + # Build results list, tracking which indices need fetching + results: list[str | None] = list(cached_values) + uncached_indices: list[int] = [] + uncached_filenames: list[str] = [] + uncached_download_filenames: list[str | None] = [] + + for i, (filename, cached) in enumerate(zip(filenames, cached_values)): + if not cached: + uncached_indices.append(i) + uncached_filenames.append(filename) + uncached_download_filenames.append(download_filenames[i] if download_filenames else None) + + # Batch fetch uncached URLs from storage + if uncached_filenames: + uncached_urls = [ + self._storage.get_download_url(f, expires_in, download_filename=df) + for f, df in zip(uncached_filenames, uncached_download_filenames, strict=True) + ] + + # Fill results at correct positions + for idx, url in zip(uncached_indices, uncached_urls): + results[idx] = url + + # Batch set cache + uncached_cache_keys = [cache_keys[i] for i in uncached_indices] + self._set_cached_batch(uncached_cache_keys, uncached_urls, expires_in) + + return results # type: ignore[return-value] + + def invalidate(self, filenames: list[str]) -> None: + """Invalidate cached URLs for given filenames. + + Args: + filenames: List of file paths/keys to invalidate + """ + if not filenames: + return + + cache_keys = [self._cache_key(f) for f in filenames] + try: + self._redis.delete(*cache_keys) + except Exception: + logger.warning("Failed to invalidate presign cache", exc_info=True) + + def _cache_key(self, filename: str, download_filename: str | None = None) -> str: + """Generate cache key for a filename. + + When download_filename is provided, its hash is appended to the key to ensure + different download names for the same storage key get separate cache entries. + We use a hash (truncated MD5) instead of the raw string because: + - download_filename may contain special characters unsafe for Redis keys + - Hash collisions only cause a cache miss, no functional impact + """ + if download_filename: + # Use first 16 chars of MD5 hex digest (64 bits) - sufficient for cache key uniqueness + name_hash = hashlib.md5(download_filename.encode("utf-8")).hexdigest()[:16] + return f"{self._cache_key_prefix}:{filename}::{name_hash}" + return f"{self._cache_key_prefix}:{filename}" + + def _compute_ttl(self, expires_in: int) -> int: + """Compute cache TTL from presign expiration. + + Returns TTL slightly shorter than presign expiry to ensure + cached URLs are refreshed before they expire. + """ + return max(expires_in - self.TTL_BUFFER_SECONDS, self.MIN_TTL_SECONDS) + + def _get_cached(self, cache_key: str) -> str | None: + """Get a single cached URL.""" + try: + values = self._redis.mget([cache_key]) + cached = values[0] if values else None + if cached: + return cached.decode("utf-8") if isinstance(cached, (bytes, bytearray)) else cached + return None + except Exception: + logger.warning("Failed to read presign cache", exc_info=True) + return None + + def _get_cached_batch(self, cache_keys: list[str]) -> list[str | None]: + """Get multiple cached URLs.""" + try: + cached_values = self._redis.mget(cache_keys) + return [v.decode("utf-8") if isinstance(v, (bytes, bytearray)) else v for v in cached_values] + except Exception: + logger.warning("Failed to read presign cache batch", exc_info=True) + return [None] * len(cache_keys) + + def _set_cached(self, cache_key: str, url: str, expires_in: int) -> None: + """Store a URL in cache with computed TTL.""" + ttl = self._compute_ttl(expires_in) + try: + self._redis.setex(cache_key, ttl, url) + except Exception: + logger.warning("Failed to write presign cache", exc_info=True) + + def _set_cached_batch(self, cache_keys: list[str], urls: list[str], expires_in: int) -> None: + """Store multiple URLs in cache with computed TTL using pipeline.""" + if not cache_keys: + return + ttl = self._compute_ttl(expires_in) + try: + pipe = self._redis.pipeline() + for cache_key, url in zip(cache_keys, urls): + pipe.setex(cache_key, ttl, url) + pipe.execute() + except Exception: + logger.warning("Failed to write presign cache batch", exc_info=True) diff --git a/api/extensions/storage/file_presign_storage.py b/api/extensions/storage/file_presign_storage.py new file mode 100644 index 0000000000..7594fa0d52 --- /dev/null +++ b/api/extensions/storage/file_presign_storage.py @@ -0,0 +1,73 @@ +"""Storage wrapper that provides presigned URL support with fallback to ticket-based URLs. + +This is the unified presign wrapper for all storage operations. When the underlying +storage backend doesn't support presigned URLs (raises NotImplementedError), it falls +back to generating ticket-based URLs that route through Dify's file proxy endpoints. + +Usage: + from extensions.storage.file_presign_storage import FilePresignStorage + + # Wrap any BaseStorage to add presign support + presign_storage = FilePresignStorage(base_storage) + download_url = presign_storage.get_download_url("path/to/file.txt", expires_in=3600) + upload_url = presign_storage.get_upload_url("path/to/file.txt", expires_in=3600) + +When the underlying storage doesn't support presigned URLs, the fallback URLs follow the format: + {FILES_API_URL}/files/storage-files/{token} (falls back to FILES_URL) + +The token is a UUID that maps to the real storage key in Redis. +""" + +from extensions.storage.storage_wrapper import StorageWrapper + + +class FilePresignStorage(StorageWrapper): + """Storage wrapper that provides presigned URL support with ticket fallback. + + If the wrapped storage supports presigned URLs, delegates to it. + Otherwise, generates ticket-based URLs for both download and upload operations. + """ + + def get_download_url( + self, + filename: str, + expires_in: int = 3600, + *, + download_filename: str | None = None, + ) -> str: + """Get a presigned download URL, falling back to ticket URL if not supported.""" + try: + return self._storage.get_download_url(filename, expires_in, download_filename=download_filename) + except NotImplementedError: + from services.storage_ticket_service import StorageTicketService + + return StorageTicketService.create_download_url(filename, expires_in=expires_in, filename=download_filename) + + def get_download_urls( + self, + filenames: list[str], + expires_in: int = 3600, + *, + download_filenames: list[str] | None = None, + ) -> list[str]: + """Get presigned download URLs for multiple files.""" + try: + return self._storage.get_download_urls(filenames, expires_in, download_filenames=download_filenames) + except NotImplementedError: + from services.storage_ticket_service import StorageTicketService + + if download_filenames is None: + return [StorageTicketService.create_download_url(f, expires_in=expires_in) for f in filenames] + return [ + StorageTicketService.create_download_url(f, expires_in=expires_in, filename=df) + for f, df in zip(filenames, download_filenames, strict=True) + ] + + def get_upload_url(self, filename: str, expires_in: int = 3600) -> str: + """Get a presigned upload URL, falling back to ticket URL if not supported.""" + try: + return self._storage.get_upload_url(filename, expires_in) + except NotImplementedError: + from services.storage_ticket_service import StorageTicketService + + return StorageTicketService.create_upload_url(filename, expires_in=expires_in) diff --git a/api/extensions/storage/storage_wrapper.py b/api/extensions/storage/storage_wrapper.py new file mode 100644 index 0000000000..db472f3c47 --- /dev/null +++ b/api/extensions/storage/storage_wrapper.py @@ -0,0 +1,66 @@ +"""Base class for storage wrappers that delegate to an inner storage.""" + +from collections.abc import Generator + +from extensions.storage.base_storage import BaseStorage + + +class StorageWrapper(BaseStorage): + """Base class for storage wrappers using the decorator pattern. + + Forwards all BaseStorage methods to the wrapped storage by default. + Subclasses can override specific methods to customize behavior. + + Example: + class MyCustomStorage(StorageWrapper): + def save(self, filename: str, data: bytes): + # Custom logic before save + super().save(filename, data) + # Custom logic after save + """ + + def __init__(self, storage: BaseStorage): + super().__init__() + self._storage = storage + + def save(self, filename: str, data: bytes): + self._storage.save(filename, data) + + def load_once(self, filename: str) -> bytes: + return self._storage.load_once(filename) + + def load_stream(self, filename: str) -> Generator: + return self._storage.load_stream(filename) + + def download(self, filename: str, target_filepath: str): + self._storage.download(filename, target_filepath) + + def exists(self, filename: str) -> bool: + return self._storage.exists(filename) + + def delete(self, filename: str): + self._storage.delete(filename) + + def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]: + return self._storage.scan(path, files=files, directories=directories) + + def get_download_url( + self, + filename: str, + expires_in: int = 3600, + *, + download_filename: str | None = None, + ) -> str: + return self._storage.get_download_url(filename, expires_in, download_filename=download_filename) + + def get_download_urls( + self, + filenames: list[str], + expires_in: int = 3600, + *, + download_filenames: list[str] | None = None, + ) -> list[str]: + return self._storage.get_download_urls(filenames, expires_in, download_filenames=download_filenames) + + def get_upload_url(self, filename: str, expires_in: int = 3600) -> str: + return self._storage.get_upload_url(filename, expires_in) diff --git a/api/libs/attr_map.py b/api/libs/attr_map.py new file mode 100644 index 0000000000..c7dd61a820 --- /dev/null +++ b/api/libs/attr_map.py @@ -0,0 +1,163 @@ +""" +Type-safe attribute storage inspired by Netty's AttributeKey/AttributeMap pattern. + +Provides loosely-coupled typed attribute storage where only code with access +to the same AttrKey instance can read/write the corresponding attribute. + + SESSION_KEY: AttrKey[Session] = AttrKey("session", Session) + attrs = AttrMap() + attrs.set(SESSION_KEY, session) + session = attrs.get(SESSION_KEY) # -> Session (raises if not set) + session = attrs.get_or_none(SESSION_KEY) # -> Session | None + +Note: AttrMap is NOT thread-safe. Each instance should be confined to a single +thread/context (e.g., one AttrMap per Sandbox/VirtualEnvironment instance). +""" + +from __future__ import annotations + +from typing import Any, Generic, TypeVar, cast, final, overload + +T = TypeVar("T") +D = TypeVar("D") + + +@final +class AttrKey(Generic[T]): + """ + A type-safe key for attribute storage. + + Identity-based: different AttrKey instances with same name are distinct keys. + This enables different modules to define keys independently without collision. + """ + + __slots__ = ("_name", "_type") + + def __init__(self, name: str, type_: type[T]) -> None: + self._name = name + self._type = type_ + + @property + def name(self) -> str: + return self._name + + @property + def type_(self) -> type[T]: + return self._type + + def __repr__(self) -> str: + return f"AttrKey({self._name!r}, {self._type.__name__})" + + def __hash__(self) -> int: + return id(self) + + def __eq__(self, other: object) -> bool: + return self is other + + +class AttrMapKeyError(KeyError): + """Raised when a required attribute is not set.""" + + key: AttrKey[Any] + + def __init__(self, key: AttrKey[Any]) -> None: + self.key = key + super().__init__(f"Required attribute '{key.name}' (type: {key.type_.__name__}) is not set") + + +class AttrMapTypeError(TypeError): + """Raised when attribute value type doesn't match the key's declared type.""" + + key: AttrKey[Any] + expected_type: type[Any] + actual_type: type[Any] + + def __init__(self, key: AttrKey[Any], expected_type: type[Any], actual_type: type[Any]) -> None: + self.key = key + self.expected_type = expected_type + self.actual_type = actual_type + super().__init__( + f"Attribute '{key.name}' expects type '{expected_type.__name__}', got '{actual_type.__name__}'" + ) + + +@final +class AttrMap: + """ + Thread-confined container for storing typed attributes using AttrKey instances. + + NOT thread-safe. Each instance should be owned by a single context + (e.g., one AttrMap per Sandbox/VirtualEnvironment instance). + """ + + __slots__ = ("_data",) + + def __init__(self) -> None: + self._data: dict[AttrKey[Any], Any] = {} + + def set(self, key: AttrKey[T], value: T, *, validate: bool = True) -> None: + """ + Store an attribute. Raises AttrMapTypeError if validate=True and type mismatches. + + Note: Runtime validation only checks outer type (e.g., `list` not `list[str]`). + """ + if validate and not isinstance(value, key.type_): + raise AttrMapTypeError(key, key.type_, type(value)) + self._data[key] = value + + def get(self, key: AttrKey[T]) -> T: + """Retrieve an attribute. Raises AttrMapKeyError if not set.""" + if key not in self._data: + raise AttrMapKeyError(key) + return cast(T, self._data[key]) + + def get_or_none(self, key: AttrKey[T]) -> T | None: + """Retrieve an attribute, returning None if not set.""" + return cast(T | None, self._data.get(key)) + + @overload + def get_or_default(self, key: AttrKey[T], default: T) -> T: ... + + @overload + def get_or_default(self, key: AttrKey[T], default: D) -> T | D: ... + + def get_or_default(self, key: AttrKey[T], default: T | D) -> T | D: + """Retrieve an attribute, returning default if not set.""" + if key in self._data: + return cast(T, self._data[key]) + return default + + def has(self, key: AttrKey[Any]) -> bool: + """Check if an attribute is set.""" + return key in self._data + + def remove(self, key: AttrKey[Any]) -> bool: + """Remove an attribute. Returns True if it was present.""" + if key in self._data: + del self._data[key] + return True + return False + + def set_if_absent(self, key: AttrKey[T], value: T, *, validate: bool = True) -> T: + """ + Set attribute only if not already set. Returns existing or newly set value. + + Raises AttrMapTypeError if validate=True and type mismatches. + """ + if key in self._data: + return cast(T, self._data[key]) + if validate and not isinstance(value, key.type_): + raise AttrMapTypeError(key, key.type_, type(value)) + self._data[key] = value + return value + + def clear(self) -> None: + """Remove all attributes.""" + self._data.clear() + + def __len__(self) -> int: + return len(self._data) + + def __repr__(self) -> str: + keys = [k.name for k in self._data] + return f"AttrMap({keys})" diff --git a/api/services/app_asset_service.py b/api/services/app_asset_service.py new file mode 100644 index 0000000000..471bef7cf9 --- /dev/null +++ b/api/services/app_asset_service.py @@ -0,0 +1,579 @@ +from __future__ import annotations + +import logging +import threading +from uuid import uuid4 + +from sqlalchemy.orm import Session + +from core.app.entities.app_asset_entities import ( + AppAssetFileTree, + AppAssetNode, + AssetNodeType, + BatchUploadNode, + TreeNodeNotFoundError, + TreeParentNotFoundError, + TreePathConflictError, +) +from core.app_assets.accessor import CachedContentAccessor +from core.app_assets.entities.assets import AssetItem +from core.app_assets.storage import AssetPaths +from core.zip_sandbox import SandboxDownloadItem +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from extensions.ext_storage import storage +from extensions.storage.cached_presign_storage import CachedPresignStorage +from extensions.storage.file_presign_storage import FilePresignStorage +from models.app_asset import AppAssets +from models.model import App +from services.asset_content_service import AssetContentService + +from .errors.app_asset import ( + AppAssetNodeNotFoundError, + AppAssetNodeTooLargeError, + AppAssetParentNotFoundError, + AppAssetPathConflictError, +) + +logger = logging.getLogger(__name__) + + +class AppAssetService: + MAX_PREVIEW_CONTENT_SIZE = 5 * 1024 * 1024 # 5MB + _LOCK_TIMEOUT_SECONDS = 60 + + @staticmethod + def get_storage() -> CachedPresignStorage: + """Get a lazily-initialized storage instance for app assets. + + Returns a CachedPresignStorage wrapping FilePresignStorage, + providing presign fallback and URL caching. + """ + return CachedPresignStorage( + storage=FilePresignStorage(storage.storage_runner), + cache_key_prefix="app_assets", + ) + + @staticmethod + def _lock(app_id: str): + return redis_client.lock(f"app_asset:lock:{app_id}", timeout=AppAssetService._LOCK_TIMEOUT_SECONDS) + + @staticmethod + def get_assets_by_version(tenant_id: str, app_id: str, workflow_id: str | None = None) -> AppAssets: + """Get asset tree by workflow_id (published) or draft if workflow_id is None.""" + with Session(db.engine) as session: + version = workflow_id or AppAssets.VERSION_DRAFT + assets = ( + session.query(AppAssets) + .filter( + AppAssets.tenant_id == tenant_id, + AppAssets.app_id == app_id, + AppAssets.version == version, + ) + .first() + ) + return assets or AppAssets(tenant_id=tenant_id, app_id=app_id, version=version) + + @staticmethod + def get_draft_assets(tenant_id: str, app_id: str) -> list[AssetItem]: + with Session(db.engine) as session: + assets = ( + session.query(AppAssets) + .filter( + AppAssets.tenant_id == tenant_id, + AppAssets.app_id == app_id, + AppAssets.version == AppAssets.VERSION_DRAFT, + ) + .first() + ) + if not assets: + return [] + return AppAssetService.get_draft_asset_items(assets.tenant_id, assets.app_id, assets.asset_tree) + + @staticmethod + def get_draft_asset_items(tenant_id: str, app_id: str, file_tree: AppAssetFileTree) -> list[AssetItem]: + files = file_tree.walk_files() + return [ + AssetItem( + asset_id=f.id, + path=file_tree.get_path(f.id), + file_name=f.name, + extension=f.extension, + storage_key=AssetPaths.draft(tenant_id, app_id, f.id), + ) + for f in files + ] + + @staticmethod + def get_or_create_assets(session: Session, app_model: App, account_id: str) -> AppAssets: + assets = ( + session.query(AppAssets) + .filter( + AppAssets.tenant_id == app_model.tenant_id, + AppAssets.app_id == app_model.id, + AppAssets.version == AppAssets.VERSION_DRAFT, + ) + .first() + ) + if not assets: + assets = AppAssets( + id=str(uuid4()), + tenant_id=app_model.tenant_id, + app_id=app_model.id, + version=AppAssets.VERSION_DRAFT, + created_by=account_id, + ) + session.add(assets) + session.commit() + return assets + + @staticmethod + def get_tenant_app_assets(tenant_id: str, assets_id: str) -> AppAssets: + with Session(db.engine, expire_on_commit=False) as session: + app_assets = ( + session.query(AppAssets) + .filter( + AppAssets.tenant_id == tenant_id, + AppAssets.id == assets_id, + ) + .first() + ) + if not app_assets: + raise ValueError(f"App assets not found for tenant_id={tenant_id}, assets_id={assets_id}") + + return app_assets + + @staticmethod + def get_assets(tenant_id: str, app_id: str, user_id: str, *, is_draft: bool) -> AppAssets | None: + with Session(db.engine, expire_on_commit=False) as session: + if is_draft: + stmt = session.query(AppAssets).filter( + AppAssets.tenant_id == tenant_id, + AppAssets.app_id == app_id, + AppAssets.version == AppAssets.VERSION_DRAFT, + ) + if not stmt.first(): + assets = AppAssets( + id=str(uuid4()), + tenant_id=tenant_id, + app_id=app_id, + version=AppAssets.VERSION_DRAFT, + created_by=user_id, + ) + session.add(assets) + session.commit() + else: + stmt = ( + session.query(AppAssets) + .filter( + AppAssets.tenant_id == tenant_id, + AppAssets.app_id == app_id, + AppAssets.version != AppAssets.VERSION_DRAFT, + ) + .order_by(AppAssets.created_at.desc()) + ) + return stmt.first() + + @staticmethod + def get_asset_tree(app_model: App, account_id: str) -> AppAssetFileTree: + with Session(db.engine) as session: + assets = AppAssetService.get_or_create_assets(session, app_model, account_id) + return assets.asset_tree + + @staticmethod + def create_folder( + app_model: App, + account_id: str, + name: str, + parent_id: str | None = None, + ) -> AppAssetNode: + with AppAssetService._lock(app_model.id): + with Session(db.engine, expire_on_commit=False) as session: + assets = AppAssetService.get_or_create_assets(session, app_model, account_id) + tree = assets.asset_tree + + unique_name = tree.ensure_unique_name( + parent_id, + name, + is_file=False, + ) + node = AppAssetNode.create_folder(str(uuid4()), unique_name, parent_id) + + try: + tree.add(node) + except TreeParentNotFoundError as e: + raise AppAssetParentNotFoundError(str(e)) from e + except TreePathConflictError as e: + raise AppAssetPathConflictError(str(e)) from e + + assets.asset_tree = tree + assets.updated_by = account_id + session.commit() + + return node + + @staticmethod + def get_accessor(tenant_id: str, app_id: str) -> CachedContentAccessor: + """Get a content accessor with DB caching for the given app.""" + return CachedContentAccessor(AppAssetService.get_storage(), tenant_id, app_id) + + # Default TTL for presigned download URLs generated by to_download_items(). + _DOWNLOAD_URL_TTL_SECONDS = 600 + + @staticmethod + def to_download_items( + items: list[AssetItem], + *, + path_prefix: str = "", + ) -> list[SandboxDownloadItem]: + """Convert asset items to unified download items. + + Items with *content* become inline ``SandboxDownloadItem`` instances + (no presigned URL needed). Items without *content* get presigned + download URLs from storage. + + *path_prefix*, when set, is prepended to every item path + (e.g. ``"my-app"`` → ``"my-app/skills/foo.md"``). + """ + from core.zip_sandbox import SandboxDownloadItem + + inline: list[SandboxDownloadItem] = [] + remote_items: list[tuple[AssetItem, str]] = [] # (item, path) + + for item in items: + path = f"{path_prefix}/{item.path}" if path_prefix else item.path + if item.content is not None: + inline.append(SandboxDownloadItem(path=path, content=item.content)) + else: + remote_items.append((item, path)) + + result = list(inline) + if remote_items: + asset_storage = AppAssetService.get_storage() + keys = [a.storage_key for a, _ in remote_items] + urls = asset_storage.get_download_urls(keys, AppAssetService._DOWNLOAD_URL_TTL_SECONDS) + for (_, path), url in zip(remote_items, urls, strict=True): + result.append(SandboxDownloadItem(path=path, url=url)) + + return result + + @staticmethod + def get_file_content(app_model: App, account_id: str, node_id: str) -> bytes: + with Session(db.engine) as session: + assets = AppAssetService.get_or_create_assets(session, app_model, account_id) + tree = assets.asset_tree + + node = tree.get(node_id) + if not node or node.node_type != AssetNodeType.FILE: + raise AppAssetNodeNotFoundError(f"File node {node_id} not found") + + if node.size > AppAssetService.MAX_PREVIEW_CONTENT_SIZE: + max_size_mb = AppAssetService.MAX_PREVIEW_CONTENT_SIZE / 1024 / 1024 + raise AppAssetNodeTooLargeError(f"File node {node_id} size exceeded the limit: {max_size_mb} MB") + + accessor = AppAssetService.get_accessor(app_model.tenant_id, app_model.id) + return accessor.load(node) + + @staticmethod + def update_file_content( + app_model: App, + account_id: str, + node_id: str, + content: bytes, + ) -> AppAssetNode: + with AppAssetService._lock(app_model.id): + with Session(db.engine, expire_on_commit=False) as session: + assets = AppAssetService.get_or_create_assets(session, app_model, account_id) + tree = assets.asset_tree + + try: + node = tree.update(node_id, len(content)) + except TreeNodeNotFoundError as e: + raise AppAssetNodeNotFoundError(str(e)) from e + + accessor = AppAssetService.get_accessor(app_model.tenant_id, app_model.id) + accessor.save(node, content) + + assets.asset_tree = tree + assets.updated_by = account_id + session.commit() + + return node + + @staticmethod + def rename_node( + app_model: App, + account_id: str, + node_id: str, + new_name: str, + ) -> AppAssetNode: + with AppAssetService._lock(app_model.id): + with Session(db.engine, expire_on_commit=False) as session: + assets = AppAssetService.get_or_create_assets(session, app_model, account_id) + tree = assets.asset_tree + try: + node = tree.rename(node_id, new_name) + except TreeNodeNotFoundError as e: + raise AppAssetNodeNotFoundError(str(e)) from e + except TreePathConflictError as e: + raise AppAssetPathConflictError(str(e)) from e + + assets.asset_tree = tree + assets.updated_by = account_id + session.commit() + return node + + @staticmethod + def move_node( + app_model: App, + account_id: str, + node_id: str, + new_parent_id: str | None, + ) -> AppAssetNode: + with AppAssetService._lock(app_model.id): + with Session(db.engine, expire_on_commit=False) as session: + assets = AppAssetService.get_or_create_assets(session, app_model, account_id) + tree = assets.asset_tree + + try: + node = tree.move(node_id, new_parent_id) + except TreeNodeNotFoundError as e: + raise AppAssetNodeNotFoundError(str(e)) from e + except TreeParentNotFoundError as e: + raise AppAssetParentNotFoundError(str(e)) from e + except TreePathConflictError as e: + raise AppAssetPathConflictError(str(e)) from e + + assets.asset_tree = tree + assets.updated_by = account_id + session.commit() + + return node + + @staticmethod + def reorder_node( + app_model: App, + account_id: str, + node_id: str, + after_node_id: str | None, + ) -> AppAssetNode: + with AppAssetService._lock(app_model.id): + with Session(db.engine, expire_on_commit=False) as session: + assets = AppAssetService.get_or_create_assets(session, app_model, account_id=account_id) + tree = assets.asset_tree + + try: + node = tree.reorder(node_id, after_node_id) + except TreeNodeNotFoundError as e: + raise AppAssetNodeNotFoundError(str(e)) from e + + assets.asset_tree = tree + assets.updated_by = account_id + session.commit() + + return node + + @staticmethod + def delete_node(app_model: App, account_id: str, node_id: str) -> None: + with AppAssetService._lock(app_model.id): + with Session(db.engine) as session: + assets = AppAssetService.get_or_create_assets(session, app_model, account_id) + tree = assets.asset_tree + + try: + removed_ids = tree.remove(node_id) + except TreeNodeNotFoundError as e: + raise AppAssetNodeNotFoundError(str(e)) from e + assets.asset_tree = tree + assets.updated_by = account_id + session.commit() + + # Delete from both DB cache and S3 in background; failures are non-fatal. + def _delete_files(tenant_id: str, app_id: str, node_ids: list[str]) -> None: + AssetContentService.delete_many(tenant_id, app_id, node_ids) + asset_storage = AppAssetService.get_storage() + for nid in node_ids: + key = AssetPaths.draft(tenant_id, app_id, nid) + try: + asset_storage.delete(key) + except Exception: + logger.warning("Failed to delete storage file %s", key, exc_info=True) + + threading.Thread(target=lambda: _delete_files(app_model.tenant_id, app_model.id, removed_ids)).start() + + @staticmethod + def get_file_download_url( + app_model: App, + account_id: str, + node_id: str, + expires_in: int = 3600, + ) -> str: + with Session(db.engine) as session: + assets = AppAssetService.get_or_create_assets(session, app_model, account_id) + node = assets.asset_tree.get(node_id) + if not node or node.node_type != AssetNodeType.FILE: + raise AppAssetNodeNotFoundError(f"File node {node_id} not found") + + asset_storage = AppAssetService.get_storage() + key = AssetPaths.draft(app_model.tenant_id, app_model.id, node_id) + return asset_storage.get_download_url(key, expires_in, download_filename=node.name) + + @staticmethod + def get_source_zip_bytes(tenant_id: str, app_id: str, workflow_id: str) -> bytes | None: + asset_storage = AppAssetService.get_storage() + key = AssetPaths.source_zip(tenant_id, app_id, workflow_id) + try: + return asset_storage.load_once(key) + except FileNotFoundError: + logger.warning("Source zip not found: %s", key) + return None + + @staticmethod + def set_draft_assets( + app_model: App, + account_id: str, + new_tree: AppAssetFileTree, + ) -> AppAssets: + with AppAssetService._lock(app_model.id): + with Session(db.engine, expire_on_commit=False) as session: + assets = AppAssetService.get_or_create_assets(session, app_model, account_id) + assets.asset_tree = new_tree + assets.updated_by = account_id + session.commit() + + return assets + + @staticmethod + def get_file_upload_url( + app_model: App, + account_id: str, + name: str, + size: int, + parent_id: str | None = None, + expires_in: int = 3600, + ) -> tuple[AppAssetNode, str]: + """ + Create a file node with metadata and return a pre-signed upload URL. + + The file metadata is saved immediately. If the user doesn't upload, + the download will fail when the file is accessed. + + If a sibling with the same name exists, a numeric suffix is appended + to make the name unique (e.g. "report 1.txt"). + + Returns: + tuple of (node, upload_url) + """ + with AppAssetService._lock(app_model.id): + with Session(db.engine, expire_on_commit=False) as session: + assets = AppAssetService.get_or_create_assets(session, app_model, account_id) + tree = assets.asset_tree + + unique_name = tree.ensure_unique_name( + parent_id, + name, + is_file=True, + ) + node_id = str(uuid4()) + node = AppAssetNode.create_file(node_id, unique_name, parent_id, size) + + try: + tree.add(node) + except TreeParentNotFoundError as e: + raise AppAssetParentNotFoundError(str(e)) from e + except TreePathConflictError as e: + raise AppAssetPathConflictError(str(e)) from e + + assets.asset_tree = tree + assets.updated_by = account_id + session.commit() + + key = AssetPaths.draft(app_model.tenant_id, app_model.id, node_id) + asset_storage = AppAssetService.get_storage() + + # put empty content to create the file record + # which avoids file not found error when uploading via presigned URL is never touched + # resulting in inconsistent state + asset_storage.save(key, b"") + + upload_url = asset_storage.get_upload_url(key, expires_in) + + return node, upload_url + + @staticmethod + def batch_create_from_tree( + app_model: App, + account_id: str, + input_children: list[BatchUploadNode], + parent_id: str | None = None, + expires_in: int = 3600, + ) -> list[BatchUploadNode]: + """ + Create a nested batch-upload tree under one parent in a single tree mutation. + + The full metadata tree is added to the draft asset tree before the method + returns any upload URLs. That preserves sibling name de-duplication and + keeps nested uploads atomic for both root and subfolder targets. + """ + if not input_children: + return [] + + with AppAssetService._lock(app_model.id): + with Session(db.engine, expire_on_commit=False) as session: + assets = AppAssetService.get_or_create_assets(session, app_model, account_id) + tree = assets.asset_tree + + taken_by_parent: dict[str | None, set[str]] = {} + stack: list[tuple[BatchUploadNode, str | None]] = [ + (child, parent_id) for child in reversed(input_children) + ] + while stack: + node, current_parent_id = stack.pop() + if node.id is None: + node.id = str(uuid4()) + if current_parent_id not in taken_by_parent: + taken_by_parent[current_parent_id] = { + child.name for child in tree.get_children(current_parent_id) + } + taken = taken_by_parent[current_parent_id] + unique_name = tree.ensure_unique_name( + current_parent_id, + node.name, + is_file=node.node_type == AssetNodeType.FILE, + extra_taken=taken, + ) + node.name = unique_name + taken.add(unique_name) + if node.node_type == AssetNodeType.FOLDER and node.children: + for child in reversed(node.children): + stack.append((child, node.id)) + + new_nodes: list[AppAssetNode] = [] + for child in input_children: + new_nodes.extend(child.to_app_asset_nodes(parent_id)) + + try: + for node in new_nodes: + tree.add(node) + except TreeParentNotFoundError as e: + raise AppAssetParentNotFoundError(str(e)) from e + except TreePathConflictError as e: + raise AppAssetPathConflictError(str(e)) from e + + assets.asset_tree = tree + assets.updated_by = account_id + session.commit() + + asset_storage = AppAssetService.get_storage() + + def fill_urls(node: BatchUploadNode) -> None: + if node.node_type == AssetNodeType.FILE and node.id: + key = AssetPaths.draft(app_model.tenant_id, app_model.id, node.id) + node.upload_url = asset_storage.get_upload_url(key, expires_in) + for child in node.children: + fill_urls(child) + + for child in input_children: + fill_urls(child) + + return input_children diff --git a/api/services/app_bundle_service.py b/api/services/app_bundle_service.py new file mode 100644 index 0000000000..bdc57c308b --- /dev/null +++ b/api/services/app_bundle_service.py @@ -0,0 +1,288 @@ +"""Service for exporting and importing App Bundles (DSL + assets). + +Bundle structure: + bundle.zip/ + {app_name}.yml # DSL file + manifest.json # Asset manifest (required for import) + {app_name}/ # Asset files + folder/file.txt + ... + +Import flow (sandbox-based): + 1. prepare_import: Frontend gets upload URL, stores import_id in Redis + 2. Frontend uploads zip to storage + 3. confirm_import: Sandbox downloads zip, extracts, uploads assets via presigned URLs + +Manifest format (schema_version 1.0): + - app_assets.tree: Full AppAssetFileTree for 100% ID restoration + - files: node_id -> path mapping for file nodes + - integrity.file_count: Basic validation +""" + +from __future__ import annotations + +import json +import logging +import re +from dataclasses import dataclass +from uuid import uuid4 + +from pydantic import ValidationError +from sqlalchemy.orm import Session + +from core.app.entities.app_bundle_entities import ( + MANIFEST_FILENAME, + BundleExportResult, + BundleFormatError, + BundleManifest, +) +from core.app_assets.storage import AssetPaths +from core.zip_sandbox import SandboxUploadItem, ZipSandbox +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from extensions.storage.cached_presign_storage import CachedPresignStorage +from models.account import Account +from models.model import App + +from .app_asset_package_service import AppAssetPackageService +from .app_asset_service import AppAssetService +from .app_dsl_service import AppDslService, Import + +logger = logging.getLogger(__name__) + +_IMPORT_REDIS_PREFIX = "app_bundle:import:" +_IMPORT_TTL_SECONDS = 3600 # 1 hour + + +@dataclass +class ImportPrepareResult: + import_id: str + upload_url: str + + +class AppBundleService: + @staticmethod + def publish( + session: Session, + app_model: App, + account: Account, + marked_name: str = "", + marked_comment: str = "", + ): + """Publish App Bundle (workflow + assets) in a single transaction.""" + from models.workflow import Workflow + from services.workflow_service import WorkflowService + + workflow: Workflow = WorkflowService().publish_workflow( + session=session, + app_model=app_model, + account=account, + marked_name=marked_name, + marked_comment=marked_comment, + ) + AppAssetPackageService.publish( + session=session, + app_model=app_model, + account_id=account.id, + workflow_id=workflow.id, + ) + return workflow + + # ========== Export ========== + + @staticmethod + def export_bundle( + *, + app_model: App, + account_id: str, + include_secret: bool = False, + workflow_id: str | None = None, + expires_in: int = 10 * 60, + ) -> BundleExportResult: + """Export bundle with manifest.json and return a temporary download URL.""" + tenant_id = app_model.tenant_id + app_id = app_model.id + safe_name = AppBundleService._sanitize_filename(app_model.name) + + dsl_filename = f"{safe_name}.yml" + app_assets = AppAssetService.get_assets_by_version(tenant_id, app_id, workflow_id) + manifest = BundleManifest.from_tree(app_assets.asset_tree, dsl_filename) + + export_id = uuid4().hex + export_key = AssetPaths.bundle_export(tenant_id, app_id, export_id) + asset_storage = AppAssetService.get_storage() + upload_url = asset_storage.get_upload_url(export_key, expires_in) + + dsl_content = AppDslService.export_dsl( + app_model=app_model, + include_secret=include_secret, + workflow_id=workflow_id, + ) + + with ZipSandbox(tenant_id=tenant_id, user_id=account_id, app_id="app-bundle-export") as zs: + zs.write_file(f"bundle_root/{safe_name}.yml", dsl_content.encode("utf-8")) + zs.write_file(f"bundle_root/{MANIFEST_FILENAME}", manifest.model_dump_json(indent=2).encode("utf-8")) + + if workflow_id is not None: + source_key = AssetPaths.source_zip(tenant_id, app_id, workflow_id) + source_url = asset_storage.get_download_url(source_key, expires_in) + zs.download_archive(source_url, path="tmp/source_assets.zip") + zs.unzip(archive_path="tmp/source_assets.zip", dest_dir=f"bundle_root/{safe_name}") + else: + asset_items = AppAssetService.get_draft_assets(tenant_id, app_id) + if asset_items: + accessor = AppAssetService.get_accessor(tenant_id, app_id) + resolved = accessor.resolve_items(asset_items) + download_items = AppAssetService.to_download_items(resolved, path_prefix=safe_name) + zs.download_items(download_items, dest_dir="bundle_root") + + archive = zs.zip(src="bundle_root", include_base=False) + zs.upload(archive, upload_url) + + bundle_filename = f"{safe_name}.zip" + download_url = asset_storage.get_download_url(export_key, expires_in, download_filename=bundle_filename) + return BundleExportResult(download_url=download_url, filename=bundle_filename) + + # ========== Import ========== + + @staticmethod + def prepare_import(tenant_id: str, account_id: str) -> ImportPrepareResult: + """Prepare import: generate import_id and upload URL.""" + import_id = uuid4().hex + import_key = AssetPaths.bundle_import(tenant_id, import_id) + asset_storage = AppAssetService.get_storage() + upload_url = asset_storage.get_upload_url(import_key, _IMPORT_TTL_SECONDS) + + redis_client.setex( + f"{_IMPORT_REDIS_PREFIX}{import_id}", + _IMPORT_TTL_SECONDS, + json.dumps({"tenant_id": tenant_id, "account_id": account_id}), + ) + + return ImportPrepareResult(import_id=import_id, upload_url=upload_url) + + @staticmethod + def confirm_import( + import_id: str, + account: Account, + *, + name: str | None = None, + description: str | None = None, + icon_type: str | None = None, + icon: str | None = None, + icon_background: str | None = None, + ) -> Import: + """Confirm import: download zip in sandbox, extract, and upload assets.""" + redis_key = f"{_IMPORT_REDIS_PREFIX}{import_id}" + redis_data = redis_client.get(redis_key) + if not redis_data: + raise BundleFormatError("Import session expired or not found") + + import_meta = json.loads(redis_data) + tenant_id: str = import_meta["tenant_id"] + + if tenant_id != account.current_tenant_id: + raise BundleFormatError("Import session tenant mismatch") + + import_key = AssetPaths.bundle_import(tenant_id, import_id) + asset_storage = AppAssetService.get_storage() + + try: + result = AppBundleService.import_bundle( + tenant_id=tenant_id, + account=account, + import_key=import_key, + asset_storage=asset_storage, + name=name, + description=description, + icon_type=icon_type, + icon=icon, + icon_background=icon_background, + ) + finally: + redis_client.delete(redis_key) + try: + asset_storage.delete(import_key) + except Exception: # noqa: S110 + pass + + return result + + @staticmethod + def import_bundle( + *, + tenant_id: str, + account: Account, + import_key: str, + asset_storage: CachedPresignStorage, + name: str | None, + description: str | None, + icon_type: str | None, + icon: str | None, + icon_background: str | None, + ) -> Import: + """Execute import in sandbox.""" + download_url = asset_storage.get_download_url(import_key, _IMPORT_TTL_SECONDS) + + with ZipSandbox(tenant_id=tenant_id, user_id=account.id, app_id="app-bundle-import") as zs: + zs.download_archive(download_url, path="import.zip") + zs.unzip(archive_path="import.zip", dest_dir="bundle") + + manifest_bytes = zs.read_file(f"bundle/{MANIFEST_FILENAME}") + try: + manifest = BundleManifest.model_validate_json(manifest_bytes) + except ValidationError as e: + raise BundleFormatError(f"Invalid manifest.json: {e}") from e + + dsl_content = zs.read_file(f"bundle/{manifest.dsl_filename}").decode("utf-8") + + with Session(db.engine) as session: + dsl_service = AppDslService(session) + import_result = dsl_service.import_app( + account=account, + import_mode="yaml-content", + yaml_content=dsl_content, + name=name, + description=description, + icon_type=icon_type, + icon=icon, + icon_background=icon_background, + app_id=None, + ) + session.commit() + + if not import_result.app_id: + return import_result + + app_id = import_result.app_id + tree = manifest.app_assets.tree + + upload_items: list[SandboxUploadItem] = [] + for file_entry in manifest.files: + key = AssetPaths.draft(tenant_id, app_id, file_entry.node_id) + file_upload_url = asset_storage.get_upload_url(key, _IMPORT_TTL_SECONDS) + src_path = f"{manifest.assets_prefix}/{file_entry.path}" + upload_items.append(SandboxUploadItem(path=src_path, url=file_upload_url)) + + if upload_items: + zs.upload_items(upload_items, src_dir="bundle") + + # Tree sizes are already set from manifest; no need to update + app_model = db.session.query(App).where(App.id == app_id).first() + if app_model: + AppAssetService.set_draft_assets( + app_model=app_model, + account_id=account.id, + new_tree=tree, + ) + + return import_result + + # ========== Helpers ========== + + @staticmethod + def _sanitize_filename(name: str) -> str: + """Sanitize app name for use as filename.""" + safe = re.sub(r'[<>:"/\\|?*\x00-\x1f]', "_", name) + safe = safe.strip(". ") + return safe[:100] if safe else "app" diff --git a/api/services/sandbox/__init__.py b/api/services/sandbox/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/services/sandbox/sandbox_file_service.py b/api/services/sandbox/sandbox_file_service.py new file mode 100644 index 0000000000..2e388df9e3 --- /dev/null +++ b/api/services/sandbox/sandbox_file_service.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from core.sandbox.entities.files import SandboxFileDownloadTicket, SandboxFileNode +from core.sandbox.inspector import SandboxFileBrowser +from extensions.ext_storage import storage +from extensions.storage.cached_presign_storage import CachedPresignStorage +from extensions.storage.file_presign_storage import FilePresignStorage + + +class SandboxFileService: + @staticmethod + def get_storage() -> CachedPresignStorage: + """Get a lazily-initialized storage instance for sandbox files. + + Returns a CachedPresignStorage wrapping FilePresignStorage, + providing presign fallback and URL caching. + """ + return CachedPresignStorage( + storage=FilePresignStorage(storage.storage_runner), + cache_key_prefix="sandbox_files", + ) + + @classmethod + def exists(cls, *, tenant_id: str, app_id: str, sandbox_id: str) -> bool: + """Check if the sandbox source exists and is available.""" + browser = SandboxFileBrowser(tenant_id=tenant_id, app_id=app_id, sandbox_id=sandbox_id) + return browser.exists() + + @classmethod + def list_files( + cls, + *, + tenant_id: str, + app_id: str, + sandbox_id: str, + path: str | None = None, + recursive: bool = False, + ) -> list[SandboxFileNode]: + browser = SandboxFileBrowser(tenant_id=tenant_id, app_id=app_id, sandbox_id=sandbox_id) + if not browser.exists(): + return [] + return browser.list_files(path=path, recursive=recursive) + + @classmethod + def download_file(cls, *, tenant_id: str, app_id: str, sandbox_id: str, path: str) -> SandboxFileDownloadTicket: + browser = SandboxFileBrowser(tenant_id=tenant_id, app_id=app_id, sandbox_id=sandbox_id) + if not browser.exists(): + raise ValueError("Sandbox source not found") + return browser.download_file(path=path) diff --git a/api/services/sandbox/sandbox_provider_service.py b/api/services/sandbox/sandbox_provider_service.py new file mode 100644 index 0000000000..ac970c4d02 --- /dev/null +++ b/api/services/sandbox/sandbox_provider_service.py @@ -0,0 +1,233 @@ +import json +import logging +from collections.abc import Mapping +from typing import Any + +from sqlalchemy.orm import Session + +from constants import HIDDEN_VALUE +from core.sandbox import ( + SandboxBuilder, + SandboxProviderApiEntity, + SandboxType, + VMConfig, + create_sandbox_config_encrypter, + masked_config, +) +from core.sandbox.entities.providers import SandboxProviderEntity +from core.tools.utils.system_encryption import decrypt_system_params +from extensions.ext_database import db +from models.sandbox import SandboxProvider, SandboxProviderSystemConfig + +logger = logging.getLogger(__name__) + + +def _get_encrypter(tenant_id: str, provider_type: str): + return create_sandbox_config_encrypter(tenant_id, VMConfig.get_schema(SandboxType(provider_type)), provider_type)[0] + + +def _query_tenant_config(session: Session, tenant_id: str, provider_type: str) -> SandboxProvider | None: + return ( + session.query(SandboxProvider) + .filter(SandboxProvider.tenant_id == tenant_id, SandboxProvider.provider_type == provider_type) + .first() + ) + + +class SandboxProviderService: + @classmethod + def list_providers(cls, tenant_id: str) -> list[SandboxProviderApiEntity]: + with Session(db.engine, expire_on_commit=False) as session: + provider_types = SandboxType.get_all() + tenant_configs = { + config.provider_type: config + for config in session.query(SandboxProvider).where(SandboxProvider.tenant_id == tenant_id).all() + } + system_configs = { + config.provider_type: config + for config in session.query(SandboxProviderSystemConfig) + .where(SandboxProviderSystemConfig.provider_type.in_(provider_types)) + .all() + } + + providers: list[SandboxProviderApiEntity] = [] + current_provider = cls.get_active_sandbox_config(session, tenant_id) + for provider_type in SandboxType.get_all(): + tenant_config = tenant_configs.get(provider_type) + schema = VMConfig.get_schema(SandboxType(provider_type)) + if tenant_config: + is_tenant_configured = tenant_config.configure_type == "user" + if is_tenant_configured: + decrypted_config = _get_encrypter(tenant_id, provider_type).decrypt(data=tenant_config.config) + config = masked_config(schemas=schema, config=decrypted_config) + else: + config = {} + providers.append( + SandboxProviderApiEntity( + provider_type=provider_type, + is_system_configured=system_configs.get(provider_type) is not None, + is_tenant_configured=is_tenant_configured, + is_active=current_provider.id == tenant_config.id, + config=config, + config_schema=[c.model_dump() for c in schema], + ) + ) + else: + system_config = system_configs.get(provider_type) + providers.append( + SandboxProviderApiEntity( + provider_type=provider_type, + is_active=system_config is not None and system_config.id == current_provider.id, + is_system_configured=system_config is not None, + config_schema=[c.model_dump() for c in schema], + ) + ) + return providers + + @classmethod + def validate_config(cls, provider_type: str, config: Mapping[str, Any]) -> None: + SandboxBuilder.validate(SandboxType(provider_type), config) + + @classmethod + def save_config( + cls, tenant_id: str, provider_type: str, config: Mapping[str, Any], activate: bool + ) -> dict[str, Any]: + if provider_type not in SandboxType.get_all(): + raise ValueError(f"Invalid provider type: {provider_type}") + + with Session(db.engine) as session: + provider = _query_tenant_config(session, tenant_id, provider_type) + encrypter, cache = create_sandbox_config_encrypter( + tenant_id, VMConfig.get_schema(SandboxType(provider_type)), provider_type + ) + if not provider: + provider = SandboxProvider( + tenant_id=tenant_id, + provider_type=provider_type, + encrypted_config=json.dumps({}), + ) + session.add(provider) + + new_config = dict(config) + old_config = encrypter.decrypt(provider.config) + for key, value in new_config.items(): + if value == HIDDEN_VALUE: + new_config[key] = old_config.get(key, "") + + cls.validate_config(provider_type, new_config) + + provider.encrypted_config = json.dumps(encrypter.encrypt(new_config)) + provider.is_active = activate or provider.is_active or cls.is_system_default_config(session, tenant_id) + provider.configure_type = "user" + session.commit() + + cache.delete() + return {"result": "success"} + + @classmethod + def delete_config(cls, tenant_id: str, provider_type: str) -> dict[str, Any]: + with Session(db.engine) as session: + if config := _query_tenant_config(session, tenant_id, provider_type): + session.delete(config) + session.commit() + return {"result": "success"} + + @classmethod + def is_system_default_config(cls, session: Session, tenant_id: str) -> bool: + system_configed: SandboxProviderSystemConfig | None = session.query(SandboxProviderSystemConfig).first() + if not system_configed: + return False + active_config = cls.get_active_sandbox_config(session, tenant_id) + return active_config.id == system_configed.id + + @classmethod + def activate_provider(cls, tenant_id: str, provider_type: str, type: str | None = None) -> dict[str, Any]: + if provider_type not in SandboxType.get_all(): + raise ValueError(f"Invalid provider type: {provider_type}") + + with Session(db.engine) as session: + tenant_config = _query_tenant_config(session, tenant_id, provider_type) + system_config = session.query(SandboxProviderSystemConfig).filter_by(provider_type=provider_type).first() + + session.query(SandboxProvider).where(SandboxProvider.tenant_id == tenant_id).update({"is_active": False}) + + # using tenant config + if tenant_config: + tenant_config.is_active = True + tenant_config.configure_type = type or tenant_config.configure_type + session.commit() + return {"result": "success"} + + # using system config + if system_config: + session.add( + SandboxProvider( + is_active=True, + tenant_id=tenant_id, + configure_type="system", + provider_type=provider_type, + encrypted_config=json.dumps({}), + ) + ) + session.commit() + return {"result": "success"} + + raise ValueError(f"No sandbox provider configured for tenant {tenant_id} and provider type {provider_type}") + + @classmethod + def get_active_sandbox_config(cls, session: Session, tenant_id: str) -> SandboxProviderEntity: + tenant_configed = ( + session.query(SandboxProvider) + .filter(SandboxProvider.tenant_id == tenant_id, SandboxProvider.is_active.is_(True)) + .first() + ) + if tenant_configed: + if tenant_configed.configure_type == "user": + config = _get_encrypter(tenant_id, tenant_configed.provider_type).decrypt(tenant_configed.config) + return SandboxProviderEntity( + id=tenant_configed.id, provider_type=tenant_configed.provider_type, config=config + ) + else: + system_configed: SandboxProviderSystemConfig | None = ( + session.query(SandboxProviderSystemConfig) + .filter_by(provider_type=tenant_configed.provider_type) + .first() + ) + if not system_configed: + raise ValueError( + f"No system default provider configured for provider type {tenant_configed.provider_type}" + ) + return SandboxProviderEntity( + id=tenant_configed.id, + provider_type=system_configed.provider_type, + config=decrypt_system_params(system_configed.encrypted_config), + ) + + # fallback to system default config + system_configed = session.query(SandboxProviderSystemConfig).first() + if system_configed: + return SandboxProviderEntity( + id=system_configed.id, + provider_type=system_configed.provider_type, + config=decrypt_system_params(system_configed.encrypted_config), + ) + + raise ValueError(f"No sandbox provider configured for tenant {tenant_id}") + + @classmethod + def get_system_default_config(cls, session: Session, tenant_id: str, provider_type: str) -> SandboxProviderEntity: + system_configed: SandboxProviderSystemConfig | None = ( + session.query(SandboxProviderSystemConfig).filter_by(provider_type=provider_type).first() + ) + if system_configed: + return SandboxProviderEntity( + id=system_configed.id, + provider_type=system_configed.provider_type, + config=decrypt_system_params(system_configed.encrypted_config), + ) + raise ValueError(f"No system default provider configured for tenant {tenant_id}") + + @classmethod + def get_sandbox_provider(cls, tenant_id: str) -> SandboxProviderEntity: + with Session(db.engine, expire_on_commit=False) as session: + return cls.get_active_sandbox_config(session, tenant_id) diff --git a/api/services/sandbox/sandbox_service.py b/api/services/sandbox/sandbox_service.py new file mode 100644 index 0000000000..93a3a8065a --- /dev/null +++ b/api/services/sandbox/sandbox_service.py @@ -0,0 +1,139 @@ +"""Service for creating and managing sandbox instances. + +Three creation paths: + +- ``create()`` — published runtime. Downloads the pre-built ZIP via + ``AppAssetsInitializer`` and loads the ``SkillBundle`` via + ``SkillInitializer``. + +- ``create_draft()`` / ``create_for_single_step()`` — draft runtime. + ``DraftAppAssetsInitializer`` runs the build pipeline on the fly, + compiles ``.md`` skills (saving the ``SkillBundle`` to Redis/S3 as a + side-effect), and pushes resolved content as inline base64 into the + sandbox. ``SkillInitializer`` then loads the bundle from Redis/S3. + No separate ``build_assets()`` call is needed. +""" + +import logging + +from core.sandbox.builder import SandboxBuilder +from core.sandbox.entities import AppAssets, SandboxType +from core.sandbox.entities.providers import SandboxProviderEntity +from core.sandbox.initializer.app_asset_attrs_initializer import AppAssetAttrsInitializer +from core.sandbox.initializer.app_assets_initializer import AppAssetsInitializer +from core.sandbox.initializer.dify_cli_initializer import DifyCliInitializer +from core.sandbox.initializer.draft_app_assets_initializer import DraftAppAssetsDownloader, DraftAppAssetsInitializer +from core.sandbox.initializer.skill_initializer import SkillInitializer +from core.sandbox.sandbox import Sandbox +from core.sandbox.storage.archive_storage import ArchiveSandboxStorage +from extensions.ext_storage import storage +from services.app_asset_service import AppAssetService + +logger = logging.getLogger(__name__) + + +class SandboxService: + @classmethod + def create( + cls, + tenant_id: str, + app_id: str, + user_id: str, + sandbox_id: str, + sandbox_provider: SandboxProviderEntity, + ) -> Sandbox: + assets = AppAssetService.get_assets(tenant_id, app_id, user_id, is_draft=False) + if not assets: + raise ValueError(f"No assets found for tid={tenant_id}, app_id={app_id}") + + archive_storage = ArchiveSandboxStorage(tenant_id, app_id, sandbox_id, storage.storage_runner) + sandbox = ( + SandboxBuilder(tenant_id, SandboxType(sandbox_provider.provider_type)) + .options(sandbox_provider.config) + .user(user_id) + .app(app_id) + .initializer(AppAssetAttrsInitializer()) + .initializer(AppAssetsInitializer()) + .initializer(SkillInitializer()) + .initializer(DifyCliInitializer()) + .storage(archive_storage, assets.id) + .build() + ) + + logger.info("Sandbox created: id=%s, assets=%s", sandbox.id, sandbox.assets_id) + return sandbox + + @classmethod + def delete_draft_storage(cls, tenant_id: str, app_id: str, user_id: str) -> None: + archive_storage = ArchiveSandboxStorage( + tenant_id, app_id, SandboxBuilder.draft_id(user_id), storage.storage_runner + ) + archive_storage.delete() + + @classmethod + def create_draft( + cls, + tenant_id: str, + app_id: str, + user_id: str, + sandbox_provider: SandboxProviderEntity, + ) -> Sandbox: + assets = AppAssetService.get_assets(tenant_id, app_id, user_id, is_draft=True) + if not assets: + raise ValueError(f"No assets found for tid={tenant_id}, app_id={app_id}") + + sandbox_id = SandboxBuilder.draft_id(user_id) + archive_storage = ArchiveSandboxStorage( + tenant_id, app_id, sandbox_id, storage.storage_runner, exclude_patterns=[AppAssets.PATH] + ) + + sandbox = ( + SandboxBuilder(tenant_id, SandboxType(sandbox_provider.provider_type)) + .options(sandbox_provider.config) + .user(user_id) + .app(app_id) + .initializer(AppAssetAttrsInitializer()) + .initializer(DraftAppAssetsInitializer()) + .initializer(DraftAppAssetsDownloader()) + .initializer(SkillInitializer()) + .initializer(DifyCliInitializer()) + .storage(archive_storage, assets.id) + .build() + ) + + logger.info("Draft sandbox created: id=%s, assets=%s", sandbox.id, sandbox.assets_id) + return sandbox + + @classmethod + def create_for_single_step( + cls, + tenant_id: str, + app_id: str, + user_id: str, + sandbox_provider: SandboxProviderEntity, + ) -> Sandbox: + assets = AppAssetService.get_assets(tenant_id, app_id, user_id, is_draft=True) + if not assets: + raise ValueError(f"No assets found for tid={tenant_id}, app_id={app_id}") + + sandbox_id = SandboxBuilder.draft_id(user_id) + archive_storage = ArchiveSandboxStorage( + tenant_id, app_id, sandbox_id, storage.storage_runner, exclude_patterns=[AppAssets.PATH] + ) + + sandbox = ( + SandboxBuilder(tenant_id, SandboxType(sandbox_provider.provider_type)) + .options(sandbox_provider.config) + .user(user_id) + .app(app_id) + .initializer(AppAssetAttrsInitializer()) + .initializer(DraftAppAssetsInitializer()) + .initializer(DraftAppAssetsDownloader()) + .initializer(SkillInitializer()) + .initializer(DifyCliInitializer()) + .storage(archive_storage, assets.id) + .build() + ) + + logger.info("Single-step sandbox created: id=%s, assets=%s", sandbox.id, sandbox.assets_id) + return sandbox