refactor(models): replace Any with precise types in Tenant and MCPToo… (#34880)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
corevibe555 2026-04-10 06:12:38 +03:00 committed by GitHub
parent 40e23ce8dc
commit d826ac7099
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 90 additions and 160 deletions

View File

@ -122,7 +122,7 @@ class MCPClientWithAuthRetry(MCPClient):
logger.exception("Authentication retry failed")
raise MCPAuthError(f"Authentication retry failed: {e}") from e
def _execute_with_retry(self, func: Callable[..., Any], *args, **kwargs) -> Any:
def _execute_with_retry[**P, R](self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
"""
Execute a function with authentication retry logic.

View File

@ -1,6 +1,6 @@
from dataclasses import dataclass
from enum import StrEnum
from typing import Any, TypeVar
from typing import Any
from pydantic import BaseModel
@ -9,12 +9,9 @@ from core.mcp.types import LATEST_PROTOCOL_VERSION, OAuthClientInformation, OAut
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION]
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
LifespanContextT = TypeVar("LifespanContextT")
@dataclass
class RequestContext[SessionT: BaseSession[Any, Any, Any, Any, Any], LifespanContextT]:
class RequestContext[SessionT: BaseSession, LifespanContextT]:
request_id: RequestId
meta: RequestParams.Meta | None
session: SessionT

View File

@ -6,7 +6,7 @@ from graphon.model_runtime.callbacks.base_callback import Callback
from graphon.model_runtime.entities.llm_entities import LLMResult
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelType
from graphon.model_runtime.entities.rerank_entities import RerankResult
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
@ -172,10 +172,10 @@ class ModelInstance:
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
prompt_messages=prompt_messages,
prompt_messages=list(prompt_messages),
model_parameters=model_parameters,
tools=tools,
stop=stop,
tools=list(tools) if tools else None,
stop=list(stop) if stop else None,
stream=stream,
callbacks=callbacks,
),
@ -193,15 +193,12 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, LargeLanguageModel):
raise Exception("Model type instance is not LargeLanguageModel")
return cast(
int,
self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
model=self.model_name,
credentials=self.credentials,
prompt_messages=prompt_messages,
tools=tools,
),
return self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
model=self.model_name,
credentials=self.credentials,
prompt_messages=list(prompt_messages),
tools=list(tools) if tools else None,
)
def invoke_text_embedding(
@ -216,15 +213,12 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
return cast(
EmbeddingResult,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
texts=texts,
input_type=input_type,
),
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
texts=texts,
input_type=input_type,
)
def invoke_multimodal_embedding(
@ -241,15 +235,12 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
return cast(
EmbeddingResult,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
multimodel_documents=multimodel_documents,
input_type=input_type,
),
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
multimodel_documents=multimodel_documents,
input_type=input_type,
)
def get_text_embedding_num_tokens(self, texts: list[str]) -> list[int]:
@ -261,14 +252,11 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
return cast(
list[int],
self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
model=self.model_name,
credentials=self.credentials,
texts=texts,
),
return self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
model=self.model_name,
credentials=self.credentials,
texts=texts,
)
def invoke_rerank(
@ -289,23 +277,20 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel")
return cast(
RerankResult,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
),
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
)
def invoke_multimodal_rerank(
self,
query: dict,
docs: list[dict],
query: MultimodalRerankInput,
docs: list[MultimodalRerankInput],
score_threshold: float | None = None,
top_n: int | None = None,
) -> RerankResult:
@ -320,17 +305,14 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel")
return cast(
RerankResult,
self._round_robin_invoke(
function=self.model_type_instance.invoke_multimodal_rerank,
model=self.model_name,
credentials=self.credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
),
return self._round_robin_invoke(
function=self.model_type_instance.invoke_multimodal_rerank,
model=self.model_name,
credentials=self.credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
)
def invoke_moderation(self, text: str) -> bool:
@ -342,14 +324,11 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, ModerationModel):
raise Exception("Model type instance is not ModerationModel")
return cast(
bool,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
text=text,
),
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
text=text,
)
def invoke_speech2text(self, file: IO[bytes]) -> str:
@ -361,14 +340,11 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, Speech2TextModel):
raise Exception("Model type instance is not Speech2TextModel")
return cast(
str,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
file=file,
),
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
file=file,
)
def invoke_tts(self, content_text: str, voice: str = "") -> Iterable[bytes]:
@ -381,18 +357,15 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel")
return cast(
Iterable[bytes],
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
content_text=content_text,
voice=voice,
),
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
content_text=content_text,
voice=voice,
)
def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs):
def _round_robin_invoke[**P, R](self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
"""
Round-robin invoke
:param function: function to invoke
@ -430,9 +403,8 @@ class ModelInstance:
continue
try:
if "credentials" in kwargs:
del kwargs["credentials"]
return function(*args, **kwargs, credentials=lb_config.credentials)
kwargs["credentials"] = lb_config.credentials
return function(*args, **kwargs)
except InvokeRateLimitError as e:
# expire in 60 seconds
self.load_balancing_manager.cooldown(lb_config, expire=60)

View File

@ -1,7 +1,7 @@
import base64
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.entities.rerank_entities import RerankResult
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
from core.model_manager import ModelInstance, ModelManager
from core.rag.index_processor.constant.doc_type import DocType
@ -123,7 +123,7 @@ class RerankModelRunner(BaseRerankRunner):
:param query_type: query type
:return: rerank result
"""
docs = []
docs: list[MultimodalRerankInput] = []
doc_ids = set()
unique_documents = []
for document in documents:
@ -138,26 +138,28 @@ class RerankModelRunner(BaseRerankRunner):
if upload_file:
blob = storage.load_once(upload_file.key)
document_file_base64 = base64.b64encode(blob).decode()
document_file_dict = {
"content": document_file_base64,
"content_type": document.metadata["doc_type"],
}
docs.append(document_file_dict)
docs.append(
MultimodalRerankInput(
content=document_file_base64,
content_type=document.metadata["doc_type"],
)
)
else:
document_text_dict = {
"content": document.page_content,
"content_type": document.metadata.get("doc_type") or DocType.TEXT,
}
docs.append(document_text_dict)
docs.append(
MultimodalRerankInput(
content=document.page_content,
content_type=document.metadata.get("doc_type") or DocType.TEXT,
)
)
doc_ids.add(document.metadata["doc_id"])
unique_documents.append(document)
elif document.provider == "external":
if document not in unique_documents:
docs.append(
{
"content": document.page_content,
"content_type": document.metadata.get("doc_type") or DocType.TEXT,
}
MultimodalRerankInput(
content=document.page_content,
content_type=document.metadata.get("doc_type") or DocType.TEXT,
)
)
unique_documents.append(document)
@ -171,12 +173,12 @@ class RerankModelRunner(BaseRerankRunner):
if upload_file:
blob = storage.load_once(upload_file.key)
file_query = base64.b64encode(blob).decode()
file_query_dict = {
"content": file_query,
"content_type": DocType.IMAGE,
}
file_query_input = MultimodalRerankInput(
content=file_query,
content_type=DocType.IMAGE,
)
rerank_result = self.rerank_model_instance.invoke_multimodal_rerank(
query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n
query=file_query_input, docs=docs, score_threshold=score_threshold, top_n=top_n
)
return rerank_result, unique_documents
else:

View File

@ -4,9 +4,7 @@ from unittest.mock import Mock
from core.mcp.entities import (
SUPPORTED_PROTOCOL_VERSIONS,
LifespanContextT,
RequestContext,
SessionT,
)
from core.mcp.session.base_session import BaseSession
from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestParams
@ -198,42 +196,3 @@ class TestRequestContext:
assert "RequestContext" in repr_str
assert "test-123" in repr_str
assert "MockSession" in repr_str
class TestTypeVariables:
"""Test type variables defined in the module."""
def test_session_type_var(self):
"""Test SessionT type variable."""
# Create a custom session class
class CustomSession(BaseSession):
pass
# Use in generic context
def process_session(session: SessionT) -> SessionT:
return session
mock_session = Mock(spec=CustomSession)
result = process_session(mock_session)
assert result == mock_session
def test_lifespan_context_type_var(self):
"""Test LifespanContextT type variable."""
# Use in generic context
def process_lifespan(context: LifespanContextT) -> LifespanContextT:
return context
# Test with different types
str_context = "string-context"
assert process_lifespan(str_context) == str_context
dict_context = {"key": "value"}
assert process_lifespan(dict_context) == dict_context
class CustomContext:
pass
custom_context = CustomContext()
assert process_lifespan(custom_context) == custom_context