diff --git a/api/core/callback_handler/agent_tool_callback_handler.py b/api/core/callback_handler/agent_tool_callback_handler.py index f31b43cb7f..eb38a5fed0 100644 --- a/api/core/callback_handler/agent_tool_callback_handler.py +++ b/api/core/callback_handler/agent_tool_callback_handler.py @@ -103,9 +103,9 @@ class DifyAgentCallbackHandler(BaseModel): @property def ignore_agent(self) -> bool: """Whether to ignore agent callbacks.""" - return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != "true" + return not os.environ.get("DEBUG") or os.environ.get("DEBUG", "").lower() != "true" @property def ignore_chat_model(self) -> bool: """Whether to ignore chat model callbacks.""" - return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != "true" + return not os.environ.get("DEBUG") or os.environ.get("DEBUG", "").lower() != "true" diff --git a/api/core/callback_handler/plugin_tool_callback_handler.py b/api/core/callback_handler/plugin_tool_callback_handler.py deleted file mode 100644 index 033b8d423c..0000000000 --- a/api/core/callback_handler/plugin_tool_callback_handler.py +++ /dev/null @@ -1,5 +0,0 @@ -from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler - - -class DifyPluginCallbackHandler(DifyAgentCallbackHandler): - """Callback Handler that prints to std out.""" diff --git a/api/core/callback_handler/workflow_tool_callback_handler.py b/api/core/callback_handler/workflow_tool_callback_handler.py index 8ac12f72f2..350b18772b 100644 --- a/api/core/callback_handler/workflow_tool_callback_handler.py +++ b/api/core/callback_handler/workflow_tool_callback_handler.py @@ -1,5 +1,26 @@ -from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler +from collections.abc import Generator, Iterable, Mapping +from typing import Any, Optional + +from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler, print_text +from core.ops.ops_trace_manager import TraceQueueManager +from core.tools.entities.tool_entities import ToolInvokeMessage class DifyWorkflowCallbackHandler(DifyAgentCallbackHandler): """Callback Handler that prints to std out.""" + + def on_tool_execution( + self, + tool_name: str, + tool_inputs: Mapping[str, Any], + tool_outputs: Iterable[ToolInvokeMessage], + message_id: Optional[str] = None, + timer: Optional[Any] = None, + trace_manager: Optional[TraceQueueManager] = None, + ) -> Generator[ToolInvokeMessage, None, None]: + for tool_output in tool_outputs: + print_text("\n[on_tool_execution]\n", color=self.color) + print_text("Tool: " + tool_name + "\n", color=self.color) + print_text("Outputs: " + tool_output.model_dump_json()[:1000] + "\n", color=self.color) + print_text("\n") + yield tool_output diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index d8889917f0..5396acc285 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -9,7 +9,6 @@ from yarl import URL from core.app.entities.app_invoke_entities import InvokeFrom from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler -from core.callback_handler.plugin_tool_callback_handler import DifyPluginCallbackHandler from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.file.file_obj import FileTransferMethod from core.ops.ops_trace_manager import TraceQueueManager @@ -157,7 +156,7 @@ class ToolEngine: response = tool.invoke(user_id=user_id, tool_parameters=tool_parameters) # hit the callback handler - workflow_tool_callback.on_tool_end( + response = workflow_tool_callback.on_tool_execution( tool_name=tool.entity.identity.name, tool_inputs=tool_parameters, tool_outputs=response, @@ -168,31 +167,6 @@ class ToolEngine: workflow_tool_callback.on_tool_error(e) raise e - @staticmethod - def plugin_invoke( - tool: Tool, tool_parameters: dict, user_id: str, callback: DifyPluginCallbackHandler - ) -> Generator[ToolInvokeMessage, None, None]: - """ - Plugin invokes the tool with the given arguments. - """ - try: - # hit the callback handler - callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters) - - response = tool.invoke(user_id, tool_parameters) - - # hit the callback handler - callback.on_tool_end( - tool_name=tool.entity.identity.name, - tool_inputs=tool_parameters, - tool_outputs=response, - ) - - return response - except Exception as e: - callback.on_tool_error(e) - raise e - @staticmethod def _invoke( tool: Tool, tool_parameters: dict, user_id: str