mirror of
https://github.com/langgenius/dify.git
synced 2026-05-10 14:14:17 +08:00
feat: implement DSL prediction API and virtual environment base classes
This commit is contained in:
parent
2399d00d86
commit
a513ab9a59
60
api/controllers/console/workspace/dsl.py
Normal file
60
api/controllers/console/workspace/dsl.py
Normal file
@ -0,0 +1,60 @@
|
||||
import json
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
from flask_restx import Resource, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.plugin.impl.exc import PluginPermissionDeniedError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import App
|
||||
from models.workflow import Workflow
|
||||
from services.app_dsl_service import AppDslService
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/dsl/predict")
|
||||
class DSLPredictApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("app_id", type=str, required=True, location="json")
|
||||
.add_argument("current_node_id", type=str, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
app_id: str = args["app_id"]
|
||||
current_node_id: str = args["current_node_id"]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
app = session.query(App).filter_by(id=app_id).first()
|
||||
workflow = session.query(Workflow).filter_by(app_id=app_id, version=Workflow.VERSION_DRAFT).first()
|
||||
|
||||
try:
|
||||
i = 0
|
||||
for node_id, _ in workflow.walk_nodes():
|
||||
if node_id == current_node_id:
|
||||
break
|
||||
i += 1
|
||||
|
||||
dsl = yaml.safe_load(AppDslService.export_dsl(app_model=app))
|
||||
|
||||
response = httpx.post(
|
||||
"http://spark-832c:8000/predict",
|
||||
json={"graph_data": dsl, "source_node_index": i},
|
||||
)
|
||||
return {
|
||||
"nodes": json.loads(response.json()),
|
||||
}
|
||||
except PluginPermissionDeniedError as e:
|
||||
raise ValueError(e.description) from e
|
||||
38
api/core/virtual_environment/__base/entities.py
Normal file
38
api/core/virtual_environment/__base/entities.py
Normal file
@ -0,0 +1,38 @@
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Arch(StrEnum):
|
||||
"""
|
||||
Architecture types for virtual environments.
|
||||
"""
|
||||
|
||||
ARM64 = "arm64"
|
||||
AMD64 = "amd64"
|
||||
|
||||
|
||||
class Metadata(BaseModel):
|
||||
"""
|
||||
Returned metadata about a virtual environment.
|
||||
"""
|
||||
|
||||
id: str = Field(description="The unique identifier of the virtual environment.")
|
||||
arch: Arch = Field(description="Which architecture was used to create the virtual environment.")
|
||||
|
||||
|
||||
class ConnectionHandle(BaseModel):
|
||||
"""
|
||||
Handle for managing connections to the virtual environment.
|
||||
"""
|
||||
|
||||
id: str = Field(description="The unique identifier of the connection handle.")
|
||||
|
||||
|
||||
class CommandStatus(BaseModel):
|
||||
"""
|
||||
Status of a command executed in the virtual environment.
|
||||
"""
|
||||
|
||||
pid: int = Field(description="The process ID of the command.")
|
||||
return_code: int = Field(description="The return code of the command execution.")
|
||||
108
api/core/virtual_environment/__base/virtual_environment.py
Normal file
108
api/core/virtual_environment/__base/virtual_environment.py
Normal file
@ -0,0 +1,108 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
|
||||
from core.virtual_environment.__base.entities import CommandStatus, ConnectionHandle, Metadata
|
||||
|
||||
|
||||
class VirtualEnvironment(ABC):
|
||||
"""
|
||||
Base class for virtual environment implementations.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def request_environment(self, options: Mapping[str, Any]) -> Metadata:
|
||||
"""
|
||||
Request a virtual environment with the given options.
|
||||
|
||||
Args:
|
||||
options (Mapping[str, Any]): Options for requesting the virtual environment.
|
||||
Those options are implementation-specific, which can be defined in environment
|
||||
|
||||
Returns:
|
||||
Metadata: Metadata about the requested virtual environment.
|
||||
|
||||
Raises:
|
||||
Exception: If the environment cannot be requested.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def upload_file(self, environment_id: str, destination_path: str, content: BytesIO) -> None:
|
||||
"""
|
||||
Upload a file to the virtual environment.
|
||||
|
||||
Args:
|
||||
environment_id (str): The unique identifier of the virtual environment.
|
||||
destination_path (str): The destination path in the virtual environment.
|
||||
content (BytesIO): The content of the file to upload.
|
||||
|
||||
Raises:
|
||||
Exception: If the file cannot be uploaded.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def establish_connection(self, environment_id: str) -> ConnectionHandle:
|
||||
"""
|
||||
Establish a connection to the virtual environment.
|
||||
|
||||
Args:
|
||||
environment_id (str): The unique identifier of the virtual environment.
|
||||
|
||||
Returns:
|
||||
ConnectionHandle: Handle for managing the connection to the virtual environment.
|
||||
|
||||
Raises:
|
||||
Exception: If the connection cannot be established.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def release_connection(self, connection_handle: ConnectionHandle) -> None:
|
||||
"""
|
||||
Release the connection to the virtual environment.
|
||||
|
||||
Args:
|
||||
connection_handle (ConnectionHandle): The handle for managing the connection.
|
||||
|
||||
Raises:
|
||||
Exception: If the connection cannot be released.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def release_environment(self, environment_id: str) -> None:
|
||||
"""
|
||||
Release the virtual environment.
|
||||
|
||||
Args:
|
||||
environment_id (str): The unique identifier of the virtual environment.
|
||||
|
||||
Raises:
|
||||
Exception: If the environment cannot be released.
|
||||
Multiple calls to `release_environment` with the same `environment_id` is acceptable.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def execute_command(self, connection_handle: ConnectionHandle, command: list[str]) -> tuple[int, int, int, int]:
|
||||
"""
|
||||
Execute a command in the virtual environment.
|
||||
|
||||
Args:
|
||||
connection_handle (ConnectionHandle): The handle for managing the connection.
|
||||
command (list[str]): The command to execute as a list of strings.
|
||||
|
||||
Returns:
|
||||
tuple[int, int, int, int]: A tuple containing pid and 3 handle to os.pipe(): (stdin, stdout, stderr).
|
||||
After exuection, the 3 handles will be closed by `execute_command` itself.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_command_status(self, connection_handle: ConnectionHandle, pid: int) -> CommandStatus:
|
||||
"""
|
||||
Get the status of a command executed in the virtual environment.
|
||||
|
||||
Args:
|
||||
connection_handle (ConnectionHandle): The handle for managing the connection.
|
||||
pid (int): The process ID of the command.
|
||||
Returns:
|
||||
CommandStatus: The status of the command execution.
|
||||
"""
|
||||
Loading…
Reference in New Issue
Block a user