fix RetrievalMethod StrEnum (#26768)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
Asuka Minato 2025-10-13 11:29:37 +09:00 committed by GitHub
parent d299e75e1b
commit 24cd7bbc62
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 65 additions and 43 deletions

View File

@ -1,5 +1,5 @@
import enum
from enum import Enum
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field, ValidationInfo, field_validator
@ -218,7 +218,7 @@ class DatasourceLabel(BaseModel):
icon: str = Field(..., description="The icon of the tool")
class DatasourceInvokeFrom(Enum):
class DatasourceInvokeFrom(StrEnum):
"""
Enum class for datasource invoke
"""

View File

@ -1,5 +1,5 @@
from collections.abc import Sequence
from enum import Enum, StrEnum, auto
from enum import StrEnum, auto
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
@ -7,7 +7,7 @@ from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
class ConfigurateMethod(Enum):
class ConfigurateMethod(StrEnum):
"""
Enum class for configurate method of provider model.
"""

View File

@ -34,7 +34,7 @@ class RetrievalService:
@classmethod
def retrieve(
cls,
retrieval_method: str,
retrieval_method: RetrievalMethod,
dataset_id: str,
query: str,
top_k: int,
@ -56,7 +56,7 @@ class RetrievalService:
# Optimize multithreading with thread pools
with ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_EXECUTORS) as executor: # type: ignore
futures = []
if retrieval_method == "keyword_search":
if retrieval_method == RetrievalMethod.KEYWORD_SEARCH:
futures.append(
executor.submit(
cls.keyword_search,
@ -220,7 +220,7 @@ class RetrievalService:
score_threshold: float | None,
reranking_model: dict | None,
all_documents: list,
retrieval_method: str,
retrieval_method: RetrievalMethod,
exceptions: list,
document_ids_filter: list[str] | None = None,
):

View File

@ -1,11 +1,11 @@
from collections.abc import Mapping
from enum import Enum
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field
class DatasourceStreamEvent(Enum):
class DatasourceStreamEvent(StrEnum):
"""
Datasource Stream event
"""

View File

@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional
from configs import dify_config
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.models.document import Document
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.rag.splitter.fixed_text_splitter import (
EnhanceRecursiveCharacterTextSplitter,
FixedRecursiveCharacterTextSplitter,
@ -49,7 +50,7 @@ class BaseIndexProcessor(ABC):
@abstractmethod
def retrieve(
self,
retrieval_method: str,
retrieval_method: RetrievalMethod,
query: str,
dataset: Dataset,
top_k: int,

View File

@ -14,6 +14,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper
from models.dataset import Dataset, DatasetProcessRule
@ -106,7 +107,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
def retrieve(
self,
retrieval_method: str,
retrieval_method: RetrievalMethod,
query: str,
dataset: Dataset,
top_k: int,

View File

@ -16,6 +16,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import ChildDocument, Document, ParentChildStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from libs import helper
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
@ -161,7 +162,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
def retrieve(
self,
retrieval_method: str,
retrieval_method: RetrievalMethod,
query: str,
dataset: Dataset,
top_k: int,

View File

@ -21,6 +21,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document, QAStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper
from models.dataset import Dataset
@ -141,7 +142,7 @@ class QAIndexProcessor(BaseIndexProcessor):
def retrieve(
self,
retrieval_method: str,
retrieval_method: RetrievalMethod,
query: str,
dataset: Dataset,
top_k: int,

View File

@ -364,7 +364,7 @@ class DatasetRetrieval:
top_k = retrieval_model_config["top_k"]
# get retrieval method
if dataset.indexing_technique == "economy":
retrieval_method = "keyword_search"
retrieval_method = RetrievalMethod.KEYWORD_SEARCH
else:
retrieval_method = retrieval_model_config["search_method"]
# get reranking model
@ -623,7 +623,7 @@ class DatasetRetrieval:
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(
retrieval_method="keyword_search",
retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
dataset_id=dataset.id,
query=query,
top_k=top_k,

View File

@ -1,7 +1,7 @@
from enum import Enum
from enum import StrEnum
class RetrievalMethod(Enum):
class RetrievalMethod(StrEnum):
SEMANTIC_SEARCH = "semantic_search"
FULL_TEXT_SEARCH = "full_text_search"
HYBRID_SEARCH = "hybrid_search"

View File

@ -172,7 +172,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(
retrieval_method="keyword_search",
retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
dataset_id=dataset.id,
query=query,
top_k=retrieval_model.get("top_k") or 4,

View File

@ -130,7 +130,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(
retrieval_method="keyword_search",
retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
dataset_id=dataset.id,
query=query,
top_k=self.top_k,

View File

@ -1,7 +1,7 @@
from enum import Enum, StrEnum
from enum import StrEnum
class NodeState(Enum):
class NodeState(StrEnum):
"""State of a node or edge during workflow execution."""
UNKNOWN = "unknown"

View File

@ -10,7 +10,7 @@ When limits are exceeded, the layer automatically aborts execution.
import logging
import time
from enum import Enum
from enum import StrEnum
from typing import final
from typing_extensions import override
@ -24,7 +24,7 @@ from core.workflow.graph_events import (
from core.workflow.graph_events.node import NodeRunFailedEvent, NodeRunSucceededEvent
class LimitType(Enum):
class LimitType(StrEnum):
"""Types of execution limits that can be exceeded."""
STEP_LIMIT = "step_limit"

View File

@ -2,6 +2,7 @@ from typing import Literal, Union
from pydantic import BaseModel
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.workflow.nodes.base import BaseNodeData
@ -63,7 +64,7 @@ class RetrievalSetting(BaseModel):
Retrieval Setting.
"""
search_method: Literal["semantic_search", "keyword_search", "full_text_search", "hybrid_search"]
search_method: RetrievalMethod
top_k: int
score_threshold: float | None = 0.5
score_threshold_enabled: bool = False

View File

@ -1470,7 +1470,7 @@ class DocumentService:
dataset.collection_binding_id = dataset_collection_binding.id
if not dataset.retrieval_model:
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 4,
@ -1752,7 +1752,7 @@ class DocumentService:
# dataset.collection_binding_id = dataset_collection_binding.id
# if not dataset.retrieval_model:
# default_retrieval_model = {
# "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
# "search_method": RetrievalMethod.SEMANTIC_SEARCH,
# "reranking_enable": False,
# "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
# "top_k": 2,
@ -2205,7 +2205,7 @@ class DocumentService:
retrieval_model = knowledge_config.retrieval_model
else:
retrieval_model = RetrievalModel(
search_method=RetrievalMethod.SEMANTIC_SEARCH.value,
search_method=RetrievalMethod.SEMANTIC_SEARCH,
reranking_enable=False,
reranking_model=RerankingModel(reranking_provider_name="", reranking_model_name=""),
top_k=4,

View File

@ -3,6 +3,8 @@ from typing import Literal
from pydantic import BaseModel
from core.rag.retrieval.retrieval_methods import RetrievalMethod
class ParentMode(StrEnum):
FULL_DOC = "full-doc"
@ -95,7 +97,7 @@ class WeightModel(BaseModel):
class RetrievalModel(BaseModel):
search_method: Literal["hybrid_search", "semantic_search", "full_text_search", "keyword_search"]
search_method: RetrievalMethod
reranking_enable: bool
reranking_model: RerankingModel | None = None
reranking_mode: str | None = None

View File

@ -2,6 +2,8 @@ from typing import Literal
from pydantic import BaseModel, field_validator
from core.rag.retrieval.retrieval_methods import RetrievalMethod
class IconInfo(BaseModel):
icon: str
@ -83,7 +85,7 @@ class RetrievalSetting(BaseModel):
Retrieval Setting.
"""
search_method: Literal["semantic_search", "full_text_search", "keyword_search", "hybrid_search"]
search_method: RetrievalMethod
top_k: int
score_threshold: float | None = 0.5
score_threshold_enabled: bool = False

View File

@ -1,5 +1,5 @@
from collections.abc import Sequence
from enum import Enum
from enum import StrEnum
from pydantic import BaseModel, ConfigDict, model_validator
@ -27,7 +27,7 @@ from core.model_runtime.entities.provider_entities import (
from models.provider import ProviderType
class CustomConfigurationStatus(Enum):
class CustomConfigurationStatus(StrEnum):
"""
Enum class for custom configuration status.
"""

View File

@ -63,7 +63,7 @@ class HitTestingService:
if metadata_condition and not document_ids_filter:
return cls.compact_retrieve_response(query, [])
all_documents = RetrievalService.retrieve(
retrieval_method=retrieval_model.get("search_method", "semantic_search"),
retrieval_method=RetrievalMethod(retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)),
dataset_id=dataset.id,
query=query,
top_k=retrieval_model.get("top_k", 4),

View File

@ -9,6 +9,7 @@ from flask_login import current_user
from constants import DOCUMENT_EXTENSIONS
from core.plugin.impl.plugin import PluginInstaller
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from factories import variable_factory
from models.dataset import Dataset, Document, DocumentPipelineExecutionLog, Pipeline
@ -164,7 +165,7 @@ class RagPipelineTransformService:
if retrieval_model:
retrieval_setting = RetrievalSetting.model_validate(retrieval_model)
if indexing_technique == "economy":
retrieval_setting.search_method = "keyword_search"
retrieval_setting.search_method = RetrievalMethod.KEYWORD_SEARCH
knowledge_configuration.retrieval_model = retrieval_setting
else:
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()

View File

@ -1,10 +1,12 @@
import os
from pytest_mock import MockerFixture
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
from tests.unit_tests.core.rag.extractor.test_notion_extractor import _mock_response
def test_firecrawl_web_extractor_crawl_mode(mocker):
def test_firecrawl_web_extractor_crawl_mode(mocker: MockerFixture):
url = "https://firecrawl.dev"
api_key = os.getenv("FIRECRAWL_API_KEY") or "fc-"
base_url = "https://api.firecrawl.dev"

View File

@ -1,5 +1,7 @@
from unittest import mock
from pytest_mock import MockerFixture
from core.rag.extractor import notion_extractor
user_id = "user1"
@ -57,7 +59,7 @@ def _remove_multiple_new_lines(text):
return text.strip()
def test_notion_page(mocker):
def test_notion_page(mocker: MockerFixture):
texts = ["Head 1", "1.1", "paragraph 1", "1.1.1"]
mocked_notion_page = {
"object": "list",
@ -77,7 +79,7 @@ def test_notion_page(mocker):
assert content == "# Head 1\n## 1.1\nparagraph 1\n### 1.1.1"
def test_notion_database(mocker):
def test_notion_database(mocker: MockerFixture):
page_title_list = ["page1", "page2", "page3"]
mocked_notion_database = {
"object": "list",

View File

@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch
import pytest
import redis
from pytest_mock import MockerFixture
from core.entities.provider_entities import ModelLoadBalancingConfiguration
from core.model_manager import LBModelManager
@ -39,7 +40,7 @@ def lb_model_manager():
return lb_model_manager
def test_lb_model_manager_fetch_next(mocker, lb_model_manager):
def test_lb_model_manager_fetch_next(mocker: MockerFixture, lb_model_manager: LBModelManager):
# initialize redis client
redis_client.initialize(redis.Redis())

View File

@ -1,4 +1,5 @@
import pytest
from pytest_mock import MockerFixture
from core.entities.provider_entities import ModelSettings
from core.model_runtime.entities.model_entities import ModelType
@ -7,19 +8,25 @@ from models.provider import LoadBalancingModelConfig, ProviderModelSetting
@pytest.fixture
def mock_provider_entity(mocker):
def mock_provider_entity(mocker: MockerFixture):
mock_entity = mocker.Mock()
mock_entity.provider = "openai"
mock_entity.configurate_methods = ["predefined-model"]
mock_entity.supported_model_types = [ModelType.LLM]
mock_entity.model_credential_schema = mocker.Mock()
mock_entity.model_credential_schema.credential_form_schemas = []
# Use PropertyMock to ensure credential_form_schemas is iterable
provider_credential_schema = mocker.Mock()
type(provider_credential_schema).credential_form_schemas = mocker.PropertyMock(return_value=[])
mock_entity.provider_credential_schema = provider_credential_schema
model_credential_schema = mocker.Mock()
type(model_credential_schema).credential_form_schemas = mocker.PropertyMock(return_value=[])
mock_entity.model_credential_schema = model_credential_schema
return mock_entity
def test__to_model_settings(mocker, mock_provider_entity):
def test__to_model_settings(mocker: MockerFixture, mock_provider_entity):
# Mocking the inputs
provider_model_settings = [
ProviderModelSetting(
@ -79,7 +86,7 @@ def test__to_model_settings(mocker, mock_provider_entity):
assert result[0].load_balancing_configs[1].name == "first"
def test__to_model_settings_only_one_lb(mocker, mock_provider_entity):
def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_entity):
# Mocking the inputs
provider_model_settings = [
ProviderModelSetting(
@ -127,7 +134,7 @@ def test__to_model_settings_only_one_lb(mocker, mock_provider_entity):
assert len(result[0].load_balancing_configs) == 0
def test__to_model_settings_lb_disabled(mocker, mock_provider_entity):
def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_entity):
# Mocking the inputs
provider_model_settings = [
ProviderModelSetting(