From 70c001436e4513c38523537e334820d8089d0e7b Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Tue, 10 Sep 2024 18:13:33 +0800 Subject: [PATCH] support variable --- api/core/tools/entities/tool_entities.py | 44 +++++++++++++++++++++-- api/core/workflow/nodes/tool/tool_node.py | 38 +++++++++++--------- 2 files changed, 64 insertions(+), 18 deletions(-) diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 4b0961fb09..ef96207fa7 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -1,7 +1,8 @@ +import base64 from enum import Enum from typing import Any, Optional, Union, cast -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field, field_serializer, field_validator from core.entities.parameter_entities import AppSelectorScope, CommonParameterType, ModelConfigScope from core.tools.entities.common_entities import I18nObject @@ -100,6 +101,26 @@ class ToolInvokeMessage(BaseModel): class BlobMessage(BaseModel): blob: bytes + class VariableMessage(BaseModel): + variable_name: str = Field(..., description="The name of the variable") + variable_value: str = Field(..., description="The value of the variable") + stream: bool = Field(default=False, description="Whether the variable is streamed") + + @field_validator("variable_value", mode="before") + def transform_variable_value(cls, value, values) -> Any: + """ + Only basic types and lists are allowed. + """ + if not isinstance(value, dict | list | str | int | float | bool): + raise ValueError("Only basic types and lists are allowed.") + + # if stream is true, the value must be a string + if values.get('stream'): + if not isinstance(value, str): + raise ValueError("When 'stream' is True, 'variable_value' must be a string.") + + return value + class MessageType(Enum): TEXT = "text" IMAGE = "image" @@ -108,15 +129,34 @@ class ToolInvokeMessage(BaseModel): JSON = "json" IMAGE_LINK = "image_link" FILE_VAR = "file_var" + VARIABLE = "variable" type: MessageType = MessageType.TEXT """ plain text, image url or link url """ - message: JsonMessage | TextMessage | BlobMessage | None + message: JsonMessage | TextMessage | BlobMessage | VariableMessage | None meta: dict[str, Any] | None = None save_as: str = '' + @field_validator('message', mode='before') + @classmethod + def decode_blob_message(cls, v): + if isinstance(v, dict) and 'blob' in v: + try: + v['blob'] = base64.b64decode(v['blob']) + except Exception: + pass + return v + + @field_serializer('message') + def serialize_message(self, v): + if isinstance(v, self.BlobMessage): + return { + 'blob': base64.b64encode(v.blob).decode('utf-8') + } + return v + class ToolInvokeMessageBinary(BaseModel): mimetype: str = Field(..., description="The mimetype of the binary") url: str = Field(..., description="The url of the binary") diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 3865695c71..1f32c7b8bd 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -1,7 +1,6 @@ -from collections.abc import Generator, Iterable, Mapping, Sequence +from collections.abc import Generator, Mapping, Sequence from os import path from typing import Any, cast -from urllib import response from core.app.segments import ArrayAnySegment, ArrayAnyVariable, parser from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler @@ -98,19 +97,6 @@ class ToolNode(BaseNode): # convert tool messages yield from self._transform_message(message_stream, tool_info, parameters_for_log) - # return NodeRunResult( - # status=WorkflowNodeExecutionStatus.SUCCEEDED, - # outputs={ - # 'text': plain_text, - # 'files': files, - # 'json': json - # }, - # metadata={ - # NodeRunMetadataKey.TOOL_INFO: tool_info - # }, - # inputs=parameters_for_log - # ) - def _generate_parameters( self, *, @@ -183,6 +169,8 @@ class ToolNode(BaseNode): files: list[FileVar] = [] text = "" json: list[dict] = [] + + variables: dict[str, Any] = {} for message in message_stream: if message.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ @@ -241,6 +229,23 @@ class ToolNode(BaseNode): chunk_content=stream_text, from_variable_selector=[self.node_id, 'text'] ) + elif message.type == ToolInvokeMessage.MessageType.VARIABLE: + assert isinstance(message.message, ToolInvokeMessage.VariableMessage) + variable_name = message.message.variable_name + variable_value = message.message.variable_value + if message.message.stream: + if not isinstance(variable_value, str): + raise ValueError("When 'stream' is True, 'variable_value' must be a string.") + if variable_name not in variables: + variables[variable_name] = "" + variables[variable_name] += variable_value + + yield RunStreamChunkEvent( + chunk_content=variable_value, + from_variable_selector=[self.node_id, variable_name] + ) + else: + variables[variable_name] = variable_value yield RunCompletedEvent( run_result=NodeRunResult( @@ -248,7 +253,8 @@ class ToolNode(BaseNode): outputs={ 'text': text, 'files': files, - 'json': json + 'json': json, + **variables }, metadata={ NodeRunMetadataKey.TOOL_INFO: tool_info