From 956436b9439803b63e25c9fbe8992e57f1f16d88 Mon Sep 17 00:00:00 2001 From: Harry Date: Mon, 19 Jan 2026 18:15:24 +0800 Subject: [PATCH] feat(sandbox): skill initialize & draft run --- .../app/apps/advanced_chat/app_generator.py | 1 + api/core/app/apps/workflow/app_generator.py | 1 + api/core/app/layers/sandbox_layer.py | 49 +++- api/core/app_assets/entities/skill.py | 1 - api/core/app_assets/parser/asset_parser.py | 17 +- api/core/app_assets/parser/base.py | 2 - api/core/app_assets/parser/skill_parser.py | 114 ++++++---- api/core/app_assets/paths.py | 12 +- api/core/sandbox/__init__.py | 8 +- api/core/sandbox/bash/bash_tool.py | 6 +- api/core/sandbox/constants.py | 7 +- .../initializer/app_assets_initializer.py | 43 +--- .../initializer/dify_cli_initializer.py | 146 ++++++++++-- api/core/sandbox/session.py | 81 ++++--- api/core/skill/skill_manager.py | 17 +- api/core/workflow/nodes/llm/node.py | 42 ++-- api/services/app_asset_service.py | 54 ++++- .../core/app/layers/test_sandbox_layer.py | 210 ++++++++++++------ 18 files changed, 531 insertions(+), 280 deletions(-) diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index bb1b02656b..79d50699e5 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -523,6 +523,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): SandboxLayer( tenant_id=application_generate_entity.app_config.tenant_id, app_id=application_generate_entity.app_config.app_id, + workflow_version=workflow.version, sandbox_id=application_generate_entity.workflow_run_id, sandbox_storage=ArchiveSandboxStorage( tenant_id=application_generate_entity.app_config.tenant_id, diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index b7f359bbd1..e6cdf3a4d2 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -497,6 +497,7 @@ class WorkflowAppGenerator(BaseAppGenerator): SandboxLayer( tenant_id=application_generate_entity.app_config.tenant_id, app_id=application_generate_entity.app_config.app_id, + workflow_version=workflow.version, sandbox_id=application_generate_entity.workflow_execution_id, sandbox_storage=ArchiveSandboxStorage( tenant_id=application_generate_entity.app_config.tenant_id, diff --git a/api/core/app/layers/sandbox_layer.py b/api/core/app/layers/sandbox_layer.py index c1b6071eb5..5156f8c896 100644 --- a/api/core/app/layers/sandbox_layer.py +++ b/api/core/app/layers/sandbox_layer.py @@ -1,11 +1,14 @@ import logging -from core.sandbox import SandboxManager +from core.sandbox import AppAssetsInitializer, DifyCliInitializer, SandboxManager from core.sandbox.storage.sandbox_storage import SandboxStorage from core.virtual_environment.__base.virtual_environment import VirtualEnvironment from core.workflow.graph_engine.layers.base import GraphEngineLayer from core.workflow.graph_events.base import GraphEngineEvent from core.workflow.graph_events.graph import GraphRunPausedEvent +from models.workflow import Workflow +from services.app_asset_service import AppAssetService +from services.sandbox.sandbox_provider_service import SandboxProviderService logger = logging.getLogger(__name__) @@ -15,10 +18,18 @@ class SandboxInitializationError(Exception): class SandboxLayer(GraphEngineLayer): - def __init__(self, tenant_id: str, app_id: str, sandbox_id: str, sandbox_storage: SandboxStorage) -> None: + def __init__( + self, + tenant_id: str, + app_id: str, + workflow_version: str, + sandbox_id: str, + sandbox_storage: SandboxStorage, + ) -> None: super().__init__() self._tenant_id = tenant_id self._app_id = app_id + self._workflow_version = workflow_version self._sandbox_id = sandbox_id self._sandbox_storage = sandbox_storage @@ -31,16 +42,34 @@ class SandboxLayer(GraphEngineLayer): def on_graph_start(self) -> None: try: - # Initialize sandbox - from core.sandbox import AppAssetsInitializer, DifyCliInitializer - from services.sandbox.sandbox_provider_service import SandboxProviderService + is_draft = self._workflow_version == Workflow.VERSION_DRAFT + assets = AppAssetService.get_assets(self._tenant_id, self._app_id, is_draft=is_draft) + if not assets: + raise ValueError( + f"No assets found for tid={self._tenant_id}, app_id={self._app_id}, wf={self._workflow_version}" + ) + if is_draft: + logger.info( + "Building draft assets for tenant_id=%s, app_id=%s, workflow_version=%s, assets_id=%s", + self._tenant_id, + self._app_id, + self._workflow_version, + assets.id, + ) + AppAssetService.build_assets(self._tenant_id, self._app_id, assets) - logger.info("Initializing sandbox for tenant_id=%s, app_id=%s", self._tenant_id, self._app_id) + logger.info( + "Initializing sandbox for tenant_id=%s, app_id=%s, workflow_version=%s, assets_id=%s", + self._tenant_id, + self._app_id, + self._workflow_version, + assets.id, + ) builder = ( SandboxProviderService.create_sandbox_builder(self._tenant_id) - .initializer(DifyCliInitializer()) - .initializer(AppAssetsInitializer(self._tenant_id, self._app_id)) + .initializer(AppAssetsInitializer(self._tenant_id, self._app_id, assets.id)) + .initializer(DifyCliInitializer(self._tenant_id, self._app_id, assets.id)) ) sandbox = builder.build() @@ -65,10 +94,6 @@ class SandboxLayer(GraphEngineLayer): return def on_graph_end(self, error: Exception | None) -> None: - if self._sandbox_id is None: - logger.debug("No workflow_execution_id set, nothing to release") - return - sandbox = SandboxManager.unregister(self._sandbox_id) if sandbox is None: logger.debug("No sandbox to release for sandbox_id=%s", self._sandbox_id) diff --git a/api/core/app_assets/entities/skill.py b/api/core/app_assets/entities/skill.py index 3658d06537..3c5691f62b 100644 --- a/api/core/app_assets/entities/skill.py +++ b/api/core/app_assets/entities/skill.py @@ -61,7 +61,6 @@ class SkillMetadata(BaseModel): class SkillAsset(AssetItem): storage_key: str metadata: SkillMetadata - content: str tool_references: list[ToolReference] = field(default_factory=list) file_references: list[FileReference] = field(default_factory=list) diff --git a/api/core/app_assets/parser/asset_parser.py b/api/core/app_assets/parser/asset_parser.py index fc14ce5dfd..3a934520cc 100644 --- a/api/core/app_assets/parser/asset_parser.py +++ b/api/core/app_assets/parser/asset_parser.py @@ -1,34 +1,20 @@ -from typing import TYPE_CHECKING - from core.app.entities.app_asset_entities import AppAssetFileTree from core.app_assets.entities import AssetItem from core.app_assets.paths import AssetPaths from .base import AssetItemParser, FileAssetParser -if TYPE_CHECKING: - from extensions.ext_storage import Storage - class AssetParser: - _tree: AppAssetFileTree - _tenant_id: str - _app_id: str - _storage: "Storage" - _parsers: dict[str, AssetItemParser] - _default_parser: AssetItemParser - def __init__( self, tree: AppAssetFileTree, tenant_id: str, app_id: str, - storage: "Storage", ) -> None: self._tree = tree self._tenant_id = tenant_id self._app_id = app_id - self._storage = storage self._parsers = {} self._default_parser = FileAssetParser() @@ -41,11 +27,10 @@ class AssetParser: for node in self._tree.walk_files(): path = self._tree.get_path(node.id).lstrip("/") storage_key = AssetPaths.draft_file(self._tenant_id, self._app_id, node.id) - raw_bytes = self._storage.load_once(storage_key) extension = node.extension or "" parser = self._parsers.get(extension, self._default_parser) - asset = parser.parse(node.id, path, node.name, extension, storage_key, raw_bytes) + asset = parser.parse(node.id, path, node.name, extension, storage_key) assets.append(asset) return assets diff --git a/api/core/app_assets/parser/base.py b/api/core/app_assets/parser/base.py index 4ced242868..01696c5c7e 100644 --- a/api/core/app_assets/parser/base.py +++ b/api/core/app_assets/parser/base.py @@ -12,7 +12,6 @@ class AssetItemParser(ABC): file_name: str, extension: str, storage_key: str, - raw_bytes: bytes, ) -> AssetItem: raise NotImplementedError @@ -25,7 +24,6 @@ class FileAssetParser(AssetItemParser): file_name: str, extension: str, storage_key: str, - raw_bytes: bytes, ) -> FileAsset: return FileAsset( node_id=node_id, diff --git a/api/core/app_assets/parser/skill_parser.py b/api/core/app_assets/parser/skill_parser.py index 9fe783c213..bd67e3ba54 100644 --- a/api/core/app_assets/parser/skill_parser.py +++ b/api/core/app_assets/parser/skill_parser.py @@ -1,6 +1,7 @@ import json +import logging import re -from typing import TYPE_CHECKING, Any +from typing import Any from core.app_assets.entities import ( FileReference, @@ -9,36 +10,26 @@ from core.app_assets.entities import ( ToolReference, ) from core.app_assets.paths import AssetPaths +from extensions.ext_storage import storage from .base import AssetItemParser -if TYPE_CHECKING: - from extensions.ext_storage import Storage - TOOL_REFERENCE_PATTERN = re.compile(r"§\[tool\]\.\[([^\]]+)\]\.\[([^\]]+)\]\.\[([^\]]+)\]§") FILE_REFERENCE_PATTERN = re.compile(r"§\[file\]\.\[([^\]]+)\]\.\[([^\]]+)\]§") +logger = logging.getLogger(__name__) + class SkillAssetParser(AssetItemParser): - _tenant_id: str - _app_id: str - _publish_id: str - _storage: "Storage" - def __init__( self, tenant_id: str, app_id: str, - publish_id: str, - storage: "Storage", + assets_id: str, ) -> None: self._tenant_id = tenant_id self._app_id = app_id - self._publish_id = publish_id - self._storage = storage - - def _get_resolved_key(self, node_id: str) -> str: - return AssetPaths.published_resolved_file(self._tenant_id, self._app_id, self._publish_id, node_id) + self._assets_id = assets_id def parse( self, @@ -47,12 +38,40 @@ class SkillAssetParser(AssetItemParser): file_name: str, extension: str, storage_key: str, - raw_bytes: bytes, ) -> SkillAsset: try: - data = json.loads(raw_bytes.decode("utf-8")) - except (json.JSONDecodeError, UnicodeDecodeError) as e: - raise ValueError(f"Invalid skill document JSON for {node_id}: {e}") from e + return self._parse_skill_asset(node_id, path, file_name, extension, storage_key) + except Exception: + logger.exception("Failed to parse skill asset %s: %s", node_id) + # handle as plain text + return SkillAsset( + node_id=node_id, + path=path, + file_name=file_name, + extension=extension, + storage_key=storage_key, + metadata=SkillMetadata(), + tool_references=[], + file_references=[], + ) + + def _parse_skill_asset( + self, node_id: str, path: str, file_name: str, extension: str, storage_key: str + ) -> SkillAsset: + try: + data = json.loads(storage.load_once(storage_key)) + except (json.JSONDecodeError, UnicodeDecodeError): + # handle as plain text + return SkillAsset( + node_id=node_id, + path=path, + file_name=file_name, + extension=extension, + storage_key=storage_key, + metadata=SkillMetadata(), + tool_references=[], + file_references=[], + ) if not isinstance(data, dict): raise ValueError(f"Skill document {node_id} must be a JSON object") @@ -66,30 +85,12 @@ class SkillAssetParser(AssetItemParser): metadata = SkillMetadata.model_validate(metadata_raw) - tool_references: list[ToolReference] = [] - for match in TOOL_REFERENCE_PATTERN.finditer(content): - tool_references.append( - ToolReference( - provider=match.group(1), - tool_name=match.group(2), - uuid=match.group(3), - raw=match.group(0), - ) - ) - - file_references: list[FileReference] = [] - for match in FILE_REFERENCE_PATTERN.finditer(content): - file_references.append( - FileReference( - source=match.group(1), - uuid=match.group(2), - raw=match.group(0), - ) - ) + tool_references: list[ToolReference] = self._parse_tool_references(content) + file_references: list[FileReference] = self._parse_file_references(content) resolved_content = self._resolve_content(content, tool_references, file_references) - resolved_key = self._get_resolved_key(node_id) - self._storage.save(resolved_key, resolved_content.encode("utf-8")) + resolved_key = AssetPaths.build_resolved_file(self._tenant_id, self._app_id, self._assets_id, node_id) + storage.save(resolved_key, resolved_content.encode("utf-8")) return SkillAsset( node_id=node_id, @@ -98,7 +99,6 @@ class SkillAssetParser(AssetItemParser): extension=extension, storage_key=resolved_key, metadata=metadata, - content=resolved_content, tool_references=tool_references, file_references=file_references, ) @@ -110,7 +110,7 @@ class SkillAssetParser(AssetItemParser): file_references: list[FileReference], ) -> str: for ref in tool_references: - replacement = f"{ref.provider}/{ref.tool_name}" + replacement = f"{ref.tool_name}" content = content.replace(ref.raw, replacement) for ref in file_references: @@ -118,3 +118,29 @@ class SkillAssetParser(AssetItemParser): content = content.replace(ref.raw, replacement) return content + + def _parse_tool_references(self, content: str) -> list[ToolReference]: + tool_references: list[ToolReference] = [] + for match in TOOL_REFERENCE_PATTERN.finditer(content): + tool_references.append( + ToolReference( + provider=match.group(1), + tool_name=match.group(2), + uuid=match.group(3), + raw=match.group(0), + ) + ) + + return tool_references + + def _parse_file_references(self, content: str) -> list[FileReference]: + file_references: list[FileReference] = [] + for match in FILE_REFERENCE_PATTERN.finditer(content): + file_references.append( + FileReference( + source=match.group(1), + uuid=match.group(2), + raw=match.group(0), + ) + ) + return file_references diff --git a/api/core/app_assets/paths.py b/api/core/app_assets/paths.py index 00644bea7c..736b567ab8 100644 --- a/api/core/app_assets/paths.py +++ b/api/core/app_assets/paths.py @@ -6,13 +6,13 @@ class AssetPaths: return f"{AssetPaths._BASE}/{tenant_id}/{app_id}/draft/{node_id}" @staticmethod - def published_zip(tenant_id: str, app_id: str, publish_id: str) -> str: - return f"{AssetPaths._BASE}/{tenant_id}/{app_id}/published/{publish_id}.zip" + def build_zip(tenant_id: str, app_id: str, assets_id: str) -> str: + return f"{AssetPaths._BASE}/{tenant_id}/{app_id}/build/{assets_id}.zip" @staticmethod - def published_resolved_file(tenant_id: str, app_id: str, publish_id: str, node_id: str) -> str: - return f"{AssetPaths._BASE}/{tenant_id}/{app_id}/published/{publish_id}/resolved/{node_id}" + def build_resolved_file(tenant_id: str, app_id: str, assets_id: str, node_id: str) -> str: + return f"{AssetPaths._BASE}/{tenant_id}/{app_id}/build/{assets_id}/resolved/{node_id}" @staticmethod - def published_tool_manifest(tenant_id: str, app_id: str, publish_id: str) -> str: - return f"{AssetPaths._BASE}/{tenant_id}/{app_id}/published/{publish_id}/tools.json" + def build_tool_manifest(tenant_id: str, app_id: str, assets_id: str) -> str: + return f"{AssetPaths._BASE}/{tenant_id}/{app_id}/build/{assets_id}/tools.json" diff --git a/api/core/sandbox/__init__.py b/api/core/sandbox/__init__.py index c7fcd8c899..1bf69d9d75 100644 --- a/api/core/sandbox/__init__.py +++ b/api/core/sandbox/__init__.py @@ -8,10 +8,12 @@ from .bash.dify_cli import ( from .constants import ( APP_ASSETS_PATH, APP_ASSETS_ZIP_PATH, - DIFY_CLI_CONFIG_PATH, + DIFY_CLI_CONFIG_FILENAME, + DIFY_CLI_GLOBAL_TOOLS_PATH, DIFY_CLI_PATH, DIFY_CLI_PATH_PATTERN, DIFY_CLI_ROOT, + DIFY_CLI_TOOLS_ROOT, ) from .initializer import AppAssetsInitializer, DifyCliInitializer, SandboxInitializer from .manager import SandboxManager @@ -24,10 +26,12 @@ from .vm import SandboxBuilder, SandboxType, VMConfig __all__ = [ "APP_ASSETS_PATH", "APP_ASSETS_ZIP_PATH", - "DIFY_CLI_CONFIG_PATH", + "DIFY_CLI_CONFIG_FILENAME", + "DIFY_CLI_GLOBAL_TOOLS_PATH", "DIFY_CLI_PATH", "DIFY_CLI_PATH_PATTERN", "DIFY_CLI_ROOT", + "DIFY_CLI_TOOLS_ROOT", "AppAssetsInitializer", "ArchiveSandboxStorage", "DifyCliBinary", diff --git a/api/core/sandbox/bash/bash_tool.py b/api/core/sandbox/bash/bash_tool.py index cb86c2f3f8..632974e4e4 100644 --- a/api/core/sandbox/bash/bash_tool.py +++ b/api/core/sandbox/bash/bash_tool.py @@ -21,8 +21,9 @@ COMMAND_TIMEOUT_SECONDS = 60 class SandboxBashTool(Tool): - def __init__(self, sandbox: VirtualEnvironment, tenant_id: str): + def __init__(self, sandbox: VirtualEnvironment, tenant_id: str, tools_path: str) -> None: self._sandbox = sandbox + self._tools_path = tools_path entity = ToolEntity( identity=ToolIdentity( @@ -71,9 +72,10 @@ class SandboxBashTool(Tool): try: with with_connection(self._sandbox) as conn: cmd_list = ["bash", "-c", command] + env_vars = {"PATH": f"{self._tools_path}:/usr/local/bin:/usr/bin:/bin"} sandbox_debug("bash_tool", "cmd_list", cmd_list) - future = submit_command(self._sandbox, conn, cmd_list) + future = submit_command(self._sandbox, conn, cmd_list, environments=env_vars) timeout = COMMAND_TIMEOUT_SECONDS if COMMAND_TIMEOUT_SECONDS > 0 else None result = future.result(timeout=timeout) diff --git a/api/core/sandbox/constants.py b/api/core/sandbox/constants.py index 35227227f5..2ef922c7b0 100644 --- a/api/core/sandbox/constants.py +++ b/api/core/sandbox/constants.py @@ -6,8 +6,11 @@ DIFY_CLI_PATH: Final[str] = "/tmp/.dify/bin/dify" DIFY_CLI_PATH_PATTERN: Final[str] = "dify-cli-{os}-{arch}" -DIFY_CLI_CONFIG_PATH: Final[str] = "/tmp/.dify/.dify_cli.json" +DIFY_CLI_CONFIG_FILENAME: Final[str] = ".dify_cli.json" + +DIFY_CLI_TOOLS_ROOT: Final[str] = "/tmp/.dify/tools" +DIFY_CLI_GLOBAL_TOOLS_PATH: Final[str] = "/tmp/.dify/tools/global" # App Assets (relative path - stays in sandbox workdir) APP_ASSETS_PATH: Final[str] = "assets" -APP_ASSETS_ZIP_PATH: Final[str] = "/tmp/.dify/tmp/assets.zip" +APP_ASSETS_ZIP_PATH: Final[str] = "/tmp/assets.zip" diff --git a/api/core/sandbox/initializer/app_assets_initializer.py b/api/core/sandbox/initializer/app_assets_initializer.py index 4e0fd24307..c7b54b85c5 100644 --- a/api/core/sandbox/initializer/app_assets_initializer.py +++ b/api/core/sandbox/initializer/app_assets_initializer.py @@ -1,33 +1,25 @@ import logging from io import BytesIO -from sqlalchemy.orm import Session - from core.app_assets.paths import AssetPaths from core.virtual_environment.__base.helpers import execute, with_connection from core.virtual_environment.__base.virtual_environment import VirtualEnvironment -from extensions.ext_database import db from extensions.ext_storage import storage -from models.app_asset import AppAssets -from ..constants import APP_ASSETS_PATH, APP_ASSETS_ZIP_PATH, DIFY_CLI_ROOT +from ..constants import APP_ASSETS_PATH, APP_ASSETS_ZIP_PATH from .base import SandboxInitializer logger = logging.getLogger(__name__) class AppAssetsInitializer(SandboxInitializer): - def __init__(self, tenant_id: str, app_id: str) -> None: + def __init__(self, tenant_id: str, app_id: str, assets_id: str) -> None: self._tenant_id = tenant_id self._app_id = app_id + self._assets_id = assets_id def initialize(self, env: VirtualEnvironment) -> None: - published = self._get_latest_published() - if not published: - logger.debug("No published assets for app_id=%s, skipping", self._app_id) - return - - zip_key = AssetPaths.published_zip(self._tenant_id, self._app_id, published.id) + zip_key = AssetPaths.build_zip(self._tenant_id, self._app_id, self._assets_id) try: zip_data = storage.load_once(zip_key) except Exception: @@ -42,18 +34,6 @@ class AppAssetsInitializer(SandboxInitializer): env.upload_file(APP_ASSETS_ZIP_PATH, BytesIO(zip_data)) with with_connection(env) as conn: - execute( - env, - ["mkdir", "-p", f"{DIFY_CLI_ROOT}/tmp"], - connection=conn, - error_message="Failed to create temp directory", - ) - execute( - env, - ["mkdir", "-p", APP_ASSETS_PATH], - connection=conn, - error_message="Failed to create assets directory", - ) execute( env, ["unzip", "-o", APP_ASSETS_ZIP_PATH, "-d", APP_ASSETS_PATH], @@ -71,18 +51,5 @@ class AppAssetsInitializer(SandboxInitializer): logger.info( "App assets initialized for app_id=%s, published_id=%s", self._app_id, - published.id, + self._assets_id, ) - - def _get_latest_published(self) -> AppAssets | None: - with Session(db.engine) as session: - return ( - session.query(AppAssets) - .filter( - AppAssets.tenant_id == self._tenant_id, - AppAssets.app_id == self._app_id, - AppAssets.version != AppAssets.VERSION_DRAFT, - ) - .order_by(AppAssets.created_at.desc()) - .first() - ) diff --git a/api/core/sandbox/initializer/dify_cli_initializer.py b/api/core/sandbox/initializer/dify_cli_initializer.py index 68c02c1c3b..311f87dd0c 100644 --- a/api/core/sandbox/initializer/dify_cli_initializer.py +++ b/api/core/sandbox/initializer/dify_cli_initializer.py @@ -1,37 +1,147 @@ +from __future__ import annotations + +import json import logging from io import BytesIO from pathlib import Path -from core.virtual_environment.__base.helpers import execute +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app_assets.entities import ToolType +from core.session.cli_api import CliApiSessionManager +from core.skill.entities import ToolManifest +from core.skill.skill_manager import SkillManager +from core.tools.__base.tool import Tool +from core.tools.entities.tool_entities import ToolProviderType +from core.tools.tool_manager import ToolManager +from core.virtual_environment.__base.helpers import execute, with_connection from core.virtual_environment.__base.virtual_environment import VirtualEnvironment -from ..bash.dify_cli import DifyCliLocator -from ..constants import DIFY_CLI_PATH, DIFY_CLI_ROOT +from ..bash.dify_cli import DifyCliConfig, DifyCliLocator +from ..constants import ( + DIFY_CLI_CONFIG_FILENAME, + DIFY_CLI_GLOBAL_TOOLS_PATH, + DIFY_CLI_PATH, + DIFY_CLI_ROOT, +) from .base import SandboxInitializer logger = logging.getLogger(__name__) class DifyCliInitializer(SandboxInitializer): - def __init__(self, cli_root: str | Path | None = None) -> None: + def __init__( + self, + tenant_id: str, + app_id: str, + assets_id: str, + cli_root: str | Path | None = None, + ) -> None: + self._tenant_id = tenant_id + self._app_id = app_id + self._assets_id = assets_id self._locator = DifyCliLocator(root=cli_root) + self._tools = [] + self._cli_api_session = None + def initialize(self, env: VirtualEnvironment) -> None: binary = self._locator.resolve(env.metadata.os, env.metadata.arch) - execute( - env, - ["mkdir", "-p", f"{DIFY_CLI_ROOT}/bin"], - timeout=10, - error_message="Failed to create dify CLI directory", - ) + with with_connection(env) as conn: + execute( + env, + ["mkdir", "-p", f"{DIFY_CLI_ROOT}/bin"], + connection=conn, + timeout=10, + error_message="Failed to create dify CLI directory", + ) - env.upload_file(DIFY_CLI_PATH, BytesIO(binary.path.read_bytes())) + env.upload_file(DIFY_CLI_PATH, BytesIO(binary.path.read_bytes())) - execute( - env, - ["chmod", "+x", DIFY_CLI_PATH], - timeout=10, - error_message="Failed to mark dify CLI as executable", - ) - logger.info("Dify CLI uploaded to sandbox, path=%s", DIFY_CLI_PATH) + execute( + env, + ["chmod", "+x", DIFY_CLI_PATH], + connection=conn, + timeout=10, + error_message="Failed to mark dify CLI as executable", + ) + + logger.info("Dify CLI uploaded to sandbox, path=%s", DIFY_CLI_PATH) + + manifest = SkillManager.load_tool_manifest( + self._tenant_id, + self._app_id, + self._assets_id, + ) + + if manifest is None or not manifest.tools: + logger.info("No tools found in manifest for assets_id=%s", self._assets_id) + return + + self._tools = self._resolve_tools_from_manifest(manifest) + self._cli_api_session = CliApiSessionManager().create(tenant_id=self._tenant_id, user_id="system") + + execute( + env, + ["mkdir", "-p", DIFY_CLI_GLOBAL_TOOLS_PATH], + connection=conn, + timeout=10, + error_message="Failed to create global tools directory", + ) + + config_json = json.dumps( + DifyCliConfig.create(self._cli_api_session, self._tools).model_dump(mode="json"), ensure_ascii=False + ) + env.upload_file( + f"{DIFY_CLI_GLOBAL_TOOLS_PATH}/{DIFY_CLI_CONFIG_FILENAME}", BytesIO(config_json.encode("utf-8")) + ) + + execute( + env, + [DIFY_CLI_PATH, "init"], + connection=conn, + timeout=30, + cwd=DIFY_CLI_GLOBAL_TOOLS_PATH, + error_message="Failed to initialize Dify CLI", + ) + + logger.info( + "Global tools initialized, path=%s, tool_count=%d", + DIFY_CLI_GLOBAL_TOOLS_PATH, + len(self._tools), + ) + + def _resolve_tools_from_manifest(self, manifest: ToolManifest) -> list[Tool]: + tools: list[Tool] = [] + + for entry in manifest.tools.values(): + if entry.provider is None or entry.tool_name is None: + logger.warning("Skipping tool entry with missing provider or tool_name: %s", entry.uuid) + continue + + try: + provider_type = self._convert_tool_type(entry.type) + tool = ToolManager.get_tool_runtime( + tenant_id=self._tenant_id, + provider_type=provider_type, + provider_id=entry.provider, + tool_name=entry.tool_name, + invoke_from=InvokeFrom.AGENT, + credential_id=entry.credential_id, + ) + tools.append(tool) + except Exception as e: + logger.warning("Failed to resolve tool %s/%s: %s", entry.provider, entry.tool_name, e) + continue + + return tools + + @staticmethod + def _convert_tool_type(tool_type: ToolType) -> ToolProviderType: + match tool_type: + case ToolType.BUILTIN: + return ToolProviderType.BUILT_IN + case ToolType.MCP: + return ToolProviderType.MCP + case _: + raise ValueError(f"Unsupported tool type: {tool_type}") diff --git a/api/core/sandbox/session.py b/api/core/sandbox/session.py index bd6a416888..e7765332f2 100644 --- a/api/core/sandbox/session.py +++ b/api/core/sandbox/session.py @@ -1,79 +1,90 @@ from __future__ import annotations -import json import logging -from io import BytesIO from types import TracebackType from typing import TYPE_CHECKING from core.session.cli_api import CliApiSessionManager -from core.virtual_environment.__base.helpers import execute from core.virtual_environment.__base.virtual_environment import VirtualEnvironment -from .bash.dify_cli import DifyCliConfig -from .constants import DIFY_CLI_CONFIG_PATH, DIFY_CLI_PATH +from .constants import ( + DIFY_CLI_GLOBAL_TOOLS_PATH, +) from .manager import SandboxManager -from .utils.debug import sandbox_debug if TYPE_CHECKING: - from core.tools.__base.tool import Tool - from .bash.bash_tool import SandboxBashTool logger = logging.getLogger(__name__) class SandboxSession: + _workflow_execution_id: str + _tenant_id: str + _user_id: str + _node_id: str | None + _allow_tools: list[str] | None + + _sandbox: VirtualEnvironment | None + _bash_tool: SandboxBashTool | None + _session_id: str | None + _tools_path: str + def __init__( self, *, workflow_execution_id: str, tenant_id: str, user_id: str, - tools: list[Tool], + node_id: str | None = None, + allow_tools: list[str] | None = None, ) -> None: self._workflow_execution_id = workflow_execution_id self._tenant_id = tenant_id self._user_id = user_id - self._tools = tools + self._node_id = node_id + self._allow_tools = allow_tools - self._sandbox: VirtualEnvironment | None = None - self._bash_tool: SandboxBashTool | None = None - self._session_id: str | None = None + self._sandbox = None + self._bash_tool = None + self._session_id = None + self._tools_path = DIFY_CLI_GLOBAL_TOOLS_PATH def __enter__(self) -> SandboxSession: sandbox = SandboxManager.get(self._workflow_execution_id) if sandbox is None: raise RuntimeError(f"Sandbox not found for workflow_execution_id={self._workflow_execution_id}") - session = CliApiSessionManager().create(tenant_id=self._tenant_id, user_id=self._user_id) - self._session_id = session.id + self._sandbox = sandbox - try: - config = DifyCliConfig.create(session, self._tools) - config_json = json.dumps(config.model_dump(mode="json"), ensure_ascii=False) - - sandbox_debug("sandbox", "config_json", config_json) - sandbox.upload_file(DIFY_CLI_CONFIG_PATH, BytesIO(config_json.encode("utf-8"))) - - execute( - sandbox, - [DIFY_CLI_PATH, "init"], - timeout=30, - error_message="Failed to initialize Dify CLI in sandbox", - ) - - except Exception: - CliApiSessionManager().delete(session.id) - self._session_id = None - raise + if self._allow_tools is not None: + # TODO: Implement node tools directory setup + if self._node_id is None: + raise ValueError("node_id is required when allow_tools is specified") + # self._tools_path = self._setup_node_tools_directory(sandbox, self._node_id, self._allow_tools) + else: + self._tools_path = DIFY_CLI_GLOBAL_TOOLS_PATH from .bash.bash_tool import SandboxBashTool - self._sandbox = sandbox - self._bash_tool = SandboxBashTool(sandbox=sandbox, tenant_id=self._tenant_id) + self._bash_tool = SandboxBashTool(sandbox=sandbox, tenant_id=self._tenant_id, tools_path=self._tools_path) return self + def _setup_node_tools_directory( + self, + sandbox: VirtualEnvironment, + node_id: str, + allow_tools: list[str], + ) -> None: + pass + + @staticmethod + def _get_tool_name_from_config(tool_config: dict) -> str: + identity = tool_config.get("identity", {}) + provider = identity.get("provider", "") + name = identity.get("name", "") + return f"{provider}__{name}" + def __exit__( self, exc_type: type[BaseException] | None, diff --git a/api/core/skill/skill_manager.py b/api/core/skill/skill_manager.py index 487a37acb7..fea0e269d1 100644 --- a/api/core/skill/skill_manager.py +++ b/api/core/skill/skill_manager.py @@ -22,15 +22,28 @@ class SkillManager: def save_tool_manifest( tenant_id: str, app_id: str, - publish_id: str, + assets_id: str, manifest: ToolManifest, ) -> None: if not manifest.tools: return - key = AssetPaths.published_tool_manifest(tenant_id, app_id, publish_id) + key = AssetPaths.build_tool_manifest(tenant_id, app_id, assets_id) storage.save(key, manifest.model_dump_json(indent=2).encode("utf-8")) + @staticmethod + def load_tool_manifest( + tenant_id: str, + app_id: str, + assets_id: str, + ) -> ToolManifest | None: + key = AssetPaths.build_tool_manifest(tenant_id, app_id, assets_id) + try: + data = storage.load_once(key) + return ToolManifest.model_validate_json(data) + except Exception: + return None + @staticmethod def _collect_asset_manifest(asset: SkillAsset) -> ToolManifest: tools: dict[str, ToolManifestEntry] = {} diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index b7d79ee68a..8972b73833 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -13,7 +13,7 @@ from sqlalchemy import select from core.agent.entities import AgentEntity, AgentLog, AgentResult, AgentToolEntity, ExecutionContext from core.agent.patterns import StrategyFactory -from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.file import File, FileTransferMethod, FileType, file_manager from core.helper.code_executor import CodeExecutor, CodeLanguage from core.llm_generator.output_parser.errors import OutputParserError @@ -1580,34 +1580,17 @@ class LLMNode(Node[LLMNodeData]): result = yield from self._process_tool_outputs(outputs) return result - def _prepare_sandbox_tools(self) -> list[Tool]: - """Prepare sandbox tools.""" - tool_instances = [] + def _get_allow_tools_list(self) -> list[str] | None: + if not self._node_data.tools: + return None - for tool in self._node_data.tools or []: - try: - # Get tool runtime from ToolManager - tool_runtime = ToolManager.get_tool_runtime( - tenant_id=self.tenant_id, - tool_name=tool.tool_name, - provider_id=tool.provider_name, - provider_type=tool.type, - invoke_from=InvokeFrom.AGENT, - credential_id=tool.credential_id, - ) + allow_tools = [] + for tool in self._node_data.tools: + if tool.enabled: + tool_name = f"{tool.tool_name}" + allow_tools.append(tool_name) - # Apply custom description from extra field if available - if tool.extra.get("description") and tool_runtime.entity.description: - tool_runtime.entity.description.llm = ( - tool.extra.get("description") or tool_runtime.entity.description.llm - ) - - tool_instances.append(tool_runtime) - except Exception as e: - logger.warning("Failed to load tool %s: %s", tool, str(e)) - continue - - return tool_instances + return allow_tools or None def _invoke_llm_with_sandbox( self, @@ -1620,7 +1603,7 @@ class LLMNode(Node[LLMNodeData]): if not workflow_execution_id: raise LLMNodeError("workflow_execution_id is required for sandbox runtime mode") - configured_tools = self._prepare_sandbox_tools() + allow_tools = self._get_allow_tools_list() result: LLMGenerationData | None = None @@ -1628,7 +1611,8 @@ class LLMNode(Node[LLMNodeData]): workflow_execution_id=workflow_execution_id, tenant_id=self.tenant_id, user_id=self.user_id, - tools=configured_tools, + node_id=self.id, + allow_tools=allow_tools, ) as sandbox_session: prompt_files = self._extract_prompt_files(variable_pool) model_features = self._get_model_features(model_instance) diff --git a/api/services/app_asset_service.py b/api/services/app_asset_service.py index 3f012ca34e..7b5ee53bdb 100644 --- a/api/services/app_asset_service.py +++ b/api/services/app_asset_service.py @@ -61,6 +61,27 @@ class AppAssetService: session.commit() return assets + @staticmethod + def get_assets(tenant_id: str, app_id: str, *, is_draft: bool) -> AppAssets | None: + with Session(db.engine) as session: + if is_draft: + stmt = session.query(AppAssets).filter( + AppAssets.tenant_id == tenant_id, + AppAssets.app_id == app_id, + AppAssets.version == AppAssets.VERSION_DRAFT, + ) + else: + stmt = ( + session.query(AppAssets) + .filter( + AppAssets.tenant_id == tenant_id, + AppAssets.app_id == app_id, + AppAssets.version != AppAssets.VERSION_DRAFT, + ) + .order_by(AppAssets.created_at.desc()) + ) + return stmt.first() + @staticmethod def get_asset_tree(app_model: App, account_id: str) -> AppAssetFileTree: with Session(db.engine) as session: @@ -284,10 +305,10 @@ class AppAssetService: session.add(published) session.flush() - parser = AssetParser(tree, tenant_id, app_id, storage) + parser = AssetParser(tree, tenant_id, app_id) parser.register( "md", - SkillAssetParser(tenant_id, app_id, publish_id, storage), + SkillAssetParser(tenant_id, app_id, publish_id), ) assets = parser.parse() @@ -306,13 +327,40 @@ class AppAssetService: packager = ZipPackager(storage) zip_bytes = packager.package(assets) - zip_key = AssetPaths.published_zip(tenant_id, app_id, publish_id) + zip_key = AssetPaths.build_zip(tenant_id, app_id, publish_id) storage.save(zip_key, zip_bytes) session.commit() return published + @staticmethod + def build_assets(tenant_id: str, app_id: str, assets: AppAssets) -> None: + tree = assets.asset_tree + + parser = AssetParser(tree, tenant_id, app_id) + parser.register( + "md", + SkillAssetParser(tenant_id, app_id, assets.id), + ) + + parsed_assets = parser.parse() + manifest = SkillManager.generate_tool_manifest( + assets=[asset for asset in parsed_assets if isinstance(asset, SkillAsset)] + ) + + SkillManager.save_tool_manifest( + tenant_id, + app_id, + assets.id, + manifest, + ) + + packager = ZipPackager(storage) + zip_bytes = packager.package(parsed_assets) + zip_key = AssetPaths.build_zip(tenant_id, app_id, assets.id) + storage.save(zip_key, zip_bytes) + @staticmethod def get_file_download_url( app_model: App, diff --git a/api/tests/unit_tests/core/app/layers/test_sandbox_layer.py b/api/tests/unit_tests/core/app/layers/test_sandbox_layer.py index eab45cdb21..76fd68cfcb 100644 --- a/api/tests/unit_tests/core/app/layers/test_sandbox_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_sandbox_layer.py @@ -1,9 +1,11 @@ +from typing import Any from unittest.mock import MagicMock, patch import pytest from core.app.layers.sandbox_layer import SandboxInitializationError, SandboxLayer from core.sandbox import SandboxManager +from core.sandbox.storage.sandbox_storage import SandboxStorage from core.virtual_environment.__base.entities import Arch from core.virtual_environment.__base.virtual_environment import VirtualEnvironment from core.workflow.graph_engine.layers.base import GraphEngineLayerNotInitializedError @@ -12,6 +14,7 @@ from core.workflow.graph_events.graph import ( GraphRunStartedEvent, GraphRunSucceededEvent, ) +from models.app_asset import AppAssets class MockMetadata: @@ -30,16 +33,18 @@ class MockVirtualEnvironment: class MockVMBuilder: - def __init__(self, sandbox: VirtualEnvironment): + _sandbox: VirtualEnvironment + + def __init__(self, sandbox: VirtualEnvironment) -> None: self._sandbox = sandbox - def environments(self, _): + def environments(self, _: object) -> "MockVMBuilder": return self - def initializer(self, _): + def initializer(self, _: object) -> "MockVMBuilder": return self - def build(self): + def build(self) -> VirtualEnvironment: return self._sandbox @@ -51,68 +56,107 @@ def clean_sandbox_manager(): @pytest.fixture -def mock_archive_storage(): - with patch("core.app.layers.sandbox_layer.ArchiveSandboxStorage") as mock_class: - mock_instance = MagicMock() - mock_instance.mount.return_value = False - mock_instance.unmount.return_value = True - mock_class.return_value = mock_instance - yield mock_instance +def mock_sandbox_storage() -> MagicMock: + mock_storage = MagicMock(spec=SandboxStorage) + mock_storage.mount.return_value = False + mock_storage.unmount.return_value = True + return mock_storage -def create_mock_builder(sandbox): +def create_mock_builder(sandbox: Any) -> MockVMBuilder: return MockVMBuilder(sandbox) +def create_layer( + tenant_id: str = "test-tenant", + app_id: str = "test-app", + workflow_version: str = AppAssets.VERSION_DRAFT, + sandbox_id: str = "test-sandbox", + sandbox_storage: Any = None, +) -> SandboxLayer: + if sandbox_storage is None: + sandbox_storage = MagicMock(spec=SandboxStorage) + sandbox_storage.mount.return_value = False + sandbox_storage.unmount.return_value = True + return SandboxLayer( + tenant_id=tenant_id, + app_id=app_id, + workflow_version=workflow_version, + sandbox_id=sandbox_id, + sandbox_storage=sandbox_storage, + ) + + class TestSandboxLayer: - def test_init_with_parameters(self): - layer = SandboxLayer(tenant_id="test-tenant", app_id="test-app", sandbox_id="test-sandbox") + def test_init_with_parameters(self, mock_sandbox_storage: MagicMock) -> None: + layer = create_layer( + tenant_id="test-tenant", + app_id="test-app", + sandbox_id="test-sandbox", + sandbox_storage=mock_sandbox_storage, + ) assert layer._tenant_id == "test-tenant" # pyright: ignore[reportPrivateUsage] assert layer._app_id == "test-app" # pyright: ignore[reportPrivateUsage] assert layer._sandbox_id == "test-sandbox" # pyright: ignore[reportPrivateUsage] - def test_sandbox_property_raises_when_not_initialized(self): - layer = SandboxLayer(tenant_id="test-tenant", app_id="test-app", sandbox_id="test-sandbox") + def test_sandbox_property_raises_when_not_initialized(self, mock_sandbox_storage: MagicMock) -> None: + layer = create_layer(sandbox_storage=mock_sandbox_storage) with pytest.raises(RuntimeError) as exc_info: _ = layer.sandbox assert "Sandbox not found" in str(exc_info.value) - def test_sandbox_property_returns_sandbox_after_initialization(self, mock_archive_storage): + def test_sandbox_property_returns_sandbox_after_initialization(self, mock_sandbox_storage: MagicMock) -> None: sandbox_id = "test-exec-id" - layer = SandboxLayer(tenant_id="test-tenant", app_id="test-app", sandbox_id=sandbox_id) + layer = create_layer(sandbox_id=sandbox_id, sandbox_storage=mock_sandbox_storage) mock_sandbox = MockVirtualEnvironment() - with patch( - "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", - return_value=create_mock_builder(mock_sandbox), + with ( + patch( + "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", + return_value=create_mock_builder(mock_sandbox), + ), + patch("services.app_asset_service.AppAssetService.get_assets", return_value=None), ): layer.on_graph_start() assert layer.sandbox is mock_sandbox - def test_on_graph_start_creates_sandbox_and_registers_with_manager(self, mock_archive_storage): + def test_on_graph_start_creates_sandbox_and_registers_with_manager(self, mock_sandbox_storage: MagicMock) -> None: sandbox_id = "test-exec-123" - layer = SandboxLayer(tenant_id="test-tenant-123", app_id="test-app-123", sandbox_id=sandbox_id) + layer = create_layer( + tenant_id="test-tenant-123", + app_id="test-app-123", + sandbox_id=sandbox_id, + sandbox_storage=mock_sandbox_storage, + ) mock_sandbox = MockVirtualEnvironment() - with patch( - "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", - return_value=create_mock_builder(mock_sandbox), - ) as mock_create: + with ( + patch( + "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", + return_value=create_mock_builder(mock_sandbox), + ) as mock_create, + patch("services.app_asset_service.AppAssetService.get_assets", return_value=None), + ): layer.on_graph_start() mock_create.assert_called_once_with("test-tenant-123") assert SandboxManager.get(sandbox_id) is mock_sandbox - def test_on_graph_start_raises_sandbox_initialization_error_on_failure(self): - layer = SandboxLayer(tenant_id="test-tenant", app_id="test-app", sandbox_id="test-sandbox") + def test_on_graph_start_raises_sandbox_initialization_error_on_failure( + self, mock_sandbox_storage: MagicMock + ) -> None: + layer = create_layer(sandbox_storage=mock_sandbox_storage) - with patch( - "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", - side_effect=Exception("Sandbox provider not available"), + with ( + patch( + "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", + side_effect=Exception("Sandbox provider not available"), + ), + patch("services.app_asset_service.AppAssetService.get_assets", return_value=None), ): with pytest.raises(SandboxInitializationError) as exc_info: layer.on_graph_start() @@ -120,22 +164,27 @@ class TestSandboxLayer: assert "Failed to initialize sandbox" in str(exc_info.value) assert "Sandbox provider not available" in str(exc_info.value) - def test_on_event_is_noop(self): - layer = SandboxLayer(tenant_id="test-tenant", app_id="test-app", sandbox_id="test-sandbox") + def test_on_event_is_noop(self, mock_sandbox_storage: MagicMock) -> None: + layer = create_layer(sandbox_storage=mock_sandbox_storage) layer.on_event(GraphRunStartedEvent()) layer.on_event(GraphRunSucceededEvent(outputs={})) layer.on_event(GraphRunFailedEvent(error="test error", exceptions_count=1)) - def test_on_graph_end_releases_sandbox_and_unregisters_from_manager(self, mock_archive_storage): + def test_on_graph_end_releases_sandbox_and_unregisters_from_manager( + self, mock_sandbox_storage: MagicMock + ) -> None: sandbox_id = "test-exec-456" - layer = SandboxLayer(tenant_id="test-tenant", app_id="test-app", sandbox_id=sandbox_id) + layer = create_layer(sandbox_id=sandbox_id, sandbox_storage=mock_sandbox_storage) mock_sandbox = MagicMock(spec=VirtualEnvironment) mock_sandbox.metadata = MockMetadata() - with patch( - "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", - return_value=create_mock_builder(mock_sandbox), + with ( + patch( + "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", + return_value=create_mock_builder(mock_sandbox), + ), + patch("services.app_asset_service.AppAssetService.get_assets", return_value=None), ): layer.on_graph_start() @@ -146,15 +195,18 @@ class TestSandboxLayer: mock_sandbox.release_environment.assert_called_once() assert not SandboxManager.has(sandbox_id) - def test_on_graph_end_releases_sandbox_even_on_error(self, mock_archive_storage): + def test_on_graph_end_releases_sandbox_even_on_error(self, mock_sandbox_storage: MagicMock) -> None: sandbox_id = "test-exec-789" - layer = SandboxLayer(tenant_id="test-tenant", app_id="test-app", sandbox_id=sandbox_id) + layer = create_layer(sandbox_id=sandbox_id, sandbox_storage=mock_sandbox_storage) mock_sandbox = MagicMock(spec=VirtualEnvironment) mock_sandbox.metadata = MockMetadata() - with patch( - "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", - return_value=create_mock_builder(mock_sandbox), + with ( + patch( + "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", + return_value=create_mock_builder(mock_sandbox), + ), + patch("services.app_asset_service.AppAssetService.get_assets", return_value=None), ): layer.on_graph_start() @@ -163,16 +215,19 @@ class TestSandboxLayer: mock_sandbox.release_environment.assert_called_once() assert not SandboxManager.has(sandbox_id) - def test_on_graph_end_handles_release_failure_gracefully(self, mock_archive_storage): + def test_on_graph_end_handles_release_failure_gracefully(self, mock_sandbox_storage: MagicMock) -> None: sandbox_id = "test-exec-fail" - layer = SandboxLayer(tenant_id="test-tenant", app_id="test-app", sandbox_id=sandbox_id) + layer = create_layer(sandbox_id=sandbox_id, sandbox_storage=mock_sandbox_storage) mock_sandbox = MagicMock(spec=VirtualEnvironment) mock_sandbox.metadata = MockMetadata() mock_sandbox.release_environment.side_effect = Exception("Container already removed") - with patch( - "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", - return_value=create_mock_builder(mock_sandbox), + with ( + patch( + "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", + return_value=create_mock_builder(mock_sandbox), + ), + patch("services.app_asset_service.AppAssetService.get_assets", return_value=None), ): layer.on_graph_start() @@ -180,20 +235,23 @@ class TestSandboxLayer: mock_sandbox.release_environment.assert_called_once() - def test_on_graph_end_noop_when_sandbox_not_registered(self): - layer = SandboxLayer(tenant_id="test-tenant", app_id="test-app", sandbox_id="nonexistent-sandbox") + def test_on_graph_end_noop_when_sandbox_not_registered(self, mock_sandbox_storage: MagicMock) -> None: + layer = create_layer(sandbox_id="nonexistent-sandbox", sandbox_storage=mock_sandbox_storage) layer.on_graph_end(error=None) - def test_on_graph_end_is_idempotent(self, mock_archive_storage): + def test_on_graph_end_is_idempotent(self, mock_sandbox_storage: MagicMock) -> None: sandbox_id = "test-exec-idempotent" - layer = SandboxLayer(tenant_id="test-tenant", app_id="test-app", sandbox_id=sandbox_id) + layer = create_layer(sandbox_id=sandbox_id, sandbox_storage=mock_sandbox_storage) mock_sandbox = MagicMock(spec=VirtualEnvironment) mock_sandbox.metadata = MockMetadata() - with patch( - "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", - return_value=create_mock_builder(mock_sandbox), + with ( + patch( + "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", + return_value=create_mock_builder(mock_sandbox), + ), + patch("services.app_asset_service.AppAssetService.get_assets", return_value=None), ): layer.on_graph_start() @@ -202,8 +260,8 @@ class TestSandboxLayer: mock_sandbox.release_environment.assert_called_once() - def test_layer_inherits_from_graph_engine_layer(self): - layer = SandboxLayer(tenant_id="test-tenant", app_id="test-app", sandbox_id="test-sandbox") + def test_layer_inherits_from_graph_engine_layer(self, mock_sandbox_storage: MagicMock) -> None: + layer = create_layer(sandbox_storage=mock_sandbox_storage) with pytest.raises(GraphEngineLayerNotInitializedError): _ = layer.graph_runtime_state @@ -212,15 +270,23 @@ class TestSandboxLayer: class TestSandboxLayerIntegration: - def test_full_lifecycle_with_mocked_provider(self, mock_archive_storage): + def test_full_lifecycle_with_mocked_provider(self, mock_sandbox_storage: MagicMock) -> None: sandbox_id = "integration-test-exec" - layer = SandboxLayer(tenant_id="integration-tenant", app_id="integration-app", sandbox_id=sandbox_id) + layer = create_layer( + tenant_id="integration-tenant", + app_id="integration-app", + sandbox_id=sandbox_id, + sandbox_storage=mock_sandbox_storage, + ) mock_sandbox = MagicMock(spec=VirtualEnvironment) mock_sandbox.metadata = MockMetadata(sandbox_id="integration-sandbox") - with patch( - "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", - return_value=create_mock_builder(mock_sandbox), + with ( + patch( + "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", + return_value=create_mock_builder(mock_sandbox), + ), + patch("services.app_asset_service.AppAssetService.get_assets", return_value=None), ): layer.on_graph_start() @@ -232,15 +298,23 @@ class TestSandboxLayerIntegration: assert not SandboxManager.has(sandbox_id) mock_sandbox.release_environment.assert_called_once() - def test_lifecycle_with_workflow_error(self, mock_archive_storage): + def test_lifecycle_with_workflow_error(self, mock_sandbox_storage: MagicMock) -> None: sandbox_id = "integration-error-test" - layer = SandboxLayer(tenant_id="error-tenant", app_id="error-app", sandbox_id=sandbox_id) + layer = create_layer( + tenant_id="error-tenant", + app_id="error-app", + sandbox_id=sandbox_id, + sandbox_storage=mock_sandbox_storage, + ) mock_sandbox = MagicMock(spec=VirtualEnvironment) mock_sandbox.metadata = MockMetadata() - with patch( - "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", - return_value=create_mock_builder(mock_sandbox), + with ( + patch( + "services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder", + return_value=create_mock_builder(mock_sandbox), + ), + patch("services.app_asset_service.AppAssetService.get_assets", return_value=None), ): layer.on_graph_start()