diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 2534c553f2..bb1b02656b 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -30,6 +30,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager from core.prompt.utils.get_thread_messages_length import get_thread_messages_length from core.repositories import DifyCoreRepositoryFactory +from core.sandbox.storage.archive_storage import ArchiveSandboxStorage from core.workflow.repositories.draft_variable_repository import ( DraftVariableSaverFactory, ) @@ -516,7 +517,19 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): graph_engine_layers: tuple = () if workflow.get_feature(WorkflowFeatures.SANDBOX).enabled: - graph_engine_layers = (SandboxLayer(tenant_id=application_generate_entity.app_config.tenant_id),) + if application_generate_entity.workflow_run_id is None: + raise ValueError("workflow_run_id is required when sandbox is enabled") + graph_engine_layers = ( + SandboxLayer( + tenant_id=application_generate_entity.app_config.tenant_id, + app_id=application_generate_entity.app_config.app_id, + sandbox_id=application_generate_entity.workflow_run_id, + sandbox_storage=ArchiveSandboxStorage( + tenant_id=application_generate_entity.app_config.tenant_id, + sandbox_id=application_generate_entity.workflow_run_id, + ), + ), + ) # Determine system_user_id based on invocation source is_external_api_call = application_generate_entity.invoke_from in { diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 9da4e71ba3..c9d5ca46e8 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -28,6 +28,7 @@ from core.helper.trace_id_helper import extract_external_trace_id_from_args from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager from core.repositories import DifyCoreRepositoryFactory +from core.sandbox.storage.archive_storage import ArchiveSandboxStorage from core.workflow.graph_engine.layers.base import GraphEngineLayer from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository @@ -496,6 +497,10 @@ class WorkflowAppGenerator(BaseAppGenerator): tenant_id=application_generate_entity.app_config.tenant_id, app_id=application_generate_entity.app_config.app_id, sandbox_id=application_generate_entity.workflow_execution_id, + sandbox_storage=ArchiveSandboxStorage( + tenant_id=application_generate_entity.app_config.tenant_id, + sandbox_id=application_generate_entity.workflow_execution_id, + ), ), ) diff --git a/api/core/app/layers/sandbox_layer.py b/api/core/app/layers/sandbox_layer.py index b9cebb4c8d..6ea1e1d3cb 100644 --- a/api/core/app/layers/sandbox_layer.py +++ b/api/core/app/layers/sandbox_layer.py @@ -1,6 +1,7 @@ import logging from core.sandbox import ArchiveSandboxStorage, SandboxManager +from core.sandbox.storage.sandbox_storage import SandboxStorage from core.virtual_environment.__base.virtual_environment import VirtualEnvironment from core.workflow.graph_engine.layers.base import GraphEngineLayer from core.workflow.graph_events.base import GraphEngineEvent @@ -14,11 +15,12 @@ class SandboxInitializationError(Exception): class SandboxLayer(GraphEngineLayer): - def __init__(self, tenant_id: str, app_id: str, sandbox_id: str) -> None: + def __init__(self, tenant_id: str, app_id: str, sandbox_id: str, sandbox_storage: SandboxStorage) -> None: super().__init__() self._tenant_id = tenant_id self._app_id = app_id self._sandbox_id = sandbox_id + self._sandbox_storage = sandbox_storage @property def sandbox(self) -> VirtualEnvironment: @@ -81,12 +83,7 @@ class SandboxLayer(GraphEngineLayer): ) try: - sandbox_storage = ArchiveSandboxStorage( - storage=storage, - tenant_id=self._tenant_id, - sandbox_id=self._sandbox_id, - ) - sandbox_storage.unmount(sandbox) + self._sandbox_storage.unmount(sandbox) logger.info("Sandbox files persisted, sandbox_id=%s", self._sandbox_id) except Exception: logger.exception("Failed to persist sandbox files, sandbox_id=%s", self._sandbox_id) diff --git a/api/core/sandbox/initializer/app_assets_initializer.py b/api/core/sandbox/initializer/app_assets_initializer.py index c07b073627..60cd209857 100644 --- a/api/core/sandbox/initializer/app_assets_initializer.py +++ b/api/core/sandbox/initializer/app_assets_initializer.py @@ -7,7 +7,7 @@ from core.virtual_environment.__base.helpers import execute, with_connection from core.virtual_environment.__base.virtual_environment import VirtualEnvironment from extensions.ext_database import db from extensions.ext_storage import storage -from models.app_asset import AppAssetDraft +from models.app_asset import AppAssets from ..constants import APP_ASSETS_PATH, APP_ASSETS_ZIP_PATH from .base import SandboxInitializer @@ -26,7 +26,7 @@ class AppAssetsInitializer(SandboxInitializer): logger.debug("No published assets for app_id=%s, skipping", self._app_id) return - zip_key = AppAssetDraft.get_published_storage_key(self._tenant_id, self._app_id, published.id) + zip_key = AppAssets.get_published_storage_key(self._tenant_id, self._app_id, published.id) try: zip_data = storage.load_once(zip_key) except Exception: @@ -73,15 +73,15 @@ class AppAssetsInitializer(SandboxInitializer): published.id, ) - def _get_latest_published(self) -> AppAssetDraft | None: + def _get_latest_published(self) -> AppAssets | None: with Session(db.engine) as session: return ( - session.query(AppAssetDraft) + session.query(AppAssets) .filter( - AppAssetDraft.tenant_id == self._tenant_id, - AppAssetDraft.app_id == self._app_id, - AppAssetDraft.version != AppAssetDraft.VERSION_DRAFT, + AppAssets.tenant_id == self._tenant_id, + AppAssets.app_id == self._app_id, + AppAssets.version != AppAssets.VERSION_DRAFT, ) - .order_by(AppAssetDraft.created_at.desc()) + .order_by(AppAssets.created_at.desc()) .first() ) diff --git a/api/core/sandbox/storage/archive_storage.py b/api/core/sandbox/storage/archive_storage.py index 61d5c10643..fd0836159f 100644 --- a/api/core/sandbox/storage/archive_storage.py +++ b/api/core/sandbox/storage/archive_storage.py @@ -3,7 +3,7 @@ from io import BytesIO from core.virtual_environment.__base.helpers import try_execute from core.virtual_environment.__base.virtual_environment import VirtualEnvironment -from extensions.ext_storage import Storage +from extensions.ext_storage import storage from .sandbox_storage import SandboxStorage @@ -14,7 +14,7 @@ WORKSPACE_DIR = "." class ArchiveSandboxStorage(SandboxStorage): - def __init__(self, storage: Storage, tenant_id: str, sandbox_id: str): + def __init__(self, tenant_id: str, sandbox_id: str): self._storage = storage self._tenant_id = tenant_id self._sandbox_id = sandbox_id diff --git a/api/models/__init__.py b/api/models/__init__.py index 44eecac4ba..8b64802107 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -9,7 +9,7 @@ from .account import ( TenantStatus, ) from .api_based_extension import APIBasedExtension, APIBasedExtensionPoint -from .app_asset import AppAssetDraft +from .app_asset import AppAssets from .dataset import ( AppDatasetJoin, Dataset, @@ -124,7 +124,7 @@ __all__ = [ "App", "AppAnnotationHitHistory", "AppAnnotationSetting", - "AppAssetDraft", + "AppAssets", "AppDatasetJoin", "AppMCPServer", "AppMode", diff --git a/api/models/app_asset.py b/api/models/app_asset.py index 2d66de9ee2..4c03cc3195 100644 --- a/api/models/app_asset.py +++ b/api/models/app_asset.py @@ -11,11 +11,11 @@ from .base import Base from .types import LongText, StringUUID -class AppAssetDraft(Base): - __tablename__ = "app_asset_drafts" +class AppAssets(Base): + __tablename__ = "app_assets" __table_args__ = ( - sa.PrimaryKeyConstraint("id", name="app_asset_draft_pkey"), - sa.Index("app_asset_draft_version_idx", "tenant_id", "app_id", "version"), + sa.PrimaryKeyConstraint("id", name="app_assets_pkey"), + sa.Index("app_assets_version_idx", "tenant_id", "app_id", "version"), ) VERSION_DRAFT = "draft" @@ -51,8 +51,8 @@ class AppAssetDraft(Base): return f"app_assets/{tenant_id}/{app_id}/draft/{node_id}" @staticmethod - def get_published_storage_key(tenant_id: str, app_id: str, draft_id: str) -> str: - return f"app_assets/{tenant_id}/{app_id}/published/{draft_id}.zip" + def get_published_storage_key(tenant_id: str, app_id: str, assets_id: str) -> str: + return f"app_assets/{tenant_id}/{app_id}/published/{assets_id}.zip" def __repr__(self) -> str: - return f"" + return f"" diff --git a/api/services/app_asset_service.py b/api/services/app_asset_service.py index 297adfd9b0..156df7516f 100644 --- a/api/services/app_asset_service.py +++ b/api/services/app_asset_service.py @@ -17,7 +17,7 @@ from core.app.entities.app_asset_entities import ( from extensions.ext_database import db from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now -from models.app_asset import AppAssetDraft +from models.app_asset import AppAssets from models.model import App from .errors.app_asset import ( @@ -31,22 +31,22 @@ logger = logging.getLogger(__name__) class AppAssetService: @staticmethod - def get_or_create_draft(session: Session, app_model: App, account_id: str) -> AppAssetDraft: + def get_or_create_draft(session: Session, app_model: App, account_id: str) -> AppAssets: draft = ( - session.query(AppAssetDraft) + session.query(AppAssets) .filter( - AppAssetDraft.tenant_id == app_model.tenant_id, - AppAssetDraft.app_id == app_model.id, - AppAssetDraft.version == AppAssetDraft.VERSION_DRAFT, + AppAssets.tenant_id == app_model.tenant_id, + AppAssets.app_id == app_model.id, + AppAssets.version == AppAssets.VERSION_DRAFT, ) .first() ) if not draft: - draft = AppAssetDraft( + draft = AppAssets( id=str(uuid4()), tenant_id=app_model.tenant_id, app_id=app_model.id, - version=AppAssetDraft.VERSION_DRAFT, + version=AppAssets.VERSION_DRAFT, created_by=account_id, ) session.add(draft) @@ -108,7 +108,7 @@ class AppAssetService: except TreePathConflictError as e: raise AppAssetPathConflictError(str(e)) from e - storage_key = AppAssetDraft.get_storage_key(app_model.tenant_id, app_model.id, node_id) + storage_key = AppAssets.get_storage_key(app_model.tenant_id, app_model.id, node_id) storage.save(storage_key, content) draft.asset_tree = tree @@ -127,7 +127,7 @@ class AppAssetService: if not node or node.node_type != AssetNodeType.FILE: raise AppAssetNodeNotFoundError(f"File node {node_id} not found") - storage_key = AppAssetDraft.get_storage_key(app_model.tenant_id, app_model.id, node_id) + storage_key = AppAssets.get_storage_key(app_model.tenant_id, app_model.id, node_id) return storage.load_once(storage_key) @staticmethod @@ -148,7 +148,7 @@ class AppAssetService: except TreeNodeNotFoundError as e: raise AppAssetNodeNotFoundError(str(e)) from e - storage_key = AppAssetDraft.get_storage_key(app_model.tenant_id, app_model.id, node_id) + storage_key = AppAssets.get_storage_key(app_model.tenant_id, app_model.id, node_id) storage.save(storage_key, content) draft.asset_tree = tree @@ -241,7 +241,7 @@ class AppAssetService: raise AppAssetNodeNotFoundError(str(e)) from e for nid in removed_ids: - storage_key = AppAssetDraft.get_storage_key(app_model.tenant_id, app_model.id, nid) + storage_key = AppAssets.get_storage_key(app_model.tenant_id, app_model.id, nid) try: storage.delete(storage_key) except Exception: @@ -252,7 +252,7 @@ class AppAssetService: session.commit() @staticmethod - def publish(app_model: App, account_id: str) -> AppAssetDraft: + def publish(app_model: App, account_id: str) -> AppAssets: with Session(db.engine, expire_on_commit=False) as session: draft = AppAssetService.get_or_create_draft(session, app_model, account_id) tree = draft.asset_tree @@ -261,12 +261,12 @@ class AppAssetService: zip_buffer = io.BytesIO() with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf: for file_node in tree.walk_files(): - storage_key = AppAssetDraft.get_storage_key(app_model.tenant_id, app_model.id, file_node.id) + storage_key = AppAssets.get_storage_key(app_model.tenant_id, app_model.id, file_node.id) content = storage.load_once(storage_key) archive_path = tree.get_path(file_node.id).lstrip("/") zf.writestr(archive_path, content) - published = AppAssetDraft( + published = AppAssets( id=str(uuid4()), tenant_id=app_model.tenant_id, app_id=app_model.id, @@ -277,7 +277,7 @@ class AppAssetService: session.add(published) session.flush() - zip_key = AppAssetDraft.get_published_storage_key(app_model.tenant_id, app_model.id, published.id) + zip_key = AppAssets.get_published_storage_key(app_model.tenant_id, app_model.id, published.id) storage.save(zip_key, zip_buffer.getvalue()) session.commit() @@ -287,23 +287,23 @@ class AppAssetService: @staticmethod def get_published_file_content( app_model: App, - draft_id: str, + assets_id: str, file_path: str, ) -> bytes: with Session(db.engine) as session: published = ( - session.query(AppAssetDraft) + session.query(AppAssets) .filter( - AppAssetDraft.tenant_id == app_model.tenant_id, - AppAssetDraft.app_id == app_model.id, - AppAssetDraft.id == draft_id, + AppAssets.tenant_id == app_model.tenant_id, + AppAssets.app_id == app_model.id, + AppAssets.id == assets_id, ) .first() ) - if not published or published.version == AppAssetDraft.VERSION_DRAFT: - raise AppAssetNodeNotFoundError(f"Published version {draft_id} not found") + if not published or published.version == AppAssets.VERSION_DRAFT: + raise AppAssetNodeNotFoundError(f"Published version {assets_id} not found") - zip_key = AppAssetDraft.get_published_storage_key(app_model.tenant_id, app_model.id, draft_id) + zip_key = AppAssets.get_published_storage_key(app_model.tenant_id, app_model.id, assets_id) zip_data = storage.load_once(zip_key) archive_path = file_path.lstrip("/")