diff --git a/api/core/tools/builtin_tool/providers/sandbox/__init__.py b/api/core/tools/builtin_tool/providers/sandbox/__init__.py new file mode 100644 index 0000000000..039f0311ca --- /dev/null +++ b/api/core/tools/builtin_tool/providers/sandbox/__init__.py @@ -0,0 +1,3 @@ +from core.tools.builtin_tool.providers.sandbox.bash_tool import SandboxBashTool + +__all__ = ["SandboxBashTool"] diff --git a/api/core/tools/builtin_tool/providers/sandbox/bash_tool.py b/api/core/tools/builtin_tool/providers/sandbox/bash_tool.py new file mode 100644 index 0000000000..19accdcfd0 --- /dev/null +++ b/api/core/tools/builtin_tool/providers/sandbox/bash_tool.py @@ -0,0 +1,96 @@ +from collections.abc import Generator +from typing import Any + +from core.tools.__base.tool import Tool +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ( + ToolDescription, + ToolEntity, + ToolIdentity, + ToolInvokeMessage, + ToolParameter, + ToolProviderType, +) +from core.virtual_environment.__base.virtual_environment import VirtualEnvironment + +SANDBOX_BASH_TOOL_NAME = "bash" +SANDBOX_BASH_TOOL_PROVIDER = "sandbox" +COMMAND_TIMEOUT_SECONDS = 60 + + +class SandboxBashTool(Tool): + def __init__(self, sandbox: VirtualEnvironment, tenant_id: str): + self._sandbox = sandbox + + entity = ToolEntity( + identity=ToolIdentity( + author="Dify", + name=SANDBOX_BASH_TOOL_NAME, + label=I18nObject(en_US="Bash", zh_Hans="Bash"), + provider=SANDBOX_BASH_TOOL_PROVIDER, + ), + parameters=[ + ToolParameter.get_simple_instance( + name="command", + llm_description="The bash command to execute in the sandbox environment", + typ=ToolParameter.ToolParameterType.STRING, + required=True, + ), + ], + description=ToolDescription( + human=I18nObject( + en_US="Execute bash commands in the sandbox environment", + zh_Hans="在沙盒环境中执行 bash 命令", + ), + llm="Execute bash commands in the sandbox environment. " + "Use this tool to run shell commands, scripts, or interact with the system. " + "The command will be executed in an isolated sandbox environment.", + ), + ) + + runtime = ToolRuntime(tenant_id=tenant_id) + super().__init__(entity=entity, runtime=runtime) + + def tool_provider_type(self) -> ToolProviderType: + return ToolProviderType.BUILT_IN + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + ) -> Generator[ToolInvokeMessage, None, None]: + command = tool_parameters.get("command", "") + if not command: + yield self.create_text_message("Error: No command provided") + return + + connection_handle = self._sandbox.establish_connection() + try: + cmd_list = ["sh", "-c", command] + future = self._sandbox.run_command(connection_handle, cmd_list) + timeout = COMMAND_TIMEOUT_SECONDS if COMMAND_TIMEOUT_SECONDS > 0 else None + result = future.result(timeout=timeout) + + stdout = result.stdout.decode("utf-8", errors="replace") if result.stdout else "" + stderr = result.stderr.decode("utf-8", errors="replace") if result.stderr else "" + exit_code = result.exit_code + + output_parts: list[str] = [] + if stdout: + output_parts.append(f"stdout:\n{stdout}") + if stderr: + output_parts.append(f"stderr:\n{stderr}") + output_parts.append(f"exit_code: {exit_code}") + + yield self.create_text_message("\n".join(output_parts)) + + except TimeoutError: + yield self.create_text_message(f"Error: Command timed out after {COMMAND_TIMEOUT_SECONDS}s") + except Exception as e: + yield self.create_text_message(f"Error: {e!s}") + finally: + self._sandbox.release_connection(connection_handle) diff --git a/api/core/virtual_environment/sandbox_manager.py b/api/core/virtual_environment/sandbox_manager.py index 3a0666f5e5..72eda64d0c 100644 --- a/api/core/virtual_environment/sandbox_manager.py +++ b/api/core/virtual_environment/sandbox_manager.py @@ -1,6 +1,12 @@ +from __future__ import annotations + import logging import threading -from typing import Final +from typing import TYPE_CHECKING, Final + +if TYPE_CHECKING: + from core.tools.__base.tool import Tool + from core.tools.builtin_tool.providers.sandbox.bash_tool import SandboxBashTool from core.virtual_environment.__base.virtual_environment import VirtualEnvironment @@ -83,6 +89,10 @@ class SandboxManager: shard_index = cls._shard_index(workflow_execution_id) return workflow_execution_id in cls._shards[shard_index] + @classmethod + def is_sandbox_runtime(cls, workflow_execution_id: str) -> bool: + return cls.has(workflow_execution_id) + @classmethod def clear(cls) -> None: for lock in cls._shard_locks: @@ -98,3 +108,28 @@ class SandboxManager: @classmethod def count(cls) -> int: return sum(len(shard) for shard in cls._shards) + + @classmethod + def get_bash_tool( + cls, + workflow_execution_id: str, + tenant_id: str, + configured_tools: list[Tool], + ) -> SandboxBashTool: + from core.tools.builtin_tool.providers.sandbox.bash_tool import SandboxBashTool + + sandbox = cls.get(workflow_execution_id) + if sandbox is None: + raise RuntimeError(f"Sandbox not found for workflow_execution_id={workflow_execution_id}") + + cls._initialize_tools_in_sandbox(sandbox, configured_tools) + + return SandboxBashTool(sandbox=sandbox, tenant_id=tenant_id) + + @classmethod + def _initialize_tools_in_sandbox( + cls, + sandbox: VirtualEnvironment, + configured_tools: list[Tool], + ) -> None: + raise NotImplementedError("TODO: Initialize configured tools in sandbox environment") diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index b4caa28ee0..e306693a1d 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -61,6 +61,7 @@ from core.variables import ( ObjectSegment, StringSegment, ) +from core.virtual_environment.sandbox_manager import SandboxManager from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.entities import GraphInitParams, ToolCall, ToolResult, ToolResultStatus from core.workflow.entities.tool_entities import ToolCallResult @@ -261,18 +262,33 @@ class LLMNode(Node[LLMNodeData]): generation_data: LLMGenerationData | None = None structured_output: LLMStructuredOutput | None = None - # Check if tools are configured if self.tool_call_enabled: - # Use tool-enabled invocation (Agent V2 style) - generator = self._invoke_llm_with_tools( - model_instance=model_instance, - prompt_messages=prompt_messages, - stop=stop, - files=files, - variable_pool=variable_pool, - node_inputs=node_inputs, - process_data=process_data, + workflow_execution_id = variable_pool.system_variables.workflow_execution_id + is_sandbox_runtime = ( + workflow_execution_id is not None + and SandboxManager.is_sandbox_runtime(workflow_execution_id) ) + + if is_sandbox_runtime: + generator = self._invoke_llm_with_sandbox( + model_instance=model_instance, + prompt_messages=prompt_messages, + stop=stop, + files=files, + variable_pool=variable_pool, + node_inputs=node_inputs, + process_data=process_data, + ) + else: + generator = self._invoke_llm_with_tools( + model_instance=model_instance, + prompt_messages=prompt_messages, + stop=stop, + files=files, + variable_pool=variable_pool, + node_inputs=node_inputs, + process_data=process_data, + ) else: # Use traditional LLM invocation generator = LLMNode.invoke_llm( @@ -1565,7 +1581,52 @@ class LLMNode(Node[LLMNodeData]): stream=True, ) - # Process outputs and return generation result + result = yield from self._process_tool_outputs(outputs) + return result + + def _invoke_llm_with_sandbox( + self, + model_instance: ModelInstance, + prompt_messages: Sequence[PromptMessage], + stop: Sequence[str] | None, + files: Sequence[File], + variable_pool: VariablePool, + node_inputs: dict[str, Any], + process_data: dict[str, Any], + ) -> Generator[NodeEventBase, None, LLMGenerationData]: + from core.agent.entities import AgentEntity + + workflow_execution_id = variable_pool.system_variables.workflow_execution_id + if not workflow_execution_id: + raise LLMNodeError("workflow_execution_id is required for sandbox runtime mode") + + configured_tools = self._prepare_tool_instances(variable_pool) + + bash_tool = SandboxManager.get_bash_tool( + workflow_execution_id=workflow_execution_id, + tenant_id=self.tenant_id, + configured_tools=configured_tools, + ) + + prompt_files = self._extract_prompt_files(variable_pool) + + strategy = StrategyFactory.create_strategy( + model_features=[], + model_instance=model_instance, + tools=[bash_tool], + files=prompt_files, + max_iterations=self._node_data.max_iterations or 10, + context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id), + agent_strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT, + ) + + outputs = strategy.run( + prompt_messages=list(prompt_messages), + model_parameters=self._node_data.model.completion_params, + stop=list(stop or []), + stream=True, + ) + result = yield from self._process_tool_outputs(outputs) return result