fix type check

This commit is contained in:
yunlu.wen 2026-04-30 18:55:14 +08:00
parent f8912c920e
commit 28289212fb
5 changed files with 95 additions and 21 deletions

View File

@ -16,6 +16,7 @@ from graphon.graph_engine.entities.commands import AbortCommand, CommandType
from graphon.graph_engine.layers import GraphEngineLayer
from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, NodeRunSucceededEvent
from graphon.nodes.base.node import Node
from graphon.nodes.llm.runtime_protocols import PreparedLLMProtocol
if TYPE_CHECKING:
from graphon.nodes.llm.node import LLMNode
@ -116,7 +117,8 @@ class LLMQuotaLayer(GraphEngineLayer):
case BuiltinNodeTypes.PARAMETER_EXTRACTOR:
model_instance = cast("ParameterExtractorNode", node).model_instance
case BuiltinNodeTypes.QUESTION_CLASSIFIER:
model_instance = cast("QuestionClassifierNode", node).model_instance
typed_node: QuestionClassifierNode = cast("QuestionClassifierNode", node)
model_instance = cast(PreparedLLMProtocol, getattr(typed_node, "_model_instance"))
case _:
return None
except AttributeError:

View File

@ -4,7 +4,7 @@ import hashlib
import logging
from collections.abc import Generator, Iterable, Sequence
from threading import Lock
from typing import IO, Any, Union
from typing import IO, Any, Literal, Union, overload
from pydantic import ValidationError
from redis import RedisError
@ -15,7 +15,12 @@ from core.plugin.impl.asset import PluginAssetManager
from core.plugin.impl.model import PluginModelClient
from extensions.ext_redis import redis_client
from graphon.model_runtime import ModelRuntime
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from graphon.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
LLMResultChunkWithStructuredOutput,
LLMResultWithStructuredOutput,
)
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType
from graphon.model_runtime.entities.provider_entities import ProviderEntity
@ -195,6 +200,34 @@ class PluginModelRuntime(ModelRuntime):
return schema
@overload
def invoke_llm(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
tools: list[PromptMessageTool] | None,
stop: Sequence[str] | None,
stream: Literal[False],
) -> LLMResult: ...
@overload
def invoke_llm(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
tools: list[PromptMessageTool] | None,
stop: Sequence[str] | None,
stream: Literal[True],
) -> Generator[LLMResultChunk, None, None]: ...
def invoke_llm(
self,
*,
@ -221,6 +254,51 @@ class PluginModelRuntime(ModelRuntime):
stop=list(stop) if stop else None,
stream=stream,
)
@overload
def invoke_llm_with_structured_output(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
json_schema: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
stop: Sequence[str] | None,
stream: Literal[False],
) -> LLMResultWithStructuredOutput: ...
@overload
def invoke_llm_with_structured_output(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
json_schema: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
stop: Sequence[str] | None,
stream: Literal[True],
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
def invoke_llm_with_structured_output(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
json_schema: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
stop: Sequence[str] | None,
stream: bool,
) -> (
LLMResultWithStructuredOutput
| Generator[LLMResultChunkWithStructuredOutput, None, None]
):
raise NotImplementedError
def get_llm_num_tokens(
self,

View File

@ -35,7 +35,7 @@ class AgentNode(Node[AgentNodeData]):
def __init__(
self,
node_id: str,
config: AgentNodeData,
data: AgentNodeData,
*,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
@ -46,7 +46,7 @@ class AgentNode(Node[AgentNodeData]):
) -> None:
super().__init__(
node_id=node_id,
config=config,
data=data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)

View File

@ -3,7 +3,7 @@ from __future__ import annotations
from collections import defaultdict
from collections.abc import Mapping, Sequence
from enum import StrEnum
from typing import Any, Protocol, cast
from typing import Any, Protocol
from uuid import uuid4
from graphon.enums import BuiltinNodeTypes
@ -82,13 +82,10 @@ def build_system_variables(values: Mapping[str, Any] | None = None, /, **kwargs:
normalized = _normalize_system_variable_values(values, **kwargs)
return [
cast(
Variable,
segment_to_variable(
segment=build_segment(value),
selector=system_variable_selector(key),
name=key,
),
segment_to_variable(
segment=build_segment(value),
selector=system_variable_selector(key),
name=key,
)
for key, value in normalized.items()
]
@ -130,13 +127,10 @@ def build_bootstrap_variables(
for node_id, value in rag_pipeline_variables_map.items():
variables.append(
cast(
Variable,
segment_to_variable(
segment=build_segment(value),
selector=(RAG_PIPELINE_VARIABLE_NODE_ID, node_id),
name=node_id,
),
segment_to_variable(
segment=build_segment(value),
selector=(RAG_PIPELINE_VARIABLE_NODE_ID, node_id),
name=node_id,
)
)

View File

@ -5,7 +5,7 @@ from collections.abc import Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from enum import StrEnum
from typing import Any, ClassVar, NotRequired, TypedDict
from typing import Any, ClassVar, NotRequired, TypedDict, cast
from sqlalchemy import Engine, delete, orm, select
from sqlalchemy.dialects.mysql import insert as mysql_insert