diff --git a/api/core/llm_generator/output_parser/structured_output.py b/api/core/llm_generator/output_parser/structured_output.py index 7986ccd2d9..8f166a5757 100644 --- a/api/core/llm_generator/output_parser/structured_output.py +++ b/api/core/llm_generator/output_parser/structured_output.py @@ -1,8 +1,8 @@ import json -from collections.abc import Generator, Mapping, Sequence +from collections.abc import Mapping, Sequence from copy import deepcopy from enum import StrEnum -from typing import Any, Literal, TypeVar, cast, overload +from typing import Any, TypeVar, cast import json_repair from pydantic import BaseModel, TypeAdapter, ValidationError @@ -14,13 +14,9 @@ from core.model_manager import ModelInstance from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import ( LLMResult, - LLMResultChunk, - LLMResultChunkDelta, - LLMResultChunkWithStructuredOutput, LLMResultWithStructuredOutput, ) from core.model_runtime.entities.message_entities import ( - AssistantPromptMessage, PromptMessage, PromptMessageTool, SystemPromptMessage, @@ -52,7 +48,6 @@ TOOL_CALL_FEATURES = {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, Mode T = TypeVar("T", bound=BaseModel) -@overload def invoke_llm_with_structured_output( *, provider: str, @@ -63,58 +58,10 @@ def invoke_llm_with_structured_output( model_parameters: Mapping | None = None, tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: Literal[True], user: str | None = None, callbacks: list[Callback] | None = None, tenant_id: str | None = None, -) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ... -@overload -def invoke_llm_with_structured_output( - *, - provider: str, - model_schema: AIModelEntity, - model_instance: ModelInstance, - prompt_messages: Sequence[PromptMessage], - json_schema: Mapping[str, Any], - model_parameters: Mapping | None = None, - tools: Sequence[PromptMessageTool] | None = None, - stop: list[str] | None = None, - stream: Literal[False], - user: str | None = None, - callbacks: list[Callback] | None = None, - tenant_id: str | None = None, -) -> LLMResultWithStructuredOutput: ... -@overload -def invoke_llm_with_structured_output( - *, - provider: str, - model_schema: AIModelEntity, - model_instance: ModelInstance, - prompt_messages: Sequence[PromptMessage], - json_schema: Mapping[str, Any], - model_parameters: Mapping | None = None, - tools: Sequence[PromptMessageTool] | None = None, - stop: list[str] | None = None, - stream: bool = True, - user: str | None = None, - callbacks: list[Callback] | None = None, - tenant_id: str | None = None, -) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ... -def invoke_llm_with_structured_output( - *, - provider: str, - model_schema: AIModelEntity, - model_instance: ModelInstance, - prompt_messages: Sequence[PromptMessage], - json_schema: Mapping[str, Any], - model_parameters: Mapping | None = None, - tools: Sequence[PromptMessageTool] | None = None, - stop: list[str] | None = None, - stream: bool = True, - user: str | None = None, - callbacks: list[Callback] | None = None, - tenant_id: str | None = None, -) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: +) -> LLMResultWithStructuredOutput: """ Invoke large language model with structured output. @@ -129,7 +76,6 @@ def invoke_llm_with_structured_output( :param model_parameters: model parameters :param tools: tools for tool calling :param stop: stop words - :param stream: is stream response :param user: unique user id :param callbacks: callbacks :param tenant_id: tenant ID for file reference conversion. When provided and @@ -165,91 +111,33 @@ def invoke_llm_with_structured_output( model_parameters=model_parameters_with_json_schema, tools=tools, stop=stop, - stream=stream, + stream=False, user=user, callbacks=callbacks, ) - if isinstance(llm_result, LLMResult): - # Non-streaming result - structured_output = _extract_structured_output(llm_result) + # Non-streaming result + structured_output = _extract_structured_output(llm_result) - # Fill missing fields with default values - structured_output = fill_defaults_from_schema(structured_output, json_schema) + # Fill missing fields with default values + structured_output = fill_defaults_from_schema(structured_output, json_schema) - # Convert file references if tenant_id is provided - if tenant_id is not None: - structured_output = convert_file_refs_in_output( - output=structured_output, - json_schema=json_schema, - tenant_id=tenant_id, - ) - - return LLMResultWithStructuredOutput( - structured_output=structured_output, - model=llm_result.model, - message=llm_result.message, - usage=llm_result.usage, - system_fingerprint=llm_result.system_fingerprint, - prompt_messages=llm_result.prompt_messages, + # Convert file references if tenant_id is provided + if tenant_id is not None: + structured_output = convert_file_refs_in_output( + output=structured_output, + json_schema=json_schema, + tenant_id=tenant_id, ) - else: - def generator() -> Generator[LLMResultChunkWithStructuredOutput, None, None]: - result_text: str = "" - tool_call_args: dict[str, str] = {} # tool_call_id -> arguments - prompt_messages: Sequence[PromptMessage] = [] - system_fingerprint: str | None = None - - for event in llm_result: - if isinstance(event, LLMResultChunk): - prompt_messages = event.prompt_messages - system_fingerprint = event.system_fingerprint - - # Collect text content - result_text += event.delta.message.get_text_content() - # Collect tool call arguments - if event.delta.message.tool_calls: - for tool_call in event.delta.message.tool_calls: - call_id = tool_call.id or "" - if tool_call.function.arguments: - tool_call_args[call_id] = tool_call_args.get(call_id, "") + tool_call.function.arguments - - yield LLMResultChunkWithStructuredOutput( - model=model_schema.model, - prompt_messages=prompt_messages, - system_fingerprint=system_fingerprint, - delta=event.delta, - ) - - # Extract structured output: prefer tool call, fallback to text - structured_output = _extract_structured_output_from_stream(result_text, tool_call_args) - - # Fill missing fields with default values - structured_output = fill_defaults_from_schema(structured_output, json_schema) - - # Convert file references if tenant_id is provided - if tenant_id is not None: - structured_output = convert_file_refs_in_output( - output=structured_output, - json_schema=json_schema, - tenant_id=tenant_id, - ) - - yield LLMResultChunkWithStructuredOutput( - structured_output=structured_output, - model=model_schema.model, - prompt_messages=prompt_messages, - system_fingerprint=system_fingerprint, - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage(content=""), - usage=None, - finish_reason=None, - ), - ) - - return generator() + return LLMResultWithStructuredOutput( + structured_output=structured_output, + model=llm_result.model, + message=llm_result.message, + usage=llm_result.usage, + system_fingerprint=llm_result.system_fingerprint, + prompt_messages=llm_result.prompt_messages, + ) def invoke_llm_with_pydantic_model( @@ -289,7 +177,6 @@ def invoke_llm_with_pydantic_model( model_parameters=model_parameters, tools=tools, stop=stop, - stream=False, user=user, callbacks=callbacks, tenant_id=tenant_id, @@ -385,7 +272,7 @@ def _parse_tool_call_arguments(arguments: str) -> Mapping[str, Any]: repaired = json_repair.loads(arguments) if not isinstance(repaired, dict): raise OutputParserError(f"Failed to parse tool call arguments: {arguments}") - return cast(dict, repaired) + return repaired def _get_default_value_for_type(type_name: str | list[str] | None) -> Any: diff --git a/api/core/plugin/backwards_invocation/model.py b/api/core/plugin/backwards_invocation/model.py index 6cdc047a64..1abd9fabc7 100644 --- a/api/core/plugin/backwards_invocation/model.py +++ b/api/core/plugin/backwards_invocation/model.py @@ -114,46 +114,32 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): model_instance=model_instance, prompt_messages=payload.prompt_messages, json_schema=payload.structured_output_schema, + model_parameters=payload.completion_params, tools=payload.tools, stop=payload.stop, - stream=True if payload.stream is None else payload.stream, - user=user_id, - model_parameters=payload.completion_params, + user=user_id ) - if isinstance(response, Generator): + if response.usage: + llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage) - def handle() -> Generator[LLMResultChunkWithStructuredOutput, None, None]: - for chunk in response: - if chunk.delta.usage: - llm_utils.deduct_llm_quota( - tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage - ) - chunk.prompt_messages = [] - yield chunk + def handle_non_streaming( + response: LLMResultWithStructuredOutput, + ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: + yield LLMResultChunkWithStructuredOutput( + model=response.model, + prompt_messages=[], + system_fingerprint=response.system_fingerprint, + structured_output=response.structured_output, + delta=LLMResultChunkDelta( + index=0, + message=response.message, + usage=response.usage, + finish_reason="", + ), + ) - return handle() - else: - if response.usage: - llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage) - - def handle_non_streaming( - response: LLMResultWithStructuredOutput, - ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: - yield LLMResultChunkWithStructuredOutput( - model=response.model, - prompt_messages=[], - system_fingerprint=response.system_fingerprint, - structured_output=response.structured_output, - delta=LLMResultChunkDelta( - index=0, - message=response.message, - usage=response.usage, - finish_reason="", - ), - ) - - return handle_non_streaming(response) + return handle_non_streaming(response) @classmethod def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding): diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index f780f5a54f..feb76e6510 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -522,7 +522,6 @@ class LLMNode(Node[LLMNodeData]): json_schema=output_schema, model_parameters=node_data_model.completion_params, stop=list(stop or []), - stream=False, user=user_id, tenant_id=tenant_id, ) diff --git a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py index fed3e923e9..df73c29004 100644 --- a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py +++ b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py @@ -312,7 +312,6 @@ def test_structured_output_parser(): model_instance=model_instance, prompt_messages=prompt_messages, json_schema=case["json_schema"], - stream=case["stream"], ) # Consume the generator to trigger the error list(result_generator) @@ -323,7 +322,6 @@ def test_structured_output_parser(): model_instance=model_instance, prompt_messages=prompt_messages, json_schema=case["json_schema"], - stream=case["stream"], ) else: # Test successful cases @@ -338,7 +336,6 @@ def test_structured_output_parser(): model_instance=model_instance, prompt_messages=prompt_messages, json_schema=case["json_schema"], - stream=case["stream"], model_parameters={"temperature": 0.7, "max_tokens": 100}, user="test_user", ) @@ -418,7 +415,6 @@ def test_parse_structured_output_edge_cases(): model_instance=model_instance, prompt_messages=prompt_messages, json_schema=testcase_list_with_dict["json_schema"], - stream=testcase_list_with_dict["stream"], ) assert isinstance(result, LLMResultWithStructuredOutput) @@ -456,7 +452,6 @@ def test_model_specific_schema_preparation(): model_instance=model_instance, prompt_messages=prompt_messages, json_schema=gemini_case["json_schema"], - stream=gemini_case["stream"], ) assert isinstance(result, LLMResultWithStructuredOutput)