support variable

This commit is contained in:
Yeuoly 2024-09-10 18:13:33 +08:00
parent cf73374c1b
commit 70c001436e
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
2 changed files with 64 additions and 18 deletions

View File

@ -1,7 +1,8 @@
import base64
from enum import Enum
from typing import Any, Optional, Union, cast
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field, field_serializer, field_validator
from core.entities.parameter_entities import AppSelectorScope, CommonParameterType, ModelConfigScope
from core.tools.entities.common_entities import I18nObject
@ -100,6 +101,26 @@ class ToolInvokeMessage(BaseModel):
class BlobMessage(BaseModel):
blob: bytes
class VariableMessage(BaseModel):
variable_name: str = Field(..., description="The name of the variable")
variable_value: str = Field(..., description="The value of the variable")
stream: bool = Field(default=False, description="Whether the variable is streamed")
@field_validator("variable_value", mode="before")
def transform_variable_value(cls, value, values) -> Any:
"""
Only basic types and lists are allowed.
"""
if not isinstance(value, dict | list | str | int | float | bool):
raise ValueError("Only basic types and lists are allowed.")
# if stream is true, the value must be a string
if values.get('stream'):
if not isinstance(value, str):
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
return value
class MessageType(Enum):
TEXT = "text"
IMAGE = "image"
@ -108,15 +129,34 @@ class ToolInvokeMessage(BaseModel):
JSON = "json"
IMAGE_LINK = "image_link"
FILE_VAR = "file_var"
VARIABLE = "variable"
type: MessageType = MessageType.TEXT
"""
plain text, image url or link url
"""
message: JsonMessage | TextMessage | BlobMessage | None
message: JsonMessage | TextMessage | BlobMessage | VariableMessage | None
meta: dict[str, Any] | None = None
save_as: str = ''
@field_validator('message', mode='before')
@classmethod
def decode_blob_message(cls, v):
if isinstance(v, dict) and 'blob' in v:
try:
v['blob'] = base64.b64decode(v['blob'])
except Exception:
pass
return v
@field_serializer('message')
def serialize_message(self, v):
if isinstance(v, self.BlobMessage):
return {
'blob': base64.b64encode(v.blob).decode('utf-8')
}
return v
class ToolInvokeMessageBinary(BaseModel):
mimetype: str = Field(..., description="The mimetype of the binary")
url: str = Field(..., description="The url of the binary")

View File

@ -1,7 +1,6 @@
from collections.abc import Generator, Iterable, Mapping, Sequence
from collections.abc import Generator, Mapping, Sequence
from os import path
from typing import Any, cast
from urllib import response
from core.app.segments import ArrayAnySegment, ArrayAnyVariable, parser
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
@ -98,19 +97,6 @@ class ToolNode(BaseNode):
# convert tool messages
yield from self._transform_message(message_stream, tool_info, parameters_for_log)
# return NodeRunResult(
# status=WorkflowNodeExecutionStatus.SUCCEEDED,
# outputs={
# 'text': plain_text,
# 'files': files,
# 'json': json
# },
# metadata={
# NodeRunMetadataKey.TOOL_INFO: tool_info
# },
# inputs=parameters_for_log
# )
def _generate_parameters(
self,
*,
@ -183,6 +169,8 @@ class ToolNode(BaseNode):
files: list[FileVar] = []
text = ""
json: list[dict] = []
variables: dict[str, Any] = {}
for message in message_stream:
if message.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
@ -241,6 +229,23 @@ class ToolNode(BaseNode):
chunk_content=stream_text,
from_variable_selector=[self.node_id, 'text']
)
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
if not isinstance(variable_value, str):
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
if variable_name not in variables:
variables[variable_name] = ""
variables[variable_name] += variable_value
yield RunStreamChunkEvent(
chunk_content=variable_value,
from_variable_selector=[self.node_id, variable_name]
)
else:
variables[variable_name] = variable_value
yield RunCompletedEvent(
run_result=NodeRunResult(
@ -248,7 +253,8 @@ class ToolNode(BaseNode):
outputs={
'text': text,
'files': files,
'json': json
'json': json,
**variables
},
metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info