From 1a4eb47e1d0ebb3a6535163dcb36f95c52fe1b93 Mon Sep 17 00:00:00 2001 From: tmimmanuel <14046872+tmimmanuel@users.noreply.github.com> Date: Thu, 9 Apr 2026 01:14:44 +0200 Subject: [PATCH] refactor(api): tighten types in trivial lint and config fixes (#34773) Co-authored-by: tmimmanuel Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../datasource/vdb/analyticdb/analyticdb_vector.py | 5 +++-- .../vdb/analyticdb/analyticdb_vector_openapi.py | 2 +- .../vdb/analyticdb/analyticdb_vector_sql.py | 5 +++-- api/core/rag/datasource/vdb/chroma/chroma_vector.py | 11 ++++++----- api/core/rag/datasource/vdb/qdrant/qdrant_vector.py | 7 +++---- api/core/rag/datasource/vdb/relyt/relyt_vector.py | 2 +- .../unstructured/unstructured_doc_extractor.py | 7 +++++-- .../vdb/analyticdb/test_analyticdb_vector.py | 2 +- .../vdb/analyticdb/test_analyticdb_vector_openapi.py | 6 +++--- .../vdb/analyticdb/test_analyticdb_vector_sql.py | 4 ++-- 10 files changed, 28 insertions(+), 23 deletions(-) diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py index ddb549ba9d..79cc5f0344 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py @@ -37,11 +37,12 @@ class AnalyticdbVector(BaseVector): def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): dimension = len(embeddings[0]) - self.analyticdb_vector._create_collection_if_not_exists(dimension) + self.analyticdb_vector.create_collection_if_not_exists(dimension) self.analyticdb_vector.add_texts(texts, embeddings) - def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]: self.analyticdb_vector.add_texts(documents, embeddings) + return [] def text_exists(self, id: str) -> bool: return self.analyticdb_vector.text_exists(id) diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py index fb6eaa370a..726ee8c050 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py @@ -123,7 +123,7 @@ class AnalyticdbVectorOpenAPI: else: raise ValueError(f"failed to create namespace {self.config.namespace}: {e}") - def _create_collection_if_not_exists(self, embedding_dimension: int): + def create_collection_if_not_exists(self, embedding_dimension: int): from alibabacloud_gpdb20160503 import models as gpdb_20160503_models from Tea.exceptions import TeaException diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py index 12126f32d6..41c33a3ab1 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py @@ -1,5 +1,6 @@ import json import uuid +from collections.abc import Iterator from contextlib import contextmanager from typing import Any @@ -74,7 +75,7 @@ class AnalyticdbVectorBySql: ) @contextmanager - def _get_cursor(self): + def _get_cursor(self) -> Iterator[Any]: assert self.pool is not None, "Connection pool is not initialized" conn = self.pool.getconn() cur = conn.cursor() @@ -130,7 +131,7 @@ class AnalyticdbVectorBySql: ) cur.execute(f"CREATE SCHEMA IF NOT EXISTS {self.config.namespace}") - def _create_collection_if_not_exists(self, embedding_dimension: int): + def create_collection_if_not_exists(self, embedding_dimension: int): cache_key = f"vector_indexing_{self._collection_name}" lock_name = f"{cache_key}_lock" with redis_client.lock(lock_name, timeout=20): diff --git a/api/core/rag/datasource/vdb/chroma/chroma_vector.py b/api/core/rag/datasource/vdb/chroma/chroma_vector.py index 73787c2f00..5b0cfbea15 100644 --- a/api/core/rag/datasource/vdb/chroma/chroma_vector.py +++ b/api/core/rag/datasource/vdb/chroma/chroma_vector.py @@ -2,7 +2,7 @@ import json from typing import Any, TypedDict import chromadb -from chromadb import QueryResult, Settings +from chromadb import QueryResult, Settings # pyright: ignore[reportPrivateImportUsage] from pydantic import BaseModel from configs import dify_config @@ -106,14 +106,15 @@ class ChromaVector(BaseVector): def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: collection = self._client.get_or_create_collection(self._collection_name) document_ids_filter = kwargs.get("document_ids_filter") + results: QueryResult if document_ids_filter: - results: QueryResult = collection.query( + results = collection.query( query_embeddings=query_vector, n_results=kwargs.get("top_k", 4), where={"document_id": {"$in": document_ids_filter}}, # type: ignore ) else: - results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) # type: ignore + results = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) # type: ignore score_threshold = float(kwargs.get("score_threshold") or 0.0) # Check if results contain data @@ -165,8 +166,8 @@ class ChromaVectorFactory(AbstractVectorFactory): config=ChromaConfig( host=dify_config.CHROMA_HOST or "", port=dify_config.CHROMA_PORT, - tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT, - database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE, + tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT, # pyright: ignore[reportPrivateImportUsage] + database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE, # pyright: ignore[reportPrivateImportUsage] auth_provider=dify_config.CHROMA_AUTH_PROVIDER, auth_credentials=dify_config.CHROMA_AUTH_CREDENTIALS, ), diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index f4fcb975c3..b5ff87fc5d 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -3,7 +3,7 @@ import os import uuid from collections.abc import Generator, Iterable, Sequence from itertools import islice -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast import qdrant_client from flask import current_app @@ -32,7 +32,6 @@ from extensions.ext_redis import redis_client from models.dataset import Dataset, DatasetCollectionBinding if TYPE_CHECKING: - from qdrant_client import grpc # noqa from qdrant_client.conversions import common_types from qdrant_client.http import models as rest @@ -180,7 +179,7 @@ class QdrantVector(BaseVector): for batch_ids, points in self._generate_rest_batches( texts, embeddings, filtered_metadatas, uuids, 64, self._group_id ): - self._client.upsert(collection_name=self._collection_name, points=points) + self._client.upsert(collection_name=self._collection_name, points=cast("common_types.Points", points)) added_ids.extend(batch_ids) return added_ids @@ -472,7 +471,7 @@ class QdrantVector(BaseVector): def _reload_if_needed(self): if isinstance(self._client, QdrantLocal): - self._client._load() + self._client._load() # pyright: ignore[reportPrivateUsage] @classmethod def _document_from_scored_point( diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py index e486375ec2..3ecc9867fa 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py @@ -26,7 +26,7 @@ from extensions.ext_redis import redis_client logger = logging.getLogger(__name__) -Base = declarative_base() # type: Any +Base: Any = declarative_base() class RelytConfig(BaseModel): diff --git a/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py b/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py index 7dd8beaa46..f9fbfbc409 100644 --- a/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py @@ -19,12 +19,15 @@ class UnstructuredWordExtractor(BaseExtractor): def extract(self) -> list[Document]: from unstructured.__version__ import __version__ as __unstructured_version__ - from unstructured.file_utils.filetype import FileType, detect_filetype + from unstructured.file_utils.filetype import ( # pyright: ignore[reportPrivateImportUsage] + FileType, + detect_filetype, + ) unstructured_version = tuple(int(x) for x in __unstructured_version__.split(".")) # check the file extension try: - import magic # noqa: F401 + import magic # noqa: F401 # pyright: ignore[reportUnusedImport] is_doc = detect_filetype(self._file_path) == FileType.DOC except ImportError: diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector.py index 545565cdf4..d4fa4b3e8e 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector.py @@ -71,7 +71,7 @@ def test_vector_methods_delegate_to_underlying_implementation(): assert vector.search_by_full_text("hello", top_k=2) == runner.search_by_full_text.return_value vector.delete() - runner._create_collection_if_not_exists.assert_called_once_with(2) + runner.create_collection_if_not_exists.assert_called_once_with(2) runner.add_texts.assert_any_call(texts, [[0.1, 0.2]]) runner.delete_by_ids.assert_called_once_with(["d1"]) runner.delete_by_metadata_field.assert_called_once_with("document_id", "doc-1") diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_openapi.py b/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_openapi.py index 45777774d0..4f8653a926 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_openapi.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_openapi.py @@ -249,7 +249,7 @@ def test_create_collection_if_not_exists_creates_when_missing(monkeypatch): vector._client = MagicMock() vector._client.describe_collection.side_effect = stubs.TeaException(statusCode=404) - vector._create_collection_if_not_exists(embedding_dimension=1024) + vector.create_collection_if_not_exists(embedding_dimension=1024) vector._client.create_collection.assert_called_once() openapi_module.redis_client.set.assert_called_once() @@ -268,7 +268,7 @@ def test_create_collection_if_not_exists_skips_when_cached(monkeypatch): vector.config = _config() vector._client = MagicMock() - vector._create_collection_if_not_exists(embedding_dimension=1024) + vector.create_collection_if_not_exists(embedding_dimension=1024) vector._client.describe_collection.assert_not_called() vector._client.create_collection.assert_not_called() @@ -290,7 +290,7 @@ def test_create_collection_if_not_exists_raises_on_non_404_errors(monkeypatch): vector._client.describe_collection.side_effect = stubs.TeaException(statusCode=500) with pytest.raises(ValueError, match="failed to create collection collection_1"): - vector._create_collection_if_not_exists(embedding_dimension=512) + vector.create_collection_if_not_exists(embedding_dimension=512) def test_openapi_add_delete_and_search_methods(monkeypatch): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_sql.py b/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_sql.py index 8f1206696b..f798ef8bd1 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_sql.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_sql.py @@ -374,7 +374,7 @@ def test_create_collection_if_not_exists_creates_table_indexes_and_cache(monkeyp vector._get_cursor = _cursor_context - vector._create_collection_if_not_exists(embedding_dimension=3) + vector.create_collection_if_not_exists(embedding_dimension=3) assert any("CREATE TABLE IF NOT EXISTS dify.collection" in call.args[0] for call in cursor.execute.call_args_list) assert any("CREATE INDEX collection_embedding_idx" in call.args[0] for call in cursor.execute.call_args_list) @@ -404,7 +404,7 @@ def test_create_collection_if_not_exists_raises_for_non_existing_error(monkeypat vector._get_cursor = _cursor_context with pytest.raises(RuntimeError, match="permission denied"): - vector._create_collection_if_not_exists(embedding_dimension=3) + vector.create_collection_if_not_exists(embedding_dimension=3) def test_delete_methods_raise_when_error_is_not_missing_table():