mirror of
https://github.com/langgenius/dify.git
synced 2026-06-07 16:32:01 +08:00
refactor: add missing @override decorator to remaining MCP, Jieba, embeddings, and misc subclasses (#36528)
This commit is contained in:
parent
473c945839
commit
4d8b6c7dc0
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
from typing import Literal, override
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from flask import request
|
||||
@ -76,11 +76,13 @@ def _enum_value(value):
|
||||
|
||||
|
||||
class WorkflowRunStatusField(fields.Raw):
|
||||
@override
|
||||
def output(self, key, obj: WorkflowRun, **kwargs):
|
||||
return _enum_value(obj.status)
|
||||
|
||||
|
||||
class WorkflowRunOutputsField(fields.Raw):
|
||||
@override
|
||||
def output(self, key, obj: WorkflowRun, **kwargs):
|
||||
status = _enum_value(obj.status)
|
||||
if status == WorkflowExecutionStatus.PAUSED.value:
|
||||
|
||||
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Protocol
|
||||
from typing import Any, Protocol, override
|
||||
|
||||
from graphon.enums import NodeType
|
||||
|
||||
@ -29,5 +29,6 @@ class DraftVariableSaverFactory(Protocol):
|
||||
|
||||
|
||||
class NoopDraftVariableSaver(DraftVariableSaver):
|
||||
@override
|
||||
def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None) -> None:
|
||||
return None
|
||||
|
||||
@ -6,7 +6,7 @@ import re
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator, Sequence
|
||||
from json import JSONDecodeError
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
|
||||
from sqlalchemy import func, select
|
||||
@ -1889,6 +1889,7 @@ class ProviderConfigurations(BaseModel):
|
||||
key = str(ModelProviderID(key))
|
||||
return key in self.configurations
|
||||
|
||||
@override
|
||||
def __iter__(self):
|
||||
# Return an iterator of (key, value) tuples to match BaseModel's __iter__
|
||||
yield from self.configurations.items()
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any, TypedDict, override
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
@ -29,6 +29,7 @@ class ApiExternalDataTool(ExternalDataTool):
|
||||
"""the unique name of external data tool"""
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
@ -50,6 +51,7 @@ class ApiExternalDataTool(ExternalDataTool):
|
||||
if not api_based_extension:
|
||||
raise ValueError("api_based_extension_id is invalid")
|
||||
|
||||
@override
|
||||
def query(self, inputs: Mapping[str, Any], query: str | None = None) -> str:
|
||||
"""
|
||||
Query the external data tool.
|
||||
|
||||
@ -7,7 +7,7 @@ authentication failures and retries operations after refreshing tokens.
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -159,6 +159,7 @@ class MCPClientWithAuthRetry(MCPClient):
|
||||
# Reset retry flag after operation completes
|
||||
self._has_retried = False
|
||||
|
||||
@override
|
||||
def __enter__(self):
|
||||
"""Enter the context manager with retry support."""
|
||||
|
||||
@ -168,6 +169,7 @@ class MCPClientWithAuthRetry(MCPClient):
|
||||
|
||||
return self._execute_with_retry(initialize_with_retry)
|
||||
|
||||
@override
|
||||
def list_tools(self) -> list[Tool]:
|
||||
"""
|
||||
List available tools from the MCP server with auth retry.
|
||||
@ -180,6 +182,7 @@ class MCPClientWithAuthRetry(MCPClient):
|
||||
"""
|
||||
return self._execute_with_retry(super().list_tools)
|
||||
|
||||
@override
|
||||
def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult:
|
||||
"""
|
||||
Invoke a tool on the MCP server with auth retry.
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import queue
|
||||
from datetime import timedelta
|
||||
from typing import Any, Protocol
|
||||
from typing import Any, Protocol, override
|
||||
|
||||
from pydantic import AnyUrl, TypeAdapter
|
||||
|
||||
@ -159,6 +159,7 @@ class ClientSession(
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
@override
|
||||
def send_progress_notification(self, progress_token: str | int, progress: float, total: float | None = None):
|
||||
"""Send a progress notification."""
|
||||
self.send_notification(
|
||||
@ -326,6 +327,7 @@ class ClientSession(
|
||||
)
|
||||
)
|
||||
|
||||
@override
|
||||
def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]):
|
||||
ctx = RequestContext[ClientSession, Any](
|
||||
request_id=responder.request_id,
|
||||
@ -351,6 +353,7 @@ class ClientSession(
|
||||
with responder:
|
||||
return responder.respond(types.ClientResult(root=types.EmptyResult()))
|
||||
|
||||
@override
|
||||
def _handle_incoming(
|
||||
self,
|
||||
req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
@ -358,6 +361,7 @@ class ClientSession(
|
||||
"""Handle incoming messages by forwarding to the message handler."""
|
||||
self._message_handler(req)
|
||||
|
||||
@override
|
||||
def _received_notification(self, notification: types.ServerNotification):
|
||||
"""Handle notifications from the server."""
|
||||
# Process specific notification types
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import override
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
@ -11,6 +12,7 @@ class PluginDaemonError(Exception):
|
||||
def __init__(self, description: str):
|
||||
self.description = description
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
# returns the class name and description
|
||||
return f"req_id: {get_request_id()} {self.__class__.__name__}: {self.description}"
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections import defaultdict
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any, TypedDict, override
|
||||
|
||||
import orjson
|
||||
from pydantic import BaseModel
|
||||
@ -29,6 +29,7 @@ class Jieba(BaseKeyword):
|
||||
super().__init__(dataset)
|
||||
self._config = KeywordTableConfig()
|
||||
|
||||
@override
|
||||
def create(self, texts: list[Document], **kwargs) -> BaseKeyword:
|
||||
lock_name = f"keyword_indexing_lock_{self.dataset.id}"
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
@ -48,6 +49,7 @@ class Jieba(BaseKeyword):
|
||||
|
||||
return self
|
||||
|
||||
@override
|
||||
def add_texts(self, texts: list[Document], **kwargs):
|
||||
lock_name = f"keyword_indexing_lock_{self.dataset.id}"
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
@ -72,12 +74,14 @@ class Jieba(BaseKeyword):
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
@override
|
||||
def text_exists(self, id: str) -> bool:
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
if keyword_table is None:
|
||||
return False
|
||||
return id in set.union(*keyword_table.values())
|
||||
|
||||
@override
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
lock_name = f"keyword_indexing_lock_{self.dataset.id}"
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
@ -87,6 +91,7 @@ class Jieba(BaseKeyword):
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
@override
|
||||
def search(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
|
||||
@ -122,6 +127,7 @@ class Jieba(BaseKeyword):
|
||||
|
||||
return documents
|
||||
|
||||
@override
|
||||
def delete(self):
|
||||
lock_name = f"keyword_indexing_lock_{self.dataset.id}"
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
|
||||
@ -2,7 +2,7 @@ import base64
|
||||
import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
@ -72,21 +72,27 @@ class _LazyEmbeddings(Embeddings):
|
||||
self._real = CacheEmbedding(embedding_model)
|
||||
return self._real
|
||||
|
||||
@override
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
return self._ensure().embed_documents(texts)
|
||||
|
||||
@override
|
||||
def embed_multimodal_documents(self, multimodel_documents: list[dict[str, Any]]) -> list[list[float]]:
|
||||
return self._ensure().embed_multimodal_documents(multimodel_documents)
|
||||
|
||||
@override
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
return self._ensure().embed_query(text)
|
||||
|
||||
@override
|
||||
def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]:
|
||||
return self._ensure().embed_multimodal_query(multimodel_document)
|
||||
|
||||
@override
|
||||
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
return await self._ensure().aembed_documents(texts)
|
||||
|
||||
@override
|
||||
async def aembed_query(self, text: str) -> list[float]:
|
||||
return await self._ensure().aembed_query(text)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user