From ea37904c756bd4a02e75013b959a12b27742dada Mon Sep 17 00:00:00 2001 From: Stream Date: Wed, 21 Jan 2026 19:30:46 +0800 Subject: [PATCH] refactor: unify structured output with pydantic model Signed-off-by: Stream --- api/core/llm_generator/llm_generator.py | 101 +++++------------- api/core/llm_generator/output_models.py | 34 ++++++ .../output_parser/structured_output.py | 89 ++++++++++++++- .../test_structured_output_parser.py | 71 +++++++++++- 4 files changed, 216 insertions(+), 79 deletions(-) create mode 100644 api/core/llm_generator/output_models.py diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index fd769f6a83..d29332c3fa 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -6,6 +6,11 @@ from typing import Any, Protocol, cast import json_repair +from core.llm_generator.output_models import ( + CodeNodeStructuredOutput, + InstructionModifyOutput, + SuggestedQuestionsOutput, +) from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser from core.llm_generator.prompts import ( @@ -470,7 +475,7 @@ class LLMGenerator: *prompt_messages, ] - from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output + from core.llm_generator.output_parser.structured_output import invoke_llm_with_pydantic_model # Get model instance and schema provider = model_config.get("provider", "") @@ -487,15 +492,13 @@ class LLMGenerator: return cls._error_response(f"Model schema not found for {model_name}") model_parameters = model_config.get("completion_params", {}) - json_schema = cls._get_code_node_json_schema() - try: - response = invoke_llm_with_structured_output( + response = invoke_llm_with_pydantic_model( provider=provider, model_schema=model_schema, model_instance=model_instance, prompt_messages=complete_messages, - json_schema=json_schema, + output_model=CodeNodeStructuredOutput, model_parameters=model_parameters, stream=False, tenant_id=tenant_id, @@ -541,7 +544,7 @@ class LLMGenerator: from sqlalchemy import select from sqlalchemy.orm import Session - from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output + from core.llm_generator.output_parser.structured_output import invoke_llm_with_pydantic_model from services.workflow_service import WorkflowService # Get workflow context (reuse existing logic) @@ -602,15 +605,13 @@ class LLMGenerator: completion_params = model_config.get("completion_params", {}) if model_config else {} model_parameters = {**completion_params, "max_tokens": 256} - json_schema = cls._get_suggested_questions_json_schema() - try: - response = invoke_llm_with_structured_output( + response = invoke_llm_with_pydantic_model( provider=model_instance.provider, model_schema=model_schema, model_instance=model_instance, prompt_messages=prompt_messages, - json_schema=json_schema, + output_model=SuggestedQuestionsOutput, model_parameters=model_parameters, stream=False, tenant_id=tenant_id, @@ -644,58 +645,6 @@ Sources: {", ".join(sources)} Target: {parameter_info.get("name")}({param_type}) - {param_desc} Output 3 short, practical questions in {language}.""" - @classmethod - def _get_suggested_questions_json_schema(cls) -> dict: - """Return JSON Schema for suggested questions.""" - return { - "type": "object", - "properties": { - "questions": { - "type": "array", - "items": {"type": "string"}, - "minItems": 3, - "maxItems": 3, - "description": "3 suggested questions", - }, - }, - "required": ["questions"], - } - - @classmethod - def _get_code_node_json_schema(cls) -> dict: - """Return JSON Schema for structured output.""" - return { - "type": "object", - "properties": { - "variables": { - "type": "array", - "items": { - "type": "object", - "properties": { - "variable": {"type": "string", "description": "Variable name in code"}, - "value_selector": { - "type": "array", - "items": {"type": "string"}, - "description": "Path like [node_id, output_name]", - }, - }, - "required": ["variable", "value_selector"], - }, - }, - "code": {"type": "string", "description": "Generated code with main function"}, - "outputs": { - "type": "object", - "additionalProperties": { - "type": "object", - "properties": {"type": {"type": "string"}}, - }, - "description": "Output definitions, key is output name", - }, - "explanation": {"type": "string", "description": "Brief explanation of the code"}, - }, - "required": ["variables", "code", "outputs", "explanation"], - } - @classmethod def _get_upstream_nodes(cls, graph_dict: Mapping[str, Any], node_id: str) -> list[dict]: """ @@ -1011,6 +960,10 @@ Parameter: {parameter_info.get("name")} ({param_type}) - {parameter_info.get("de provider=model_config.get("provider", ""), model=model_config.get("name", ""), ) + model_name = model_config.get("name", "") + model_schema = model_instance.model_type_instance.get_model_schema(model_name, model_instance.credentials) + if not model_schema: + return {"error": f"Model schema not found for {model_name}"} match node_type: case "llm" | "agent": system_prompt = LLM_MODIFY_PROMPT_SYSTEM @@ -1034,20 +987,18 @@ Parameter: {parameter_info.get("name")} ({param_type}) - {parameter_info.get("de model_parameters = {"temperature": 0.4} try: - response: LLMResult = model_instance.invoke_llm( - prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False - ) + from core.llm_generator.output_parser.structured_output import invoke_llm_with_pydantic_model - generated_raw = response.message.get_text_content() - first_brace = generated_raw.find("{") - last_brace = generated_raw.rfind("}") - if first_brace == -1 or last_brace == -1 or last_brace < first_brace: - raise ValueError(f"Could not find a valid JSON object in response: {generated_raw}") - json_str = generated_raw[first_brace : last_brace + 1] - data = json_repair.loads(json_str) - if not isinstance(data, dict): - raise TypeError(f"Expected a JSON object, but got {type(data).__name__}") - return data + response = invoke_llm_with_pydantic_model( + provider=model_instance.provider, + model_schema=model_schema, + model_instance=model_instance, + prompt_messages=list(prompt_messages), + output_model=InstructionModifyOutput, + model_parameters=model_parameters, + stream=False, + ) + return response.structured_output or {} except InvokeError as e: error = str(e) return {"error": f"Failed to generate code. Error: {error}"} diff --git a/api/core/llm_generator/output_models.py b/api/core/llm_generator/output_models.py new file mode 100644 index 0000000000..22c4f5e411 --- /dev/null +++ b/api/core/llm_generator/output_models.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from pydantic import BaseModel, ConfigDict, Field + +from core.variables.types import SegmentType +from core.workflow.nodes.base.entities import VariableSelector + + +class SuggestedQuestionsOutput(BaseModel): + model_config = ConfigDict(extra="forbid") + + questions: list[str] = Field(min_length=3, max_length=3) + + +class CodeNodeOutput(BaseModel): + model_config = ConfigDict(extra="forbid") + + type: SegmentType + + +class CodeNodeStructuredOutput(BaseModel): + model_config = ConfigDict(extra="forbid") + + variables: list[VariableSelector] + code: str + outputs: dict[str, CodeNodeOutput] + explanation: str + + +class InstructionModifyOutput(BaseModel): + model_config = ConfigDict(extra="forbid") + + modified: str + message: str diff --git a/api/core/llm_generator/output_parser/structured_output.py b/api/core/llm_generator/output_parser/structured_output.py index 250acf14fd..7e931fed32 100644 --- a/api/core/llm_generator/output_parser/structured_output.py +++ b/api/core/llm_generator/output_parser/structured_output.py @@ -2,10 +2,10 @@ import json from collections.abc import Generator, Mapping, Sequence from copy import deepcopy from enum import StrEnum -from typing import Any, Literal, cast, overload +from typing import Any, Literal, TypeVar, cast, overload import json_repair -from pydantic import TypeAdapter, ValidationError +from pydantic import BaseModel, TypeAdapter, ValidationError from core.llm_generator.output_parser.errors import OutputParserError from core.llm_generator.output_parser.file_ref import convert_file_refs_in_output @@ -44,6 +44,9 @@ class SpecialModelType(StrEnum): OLLAMA = "ollama" +T = TypeVar("T", bound=BaseModel) + + @overload def invoke_llm_with_structured_output( *, @@ -129,7 +132,6 @@ def invoke_llm_with_structured_output( file IDs in the output will be automatically converted to File objects. :return: full response or stream response chunk generator result """ - # handle native json schema model_parameters_with_json_schema: dict[str, Any] = { **(model_parameters or {}), @@ -234,6 +236,87 @@ def invoke_llm_with_structured_output( return generator() +@overload +def invoke_llm_with_pydantic_model( + *, + provider: str, + model_schema: AIModelEntity, + model_instance: ModelInstance, + prompt_messages: Sequence[PromptMessage], + output_model: type[T], + model_parameters: Mapping | None = None, + tools: Sequence[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: Literal[False] = False, + user: str | None = None, + callbacks: list[Callback] | None = None, + tenant_id: str | None = None, +) -> LLMResultWithStructuredOutput: ... + + +def invoke_llm_with_pydantic_model( + *, + provider: str, + model_schema: AIModelEntity, + model_instance: ModelInstance, + prompt_messages: Sequence[PromptMessage], + output_model: type[T], + model_parameters: Mapping | None = None, + tools: Sequence[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = False, + user: str | None = None, + callbacks: list[Callback] | None = None, + tenant_id: str | None = None, +) -> LLMResultWithStructuredOutput: + """ + Invoke large language model with a Pydantic output model. + + This helper generates a JSON schema from the Pydantic model, invokes the + structured-output LLM path, and validates the result in non-streaming mode. + """ + if stream: + raise ValueError("invoke_llm_with_pydantic_model only supports stream=False") + + json_schema = _schema_from_pydantic(output_model) + result = invoke_llm_with_structured_output( + provider=provider, + model_schema=model_schema, + model_instance=model_instance, + prompt_messages=prompt_messages, + json_schema=json_schema, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=False, + user=user, + callbacks=callbacks, + tenant_id=tenant_id, + ) + + structured_output = result.structured_output + if structured_output is None: + raise OutputParserError("Structured output is empty") + + validated_output = _validate_structured_output(output_model, structured_output) + return result.model_copy(update={"structured_output": validated_output}) + + +def _schema_from_pydantic(output_model: type[BaseModel]) -> dict[str, Any]: + return output_model.model_json_schema() + + +def _validate_structured_output( + output_model: type[T], + structured_output: Mapping[str, Any], +) -> dict[str, Any]: + try: + validated_output = output_model.model_validate(structured_output) + except ValidationError as exc: + raise OutputParserError(f"Structured output validation failed: {exc}") from exc + return validated_output.model_dump(mode="python") + + def _handle_native_json_schema( provider: str, model_schema: AIModelEntity, diff --git a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py index 9046f785d2..9742590cd4 100644 --- a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py +++ b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py @@ -2,9 +2,13 @@ from decimal import Decimal from unittest.mock import MagicMock, patch import pytest +from pydantic import BaseModel, ConfigDict from core.llm_generator.output_parser.errors import OutputParserError -from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output +from core.llm_generator.output_parser.structured_output import ( + invoke_llm_with_pydantic_model, + invoke_llm_with_structured_output, +) from core.model_runtime.entities.llm_entities import ( LLMResult, LLMResultChunk, @@ -461,3 +465,68 @@ def test_model_specific_schema_preparation(): # For Gemini, the schema should not have additionalProperties and boolean should be converted to string assert "json_schema" in call_args.kwargs["model_parameters"] + + +class ExampleOutput(BaseModel): + model_config = ConfigDict(extra="forbid") + + name: str + + +def test_structured_output_with_pydantic_model(): + model_schema = get_model_entity("openai", "gpt-4o", support_structure_output=True) + model_instance = get_model_instance() + model_instance.invoke_llm.return_value = LLMResult( + model="gpt-4o", + message=AssistantPromptMessage(content='{"name": "test"}'), + usage=create_mock_usage(prompt_tokens=8, completion_tokens=4), + ) + + prompt_messages = [UserPromptMessage(content="Return a JSON object with name.")] + + result = invoke_llm_with_pydantic_model( + provider="openai", + model_schema=model_schema, + model_instance=model_instance, + prompt_messages=prompt_messages, + output_model=ExampleOutput, + stream=False, + ) + + assert isinstance(result, LLMResultWithStructuredOutput) + assert result.structured_output == {"name": "test"} + + +def test_structured_output_with_pydantic_model_streaming_rejected(): + model_schema = get_model_entity("openai", "gpt-4o", support_structure_output=True) + model_instance = get_model_instance() + + with pytest.raises(ValueError): + invoke_llm_with_pydantic_model( + provider="openai", + model_schema=model_schema, + model_instance=model_instance, + prompt_messages=[UserPromptMessage(content="test")], + output_model=ExampleOutput, + stream=True, + ) + + +def test_structured_output_with_pydantic_model_validation_error(): + model_schema = get_model_entity("openai", "gpt-4o", support_structure_output=True) + model_instance = get_model_instance() + model_instance.invoke_llm.return_value = LLMResult( + model="gpt-4o", + message=AssistantPromptMessage(content='{"name": 123}'), + usage=create_mock_usage(prompt_tokens=8, completion_tokens=4), + ) + + with pytest.raises(OutputParserError): + invoke_llm_with_pydantic_model( + provider="openai", + model_schema=model_schema, + model_instance=model_instance, + prompt_messages=[UserPromptMessage(content="test")], + output_model=ExampleOutput, + stream=False, + )