refactor(api): fix pyright errors in jieba, milvus, couchbase, oracle, and router (#34938)

Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
This commit is contained in:
tmimmanuel 2026-04-24 00:30:28 +02:00 committed by GitHub
parent 0c8dec3315
commit ed8d3f3e8d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 34 additions and 24 deletions

View File

@ -156,7 +156,8 @@ class Jieba(BaseKeyword):
if dataset_keyword_table: if dataset_keyword_table:
keyword_table_dict = dataset_keyword_table.keyword_table_dict keyword_table_dict = dataset_keyword_table.keyword_table_dict
if keyword_table_dict: if keyword_table_dict:
return dict(keyword_table_dict["__data__"]["table"]) data: Any = keyword_table_dict["__data__"]
return dict(data["table"])
else: else:
keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE
dataset_keyword_table = DatasetKeywordTable( dataset_keyword_table = DatasetKeywordTable(

View File

@ -109,7 +109,7 @@ class JiebaKeywordTableHandler:
"""Extract keywords with JIEBA tfidf.""" """Extract keywords with JIEBA tfidf."""
keywords = self._tfidf.extract_tags( keywords = self._tfidf.extract_tags(
sentence=text, sentence=text,
topK=max_keywords_per_chunk, topK=max_keywords_per_chunk or 10,
) )
# jieba.analyse.extract_tags returns an untyped list when withFlag is False by default. # jieba.analyse.extract_tags returns an untyped list when withFlag is False by default.
keywords = cast(list[str], keywords) keywords = cast(list[str], keywords)

View File

@ -31,7 +31,7 @@ class FunctionCallMultiDatasetRouter:
result: LLMResult = model_instance.invoke_llm( # pyright: ignore[reportCallIssue, reportArgumentType] result: LLMResult = model_instance.invoke_llm( # pyright: ignore[reportCallIssue, reportArgumentType]
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
tools=dataset_tools, tools=dataset_tools,
stream=False, stream=False, # pyright: ignore[reportArgumentType]
model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500}, model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
) )
usage = result.usage or LLMUsage.empty_usage() usage = result.usage or LLMUsage.empty_usage()

View File

@ -59,7 +59,7 @@ class CouchbaseVector(BaseVector):
auth = PasswordAuthenticator(config.user, config.password) auth = PasswordAuthenticator(config.user, config.password)
options = ClusterOptions(auth) options = ClusterOptions(auth)
self._cluster = Cluster(config.connection_string, options) self._cluster = Cluster(config.connection_string, options) # pyright: ignore[reportArgumentType]
self._bucket = self._cluster.bucket(config.bucket_name) self._bucket = self._cluster.bucket(config.bucket_name)
self._scope = self._bucket.scope(config.scope_name) self._scope = self._bucket.scope(config.scope_name)
self._bucket_name = config.bucket_name self._bucket_name = config.bucket_name
@ -306,7 +306,7 @@ class CouchbaseVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4) top_k = kwargs.get("top_k", 4)
try: try:
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query)) CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query)) # pyright: ignore[reportCallIssue]
search_iter = self._scope.search( search_iter = self._scope.search(
self._collection_name + "_search", CBrequest, SearchOptions(limit=top_k, fields=["*"]) self._collection_name + "_search", CBrequest, SearchOptions(limit=top_k, fields=["*"])
) )

View File

@ -1,6 +1,6 @@
import json import json
import logging import logging
from typing import Any, TypedDict from typing import Any, TypedDict, cast
from packaging import version from packaging import version
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
@ -92,7 +92,7 @@ class MilvusVector(BaseVector):
def _load_collection_fields(self, fields: list[str] | None = None): def _load_collection_fields(self, fields: list[str] | None = None):
if fields is None: if fields is None:
# Load collection fields from remote server # Load collection fields from remote server
collection_info = self._client.describe_collection(self._collection_name) collection_info = cast(dict[str, Any], self._client.describe_collection(self._collection_name))
fields = [field["name"] for field in collection_info["fields"]] fields = [field["name"] for field in collection_info["fields"]]
# Since primary field is auto-id, no need to track it # Since primary field is auto-id, no need to track it
self._fields = [f for f in fields if f != Field.PRIMARY_KEY] self._fields = [f for f in fields if f != Field.PRIMARY_KEY]
@ -106,7 +106,8 @@ class MilvusVector(BaseVector):
return False return False
try: try:
milvus_version = self._client.get_server_version() milvus_version_raw = self._client.get_server_version()
milvus_version = milvus_version_raw if isinstance(milvus_version_raw, str) else str(milvus_version_raw)
# Check if it's Zilliz Cloud - it supports full-text search with Milvus 2.5 compatibility # Check if it's Zilliz Cloud - it supports full-text search with Milvus 2.5 compatibility
if "Zilliz Cloud" in milvus_version: if "Zilliz Cloud" in milvus_version:
return True return True

View File

@ -3,7 +3,7 @@ import json
import logging import logging
import re import re
import uuid import uuid
from typing import Any from typing import Any, TypedDict
import jieba.posseg as pseg # type: ignore import jieba.posseg as pseg # type: ignore
import numpy import numpy
@ -25,6 +25,18 @@ logger = logging.getLogger(__name__)
oracledb.defaults.fetch_lobs = False oracledb.defaults.fetch_lobs = False
class _OraclePoolParams(TypedDict, total=False):
user: str
password: str
dsn: str
min: int
max: int
increment: int
config_dir: str | None
wallet_location: str | None
wallet_password: str | None
class OracleVectorConfig(BaseModel): class OracleVectorConfig(BaseModel):
user: str user: str
password: str password: str
@ -127,22 +139,18 @@ class OracleVector(BaseVector):
return connection return connection
def _create_connection_pool(self, config: OracleVectorConfig): def _create_connection_pool(self, config: OracleVectorConfig):
pool_params = { pool_params = _OraclePoolParams(
"user": config.user, user=config.user,
"password": config.password, password=config.password,
"dsn": config.dsn, dsn=config.dsn,
"min": 1, min=1,
"max": 5, max=5,
"increment": 1, increment=1,
} )
if config.is_autonomous: if config.is_autonomous:
pool_params.update( pool_params["config_dir"] = config.config_dir
{ pool_params["wallet_location"] = config.wallet_location
"config_dir": config.config_dir, pool_params["wallet_password"] = config.wallet_password
"wallet_location": config.wallet_location,
"wallet_password": config.wallet_password,
}
)
return oracledb.create_pool(**pool_params) return oracledb.create_pool(**pool_params)
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):