mirror of https://github.com/langgenius/dify.git
support variable
This commit is contained in:
parent
cf73374c1b
commit
70c001436e
|
|
@ -1,7 +1,8 @@
|
|||
import base64
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field, field_serializer, field_validator
|
||||
|
||||
from core.entities.parameter_entities import AppSelectorScope, CommonParameterType, ModelConfigScope
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
|
|
@ -100,6 +101,26 @@ class ToolInvokeMessage(BaseModel):
|
|||
class BlobMessage(BaseModel):
|
||||
blob: bytes
|
||||
|
||||
class VariableMessage(BaseModel):
|
||||
variable_name: str = Field(..., description="The name of the variable")
|
||||
variable_value: str = Field(..., description="The value of the variable")
|
||||
stream: bool = Field(default=False, description="Whether the variable is streamed")
|
||||
|
||||
@field_validator("variable_value", mode="before")
|
||||
def transform_variable_value(cls, value, values) -> Any:
|
||||
"""
|
||||
Only basic types and lists are allowed.
|
||||
"""
|
||||
if not isinstance(value, dict | list | str | int | float | bool):
|
||||
raise ValueError("Only basic types and lists are allowed.")
|
||||
|
||||
# if stream is true, the value must be a string
|
||||
if values.get('stream'):
|
||||
if not isinstance(value, str):
|
||||
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
|
||||
|
||||
return value
|
||||
|
||||
class MessageType(Enum):
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
|
|
@ -108,15 +129,34 @@ class ToolInvokeMessage(BaseModel):
|
|||
JSON = "json"
|
||||
IMAGE_LINK = "image_link"
|
||||
FILE_VAR = "file_var"
|
||||
VARIABLE = "variable"
|
||||
|
||||
type: MessageType = MessageType.TEXT
|
||||
"""
|
||||
plain text, image url or link url
|
||||
"""
|
||||
message: JsonMessage | TextMessage | BlobMessage | None
|
||||
message: JsonMessage | TextMessage | BlobMessage | VariableMessage | None
|
||||
meta: dict[str, Any] | None = None
|
||||
save_as: str = ''
|
||||
|
||||
@field_validator('message', mode='before')
|
||||
@classmethod
|
||||
def decode_blob_message(cls, v):
|
||||
if isinstance(v, dict) and 'blob' in v:
|
||||
try:
|
||||
v['blob'] = base64.b64decode(v['blob'])
|
||||
except Exception:
|
||||
pass
|
||||
return v
|
||||
|
||||
@field_serializer('message')
|
||||
def serialize_message(self, v):
|
||||
if isinstance(v, self.BlobMessage):
|
||||
return {
|
||||
'blob': base64.b64encode(v.blob).decode('utf-8')
|
||||
}
|
||||
return v
|
||||
|
||||
class ToolInvokeMessageBinary(BaseModel):
|
||||
mimetype: str = Field(..., description="The mimetype of the binary")
|
||||
url: str = Field(..., description="The url of the binary")
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from collections.abc import Generator, Iterable, Mapping, Sequence
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from os import path
|
||||
from typing import Any, cast
|
||||
from urllib import response
|
||||
|
||||
from core.app.segments import ArrayAnySegment, ArrayAnyVariable, parser
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
|
|
@ -98,19 +97,6 @@ class ToolNode(BaseNode):
|
|||
# convert tool messages
|
||||
yield from self._transform_message(message_stream, tool_info, parameters_for_log)
|
||||
|
||||
# return NodeRunResult(
|
||||
# status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
# outputs={
|
||||
# 'text': plain_text,
|
||||
# 'files': files,
|
||||
# 'json': json
|
||||
# },
|
||||
# metadata={
|
||||
# NodeRunMetadataKey.TOOL_INFO: tool_info
|
||||
# },
|
||||
# inputs=parameters_for_log
|
||||
# )
|
||||
|
||||
def _generate_parameters(
|
||||
self,
|
||||
*,
|
||||
|
|
@ -183,6 +169,8 @@ class ToolNode(BaseNode):
|
|||
files: list[FileVar] = []
|
||||
text = ""
|
||||
json: list[dict] = []
|
||||
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
for message in message_stream:
|
||||
if message.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
|
||||
|
|
@ -241,6 +229,23 @@ class ToolNode(BaseNode):
|
|||
chunk_content=stream_text,
|
||||
from_variable_selector=[self.node_id, 'text']
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
variable_value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
if not isinstance(variable_value, str):
|
||||
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
|
||||
if variable_name not in variables:
|
||||
variables[variable_name] = ""
|
||||
variables[variable_name] += variable_value
|
||||
|
||||
yield RunStreamChunkEvent(
|
||||
chunk_content=variable_value,
|
||||
from_variable_selector=[self.node_id, variable_name]
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
|
|
@ -248,7 +253,8 @@ class ToolNode(BaseNode):
|
|||
outputs={
|
||||
'text': text,
|
||||
'files': files,
|
||||
'json': json
|
||||
'json': json,
|
||||
**variables
|
||||
},
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOOL_INFO: tool_info
|
||||
|
|
|
|||
Loading…
Reference in New Issue