mirror of https://github.com/langgenius/dify.git
fix RetrievalMethod StrEnum (#26768)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
parent
d299e75e1b
commit
24cd7bbc62
|
|
@ -1,5 +1,5 @@
|
||||||
import enum
|
import enum
|
||||||
from enum import Enum
|
from enum import StrEnum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
||||||
|
|
@ -218,7 +218,7 @@ class DatasourceLabel(BaseModel):
|
||||||
icon: str = Field(..., description="The icon of the tool")
|
icon: str = Field(..., description="The icon of the tool")
|
||||||
|
|
||||||
|
|
||||||
class DatasourceInvokeFrom(Enum):
|
class DatasourceInvokeFrom(StrEnum):
|
||||||
"""
|
"""
|
||||||
Enum class for datasource invoke
|
Enum class for datasource invoke
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from enum import Enum, StrEnum, auto
|
from enum import StrEnum, auto
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||||
|
|
||||||
|
|
@ -7,7 +7,7 @@ from core.model_runtime.entities.common_entities import I18nObject
|
||||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||||
|
|
||||||
|
|
||||||
class ConfigurateMethod(Enum):
|
class ConfigurateMethod(StrEnum):
|
||||||
"""
|
"""
|
||||||
Enum class for configurate method of provider model.
|
Enum class for configurate method of provider model.
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ class RetrievalService:
|
||||||
@classmethod
|
@classmethod
|
||||||
def retrieve(
|
def retrieve(
|
||||||
cls,
|
cls,
|
||||||
retrieval_method: str,
|
retrieval_method: RetrievalMethod,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
query: str,
|
query: str,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
|
|
@ -56,7 +56,7 @@ class RetrievalService:
|
||||||
# Optimize multithreading with thread pools
|
# Optimize multithreading with thread pools
|
||||||
with ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_EXECUTORS) as executor: # type: ignore
|
with ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_EXECUTORS) as executor: # type: ignore
|
||||||
futures = []
|
futures = []
|
||||||
if retrieval_method == "keyword_search":
|
if retrieval_method == RetrievalMethod.KEYWORD_SEARCH:
|
||||||
futures.append(
|
futures.append(
|
||||||
executor.submit(
|
executor.submit(
|
||||||
cls.keyword_search,
|
cls.keyword_search,
|
||||||
|
|
@ -220,7 +220,7 @@ class RetrievalService:
|
||||||
score_threshold: float | None,
|
score_threshold: float | None,
|
||||||
reranking_model: dict | None,
|
reranking_model: dict | None,
|
||||||
all_documents: list,
|
all_documents: list,
|
||||||
retrieval_method: str,
|
retrieval_method: RetrievalMethod,
|
||||||
exceptions: list,
|
exceptions: list,
|
||||||
document_ids_filter: list[str] | None = None,
|
document_ids_filter: list[str] | None = None,
|
||||||
):
|
):
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,11 @@
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from enum import Enum
|
from enum import StrEnum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
class DatasourceStreamEvent(Enum):
|
class DatasourceStreamEvent(StrEnum):
|
||||||
"""
|
"""
|
||||||
Datasource Stream event
|
Datasource Stream event
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
from core.rag.splitter.fixed_text_splitter import (
|
from core.rag.splitter.fixed_text_splitter import (
|
||||||
EnhanceRecursiveCharacterTextSplitter,
|
EnhanceRecursiveCharacterTextSplitter,
|
||||||
FixedRecursiveCharacterTextSplitter,
|
FixedRecursiveCharacterTextSplitter,
|
||||||
|
|
@ -49,7 +50,7 @@ class BaseIndexProcessor(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def retrieve(
|
def retrieve(
|
||||||
self,
|
self,
|
||||||
retrieval_method: str,
|
retrieval_method: RetrievalMethod,
|
||||||
query: str,
|
query: str,
|
||||||
dataset: Dataset,
|
dataset: Dataset,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor
|
||||||
from core.rag.index_processor.constant.index_type import IndexType
|
from core.rag.index_processor.constant.index_type import IndexType
|
||||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from models.dataset import Dataset, DatasetProcessRule
|
from models.dataset import Dataset, DatasetProcessRule
|
||||||
|
|
@ -106,7 +107,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||||
|
|
||||||
def retrieve(
|
def retrieve(
|
||||||
self,
|
self,
|
||||||
retrieval_method: str,
|
retrieval_method: RetrievalMethod,
|
||||||
query: str,
|
query: str,
|
||||||
dataset: Dataset,
|
dataset: Dataset,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor
|
||||||
from core.rag.index_processor.constant.index_type import IndexType
|
from core.rag.index_processor.constant.index_type import IndexType
|
||||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||||
from core.rag.models.document import ChildDocument, Document, ParentChildStructureChunk
|
from core.rag.models.document import ChildDocument, Document, ParentChildStructureChunk
|
||||||
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
|
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
|
||||||
|
|
@ -161,7 +162,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||||
|
|
||||||
def retrieve(
|
def retrieve(
|
||||||
self,
|
self,
|
||||||
retrieval_method: str,
|
retrieval_method: RetrievalMethod,
|
||||||
query: str,
|
query: str,
|
||||||
dataset: Dataset,
|
dataset: Dataset,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor
|
||||||
from core.rag.index_processor.constant.index_type import IndexType
|
from core.rag.index_processor.constant.index_type import IndexType
|
||||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||||
from core.rag.models.document import Document, QAStructureChunk
|
from core.rag.models.document import Document, QAStructureChunk
|
||||||
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from models.dataset import Dataset
|
from models.dataset import Dataset
|
||||||
|
|
@ -141,7 +142,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||||
|
|
||||||
def retrieve(
|
def retrieve(
|
||||||
self,
|
self,
|
||||||
retrieval_method: str,
|
retrieval_method: RetrievalMethod,
|
||||||
query: str,
|
query: str,
|
||||||
dataset: Dataset,
|
dataset: Dataset,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
|
|
|
||||||
|
|
@ -364,7 +364,7 @@ class DatasetRetrieval:
|
||||||
top_k = retrieval_model_config["top_k"]
|
top_k = retrieval_model_config["top_k"]
|
||||||
# get retrieval method
|
# get retrieval method
|
||||||
if dataset.indexing_technique == "economy":
|
if dataset.indexing_technique == "economy":
|
||||||
retrieval_method = "keyword_search"
|
retrieval_method = RetrievalMethod.KEYWORD_SEARCH
|
||||||
else:
|
else:
|
||||||
retrieval_method = retrieval_model_config["search_method"]
|
retrieval_method = retrieval_model_config["search_method"]
|
||||||
# get reranking model
|
# get reranking model
|
||||||
|
|
@ -623,7 +623,7 @@ class DatasetRetrieval:
|
||||||
if dataset.indexing_technique == "economy":
|
if dataset.indexing_technique == "economy":
|
||||||
# use keyword table query
|
# use keyword table query
|
||||||
documents = RetrievalService.retrieve(
|
documents = RetrievalService.retrieve(
|
||||||
retrieval_method="keyword_search",
|
retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
|
||||||
dataset_id=dataset.id,
|
dataset_id=dataset.id,
|
||||||
query=query,
|
query=query,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from enum import Enum
|
from enum import StrEnum
|
||||||
|
|
||||||
|
|
||||||
class RetrievalMethod(Enum):
|
class RetrievalMethod(StrEnum):
|
||||||
SEMANTIC_SEARCH = "semantic_search"
|
SEMANTIC_SEARCH = "semantic_search"
|
||||||
FULL_TEXT_SEARCH = "full_text_search"
|
FULL_TEXT_SEARCH = "full_text_search"
|
||||||
HYBRID_SEARCH = "hybrid_search"
|
HYBRID_SEARCH = "hybrid_search"
|
||||||
|
|
|
||||||
|
|
@ -172,7 +172,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||||
if dataset.indexing_technique == "economy":
|
if dataset.indexing_technique == "economy":
|
||||||
# use keyword table query
|
# use keyword table query
|
||||||
documents = RetrievalService.retrieve(
|
documents = RetrievalService.retrieve(
|
||||||
retrieval_method="keyword_search",
|
retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
|
||||||
dataset_id=dataset.id,
|
dataset_id=dataset.id,
|
||||||
query=query,
|
query=query,
|
||||||
top_k=retrieval_model.get("top_k") or 4,
|
top_k=retrieval_model.get("top_k") or 4,
|
||||||
|
|
|
||||||
|
|
@ -130,7 +130,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||||
if dataset.indexing_technique == "economy":
|
if dataset.indexing_technique == "economy":
|
||||||
# use keyword table query
|
# use keyword table query
|
||||||
documents = RetrievalService.retrieve(
|
documents = RetrievalService.retrieve(
|
||||||
retrieval_method="keyword_search",
|
retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
|
||||||
dataset_id=dataset.id,
|
dataset_id=dataset.id,
|
||||||
query=query,
|
query=query,
|
||||||
top_k=self.top_k,
|
top_k=self.top_k,
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from enum import Enum, StrEnum
|
from enum import StrEnum
|
||||||
|
|
||||||
|
|
||||||
class NodeState(Enum):
|
class NodeState(StrEnum):
|
||||||
"""State of a node or edge during workflow execution."""
|
"""State of a node or edge during workflow execution."""
|
||||||
|
|
||||||
UNKNOWN = "unknown"
|
UNKNOWN = "unknown"
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ When limits are exceeded, the layer automatically aborts execution.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from enum import Enum
|
from enum import StrEnum
|
||||||
from typing import final
|
from typing import final
|
||||||
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
@ -24,7 +24,7 @@ from core.workflow.graph_events import (
|
||||||
from core.workflow.graph_events.node import NodeRunFailedEvent, NodeRunSucceededEvent
|
from core.workflow.graph_events.node import NodeRunFailedEvent, NodeRunSucceededEvent
|
||||||
|
|
||||||
|
|
||||||
class LimitType(Enum):
|
class LimitType(StrEnum):
|
||||||
"""Types of execution limits that can be exceeded."""
|
"""Types of execution limits that can be exceeded."""
|
||||||
|
|
||||||
STEP_LIMIT = "step_limit"
|
STEP_LIMIT = "step_limit"
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ from typing import Literal, Union
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
from core.workflow.nodes.base import BaseNodeData
|
from core.workflow.nodes.base import BaseNodeData
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -63,7 +64,7 @@ class RetrievalSetting(BaseModel):
|
||||||
Retrieval Setting.
|
Retrieval Setting.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
search_method: Literal["semantic_search", "keyword_search", "full_text_search", "hybrid_search"]
|
search_method: RetrievalMethod
|
||||||
top_k: int
|
top_k: int
|
||||||
score_threshold: float | None = 0.5
|
score_threshold: float | None = 0.5
|
||||||
score_threshold_enabled: bool = False
|
score_threshold_enabled: bool = False
|
||||||
|
|
|
||||||
|
|
@ -1470,7 +1470,7 @@ class DocumentService:
|
||||||
dataset.collection_binding_id = dataset_collection_binding.id
|
dataset.collection_binding_id = dataset_collection_binding.id
|
||||||
if not dataset.retrieval_model:
|
if not dataset.retrieval_model:
|
||||||
default_retrieval_model = {
|
default_retrieval_model = {
|
||||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||||
"reranking_enable": False,
|
"reranking_enable": False,
|
||||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||||
"top_k": 4,
|
"top_k": 4,
|
||||||
|
|
@ -1752,7 +1752,7 @@ class DocumentService:
|
||||||
# dataset.collection_binding_id = dataset_collection_binding.id
|
# dataset.collection_binding_id = dataset_collection_binding.id
|
||||||
# if not dataset.retrieval_model:
|
# if not dataset.retrieval_model:
|
||||||
# default_retrieval_model = {
|
# default_retrieval_model = {
|
||||||
# "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
# "search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||||
# "reranking_enable": False,
|
# "reranking_enable": False,
|
||||||
# "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
# "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||||
# "top_k": 2,
|
# "top_k": 2,
|
||||||
|
|
@ -2205,7 +2205,7 @@ class DocumentService:
|
||||||
retrieval_model = knowledge_config.retrieval_model
|
retrieval_model = knowledge_config.retrieval_model
|
||||||
else:
|
else:
|
||||||
retrieval_model = RetrievalModel(
|
retrieval_model = RetrievalModel(
|
||||||
search_method=RetrievalMethod.SEMANTIC_SEARCH.value,
|
search_method=RetrievalMethod.SEMANTIC_SEARCH,
|
||||||
reranking_enable=False,
|
reranking_enable=False,
|
||||||
reranking_model=RerankingModel(reranking_provider_name="", reranking_model_name=""),
|
reranking_model=RerankingModel(reranking_provider_name="", reranking_model_name=""),
|
||||||
top_k=4,
|
top_k=4,
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,8 @@ from typing import Literal
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
|
|
||||||
|
|
||||||
class ParentMode(StrEnum):
|
class ParentMode(StrEnum):
|
||||||
FULL_DOC = "full-doc"
|
FULL_DOC = "full-doc"
|
||||||
|
|
@ -95,7 +97,7 @@ class WeightModel(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class RetrievalModel(BaseModel):
|
class RetrievalModel(BaseModel):
|
||||||
search_method: Literal["hybrid_search", "semantic_search", "full_text_search", "keyword_search"]
|
search_method: RetrievalMethod
|
||||||
reranking_enable: bool
|
reranking_enable: bool
|
||||||
reranking_model: RerankingModel | None = None
|
reranking_model: RerankingModel | None = None
|
||||||
reranking_mode: str | None = None
|
reranking_mode: str | None = None
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,8 @@ from typing import Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, field_validator
|
from pydantic import BaseModel, field_validator
|
||||||
|
|
||||||
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
|
|
||||||
|
|
||||||
class IconInfo(BaseModel):
|
class IconInfo(BaseModel):
|
||||||
icon: str
|
icon: str
|
||||||
|
|
@ -83,7 +85,7 @@ class RetrievalSetting(BaseModel):
|
||||||
Retrieval Setting.
|
Retrieval Setting.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
search_method: Literal["semantic_search", "full_text_search", "keyword_search", "hybrid_search"]
|
search_method: RetrievalMethod
|
||||||
top_k: int
|
top_k: int
|
||||||
score_threshold: float | None = 0.5
|
score_threshold: float | None = 0.5
|
||||||
score_threshold_enabled: bool = False
|
score_threshold_enabled: bool = False
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from enum import Enum
|
from enum import StrEnum
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, model_validator
|
from pydantic import BaseModel, ConfigDict, model_validator
|
||||||
|
|
||||||
|
|
@ -27,7 +27,7 @@ from core.model_runtime.entities.provider_entities import (
|
||||||
from models.provider import ProviderType
|
from models.provider import ProviderType
|
||||||
|
|
||||||
|
|
||||||
class CustomConfigurationStatus(Enum):
|
class CustomConfigurationStatus(StrEnum):
|
||||||
"""
|
"""
|
||||||
Enum class for custom configuration status.
|
Enum class for custom configuration status.
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -63,7 +63,7 @@ class HitTestingService:
|
||||||
if metadata_condition and not document_ids_filter:
|
if metadata_condition and not document_ids_filter:
|
||||||
return cls.compact_retrieve_response(query, [])
|
return cls.compact_retrieve_response(query, [])
|
||||||
all_documents = RetrievalService.retrieve(
|
all_documents = RetrievalService.retrieve(
|
||||||
retrieval_method=retrieval_model.get("search_method", "semantic_search"),
|
retrieval_method=RetrievalMethod(retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)),
|
||||||
dataset_id=dataset.id,
|
dataset_id=dataset.id,
|
||||||
query=query,
|
query=query,
|
||||||
top_k=retrieval_model.get("top_k", 4),
|
top_k=retrieval_model.get("top_k", 4),
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ from flask_login import current_user
|
||||||
|
|
||||||
from constants import DOCUMENT_EXTENSIONS
|
from constants import DOCUMENT_EXTENSIONS
|
||||||
from core.plugin.impl.plugin import PluginInstaller
|
from core.plugin.impl.plugin import PluginInstaller
|
||||||
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from factories import variable_factory
|
from factories import variable_factory
|
||||||
from models.dataset import Dataset, Document, DocumentPipelineExecutionLog, Pipeline
|
from models.dataset import Dataset, Document, DocumentPipelineExecutionLog, Pipeline
|
||||||
|
|
@ -164,7 +165,7 @@ class RagPipelineTransformService:
|
||||||
if retrieval_model:
|
if retrieval_model:
|
||||||
retrieval_setting = RetrievalSetting.model_validate(retrieval_model)
|
retrieval_setting = RetrievalSetting.model_validate(retrieval_model)
|
||||||
if indexing_technique == "economy":
|
if indexing_technique == "economy":
|
||||||
retrieval_setting.search_method = "keyword_search"
|
retrieval_setting.search_method = RetrievalMethod.KEYWORD_SEARCH
|
||||||
knowledge_configuration.retrieval_model = retrieval_setting
|
knowledge_configuration.retrieval_model = retrieval_setting
|
||||||
else:
|
else:
|
||||||
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
|
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,12 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
|
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
|
||||||
from tests.unit_tests.core.rag.extractor.test_notion_extractor import _mock_response
|
from tests.unit_tests.core.rag.extractor.test_notion_extractor import _mock_response
|
||||||
|
|
||||||
|
|
||||||
def test_firecrawl_web_extractor_crawl_mode(mocker):
|
def test_firecrawl_web_extractor_crawl_mode(mocker: MockerFixture):
|
||||||
url = "https://firecrawl.dev"
|
url = "https://firecrawl.dev"
|
||||||
api_key = os.getenv("FIRECRAWL_API_KEY") or "fc-"
|
api_key = os.getenv("FIRECRAWL_API_KEY") or "fc-"
|
||||||
base_url = "https://api.firecrawl.dev"
|
base_url = "https://api.firecrawl.dev"
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
from core.rag.extractor import notion_extractor
|
from core.rag.extractor import notion_extractor
|
||||||
|
|
||||||
user_id = "user1"
|
user_id = "user1"
|
||||||
|
|
@ -57,7 +59,7 @@ def _remove_multiple_new_lines(text):
|
||||||
return text.strip()
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
def test_notion_page(mocker):
|
def test_notion_page(mocker: MockerFixture):
|
||||||
texts = ["Head 1", "1.1", "paragraph 1", "1.1.1"]
|
texts = ["Head 1", "1.1", "paragraph 1", "1.1.1"]
|
||||||
mocked_notion_page = {
|
mocked_notion_page = {
|
||||||
"object": "list",
|
"object": "list",
|
||||||
|
|
@ -77,7 +79,7 @@ def test_notion_page(mocker):
|
||||||
assert content == "# Head 1\n## 1.1\nparagraph 1\n### 1.1.1"
|
assert content == "# Head 1\n## 1.1\nparagraph 1\n### 1.1.1"
|
||||||
|
|
||||||
|
|
||||||
def test_notion_database(mocker):
|
def test_notion_database(mocker: MockerFixture):
|
||||||
page_title_list = ["page1", "page2", "page3"]
|
page_title_list = ["page1", "page2", "page3"]
|
||||||
mocked_notion_database = {
|
mocked_notion_database = {
|
||||||
"object": "list",
|
"object": "list",
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import redis
|
import redis
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
from core.entities.provider_entities import ModelLoadBalancingConfiguration
|
from core.entities.provider_entities import ModelLoadBalancingConfiguration
|
||||||
from core.model_manager import LBModelManager
|
from core.model_manager import LBModelManager
|
||||||
|
|
@ -39,7 +40,7 @@ def lb_model_manager():
|
||||||
return lb_model_manager
|
return lb_model_manager
|
||||||
|
|
||||||
|
|
||||||
def test_lb_model_manager_fetch_next(mocker, lb_model_manager):
|
def test_lb_model_manager_fetch_next(mocker: MockerFixture, lb_model_manager: LBModelManager):
|
||||||
# initialize redis client
|
# initialize redis client
|
||||||
redis_client.initialize(redis.Redis())
|
redis_client.initialize(redis.Redis())
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
from core.entities.provider_entities import ModelSettings
|
from core.entities.provider_entities import ModelSettings
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
|
|
@ -7,19 +8,25 @@ from models.provider import LoadBalancingModelConfig, ProviderModelSetting
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_provider_entity(mocker):
|
def mock_provider_entity(mocker: MockerFixture):
|
||||||
mock_entity = mocker.Mock()
|
mock_entity = mocker.Mock()
|
||||||
mock_entity.provider = "openai"
|
mock_entity.provider = "openai"
|
||||||
mock_entity.configurate_methods = ["predefined-model"]
|
mock_entity.configurate_methods = ["predefined-model"]
|
||||||
mock_entity.supported_model_types = [ModelType.LLM]
|
mock_entity.supported_model_types = [ModelType.LLM]
|
||||||
|
|
||||||
mock_entity.model_credential_schema = mocker.Mock()
|
# Use PropertyMock to ensure credential_form_schemas is iterable
|
||||||
mock_entity.model_credential_schema.credential_form_schemas = []
|
provider_credential_schema = mocker.Mock()
|
||||||
|
type(provider_credential_schema).credential_form_schemas = mocker.PropertyMock(return_value=[])
|
||||||
|
mock_entity.provider_credential_schema = provider_credential_schema
|
||||||
|
|
||||||
|
model_credential_schema = mocker.Mock()
|
||||||
|
type(model_credential_schema).credential_form_schemas = mocker.PropertyMock(return_value=[])
|
||||||
|
mock_entity.model_credential_schema = model_credential_schema
|
||||||
|
|
||||||
return mock_entity
|
return mock_entity
|
||||||
|
|
||||||
|
|
||||||
def test__to_model_settings(mocker, mock_provider_entity):
|
def test__to_model_settings(mocker: MockerFixture, mock_provider_entity):
|
||||||
# Mocking the inputs
|
# Mocking the inputs
|
||||||
provider_model_settings = [
|
provider_model_settings = [
|
||||||
ProviderModelSetting(
|
ProviderModelSetting(
|
||||||
|
|
@ -79,7 +86,7 @@ def test__to_model_settings(mocker, mock_provider_entity):
|
||||||
assert result[0].load_balancing_configs[1].name == "first"
|
assert result[0].load_balancing_configs[1].name == "first"
|
||||||
|
|
||||||
|
|
||||||
def test__to_model_settings_only_one_lb(mocker, mock_provider_entity):
|
def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_entity):
|
||||||
# Mocking the inputs
|
# Mocking the inputs
|
||||||
provider_model_settings = [
|
provider_model_settings = [
|
||||||
ProviderModelSetting(
|
ProviderModelSetting(
|
||||||
|
|
@ -127,7 +134,7 @@ def test__to_model_settings_only_one_lb(mocker, mock_provider_entity):
|
||||||
assert len(result[0].load_balancing_configs) == 0
|
assert len(result[0].load_balancing_configs) == 0
|
||||||
|
|
||||||
|
|
||||||
def test__to_model_settings_lb_disabled(mocker, mock_provider_entity):
|
def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_entity):
|
||||||
# Mocking the inputs
|
# Mocking the inputs
|
||||||
provider_model_settings = [
|
provider_model_settings = [
|
||||||
ProviderModelSetting(
|
ProviderModelSetting(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue