mirror of https://github.com/langgenius/dify.git
fix RetrievalMethod StrEnum (#26768)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
parent
d299e75e1b
commit
24cd7bbc62
|
|
@ -1,5 +1,5 @@
|
|||
import enum
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
||||
|
|
@ -218,7 +218,7 @@ class DatasourceLabel(BaseModel):
|
|||
icon: str = Field(..., description="The icon of the tool")
|
||||
|
||||
|
||||
class DatasourceInvokeFrom(Enum):
|
||||
class DatasourceInvokeFrom(StrEnum):
|
||||
"""
|
||||
Enum class for datasource invoke
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Sequence
|
||||
from enum import Enum, StrEnum, auto
|
||||
from enum import StrEnum, auto
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
|
|
@ -7,7 +7,7 @@ from core.model_runtime.entities.common_entities import I18nObject
|
|||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
|
||||
|
||||
class ConfigurateMethod(Enum):
|
||||
class ConfigurateMethod(StrEnum):
|
||||
"""
|
||||
Enum class for configurate method of provider model.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class RetrievalService:
|
|||
@classmethod
|
||||
def retrieve(
|
||||
cls,
|
||||
retrieval_method: str,
|
||||
retrieval_method: RetrievalMethod,
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
top_k: int,
|
||||
|
|
@ -56,7 +56,7 @@ class RetrievalService:
|
|||
# Optimize multithreading with thread pools
|
||||
with ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_EXECUTORS) as executor: # type: ignore
|
||||
futures = []
|
||||
if retrieval_method == "keyword_search":
|
||||
if retrieval_method == RetrievalMethod.KEYWORD_SEARCH:
|
||||
futures.append(
|
||||
executor.submit(
|
||||
cls.keyword_search,
|
||||
|
|
@ -220,7 +220,7 @@ class RetrievalService:
|
|||
score_threshold: float | None,
|
||||
reranking_model: dict | None,
|
||||
all_documents: list,
|
||||
retrieval_method: str,
|
||||
retrieval_method: RetrievalMethod,
|
||||
exceptions: list,
|
||||
document_ids_filter: list[str] | None = None,
|
||||
):
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
from collections.abc import Mapping
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class DatasourceStreamEvent(Enum):
|
||||
class DatasourceStreamEvent(StrEnum):
|
||||
"""
|
||||
Datasource Stream event
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional
|
|||
from configs import dify_config
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.rag.splitter.fixed_text_splitter import (
|
||||
EnhanceRecursiveCharacterTextSplitter,
|
||||
FixedRecursiveCharacterTextSplitter,
|
||||
|
|
@ -49,7 +50,7 @@ class BaseIndexProcessor(ABC):
|
|||
@abstractmethod
|
||||
def retrieve(
|
||||
self,
|
||||
retrieval_method: str,
|
||||
retrieval_method: RetrievalMethod,
|
||||
query: str,
|
||||
dataset: Dataset,
|
||||
top_k: int,
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor
|
|||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from libs import helper
|
||||
from models.dataset import Dataset, DatasetProcessRule
|
||||
|
|
@ -106,7 +107,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
|||
|
||||
def retrieve(
|
||||
self,
|
||||
retrieval_method: str,
|
||||
retrieval_method: RetrievalMethod,
|
||||
query: str,
|
||||
dataset: Dataset,
|
||||
top_k: int,
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor
|
|||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.models.document import ChildDocument, Document, ParentChildStructureChunk
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
|
||||
|
|
@ -161,7 +162,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
|||
|
||||
def retrieve(
|
||||
self,
|
||||
retrieval_method: str,
|
||||
retrieval_method: RetrievalMethod,
|
||||
query: str,
|
||||
dataset: Dataset,
|
||||
top_k: int,
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor
|
|||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.models.document import Document, QAStructureChunk
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from libs import helper
|
||||
from models.dataset import Dataset
|
||||
|
|
@ -141,7 +142,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
|||
|
||||
def retrieve(
|
||||
self,
|
||||
retrieval_method: str,
|
||||
retrieval_method: RetrievalMethod,
|
||||
query: str,
|
||||
dataset: Dataset,
|
||||
top_k: int,
|
||||
|
|
|
|||
|
|
@ -364,7 +364,7 @@ class DatasetRetrieval:
|
|||
top_k = retrieval_model_config["top_k"]
|
||||
# get retrieval method
|
||||
if dataset.indexing_technique == "economy":
|
||||
retrieval_method = "keyword_search"
|
||||
retrieval_method = RetrievalMethod.KEYWORD_SEARCH
|
||||
else:
|
||||
retrieval_method = retrieval_model_config["search_method"]
|
||||
# get reranking model
|
||||
|
|
@ -623,7 +623,7 @@ class DatasetRetrieval:
|
|||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method="keyword_search",
|
||||
retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class RetrievalMethod(Enum):
|
||||
class RetrievalMethod(StrEnum):
|
||||
SEMANTIC_SEARCH = "semantic_search"
|
||||
FULL_TEXT_SEARCH = "full_text_search"
|
||||
HYBRID_SEARCH = "hybrid_search"
|
||||
|
|
|
|||
|
|
@ -172,7 +172,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
|||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method="keyword_search",
|
||||
retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=retrieval_model.get("top_k") or 4,
|
||||
|
|
|
|||
|
|
@ -130,7 +130,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method="keyword_search",
|
||||
retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=self.top_k,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from enum import Enum, StrEnum
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class NodeState(Enum):
|
||||
class NodeState(StrEnum):
|
||||
"""State of a node or edge during workflow execution."""
|
||||
|
||||
UNKNOWN = "unknown"
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ When limits are exceeded, the layer automatically aborts execution.
|
|||
|
||||
import logging
|
||||
import time
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import final
|
||||
|
||||
from typing_extensions import override
|
||||
|
|
@ -24,7 +24,7 @@ from core.workflow.graph_events import (
|
|||
from core.workflow.graph_events.node import NodeRunFailedEvent, NodeRunSucceededEvent
|
||||
|
||||
|
||||
class LimitType(Enum):
|
||||
class LimitType(StrEnum):
|
||||
"""Types of execution limits that can be exceeded."""
|
||||
|
||||
STEP_LIMIT = "step_limit"
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from typing import Literal, Union
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
|
|
@ -63,7 +64,7 @@ class RetrievalSetting(BaseModel):
|
|||
Retrieval Setting.
|
||||
"""
|
||||
|
||||
search_method: Literal["semantic_search", "keyword_search", "full_text_search", "hybrid_search"]
|
||||
search_method: RetrievalMethod
|
||||
top_k: int
|
||||
score_threshold: float | None = 0.5
|
||||
score_threshold_enabled: bool = False
|
||||
|
|
|
|||
|
|
@ -1470,7 +1470,7 @@ class DocumentService:
|
|||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
if not dataset.retrieval_model:
|
||||
default_retrieval_model = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 4,
|
||||
|
|
@ -1752,7 +1752,7 @@ class DocumentService:
|
|||
# dataset.collection_binding_id = dataset_collection_binding.id
|
||||
# if not dataset.retrieval_model:
|
||||
# default_retrieval_model = {
|
||||
# "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
# "search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
# "reranking_enable": False,
|
||||
# "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
# "top_k": 2,
|
||||
|
|
@ -2205,7 +2205,7 @@ class DocumentService:
|
|||
retrieval_model = knowledge_config.retrieval_model
|
||||
else:
|
||||
retrieval_model = RetrievalModel(
|
||||
search_method=RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
search_method=RetrievalMethod.SEMANTIC_SEARCH,
|
||||
reranking_enable=False,
|
||||
reranking_model=RerankingModel(reranking_provider_name="", reranking_model_name=""),
|
||||
top_k=4,
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@ from typing import Literal
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
|
||||
|
||||
class ParentMode(StrEnum):
|
||||
FULL_DOC = "full-doc"
|
||||
|
|
@ -95,7 +97,7 @@ class WeightModel(BaseModel):
|
|||
|
||||
|
||||
class RetrievalModel(BaseModel):
|
||||
search_method: Literal["hybrid_search", "semantic_search", "full_text_search", "keyword_search"]
|
||||
search_method: RetrievalMethod
|
||||
reranking_enable: bool
|
||||
reranking_model: RerankingModel | None = None
|
||||
reranking_mode: str | None = None
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ from typing import Literal
|
|||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
|
||||
|
||||
class IconInfo(BaseModel):
|
||||
icon: str
|
||||
|
|
@ -83,7 +85,7 @@ class RetrievalSetting(BaseModel):
|
|||
Retrieval Setting.
|
||||
"""
|
||||
|
||||
search_method: Literal["semantic_search", "full_text_search", "keyword_search", "hybrid_search"]
|
||||
search_method: RetrievalMethod
|
||||
top_k: int
|
||||
score_threshold: float | None = 0.5
|
||||
score_threshold_enabled: bool = False
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
|
||||
|
|
@ -27,7 +27,7 @@ from core.model_runtime.entities.provider_entities import (
|
|||
from models.provider import ProviderType
|
||||
|
||||
|
||||
class CustomConfigurationStatus(Enum):
|
||||
class CustomConfigurationStatus(StrEnum):
|
||||
"""
|
||||
Enum class for custom configuration status.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ class HitTestingService:
|
|||
if metadata_condition and not document_ids_filter:
|
||||
return cls.compact_retrieve_response(query, [])
|
||||
all_documents = RetrievalService.retrieve(
|
||||
retrieval_method=retrieval_model.get("search_method", "semantic_search"),
|
||||
retrieval_method=RetrievalMethod(retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)),
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=retrieval_model.get("top_k", 4),
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from flask_login import current_user
|
|||
|
||||
from constants import DOCUMENT_EXTENSIONS
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from factories import variable_factory
|
||||
from models.dataset import Dataset, Document, DocumentPipelineExecutionLog, Pipeline
|
||||
|
|
@ -164,7 +165,7 @@ class RagPipelineTransformService:
|
|||
if retrieval_model:
|
||||
retrieval_setting = RetrievalSetting.model_validate(retrieval_model)
|
||||
if indexing_technique == "economy":
|
||||
retrieval_setting.search_method = "keyword_search"
|
||||
retrieval_setting.search_method = RetrievalMethod.KEYWORD_SEARCH
|
||||
knowledge_configuration.retrieval_model = retrieval_setting
|
||||
else:
|
||||
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
import os
|
||||
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
|
||||
from tests.unit_tests.core.rag.extractor.test_notion_extractor import _mock_response
|
||||
|
||||
|
||||
def test_firecrawl_web_extractor_crawl_mode(mocker):
|
||||
def test_firecrawl_web_extractor_crawl_mode(mocker: MockerFixture):
|
||||
url = "https://firecrawl.dev"
|
||||
api_key = os.getenv("FIRECRAWL_API_KEY") or "fc-"
|
||||
base_url = "https://api.firecrawl.dev"
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
from unittest import mock
|
||||
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.rag.extractor import notion_extractor
|
||||
|
||||
user_id = "user1"
|
||||
|
|
@ -57,7 +59,7 @@ def _remove_multiple_new_lines(text):
|
|||
return text.strip()
|
||||
|
||||
|
||||
def test_notion_page(mocker):
|
||||
def test_notion_page(mocker: MockerFixture):
|
||||
texts = ["Head 1", "1.1", "paragraph 1", "1.1.1"]
|
||||
mocked_notion_page = {
|
||||
"object": "list",
|
||||
|
|
@ -77,7 +79,7 @@ def test_notion_page(mocker):
|
|||
assert content == "# Head 1\n## 1.1\nparagraph 1\n### 1.1.1"
|
||||
|
||||
|
||||
def test_notion_database(mocker):
|
||||
def test_notion_database(mocker: MockerFixture):
|
||||
page_title_list = ["page1", "page2", "page3"]
|
||||
mocked_notion_database = {
|
||||
"object": "list",
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch
|
|||
|
||||
import pytest
|
||||
import redis
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.entities.provider_entities import ModelLoadBalancingConfiguration
|
||||
from core.model_manager import LBModelManager
|
||||
|
|
@ -39,7 +40,7 @@ def lb_model_manager():
|
|||
return lb_model_manager
|
||||
|
||||
|
||||
def test_lb_model_manager_fetch_next(mocker, lb_model_manager):
|
||||
def test_lb_model_manager_fetch_next(mocker: MockerFixture, lb_model_manager: LBModelManager):
|
||||
# initialize redis client
|
||||
redis_client.initialize(redis.Redis())
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.entities.provider_entities import ModelSettings
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
|
|
@ -7,19 +8,25 @@ from models.provider import LoadBalancingModelConfig, ProviderModelSetting
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider_entity(mocker):
|
||||
def mock_provider_entity(mocker: MockerFixture):
|
||||
mock_entity = mocker.Mock()
|
||||
mock_entity.provider = "openai"
|
||||
mock_entity.configurate_methods = ["predefined-model"]
|
||||
mock_entity.supported_model_types = [ModelType.LLM]
|
||||
|
||||
mock_entity.model_credential_schema = mocker.Mock()
|
||||
mock_entity.model_credential_schema.credential_form_schemas = []
|
||||
# Use PropertyMock to ensure credential_form_schemas is iterable
|
||||
provider_credential_schema = mocker.Mock()
|
||||
type(provider_credential_schema).credential_form_schemas = mocker.PropertyMock(return_value=[])
|
||||
mock_entity.provider_credential_schema = provider_credential_schema
|
||||
|
||||
model_credential_schema = mocker.Mock()
|
||||
type(model_credential_schema).credential_form_schemas = mocker.PropertyMock(return_value=[])
|
||||
mock_entity.model_credential_schema = model_credential_schema
|
||||
|
||||
return mock_entity
|
||||
|
||||
|
||||
def test__to_model_settings(mocker, mock_provider_entity):
|
||||
def test__to_model_settings(mocker: MockerFixture, mock_provider_entity):
|
||||
# Mocking the inputs
|
||||
provider_model_settings = [
|
||||
ProviderModelSetting(
|
||||
|
|
@ -79,7 +86,7 @@ def test__to_model_settings(mocker, mock_provider_entity):
|
|||
assert result[0].load_balancing_configs[1].name == "first"
|
||||
|
||||
|
||||
def test__to_model_settings_only_one_lb(mocker, mock_provider_entity):
|
||||
def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_entity):
|
||||
# Mocking the inputs
|
||||
provider_model_settings = [
|
||||
ProviderModelSetting(
|
||||
|
|
@ -127,7 +134,7 @@ def test__to_model_settings_only_one_lb(mocker, mock_provider_entity):
|
|||
assert len(result[0].load_balancing_configs) == 0
|
||||
|
||||
|
||||
def test__to_model_settings_lb_disabled(mocker, mock_provider_entity):
|
||||
def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_entity):
|
||||
# Mocking the inputs
|
||||
provider_model_settings = [
|
||||
ProviderModelSetting(
|
||||
|
|
|
|||
Loading…
Reference in New Issue