mirror of
https://github.com/langgenius/dify.git
synced 2026-05-12 07:37:09 +08:00
refactor: port DatasetProcessRule (#31004)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
bf117dd0c8
commit
1a011dc14a
@ -324,9 +324,10 @@ class IndexingRunner:
|
||||
# one extract_setting is one source document
|
||||
for extract_setting in extract_settings:
|
||||
# extract
|
||||
processing_rule = DatasetProcessRule(
|
||||
mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"])
|
||||
)
|
||||
processing_rule = {
|
||||
"mode": tmp_processing_rule["mode"],
|
||||
"rules": tmp_processing_rule.get("rules"),
|
||||
}
|
||||
# Extract document content
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
|
||||
# Cleaning and segmentation
|
||||
@ -334,7 +335,7 @@ class IndexingRunner:
|
||||
text_docs,
|
||||
current_user=None,
|
||||
embedding_model_instance=embedding_model_instance,
|
||||
process_rule=processing_rule.to_dict(),
|
||||
process_rule=processing_rule,
|
||||
tenant_id=tenant_id,
|
||||
doc_language=doc_language,
|
||||
preview=True,
|
||||
|
||||
@ -29,6 +29,7 @@ from libs import helper
|
||||
from models import Account
|
||||
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.enums import ProcessRuleMode
|
||||
from services.account_service import AccountService
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
@ -325,7 +326,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
# update document parent mode
|
||||
dataset_process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
mode="hierarchical",
|
||||
mode=ProcessRuleMode.HIERARCHICAL,
|
||||
rules=json.dumps(
|
||||
{
|
||||
"parent_mode": parent_childs.parent_mode,
|
||||
|
||||
@ -11,7 +11,7 @@ import time
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, TypedDict, cast
|
||||
from typing import Any, ClassVar, TypedDict, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
@ -441,23 +441,27 @@ class Dataset(Base):
|
||||
return f"{dify_config.VECTOR_INDEX_NAME_PREFIX}_{normalized_dataset_id}_Node"
|
||||
|
||||
|
||||
class DatasetProcessRule(Base): # bug
|
||||
class DatasetProcessRule(TypeBase):
|
||||
__tablename__ = "dataset_process_rules"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"),
|
||||
sa.Index("dataset_process_rule_dataset_id_idx", "dataset_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()))
|
||||
dataset_id = mapped_column(StringUUID, nullable=False)
|
||||
mode = mapped_column(EnumText(ProcessRuleMode, length=255), nullable=False, server_default=sa.text("'automatic'"))
|
||||
rules = mapped_column(LongText, nullable=True)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
id: Mapped[str] = mapped_column(StringUUID, nullable=False, default_factory=lambda: str(uuid4()), init=False)
|
||||
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
mode: Mapped[ProcessRuleMode] = mapped_column(
|
||||
EnumText(ProcessRuleMode, length=255), nullable=False, server_default=sa.text("'automatic'")
|
||||
)
|
||||
rules: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
|
||||
MODES = ["automatic", "custom", "hierarchical"]
|
||||
PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"]
|
||||
AUTOMATIC_RULES: AutomaticRulesConfig = {
|
||||
AUTOMATIC_RULES: ClassVar[AutomaticRulesConfig] = {
|
||||
"pre_processing_rules": [
|
||||
{"id": "remove_extra_spaces", "enabled": True},
|
||||
{"id": "remove_urls_emails", "enabled": False},
|
||||
|
||||
@ -108,7 +108,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProcessRulesDict(TypedDict):
|
||||
mode: str
|
||||
mode: ProcessRuleMode
|
||||
rules: dict[str, Any]
|
||||
|
||||
|
||||
@ -204,7 +204,7 @@ class DatasetService:
|
||||
mode = dataset_process_rule.mode
|
||||
rules = dataset_process_rule.rules_dict or {}
|
||||
else:
|
||||
mode = str(DocumentService.DEFAULT_RULES["mode"])
|
||||
mode = ProcessRuleMode(DocumentService.DEFAULT_RULES["mode"])
|
||||
rules = dict(DocumentService.DEFAULT_RULES.get("rules") or {})
|
||||
return {"mode": mode, "rules": rules}
|
||||
|
||||
@ -1984,7 +1984,7 @@ class DocumentService:
|
||||
if process_rule.rules:
|
||||
dataset_process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
mode=process_rule.mode,
|
||||
mode=ProcessRuleMode(process_rule.mode),
|
||||
rules=process_rule.rules.model_dump_json() if process_rule.rules else None,
|
||||
created_by=account.id,
|
||||
)
|
||||
@ -1995,7 +1995,7 @@ class DocumentService:
|
||||
elif process_rule.mode == ProcessRuleMode.AUTOMATIC:
|
||||
dataset_process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
mode=process_rule.mode,
|
||||
mode=ProcessRuleMode.AUTOMATIC,
|
||||
rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
|
||||
created_by=account.id,
|
||||
)
|
||||
@ -2572,14 +2572,14 @@ class DocumentService:
|
||||
if process_rule.mode in {ProcessRuleMode.CUSTOM, ProcessRuleMode.HIERARCHICAL}:
|
||||
dataset_process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
mode=process_rule.mode,
|
||||
mode=ProcessRuleMode(process_rule.mode),
|
||||
rules=process_rule.rules.model_dump_json() if process_rule.rules else None,
|
||||
created_by=account.id,
|
||||
)
|
||||
elif process_rule.mode == ProcessRuleMode.AUTOMATIC:
|
||||
dataset_process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
mode=process_rule.mode,
|
||||
mode=ProcessRuleMode.AUTOMATIC,
|
||||
rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
|
||||
created_by=account.id,
|
||||
)
|
||||
|
||||
@ -16,6 +16,7 @@ from uuid import uuid4
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from models import AccountStatus, CreatorUserRole, TenantStatus
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import (
|
||||
AppDatasetJoin,
|
||||
@ -25,7 +26,7 @@ from models.dataset import (
|
||||
DatasetProcessRule,
|
||||
DatasetQuery,
|
||||
)
|
||||
from models.enums import DatasetQuerySource, DataSourceType, ProcessRuleMode
|
||||
from models.enums import DatasetQuerySource, DataSourceType, ProcessRuleMode, TagType
|
||||
from models.model import Tag, TagBinding
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
|
||||
@ -42,11 +43,11 @@ class DatasetRetrievalTestDataFactory:
|
||||
email=f"{uuid4()}@example.com",
|
||||
name=f"user-{uuid4()}",
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
status=AccountStatus.ACTIVE,
|
||||
)
|
||||
tenant = Tenant(
|
||||
name=f"tenant-{uuid4()}",
|
||||
status="normal",
|
||||
status=TenantStatus.NORMAL,
|
||||
)
|
||||
db_session_with_containers.add_all([account, tenant])
|
||||
db_session_with_containers.flush()
|
||||
@ -72,7 +73,7 @@ class DatasetRetrievalTestDataFactory:
|
||||
email=f"{uuid4()}@example.com",
|
||||
name=f"user-{uuid4()}",
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
status=AccountStatus.ACTIVE,
|
||||
)
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.flush()
|
||||
@ -130,7 +131,7 @@ class DatasetRetrievalTestDataFactory:
|
||||
|
||||
@staticmethod
|
||||
def create_process_rule(
|
||||
db_session_with_containers: Session, dataset_id: str, created_by: str, mode: str, rules: dict
|
||||
db_session_with_containers: Session, dataset_id: str, created_by: str, mode: ProcessRuleMode, rules: dict
|
||||
) -> DatasetProcessRule:
|
||||
"""Create a dataset process rule."""
|
||||
process_rule = DatasetProcessRule(
|
||||
@ -153,7 +154,7 @@ class DatasetRetrievalTestDataFactory:
|
||||
content=content,
|
||||
source=DatasetQuerySource.APP,
|
||||
source_app_id=None,
|
||||
created_by_role="account",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=created_by,
|
||||
)
|
||||
db_session_with_containers.add(dataset_query)
|
||||
@ -176,7 +177,7 @@ class DatasetRetrievalTestDataFactory:
|
||||
"""Create a knowledge tag and bind it to the target dataset."""
|
||||
tag = Tag(
|
||||
tenant_id=tenant_id,
|
||||
type="knowledge",
|
||||
type=TagType.KNOWLEDGE,
|
||||
name=f"tag-{uuid4()}",
|
||||
created_by=created_by,
|
||||
)
|
||||
|
||||
@ -7,6 +7,7 @@ The task is responsible for removing document segments from the search index whe
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
from faker import Faker
|
||||
from sqlalchemy import select
|
||||
@ -82,7 +83,7 @@ class TestDisableSegmentsFromIndexTask:
|
||||
|
||||
return account
|
||||
|
||||
def _create_test_dataset(self, db_session_with_containers: Session, account, fake: Faker | None = None):
|
||||
def _create_test_dataset(self, db_session_with_containers: Session, account: Account, fake: Faker | None = None):
|
||||
"""
|
||||
Helper method to create a test dataset with realistic data.
|
||||
|
||||
@ -117,7 +118,7 @@ class TestDisableSegmentsFromIndexTask:
|
||||
return dataset
|
||||
|
||||
def _create_test_document(
|
||||
self, db_session_with_containers: Session, dataset, account: Account, fake: Faker | None = None
|
||||
self, db_session_with_containers: Session, dataset: Dataset, account: Account, fake: Faker | None = None
|
||||
):
|
||||
"""
|
||||
Helper method to create a test document with realistic data.
|
||||
@ -164,7 +165,7 @@ class TestDisableSegmentsFromIndexTask:
|
||||
return document
|
||||
|
||||
def _create_test_segments(
|
||||
self, db_session_with_containers: Session, document, dataset, account, count=3, fake=None
|
||||
self, db_session_with_containers: Session, document, dataset: Dataset, account: Account, count=3, fake=None
|
||||
):
|
||||
"""
|
||||
Helper method to create test document segments with realistic data.
|
||||
@ -217,7 +218,9 @@ class TestDisableSegmentsFromIndexTask:
|
||||
|
||||
return segments
|
||||
|
||||
def _create_dataset_process_rule(self, db_session_with_containers: Session, dataset, fake: Faker | None = None):
|
||||
def _create_dataset_process_rule(
|
||||
self, db_session_with_containers: Session, dataset: Dataset, fake: Faker | None = None
|
||||
):
|
||||
"""
|
||||
Helper method to create a dataset process rule.
|
||||
|
||||
@ -230,21 +233,19 @@ class TestDisableSegmentsFromIndexTask:
|
||||
DatasetProcessRule: Created process rule instance
|
||||
"""
|
||||
fake = fake or Faker()
|
||||
process_rule = DatasetProcessRule()
|
||||
process_rule.id = fake.uuid4()
|
||||
process_rule.tenant_id = dataset.tenant_id
|
||||
process_rule.dataset_id = dataset.id
|
||||
process_rule.mode = ProcessRuleMode.AUTOMATIC
|
||||
process_rule.rules = (
|
||||
"{"
|
||||
'"mode": "automatic", '
|
||||
'"rules": {'
|
||||
'"pre_processing_rules": [], "segmentation": '
|
||||
'{"separator": "\\n\\n", "max_tokens": 1000, "chunk_overlap": 50}}'
|
||||
"}"
|
||||
process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
mode=ProcessRuleMode.AUTOMATIC,
|
||||
rules=(
|
||||
"{"
|
||||
'"mode": "automatic", '
|
||||
'"rules": {'
|
||||
'"pre_processing_rules": [], "segmentation": '
|
||||
'{"separator": "\\n\\n", "max_tokens": 1000, "chunk_overlap": 50}}'
|
||||
"}"
|
||||
),
|
||||
created_by=str(uuid4()),
|
||||
)
|
||||
process_rule.created_by = dataset.created_by
|
||||
process_rule.updated_by = dataset.updated_by
|
||||
|
||||
db_session_with_containers.add(process_rule)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
@ -847,9 +847,7 @@ class TestDatasetProcessRule:
|
||||
|
||||
# Act
|
||||
process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset_id,
|
||||
mode=ProcessRuleMode.AUTOMATIC,
|
||||
created_by=created_by,
|
||||
dataset_id=dataset_id, mode=ProcessRuleMode.AUTOMATIC, created_by=created_by, rules=None
|
||||
)
|
||||
|
||||
# Assert
|
||||
|
||||
Loading…
Reference in New Issue
Block a user