refactor: remove streaming structured output from invoke_llm_with_structured_output

Signed-off-by: Stream <Stream_2@qq.com>
This commit is contained in:
Stream 2026-01-29 23:41:08 +08:00
parent 749cebe60d
commit edce6d4152
No known key found for this signature in database
GPG Key ID: 033728094B100D70
4 changed files with 43 additions and 176 deletions

View File

@ -1,8 +1,8 @@
import json
from collections.abc import Generator, Mapping, Sequence
from collections.abc import Mapping, Sequence
from copy import deepcopy
from enum import StrEnum
from typing import Any, Literal, TypeVar, cast, overload
from typing import Any, TypeVar, cast
import json_repair
from pydantic import BaseModel, TypeAdapter, ValidationError
@ -14,13 +14,9 @@ from core.model_manager import ModelInstance
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
LLMResultChunkWithStructuredOutput,
LLMResultWithStructuredOutput,
)
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageTool,
SystemPromptMessage,
@ -52,7 +48,6 @@ TOOL_CALL_FEATURES = {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, Mode
T = TypeVar("T", bound=BaseModel)
@overload
def invoke_llm_with_structured_output(
*,
provider: str,
@ -63,58 +58,10 @@ def invoke_llm_with_structured_output(
model_parameters: Mapping | None = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: Literal[True],
user: str | None = None,
callbacks: list[Callback] | None = None,
tenant_id: str | None = None,
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
@overload
def invoke_llm_with_structured_output(
*,
provider: str,
model_schema: AIModelEntity,
model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage],
json_schema: Mapping[str, Any],
model_parameters: Mapping | None = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: Literal[False],
user: str | None = None,
callbacks: list[Callback] | None = None,
tenant_id: str | None = None,
) -> LLMResultWithStructuredOutput: ...
@overload
def invoke_llm_with_structured_output(
*,
provider: str,
model_schema: AIModelEntity,
model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage],
json_schema: Mapping[str, Any],
model_parameters: Mapping | None = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
callbacks: list[Callback] | None = None,
tenant_id: str | None = None,
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
def invoke_llm_with_structured_output(
*,
provider: str,
model_schema: AIModelEntity,
model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage],
json_schema: Mapping[str, Any],
model_parameters: Mapping | None = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
callbacks: list[Callback] | None = None,
tenant_id: str | None = None,
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]:
) -> LLMResultWithStructuredOutput:
"""
Invoke large language model with structured output.
@ -129,7 +76,6 @@ def invoke_llm_with_structured_output(
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:param callbacks: callbacks
:param tenant_id: tenant ID for file reference conversion. When provided and
@ -165,91 +111,33 @@ def invoke_llm_with_structured_output(
model_parameters=model_parameters_with_json_schema,
tools=tools,
stop=stop,
stream=stream,
stream=False,
user=user,
callbacks=callbacks,
)
if isinstance(llm_result, LLMResult):
# Non-streaming result
structured_output = _extract_structured_output(llm_result)
# Non-streaming result
structured_output = _extract_structured_output(llm_result)
# Fill missing fields with default values
structured_output = fill_defaults_from_schema(structured_output, json_schema)
# Fill missing fields with default values
structured_output = fill_defaults_from_schema(structured_output, json_schema)
# Convert file references if tenant_id is provided
if tenant_id is not None:
structured_output = convert_file_refs_in_output(
output=structured_output,
json_schema=json_schema,
tenant_id=tenant_id,
)
return LLMResultWithStructuredOutput(
structured_output=structured_output,
model=llm_result.model,
message=llm_result.message,
usage=llm_result.usage,
system_fingerprint=llm_result.system_fingerprint,
prompt_messages=llm_result.prompt_messages,
# Convert file references if tenant_id is provided
if tenant_id is not None:
structured_output = convert_file_refs_in_output(
output=structured_output,
json_schema=json_schema,
tenant_id=tenant_id,
)
else:
def generator() -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
result_text: str = ""
tool_call_args: dict[str, str] = {} # tool_call_id -> arguments
prompt_messages: Sequence[PromptMessage] = []
system_fingerprint: str | None = None
for event in llm_result:
if isinstance(event, LLMResultChunk):
prompt_messages = event.prompt_messages
system_fingerprint = event.system_fingerprint
# Collect text content
result_text += event.delta.message.get_text_content()
# Collect tool call arguments
if event.delta.message.tool_calls:
for tool_call in event.delta.message.tool_calls:
call_id = tool_call.id or ""
if tool_call.function.arguments:
tool_call_args[call_id] = tool_call_args.get(call_id, "") + tool_call.function.arguments
yield LLMResultChunkWithStructuredOutput(
model=model_schema.model,
prompt_messages=prompt_messages,
system_fingerprint=system_fingerprint,
delta=event.delta,
)
# Extract structured output: prefer tool call, fallback to text
structured_output = _extract_structured_output_from_stream(result_text, tool_call_args)
# Fill missing fields with default values
structured_output = fill_defaults_from_schema(structured_output, json_schema)
# Convert file references if tenant_id is provided
if tenant_id is not None:
structured_output = convert_file_refs_in_output(
output=structured_output,
json_schema=json_schema,
tenant_id=tenant_id,
)
yield LLMResultChunkWithStructuredOutput(
structured_output=structured_output,
model=model_schema.model,
prompt_messages=prompt_messages,
system_fingerprint=system_fingerprint,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=""),
usage=None,
finish_reason=None,
),
)
return generator()
return LLMResultWithStructuredOutput(
structured_output=structured_output,
model=llm_result.model,
message=llm_result.message,
usage=llm_result.usage,
system_fingerprint=llm_result.system_fingerprint,
prompt_messages=llm_result.prompt_messages,
)
def invoke_llm_with_pydantic_model(
@ -289,7 +177,6 @@ def invoke_llm_with_pydantic_model(
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=False,
user=user,
callbacks=callbacks,
tenant_id=tenant_id,
@ -385,7 +272,7 @@ def _parse_tool_call_arguments(arguments: str) -> Mapping[str, Any]:
repaired = json_repair.loads(arguments)
if not isinstance(repaired, dict):
raise OutputParserError(f"Failed to parse tool call arguments: {arguments}")
return cast(dict, repaired)
return repaired
def _get_default_value_for_type(type_name: str | list[str] | None) -> Any:

View File

@ -114,46 +114,32 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
model_instance=model_instance,
prompt_messages=payload.prompt_messages,
json_schema=payload.structured_output_schema,
model_parameters=payload.completion_params,
tools=payload.tools,
stop=payload.stop,
stream=True if payload.stream is None else payload.stream,
user=user_id,
model_parameters=payload.completion_params,
user=user_id
)
if isinstance(response, Generator):
if response.usage:
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
def handle() -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
for chunk in response:
if chunk.delta.usage:
llm_utils.deduct_llm_quota(
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
)
chunk.prompt_messages = []
yield chunk
def handle_non_streaming(
response: LLMResultWithStructuredOutput,
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
yield LLMResultChunkWithStructuredOutput(
model=response.model,
prompt_messages=[],
system_fingerprint=response.system_fingerprint,
structured_output=response.structured_output,
delta=LLMResultChunkDelta(
index=0,
message=response.message,
usage=response.usage,
finish_reason="",
),
)
return handle()
else:
if response.usage:
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
def handle_non_streaming(
response: LLMResultWithStructuredOutput,
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
yield LLMResultChunkWithStructuredOutput(
model=response.model,
prompt_messages=[],
system_fingerprint=response.system_fingerprint,
structured_output=response.structured_output,
delta=LLMResultChunkDelta(
index=0,
message=response.message,
usage=response.usage,
finish_reason="",
),
)
return handle_non_streaming(response)
return handle_non_streaming(response)
@classmethod
def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding):

View File

@ -522,7 +522,6 @@ class LLMNode(Node[LLMNodeData]):
json_schema=output_schema,
model_parameters=node_data_model.completion_params,
stop=list(stop or []),
stream=False,
user=user_id,
tenant_id=tenant_id,
)

View File

@ -312,7 +312,6 @@ def test_structured_output_parser():
model_instance=model_instance,
prompt_messages=prompt_messages,
json_schema=case["json_schema"],
stream=case["stream"],
)
# Consume the generator to trigger the error
list(result_generator)
@ -323,7 +322,6 @@ def test_structured_output_parser():
model_instance=model_instance,
prompt_messages=prompt_messages,
json_schema=case["json_schema"],
stream=case["stream"],
)
else:
# Test successful cases
@ -338,7 +336,6 @@ def test_structured_output_parser():
model_instance=model_instance,
prompt_messages=prompt_messages,
json_schema=case["json_schema"],
stream=case["stream"],
model_parameters={"temperature": 0.7, "max_tokens": 100},
user="test_user",
)
@ -418,7 +415,6 @@ def test_parse_structured_output_edge_cases():
model_instance=model_instance,
prompt_messages=prompt_messages,
json_schema=testcase_list_with_dict["json_schema"],
stream=testcase_list_with_dict["stream"],
)
assert isinstance(result, LLMResultWithStructuredOutput)
@ -456,7 +452,6 @@ def test_model_specific_schema_preparation():
model_instance=model_instance,
prompt_messages=prompt_messages,
json_schema=gemini_case["json_schema"],
stream=gemini_case["stream"],
)
assert isinstance(result, LLMResultWithStructuredOutput)