mirror of
https://github.com/langgenius/dify.git
synced 2026-04-15 18:06:36 +08:00
refactor(api): migrate consumers to shared RAG domain entities from core/rag/entities/ (#34692)
This commit is contained in:
parent
cb55176612
commit
80a7843f45
@ -1,4 +1,3 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Literal
|
||||
|
||||
@ -9,6 +8,7 @@ from graphon.variables.input_entities import VariableEntity as WorkflowVariableE
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
|
||||
from core.rag.entities import MetadataFilteringCondition
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
@ -111,31 +111,6 @@ class ExternalDataVariableEntity(BaseModel):
|
||||
config: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
SupportedComparisonOperator = Literal[
|
||||
# for string or array
|
||||
"contains",
|
||||
"not contains",
|
||||
"start with",
|
||||
"end with",
|
||||
"is",
|
||||
"is not",
|
||||
"empty",
|
||||
"not empty",
|
||||
"in",
|
||||
"not in",
|
||||
# for number
|
||||
"=",
|
||||
"≠",
|
||||
">",
|
||||
"<",
|
||||
"≥",
|
||||
"≤",
|
||||
# for time
|
||||
"before",
|
||||
"after",
|
||||
]
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
provider: str
|
||||
name: str
|
||||
@ -143,25 +118,6 @@ class ModelConfig(BaseModel):
|
||||
completion_params: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class Condition(BaseModel):
|
||||
"""
|
||||
Condition detail
|
||||
"""
|
||||
|
||||
name: str
|
||||
comparison_operator: SupportedComparisonOperator
|
||||
value: str | Sequence[str] | None | int | float = None
|
||||
|
||||
|
||||
class MetadataFilteringCondition(BaseModel):
|
||||
"""
|
||||
Metadata Filtering Condition.
|
||||
"""
|
||||
|
||||
logical_operator: Literal["and", "or"] | None = "and"
|
||||
conditions: list[Condition] | None = Field(default=None, deprecated=True)
|
||||
|
||||
|
||||
class DatasetRetrieveConfigEntity(BaseModel):
|
||||
"""
|
||||
Dataset Retrieve Config Entity.
|
||||
|
||||
@ -15,7 +15,7 @@ from core.rag.data_post_processor.data_post_processor import DataPostProcessor,
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.embedding.retrieval import AttachmentInfoDict, RetrievalChildChunk, RetrievalSegments
|
||||
from core.rag.entities.metadata_entities import MetadataCondition
|
||||
from core.rag.entities import MetadataFilteringCondition
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.query_type import QueryType
|
||||
@ -182,7 +182,9 @@ class RetrievalService:
|
||||
if not dataset:
|
||||
return []
|
||||
metadata_condition = (
|
||||
MetadataCondition.model_validate(metadata_filtering_conditions) if metadata_filtering_conditions else None
|
||||
MetadataFilteringCondition.model_validate(metadata_filtering_conditions)
|
||||
if metadata_filtering_conditions
|
||||
else None
|
||||
)
|
||||
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
dataset.tenant_id,
|
||||
|
||||
@ -1,15 +1,19 @@
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities.context_entities import DocumentContext
|
||||
from core.rag.entities.metadata_entities import Condition, MetadataFilteringCondition, SupportedComparisonOperator
|
||||
from core.rag.entities.processing_entities import ParentMode, PreProcessingRule, Rule, Segmentation
|
||||
from core.rag.entities.retrieval_settings import KeywordSetting, VectorSetting
|
||||
|
||||
__all__ = [
|
||||
"Condition",
|
||||
"DocumentContext",
|
||||
"KeywordSetting",
|
||||
"MetadataFilteringCondition",
|
||||
"ParentMode",
|
||||
"PreProcessingRule",
|
||||
"RetrievalSourceMetadata",
|
||||
"Rule",
|
||||
"Segmentation",
|
||||
"SupportedComparisonOperator",
|
||||
"VectorSetting",
|
||||
]
|
||||
|
||||
@ -38,9 +38,9 @@ class Condition(BaseModel):
|
||||
value: str | Sequence[str] | None | int | float = None
|
||||
|
||||
|
||||
class MetadataCondition(BaseModel):
|
||||
class MetadataFilteringCondition(BaseModel):
|
||||
"""
|
||||
Metadata Condition.
|
||||
Metadata Filtering Condition.
|
||||
"""
|
||||
|
||||
logical_operator: Literal["and", "or"] | None = "and"
|
||||
|
||||
@ -1,16 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class VectorSetting(BaseModel):
|
||||
vector_weight: float
|
||||
|
||||
embedding_provider_name: str
|
||||
|
||||
embedding_model_name: str
|
||||
|
||||
|
||||
class KeywordSetting(BaseModel):
|
||||
keyword_weight: float
|
||||
from core.rag.entities import KeywordSetting, VectorSetting
|
||||
|
||||
|
||||
class Weights(BaseModel):
|
||||
|
||||
@ -39,9 +39,9 @@ from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict
|
||||
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
||||
from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService
|
||||
from core.rag.entities import Condition
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities.context_entities import DocumentContext
|
||||
from core.rag.entities.metadata_entities import Condition, MetadataCondition
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.index_processor.constant.query_type import QueryType
|
||||
@ -604,7 +604,7 @@ class DatasetRetrieval:
|
||||
planning_strategy: PlanningStrategy,
|
||||
message_id: str | None = None,
|
||||
metadata_filter_document_ids: dict[str, list[str]] | None = None,
|
||||
metadata_condition: MetadataCondition | None = None,
|
||||
metadata_condition: MetadataFilteringCondition | None = None,
|
||||
):
|
||||
tools = []
|
||||
for dataset in available_datasets:
|
||||
@ -743,7 +743,7 @@ class DatasetRetrieval:
|
||||
reranking_enable: bool = True,
|
||||
message_id: str | None = None,
|
||||
metadata_filter_document_ids: dict[str, list[str]] | None = None,
|
||||
metadata_condition: MetadataCondition | None = None,
|
||||
metadata_condition: MetadataFilteringCondition | None = None,
|
||||
attachment_ids: list[str] | None = None,
|
||||
):
|
||||
if not available_datasets:
|
||||
@ -1063,7 +1063,7 @@ class DatasetRetrieval:
|
||||
top_k: int,
|
||||
all_documents: list[Document],
|
||||
document_ids_filter: list[str] | None = None,
|
||||
metadata_condition: MetadataCondition | None = None,
|
||||
metadata_condition: MetadataFilteringCondition | None = None,
|
||||
attachment_ids: list[str] | None = None,
|
||||
):
|
||||
with flask_app.app_context():
|
||||
@ -1339,7 +1339,7 @@ class DatasetRetrieval:
|
||||
metadata_model_config: ModelConfig,
|
||||
metadata_filtering_conditions: MetadataFilteringCondition | None,
|
||||
inputs: dict,
|
||||
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None]:
|
||||
) -> tuple[dict[str, list[str]] | None, MetadataFilteringCondition | None]:
|
||||
document_query = select(DatasetDocument).where(
|
||||
DatasetDocument.dataset_id.in_(dataset_ids),
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
@ -1371,7 +1371,7 @@ class DatasetRetrieval:
|
||||
value=filter.get("value"),
|
||||
)
|
||||
)
|
||||
metadata_condition = MetadataCondition(
|
||||
metadata_condition = MetadataFilteringCondition(
|
||||
logical_operator=metadata_filtering_conditions.logical_operator
|
||||
if metadata_filtering_conditions
|
||||
else "or", # type: ignore
|
||||
@ -1400,7 +1400,7 @@ class DatasetRetrieval:
|
||||
expected_value,
|
||||
filters,
|
||||
)
|
||||
metadata_condition = MetadataCondition(
|
||||
metadata_condition = MetadataFilteringCondition(
|
||||
logical_operator=metadata_filtering_conditions.logical_operator,
|
||||
conditions=conditions,
|
||||
)
|
||||
@ -1723,7 +1723,7 @@ class DatasetRetrieval:
|
||||
self,
|
||||
flask_app: Flask,
|
||||
available_datasets: list[Dataset],
|
||||
metadata_condition: MetadataCondition | None,
|
||||
metadata_condition: MetadataFilteringCondition | None,
|
||||
metadata_filter_document_ids: dict[str, list[str]] | None,
|
||||
all_documents: list[Document],
|
||||
tenant_id: str,
|
||||
|
||||
@ -4,6 +4,7 @@ from graphon.entities.base_node_data import BaseNodeData
|
||||
from graphon.enums import NodeType
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.rag.entities import KeywordSetting, VectorSetting
|
||||
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE
|
||||
@ -18,24 +19,6 @@ class RerankingModelConfig(BaseModel):
|
||||
reranking_model_name: str
|
||||
|
||||
|
||||
class VectorSetting(BaseModel):
|
||||
"""
|
||||
Vector Setting.
|
||||
"""
|
||||
|
||||
vector_weight: float
|
||||
embedding_provider_name: str
|
||||
embedding_model_name: str
|
||||
|
||||
|
||||
class KeywordSetting(BaseModel):
|
||||
"""
|
||||
Keyword Setting.
|
||||
"""
|
||||
|
||||
keyword_weight: float
|
||||
|
||||
|
||||
class WeightedScoreConfig(BaseModel):
|
||||
"""
|
||||
Weighted score Config.
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal
|
||||
|
||||
from graphon.entities.base_node_data import BaseNodeData
|
||||
@ -6,6 +5,10 @@ from graphon.enums import BuiltinNodeTypes, NodeType
|
||||
from graphon.nodes.llm.entities import ModelConfig, VisionConfig
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.rag.entities import Condition, KeywordSetting, MetadataFilteringCondition, VectorSetting
|
||||
|
||||
__all__ = ["Condition"]
|
||||
|
||||
|
||||
class RerankingModelConfig(BaseModel):
|
||||
"""
|
||||
@ -16,24 +19,6 @@ class RerankingModelConfig(BaseModel):
|
||||
model: str
|
||||
|
||||
|
||||
class VectorSetting(BaseModel):
|
||||
"""
|
||||
Vector Setting.
|
||||
"""
|
||||
|
||||
vector_weight: float
|
||||
embedding_provider_name: str
|
||||
embedding_model_name: str
|
||||
|
||||
|
||||
class KeywordSetting(BaseModel):
|
||||
"""
|
||||
Keyword Setting.
|
||||
"""
|
||||
|
||||
keyword_weight: float
|
||||
|
||||
|
||||
class WeightedScoreConfig(BaseModel):
|
||||
"""
|
||||
Weighted score Config.
|
||||
@ -64,50 +49,6 @@ class SingleRetrievalConfig(BaseModel):
|
||||
model: ModelConfig
|
||||
|
||||
|
||||
SupportedComparisonOperator = Literal[
|
||||
# for string or array
|
||||
"contains",
|
||||
"not contains",
|
||||
"start with",
|
||||
"end with",
|
||||
"is",
|
||||
"is not",
|
||||
"empty",
|
||||
"not empty",
|
||||
"in",
|
||||
"not in",
|
||||
# for number
|
||||
"=",
|
||||
"≠",
|
||||
">",
|
||||
"<",
|
||||
"≥",
|
||||
"≤",
|
||||
# for time
|
||||
"before",
|
||||
"after",
|
||||
]
|
||||
|
||||
|
||||
class Condition(BaseModel):
|
||||
"""
|
||||
Condition detail
|
||||
"""
|
||||
|
||||
name: str
|
||||
comparison_operator: SupportedComparisonOperator
|
||||
value: str | Sequence[str] | None | int | float = None
|
||||
|
||||
|
||||
class MetadataFilteringCondition(BaseModel):
|
||||
"""
|
||||
Metadata Filtering Condition.
|
||||
"""
|
||||
|
||||
logical_operator: Literal["and", "or"] | None = "and"
|
||||
conditions: list[Condition] | None = Field(default=None, deprecated=True)
|
||||
|
||||
|
||||
class KnowledgeRetrievalNodeData(BaseNodeData):
|
||||
"""
|
||||
Knowledge retrieval Node Data.
|
||||
|
||||
@ -9,7 +9,7 @@ from sqlalchemy import func, select
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from core.helper import ssrf_proxy
|
||||
from core.rag.entities.metadata_entities import MetadataCondition
|
||||
from core.rag.entities import MetadataFilteringCondition
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import (
|
||||
@ -302,7 +302,7 @@ class ExternalDatasetService:
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
external_retrieval_parameters: dict,
|
||||
metadata_condition: MetadataCondition | None = None,
|
||||
metadata_condition: MetadataFilteringCondition | None = None,
|
||||
):
|
||||
external_knowledge_binding = db.session.scalar(
|
||||
select(ExternalKnowledgeBindings)
|
||||
|
||||
@ -236,7 +236,7 @@ class TestRetrievalServiceInternals:
|
||||
assert mock_retrieve.call_count == 2
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.ExternalDatasetService.fetch_external_knowledge_retrieval")
|
||||
@patch("core.rag.datasource.retrieval_service.MetadataCondition.model_validate")
|
||||
@patch("core.rag.datasource.retrieval_service.MetadataFilteringCondition.model_validate")
|
||||
@patch("core.rag.datasource.retrieval_service.db.session.scalar")
|
||||
def test_external_retrieve_with_metadata_conditions(self, mock_scalar, mock_validate, mock_fetch):
|
||||
mock_scalar.return_value = SimpleNamespace(tenant_id="tenant-1")
|
||||
|
||||
@ -10,9 +10,6 @@ from graphon.model_runtime.entities.llm_entities import LLMUsage
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature
|
||||
from sqlalchemy import column
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
Condition as AppCondition,
|
||||
)
|
||||
from core.app.app_config.entities import (
|
||||
DatasetEntity,
|
||||
DatasetRetrieveConfigEntity,
|
||||
@ -29,6 +26,7 @@ from core.entities.agent_entities import PlanningStrategy
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.rag.data_post_processor.data_post_processor import WeightsDict
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities import Condition as AppCondition
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.models.document import Document
|
||||
|
||||
Loading…
Reference in New Issue
Block a user