mirror of
https://github.com/langgenius/dify.git
synced 2026-05-10 05:56:31 +08:00
refactor: remove streaming structured output from invoke_llm_with_structured_output
Signed-off-by: Stream <Stream_2@qq.com>
This commit is contained in:
parent
749cebe60d
commit
edce6d4152
@ -1,8 +1,8 @@
|
||||
import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from collections.abc import Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from enum import StrEnum
|
||||
from typing import Any, Literal, TypeVar, cast, overload
|
||||
from typing import Any, TypeVar, cast
|
||||
|
||||
import json_repair
|
||||
from pydantic import BaseModel, TypeAdapter, ValidationError
|
||||
@ -14,13 +14,9 @@ from core.model_manager import ModelInstance
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
LLMResultChunkWithStructuredOutput,
|
||||
LLMResultWithStructuredOutput,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
@ -52,7 +48,6 @@ TOOL_CALL_FEATURES = {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, Mode
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
*,
|
||||
provider: str,
|
||||
@ -63,58 +58,10 @@ def invoke_llm_with_structured_output(
|
||||
model_parameters: Mapping | None = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: Literal[True],
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
*,
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
json_schema: Mapping[str, Any],
|
||||
model_parameters: Mapping | None = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: Literal[False],
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> LLMResultWithStructuredOutput: ...
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
*,
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
json_schema: Mapping[str, Any],
|
||||
model_parameters: Mapping | None = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
def invoke_llm_with_structured_output(
|
||||
*,
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
json_schema: Mapping[str, Any],
|
||||
model_parameters: Mapping | None = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
) -> LLMResultWithStructuredOutput:
|
||||
"""
|
||||
Invoke large language model with structured output.
|
||||
|
||||
@ -129,7 +76,6 @@ def invoke_llm_with_structured_output(
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
:param tenant_id: tenant ID for file reference conversion. When provided and
|
||||
@ -165,91 +111,33 @@ def invoke_llm_with_structured_output(
|
||||
model_parameters=model_parameters_with_json_schema,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream=False,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
if isinstance(llm_result, LLMResult):
|
||||
# Non-streaming result
|
||||
structured_output = _extract_structured_output(llm_result)
|
||||
# Non-streaming result
|
||||
structured_output = _extract_structured_output(llm_result)
|
||||
|
||||
# Fill missing fields with default values
|
||||
structured_output = fill_defaults_from_schema(structured_output, json_schema)
|
||||
# Fill missing fields with default values
|
||||
structured_output = fill_defaults_from_schema(structured_output, json_schema)
|
||||
|
||||
# Convert file references if tenant_id is provided
|
||||
if tenant_id is not None:
|
||||
structured_output = convert_file_refs_in_output(
|
||||
output=structured_output,
|
||||
json_schema=json_schema,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
return LLMResultWithStructuredOutput(
|
||||
structured_output=structured_output,
|
||||
model=llm_result.model,
|
||||
message=llm_result.message,
|
||||
usage=llm_result.usage,
|
||||
system_fingerprint=llm_result.system_fingerprint,
|
||||
prompt_messages=llm_result.prompt_messages,
|
||||
# Convert file references if tenant_id is provided
|
||||
if tenant_id is not None:
|
||||
structured_output = convert_file_refs_in_output(
|
||||
output=structured_output,
|
||||
json_schema=json_schema,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
else:
|
||||
|
||||
def generator() -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
result_text: str = ""
|
||||
tool_call_args: dict[str, str] = {} # tool_call_id -> arguments
|
||||
prompt_messages: Sequence[PromptMessage] = []
|
||||
system_fingerprint: str | None = None
|
||||
|
||||
for event in llm_result:
|
||||
if isinstance(event, LLMResultChunk):
|
||||
prompt_messages = event.prompt_messages
|
||||
system_fingerprint = event.system_fingerprint
|
||||
|
||||
# Collect text content
|
||||
result_text += event.delta.message.get_text_content()
|
||||
# Collect tool call arguments
|
||||
if event.delta.message.tool_calls:
|
||||
for tool_call in event.delta.message.tool_calls:
|
||||
call_id = tool_call.id or ""
|
||||
if tool_call.function.arguments:
|
||||
tool_call_args[call_id] = tool_call_args.get(call_id, "") + tool_call.function.arguments
|
||||
|
||||
yield LLMResultChunkWithStructuredOutput(
|
||||
model=model_schema.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=system_fingerprint,
|
||||
delta=event.delta,
|
||||
)
|
||||
|
||||
# Extract structured output: prefer tool call, fallback to text
|
||||
structured_output = _extract_structured_output_from_stream(result_text, tool_call_args)
|
||||
|
||||
# Fill missing fields with default values
|
||||
structured_output = fill_defaults_from_schema(structured_output, json_schema)
|
||||
|
||||
# Convert file references if tenant_id is provided
|
||||
if tenant_id is not None:
|
||||
structured_output = convert_file_refs_in_output(
|
||||
output=structured_output,
|
||||
json_schema=json_schema,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
yield LLMResultChunkWithStructuredOutput(
|
||||
structured_output=structured_output,
|
||||
model=model_schema.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
usage=None,
|
||||
finish_reason=None,
|
||||
),
|
||||
)
|
||||
|
||||
return generator()
|
||||
return LLMResultWithStructuredOutput(
|
||||
structured_output=structured_output,
|
||||
model=llm_result.model,
|
||||
message=llm_result.message,
|
||||
usage=llm_result.usage,
|
||||
system_fingerprint=llm_result.system_fingerprint,
|
||||
prompt_messages=llm_result.prompt_messages,
|
||||
)
|
||||
|
||||
|
||||
def invoke_llm_with_pydantic_model(
|
||||
@ -289,7 +177,6 @@ def invoke_llm_with_pydantic_model(
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=False,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
tenant_id=tenant_id,
|
||||
@ -385,7 +272,7 @@ def _parse_tool_call_arguments(arguments: str) -> Mapping[str, Any]:
|
||||
repaired = json_repair.loads(arguments)
|
||||
if not isinstance(repaired, dict):
|
||||
raise OutputParserError(f"Failed to parse tool call arguments: {arguments}")
|
||||
return cast(dict, repaired)
|
||||
return repaired
|
||||
|
||||
|
||||
def _get_default_value_for_type(type_name: str | list[str] | None) -> Any:
|
||||
|
||||
@ -114,46 +114,32 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
model_instance=model_instance,
|
||||
prompt_messages=payload.prompt_messages,
|
||||
json_schema=payload.structured_output_schema,
|
||||
model_parameters=payload.completion_params,
|
||||
tools=payload.tools,
|
||||
stop=payload.stop,
|
||||
stream=True if payload.stream is None else payload.stream,
|
||||
user=user_id,
|
||||
model_parameters=payload.completion_params,
|
||||
user=user_id
|
||||
)
|
||||
|
||||
if isinstance(response, Generator):
|
||||
if response.usage:
|
||||
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
|
||||
|
||||
def handle() -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
for chunk in response:
|
||||
if chunk.delta.usage:
|
||||
llm_utils.deduct_llm_quota(
|
||||
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
|
||||
)
|
||||
chunk.prompt_messages = []
|
||||
yield chunk
|
||||
def handle_non_streaming(
|
||||
response: LLMResultWithStructuredOutput,
|
||||
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
yield LLMResultChunkWithStructuredOutput(
|
||||
model=response.model,
|
||||
prompt_messages=[],
|
||||
system_fingerprint=response.system_fingerprint,
|
||||
structured_output=response.structured_output,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=response.message,
|
||||
usage=response.usage,
|
||||
finish_reason="",
|
||||
),
|
||||
)
|
||||
|
||||
return handle()
|
||||
else:
|
||||
if response.usage:
|
||||
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
|
||||
|
||||
def handle_non_streaming(
|
||||
response: LLMResultWithStructuredOutput,
|
||||
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
yield LLMResultChunkWithStructuredOutput(
|
||||
model=response.model,
|
||||
prompt_messages=[],
|
||||
system_fingerprint=response.system_fingerprint,
|
||||
structured_output=response.structured_output,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=response.message,
|
||||
usage=response.usage,
|
||||
finish_reason="",
|
||||
),
|
||||
)
|
||||
|
||||
return handle_non_streaming(response)
|
||||
return handle_non_streaming(response)
|
||||
|
||||
@classmethod
|
||||
def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding):
|
||||
|
||||
@ -522,7 +522,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
json_schema=output_schema,
|
||||
model_parameters=node_data_model.completion_params,
|
||||
stop=list(stop or []),
|
||||
stream=False,
|
||||
user=user_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
@ -312,7 +312,6 @@ def test_structured_output_parser():
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=case["json_schema"],
|
||||
stream=case["stream"],
|
||||
)
|
||||
# Consume the generator to trigger the error
|
||||
list(result_generator)
|
||||
@ -323,7 +322,6 @@ def test_structured_output_parser():
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=case["json_schema"],
|
||||
stream=case["stream"],
|
||||
)
|
||||
else:
|
||||
# Test successful cases
|
||||
@ -338,7 +336,6 @@ def test_structured_output_parser():
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=case["json_schema"],
|
||||
stream=case["stream"],
|
||||
model_parameters={"temperature": 0.7, "max_tokens": 100},
|
||||
user="test_user",
|
||||
)
|
||||
@ -418,7 +415,6 @@ def test_parse_structured_output_edge_cases():
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=testcase_list_with_dict["json_schema"],
|
||||
stream=testcase_list_with_dict["stream"],
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResultWithStructuredOutput)
|
||||
@ -456,7 +452,6 @@ def test_model_specific_schema_preparation():
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=gemini_case["json_schema"],
|
||||
stream=gemini_case["stream"],
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResultWithStructuredOutput)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user