mirror of
https://github.com/langgenius/dify.git
synced 2026-05-13 08:57:28 +08:00
refactor(api): fix pyright errors in jieba, milvus, couchbase, oracle, and router (#34938)
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
This commit is contained in:
parent
0c8dec3315
commit
ed8d3f3e8d
@ -156,7 +156,8 @@ class Jieba(BaseKeyword):
|
|||||||
if dataset_keyword_table:
|
if dataset_keyword_table:
|
||||||
keyword_table_dict = dataset_keyword_table.keyword_table_dict
|
keyword_table_dict = dataset_keyword_table.keyword_table_dict
|
||||||
if keyword_table_dict:
|
if keyword_table_dict:
|
||||||
return dict(keyword_table_dict["__data__"]["table"])
|
data: Any = keyword_table_dict["__data__"]
|
||||||
|
return dict(data["table"])
|
||||||
else:
|
else:
|
||||||
keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE
|
keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE
|
||||||
dataset_keyword_table = DatasetKeywordTable(
|
dataset_keyword_table = DatasetKeywordTable(
|
||||||
|
|||||||
@ -109,7 +109,7 @@ class JiebaKeywordTableHandler:
|
|||||||
"""Extract keywords with JIEBA tfidf."""
|
"""Extract keywords with JIEBA tfidf."""
|
||||||
keywords = self._tfidf.extract_tags(
|
keywords = self._tfidf.extract_tags(
|
||||||
sentence=text,
|
sentence=text,
|
||||||
topK=max_keywords_per_chunk,
|
topK=max_keywords_per_chunk or 10,
|
||||||
)
|
)
|
||||||
# jieba.analyse.extract_tags returns an untyped list when withFlag is False by default.
|
# jieba.analyse.extract_tags returns an untyped list when withFlag is False by default.
|
||||||
keywords = cast(list[str], keywords)
|
keywords = cast(list[str], keywords)
|
||||||
|
|||||||
@ -31,7 +31,7 @@ class FunctionCallMultiDatasetRouter:
|
|||||||
result: LLMResult = model_instance.invoke_llm( # pyright: ignore[reportCallIssue, reportArgumentType]
|
result: LLMResult = model_instance.invoke_llm( # pyright: ignore[reportCallIssue, reportArgumentType]
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
tools=dataset_tools,
|
tools=dataset_tools,
|
||||||
stream=False,
|
stream=False, # pyright: ignore[reportArgumentType]
|
||||||
model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
|
model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
|
||||||
)
|
)
|
||||||
usage = result.usage or LLMUsage.empty_usage()
|
usage = result.usage or LLMUsage.empty_usage()
|
||||||
|
|||||||
@ -59,7 +59,7 @@ class CouchbaseVector(BaseVector):
|
|||||||
|
|
||||||
auth = PasswordAuthenticator(config.user, config.password)
|
auth = PasswordAuthenticator(config.user, config.password)
|
||||||
options = ClusterOptions(auth)
|
options = ClusterOptions(auth)
|
||||||
self._cluster = Cluster(config.connection_string, options)
|
self._cluster = Cluster(config.connection_string, options) # pyright: ignore[reportArgumentType]
|
||||||
self._bucket = self._cluster.bucket(config.bucket_name)
|
self._bucket = self._cluster.bucket(config.bucket_name)
|
||||||
self._scope = self._bucket.scope(config.scope_name)
|
self._scope = self._bucket.scope(config.scope_name)
|
||||||
self._bucket_name = config.bucket_name
|
self._bucket_name = config.bucket_name
|
||||||
@ -306,7 +306,7 @@ class CouchbaseVector(BaseVector):
|
|||||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||||
top_k = kwargs.get("top_k", 4)
|
top_k = kwargs.get("top_k", 4)
|
||||||
try:
|
try:
|
||||||
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query))
|
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query)) # pyright: ignore[reportCallIssue]
|
||||||
search_iter = self._scope.search(
|
search_iter = self._scope.search(
|
||||||
self._collection_name + "_search", CBrequest, SearchOptions(limit=top_k, fields=["*"])
|
self._collection_name + "_search", CBrequest, SearchOptions(limit=top_k, fields=["*"])
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, TypedDict
|
from typing import Any, TypedDict, cast
|
||||||
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
@ -92,7 +92,7 @@ class MilvusVector(BaseVector):
|
|||||||
def _load_collection_fields(self, fields: list[str] | None = None):
|
def _load_collection_fields(self, fields: list[str] | None = None):
|
||||||
if fields is None:
|
if fields is None:
|
||||||
# Load collection fields from remote server
|
# Load collection fields from remote server
|
||||||
collection_info = self._client.describe_collection(self._collection_name)
|
collection_info = cast(dict[str, Any], self._client.describe_collection(self._collection_name))
|
||||||
fields = [field["name"] for field in collection_info["fields"]]
|
fields = [field["name"] for field in collection_info["fields"]]
|
||||||
# Since primary field is auto-id, no need to track it
|
# Since primary field is auto-id, no need to track it
|
||||||
self._fields = [f for f in fields if f != Field.PRIMARY_KEY]
|
self._fields = [f for f in fields if f != Field.PRIMARY_KEY]
|
||||||
@ -106,7 +106,8 @@ class MilvusVector(BaseVector):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
milvus_version = self._client.get_server_version()
|
milvus_version_raw = self._client.get_server_version()
|
||||||
|
milvus_version = milvus_version_raw if isinstance(milvus_version_raw, str) else str(milvus_version_raw)
|
||||||
# Check if it's Zilliz Cloud - it supports full-text search with Milvus 2.5 compatibility
|
# Check if it's Zilliz Cloud - it supports full-text search with Milvus 2.5 compatibility
|
||||||
if "Zilliz Cloud" in milvus_version:
|
if "Zilliz Cloud" in milvus_version:
|
||||||
return True
|
return True
|
||||||
|
|||||||
@ -3,7 +3,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
import jieba.posseg as pseg # type: ignore
|
import jieba.posseg as pseg # type: ignore
|
||||||
import numpy
|
import numpy
|
||||||
@ -25,6 +25,18 @@ logger = logging.getLogger(__name__)
|
|||||||
oracledb.defaults.fetch_lobs = False
|
oracledb.defaults.fetch_lobs = False
|
||||||
|
|
||||||
|
|
||||||
|
class _OraclePoolParams(TypedDict, total=False):
|
||||||
|
user: str
|
||||||
|
password: str
|
||||||
|
dsn: str
|
||||||
|
min: int
|
||||||
|
max: int
|
||||||
|
increment: int
|
||||||
|
config_dir: str | None
|
||||||
|
wallet_location: str | None
|
||||||
|
wallet_password: str | None
|
||||||
|
|
||||||
|
|
||||||
class OracleVectorConfig(BaseModel):
|
class OracleVectorConfig(BaseModel):
|
||||||
user: str
|
user: str
|
||||||
password: str
|
password: str
|
||||||
@ -127,22 +139,18 @@ class OracleVector(BaseVector):
|
|||||||
return connection
|
return connection
|
||||||
|
|
||||||
def _create_connection_pool(self, config: OracleVectorConfig):
|
def _create_connection_pool(self, config: OracleVectorConfig):
|
||||||
pool_params = {
|
pool_params = _OraclePoolParams(
|
||||||
"user": config.user,
|
user=config.user,
|
||||||
"password": config.password,
|
password=config.password,
|
||||||
"dsn": config.dsn,
|
dsn=config.dsn,
|
||||||
"min": 1,
|
min=1,
|
||||||
"max": 5,
|
max=5,
|
||||||
"increment": 1,
|
increment=1,
|
||||||
}
|
)
|
||||||
if config.is_autonomous:
|
if config.is_autonomous:
|
||||||
pool_params.update(
|
pool_params["config_dir"] = config.config_dir
|
||||||
{
|
pool_params["wallet_location"] = config.wallet_location
|
||||||
"config_dir": config.config_dir,
|
pool_params["wallet_password"] = config.wallet_password
|
||||||
"wallet_location": config.wallet_location,
|
|
||||||
"wallet_password": config.wallet_password,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return oracledb.create_pool(**pool_params)
|
return oracledb.create_pool(**pool_params)
|
||||||
|
|
||||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user