mirror of
https://github.com/langgenius/dify.git
synced 2026-04-15 18:06:36 +08:00
refactor(models): replace Any with precise types in Tenant and MCPToo… (#34880)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
40e23ce8dc
commit
d826ac7099
@ -122,7 +122,7 @@ class MCPClientWithAuthRetry(MCPClient):
|
||||
logger.exception("Authentication retry failed")
|
||||
raise MCPAuthError(f"Authentication retry failed: {e}") from e
|
||||
|
||||
def _execute_with_retry(self, func: Callable[..., Any], *args, **kwargs) -> Any:
|
||||
def _execute_with_retry[**P, R](self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""
|
||||
Execute a function with authentication retry logic.
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from typing import Any, TypeVar
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -9,12 +9,9 @@ from core.mcp.types import LATEST_PROTOCOL_VERSION, OAuthClientInformation, OAut
|
||||
|
||||
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION]
|
||||
|
||||
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
|
||||
LifespanContextT = TypeVar("LifespanContextT")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestContext[SessionT: BaseSession[Any, Any, Any, Any, Any], LifespanContextT]:
|
||||
class RequestContext[SessionT: BaseSession, LifespanContextT]:
|
||||
request_id: RequestId
|
||||
meta: RequestParams.Meta | None
|
||||
session: SessionT
|
||||
|
||||
@ -6,7 +6,7 @@ from graphon.model_runtime.callbacks.base_callback import Callback
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelType
|
||||
from graphon.model_runtime.entities.rerank_entities import RerankResult
|
||||
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
|
||||
from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult
|
||||
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
@ -172,10 +172,10 @@ class ModelInstance:
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
tools=list(tools) if tools else None,
|
||||
stop=list(stop) if stop else None,
|
||||
stream=stream,
|
||||
callbacks=callbacks,
|
||||
),
|
||||
@ -193,15 +193,12 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, LargeLanguageModel):
|
||||
raise Exception("Model type instance is not LargeLanguageModel")
|
||||
return cast(
|
||||
int,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools,
|
||||
),
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=list(prompt_messages),
|
||||
tools=list(tools) if tools else None,
|
||||
)
|
||||
|
||||
def invoke_text_embedding(
|
||||
@ -216,15 +213,12 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
return cast(
|
||||
EmbeddingResult,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
input_type=input_type,
|
||||
),
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
input_type=input_type,
|
||||
)
|
||||
|
||||
def invoke_multimodal_embedding(
|
||||
@ -241,15 +235,12 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
return cast(
|
||||
EmbeddingResult,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
multimodel_documents=multimodel_documents,
|
||||
input_type=input_type,
|
||||
),
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
multimodel_documents=multimodel_documents,
|
||||
input_type=input_type,
|
||||
)
|
||||
|
||||
def get_text_embedding_num_tokens(self, texts: list[str]) -> list[int]:
|
||||
@ -261,14 +252,11 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
return cast(
|
||||
list[int],
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
),
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
)
|
||||
|
||||
def invoke_rerank(
|
||||
@ -289,23 +277,20 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, RerankModel):
|
||||
raise Exception("Model type instance is not RerankModel")
|
||||
return cast(
|
||||
RerankResult,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
),
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
)
|
||||
|
||||
def invoke_multimodal_rerank(
|
||||
self,
|
||||
query: dict,
|
||||
docs: list[dict],
|
||||
query: MultimodalRerankInput,
|
||||
docs: list[MultimodalRerankInput],
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
) -> RerankResult:
|
||||
@ -320,17 +305,14 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, RerankModel):
|
||||
raise Exception("Model type instance is not RerankModel")
|
||||
return cast(
|
||||
RerankResult,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke_multimodal_rerank,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
),
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke_multimodal_rerank,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
)
|
||||
|
||||
def invoke_moderation(self, text: str) -> bool:
|
||||
@ -342,14 +324,11 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, ModerationModel):
|
||||
raise Exception("Model type instance is not ModerationModel")
|
||||
return cast(
|
||||
bool,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
text=text,
|
||||
),
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
text=text,
|
||||
)
|
||||
|
||||
def invoke_speech2text(self, file: IO[bytes]) -> str:
|
||||
@ -361,14 +340,11 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, Speech2TextModel):
|
||||
raise Exception("Model type instance is not Speech2TextModel")
|
||||
return cast(
|
||||
str,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
file=file,
|
||||
),
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
file=file,
|
||||
)
|
||||
|
||||
def invoke_tts(self, content_text: str, voice: str = "") -> Iterable[bytes]:
|
||||
@ -381,18 +357,15 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TTSModel):
|
||||
raise Exception("Model type instance is not TTSModel")
|
||||
return cast(
|
||||
Iterable[bytes],
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
content_text=content_text,
|
||||
voice=voice,
|
||||
),
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
content_text=content_text,
|
||||
voice=voice,
|
||||
)
|
||||
|
||||
def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs):
|
||||
def _round_robin_invoke[**P, R](self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""
|
||||
Round-robin invoke
|
||||
:param function: function to invoke
|
||||
@ -430,9 +403,8 @@ class ModelInstance:
|
||||
continue
|
||||
|
||||
try:
|
||||
if "credentials" in kwargs:
|
||||
del kwargs["credentials"]
|
||||
return function(*args, **kwargs, credentials=lb_config.credentials)
|
||||
kwargs["credentials"] = lb_config.credentials
|
||||
return function(*args, **kwargs)
|
||||
except InvokeRateLimitError as e:
|
||||
# expire in 60 seconds
|
||||
self.load_balancing_manager.cooldown(lb_config, expire=60)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import base64
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.entities.rerank_entities import RerankResult
|
||||
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
|
||||
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
@ -123,7 +123,7 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
:param query_type: query type
|
||||
:return: rerank result
|
||||
"""
|
||||
docs = []
|
||||
docs: list[MultimodalRerankInput] = []
|
||||
doc_ids = set()
|
||||
unique_documents = []
|
||||
for document in documents:
|
||||
@ -138,26 +138,28 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
if upload_file:
|
||||
blob = storage.load_once(upload_file.key)
|
||||
document_file_base64 = base64.b64encode(blob).decode()
|
||||
document_file_dict = {
|
||||
"content": document_file_base64,
|
||||
"content_type": document.metadata["doc_type"],
|
||||
}
|
||||
docs.append(document_file_dict)
|
||||
docs.append(
|
||||
MultimodalRerankInput(
|
||||
content=document_file_base64,
|
||||
content_type=document.metadata["doc_type"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
document_text_dict = {
|
||||
"content": document.page_content,
|
||||
"content_type": document.metadata.get("doc_type") or DocType.TEXT,
|
||||
}
|
||||
docs.append(document_text_dict)
|
||||
docs.append(
|
||||
MultimodalRerankInput(
|
||||
content=document.page_content,
|
||||
content_type=document.metadata.get("doc_type") or DocType.TEXT,
|
||||
)
|
||||
)
|
||||
doc_ids.add(document.metadata["doc_id"])
|
||||
unique_documents.append(document)
|
||||
elif document.provider == "external":
|
||||
if document not in unique_documents:
|
||||
docs.append(
|
||||
{
|
||||
"content": document.page_content,
|
||||
"content_type": document.metadata.get("doc_type") or DocType.TEXT,
|
||||
}
|
||||
MultimodalRerankInput(
|
||||
content=document.page_content,
|
||||
content_type=document.metadata.get("doc_type") or DocType.TEXT,
|
||||
)
|
||||
)
|
||||
unique_documents.append(document)
|
||||
|
||||
@ -171,12 +173,12 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
if upload_file:
|
||||
blob = storage.load_once(upload_file.key)
|
||||
file_query = base64.b64encode(blob).decode()
|
||||
file_query_dict = {
|
||||
"content": file_query,
|
||||
"content_type": DocType.IMAGE,
|
||||
}
|
||||
file_query_input = MultimodalRerankInput(
|
||||
content=file_query,
|
||||
content_type=DocType.IMAGE,
|
||||
)
|
||||
rerank_result = self.rerank_model_instance.invoke_multimodal_rerank(
|
||||
query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n
|
||||
query=file_query_input, docs=docs, score_threshold=score_threshold, top_n=top_n
|
||||
)
|
||||
return rerank_result, unique_documents
|
||||
else:
|
||||
|
||||
@ -4,9 +4,7 @@ from unittest.mock import Mock
|
||||
|
||||
from core.mcp.entities import (
|
||||
SUPPORTED_PROTOCOL_VERSIONS,
|
||||
LifespanContextT,
|
||||
RequestContext,
|
||||
SessionT,
|
||||
)
|
||||
from core.mcp.session.base_session import BaseSession
|
||||
from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestParams
|
||||
@ -198,42 +196,3 @@ class TestRequestContext:
|
||||
assert "RequestContext" in repr_str
|
||||
assert "test-123" in repr_str
|
||||
assert "MockSession" in repr_str
|
||||
|
||||
|
||||
class TestTypeVariables:
|
||||
"""Test type variables defined in the module."""
|
||||
|
||||
def test_session_type_var(self):
|
||||
"""Test SessionT type variable."""
|
||||
|
||||
# Create a custom session class
|
||||
class CustomSession(BaseSession):
|
||||
pass
|
||||
|
||||
# Use in generic context
|
||||
def process_session(session: SessionT) -> SessionT:
|
||||
return session
|
||||
|
||||
mock_session = Mock(spec=CustomSession)
|
||||
result = process_session(mock_session)
|
||||
assert result == mock_session
|
||||
|
||||
def test_lifespan_context_type_var(self):
|
||||
"""Test LifespanContextT type variable."""
|
||||
|
||||
# Use in generic context
|
||||
def process_lifespan(context: LifespanContextT) -> LifespanContextT:
|
||||
return context
|
||||
|
||||
# Test with different types
|
||||
str_context = "string-context"
|
||||
assert process_lifespan(str_context) == str_context
|
||||
|
||||
dict_context = {"key": "value"}
|
||||
assert process_lifespan(dict_context) == dict_context
|
||||
|
||||
class CustomContext:
|
||||
pass
|
||||
|
||||
custom_context = CustomContext()
|
||||
assert process_lifespan(custom_context) == custom_context
|
||||
|
||||
Loading…
Reference in New Issue
Block a user