fix RetrievalMethod StrEnum (#26768)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
Asuka Minato 2025-10-13 11:29:37 +09:00 committed by GitHub
parent d299e75e1b
commit 24cd7bbc62
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 65 additions and 43 deletions

View File

@ -1,5 +1,5 @@
import enum import enum
from enum import Enum from enum import StrEnum
from typing import Any from typing import Any
from pydantic import BaseModel, Field, ValidationInfo, field_validator from pydantic import BaseModel, Field, ValidationInfo, field_validator
@ -218,7 +218,7 @@ class DatasourceLabel(BaseModel):
icon: str = Field(..., description="The icon of the tool") icon: str = Field(..., description="The icon of the tool")
class DatasourceInvokeFrom(Enum): class DatasourceInvokeFrom(StrEnum):
""" """
Enum class for datasource invoke Enum class for datasource invoke
""" """

View File

@ -1,5 +1,5 @@
from collections.abc import Sequence from collections.abc import Sequence
from enum import Enum, StrEnum, auto from enum import StrEnum, auto
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
@ -7,7 +7,7 @@ from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
class ConfigurateMethod(Enum): class ConfigurateMethod(StrEnum):
""" """
Enum class for configurate method of provider model. Enum class for configurate method of provider model.
""" """

View File

@ -34,7 +34,7 @@ class RetrievalService:
@classmethod @classmethod
def retrieve( def retrieve(
cls, cls,
retrieval_method: str, retrieval_method: RetrievalMethod,
dataset_id: str, dataset_id: str,
query: str, query: str,
top_k: int, top_k: int,
@ -56,7 +56,7 @@ class RetrievalService:
# Optimize multithreading with thread pools # Optimize multithreading with thread pools
with ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_EXECUTORS) as executor: # type: ignore with ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_EXECUTORS) as executor: # type: ignore
futures = [] futures = []
if retrieval_method == "keyword_search": if retrieval_method == RetrievalMethod.KEYWORD_SEARCH:
futures.append( futures.append(
executor.submit( executor.submit(
cls.keyword_search, cls.keyword_search,
@ -220,7 +220,7 @@ class RetrievalService:
score_threshold: float | None, score_threshold: float | None,
reranking_model: dict | None, reranking_model: dict | None,
all_documents: list, all_documents: list,
retrieval_method: str, retrieval_method: RetrievalMethod,
exceptions: list, exceptions: list,
document_ids_filter: list[str] | None = None, document_ids_filter: list[str] | None = None,
): ):

View File

@ -1,11 +1,11 @@
from collections.abc import Mapping from collections.abc import Mapping
from enum import Enum from enum import StrEnum
from typing import Any from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
class DatasourceStreamEvent(Enum): class DatasourceStreamEvent(StrEnum):
""" """
Datasource Stream event Datasource Stream event
""" """

View File

@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional
from configs import dify_config from configs import dify_config
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.models.document import Document from core.rag.models.document import Document
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.rag.splitter.fixed_text_splitter import ( from core.rag.splitter.fixed_text_splitter import (
EnhanceRecursiveCharacterTextSplitter, EnhanceRecursiveCharacterTextSplitter,
FixedRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter,
@ -49,7 +50,7 @@ class BaseIndexProcessor(ABC):
@abstractmethod @abstractmethod
def retrieve( def retrieve(
self, self,
retrieval_method: str, retrieval_method: RetrievalMethod,
query: str, query: str,
dataset: Dataset, dataset: Dataset,
top_k: int, top_k: int,

View File

@ -14,6 +14,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document from core.rag.models.document import Document
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper from libs import helper
from models.dataset import Dataset, DatasetProcessRule from models.dataset import Dataset, DatasetProcessRule
@ -106,7 +107,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
def retrieve( def retrieve(
self, self,
retrieval_method: str, retrieval_method: RetrievalMethod,
query: str, query: str,
dataset: Dataset, dataset: Dataset,
top_k: int, top_k: int,

View File

@ -16,6 +16,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import ChildDocument, Document, ParentChildStructureChunk from core.rag.models.document import ChildDocument, Document, ParentChildStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db from extensions.ext_database import db
from libs import helper from libs import helper
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
@ -161,7 +162,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
def retrieve( def retrieve(
self, self,
retrieval_method: str, retrieval_method: RetrievalMethod,
query: str, query: str,
dataset: Dataset, dataset: Dataset,
top_k: int, top_k: int,

View File

@ -21,6 +21,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document, QAStructureChunk from core.rag.models.document import Document, QAStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper from libs import helper
from models.dataset import Dataset from models.dataset import Dataset
@ -141,7 +142,7 @@ class QAIndexProcessor(BaseIndexProcessor):
def retrieve( def retrieve(
self, self,
retrieval_method: str, retrieval_method: RetrievalMethod,
query: str, query: str,
dataset: Dataset, dataset: Dataset,
top_k: int, top_k: int,

View File

@ -364,7 +364,7 @@ class DatasetRetrieval:
top_k = retrieval_model_config["top_k"] top_k = retrieval_model_config["top_k"]
# get retrieval method # get retrieval method
if dataset.indexing_technique == "economy": if dataset.indexing_technique == "economy":
retrieval_method = "keyword_search" retrieval_method = RetrievalMethod.KEYWORD_SEARCH
else: else:
retrieval_method = retrieval_model_config["search_method"] retrieval_method = retrieval_model_config["search_method"]
# get reranking model # get reranking model
@ -623,7 +623,7 @@ class DatasetRetrieval:
if dataset.indexing_technique == "economy": if dataset.indexing_technique == "economy":
# use keyword table query # use keyword table query
documents = RetrievalService.retrieve( documents = RetrievalService.retrieve(
retrieval_method="keyword_search", retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
dataset_id=dataset.id, dataset_id=dataset.id,
query=query, query=query,
top_k=top_k, top_k=top_k,

View File

@ -1,7 +1,7 @@
from enum import Enum from enum import StrEnum
class RetrievalMethod(Enum): class RetrievalMethod(StrEnum):
SEMANTIC_SEARCH = "semantic_search" SEMANTIC_SEARCH = "semantic_search"
FULL_TEXT_SEARCH = "full_text_search" FULL_TEXT_SEARCH = "full_text_search"
HYBRID_SEARCH = "hybrid_search" HYBRID_SEARCH = "hybrid_search"

View File

@ -172,7 +172,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
if dataset.indexing_technique == "economy": if dataset.indexing_technique == "economy":
# use keyword table query # use keyword table query
documents = RetrievalService.retrieve( documents = RetrievalService.retrieve(
retrieval_method="keyword_search", retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
dataset_id=dataset.id, dataset_id=dataset.id,
query=query, query=query,
top_k=retrieval_model.get("top_k") or 4, top_k=retrieval_model.get("top_k") or 4,

View File

@ -130,7 +130,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
if dataset.indexing_technique == "economy": if dataset.indexing_technique == "economy":
# use keyword table query # use keyword table query
documents = RetrievalService.retrieve( documents = RetrievalService.retrieve(
retrieval_method="keyword_search", retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
dataset_id=dataset.id, dataset_id=dataset.id,
query=query, query=query,
top_k=self.top_k, top_k=self.top_k,

View File

@ -1,7 +1,7 @@
from enum import Enum, StrEnum from enum import StrEnum
class NodeState(Enum): class NodeState(StrEnum):
"""State of a node or edge during workflow execution.""" """State of a node or edge during workflow execution."""
UNKNOWN = "unknown" UNKNOWN = "unknown"

View File

@ -10,7 +10,7 @@ When limits are exceeded, the layer automatically aborts execution.
import logging import logging
import time import time
from enum import Enum from enum import StrEnum
from typing import final from typing import final
from typing_extensions import override from typing_extensions import override
@ -24,7 +24,7 @@ from core.workflow.graph_events import (
from core.workflow.graph_events.node import NodeRunFailedEvent, NodeRunSucceededEvent from core.workflow.graph_events.node import NodeRunFailedEvent, NodeRunSucceededEvent
class LimitType(Enum): class LimitType(StrEnum):
"""Types of execution limits that can be exceeded.""" """Types of execution limits that can be exceeded."""
STEP_LIMIT = "step_limit" STEP_LIMIT = "step_limit"

View File

@ -2,6 +2,7 @@ from typing import Literal, Union
from pydantic import BaseModel from pydantic import BaseModel
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.workflow.nodes.base import BaseNodeData from core.workflow.nodes.base import BaseNodeData
@ -63,7 +64,7 @@ class RetrievalSetting(BaseModel):
Retrieval Setting. Retrieval Setting.
""" """
search_method: Literal["semantic_search", "keyword_search", "full_text_search", "hybrid_search"] search_method: RetrievalMethod
top_k: int top_k: int
score_threshold: float | None = 0.5 score_threshold: float | None = 0.5
score_threshold_enabled: bool = False score_threshold_enabled: bool = False

View File

@ -1470,7 +1470,7 @@ class DocumentService:
dataset.collection_binding_id = dataset_collection_binding.id dataset.collection_binding_id = dataset_collection_binding.id
if not dataset.retrieval_model: if not dataset.retrieval_model:
default_retrieval_model = { default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value, "search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False, "reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 4, "top_k": 4,
@ -1752,7 +1752,7 @@ class DocumentService:
# dataset.collection_binding_id = dataset_collection_binding.id # dataset.collection_binding_id = dataset_collection_binding.id
# if not dataset.retrieval_model: # if not dataset.retrieval_model:
# default_retrieval_model = { # default_retrieval_model = {
# "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, # "search_method": RetrievalMethod.SEMANTIC_SEARCH,
# "reranking_enable": False, # "reranking_enable": False,
# "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, # "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
# "top_k": 2, # "top_k": 2,
@ -2205,7 +2205,7 @@ class DocumentService:
retrieval_model = knowledge_config.retrieval_model retrieval_model = knowledge_config.retrieval_model
else: else:
retrieval_model = RetrievalModel( retrieval_model = RetrievalModel(
search_method=RetrievalMethod.SEMANTIC_SEARCH.value, search_method=RetrievalMethod.SEMANTIC_SEARCH,
reranking_enable=False, reranking_enable=False,
reranking_model=RerankingModel(reranking_provider_name="", reranking_model_name=""), reranking_model=RerankingModel(reranking_provider_name="", reranking_model_name=""),
top_k=4, top_k=4,

View File

@ -3,6 +3,8 @@ from typing import Literal
from pydantic import BaseModel from pydantic import BaseModel
from core.rag.retrieval.retrieval_methods import RetrievalMethod
class ParentMode(StrEnum): class ParentMode(StrEnum):
FULL_DOC = "full-doc" FULL_DOC = "full-doc"
@ -95,7 +97,7 @@ class WeightModel(BaseModel):
class RetrievalModel(BaseModel): class RetrievalModel(BaseModel):
search_method: Literal["hybrid_search", "semantic_search", "full_text_search", "keyword_search"] search_method: RetrievalMethod
reranking_enable: bool reranking_enable: bool
reranking_model: RerankingModel | None = None reranking_model: RerankingModel | None = None
reranking_mode: str | None = None reranking_mode: str | None = None

View File

@ -2,6 +2,8 @@ from typing import Literal
from pydantic import BaseModel, field_validator from pydantic import BaseModel, field_validator
from core.rag.retrieval.retrieval_methods import RetrievalMethod
class IconInfo(BaseModel): class IconInfo(BaseModel):
icon: str icon: str
@ -83,7 +85,7 @@ class RetrievalSetting(BaseModel):
Retrieval Setting. Retrieval Setting.
""" """
search_method: Literal["semantic_search", "full_text_search", "keyword_search", "hybrid_search"] search_method: RetrievalMethod
top_k: int top_k: int
score_threshold: float | None = 0.5 score_threshold: float | None = 0.5
score_threshold_enabled: bool = False score_threshold_enabled: bool = False

View File

@ -1,5 +1,5 @@
from collections.abc import Sequence from collections.abc import Sequence
from enum import Enum from enum import StrEnum
from pydantic import BaseModel, ConfigDict, model_validator from pydantic import BaseModel, ConfigDict, model_validator
@ -27,7 +27,7 @@ from core.model_runtime.entities.provider_entities import (
from models.provider import ProviderType from models.provider import ProviderType
class CustomConfigurationStatus(Enum): class CustomConfigurationStatus(StrEnum):
""" """
Enum class for custom configuration status. Enum class for custom configuration status.
""" """

View File

@ -63,7 +63,7 @@ class HitTestingService:
if metadata_condition and not document_ids_filter: if metadata_condition and not document_ids_filter:
return cls.compact_retrieve_response(query, []) return cls.compact_retrieve_response(query, [])
all_documents = RetrievalService.retrieve( all_documents = RetrievalService.retrieve(
retrieval_method=retrieval_model.get("search_method", "semantic_search"), retrieval_method=RetrievalMethod(retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)),
dataset_id=dataset.id, dataset_id=dataset.id,
query=query, query=query,
top_k=retrieval_model.get("top_k", 4), top_k=retrieval_model.get("top_k", 4),

View File

@ -9,6 +9,7 @@ from flask_login import current_user
from constants import DOCUMENT_EXTENSIONS from constants import DOCUMENT_EXTENSIONS
from core.plugin.impl.plugin import PluginInstaller from core.plugin.impl.plugin import PluginInstaller
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db from extensions.ext_database import db
from factories import variable_factory from factories import variable_factory
from models.dataset import Dataset, Document, DocumentPipelineExecutionLog, Pipeline from models.dataset import Dataset, Document, DocumentPipelineExecutionLog, Pipeline
@ -164,7 +165,7 @@ class RagPipelineTransformService:
if retrieval_model: if retrieval_model:
retrieval_setting = RetrievalSetting.model_validate(retrieval_model) retrieval_setting = RetrievalSetting.model_validate(retrieval_model)
if indexing_technique == "economy": if indexing_technique == "economy":
retrieval_setting.search_method = "keyword_search" retrieval_setting.search_method = RetrievalMethod.KEYWORD_SEARCH
knowledge_configuration.retrieval_model = retrieval_setting knowledge_configuration.retrieval_model = retrieval_setting
else: else:
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump() dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()

View File

@ -1,10 +1,12 @@
import os import os
from pytest_mock import MockerFixture
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
from tests.unit_tests.core.rag.extractor.test_notion_extractor import _mock_response from tests.unit_tests.core.rag.extractor.test_notion_extractor import _mock_response
def test_firecrawl_web_extractor_crawl_mode(mocker): def test_firecrawl_web_extractor_crawl_mode(mocker: MockerFixture):
url = "https://firecrawl.dev" url = "https://firecrawl.dev"
api_key = os.getenv("FIRECRAWL_API_KEY") or "fc-" api_key = os.getenv("FIRECRAWL_API_KEY") or "fc-"
base_url = "https://api.firecrawl.dev" base_url = "https://api.firecrawl.dev"

View File

@ -1,5 +1,7 @@
from unittest import mock from unittest import mock
from pytest_mock import MockerFixture
from core.rag.extractor import notion_extractor from core.rag.extractor import notion_extractor
user_id = "user1" user_id = "user1"
@ -57,7 +59,7 @@ def _remove_multiple_new_lines(text):
return text.strip() return text.strip()
def test_notion_page(mocker): def test_notion_page(mocker: MockerFixture):
texts = ["Head 1", "1.1", "paragraph 1", "1.1.1"] texts = ["Head 1", "1.1", "paragraph 1", "1.1.1"]
mocked_notion_page = { mocked_notion_page = {
"object": "list", "object": "list",
@ -77,7 +79,7 @@ def test_notion_page(mocker):
assert content == "# Head 1\n## 1.1\nparagraph 1\n### 1.1.1" assert content == "# Head 1\n## 1.1\nparagraph 1\n### 1.1.1"
def test_notion_database(mocker): def test_notion_database(mocker: MockerFixture):
page_title_list = ["page1", "page2", "page3"] page_title_list = ["page1", "page2", "page3"]
mocked_notion_database = { mocked_notion_database = {
"object": "list", "object": "list",

View File

@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
import redis import redis
from pytest_mock import MockerFixture
from core.entities.provider_entities import ModelLoadBalancingConfiguration from core.entities.provider_entities import ModelLoadBalancingConfiguration
from core.model_manager import LBModelManager from core.model_manager import LBModelManager
@ -39,7 +40,7 @@ def lb_model_manager():
return lb_model_manager return lb_model_manager
def test_lb_model_manager_fetch_next(mocker, lb_model_manager): def test_lb_model_manager_fetch_next(mocker: MockerFixture, lb_model_manager: LBModelManager):
# initialize redis client # initialize redis client
redis_client.initialize(redis.Redis()) redis_client.initialize(redis.Redis())

View File

@ -1,4 +1,5 @@
import pytest import pytest
from pytest_mock import MockerFixture
from core.entities.provider_entities import ModelSettings from core.entities.provider_entities import ModelSettings
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
@ -7,19 +8,25 @@ from models.provider import LoadBalancingModelConfig, ProviderModelSetting
@pytest.fixture @pytest.fixture
def mock_provider_entity(mocker): def mock_provider_entity(mocker: MockerFixture):
mock_entity = mocker.Mock() mock_entity = mocker.Mock()
mock_entity.provider = "openai" mock_entity.provider = "openai"
mock_entity.configurate_methods = ["predefined-model"] mock_entity.configurate_methods = ["predefined-model"]
mock_entity.supported_model_types = [ModelType.LLM] mock_entity.supported_model_types = [ModelType.LLM]
mock_entity.model_credential_schema = mocker.Mock() # Use PropertyMock to ensure credential_form_schemas is iterable
mock_entity.model_credential_schema.credential_form_schemas = [] provider_credential_schema = mocker.Mock()
type(provider_credential_schema).credential_form_schemas = mocker.PropertyMock(return_value=[])
mock_entity.provider_credential_schema = provider_credential_schema
model_credential_schema = mocker.Mock()
type(model_credential_schema).credential_form_schemas = mocker.PropertyMock(return_value=[])
mock_entity.model_credential_schema = model_credential_schema
return mock_entity return mock_entity
def test__to_model_settings(mocker, mock_provider_entity): def test__to_model_settings(mocker: MockerFixture, mock_provider_entity):
# Mocking the inputs # Mocking the inputs
provider_model_settings = [ provider_model_settings = [
ProviderModelSetting( ProviderModelSetting(
@ -79,7 +86,7 @@ def test__to_model_settings(mocker, mock_provider_entity):
assert result[0].load_balancing_configs[1].name == "first" assert result[0].load_balancing_configs[1].name == "first"
def test__to_model_settings_only_one_lb(mocker, mock_provider_entity): def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_entity):
# Mocking the inputs # Mocking the inputs
provider_model_settings = [ provider_model_settings = [
ProviderModelSetting( ProviderModelSetting(
@ -127,7 +134,7 @@ def test__to_model_settings_only_one_lb(mocker, mock_provider_entity):
assert len(result[0].load_balancing_configs) == 0 assert len(result[0].load_balancing_configs) == 0
def test__to_model_settings_lb_disabled(mocker, mock_provider_entity): def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_entity):
# Mocking the inputs # Mocking the inputs
provider_model_settings = [ provider_model_settings = [
ProviderModelSetting( ProviderModelSetting(