fix(tools): fix ToolInvokeMessage Union type parsing issue (#31450)

Co-authored-by: qiaofenglin <qiaofenglin@baidu.com>
This commit is contained in:
fenglin 2026-01-24 10:18:06 +08:00 committed by GitHub
parent 1f8c730259
commit e8f9d64651
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 39 additions and 15 deletions

View File

@ -130,7 +130,7 @@ class ToolInvokeMessage(BaseModel):
text: str text: str
class JsonMessage(BaseModel): class JsonMessage(BaseModel):
json_object: dict json_object: dict | list
suppress_output: bool = Field(default=False, description="Whether to suppress JSON output in result string") suppress_output: bool = Field(default=False, description="Whether to suppress JSON output in result string")
class BlobMessage(BaseModel): class BlobMessage(BaseModel):
@ -144,7 +144,14 @@ class ToolInvokeMessage(BaseModel):
end: bool = Field(..., description="Whether the chunk is the last chunk") end: bool = Field(..., description="Whether the chunk is the last chunk")
class FileMessage(BaseModel): class FileMessage(BaseModel):
pass file_marker: str = Field(default="file_marker")
@model_validator(mode="before")
@classmethod
def validate_file_message(cls, values):
if isinstance(values, dict) and "file_marker" not in values:
raise ValueError("Invalid FileMessage: missing file_marker")
return values
class VariableMessage(BaseModel): class VariableMessage(BaseModel):
variable_name: str = Field(..., description="The name of the variable") variable_name: str = Field(..., description="The name of the variable")
@ -234,10 +241,22 @@ class ToolInvokeMessage(BaseModel):
@field_validator("message", mode="before") @field_validator("message", mode="before")
@classmethod @classmethod
def decode_blob_message(cls, v): def decode_blob_message(cls, v, info: ValidationInfo):
# 处理 blob 解码
if isinstance(v, dict) and "blob" in v: if isinstance(v, dict) and "blob" in v:
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
v["blob"] = base64.b64decode(v["blob"]) v["blob"] = base64.b64decode(v["blob"])
# Force correct message type based on type field
# Only wrap dict types to avoid wrapping already parsed Pydantic model objects
if info.data and isinstance(info.data, dict) and isinstance(v, dict):
msg_type = info.data.get("type")
if msg_type == cls.MessageType.JSON:
if "json_object" not in v:
v = {"json_object": v}
elif msg_type == cls.MessageType.FILE:
v = {"file_marker": "file_marker"}
return v return v
@field_serializer("message") @field_serializer("message")

View File

@ -494,7 +494,7 @@ class AgentNode(Node[AgentNodeData]):
text = "" text = ""
files: list[File] = [] files: list[File] = []
json_list: list[dict] = [] json_list: list[dict | list] = []
agent_logs: list[AgentLogEvent] = [] agent_logs: list[AgentLogEvent] = []
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {} agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
@ -568,13 +568,18 @@ class AgentNode(Node[AgentNodeData]):
elif message.type == ToolInvokeMessage.MessageType.JSON: elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage) assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
if node_type == NodeType.AGENT: if node_type == NodeType.AGENT:
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {}) if isinstance(message.message.json_object, dict):
llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata)) msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
agent_execution_metadata = { llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata))
WorkflowNodeExecutionMetadataKey(key): value agent_execution_metadata = {
for key, value in msg_metadata.items() WorkflowNodeExecutionMetadataKey(key): value
if key in WorkflowNodeExecutionMetadataKey.__members__.values() for key, value in msg_metadata.items()
} if key in WorkflowNodeExecutionMetadataKey.__members__.values()
}
else:
msg_metadata = {}
llm_usage = LLMUsage.empty_usage()
agent_execution_metadata = {}
if message.message.json_object: if message.message.json_object:
json_list.append(message.message.json_object) json_list.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK: elif message.type == ToolInvokeMessage.MessageType.LINK:
@ -683,7 +688,7 @@ class AgentNode(Node[AgentNodeData]):
yield agent_log yield agent_log
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process # Add agent_logs to outputs['json'] to ensure frontend can access thinking process
json_output: list[dict[str, Any]] = [] json_output: list[dict[str, Any] | list[Any]] = []
# Step 1: append each agent log as its own dict. # Step 1: append each agent log as its own dict.
if agent_logs: if agent_logs:

View File

@ -301,7 +301,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
text = "" text = ""
files: list[File] = [] files: list[File] = []
json: list[dict] = [] json: list[dict | list] = []
variables: dict[str, Any] = {} variables: dict[str, Any] = {}

View File

@ -244,7 +244,7 @@ class ToolNode(Node[ToolNodeData]):
text = "" text = ""
files: list[File] = [] files: list[File] = []
json: list[dict] = [] json: list[dict | list] = []
variables: dict[str, Any] = {} variables: dict[str, Any] = {}
@ -400,7 +400,7 @@ class ToolNode(Node[ToolNodeData]):
message.message.metadata = dict_metadata message.message.metadata = dict_metadata
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process # Add agent_logs to outputs['json'] to ensure frontend can access thinking process
json_output: list[dict[str, Any]] = [] json_output: list[dict[str, Any] | list[Any]] = []
# Step 2: normalize JSON into {"data": [...]}.change json to list[dict] # Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
if json: if json: