From 28289212fb533c6c19eb5ecff59747b980530565 Mon Sep 17 00:00:00 2001 From: "yunlu.wen" Date: Thu, 30 Apr 2026 18:55:14 +0800 Subject: [PATCH] fix type check --- api/core/app/workflow/layers/llm_quota.py | 4 +- api/core/plugin/impl/model_runtime.py | 82 ++++++++++++++++++- api/core/workflow/nodes/agent/agent_node.py | 4 +- api/core/workflow/system_variables.py | 24 ++---- .../workflow_draft_variable_service.py | 2 +- 5 files changed, 95 insertions(+), 21 deletions(-) diff --git a/api/core/app/workflow/layers/llm_quota.py b/api/core/app/workflow/layers/llm_quota.py index 4a7918032e..cf5a43eac2 100644 --- a/api/core/app/workflow/layers/llm_quota.py +++ b/api/core/app/workflow/layers/llm_quota.py @@ -16,6 +16,7 @@ from graphon.graph_engine.entities.commands import AbortCommand, CommandType from graphon.graph_engine.layers import GraphEngineLayer from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, NodeRunSucceededEvent from graphon.nodes.base.node import Node +from graphon.nodes.llm.runtime_protocols import PreparedLLMProtocol if TYPE_CHECKING: from graphon.nodes.llm.node import LLMNode @@ -116,7 +117,8 @@ class LLMQuotaLayer(GraphEngineLayer): case BuiltinNodeTypes.PARAMETER_EXTRACTOR: model_instance = cast("ParameterExtractorNode", node).model_instance case BuiltinNodeTypes.QUESTION_CLASSIFIER: - model_instance = cast("QuestionClassifierNode", node).model_instance + typed_node: QuestionClassifierNode = cast("QuestionClassifierNode", node) + model_instance = cast(PreparedLLMProtocol, getattr(typed_node, "_model_instance")) case _: return None except AttributeError: diff --git a/api/core/plugin/impl/model_runtime.py b/api/core/plugin/impl/model_runtime.py index 668aa4159c..7d0f7fd5ba 100644 --- a/api/core/plugin/impl/model_runtime.py +++ b/api/core/plugin/impl/model_runtime.py @@ -4,7 +4,7 @@ import hashlib import logging from collections.abc import Generator, Iterable, Sequence from threading import Lock -from typing import IO, Any, Union +from typing import IO, Any, Literal, Union, overload from pydantic import ValidationError from redis import RedisError @@ -15,7 +15,12 @@ from core.plugin.impl.asset import PluginAssetManager from core.plugin.impl.model import PluginModelClient from extensions.ext_redis import redis_client from graphon.model_runtime import ModelRuntime -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from graphon.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkWithStructuredOutput, + LLMResultWithStructuredOutput, +) from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType from graphon.model_runtime.entities.provider_entities import ProviderEntity @@ -195,6 +200,34 @@ class PluginModelRuntime(ModelRuntime): return schema + @overload + def invoke_llm( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + model_parameters: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + tools: list[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: Literal[False], + ) -> LLMResult: ... + + @overload + def invoke_llm( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + model_parameters: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + tools: list[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: Literal[True], + ) -> Generator[LLMResultChunk, None, None]: ... + def invoke_llm( self, *, @@ -221,6 +254,51 @@ class PluginModelRuntime(ModelRuntime): stop=list(stop) if stop else None, stream=stream, ) + + @overload + def invoke_llm_with_structured_output( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + json_schema: dict[str, Any], + model_parameters: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + stop: Sequence[str] | None, + stream: Literal[False], + ) -> LLMResultWithStructuredOutput: ... + + @overload + def invoke_llm_with_structured_output( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + json_schema: dict[str, Any], + model_parameters: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + stop: Sequence[str] | None, + stream: Literal[True], + ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ... + + def invoke_llm_with_structured_output( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + json_schema: dict[str, Any], + model_parameters: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + stop: Sequence[str] | None, + stream: bool, + ) -> ( + LLMResultWithStructuredOutput + | Generator[LLMResultChunkWithStructuredOutput, None, None] + ): + raise NotImplementedError def get_llm_num_tokens( self, diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 68a24e86b1..17d71668cb 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -35,7 +35,7 @@ class AgentNode(Node[AgentNodeData]): def __init__( self, node_id: str, - config: AgentNodeData, + data: AgentNodeData, *, graph_init_params: GraphInitParams, graph_runtime_state: GraphRuntimeState, @@ -46,7 +46,7 @@ class AgentNode(Node[AgentNodeData]): ) -> None: super().__init__( node_id=node_id, - config=config, + data=data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) diff --git a/api/core/workflow/system_variables.py b/api/core/workflow/system_variables.py index 9d15a3fcea..77ef3826e9 100644 --- a/api/core/workflow/system_variables.py +++ b/api/core/workflow/system_variables.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections import defaultdict from collections.abc import Mapping, Sequence from enum import StrEnum -from typing import Any, Protocol, cast +from typing import Any, Protocol from uuid import uuid4 from graphon.enums import BuiltinNodeTypes @@ -82,13 +82,10 @@ def build_system_variables(values: Mapping[str, Any] | None = None, /, **kwargs: normalized = _normalize_system_variable_values(values, **kwargs) return [ - cast( - Variable, - segment_to_variable( - segment=build_segment(value), - selector=system_variable_selector(key), - name=key, - ), + segment_to_variable( + segment=build_segment(value), + selector=system_variable_selector(key), + name=key, ) for key, value in normalized.items() ] @@ -130,13 +127,10 @@ def build_bootstrap_variables( for node_id, value in rag_pipeline_variables_map.items(): variables.append( - cast( - Variable, - segment_to_variable( - segment=build_segment(value), - selector=(RAG_PIPELINE_VARIABLE_NODE_ID, node_id), - name=node_id, - ), + segment_to_variable( + segment=build_segment(value), + selector=(RAG_PIPELINE_VARIABLE_NODE_ID, node_id), + name=node_id, ) ) diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index a55448e352..fa0b6fa3fa 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -5,7 +5,7 @@ from collections.abc import Mapping, Sequence from concurrent.futures import ThreadPoolExecutor from datetime import datetime from enum import StrEnum -from typing import Any, ClassVar, NotRequired, TypedDict +from typing import Any, ClassVar, NotRequired, TypedDict, cast from sqlalchemy import Engine, delete, orm, select from sqlalchemy.dialects.mysql import insert as mysql_insert