mirror of
https://github.com/langgenius/dify.git
synced 2026-05-09 12:59:18 +08:00
fix type check
This commit is contained in:
parent
f8912c920e
commit
28289212fb
@ -16,6 +16,7 @@ from graphon.graph_engine.entities.commands import AbortCommand, CommandType
|
||||
from graphon.graph_engine.layers import GraphEngineLayer
|
||||
from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, NodeRunSucceededEvent
|
||||
from graphon.nodes.base.node import Node
|
||||
from graphon.nodes.llm.runtime_protocols import PreparedLLMProtocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphon.nodes.llm.node import LLMNode
|
||||
@ -116,7 +117,8 @@ class LLMQuotaLayer(GraphEngineLayer):
|
||||
case BuiltinNodeTypes.PARAMETER_EXTRACTOR:
|
||||
model_instance = cast("ParameterExtractorNode", node).model_instance
|
||||
case BuiltinNodeTypes.QUESTION_CLASSIFIER:
|
||||
model_instance = cast("QuestionClassifierNode", node).model_instance
|
||||
typed_node: QuestionClassifierNode = cast("QuestionClassifierNode", node)
|
||||
model_instance = cast(PreparedLLMProtocol, getattr(typed_node, "_model_instance"))
|
||||
case _:
|
||||
return None
|
||||
except AttributeError:
|
||||
|
||||
@ -4,7 +4,7 @@ import hashlib
|
||||
import logging
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
from threading import Lock
|
||||
from typing import IO, Any, Union
|
||||
from typing import IO, Any, Literal, Union, overload
|
||||
|
||||
from pydantic import ValidationError
|
||||
from redis import RedisError
|
||||
@ -15,7 +15,12 @@ from core.plugin.impl.asset import PluginAssetManager
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
from extensions.ext_redis import redis_client
|
||||
from graphon.model_runtime import ModelRuntime
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from graphon.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkWithStructuredOutput,
|
||||
LLMResultWithStructuredOutput,
|
||||
)
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from graphon.model_runtime.entities.provider_entities import ProviderEntity
|
||||
@ -195,6 +200,34 @@ class PluginModelRuntime(ModelRuntime):
|
||||
|
||||
return schema
|
||||
|
||||
@overload
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
model_parameters: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None,
|
||||
stop: Sequence[str] | None,
|
||||
stream: Literal[False],
|
||||
) -> LLMResult: ...
|
||||
|
||||
@overload
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
model_parameters: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None,
|
||||
stop: Sequence[str] | None,
|
||||
stream: Literal[True],
|
||||
) -> Generator[LLMResultChunk, None, None]: ...
|
||||
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
@ -221,6 +254,51 @@ class PluginModelRuntime(ModelRuntime):
|
||||
stop=list(stop) if stop else None,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
json_schema: dict[str, Any],
|
||||
model_parameters: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
stop: Sequence[str] | None,
|
||||
stream: Literal[False],
|
||||
) -> LLMResultWithStructuredOutput: ...
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
json_schema: dict[str, Any],
|
||||
model_parameters: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
stop: Sequence[str] | None,
|
||||
stream: Literal[True],
|
||||
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
|
||||
def invoke_llm_with_structured_output(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
json_schema: dict[str, Any],
|
||||
model_parameters: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
stop: Sequence[str] | None,
|
||||
stream: bool,
|
||||
) -> (
|
||||
LLMResultWithStructuredOutput
|
||||
| Generator[LLMResultChunkWithStructuredOutput, None, None]
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_llm_num_tokens(
|
||||
self,
|
||||
|
||||
@ -35,7 +35,7 @@ class AgentNode(Node[AgentNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
config: AgentNodeData,
|
||||
data: AgentNodeData,
|
||||
*,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
@ -46,7 +46,7 @@ class AgentNode(Node[AgentNodeData]):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
node_id=node_id,
|
||||
config=config,
|
||||
data=data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Any, Protocol, cast
|
||||
from typing import Any, Protocol
|
||||
from uuid import uuid4
|
||||
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
@ -82,13 +82,10 @@ def build_system_variables(values: Mapping[str, Any] | None = None, /, **kwargs:
|
||||
normalized = _normalize_system_variable_values(values, **kwargs)
|
||||
|
||||
return [
|
||||
cast(
|
||||
Variable,
|
||||
segment_to_variable(
|
||||
segment=build_segment(value),
|
||||
selector=system_variable_selector(key),
|
||||
name=key,
|
||||
),
|
||||
segment_to_variable(
|
||||
segment=build_segment(value),
|
||||
selector=system_variable_selector(key),
|
||||
name=key,
|
||||
)
|
||||
for key, value in normalized.items()
|
||||
]
|
||||
@ -130,13 +127,10 @@ def build_bootstrap_variables(
|
||||
|
||||
for node_id, value in rag_pipeline_variables_map.items():
|
||||
variables.append(
|
||||
cast(
|
||||
Variable,
|
||||
segment_to_variable(
|
||||
segment=build_segment(value),
|
||||
selector=(RAG_PIPELINE_VARIABLE_NODE_ID, node_id),
|
||||
name=node_id,
|
||||
),
|
||||
segment_to_variable(
|
||||
segment=build_segment(value),
|
||||
selector=(RAG_PIPELINE_VARIABLE_NODE_ID, node_id),
|
||||
name=node_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ from collections.abc import Mapping, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any, ClassVar, NotRequired, TypedDict
|
||||
from typing import Any, ClassVar, NotRequired, TypedDict, cast
|
||||
|
||||
from sqlalchemy import Engine, delete, orm, select
|
||||
from sqlalchemy.dialects.mysql import insert as mysql_insert
|
||||
|
||||
Loading…
Reference in New Issue
Block a user