feat: implement DSL prediction API and virtual environment base classes

This commit is contained in:
Yeuoly 2025-12-30 15:24:54 +08:00
parent 2399d00d86
commit a513ab9a59
4 changed files with 206 additions and 0 deletions

View File

@ -0,0 +1,60 @@
import json
import httpx
import yaml
from flask_restx import Resource, reqparse
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.plugin.impl.exc import PluginPermissionDeniedError
from extensions.ext_database import db
from libs.login import current_account_with_tenant, login_required
from models.model import App
from models.workflow import Workflow
from services.app_dsl_service import AppDslService
@console_ns.route("/workspaces/current/dsl/predict")
class DSLPredictApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("app_id", type=str, required=True, location="json")
.add_argument("current_node_id", type=str, required=True, location="json")
)
args = parser.parse_args()
app_id: str = args["app_id"]
current_node_id: str = args["current_node_id"]
with Session(db.engine) as session:
app = session.query(App).filter_by(id=app_id).first()
workflow = session.query(Workflow).filter_by(app_id=app_id, version=Workflow.VERSION_DRAFT).first()
try:
i = 0
for node_id, _ in workflow.walk_nodes():
if node_id == current_node_id:
break
i += 1
dsl = yaml.safe_load(AppDslService.export_dsl(app_model=app))
response = httpx.post(
"http://spark-832c:8000/predict",
json={"graph_data": dsl, "source_node_index": i},
)
return {
"nodes": json.loads(response.json()),
}
except PluginPermissionDeniedError as e:
raise ValueError(e.description) from e

View File

@ -0,0 +1,38 @@
from enum import StrEnum
from pydantic import BaseModel, Field
class Arch(StrEnum):
"""
Architecture types for virtual environments.
"""
ARM64 = "arm64"
AMD64 = "amd64"
class Metadata(BaseModel):
"""
Returned metadata about a virtual environment.
"""
id: str = Field(description="The unique identifier of the virtual environment.")
arch: Arch = Field(description="Which architecture was used to create the virtual environment.")
class ConnectionHandle(BaseModel):
"""
Handle for managing connections to the virtual environment.
"""
id: str = Field(description="The unique identifier of the connection handle.")
class CommandStatus(BaseModel):
"""
Status of a command executed in the virtual environment.
"""
pid: int = Field(description="The process ID of the command.")
return_code: int = Field(description="The return code of the command execution.")

View File

@ -0,0 +1,108 @@
from abc import ABC, abstractmethod
from collections.abc import Mapping
from io import BytesIO
from typing import Any
from core.virtual_environment.__base.entities import CommandStatus, ConnectionHandle, Metadata
class VirtualEnvironment(ABC):
"""
Base class for virtual environment implementations.
"""
@abstractmethod
def request_environment(self, options: Mapping[str, Any]) -> Metadata:
"""
Request a virtual environment with the given options.
Args:
options (Mapping[str, Any]): Options for requesting the virtual environment.
Those options are implementation-specific, which can be defined in environment
Returns:
Metadata: Metadata about the requested virtual environment.
Raises:
Exception: If the environment cannot be requested.
"""
@abstractmethod
def upload_file(self, environment_id: str, destination_path: str, content: BytesIO) -> None:
"""
Upload a file to the virtual environment.
Args:
environment_id (str): The unique identifier of the virtual environment.
destination_path (str): The destination path in the virtual environment.
content (BytesIO): The content of the file to upload.
Raises:
Exception: If the file cannot be uploaded.
"""
@abstractmethod
def establish_connection(self, environment_id: str) -> ConnectionHandle:
"""
Establish a connection to the virtual environment.
Args:
environment_id (str): The unique identifier of the virtual environment.
Returns:
ConnectionHandle: Handle for managing the connection to the virtual environment.
Raises:
Exception: If the connection cannot be established.
"""
@abstractmethod
def release_connection(self, connection_handle: ConnectionHandle) -> None:
"""
Release the connection to the virtual environment.
Args:
connection_handle (ConnectionHandle): The handle for managing the connection.
Raises:
Exception: If the connection cannot be released.
"""
@abstractmethod
def release_environment(self, environment_id: str) -> None:
"""
Release the virtual environment.
Args:
environment_id (str): The unique identifier of the virtual environment.
Raises:
Exception: If the environment cannot be released.
Multiple calls to `release_environment` with the same `environment_id` is acceptable.
"""
@abstractmethod
def execute_command(self, connection_handle: ConnectionHandle, command: list[str]) -> tuple[int, int, int, int]:
"""
Execute a command in the virtual environment.
Args:
connection_handle (ConnectionHandle): The handle for managing the connection.
command (list[str]): The command to execute as a list of strings.
Returns:
tuple[int, int, int, int]: A tuple containing pid and 3 handle to os.pipe(): (stdin, stdout, stderr).
After exuection, the 3 handles will be closed by `execute_command` itself.
"""
@abstractmethod
def get_command_status(self, connection_handle: ConnectionHandle, pid: int) -> CommandStatus:
"""
Get the status of a command executed in the virtual environment.
Args:
connection_handle (ConnectionHandle): The handle for managing the connection.
pid (int): The process ID of the command.
Returns:
CommandStatus: The status of the command execution.
"""