mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 18:27:19 +08:00
Merge branch 'main' into jzh
This commit is contained in:
commit
43f0c780c3
13
.gemini/config.yaml
Normal file
13
.gemini/config.yaml
Normal file
@ -0,0 +1,13 @@
|
||||
have_fun: false
|
||||
memory_config:
|
||||
disabled: false
|
||||
code_review:
|
||||
disable: true
|
||||
comment_severity_threshold: MEDIUM
|
||||
max_review_comments: -1
|
||||
pull_request_opened:
|
||||
help: false
|
||||
summary: false
|
||||
code_review: false
|
||||
include_drafts: false
|
||||
ignore_patterns: []
|
||||
9
.github/actions/setup-web/action.yml
vendored
9
.github/actions/setup-web/action.yml
vendored
@ -4,10 +4,9 @@ runs:
|
||||
using: composite
|
||||
steps:
|
||||
- name: Setup Vite+
|
||||
uses: voidzero-dev/setup-vp@4a524139920f87f9f7080d3b8545acac019e1852 # v1.0.0
|
||||
uses: voidzero-dev/setup-vp@20553a7a7429c429a74894104a2835d7fed28a72 # v1.3.0
|
||||
with:
|
||||
node-version-file: web/.nvmrc
|
||||
working-directory: web
|
||||
node-version-file: .nvmrc
|
||||
cache: true
|
||||
cache-dependency-path: web/pnpm-lock.yaml
|
||||
run-install: |
|
||||
cwd: ./web
|
||||
run-install: true
|
||||
|
||||
29
.github/workflows/style.yml
vendored
29
.github/workflows/style.yml
vendored
@ -84,20 +84,20 @@ jobs:
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Restore ESLint cache
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
id: eslint-cache-restore
|
||||
uses: actions/cache/restore@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
|
||||
with:
|
||||
path: web/.eslintcache
|
||||
key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-
|
||||
|
||||
- name: Web style check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: |
|
||||
vp run lint:ci
|
||||
# pnpm run lint:report
|
||||
# continue-on-error: true
|
||||
|
||||
# - name: Annotate Code
|
||||
# if: steps.changed-files.outputs.any_changed == 'true' && github.event_name == 'pull_request'
|
||||
# uses: DerLev/eslint-annotations@51347b3a0abfb503fc8734d5ae31c4b151297fae
|
||||
# with:
|
||||
# eslint-report: web/eslint_report.json
|
||||
# github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: vp run lint:ci
|
||||
|
||||
- name: Web tsslint
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
@ -114,6 +114,13 @@ jobs:
|
||||
working-directory: ./web
|
||||
run: vp run knip
|
||||
|
||||
- name: Save ESLint cache
|
||||
if: steps.changed-files.outputs.any_changed == 'true' && success() && steps.eslint-cache-restore.outputs.cache-hit != 'true'
|
||||
uses: actions/cache/save@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
|
||||
with:
|
||||
path: web/.eslintcache
|
||||
key: ${{ steps.eslint-cache-restore.outputs.cache-primary-key }}
|
||||
|
||||
superlinter:
|
||||
name: SuperLinter
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
2
.github/workflows/translate-i18n-claude.yml
vendored
2
.github/workflows/translate-i18n-claude.yml
vendored
@ -120,7 +120,7 @@ jobs:
|
||||
|
||||
- name: Run Claude Code for Translation Sync
|
||||
if: steps.detect_changes.outputs.CHANGED_FILES != ''
|
||||
uses: anthropics/claude-code-action@6062f3709600659be5e47fcddf2cf76993c235c2 # v1.0.76
|
||||
uses: anthropics/claude-code-action@ff9acae5886d41a99ed4ec14b7dc147d55834722 # v1.0.77
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
@ -10,6 +10,7 @@ from configs import dify_config
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.models.document import ChildDocument, Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
|
||||
@ -85,7 +86,7 @@ def migrate_annotation_vector_database():
|
||||
dataset = Dataset(
|
||||
id=app.id,
|
||||
tenant_id=app.tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model_provider=dataset_collection_binding.provider_name,
|
||||
embedding_model=dataset_collection_binding.model_name,
|
||||
collection_binding_id=dataset_collection_binding.id,
|
||||
@ -177,7 +178,9 @@ def migrate_knowledge_vector_database():
|
||||
while True:
|
||||
try:
|
||||
stmt = (
|
||||
select(Dataset).where(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc())
|
||||
select(Dataset)
|
||||
.where(Dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY)
|
||||
.order_by(Dataset.created_at.desc())
|
||||
)
|
||||
|
||||
datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
|
||||
@ -269,7 +272,7 @@ def migrate_knowledge_vector_database():
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
if dataset_document.doc_form == "hierarchical_model":
|
||||
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
child_chunks = segment.get_child_chunks()
|
||||
if child_chunks:
|
||||
child_documents = []
|
||||
|
||||
@ -95,7 +95,7 @@ class CreateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
|
||||
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
|
||||
icon_type: str | None = Field(default=None, description="Icon type")
|
||||
icon_type: IconType | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
@ -103,7 +103,7 @@ class CreateAppPayload(BaseModel):
|
||||
class UpdateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
|
||||
icon_type: str | None = Field(default=None, description="Icon type")
|
||||
icon_type: IconType | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon")
|
||||
@ -113,7 +113,7 @@ class UpdateAppPayload(BaseModel):
|
||||
class CopyAppPayload(BaseModel):
|
||||
name: str | None = Field(default=None, description="Name for the copied app")
|
||||
description: str | None = Field(default=None, description="Description for the copied app", max_length=400)
|
||||
icon_type: str | None = Field(default=None, description="Icon type")
|
||||
icon_type: IconType | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
@ -594,7 +594,7 @@ class AppApi(Resource):
|
||||
args_dict: AppService.ArgsDict = {
|
||||
"name": args.name,
|
||||
"description": args.description or "",
|
||||
"icon_type": args.icon_type or "",
|
||||
"icon_type": args.icon_type,
|
||||
"icon": args.icon or "",
|
||||
"icon_background": args.icon_background or "",
|
||||
"use_icon_as_answer_icon": args.use_icon_as_answer_icon or False,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
@ -73,7 +73,7 @@ class EmailRegisterSendEmailApi(Resource):
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
|
||||
raise AccountInFreezeError()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
|
||||
token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language)
|
||||
return {"result": "success", "data": token}
|
||||
@ -145,7 +145,7 @@ class EmailRegisterResetApi(Resource):
|
||||
email = register_data.get("email", "")
|
||||
normalized_email = email.lower()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
||||
|
||||
if account:
|
||||
|
||||
@ -4,7 +4,7 @@ import secrets
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
@ -102,7 +102,7 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
else:
|
||||
language = "en-US"
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
|
||||
|
||||
token = AccountService.send_reset_password_email(
|
||||
@ -201,7 +201,7 @@ class ForgotPasswordResetApi(Resource):
|
||||
password_hashed = hash_password(args.new_password, salt)
|
||||
|
||||
email = reset_data.get("email", "")
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
||||
|
||||
if account:
|
||||
@ -215,7 +215,6 @@ class ForgotPasswordResetApi(Resource):
|
||||
# Update existing account credentials
|
||||
account.password = base64.b64encode(password_hashed).decode()
|
||||
account.password_salt = base64.b64encode(salt).decode()
|
||||
session.commit()
|
||||
|
||||
# Create workspace if needed
|
||||
if (
|
||||
|
||||
@ -4,7 +4,7 @@ import urllib.parse
|
||||
import httpx
|
||||
from flask import current_app, redirect, request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
@ -180,7 +180,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
|
||||
account: Account | None = Account.get_by_openid(provider, user_info.id)
|
||||
|
||||
if not account:
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session)
|
||||
|
||||
return account
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import Any, cast
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import func, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
@ -29,6 +29,7 @@ from core.provider_manager import ProviderManager
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
@ -355,7 +356,7 @@ class DatasetListApi(Resource):
|
||||
|
||||
for item in data:
|
||||
# convert embedding_model_provider to plugin standard format
|
||||
if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]:
|
||||
if item["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY and item["embedding_model_provider"]:
|
||||
item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"]))
|
||||
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
|
||||
if item_model in model_names:
|
||||
@ -436,7 +437,7 @@ class DatasetApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
if dataset.embedding_model_provider:
|
||||
provider_id = ModelProviderID(dataset.embedding_model_provider)
|
||||
data["embedding_model_provider"] = str(provider_id)
|
||||
@ -454,7 +455,7 @@ class DatasetApi(Resource):
|
||||
for embedding_model in embedding_models:
|
||||
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
||||
|
||||
if data["indexing_technique"] == "high_quality":
|
||||
if data["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY:
|
||||
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
|
||||
if item_model in model_names:
|
||||
data["embedding_available"] = True
|
||||
@ -485,7 +486,7 @@ class DatasetApi(Resource):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
# check embedding model setting
|
||||
if (
|
||||
payload.indexing_technique == "high_quality"
|
||||
payload.indexing_technique == IndexTechniqueType.HIGH_QUALITY
|
||||
and payload.embedding_model_provider is not None
|
||||
and payload.embedding_model is not None
|
||||
):
|
||||
@ -738,20 +739,23 @@ class DatasetIndexingStatusApi(Resource):
|
||||
documents_status = []
|
||||
for document in documents:
|
||||
completed_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
total_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
# Create a dictionary with document attributes and additional fields
|
||||
document_dict = {
|
||||
@ -802,9 +806,12 @@ class DatasetApiKeyApi(Resource):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
current_key_count = (
|
||||
db.session.query(ApiToken)
|
||||
.where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
|
||||
.count()
|
||||
db.session.scalar(
|
||||
select(func.count(ApiToken.id)).where(
|
||||
ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
if current_key_count >= self.max_keys:
|
||||
@ -839,14 +846,14 @@ class DatasetApiDeleteApi(Resource):
|
||||
def delete(self, api_key_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
api_key_id = str(api_key_id)
|
||||
key = (
|
||||
db.session.query(ApiToken)
|
||||
key = db.session.scalar(
|
||||
select(ApiToken)
|
||||
.where(
|
||||
ApiToken.tenant_id == current_tenant_id,
|
||||
ApiToken.type == self.resource_type,
|
||||
ApiToken.id == api_key_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if key is None:
|
||||
@ -857,7 +864,7 @@ class DatasetApiDeleteApi(Resource):
|
||||
assert key is not None # nosec - for type checker only
|
||||
ApiTokenCache.delete(key.token, key.type)
|
||||
|
||||
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
|
||||
db.session.delete(key)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
@ -10,7 +10,7 @@ import sqlalchemy as sa
|
||||
from flask import request, send_file
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import asc, desc, select
|
||||
from sqlalchemy import asc, desc, func, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
@ -27,6 +27,7 @@ from core.model_manager import ModelManager
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from extensions.ext_database import db
|
||||
@ -211,12 +212,11 @@ class GetProcessRuleApi(Resource):
|
||||
raise Forbidden(str(e))
|
||||
|
||||
# get the latest process rule
|
||||
dataset_process_rule = (
|
||||
db.session.query(DatasetProcessRule)
|
||||
dataset_process_rule = db.session.scalar(
|
||||
select(DatasetProcessRule)
|
||||
.where(DatasetProcessRule.dataset_id == document.dataset_id)
|
||||
.order_by(DatasetProcessRule.created_at.desc())
|
||||
.limit(1)
|
||||
.one_or_none()
|
||||
)
|
||||
if dataset_process_rule:
|
||||
mode = dataset_process_rule.mode
|
||||
@ -330,21 +330,23 @@ class DatasetDocumentListApi(Resource):
|
||||
if fetch:
|
||||
for document in documents:
|
||||
completed_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
total_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
document.completed_segments = completed_segments
|
||||
document.total_segments = total_segments
|
||||
@ -448,7 +450,7 @@ class DatasetInitApi(Resource):
|
||||
raise Forbidden()
|
||||
|
||||
knowledge_config = KnowledgeConfig.model_validate(console_ns.payload or {})
|
||||
if knowledge_config.indexing_technique == "high_quality":
|
||||
if knowledge_config.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
|
||||
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
|
||||
try:
|
||||
@ -462,7 +464,7 @@ class DatasetInitApi(Resource):
|
||||
is_multimodal = DatasetService.check_is_multimodal_model(
|
||||
current_tenant_id, knowledge_config.embedding_model_provider, knowledge_config.embedding_model
|
||||
)
|
||||
knowledge_config.is_multimodal = is_multimodal
|
||||
knowledge_config.is_multimodal = is_multimodal # pyrefly: ignore[bad-assignment]
|
||||
except InvokeAuthorizationError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
@ -521,10 +523,10 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
if data_source_info and "upload_file_id" in data_source_info:
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
|
||||
file = (
|
||||
db.session.query(UploadFile)
|
||||
file = db.session.scalar(
|
||||
select(UploadFile)
|
||||
.where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
# raise error if file not found
|
||||
@ -586,10 +588,10 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
if not data_source_info:
|
||||
continue
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
file_detail = (
|
||||
db.session.query(UploadFile)
|
||||
file_detail = db.session.scalar(
|
||||
select(UploadFile)
|
||||
.where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if file_detail is None:
|
||||
@ -672,20 +674,23 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
|
||||
documents_status = []
|
||||
for document in documents:
|
||||
completed_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
total_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
# Create a dictionary with document attributes and additional fields
|
||||
document_dict = {
|
||||
@ -723,18 +728,23 @@ class DocumentIndexingStatusApi(DocumentResource):
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
completed_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document_id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document_id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
total_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != SegmentStatus.RE_SEGMENT)
|
||||
.count()
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.document_id == str(document_id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
# Create a dictionary with document attributes and additional fields
|
||||
@ -1258,11 +1268,11 @@ class DocumentPipelineExecutionLogApi(DocumentResource):
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
log = (
|
||||
db.session.query(DocumentPipelineExecutionLog)
|
||||
.filter_by(document_id=document_id)
|
||||
log = db.session.scalar(
|
||||
select(DocumentPipelineExecutionLog)
|
||||
.where(DocumentPipelineExecutionLog.document_id == document_id)
|
||||
.order_by(DocumentPipelineExecutionLog.created_at.desc())
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not log:
|
||||
return {
|
||||
@ -1328,7 +1338,7 @@ class DocumentGenerateSummaryApi(Resource):
|
||||
raise BadRequest("document_list cannot be empty.")
|
||||
|
||||
# Check if dataset configuration supports summary generation
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
|
||||
raise ValueError(
|
||||
f"Summary generation is only available for 'high_quality' indexing technique. "
|
||||
f"Current indexing technique: {dataset.indexing_technique}"
|
||||
|
||||
@ -26,6 +26,7 @@ from controllers.console.wraps import (
|
||||
)
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_manager import ModelManager
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
@ -45,7 +46,7 @@ def _get_segment_with_summary(segment, dataset_id):
|
||||
"""Helper function to marshal segment and add summary information."""
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
segment_dict = dict(marshal(segment, segment_fields))
|
||||
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore
|
||||
# Query summary for this segment (only enabled summaries)
|
||||
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
|
||||
segment_dict["summary"] = summary.summary_content if summary else None
|
||||
@ -206,7 +207,7 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
# Add summary to each segment
|
||||
segments_with_summary = []
|
||||
for segment in segments.items:
|
||||
segment_dict = dict(marshal(segment, segment_fields))
|
||||
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore
|
||||
segment_dict["summary"] = summaries.get(segment.id)
|
||||
segments_with_summary.append(segment_dict)
|
||||
|
||||
@ -279,7 +280,7 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
# check embedding model setting
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
@ -333,7 +334,7 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
@ -383,7 +384,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
# check embedding model setting
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
@ -401,10 +402,10 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
@ -447,10 +448,10 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
@ -494,7 +495,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
payload = BatchImportPayload.model_validate(console_ns.payload or {})
|
||||
upload_file_id = payload.upload_file_id
|
||||
|
||||
upload_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
|
||||
upload_file = db.session.scalar(select(UploadFile).where(UploadFile.id == upload_file_id).limit(1))
|
||||
if not upload_file:
|
||||
raise NotFound("UploadFile not found.")
|
||||
|
||||
@ -559,17 +560,17 @@ class ChildChunkAddApi(Resource):
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
@ -616,10 +617,10 @@ class ChildChunkAddApi(Resource):
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
@ -666,10 +667,10 @@ class ChildChunkAddApi(Resource):
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
@ -714,24 +715,24 @@ class ChildChunkUpdateApi(Resource):
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
# check child chunk
|
||||
child_chunk_id = str(child_chunk_id)
|
||||
child_chunk = (
|
||||
db.session.query(ChildChunk)
|
||||
child_chunk = db.session.scalar(
|
||||
select(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.id == str(child_chunk_id),
|
||||
ChildChunk.tenant_id == current_tenant_id,
|
||||
ChildChunk.segment_id == segment.id,
|
||||
ChildChunk.document_id == document_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not child_chunk:
|
||||
raise NotFound("Child chunk not found.")
|
||||
@ -771,24 +772,24 @@ class ChildChunkUpdateApi(Resource):
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
# check child chunk
|
||||
child_chunk_id = str(child_chunk_id)
|
||||
child_chunk = (
|
||||
db.session.query(ChildChunk)
|
||||
child_chunk = db.session.scalar(
|
||||
select(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.id == str(child_chunk_id),
|
||||
ChildChunk.tenant_id == current_tenant_id,
|
||||
ChildChunk.segment_id == segment.id,
|
||||
ChildChunk.document_id == document_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not child_chunk:
|
||||
raise NotFound("Child chunk not found.")
|
||||
|
||||
@ -2,6 +2,8 @@ from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from controllers.console.datasets.error import PipelineNotFoundError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant
|
||||
@ -24,10 +26,8 @@ def get_rag_pipeline(view_func: Callable[P, R]):
|
||||
|
||||
del kwargs["pipeline_id"]
|
||||
|
||||
pipeline = (
|
||||
db.session.query(Pipeline)
|
||||
.where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
pipeline = db.session.scalar(
|
||||
select(Pipeline).where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id).limit(1)
|
||||
)
|
||||
|
||||
if not pipeline:
|
||||
|
||||
@ -15,6 +15,7 @@ from controllers.service_api.wraps import (
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
)
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from fields.tag_fields import DataSetTag
|
||||
@ -153,15 +154,20 @@ class DatasetListApi(DatasetApiResource):
|
||||
|
||||
data = marshal(datasets, dataset_detail_fields)
|
||||
for item in data:
|
||||
if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]:
|
||||
item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"]))
|
||||
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
|
||||
if (
|
||||
item["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY # pyrefly: ignore[bad-index]
|
||||
and item["embedding_model_provider"] # pyrefly: ignore[bad-index]
|
||||
):
|
||||
item["embedding_model_provider"] = str( # pyrefly: ignore[unsupported-operation]
|
||||
ModelProviderID(item["embedding_model_provider"]) # pyrefly: ignore[bad-index]
|
||||
)
|
||||
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" # pyrefly: ignore[bad-index]
|
||||
if item_model in model_names:
|
||||
item["embedding_available"] = True
|
||||
item["embedding_available"] = True # type: ignore
|
||||
else:
|
||||
item["embedding_available"] = False
|
||||
item["embedding_available"] = False # type: ignore
|
||||
else:
|
||||
item["embedding_available"] = True
|
||||
item["embedding_available"] = True # type: ignore
|
||||
response = {
|
||||
"data": data,
|
||||
"has_more": len(datasets) == query.limit,
|
||||
@ -265,7 +271,7 @@ class DatasetApi(DatasetApiResource):
|
||||
for embedding_model in embedding_models:
|
||||
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
||||
|
||||
if data.get("indexing_technique") == "high_quality":
|
||||
if data.get("indexing_technique") == IndexTechniqueType.HIGH_QUALITY:
|
||||
item_model = f"{data.get('embedding_model')}:{data.get('embedding_model_provider')}"
|
||||
if item_model in model_names:
|
||||
data["embedding_available"] = True
|
||||
@ -315,7 +321,7 @@ class DatasetApi(DatasetApiResource):
|
||||
# check embedding model setting
|
||||
embedding_model_provider = payload.embedding_model_provider
|
||||
embedding_model = payload.embedding_model
|
||||
if payload.indexing_technique == "high_quality" or embedding_model_provider:
|
||||
if payload.indexing_technique == IndexTechniqueType.HIGH_QUALITY or embedding_model_provider:
|
||||
if embedding_model_provider and embedding_model:
|
||||
DatasetService.check_embedding_model_setting(
|
||||
dataset.tenant_id, embedding_model_provider, embedding_model
|
||||
|
||||
@ -17,6 +17,7 @@ from controllers.service_api.wraps import (
|
||||
)
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_manager import ModelManager
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
from fields.segment_fields import child_chunk_fields, segment_fields
|
||||
@ -103,7 +104,7 @@ class SegmentApi(DatasetApiResource):
|
||||
if not document.enabled:
|
||||
raise NotFound("Document is disabled.")
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
@ -157,7 +158,7 @@ class SegmentApi(DatasetApiResource):
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
@ -262,7 +263,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
# check embedding model setting
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
@ -358,7 +359,7 @@ class ChildChunkApi(DatasetApiResource):
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
|
||||
@ -4,6 +4,7 @@ from sqlalchemy import select
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.enums import CollectionBindingType, ConversationFromSource
|
||||
@ -50,7 +51,7 @@ class AnnotationReplyFeature:
|
||||
dataset = Dataset(
|
||||
id=app_record.id,
|
||||
tenant_id=app_record.tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model_provider=embedding_provider_name,
|
||||
embedding_model=embedding_model_name,
|
||||
collection_binding_id=dataset_collection_binding.id,
|
||||
|
||||
@ -19,6 +19,7 @@ class RateLimit:
|
||||
_REQUEST_MAX_ALIVE_TIME = 10 * 60 # 10 minutes
|
||||
_ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes
|
||||
_instance_dict: dict[str, "RateLimit"] = {}
|
||||
max_active_requests: int
|
||||
|
||||
def __new__(cls, client_id: str, max_active_requests: int):
|
||||
if client_id not in cls._instance_dict:
|
||||
@ -27,7 +28,13 @@ class RateLimit:
|
||||
return cls._instance_dict[client_id]
|
||||
|
||||
def __init__(self, client_id: str, max_active_requests: int):
|
||||
flush_cache = hasattr(self, "max_active_requests") and self.max_active_requests != max_active_requests
|
||||
self.max_active_requests = max_active_requests
|
||||
# Only flush here if this instance has already been fully initialized,
|
||||
# i.e. the Redis key attributes exist. Otherwise, rely on the flush at
|
||||
# the end of initialization below.
|
||||
if flush_cache and hasattr(self, "active_requests_key") and hasattr(self, "max_active_requests_key"):
|
||||
self.flush_cache(use_local_value=True)
|
||||
# must be called after max_active_requests is set
|
||||
if self.disabled():
|
||||
return
|
||||
@ -41,8 +48,6 @@ class RateLimit:
|
||||
self.flush_cache(use_local_value=True)
|
||||
|
||||
def flush_cache(self, use_local_value=False):
|
||||
if self.disabled():
|
||||
return
|
||||
self.last_recalculate_time = time.time()
|
||||
# flush max active requests
|
||||
if use_local_value or not redis_client.exists(self.max_active_requests_key):
|
||||
@ -50,7 +55,8 @@ class RateLimit:
|
||||
else:
|
||||
self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode("utf-8"))
|
||||
redis_client.expire(self.max_active_requests_key, timedelta(days=1))
|
||||
|
||||
if self.disabled():
|
||||
return
|
||||
# flush max active requests (in-transit request list)
|
||||
if not redis_client.exists(self.active_requests_key):
|
||||
return
|
||||
|
||||
@ -6,16 +6,23 @@ from dify_graph.graph_events.graph import GraphRunPausedEvent
|
||||
class SuspendLayer(GraphEngineLayer):
|
||||
""" """
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._paused = False
|
||||
|
||||
def on_graph_start(self):
|
||||
pass
|
||||
self._paused = False
|
||||
|
||||
def on_event(self, event: GraphEngineEvent):
|
||||
"""
|
||||
Handle the paused event, stash runtime state into storage and wait for resume.
|
||||
"""
|
||||
if isinstance(event, GraphRunPausedEvent):
|
||||
pass
|
||||
self._paused = True
|
||||
|
||||
def on_graph_end(self, error: Exception | None):
|
||||
""" """
|
||||
pass
|
||||
self._paused = False
|
||||
|
||||
def is_paused(self) -> bool:
|
||||
return self._paused
|
||||
|
||||
@ -128,14 +128,14 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
self._handle_graph_run_paused(event)
|
||||
return
|
||||
|
||||
if isinstance(event, NodeRunStartedEvent):
|
||||
self._handle_node_started(event)
|
||||
return
|
||||
|
||||
if isinstance(event, NodeRunRetryEvent):
|
||||
self._handle_node_retry(event)
|
||||
return
|
||||
|
||||
if isinstance(event, NodeRunStartedEvent):
|
||||
self._handle_node_started(event)
|
||||
return
|
||||
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
self._handle_node_succeeded(event)
|
||||
return
|
||||
|
||||
@ -21,7 +21,7 @@ from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.models.document import ChildDocument, Document
|
||||
@ -271,7 +271,7 @@ class IndexingRunner:
|
||||
doc_form: str | None = None,
|
||||
doc_language: str = "English",
|
||||
dataset_id: str | None = None,
|
||||
indexing_technique: str = "economy",
|
||||
indexing_technique: str = IndexTechniqueType.ECONOMY,
|
||||
) -> IndexingEstimate:
|
||||
"""
|
||||
Estimate the indexing for the document.
|
||||
@ -289,7 +289,7 @@ class IndexingRunner:
|
||||
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found.")
|
||||
if dataset.indexing_technique == "high_quality" or indexing_technique == "high_quality":
|
||||
if IndexTechniqueType.HIGH_QUALITY in {dataset.indexing_technique, indexing_technique}:
|
||||
if dataset.embedding_model_provider:
|
||||
embedding_model_instance = self.model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
@ -303,7 +303,7 @@ class IndexingRunner:
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
else:
|
||||
if indexing_technique == "high_quality":
|
||||
if indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
embedding_model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
@ -573,7 +573,7 @@ class IndexingRunner:
|
||||
"""
|
||||
|
||||
embedding_model_instance = None
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
embedding_model_instance = self.model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
@ -587,7 +587,7 @@ class IndexingRunner:
|
||||
create_keyword_thread = None
|
||||
if (
|
||||
dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX
|
||||
and dataset.indexing_technique == "economy"
|
||||
and dataset.indexing_technique == IndexTechniqueType.ECONOMY
|
||||
):
|
||||
# create keyword index
|
||||
create_keyword_thread = threading.Thread(
|
||||
@ -597,7 +597,7 @@ class IndexingRunner:
|
||||
create_keyword_thread.start()
|
||||
|
||||
max_workers = 10
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = []
|
||||
|
||||
@ -628,7 +628,7 @@ class IndexingRunner:
|
||||
tokens += future.result()
|
||||
if (
|
||||
dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX
|
||||
and dataset.indexing_technique == "economy"
|
||||
and dataset.indexing_technique == IndexTechniqueType.ECONOMY
|
||||
and create_keyword_thread is not None
|
||||
):
|
||||
create_keyword_thread.join()
|
||||
@ -654,7 +654,7 @@ class IndexingRunner:
|
||||
raise ValueError("no dataset found")
|
||||
keyword = Keyword(dataset)
|
||||
keyword.create(documents)
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
|
||||
document_ids = [document.metadata["doc_id"] for document in documents]
|
||||
db.session.query(DocumentSegment).where(
|
||||
DocumentSegment.document_id == document_id,
|
||||
@ -764,7 +764,7 @@ class IndexingRunner:
|
||||
) -> list[Document]:
|
||||
# get embedding model instance
|
||||
embedding_model_instance = None
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
if dataset.embedding_model_provider:
|
||||
embedding_model_instance = self.model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
|
||||
@ -67,7 +67,8 @@ class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel):
|
||||
if field_name == "inputs":
|
||||
data = {
|
||||
"messages": [
|
||||
dict(msg, **{"usage_metadata": usage_metadata, "file_list": file_list}) for msg in v
|
||||
dict(msg, **{"usage_metadata": usage_metadata, "file_list": file_list}) # type: ignore
|
||||
for msg in v
|
||||
]
|
||||
if isinstance(v, list)
|
||||
else v,
|
||||
|
||||
@ -209,8 +209,7 @@ class PluginInstaller(BasePluginClient):
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/decode/from_identifier",
|
||||
PluginDecodeResponse,
|
||||
data={"plugin_unique_identifier": plugin_unique_identifier},
|
||||
headers={"Content-Type": "application/json"},
|
||||
params={"plugin_unique_identifier": plugin_unique_identifier},
|
||||
)
|
||||
|
||||
def fetch_plugin_installation_by_ids(
|
||||
|
||||
@ -124,13 +124,13 @@ class HuaweiCloudVector(BaseVector):
|
||||
)
|
||||
)
|
||||
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
docs = []
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if score >= score_threshold:
|
||||
if doc.metadata is not None:
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@ from typing import Any
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.models.document import AttachmentDocument, Document
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
@ -71,7 +72,7 @@ class DatasetDocumentStore:
|
||||
if max_position is None:
|
||||
max_position = 0
|
||||
embedding_model = None
|
||||
if self._dataset.indexing_technique == "high_quality":
|
||||
if self._dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
|
||||
@ -9,6 +9,7 @@ from flask import current_app
|
||||
from sqlalchemy import delete, func, select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
||||
from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError
|
||||
from core.workflow.nodes.knowledge_index.protocols import Preview, PreviewItem, QaPreview
|
||||
@ -159,7 +160,7 @@ class IndexProcessor:
|
||||
tenant_id = dataset.tenant_id
|
||||
|
||||
preview_output = self.format_preview(chunk_structure, chunks)
|
||||
if indexing_technique != "high_quality":
|
||||
if indexing_technique != IndexTechniqueType.HIGH_QUALITY:
|
||||
return preview_output
|
||||
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
|
||||
@ -22,7 +22,7 @@ from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
|
||||
from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
@ -117,7 +117,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
with_keywords: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
vector = Vector(dataset)
|
||||
vector.create(documents)
|
||||
if multimodal_documents and dataset.is_multimodal:
|
||||
@ -155,7 +155,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
# Delete all summaries for the dataset
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, None)
|
||||
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
vector = Vector(dataset)
|
||||
if node_ids:
|
||||
vector.delete_by_ids(node_ids)
|
||||
@ -253,12 +253,12 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id)
|
||||
# add document segments
|
||||
doc_store.add_documents(docs=documents, save_child=False)
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
vector = Vector(dataset)
|
||||
vector.create(documents)
|
||||
if all_multimodal_documents and dataset.is_multimodal:
|
||||
vector.create_multimodal(all_multimodal_documents)
|
||||
elif dataset.indexing_technique == "economy":
|
||||
elif dataset.indexing_technique == IndexTechniqueType.ECONOMY:
|
||||
keyword = Keyword(dataset)
|
||||
keyword.add_texts(documents)
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
|
||||
from core.rag.models.document import AttachmentDocument, ChildDocument, Document, ParentChildStructureChunk
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
@ -128,7 +128,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
with_keywords: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
vector = Vector(dataset)
|
||||
for document in documents:
|
||||
child_documents = document.children
|
||||
@ -166,7 +166,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
# Delete all summaries for the dataset
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, None)
|
||||
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
delete_child_chunks = kwargs.get("delete_child_chunks") or False
|
||||
precomputed_child_node_ids = kwargs.get("precomputed_child_node_ids")
|
||||
vector = Vector(dataset)
|
||||
@ -332,7 +332,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id)
|
||||
# add document segments
|
||||
doc_store.add_documents(docs=documents, save_child=True)
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
all_child_documents = []
|
||||
all_multimodal_documents = []
|
||||
for doc in documents:
|
||||
|
||||
@ -21,7 +21,7 @@ from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
|
||||
from core.rag.models.document import AttachmentDocument, Document, QAStructureChunk
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
@ -141,7 +141,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
with_keywords: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
vector = Vector(dataset)
|
||||
vector.create(documents)
|
||||
if multimodal_documents and dataset.is_multimodal:
|
||||
@ -224,7 +224,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
# save node to document segment
|
||||
doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id)
|
||||
doc_store.add_documents(docs=documents, save_child=False)
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
vector = Vector(dataset)
|
||||
vector.create(documents)
|
||||
else:
|
||||
|
||||
@ -675,7 +675,7 @@ class DatasetRetrieval:
|
||||
# get top k
|
||||
top_k = retrieval_model_config["top_k"]
|
||||
# get retrieval method
|
||||
if selected_dataset.indexing_technique == "economy":
|
||||
if selected_dataset.indexing_technique == IndexTechniqueType.ECONOMY:
|
||||
retrieval_method = RetrievalMethod.KEYWORD_SEARCH
|
||||
else:
|
||||
retrieval_method = retrieval_model_config["search_method"]
|
||||
@ -752,7 +752,7 @@ class DatasetRetrieval:
|
||||
"The configured knowledge base list have different indexing technique, please set reranking model."
|
||||
)
|
||||
index_type = available_datasets[0].indexing_technique
|
||||
if index_type == "high_quality":
|
||||
if index_type == IndexTechniqueType.HIGH_QUALITY:
|
||||
embedding_model_check = all(
|
||||
item.embedding_model == available_datasets[0].embedding_model for item in available_datasets
|
||||
)
|
||||
@ -1068,7 +1068,7 @@ class DatasetRetrieval:
|
||||
else default_retrieval_model
|
||||
)
|
||||
|
||||
if dataset.indexing_technique == "economy":
|
||||
if dataset.indexing_technique == IndexTechniqueType.ECONOMY:
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
|
||||
|
||||
@ -2,6 +2,7 @@ import concurrent.futures
|
||||
import logging
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
||||
from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
@ -21,7 +22,7 @@ class SummaryIndex:
|
||||
if is_preview:
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
if not dataset or dataset.indexing_technique != "high_quality":
|
||||
if not dataset or dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
|
||||
return
|
||||
|
||||
if summary_index_setting is None:
|
||||
|
||||
@ -8,6 +8,7 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa
|
||||
from core.model_manager import ModelManager
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.models.document import Document as RagDocument
|
||||
from core.rag.rerank.rerank_model import RerankModelRunner
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
@ -169,7 +170,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model = dataset.retrieval_model or default_retrieval_model
|
||||
|
||||
if dataset.indexing_technique == "economy":
|
||||
if dataset.indexing_technique == IndexTechniqueType.ECONOMY:
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
|
||||
|
||||
@ -8,6 +8,7 @@ from core.rag.data_post_processor.data_post_processor import RerankingModelDict,
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities.context_entities import DocumentContext
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.models.document import Document as RetrievalDocument
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
@ -140,7 +141,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model = dataset.retrieval_model or default_retrieval_model
|
||||
retrieval_resource_list: list[RetrievalSourceMetadata] = []
|
||||
if dataset.indexing_technique == "economy":
|
||||
if dataset.indexing_technique == IndexTechniqueType.ECONOMY:
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
|
||||
@ -173,7 +174,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.on_tool_end(documents)
|
||||
document_score_list = {}
|
||||
if dataset.indexing_technique != "economy":
|
||||
if dataset.indexing_technique != IndexTechniqueType.ECONOMY:
|
||||
for item in documents:
|
||||
if item.metadata is not None and item.metadata.get("score"):
|
||||
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
@ -36,6 +39,11 @@ from .exc import (
|
||||
)
|
||||
from .protocols import TemplateRenderer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VARIABLE_PATTERN = re.compile(r"\{\{#[^#]+#\}\}")
|
||||
MAX_RESOLVED_VALUE_LENGTH = 1024
|
||||
|
||||
|
||||
def fetch_model_schema(*, model_instance: ModelInstance) -> AIModelEntity:
|
||||
model_schema = cast(LargeLanguageModel, model_instance.model_type_instance).get_model_schema(
|
||||
@ -475,3 +483,61 @@ def _append_file_prompts(
|
||||
prompt_messages[-1] = UserPromptMessage(content=file_prompts + existing_contents)
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=file_prompts))
|
||||
|
||||
|
||||
def _coerce_resolved_value(raw: str) -> int | float | bool | str:
|
||||
"""Try to restore the original type from a resolved template string.
|
||||
|
||||
Variable references are always resolved to text, but completion params may
|
||||
expect numeric or boolean values (e.g. a variable that holds "0.7" mapped to
|
||||
the ``temperature`` parameter). This helper attempts a JSON parse so that
|
||||
``"0.7"`` → ``0.7``, ``"true"`` → ``True``, etc. Plain strings that are not
|
||||
valid JSON literals are returned as-is.
|
||||
"""
|
||||
stripped = raw.strip()
|
||||
if not stripped:
|
||||
return raw
|
||||
|
||||
try:
|
||||
parsed: object = json.loads(stripped)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
return raw
|
||||
|
||||
if isinstance(parsed, (int, float, bool)):
|
||||
return parsed
|
||||
return raw
|
||||
|
||||
|
||||
def resolve_completion_params_variables(
|
||||
completion_params: Mapping[str, Any],
|
||||
variable_pool: VariablePool,
|
||||
) -> dict[str, Any]:
|
||||
"""Resolve variable references (``{{#node_id.var#}}``) in string-typed completion params.
|
||||
|
||||
Security notes:
|
||||
- Resolved values are length-capped to ``MAX_RESOLVED_VALUE_LENGTH`` to
|
||||
prevent denial-of-service through excessively large variable payloads.
|
||||
- This follows the same ``VariablePool.convert_template`` pattern used across
|
||||
Dify (Answer Node, HTTP Request Node, Agent Node, etc.). The downstream
|
||||
model plugin receives these values as structured JSON key-value pairs — they
|
||||
are never concatenated into raw HTTP headers or SQL queries.
|
||||
- Numeric/boolean coercion is applied so that variables holding ``"0.7"`` are
|
||||
restored to their native type rather than sent as a bare string.
|
||||
"""
|
||||
resolved: dict[str, Any] = {}
|
||||
for key, value in completion_params.items():
|
||||
if isinstance(value, str) and VARIABLE_PATTERN.search(value):
|
||||
segment_group = variable_pool.convert_template(value)
|
||||
text = segment_group.text
|
||||
if len(text) > MAX_RESOLVED_VALUE_LENGTH:
|
||||
logger.warning(
|
||||
"Resolved value for param '%s' truncated from %d to %d chars",
|
||||
key,
|
||||
len(text),
|
||||
MAX_RESOLVED_VALUE_LENGTH,
|
||||
)
|
||||
text = text[:MAX_RESOLVED_VALUE_LENGTH]
|
||||
resolved[key] = _coerce_resolved_value(text)
|
||||
else:
|
||||
resolved[key] = value
|
||||
return resolved
|
||||
|
||||
@ -202,6 +202,10 @@ class LLMNode(Node[LLMNodeData]):
|
||||
|
||||
# fetch model config
|
||||
model_instance = self._model_instance
|
||||
# Resolve variable references in string-typed completion params
|
||||
model_instance.parameters = llm_utils.resolve_completion_params_variables(
|
||||
model_instance.parameters, variable_pool
|
||||
)
|
||||
model_name = model_instance.model_name
|
||||
model_provider = model_instance.provider
|
||||
model_stop = model_instance.stop
|
||||
|
||||
@ -164,6 +164,10 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
)
|
||||
|
||||
model_instance = self._model_instance
|
||||
# Resolve variable references in string-typed completion params
|
||||
model_instance.parameters = llm_utils.resolve_completion_params_variables(
|
||||
model_instance.parameters, variable_pool
|
||||
)
|
||||
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
|
||||
raise InvalidModelTypeError("Model is not a Large Language Model")
|
||||
|
||||
|
||||
@ -114,6 +114,10 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
variables = {"query": query}
|
||||
# fetch model instance
|
||||
model_instance = self._model_instance
|
||||
# Resolve variable references in string-typed completion params
|
||||
model_instance.parameters = llm_utils.resolve_completion_params_variables(
|
||||
model_instance.parameters, variable_pool
|
||||
)
|
||||
memory = self._memory
|
||||
# fetch instruction
|
||||
node_data.instruction = node_data.instruction or ""
|
||||
|
||||
@ -125,7 +125,8 @@ class BroadcastChannel(Protocol):
|
||||
a specific topic, all subscription should receive the published message.
|
||||
|
||||
There are no restriction for the persistence of messages. Once a subscription is created, it
|
||||
should receive all subsequent messages published.
|
||||
should receive all subsequent messages published. However, a subscription should not receive
|
||||
any message published before the subscription is established.
|
||||
|
||||
`BroadcastChannel` implementations must be thread-safe and support concurrent use by multiple threads.
|
||||
"""
|
||||
|
||||
@ -64,7 +64,10 @@ class _StreamsSubscription(Subscription):
|
||||
self._client = client
|
||||
self._key = key
|
||||
self._closed = threading.Event()
|
||||
self._last_id = "0-0"
|
||||
# Setting initial last id to `$` to signal redis that we only want new messages.
|
||||
#
|
||||
# ref: https://redis.io/docs/latest/commands/xread/#the-special--id
|
||||
self._last_id = "$"
|
||||
self._queue: queue.Queue[object] = queue.Queue()
|
||||
self._start_lock = threading.Lock()
|
||||
self._listener: threading.Thread | None = None
|
||||
|
||||
@ -18,15 +18,23 @@ if TYPE_CHECKING:
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
def _resolve_current_user() -> EndUser | Account | None:
|
||||
"""
|
||||
Resolve the current user proxy to its underlying user object.
|
||||
This keeps unit tests working when they patch `current_user` directly
|
||||
instead of bootstrapping a full Flask-Login manager.
|
||||
"""
|
||||
user_proxy = current_user
|
||||
get_current_object = getattr(user_proxy, "_get_current_object", None)
|
||||
return get_current_object() if callable(get_current_object) else user_proxy # type: ignore
|
||||
|
||||
|
||||
def current_account_with_tenant():
|
||||
"""
|
||||
Resolve the underlying account for the current user proxy and ensure tenant context exists.
|
||||
Allows tests to supply plain Account mocks without the LocalProxy helper.
|
||||
"""
|
||||
user_proxy = current_user
|
||||
|
||||
get_current_object = getattr(user_proxy, "_get_current_object", None)
|
||||
user = get_current_object() if callable(get_current_object) else user_proxy # type: ignore
|
||||
user = _resolve_current_user()
|
||||
|
||||
if not isinstance(user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
@ -79,9 +87,10 @@ def login_required(func: Callable[P, R]) -> Callable[P, R | ResponseReturnValue]
|
||||
if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
|
||||
return current_app.ensure_sync(func)(*args, **kwargs)
|
||||
|
||||
user = _get_user()
|
||||
user = _resolve_current_user()
|
||||
if user is None or not user.is_authenticated:
|
||||
return current_app.login_manager.unauthorized() # type: ignore
|
||||
g._login_user = user
|
||||
# we put csrf validation here for less conflicts
|
||||
# TODO: maybe find a better place for it.
|
||||
check_csrf_token(request, user.id)
|
||||
|
||||
@ -20,7 +20,7 @@ from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.index_processor.constant.query_type import QueryType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.signature import sign_upload_file
|
||||
@ -137,7 +137,7 @@ class Dataset(Base):
|
||||
default=DatasetPermissionEnum.ONLY_ME,
|
||||
)
|
||||
data_source_type = mapped_column(EnumText(DataSourceType, length=255))
|
||||
indexing_technique: Mapped[str | None] = mapped_column(String(255))
|
||||
indexing_technique: Mapped[IndexTechniqueType | None] = mapped_column(EnumText(IndexTechniqueType, length=255))
|
||||
index_struct = mapped_column(LongText, nullable=True)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
@ -496,7 +496,9 @@ class Document(Base):
|
||||
)
|
||||
doc_type = mapped_column(EnumText(DocumentDocType, length=40), nullable=True)
|
||||
doc_metadata = mapped_column(AdjustedJSON, nullable=True)
|
||||
doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'"))
|
||||
doc_form: Mapped[IndexStructureType] = mapped_column(
|
||||
EnumText(IndexStructureType, length=255), nullable=False, server_default=sa.text("'text_model'")
|
||||
)
|
||||
doc_language = mapped_column(String(255), nullable=True)
|
||||
need_summary: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
|
||||
|
||||
@ -589,7 +589,9 @@ class AppModelConfig(TypeBase):
|
||||
__tablename__ = "app_model_configs"
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="app_model_config_pkey"), sa.Index("app_app_id_idx", "app_id"))
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
|
||||
model_id: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
|
||||
@ -938,7 +940,9 @@ class AccountTrialAppRecord(Base):
|
||||
class ExporleBanner(TypeBase):
|
||||
__tablename__ = "exporle_banners"
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="exporler_banner_pkey"),)
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv4_string, init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=gen_uuidv4_string, default_factory=gen_uuidv4_string, init=False
|
||||
)
|
||||
content: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False)
|
||||
link: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
sort: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||
@ -1847,7 +1851,9 @@ class AppAnnotationHitHistory(TypeBase):
|
||||
sa.Index("app_annotation_hit_histories_message_idx", "message_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
annotation_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
source: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
|
||||
@ -145,7 +145,9 @@ class ApiToolProvider(TypeBase):
|
||||
icon: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
# original schema
|
||||
schema: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
schema_type_str: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
schema_type_str: Mapped[ApiProviderSchemaType] = mapped_column(
|
||||
EnumText(ApiProviderSchemaType, length=40), nullable=False
|
||||
)
|
||||
# who created this tool
|
||||
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# tenant id
|
||||
|
||||
@ -174,7 +174,7 @@ dev = [
|
||||
"sseclient-py>=1.8.0",
|
||||
"pytest-timeout>=2.4.0",
|
||||
"pytest-xdist>=3.8.0",
|
||||
"pyrefly>=0.55.0",
|
||||
"pyrefly>=0.57.1",
|
||||
]
|
||||
|
||||
############################################################
|
||||
|
||||
@ -241,7 +241,7 @@ class AppService:
|
||||
class ArgsDict(TypedDict):
|
||||
name: str
|
||||
description: str
|
||||
icon_type: str
|
||||
icon_type: IconType | str | None
|
||||
icon: str
|
||||
icon_background: str
|
||||
use_icon_as_answer_icon: bool
|
||||
@ -257,7 +257,13 @@ class AppService:
|
||||
assert current_user is not None
|
||||
app.name = args["name"]
|
||||
app.description = args["description"]
|
||||
app.icon_type = IconType(args["icon_type"]) if args["icon_type"] else None
|
||||
icon_type = args.get("icon_type")
|
||||
if icon_type is None:
|
||||
resolved_icon_type = app.icon_type
|
||||
else:
|
||||
resolved_icon_type = IconType(icon_type)
|
||||
|
||||
app.icon_type = resolved_icon_type
|
||||
app.icon = args["icon"]
|
||||
app.icon_background = args["icon_background"]
|
||||
app.use_icon_as_answer_icon = args.get("use_icon_as_answer_icon", False)
|
||||
|
||||
@ -1,8 +1,16 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class AuthCredentials(TypedDict):
|
||||
auth_type: str
|
||||
config: dict[str, Any]
|
||||
|
||||
|
||||
class ApiKeyAuthBase(ABC):
|
||||
def __init__(self, credentials: dict):
|
||||
def __init__(self, credentials: AuthCredentials):
|
||||
self.credentials = credentials
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
|
||||
from services.auth.auth_type import AuthType
|
||||
|
||||
|
||||
class ApiKeyAuthFactory:
|
||||
def __init__(self, provider: str, credentials: dict):
|
||||
def __init__(self, provider: str, credentials: AuthCredentials):
|
||||
auth_factory = self.get_apikey_auth_factory(provider)
|
||||
self.auth = auth_factory(credentials)
|
||||
|
||||
|
||||
@ -2,11 +2,11 @@ import json
|
||||
|
||||
import httpx
|
||||
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
|
||||
|
||||
|
||||
class FirecrawlAuth(ApiKeyAuthBase):
|
||||
def __init__(self, credentials: dict):
|
||||
def __init__(self, credentials: AuthCredentials):
|
||||
super().__init__(credentials)
|
||||
auth_type = credentials.get("auth_type")
|
||||
if auth_type != "bearer":
|
||||
|
||||
@ -2,11 +2,11 @@ import json
|
||||
|
||||
import httpx
|
||||
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
|
||||
|
||||
|
||||
class JinaAuth(ApiKeyAuthBase):
|
||||
def __init__(self, credentials: dict):
|
||||
def __init__(self, credentials: AuthCredentials):
|
||||
super().__init__(credentials)
|
||||
auth_type = credentials.get("auth_type")
|
||||
if auth_type != "bearer":
|
||||
|
||||
@ -2,11 +2,11 @@ import json
|
||||
|
||||
import httpx
|
||||
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
|
||||
|
||||
|
||||
class JinaAuth(ApiKeyAuthBase):
|
||||
def __init__(self, credentials: dict):
|
||||
def __init__(self, credentials: AuthCredentials):
|
||||
super().__init__(credentials)
|
||||
auth_type = credentials.get("auth_type")
|
||||
if auth_type != "bearer":
|
||||
|
||||
@ -3,11 +3,11 @@ from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
|
||||
|
||||
|
||||
class WatercrawlAuth(ApiKeyAuthBase):
|
||||
def __init__(self, credentials: dict):
|
||||
def __init__(self, credentials: AuthCredentials):
|
||||
super().__init__(credentials)
|
||||
auth_type = credentials.get("auth_type")
|
||||
if auth_type != "x-api-key":
|
||||
|
||||
@ -21,7 +21,7 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.helper.name_generator import generate_incremental_name
|
||||
from core.model_manager import ModelManager
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from dify_graph.file import helpers as file_helpers
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
@ -228,7 +228,7 @@ class DatasetService:
|
||||
if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first():
|
||||
raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.")
|
||||
embedding_model = None
|
||||
if indexing_technique == "high_quality":
|
||||
if indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
model_manager = ModelManager()
|
||||
if embedding_model_provider and embedding_model_name:
|
||||
# check if embedding model setting is valid
|
||||
@ -254,7 +254,10 @@ class DatasetService:
|
||||
retrieval_model.reranking_model.reranking_provider_name,
|
||||
retrieval_model.reranking_model.reranking_model_name,
|
||||
)
|
||||
dataset = Dataset(name=name, indexing_technique=indexing_technique)
|
||||
dataset = Dataset(
|
||||
name=name,
|
||||
indexing_technique=IndexTechniqueType(indexing_technique) if indexing_technique else None,
|
||||
)
|
||||
# dataset = Dataset(name=name, provider=provider, config=config)
|
||||
dataset.description = description
|
||||
dataset.created_by = account.id
|
||||
@ -349,7 +352,7 @@ class DatasetService:
|
||||
|
||||
@staticmethod
|
||||
def check_dataset_model_setting(dataset):
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
@ -717,13 +720,13 @@ class DatasetService:
|
||||
if "indexing_technique" not in data:
|
||||
return None
|
||||
if dataset.indexing_technique != data["indexing_technique"]:
|
||||
if data["indexing_technique"] == "economy":
|
||||
if data["indexing_technique"] == IndexTechniqueType.ECONOMY:
|
||||
# Remove embedding model configuration for economy mode
|
||||
filtered_data["embedding_model"] = None
|
||||
filtered_data["embedding_model_provider"] = None
|
||||
filtered_data["collection_binding_id"] = None
|
||||
return "remove"
|
||||
elif data["indexing_technique"] == "high_quality":
|
||||
elif data["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY:
|
||||
# Configure embedding model for high quality mode
|
||||
DatasetService._configure_embedding_model_for_high_quality(data, filtered_data)
|
||||
return "add"
|
||||
@ -953,8 +956,8 @@ class DatasetService:
|
||||
dataset = session.merge(dataset)
|
||||
if not has_published:
|
||||
dataset.chunk_structure = knowledge_configuration.chunk_structure
|
||||
dataset.indexing_technique = knowledge_configuration.indexing_technique
|
||||
if knowledge_configuration.indexing_technique == "high_quality":
|
||||
dataset.indexing_technique = IndexTechniqueType(knowledge_configuration.indexing_technique)
|
||||
if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id, # ignore type error
|
||||
@ -976,7 +979,7 @@ class DatasetService:
|
||||
embedding_model_name,
|
||||
)
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
elif knowledge_configuration.indexing_technique == "economy":
|
||||
elif knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY:
|
||||
dataset.keyword_number = knowledge_configuration.keyword_number
|
||||
else:
|
||||
raise ValueError("Invalid index method")
|
||||
@ -991,9 +994,9 @@ class DatasetService:
|
||||
action = None
|
||||
if dataset.indexing_technique != knowledge_configuration.indexing_technique:
|
||||
# if update indexing_technique
|
||||
if knowledge_configuration.indexing_technique == "economy":
|
||||
if knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY:
|
||||
raise ValueError("Knowledge base indexing technique is not allowed to be updated to economy.")
|
||||
elif knowledge_configuration.indexing_technique == "high_quality":
|
||||
elif knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
action = "add"
|
||||
# get embedding model setting
|
||||
try:
|
||||
@ -1018,7 +1021,7 @@ class DatasetService:
|
||||
)
|
||||
dataset.is_multimodal = is_multimodal
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
dataset.indexing_technique = knowledge_configuration.indexing_technique
|
||||
dataset.indexing_technique = IndexTechniqueType(knowledge_configuration.indexing_technique)
|
||||
except LLMBadRequestError:
|
||||
raise ValueError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
@ -1029,7 +1032,7 @@ class DatasetService:
|
||||
else:
|
||||
# add default plugin id to both setting sets, to make sure the plugin model provider is consistent
|
||||
# Skip embedding model checks if not provided in the update request
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
skip_embedding_update = False
|
||||
try:
|
||||
# Handle existing model provider
|
||||
@ -1089,7 +1092,7 @@ class DatasetService:
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ValueError(ex.description)
|
||||
elif dataset.indexing_technique == "economy":
|
||||
elif dataset.indexing_technique == IndexTechniqueType.ECONOMY:
|
||||
if dataset.keyword_number != knowledge_configuration.keyword_number:
|
||||
dataset.keyword_number = knowledge_configuration.keyword_number
|
||||
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
|
||||
@ -1440,7 +1443,7 @@ class DocumentService:
|
||||
.filter(
|
||||
Document.id.in_(document_id_list),
|
||||
Document.dataset_id == dataset_id,
|
||||
Document.doc_form != "qa_model", # Skip qa_model documents
|
||||
Document.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents
|
||||
)
|
||||
.update({Document.need_summary: need_summary}, synchronize_session=False)
|
||||
)
|
||||
@ -1907,8 +1910,8 @@ class DocumentService:
|
||||
if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
|
||||
raise ValueError("Indexing technique is invalid")
|
||||
|
||||
dataset.indexing_technique = knowledge_config.indexing_technique
|
||||
if knowledge_config.indexing_technique == "high_quality":
|
||||
dataset.indexing_technique = IndexTechniqueType(knowledge_config.indexing_technique)
|
||||
if knowledge_config.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
model_manager = ModelManager()
|
||||
if knowledge_config.embedding_model and knowledge_config.embedding_model_provider:
|
||||
dataset_embedding_model = knowledge_config.embedding_model
|
||||
@ -2040,7 +2043,7 @@ class DocumentService:
|
||||
document.dataset_process_rule_id = dataset_process_rule.id
|
||||
document.updated_at = naive_utc_now()
|
||||
document.created_from = created_from
|
||||
document.doc_form = knowledge_config.doc_form
|
||||
document.doc_form = IndexStructureType(knowledge_config.doc_form)
|
||||
document.doc_language = knowledge_config.doc_language
|
||||
document.data_source_info = json.dumps(data_source_info)
|
||||
document.batch = batch
|
||||
@ -2640,7 +2643,7 @@ class DocumentService:
|
||||
document.splitting_completed_at = None
|
||||
document.updated_at = naive_utc_now()
|
||||
document.created_from = created_from
|
||||
document.doc_form = document_data.doc_form
|
||||
document.doc_form = IndexStructureType(document_data.doc_form)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
# update document segment
|
||||
@ -2689,7 +2692,7 @@ class DocumentService:
|
||||
|
||||
dataset_collection_binding_id = None
|
||||
retrieval_model = None
|
||||
if knowledge_config.indexing_technique == "high_quality":
|
||||
if knowledge_config.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
assert knowledge_config.embedding_model_provider
|
||||
assert knowledge_config.embedding_model
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
@ -2712,7 +2715,7 @@ class DocumentService:
|
||||
tenant_id=tenant_id,
|
||||
name="",
|
||||
data_source_type=knowledge_config.data_source.info_list.data_source_type,
|
||||
indexing_technique=knowledge_config.indexing_technique,
|
||||
indexing_technique=IndexTechniqueType(knowledge_config.indexing_technique),
|
||||
created_by=account.id,
|
||||
embedding_model=knowledge_config.embedding_model,
|
||||
embedding_model_provider=knowledge_config.embedding_model_provider,
|
||||
@ -3101,7 +3104,7 @@ class DocumentService:
|
||||
class SegmentService:
|
||||
@classmethod
|
||||
def segment_create_args_validate(cls, args: dict, document: Document):
|
||||
if document.doc_form == "qa_model":
|
||||
if document.doc_form == IndexStructureType.QA_INDEX:
|
||||
if "answer" not in args or not args["answer"]:
|
||||
raise ValueError("Answer is required")
|
||||
if not args["answer"].strip():
|
||||
@ -3125,7 +3128,7 @@ class SegmentService:
|
||||
doc_id = str(uuid.uuid4())
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
tokens = 0
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
@ -3158,7 +3161,7 @@ class SegmentService:
|
||||
completed_at=naive_utc_now(),
|
||||
created_by=current_user.id,
|
||||
)
|
||||
if document.doc_form == "qa_model":
|
||||
if document.doc_form == IndexStructureType.QA_INDEX:
|
||||
segment_document.word_count += len(args["answer"])
|
||||
segment_document.answer = args["answer"]
|
||||
|
||||
@ -3208,7 +3211,7 @@ class SegmentService:
|
||||
try:
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
embedding_model = None
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
@ -3230,9 +3233,9 @@ class SegmentService:
|
||||
doc_id = str(uuid.uuid4())
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
tokens = 0
|
||||
if dataset.indexing_technique == "high_quality" and embedding_model:
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY and embedding_model:
|
||||
# calc embedding use tokens
|
||||
if document.doc_form == "qa_model":
|
||||
if document.doc_form == IndexStructureType.QA_INDEX:
|
||||
tokens = embedding_model.get_text_embedding_num_tokens(
|
||||
texts=[content + segment_item["answer"]]
|
||||
)[0]
|
||||
@ -3255,7 +3258,7 @@ class SegmentService:
|
||||
completed_at=naive_utc_now(),
|
||||
created_by=current_user.id,
|
||||
)
|
||||
if document.doc_form == "qa_model":
|
||||
if document.doc_form == IndexStructureType.QA_INDEX:
|
||||
segment_document.answer = segment_item["answer"]
|
||||
segment_document.word_count += len(segment_item["answer"])
|
||||
increment_word_count += segment_document.word_count
|
||||
@ -3322,7 +3325,7 @@ class SegmentService:
|
||||
content = args.content or segment.content
|
||||
if segment.content == content:
|
||||
segment.word_count = len(content)
|
||||
if document.doc_form == "qa_model":
|
||||
if document.doc_form == IndexStructureType.QA_INDEX:
|
||||
segment.answer = args.answer
|
||||
segment.word_count += len(args.answer) if args.answer else 0
|
||||
word_count_change = segment.word_count - word_count_change
|
||||
@ -3345,7 +3348,7 @@ class SegmentService:
|
||||
if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
|
||||
# regenerate child chunks
|
||||
# get embedding model instance
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
# check embedding model setting
|
||||
model_manager = ModelManager()
|
||||
|
||||
@ -3382,7 +3385,7 @@ class SegmentService:
|
||||
# When user manually provides summary, allow saving even if summary_index_setting doesn't exist
|
||||
# summary_index_setting is only needed for LLM generation, not for manual summary vectorization
|
||||
# Vectorization uses dataset.embedding_model, which doesn't require summary_index_setting
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
# Query existing summary from database
|
||||
from models.dataset import DocumentSegmentSummary
|
||||
|
||||
@ -3409,7 +3412,7 @@ class SegmentService:
|
||||
else:
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
tokens = 0
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
@ -3419,7 +3422,7 @@ class SegmentService:
|
||||
)
|
||||
|
||||
# calc embedding use tokens
|
||||
if document.doc_form == "qa_model":
|
||||
if document.doc_form == IndexStructureType.QA_INDEX:
|
||||
segment.answer = args.answer
|
||||
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])[0] # type: ignore
|
||||
else:
|
||||
@ -3436,7 +3439,7 @@ class SegmentService:
|
||||
segment.enabled = True
|
||||
segment.disabled_at = None
|
||||
segment.disabled_by = None
|
||||
if document.doc_form == "qa_model":
|
||||
if document.doc_form == IndexStructureType.QA_INDEX:
|
||||
segment.answer = args.answer
|
||||
segment.word_count += len(args.answer) if args.answer else 0
|
||||
word_count_change = segment.word_count - word_count_change
|
||||
@ -3449,7 +3452,7 @@ class SegmentService:
|
||||
db.session.commit()
|
||||
if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
|
||||
# get embedding model instance
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
# check embedding model setting
|
||||
model_manager = ModelManager()
|
||||
|
||||
@ -3481,7 +3484,7 @@ class SegmentService:
|
||||
# update segment vector index
|
||||
VectorService.update_segment_vector(args.keywords, segment, dataset)
|
||||
# Handle summary index when content changed
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
from models.dataset import DocumentSegmentSummary
|
||||
|
||||
existing_summary = (
|
||||
|
||||
@ -22,6 +22,7 @@ from sqlalchemy.orm import Session
|
||||
from core.helper import ssrf_proxy
|
||||
from core.helper.name_generator import generate_incremental_name
|
||||
from core.plugin.entities.plugin import PluginDependency
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.workflow.nodes.datasource.entities import DatasourceNodeData
|
||||
from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE
|
||||
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
|
||||
@ -311,13 +312,13 @@ class RagPipelineDslService:
|
||||
"icon_background": icon_background,
|
||||
"icon_url": icon_url,
|
||||
},
|
||||
indexing_technique=knowledge_configuration.indexing_technique,
|
||||
indexing_technique=IndexTechniqueType(knowledge_configuration.indexing_technique),
|
||||
created_by=account.id,
|
||||
retrieval_model=knowledge_configuration.retrieval_model.model_dump(),
|
||||
runtime_mode=DatasetRuntimeMode.RAG_PIPELINE,
|
||||
chunk_structure=knowledge_configuration.chunk_structure,
|
||||
)
|
||||
if knowledge_configuration.indexing_technique == "high_quality":
|
||||
if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
dataset_collection_binding = (
|
||||
self._session.query(DatasetCollectionBinding)
|
||||
.where(
|
||||
@ -343,7 +344,7 @@ class RagPipelineDslService:
|
||||
dataset.collection_binding_id = dataset_collection_binding_id
|
||||
dataset.embedding_model = knowledge_configuration.embedding_model
|
||||
dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider
|
||||
elif knowledge_configuration.indexing_technique == "economy":
|
||||
elif knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY:
|
||||
dataset.keyword_number = knowledge_configuration.keyword_number
|
||||
# Update summary_index_setting if provided
|
||||
if knowledge_configuration.summary_index_setting is not None:
|
||||
@ -443,18 +444,18 @@ class RagPipelineDslService:
|
||||
"icon_background": icon_background,
|
||||
"icon_url": icon_url,
|
||||
},
|
||||
indexing_technique=knowledge_configuration.indexing_technique,
|
||||
indexing_technique=IndexTechniqueType(knowledge_configuration.indexing_technique),
|
||||
created_by=account.id,
|
||||
retrieval_model=knowledge_configuration.retrieval_model.model_dump(),
|
||||
runtime_mode=DatasetRuntimeMode.RAG_PIPELINE,
|
||||
chunk_structure=knowledge_configuration.chunk_structure,
|
||||
)
|
||||
else:
|
||||
dataset.indexing_technique = knowledge_configuration.indexing_technique
|
||||
dataset.indexing_technique = IndexTechniqueType(knowledge_configuration.indexing_technique)
|
||||
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
|
||||
dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE
|
||||
dataset.chunk_structure = knowledge_configuration.chunk_structure
|
||||
if knowledge_configuration.indexing_technique == "high_quality":
|
||||
if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
dataset_collection_binding = (
|
||||
self._session.query(DatasetCollectionBinding)
|
||||
.where(
|
||||
@ -480,7 +481,7 @@ class RagPipelineDslService:
|
||||
dataset.collection_binding_id = dataset_collection_binding_id
|
||||
dataset.embedding_model = knowledge_configuration.embedding_model
|
||||
dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider
|
||||
elif knowledge_configuration.indexing_technique == "economy":
|
||||
elif knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY:
|
||||
dataset.keyword_number = knowledge_configuration.keyword_number
|
||||
# Update summary_index_setting if provided
|
||||
if knowledge_configuration.summary_index_setting is not None:
|
||||
@ -772,7 +773,7 @@ class RagPipelineDslService:
|
||||
)
|
||||
case _ if typ == KNOWLEDGE_INDEX_NODE_TYPE:
|
||||
knowledge_index_entity = KnowledgeConfiguration.model_validate(node["data"])
|
||||
if knowledge_index_entity.indexing_technique == "high_quality":
|
||||
if knowledge_index_entity.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
if knowledge_index_entity.embedding_model_provider:
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_model_provider_dependency(
|
||||
|
||||
@ -9,6 +9,7 @@ from flask_login import current_user
|
||||
|
||||
from constants import DOCUMENT_EXTENSIONS
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from factories import variable_factory
|
||||
@ -79,9 +80,9 @@ class RagPipelineTransformService:
|
||||
pipeline = self._create_pipeline(pipeline_yaml)
|
||||
|
||||
# save chunk structure to dataset
|
||||
if doc_form == "hierarchical_model":
|
||||
if doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
dataset.chunk_structure = "hierarchical_model"
|
||||
elif doc_form == "text_model":
|
||||
elif doc_form == IndexStructureType.PARAGRAPH_INDEX:
|
||||
dataset.chunk_structure = "text_model"
|
||||
else:
|
||||
raise ValueError("Unsupported doc form")
|
||||
@ -101,38 +102,38 @@ class RagPipelineTransformService:
|
||||
|
||||
def _get_transform_yaml(self, doc_form: str, datasource_type: str, indexing_technique: str | None):
|
||||
pipeline_yaml = {}
|
||||
if doc_form == "text_model":
|
||||
if doc_form == IndexStructureType.PARAGRAPH_INDEX:
|
||||
match datasource_type:
|
||||
case DataSourceType.UPLOAD_FILE:
|
||||
if indexing_technique == "high_quality":
|
||||
if indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
# get graph from transform.file-general-high-quality.yml
|
||||
with open(f"{Path(__file__).parent}/transform/file-general-high-quality.yml") as f:
|
||||
pipeline_yaml = yaml.safe_load(f)
|
||||
if indexing_technique == "economy":
|
||||
if indexing_technique == IndexTechniqueType.ECONOMY:
|
||||
# get graph from transform.file-general-economy.yml
|
||||
with open(f"{Path(__file__).parent}/transform/file-general-economy.yml") as f:
|
||||
pipeline_yaml = yaml.safe_load(f)
|
||||
case DataSourceType.NOTION_IMPORT:
|
||||
if indexing_technique == "high_quality":
|
||||
if indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
# get graph from transform.notion-general-high-quality.yml
|
||||
with open(f"{Path(__file__).parent}/transform/notion-general-high-quality.yml") as f:
|
||||
pipeline_yaml = yaml.safe_load(f)
|
||||
if indexing_technique == "economy":
|
||||
if indexing_technique == IndexTechniqueType.ECONOMY:
|
||||
# get graph from transform.notion-general-economy.yml
|
||||
with open(f"{Path(__file__).parent}/transform/notion-general-economy.yml") as f:
|
||||
pipeline_yaml = yaml.safe_load(f)
|
||||
case DataSourceType.WEBSITE_CRAWL:
|
||||
if indexing_technique == "high_quality":
|
||||
if indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
# get graph from transform.website-crawl-general-high-quality.yml
|
||||
with open(f"{Path(__file__).parent}/transform/website-crawl-general-high-quality.yml") as f:
|
||||
pipeline_yaml = yaml.safe_load(f)
|
||||
if indexing_technique == "economy":
|
||||
if indexing_technique == IndexTechniqueType.ECONOMY:
|
||||
# get graph from transform.website-crawl-general-economy.yml
|
||||
with open(f"{Path(__file__).parent}/transform/website-crawl-general-economy.yml") as f:
|
||||
pipeline_yaml = yaml.safe_load(f)
|
||||
case _:
|
||||
raise ValueError("Unsupported datasource type")
|
||||
elif doc_form == "hierarchical_model":
|
||||
elif doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
match datasource_type:
|
||||
case DataSourceType.UPLOAD_FILE:
|
||||
# get graph from transform.file-parentchild.yml
|
||||
@ -169,11 +170,11 @@ class RagPipelineTransformService:
|
||||
):
|
||||
knowledge_configuration_dict = node.get("data", {})
|
||||
|
||||
if indexing_technique == "high_quality":
|
||||
if indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
knowledge_configuration.embedding_model = dataset.embedding_model
|
||||
knowledge_configuration.embedding_model_provider = dataset.embedding_model_provider
|
||||
if retrieval_model:
|
||||
if indexing_technique == "economy":
|
||||
if indexing_technique == IndexTechniqueType.ECONOMY:
|
||||
retrieval_model.search_method = RetrievalMethod.KEYWORD_SEARCH
|
||||
knowledge_configuration.retrieval_model = retrieval_model
|
||||
else:
|
||||
|
||||
@ -12,6 +12,7 @@ from core.db.session_factory import session_factory
|
||||
from core.model_manager import ModelManager
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
||||
from core.rag.models.document import Document
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
@ -140,7 +141,7 @@ class SummaryIndexService:
|
||||
session: Optional SQLAlchemy session. If provided, uses this session instead of creating a new one.
|
||||
If not provided, creates a new session and commits automatically.
|
||||
"""
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
|
||||
logger.warning(
|
||||
"Summary vectorization skipped for dataset %s: indexing_technique is not high_quality",
|
||||
dataset.id,
|
||||
@ -724,7 +725,7 @@ class SummaryIndexService:
|
||||
List of created DocumentSegmentSummary instances
|
||||
"""
|
||||
# Only generate summary index for high_quality indexing technique
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
|
||||
logger.info(
|
||||
"Skipping summary generation for dataset %s: indexing_technique is %s, not 'high_quality'",
|
||||
dataset.id,
|
||||
@ -851,7 +852,7 @@ class SummaryIndexService:
|
||||
)
|
||||
|
||||
# Remove from vector database (but keep records)
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id]
|
||||
if summary_node_ids:
|
||||
try:
|
||||
@ -889,7 +890,7 @@ class SummaryIndexService:
|
||||
segment_ids: List of segment IDs to enable summaries for. If None, enable all.
|
||||
"""
|
||||
# Only enable summary index for high_quality indexing technique
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
|
||||
return
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
@ -981,7 +982,7 @@ class SummaryIndexService:
|
||||
return
|
||||
|
||||
# Delete from vector database
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id]
|
||||
if summary_node_ids:
|
||||
vector = Vector(dataset)
|
||||
@ -1012,7 +1013,7 @@ class SummaryIndexService:
|
||||
Updated DocumentSegmentSummary instance, or None if indexing technique is not high_quality
|
||||
"""
|
||||
# Only update summary index for high_quality indexing technique
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
|
||||
return None
|
||||
|
||||
# When user manually provides summary, allow saving even if summary_index_setting doesn't exist
|
||||
|
||||
@ -4,7 +4,7 @@ from core.model_manager import ModelInstance, ModelManager
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.models.document import AttachmentDocument, Document
|
||||
@ -45,7 +45,7 @@ class VectorService:
|
||||
if not processing_rule:
|
||||
raise ValueError("No processing rule found.")
|
||||
# get embedding model instance
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
# check embedding model setting
|
||||
model_manager = ModelManager()
|
||||
|
||||
@ -112,7 +112,7 @@ class VectorService:
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
# update vector index
|
||||
vector = Vector(dataset=dataset)
|
||||
vector.delete_by_ids([segment.index_node_id])
|
||||
@ -197,7 +197,7 @@ class VectorService:
|
||||
"dataset_id": child_segment.dataset_id,
|
||||
},
|
||||
)
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
# save vector index
|
||||
vector = Vector(dataset=dataset)
|
||||
vector.add_texts([child_document], duplicate_check=True)
|
||||
@ -237,7 +237,7 @@ class VectorService:
|
||||
delete_node_ids.append(update_child_chunk.index_node_id)
|
||||
for delete_child_chunk in delete_child_chunks:
|
||||
delete_node_ids.append(delete_child_chunk.index_node_id)
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
# update vector index
|
||||
vector = Vector(dataset=dataset)
|
||||
if delete_node_ids:
|
||||
@ -252,7 +252,7 @@ class VectorService:
|
||||
|
||||
@classmethod
|
||||
def update_multimodel_vector(cls, segment: DocumentSegment, attachment_ids: list[str], dataset: Dataset):
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
|
||||
return
|
||||
|
||||
attachments = segment.attachments
|
||||
|
||||
@ -5,6 +5,7 @@ import click
|
||||
from celery import shared_task
|
||||
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.models.document import Document
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
@ -36,7 +37,7 @@ def add_annotation_to_index_task(
|
||||
dataset = Dataset(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model_provider=dataset_collection_binding.provider_name,
|
||||
embedding_model=dataset_collection_binding.model_name,
|
||||
collection_binding_id=dataset_collection_binding.id,
|
||||
|
||||
@ -7,6 +7,7 @@ from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
@ -67,7 +68,7 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id:
|
||||
dataset = Dataset(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model_provider=dataset_collection_binding.provider_name,
|
||||
embedding_model=dataset_collection_binding.model_name,
|
||||
collection_binding_id=dataset_collection_binding.id,
|
||||
|
||||
@ -5,6 +5,7 @@ import click
|
||||
from celery import shared_task
|
||||
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
@ -26,7 +27,7 @@ def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str
|
||||
dataset = Dataset(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
collection_binding_id=dataset_collection_binding.id,
|
||||
)
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@ from sqlalchemy import exists, select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
from models.model import App, AppAnnotationSetting, MessageAnnotation
|
||||
@ -44,7 +45,7 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str):
|
||||
dataset = Dataset(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
collection_binding_id=app_annotation_setting.collection_binding_id,
|
||||
)
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@ from sqlalchemy import select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@ -64,7 +65,7 @@ def enable_annotation_reply_task(
|
||||
old_dataset = Dataset(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model_provider=old_dataset_collection_binding.provider_name,
|
||||
embedding_model=old_dataset_collection_binding.model_name,
|
||||
collection_binding_id=old_dataset_collection_binding.id,
|
||||
@ -93,7 +94,7 @@ def enable_annotation_reply_task(
|
||||
dataset = Dataset(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model_provider=embedding_provider_name,
|
||||
embedding_model=embedding_model_name,
|
||||
collection_binding_id=dataset_collection_binding.id,
|
||||
|
||||
@ -5,6 +5,7 @@ import click
|
||||
from celery import shared_task
|
||||
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.models.document import Document
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
@ -37,7 +38,7 @@ def update_annotation_to_index_task(
|
||||
dataset = Dataset(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model_provider=dataset_collection_binding.provider_name,
|
||||
embedding_model=dataset_collection_binding.model_name,
|
||||
collection_binding_id=dataset_collection_binding.id,
|
||||
|
||||
@ -11,6 +11,7 @@ from sqlalchemy import func
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.model_manager import ModelManager
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
@ -109,7 +110,7 @@ def batch_create_segment_to_index_task(
|
||||
df = pd.read_csv(file_path)
|
||||
content = []
|
||||
for _, row in df.iterrows():
|
||||
if document_config["doc_form"] == "qa_model":
|
||||
if document_config["doc_form"] == IndexStructureType.QA_INDEX:
|
||||
data = {"content": row.iloc[0], "answer": row.iloc[1]}
|
||||
else:
|
||||
data = {"content": row.iloc[0]}
|
||||
@ -119,7 +120,7 @@ def batch_create_segment_to_index_task(
|
||||
|
||||
document_segments = []
|
||||
embedding_model = None
|
||||
if dataset_config["indexing_technique"] == "high_quality":
|
||||
if dataset_config["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY:
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=dataset_config["tenant_id"],
|
||||
@ -159,7 +160,7 @@ def batch_create_segment_to_index_task(
|
||||
status="completed",
|
||||
completed_at=naive_utc_now(),
|
||||
)
|
||||
if document_config["doc_form"] == "qa_model":
|
||||
if document_config["doc_form"] == IndexStructureType.QA_INDEX:
|
||||
segment_document.answer = segment["answer"]
|
||||
segment_document.word_count += len(segment["answer"])
|
||||
word_count_change += segment_document.word_count
|
||||
|
||||
@ -10,6 +10,7 @@ from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
from core.entities.document_task import DocumentTask
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@ -126,7 +127,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
||||
logger.warning("Dataset %s not found after indexing", dataset_id)
|
||||
return
|
||||
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
summary_index_setting = dataset.summary_index_setting
|
||||
if summary_index_setting and summary_index_setting.get("enable"):
|
||||
# expire all session to get latest document's indexing status
|
||||
@ -150,7 +151,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
||||
)
|
||||
if (
|
||||
document.indexing_status == IndexingStatus.COMPLETED
|
||||
and document.doc_form != "qa_model"
|
||||
and document.doc_form != IndexStructureType.QA_INDEX
|
||||
and document.need_summary is True
|
||||
):
|
||||
try:
|
||||
|
||||
@ -7,6 +7,7 @@ import click
|
||||
from celery import shared_task
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
@ -59,7 +60,7 @@ def generate_summary_index_task(dataset_id: str, document_id: str, segment_ids:
|
||||
return
|
||||
|
||||
# Only generate summary index for high_quality indexing technique
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Skipping summary generation for dataset {dataset_id}: "
|
||||
|
||||
@ -9,6 +9,7 @@ from celery import shared_task
|
||||
from sqlalchemy import or_, select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
@ -52,7 +53,7 @@ def regenerate_summary_index_task(
|
||||
return
|
||||
|
||||
# Only regenerate summary index for high_quality indexing technique
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Skipping summary regeneration for dataset {dataset_id}: "
|
||||
@ -106,7 +107,7 @@ def regenerate_summary_index_task(
|
||||
),
|
||||
DatasetDocument.enabled == True, # Document must be enabled
|
||||
DatasetDocument.archived == False, # Document must not be archived
|
||||
DatasetDocument.doc_form != "qa_model", # Skip qa_model documents
|
||||
DatasetDocument.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents
|
||||
)
|
||||
.order_by(DocumentSegment.document_id.asc(), DocumentSegment.position.asc())
|
||||
.all()
|
||||
@ -209,7 +210,7 @@ def regenerate_summary_index_task(
|
||||
|
||||
for dataset_document in dataset_documents:
|
||||
# Skip qa_model documents
|
||||
if dataset_document.doc_form == "qa_model":
|
||||
if dataset_document.doc_form == IndexStructureType.QA_INDEX:
|
||||
continue
|
||||
|
||||
try:
|
||||
|
||||
@ -68,7 +68,7 @@ def init_tool_node(config: dict):
|
||||
return node
|
||||
|
||||
|
||||
def test_tool_variable_invoke():
|
||||
def test_tool_variable_invoke(monkeypatch):
|
||||
node = init_tool_node(
|
||||
config={
|
||||
"id": "1",
|
||||
@ -103,7 +103,7 @@ def test_tool_variable_invoke():
|
||||
assert item.node_run_result.outputs.get("text") is not None
|
||||
|
||||
|
||||
def test_tool_mixed_invoke():
|
||||
def test_tool_mixed_invoke(monkeypatch):
|
||||
node = init_tool_node(
|
||||
config={
|
||||
"id": "1",
|
||||
|
||||
@ -1,8 +1,11 @@
|
||||
"""Testcontainers integration tests for email register controller endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.auth.email_register import (
|
||||
EmailRegisterCheckApi,
|
||||
@ -13,14 +16,11 @@ from services.account_service import AccountService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
return flask_app
|
||||
def app(flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
|
||||
class TestEmailRegisterSendEmailApi:
|
||||
@patch("controllers.console.auth.email_register.Session")
|
||||
@patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.email_register.AccountService.send_email_register_email")
|
||||
@patch("controllers.console.auth.email_register.BillingService.is_email_in_freeze")
|
||||
@ -33,20 +33,15 @@ class TestEmailRegisterSendEmailApi:
|
||||
mock_is_freeze,
|
||||
mock_send_mail,
|
||||
mock_get_account,
|
||||
mock_session_cls,
|
||||
app,
|
||||
):
|
||||
mock_send_mail.return_value = "token-123"
|
||||
mock_is_freeze.return_value = False
|
||||
mock_account = MagicMock()
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
mock_get_account.return_value = mock_account
|
||||
|
||||
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
|
||||
with (
|
||||
patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")),
|
||||
patch("controllers.console.auth.email_register.dify_config", SimpleNamespace(BILLING_ENABLED=True)),
|
||||
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags),
|
||||
@ -61,7 +56,6 @@ class TestEmailRegisterSendEmailApi:
|
||||
assert response == {"result": "success", "data": "token-123"}
|
||||
mock_is_freeze.assert_called_once_with("invitee@example.com")
|
||||
mock_send_mail.assert_called_once_with(email="invitee@example.com", account=mock_account, language="en-US")
|
||||
mock_get_account.assert_called_once_with("Invitee@Example.com", session=mock_session)
|
||||
mock_extract_ip.assert_called_once()
|
||||
mock_is_email_send_ip_limit.assert_called_once_with("127.0.0.1")
|
||||
|
||||
@ -89,7 +83,6 @@ class TestEmailRegisterCheckApi:
|
||||
|
||||
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
|
||||
with (
|
||||
patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")),
|
||||
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags),
|
||||
):
|
||||
@ -114,7 +107,6 @@ class TestEmailRegisterResetApi:
|
||||
@patch("controllers.console.auth.email_register.AccountService.reset_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.email_register.AccountService.login")
|
||||
@patch("controllers.console.auth.email_register.EmailRegisterResetApi._create_new_account")
|
||||
@patch("controllers.console.auth.email_register.Session")
|
||||
@patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token")
|
||||
@patch("controllers.console.auth.email_register.AccountService.get_email_register_data")
|
||||
@ -125,7 +117,6 @@ class TestEmailRegisterResetApi:
|
||||
mock_get_data,
|
||||
mock_revoke_token,
|
||||
mock_get_account,
|
||||
mock_session_cls,
|
||||
mock_create_account,
|
||||
mock_login,
|
||||
mock_reset_login_rate,
|
||||
@ -136,14 +127,10 @@ class TestEmailRegisterResetApi:
|
||||
token_pair = MagicMock()
|
||||
token_pair.model_dump.return_value = {"access_token": "a", "refresh_token": "r"}
|
||||
mock_login.return_value = token_pair
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
mock_get_account.return_value = None
|
||||
|
||||
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
|
||||
with (
|
||||
patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")),
|
||||
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags),
|
||||
):
|
||||
@ -159,19 +146,19 @@ class TestEmailRegisterResetApi:
|
||||
mock_reset_login_rate.assert_called_once_with("invitee@example.com")
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_extract_ip.assert_called_once()
|
||||
mock_get_account.assert_called_once_with("Invitee@Example.com", session=mock_session)
|
||||
|
||||
|
||||
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup():
|
||||
def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase():
|
||||
"""Test that case fallback tries lowercase when exact match fails."""
|
||||
mock_session = MagicMock()
|
||||
first_query = MagicMock()
|
||||
first_query.scalar_one_or_none.return_value = None
|
||||
first_result = MagicMock()
|
||||
first_result.scalar_one_or_none.return_value = None
|
||||
expected_account = MagicMock()
|
||||
second_query = MagicMock()
|
||||
second_query.scalar_one_or_none.return_value = expected_account
|
||||
mock_session.execute.side_effect = [first_query, second_query]
|
||||
second_result = MagicMock()
|
||||
second_result.scalar_one_or_none.return_value = expected_account
|
||||
mock_session.execute.side_effect = [first_result, second_result]
|
||||
|
||||
account = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session)
|
||||
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session)
|
||||
|
||||
assert account is expected_account
|
||||
assert result is expected_account
|
||||
assert mock_session.execute.call_count == 2
|
||||
@ -1,8 +1,11 @@
|
||||
"""Testcontainers integration tests for forgot password controller endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.auth.forgot_password import (
|
||||
ForgotPasswordCheckApi,
|
||||
@ -13,14 +16,11 @@ from services.account_service import AccountService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
return flask_app
|
||||
def app(flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
|
||||
class TestForgotPasswordSendEmailApi:
|
||||
@patch("controllers.console.auth.forgot_password.Session")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit", return_value=False)
|
||||
@ -31,19 +31,15 @@ class TestForgotPasswordSendEmailApi:
|
||||
mock_is_ip_limit,
|
||||
mock_send_email,
|
||||
mock_get_account,
|
||||
mock_session_cls,
|
||||
app,
|
||||
):
|
||||
mock_account = MagicMock()
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_send_email.return_value = "token-123"
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
|
||||
wraps_features = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
|
||||
controller_features = SimpleNamespace(is_allow_register=True)
|
||||
with (
|
||||
patch("controllers.console.auth.forgot_password.db", SimpleNamespace(engine="engine")),
|
||||
patch(
|
||||
"controllers.console.auth.forgot_password.FeatureService.get_system_features",
|
||||
return_value=controller_features,
|
||||
@ -59,7 +55,6 @@ class TestForgotPasswordSendEmailApi:
|
||||
response = ForgotPasswordSendEmailApi().post()
|
||||
|
||||
assert response == {"result": "success", "data": "token-123"}
|
||||
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session)
|
||||
mock_send_email.assert_called_once_with(
|
||||
account=mock_account,
|
||||
email="user@example.com",
|
||||
@ -117,7 +112,6 @@ class TestForgotPasswordCheckApi:
|
||||
|
||||
class TestForgotPasswordResetApi:
|
||||
@patch("controllers.console.auth.forgot_password.ForgotPasswordResetApi._update_existing_account")
|
||||
@patch("controllers.console.auth.forgot_password.Session")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
@ -126,7 +120,6 @@ class TestForgotPasswordResetApi:
|
||||
mock_get_reset_data,
|
||||
mock_revoke_token,
|
||||
mock_get_account,
|
||||
mock_session_cls,
|
||||
mock_update_account,
|
||||
app,
|
||||
):
|
||||
@ -134,12 +127,8 @@ class TestForgotPasswordResetApi:
|
||||
mock_account = MagicMock()
|
||||
mock_get_account.return_value = mock_account
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
|
||||
wraps_features = SimpleNamespace(enable_email_password_login=True)
|
||||
with (
|
||||
patch("controllers.console.auth.forgot_password.db", SimpleNamespace(engine="engine")),
|
||||
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
|
||||
):
|
||||
@ -157,20 +146,22 @@ class TestForgotPasswordResetApi:
|
||||
assert response == {"result": "success"}
|
||||
mock_get_reset_data.assert_called_once_with("token-123")
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session)
|
||||
mock_update_account.assert_called_once()
|
||||
|
||||
|
||||
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup():
|
||||
def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase():
|
||||
"""Test that case fallback tries lowercase when exact match fails."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_session = MagicMock()
|
||||
first_query = MagicMock()
|
||||
first_query.scalar_one_or_none.return_value = None
|
||||
first_result = MagicMock()
|
||||
first_result.scalar_one_or_none.return_value = None
|
||||
expected_account = MagicMock()
|
||||
second_query = MagicMock()
|
||||
second_query.scalar_one_or_none.return_value = expected_account
|
||||
mock_session.execute.side_effect = [first_query, second_query]
|
||||
second_result = MagicMock()
|
||||
second_result.scalar_one_or_none.return_value = expected_account
|
||||
mock_session.execute.side_effect = [first_result, second_result]
|
||||
|
||||
account = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session)
|
||||
result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session)
|
||||
|
||||
assert account is expected_account
|
||||
assert result is expected_account
|
||||
assert mock_session.execute.call_count == 2
|
||||
@ -1,7 +1,10 @@
|
||||
"""Testcontainers integration tests for OAuth controller endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.auth.oauth import (
|
||||
OAuthCallback,
|
||||
@ -18,10 +21,8 @@ from services.errors.account import AccountRegisterError
|
||||
|
||||
class TestGetOAuthProviders:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("github_config", "google_config", "expected_github", "expected_google"),
|
||||
@ -64,10 +65,8 @@ class TestOAuthLogin:
|
||||
return OAuthLogin()
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
@pytest.fixture
|
||||
def mock_oauth_provider(self):
|
||||
@ -131,10 +130,8 @@ class TestOAuthCallback:
|
||||
return OAuthCallback()
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
@pytest.fixture
|
||||
def oauth_setup(self):
|
||||
@ -190,15 +187,8 @@ class TestOAuthCallback:
|
||||
(KeyError("Missing key"), "OAuth process failed"),
|
||||
],
|
||||
)
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||
def test_should_handle_oauth_exceptions(
|
||||
self, mock_get_providers, mock_db, resource, app, exception, expected_error
|
||||
):
|
||||
# Mock database session
|
||||
mock_db.session = MagicMock()
|
||||
mock_db.session.rollback = MagicMock()
|
||||
|
||||
def test_should_handle_oauth_exceptions(self, mock_get_providers, resource, app, exception, expected_error):
|
||||
# Import the real requests module to create a proper exception
|
||||
import httpx
|
||||
|
||||
@ -258,7 +248,6 @@ class TestOAuthCallback:
|
||||
)
|
||||
@patch("controllers.console.auth.oauth.AccountService")
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
@patch("controllers.console.auth.oauth.dify_config")
|
||||
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||
@patch("controllers.console.auth.oauth._generate_account")
|
||||
@ -269,7 +258,6 @@ class TestOAuthCallback:
|
||||
mock_generate_account,
|
||||
mock_get_providers,
|
||||
mock_config,
|
||||
mock_db,
|
||||
mock_tenant_service,
|
||||
mock_account_service,
|
||||
resource,
|
||||
@ -278,10 +266,6 @@ class TestOAuthCallback:
|
||||
account_status,
|
||||
expected_redirect,
|
||||
):
|
||||
# Mock database session
|
||||
mock_db.session = MagicMock()
|
||||
mock_db.session.rollback = MagicMock()
|
||||
mock_db.session.commit = MagicMock()
|
||||
|
||||
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
|
||||
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
|
||||
@ -306,14 +290,12 @@ class TestOAuthCallback:
|
||||
@patch("controllers.console.auth.oauth.dify_config")
|
||||
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||
@patch("controllers.console.auth.oauth._generate_account")
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
@patch("controllers.console.auth.oauth.AccountService")
|
||||
def test_should_activate_pending_account(
|
||||
self,
|
||||
mock_account_service,
|
||||
mock_tenant_service,
|
||||
mock_db,
|
||||
mock_generate_account,
|
||||
mock_get_providers,
|
||||
mock_config,
|
||||
@ -338,12 +320,10 @@ class TestOAuthCallback:
|
||||
|
||||
assert mock_account.status == AccountStatus.ACTIVE
|
||||
assert mock_account.initialized_at is not None
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
@patch("controllers.console.auth.oauth.dify_config")
|
||||
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||
@patch("controllers.console.auth.oauth._generate_account")
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
@patch("controllers.console.auth.oauth.AccountService")
|
||||
@patch("controllers.console.auth.oauth.redirect")
|
||||
@ -352,7 +332,6 @@ class TestOAuthCallback:
|
||||
mock_redirect,
|
||||
mock_account_service,
|
||||
mock_tenant_service,
|
||||
mock_db,
|
||||
mock_generate_account,
|
||||
mock_get_providers,
|
||||
mock_config,
|
||||
@ -414,6 +393,10 @@ class TestOAuthCallback:
|
||||
|
||||
|
||||
class TestAccountGeneration:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
@pytest.fixture
|
||||
def user_info(self):
|
||||
return OAuthUserInfo(id="123", name="Test User", email="test@example.com")
|
||||
@ -425,15 +408,10 @@ class TestAccountGeneration:
|
||||
return account
|
||||
|
||||
@patch("controllers.console.auth.oauth.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.oauth.Session")
|
||||
@patch("controllers.console.auth.oauth.Account")
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
def test_should_get_account_by_openid_or_email(
|
||||
self, mock_db, mock_account_model, mock_session, mock_get_account, user_info, mock_account
|
||||
self, mock_account_model, mock_get_account, flask_req_ctx_with_containers, user_info, mock_account
|
||||
):
|
||||
# Mock db.engine for Session creation
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
# Test OpenID found
|
||||
mock_account_model.get_by_openid.return_value = mock_account
|
||||
result = _get_account_by_openid_or_email("github", user_info)
|
||||
@ -443,15 +421,14 @@ class TestAccountGeneration:
|
||||
|
||||
# Test fallback to email lookup
|
||||
mock_account_model.get_by_openid.return_value = None
|
||||
mock_session_instance = MagicMock()
|
||||
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||
mock_get_account.return_value = mock_account
|
||||
|
||||
result = _get_account_by_openid_or_email("github", user_info)
|
||||
assert result == mock_account
|
||||
mock_get_account.assert_called_once_with(user_info.email, session=mock_session_instance)
|
||||
mock_get_account.assert_called_once()
|
||||
|
||||
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(self):
|
||||
def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase(self):
|
||||
"""Test that case fallback tries lowercase when exact match fails."""
|
||||
mock_session = MagicMock()
|
||||
first_result = MagicMock()
|
||||
first_result.scalar_one_or_none.return_value = None
|
||||
@ -462,7 +439,7 @@ class TestAccountGeneration:
|
||||
|
||||
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session)
|
||||
|
||||
assert result == expected_account
|
||||
assert result is expected_account
|
||||
assert mock_session.execute.call_count == 2
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -478,10 +455,8 @@ class TestAccountGeneration:
|
||||
@patch("controllers.console.auth.oauth.RegisterService")
|
||||
@patch("controllers.console.auth.oauth.AccountService")
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
def test_should_handle_account_generation_scenarios(
|
||||
self,
|
||||
mock_db,
|
||||
mock_tenant_service,
|
||||
mock_account_service,
|
||||
mock_register_service,
|
||||
@ -519,10 +494,8 @@ class TestAccountGeneration:
|
||||
@patch("controllers.console.auth.oauth.RegisterService")
|
||||
@patch("controllers.console.auth.oauth.AccountService")
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
def test_should_register_with_lowercase_email(
|
||||
self,
|
||||
mock_db,
|
||||
mock_tenant_service,
|
||||
mock_account_service,
|
||||
mock_register_service,
|
||||
@ -4,6 +4,7 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest
|
||||
from models.dataset import Dataset, Document
|
||||
@ -38,7 +39,7 @@ class TestGetAvailableDatasetsIntegration:
|
||||
provider="dify",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=account.id,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.flush()
|
||||
@ -55,7 +56,7 @@ class TestGetAvailableDatasetsIntegration:
|
||||
name=f"Document {i}",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
@ -112,7 +113,7 @@ class TestGetAvailableDatasetsIntegration:
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
name=f"Archived Document {i}",
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=True, # Archived
|
||||
@ -165,7 +166,7 @@ class TestGetAvailableDatasetsIntegration:
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
name=f"Disabled Document {i}",
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=False, # Disabled
|
||||
archived=False,
|
||||
@ -218,7 +219,7 @@ class TestGetAvailableDatasetsIntegration:
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
name=f"Document {status}",
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
indexing_status=status, # Not completed
|
||||
enabled=True,
|
||||
archived=False,
|
||||
@ -336,7 +337,7 @@ class TestGetAvailableDatasetsIntegration:
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
name=f"Document for {dataset.name}",
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
@ -416,7 +417,7 @@ class TestGetAvailableDatasetsIntegration:
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
name=f"Document {i}",
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
@ -459,7 +460,7 @@ class TestKnowledgeRetrievalIntegration:
|
||||
provider="dify",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=account.id,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
||||
@ -476,7 +477,7 @@ class TestKnowledgeRetrievalIntegration:
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
@ -0,0 +1,227 @@
|
||||
"""
|
||||
Integration tests for Redis Streams broadcast channel implementation using TestContainers.
|
||||
|
||||
This suite focuses on the semantics that differ from Redis Pub/Sub:
|
||||
- Every active subscription should receive each newly published message.
|
||||
- Each subscription should only observe messages published after its listener starts.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Iterator
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import pytest
|
||||
import redis
|
||||
from testcontainers.redis import RedisContainer
|
||||
|
||||
from libs.broadcast_channel.channel import BroadcastChannel, Subscription, Topic
|
||||
from libs.broadcast_channel.exc import SubscriptionClosedError
|
||||
from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel
|
||||
|
||||
|
||||
class TestRedisStreamsBroadcastChannelIntegration:
|
||||
"""Integration tests for Redis Streams broadcast channel with a real Redis instance."""
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def redis_container(self) -> Iterator[RedisContainer]:
|
||||
"""Create a Redis container for integration testing."""
|
||||
with RedisContainer(image="redis:6-alpine") as container:
|
||||
yield container
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def redis_client(self, redis_container: RedisContainer) -> redis.Redis:
|
||||
"""Create a Redis client connected to the test container."""
|
||||
host = redis_container.get_container_host_ip()
|
||||
port = redis_container.get_exposed_port(6379)
|
||||
return redis.Redis(host=host, port=port, decode_responses=False)
|
||||
|
||||
@pytest.fixture
|
||||
def broadcast_channel(self, redis_client: redis.Redis) -> BroadcastChannel:
|
||||
"""Create a StreamsBroadcastChannel instance with a real Redis client."""
|
||||
return StreamsBroadcastChannel(redis_client)
|
||||
|
||||
@classmethod
|
||||
def _get_test_topic_name(cls) -> str:
|
||||
return f"test_streams_topic_{uuid.uuid4()}"
|
||||
|
||||
@staticmethod
|
||||
def _start_subscription(subscription: Subscription) -> None:
|
||||
"""Start the background listener and confirm the subscription queue is empty."""
|
||||
assert subscription.receive(timeout=0.05) is None
|
||||
|
||||
@staticmethod
|
||||
def _receive_message(subscription: Subscription, *, timeout_seconds: float = 2.0) -> bytes:
|
||||
"""Poll until a message is received or the timeout expires."""
|
||||
deadline = time.monotonic() + timeout_seconds
|
||||
while time.monotonic() < deadline:
|
||||
message = subscription.receive(timeout=0.1)
|
||||
if message is not None:
|
||||
return message
|
||||
pytest.fail("Timed out waiting for a message")
|
||||
|
||||
def test_close_an_active_subscription_should_stop_iteration(self, broadcast_channel: BroadcastChannel) -> None:
|
||||
"""Closing an active subscription should terminate the iterator cleanly."""
|
||||
topic = broadcast_channel.topic(self._get_test_topic_name())
|
||||
subscription = topic.subscribe()
|
||||
consuming_event = threading.Event()
|
||||
|
||||
def consume() -> list[bytes]:
|
||||
messages: list[bytes] = []
|
||||
consuming_event.set()
|
||||
for message in subscription:
|
||||
messages.append(message)
|
||||
return messages
|
||||
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
consumer_future = executor.submit(consume)
|
||||
assert consuming_event.wait(timeout=1.0)
|
||||
subscription.close()
|
||||
assert consumer_future.result(timeout=2.0) == []
|
||||
|
||||
def test_end_to_end_messaging(self, broadcast_channel: BroadcastChannel) -> None:
|
||||
"""A producer should publish a message that a live subscription can consume."""
|
||||
topic = broadcast_channel.topic(self._get_test_topic_name())
|
||||
producer = topic.as_producer()
|
||||
subscription = topic.subscribe()
|
||||
message = b"hello streams"
|
||||
|
||||
try:
|
||||
self._start_subscription(subscription)
|
||||
producer.publish(message)
|
||||
|
||||
assert self._receive_message(subscription) == message
|
||||
assert subscription.receive(timeout=0.1) is None
|
||||
finally:
|
||||
subscription.close()
|
||||
|
||||
def test_multiple_subscriptions_each_receive_each_new_message(self, broadcast_channel: BroadcastChannel) -> None:
|
||||
"""Each active subscription should receive the same newly published message."""
|
||||
topic = broadcast_channel.topic(self._get_test_topic_name())
|
||||
subscriptions = [topic.subscribe() for _ in range(3)]
|
||||
new_message = b"message-visible-to-every-subscriber"
|
||||
|
||||
try:
|
||||
for subscription in subscriptions:
|
||||
self._start_subscription(subscription)
|
||||
|
||||
topic.publish(new_message)
|
||||
|
||||
for subscription in subscriptions:
|
||||
assert self._receive_message(subscription) == new_message
|
||||
assert subscription.receive(timeout=0.1) is None
|
||||
finally:
|
||||
for subscription in subscriptions:
|
||||
subscription.close()
|
||||
|
||||
def test_each_subscription_only_receives_messages_published_after_it_starts(
|
||||
self,
|
||||
broadcast_channel: BroadcastChannel,
|
||||
) -> None:
|
||||
"""A late subscription should not replay messages that existed before its listener started."""
|
||||
topic = broadcast_channel.topic(self._get_test_topic_name())
|
||||
first_subscription = topic.subscribe()
|
||||
second_subscription = topic.subscribe()
|
||||
message_before_any_subscription = b"before-any-subscription"
|
||||
message_after_first_subscription = b"after-first-subscription"
|
||||
message_after_second_subscription = b"after-second-subscription"
|
||||
|
||||
try:
|
||||
topic.publish(message_before_any_subscription)
|
||||
|
||||
self._start_subscription(first_subscription)
|
||||
topic.publish(message_after_first_subscription)
|
||||
|
||||
assert self._receive_message(first_subscription) == message_after_first_subscription
|
||||
assert first_subscription.receive(timeout=0.1) is None
|
||||
|
||||
self._start_subscription(second_subscription)
|
||||
topic.publish(message_after_second_subscription)
|
||||
|
||||
assert self._receive_message(first_subscription) == message_after_second_subscription
|
||||
assert self._receive_message(second_subscription) == message_after_second_subscription
|
||||
assert first_subscription.receive(timeout=0.1) is None
|
||||
assert second_subscription.receive(timeout=0.1) is None
|
||||
finally:
|
||||
first_subscription.close()
|
||||
second_subscription.close()
|
||||
|
||||
def test_topic_isolation(self, broadcast_channel: BroadcastChannel) -> None:
|
||||
"""Messages from different topics should remain isolated."""
|
||||
topic1 = broadcast_channel.topic(self._get_test_topic_name())
|
||||
topic2 = broadcast_channel.topic(self._get_test_topic_name())
|
||||
message1 = b"message-for-topic-1"
|
||||
message2 = b"message-for-topic-2"
|
||||
|
||||
def consume_single_message(topic: Topic) -> bytes:
|
||||
subscription = topic.subscribe()
|
||||
try:
|
||||
self._start_subscription(subscription)
|
||||
return self._receive_message(subscription)
|
||||
finally:
|
||||
subscription.close()
|
||||
|
||||
with ThreadPoolExecutor(max_workers=3) as executor:
|
||||
consumer1_future = executor.submit(consume_single_message, topic1)
|
||||
consumer2_future = executor.submit(consume_single_message, topic2)
|
||||
time.sleep(0.1)
|
||||
topic1.publish(message1)
|
||||
topic2.publish(message2)
|
||||
|
||||
assert consumer1_future.result(timeout=5.0) == message1
|
||||
assert consumer2_future.result(timeout=5.0) == message2
|
||||
|
||||
def test_concurrent_producers_publish_all_messages(self, broadcast_channel: BroadcastChannel) -> None:
|
||||
"""Concurrent producers should not lose messages for a live subscription."""
|
||||
topic = broadcast_channel.topic(self._get_test_topic_name())
|
||||
subscription = topic.subscribe()
|
||||
producer_count = 4
|
||||
messages_per_producer = 4
|
||||
expected_total = producer_count * messages_per_producer
|
||||
consumer_ready = threading.Event()
|
||||
|
||||
def produce_messages(producer_idx: int) -> set[bytes]:
|
||||
producer = topic.as_producer()
|
||||
produced: set[bytes] = set()
|
||||
for message_idx in range(messages_per_producer):
|
||||
payload = f"producer-{producer_idx}-message-{message_idx}".encode()
|
||||
produced.add(payload)
|
||||
producer.publish(payload)
|
||||
time.sleep(0.001)
|
||||
return produced
|
||||
|
||||
def consume_messages() -> set[bytes]:
|
||||
received: set[bytes] = set()
|
||||
try:
|
||||
self._start_subscription(subscription)
|
||||
consumer_ready.set()
|
||||
while len(received) < expected_total:
|
||||
message = subscription.receive(timeout=0.2)
|
||||
if message is not None:
|
||||
received.add(message)
|
||||
return received
|
||||
finally:
|
||||
subscription.close()
|
||||
|
||||
with ThreadPoolExecutor(max_workers=producer_count + 1) as executor:
|
||||
consumer_future = executor.submit(consume_messages)
|
||||
assert consumer_ready.wait(timeout=2.0)
|
||||
|
||||
producer_futures = [executor.submit(produce_messages, idx) for idx in range(producer_count)]
|
||||
expected_messages: set[bytes] = set()
|
||||
for future in as_completed(producer_futures, timeout=10.0):
|
||||
expected_messages.update(future.result())
|
||||
|
||||
assert consumer_future.result(timeout=10.0) == expected_messages
|
||||
|
||||
def test_receive_raises_subscription_closed_after_close(self, broadcast_channel: BroadcastChannel) -> None:
|
||||
"""Calling receive on a closed subscription should raise SubscriptionClosedError."""
|
||||
topic = broadcast_channel.topic(self._get_test_topic_name())
|
||||
subscription = topic.subscribe()
|
||||
|
||||
self._start_subscription(subscription)
|
||||
subscription.close()
|
||||
|
||||
with pytest.raises(SubscriptionClosedError):
|
||||
subscription.receive(timeout=0.1)
|
||||
@ -13,6 +13,7 @@ import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import AppDatasetJoin, Dataset, DatasetPermissionEnum
|
||||
from models.enums import DataSourceType
|
||||
@ -74,7 +75,7 @@ class DatasetUpdateDeleteTestDataFactory:
|
||||
name=name,
|
||||
description="Test description",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
created_by=created_by,
|
||||
permission=permission,
|
||||
provider="vendor",
|
||||
|
||||
@ -13,6 +13,7 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models import Account
|
||||
from models.dataset import Dataset, Document
|
||||
@ -91,7 +92,7 @@ class DocumentStatusTestDataFactory:
|
||||
name=name,
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=created_by,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
)
|
||||
document.id = document_id
|
||||
document.indexing_status = indexing_status
|
||||
|
||||
@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from constants.model_template import default_app_templates
|
||||
from models import Account
|
||||
from models.model import App, Site
|
||||
from models.model import App, IconType, Site
|
||||
from services.account_service import AccountService, TenantService
|
||||
from tests.test_containers_integration_tests.helpers import generate_valid_password
|
||||
|
||||
@ -463,6 +463,109 @@ class TestAppService:
|
||||
assert updated_app.tenant_id == app.tenant_id
|
||||
assert updated_app.created_by == app.created_by
|
||||
|
||||
def test_update_app_should_preserve_icon_type_when_omitted(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test update_app keeps the persisted icon_type when the update payload omits it.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(
|
||||
tenant.id,
|
||||
{
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🎯",
|
||||
"icon_background": "#45B7D1",
|
||||
},
|
||||
account,
|
||||
)
|
||||
|
||||
mock_current_user = create_autospec(Account, instance=True)
|
||||
mock_current_user.id = account.id
|
||||
mock_current_user.current_tenant_id = account.current_tenant_id
|
||||
|
||||
with patch("services.app_service.current_user", mock_current_user):
|
||||
updated_app = app_service.update_app(
|
||||
app,
|
||||
{
|
||||
"name": "Updated App Name",
|
||||
"description": "Updated app description",
|
||||
"icon_type": None,
|
||||
"icon": "🔄",
|
||||
"icon_background": "#FF8C42",
|
||||
"use_icon_as_answer_icon": True,
|
||||
},
|
||||
)
|
||||
|
||||
assert updated_app.icon_type == IconType.EMOJI
|
||||
|
||||
def test_update_app_should_reject_empty_icon_type(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test update_app rejects an explicit empty icon_type.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(
|
||||
tenant.id,
|
||||
{
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🎯",
|
||||
"icon_background": "#45B7D1",
|
||||
},
|
||||
account,
|
||||
)
|
||||
|
||||
mock_current_user = create_autospec(Account, instance=True)
|
||||
mock_current_user.id = account.id
|
||||
mock_current_user.current_tenant_id = account.current_tenant_id
|
||||
|
||||
with patch("services.app_service.current_user", mock_current_user):
|
||||
with pytest.raises(ValueError):
|
||||
app_service.update_app(
|
||||
app,
|
||||
{
|
||||
"name": "Updated App Name",
|
||||
"description": "Updated app description",
|
||||
"icon_type": "",
|
||||
"icon": "🔄",
|
||||
"icon_background": "#FF8C42",
|
||||
"use_icon_as_answer_icon": True,
|
||||
},
|
||||
)
|
||||
|
||||
def test_update_app_name_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app name update.
|
||||
@ -1142,3 +1245,51 @@ class TestAppService:
|
||||
assert paginated_apps is not None
|
||||
assert paginated_apps.total == 1
|
||||
assert all("50%" in app.name for app in paginated_apps.items)
|
||||
|
||||
def test_get_app_code_by_id_not_found(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test get_app_code_by_id raises ValueError when site is missing."""
|
||||
from uuid import uuid4
|
||||
|
||||
from services.app_service import AppService
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
AppService.get_app_code_by_id(str(uuid4()))
|
||||
|
||||
def test_get_app_id_by_code_not_found(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test get_app_id_by_code raises ValueError when code does not exist."""
|
||||
from services.app_service import AppService
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
AppService.get_app_id_by_code("nonexistent-code")
|
||||
|
||||
def test_get_app_meta_returns_empty_when_workflow_missing(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test get_app_meta returns empty tool_icons when workflow is None."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
workflow_app = SimpleNamespace(mode="workflow", workflow=None)
|
||||
|
||||
meta = app_service.get_app_meta(workflow_app)
|
||||
assert meta == {"tool_icons": {}}
|
||||
|
||||
def test_get_app_meta_returns_empty_when_model_config_missing(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test get_app_meta returns empty tool_icons when app_model_config is None."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
chat_app = SimpleNamespace(mode="chat", app_model_config=None)
|
||||
|
||||
meta = app_service.get_app_meta(chat_app)
|
||||
assert meta == {"tool_icons": {}}
|
||||
|
||||
@ -9,6 +9,7 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from extensions.ext_database import db
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import (
|
||||
@ -69,7 +70,7 @@ class DatasetPermissionTestDataFactory:
|
||||
name=name,
|
||||
description="desc",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
created_by=created_by,
|
||||
permission=permission,
|
||||
provider="vendor",
|
||||
|
||||
@ -11,6 +11,7 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
@ -62,7 +63,7 @@ class DatasetServiceIntegrationDataFactory:
|
||||
name: str = "Test Dataset",
|
||||
description: str | None = "Test description",
|
||||
provider: str = "vendor",
|
||||
indexing_technique: str | None = "high_quality",
|
||||
indexing_technique: str | None = IndexTechniqueType.HIGH_QUALITY,
|
||||
permission: str = DatasetPermissionEnum.ONLY_ME,
|
||||
retrieval_model: dict | None = None,
|
||||
embedding_model_provider: str | None = None,
|
||||
@ -106,7 +107,7 @@ class DatasetServiceIntegrationDataFactory:
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=created_by,
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.flush()
|
||||
@ -156,13 +157,13 @@ class TestDatasetServiceCreateDataset:
|
||||
tenant_id=tenant.id,
|
||||
name="Economy Dataset",
|
||||
description=None,
|
||||
indexing_technique="economy",
|
||||
indexing_technique=IndexTechniqueType.ECONOMY,
|
||||
account=account,
|
||||
)
|
||||
|
||||
# Assert
|
||||
db_session_with_containers.refresh(result)
|
||||
assert result.indexing_technique == "economy"
|
||||
assert result.indexing_technique == IndexTechniqueType.ECONOMY
|
||||
assert result.embedding_model_provider is None
|
||||
assert result.embedding_model is None
|
||||
|
||||
@ -180,13 +181,13 @@ class TestDatasetServiceCreateDataset:
|
||||
tenant_id=tenant.id,
|
||||
name="High Quality Dataset",
|
||||
description=None,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
account=account,
|
||||
)
|
||||
|
||||
# Assert
|
||||
db_session_with_containers.refresh(result)
|
||||
assert result.indexing_technique == "high_quality"
|
||||
assert result.indexing_technique == IndexTechniqueType.HIGH_QUALITY
|
||||
assert result.embedding_model_provider == embedding_model.provider
|
||||
assert result.embedding_model == embedding_model.model_name
|
||||
mock_model_manager.return_value.get_default_model_instance.assert_called_once_with(
|
||||
@ -272,7 +273,7 @@ class TestDatasetServiceCreateDataset:
|
||||
tenant_id=tenant.id,
|
||||
name="Dataset With Reranking",
|
||||
description=None,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
account=account,
|
||||
retrieval_model=retrieval_model,
|
||||
)
|
||||
@ -305,7 +306,7 @@ class TestDatasetServiceCreateDataset:
|
||||
tenant_id=tenant.id,
|
||||
name="Custom Embedding Dataset",
|
||||
description=None,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
account=account,
|
||||
embedding_model_provider=embedding_provider,
|
||||
embedding_model_name=embedding_model_name,
|
||||
@ -313,7 +314,7 @@ class TestDatasetServiceCreateDataset:
|
||||
|
||||
# Assert
|
||||
db_session_with_containers.refresh(result)
|
||||
assert result.indexing_technique == "high_quality"
|
||||
assert result.indexing_technique == IndexTechniqueType.HIGH_QUALITY
|
||||
assert result.embedding_model_provider == embedding_provider
|
||||
assert result.embedding_model == embedding_model_name
|
||||
mock_check_embedding.assert_called_once_with(tenant.id, embedding_provider, embedding_model_name)
|
||||
@ -588,7 +589,7 @@ class TestDatasetServiceUpdateAndDeleteDataset:
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
chunk_structure="text_model",
|
||||
)
|
||||
DatasetServiceIntegrationDataFactory.create_document(
|
||||
@ -684,14 +685,14 @@ class TestDatasetServiceRetrievalConfiguration:
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
retrieval_model={"search_method": "semantic_search", "top_k": 2, "score_threshold": 0.0},
|
||||
embedding_model_provider="openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
collection_binding_id=str(uuid4()),
|
||||
)
|
||||
update_data = {
|
||||
"indexing_technique": "high_quality",
|
||||
"indexing_technique": IndexTechniqueType.HIGH_QUALITY,
|
||||
"retrieval_model": {
|
||||
"search_method": "full_text_search",
|
||||
"top_k": 10,
|
||||
@ -706,3 +707,104 @@ class TestDatasetServiceRetrievalConfiguration:
|
||||
db_session_with_containers.refresh(dataset)
|
||||
assert result.id == dataset.id
|
||||
assert dataset.retrieval_model == update_data["retrieval_model"]
|
||||
|
||||
|
||||
class TestDocumentServicePauseRecoverRetry:
|
||||
"""Tests for pause/recover/retry orchestration using real DB and Redis."""
|
||||
|
||||
def _create_indexing_document(self, db_session_with_containers, indexing_status="indexing"):
|
||||
factory = DatasetServiceIntegrationDataFactory
|
||||
account, tenant = factory.create_account_with_tenant(db_session_with_containers)
|
||||
dataset = factory.create_dataset(db_session_with_containers, tenant.id, account.id)
|
||||
doc = factory.create_document(db_session_with_containers, dataset, account.id)
|
||||
doc.indexing_status = indexing_status
|
||||
db_session_with_containers.commit()
|
||||
return doc, account
|
||||
|
||||
def test_pause_document_success(self, db_session_with_containers):
|
||||
from extensions.ext_redis import redis_client
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
doc, account = self._create_indexing_document(db_session_with_containers, indexing_status="indexing")
|
||||
|
||||
with patch("services.dataset_service.current_user") as mock_user:
|
||||
mock_user.id = account.id
|
||||
DocumentService.pause_document(doc)
|
||||
|
||||
db_session_with_containers.refresh(doc)
|
||||
assert doc.is_paused is True
|
||||
assert doc.paused_by == account.id
|
||||
assert doc.paused_at is not None
|
||||
|
||||
cache_key = f"document_{doc.id}_is_paused"
|
||||
assert redis_client.get(cache_key) is not None
|
||||
redis_client.delete(cache_key)
|
||||
|
||||
def test_pause_document_invalid_status_error(self, db_session_with_containers):
|
||||
from services.dataset_service import DocumentService
|
||||
from services.errors.document import DocumentIndexingError
|
||||
|
||||
doc, account = self._create_indexing_document(db_session_with_containers, indexing_status="completed")
|
||||
|
||||
with patch("services.dataset_service.current_user") as mock_user:
|
||||
mock_user.id = account.id
|
||||
with pytest.raises(DocumentIndexingError):
|
||||
DocumentService.pause_document(doc)
|
||||
|
||||
def test_recover_document_success(self, db_session_with_containers):
|
||||
from extensions.ext_redis import redis_client
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
doc, account = self._create_indexing_document(db_session_with_containers, indexing_status="indexing")
|
||||
|
||||
# Pause first
|
||||
with patch("services.dataset_service.current_user") as mock_user:
|
||||
mock_user.id = account.id
|
||||
DocumentService.pause_document(doc)
|
||||
|
||||
# Recover
|
||||
with patch("services.dataset_service.recover_document_indexing_task") as recover_task:
|
||||
DocumentService.recover_document(doc)
|
||||
|
||||
db_session_with_containers.refresh(doc)
|
||||
assert doc.is_paused is False
|
||||
assert doc.paused_by is None
|
||||
assert doc.paused_at is None
|
||||
|
||||
cache_key = f"document_{doc.id}_is_paused"
|
||||
assert redis_client.get(cache_key) is None
|
||||
recover_task.delay.assert_called_once_with(doc.dataset_id, doc.id)
|
||||
|
||||
def test_retry_document_indexing_success(self, db_session_with_containers):
|
||||
from extensions.ext_redis import redis_client
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
factory = DatasetServiceIntegrationDataFactory
|
||||
account, tenant = factory.create_account_with_tenant(db_session_with_containers)
|
||||
dataset = factory.create_dataset(db_session_with_containers, tenant.id, account.id)
|
||||
doc1 = factory.create_document(db_session_with_containers, dataset, account.id, name="doc1.txt")
|
||||
doc2 = factory.create_document(db_session_with_containers, dataset, account.id, name="doc2.txt")
|
||||
doc2.position = 2
|
||||
doc1.indexing_status = "error"
|
||||
doc2.indexing_status = "error"
|
||||
db_session_with_containers.commit()
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.current_user") as mock_user,
|
||||
patch("services.dataset_service.retry_document_indexing_task") as retry_task,
|
||||
):
|
||||
mock_user.id = account.id
|
||||
DocumentService.retry_document(dataset.id, [doc1, doc2])
|
||||
|
||||
db_session_with_containers.refresh(doc1)
|
||||
db_session_with_containers.refresh(doc2)
|
||||
assert doc1.indexing_status == "waiting"
|
||||
assert doc2.indexing_status == "waiting"
|
||||
|
||||
# Verify redis keys were set
|
||||
assert redis_client.get(f"document_{doc1.id}_is_retried") is not None
|
||||
assert redis_client.get(f"document_{doc2.id}_is_retried") is not None
|
||||
retry_task.delay.assert_called_once_with(dataset.id, [doc1.id, doc2.id], account.id)
|
||||
|
||||
# Cleanup
|
||||
redis_client.delete(f"document_{doc1.id}_is_retried", f"document_{doc2.id}_is_retried")
|
||||
|
||||
@ -13,6 +13,7 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from models.dataset import Dataset, Document
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus
|
||||
from services.dataset_service import DocumentService
|
||||
@ -79,7 +80,7 @@ class DocumentBatchUpdateIntegrationDataFactory:
|
||||
name=name,
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=created_by or str(uuid4()),
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
)
|
||||
document.id = document_id or str(uuid4())
|
||||
document.enabled = enabled
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom
|
||||
@ -78,7 +79,7 @@ class DatasetDeleteIntegrationDataFactory:
|
||||
tenant_id: str,
|
||||
dataset_id: str,
|
||||
created_by: str,
|
||||
doc_form: str = "text_model",
|
||||
doc_form: str = IndexStructureType.PARAGRAPH_INDEX,
|
||||
) -> Document:
|
||||
"""Persist a document so dataset.doc_form resolves through the real document path."""
|
||||
document = Document(
|
||||
@ -108,7 +109,7 @@ class TestDatasetServiceDeleteDataset:
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=owner.id,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
chunk_structure=None,
|
||||
index_struct='{"type": "paragraph"}',
|
||||
collection_binding_id=str(uuid4()),
|
||||
@ -119,7 +120,7 @@ class TestDatasetServiceDeleteDataset:
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
created_by=owner.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
)
|
||||
|
||||
# Act
|
||||
@ -207,7 +208,7 @@ class TestDatasetServiceDeleteDataset:
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=owner.id,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
chunk_structure=None,
|
||||
index_struct='{"type": "paragraph"}',
|
||||
collection_binding_id=str(uuid4()),
|
||||
|
||||
@ -12,6 +12,7 @@ from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom
|
||||
@ -64,7 +65,7 @@ class SegmentServiceTestDataFactory:
|
||||
name=f"Test Dataset {uuid4()}",
|
||||
description="Test description",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
created_by=created_by,
|
||||
permission=DatasetPermissionEnum.ONLY_ME,
|
||||
provider="vendor",
|
||||
|
||||
@ -15,6 +15,7 @@ from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import (
|
||||
AppDatasetJoin,
|
||||
@ -102,7 +103,7 @@ class DatasetRetrievalTestDataFactory:
|
||||
name=name,
|
||||
description="desc",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
created_by=created_by,
|
||||
permission=permission,
|
||||
provider="vendor",
|
||||
|
||||
@ -4,6 +4,7 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, ExternalKnowledgeBindings
|
||||
@ -53,7 +54,7 @@ class DatasetUpdateTestDataFactory:
|
||||
provider: str = "vendor",
|
||||
name: str = "old_name",
|
||||
description: str = "old_description",
|
||||
indexing_technique: str = "high_quality",
|
||||
indexing_technique: str = IndexTechniqueType.HIGH_QUALITY,
|
||||
retrieval_model: str = "old_model",
|
||||
permission: str = "only_me",
|
||||
embedding_model_provider: str | None = None,
|
||||
@ -241,7 +242,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
tenant_id=tenant.id,
|
||||
created_by=user.id,
|
||||
provider="vendor",
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model_provider="openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
collection_binding_id=existing_binding_id,
|
||||
@ -250,7 +251,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
update_data = {
|
||||
"name": "new_name",
|
||||
"description": "new_description",
|
||||
"indexing_technique": "high_quality",
|
||||
"indexing_technique": IndexTechniqueType.HIGH_QUALITY,
|
||||
"retrieval_model": "new_model",
|
||||
"embedding_model_provider": "openai",
|
||||
"embedding_model": "text-embedding-ada-002",
|
||||
@ -261,7 +262,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
|
||||
assert dataset.name == "new_name"
|
||||
assert dataset.description == "new_description"
|
||||
assert dataset.indexing_technique == "high_quality"
|
||||
assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY
|
||||
assert dataset.retrieval_model == "new_model"
|
||||
assert dataset.embedding_model_provider == "openai"
|
||||
assert dataset.embedding_model == "text-embedding-ada-002"
|
||||
@ -276,7 +277,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
tenant_id=tenant.id,
|
||||
created_by=user.id,
|
||||
provider="vendor",
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model_provider="openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
collection_binding_id=existing_binding_id,
|
||||
@ -285,7 +286,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
update_data = {
|
||||
"name": "new_name",
|
||||
"description": None,
|
||||
"indexing_technique": "high_quality",
|
||||
"indexing_technique": IndexTechniqueType.HIGH_QUALITY,
|
||||
"retrieval_model": "new_model",
|
||||
"embedding_model_provider": None,
|
||||
"embedding_model": None,
|
||||
@ -312,14 +313,14 @@ class TestDatasetServiceUpdateDataset:
|
||||
tenant_id=tenant.id,
|
||||
created_by=user.id,
|
||||
provider="vendor",
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model_provider="openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
collection_binding_id=existing_binding_id,
|
||||
)
|
||||
|
||||
update_data = {
|
||||
"indexing_technique": "economy",
|
||||
"indexing_technique": IndexTechniqueType.ECONOMY,
|
||||
"retrieval_model": "new_model",
|
||||
}
|
||||
|
||||
@ -328,7 +329,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
mock_task.delay.assert_called_once_with(dataset.id, "remove")
|
||||
|
||||
db_session_with_containers.refresh(dataset)
|
||||
assert dataset.indexing_technique == "economy"
|
||||
assert dataset.indexing_technique == IndexTechniqueType.ECONOMY
|
||||
assert dataset.embedding_model is None
|
||||
assert dataset.embedding_model_provider is None
|
||||
assert dataset.collection_binding_id is None
|
||||
@ -343,7 +344,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
tenant_id=tenant.id,
|
||||
created_by=user.id,
|
||||
provider="vendor",
|
||||
indexing_technique="economy",
|
||||
indexing_technique=IndexTechniqueType.ECONOMY,
|
||||
)
|
||||
|
||||
embedding_model = Mock()
|
||||
@ -354,7 +355,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
binding.id = str(uuid4())
|
||||
|
||||
update_data = {
|
||||
"indexing_technique": "high_quality",
|
||||
"indexing_technique": IndexTechniqueType.HIGH_QUALITY,
|
||||
"embedding_model_provider": "openai",
|
||||
"embedding_model": "text-embedding-ada-002",
|
||||
"retrieval_model": "new_model",
|
||||
@ -383,7 +384,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
mock_task.delay.assert_called_once_with(dataset.id, "add")
|
||||
|
||||
db_session_with_containers.refresh(dataset)
|
||||
assert dataset.indexing_technique == "high_quality"
|
||||
assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY
|
||||
assert dataset.embedding_model == "text-embedding-ada-002"
|
||||
assert dataset.embedding_model_provider == "openai"
|
||||
assert dataset.collection_binding_id == binding.id
|
||||
@ -403,7 +404,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
tenant_id=tenant.id,
|
||||
created_by=user.id,
|
||||
provider="vendor",
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model_provider="openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
collection_binding_id=existing_binding_id,
|
||||
@ -411,7 +412,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
|
||||
update_data = {
|
||||
"name": "new_name",
|
||||
"indexing_technique": "high_quality",
|
||||
"indexing_technique": IndexTechniqueType.HIGH_QUALITY,
|
||||
"retrieval_model": "new_model",
|
||||
}
|
||||
|
||||
@ -419,7 +420,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
db_session_with_containers.refresh(dataset)
|
||||
|
||||
assert dataset.name == "new_name"
|
||||
assert dataset.indexing_technique == "high_quality"
|
||||
assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY
|
||||
assert dataset.embedding_model_provider == "openai"
|
||||
assert dataset.embedding_model == "text-embedding-ada-002"
|
||||
assert dataset.collection_binding_id == existing_binding_id
|
||||
@ -435,7 +436,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
tenant_id=tenant.id,
|
||||
created_by=user.id,
|
||||
provider="vendor",
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model_provider="openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
collection_binding_id=existing_binding_id,
|
||||
@ -449,7 +450,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
binding.id = str(uuid4())
|
||||
|
||||
update_data = {
|
||||
"indexing_technique": "high_quality",
|
||||
"indexing_technique": IndexTechniqueType.HIGH_QUALITY,
|
||||
"embedding_model_provider": "openai",
|
||||
"embedding_model": "text-embedding-3-small",
|
||||
"retrieval_model": "new_model",
|
||||
@ -531,11 +532,11 @@ class TestDatasetServiceUpdateDataset:
|
||||
tenant_id=tenant.id,
|
||||
created_by=user.id,
|
||||
provider="vendor",
|
||||
indexing_technique="economy",
|
||||
indexing_technique=IndexTechniqueType.ECONOMY,
|
||||
)
|
||||
|
||||
update_data = {
|
||||
"indexing_technique": "high_quality",
|
||||
"indexing_technique": IndexTechniqueType.HIGH_QUALITY,
|
||||
"embedding_model_provider": "invalid_provider",
|
||||
"embedding_model": "invalid_model",
|
||||
"retrieval_model": "new_model",
|
||||
|
||||
@ -141,3 +141,73 @@ class TestArchivedWorkflowRunDeletion:
|
||||
db_session_with_containers.expunge_all()
|
||||
deleted_run = db_session_with_containers.get(WorkflowRun, run_id)
|
||||
assert deleted_run is None
|
||||
|
||||
def test_delete_run_dry_run(self, db_session_with_containers):
|
||||
"""Dry run should return success without actually deleting."""
|
||||
tenant_id = str(uuid4())
|
||||
run = self._create_workflow_run(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant_id,
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
run_id = run.id
|
||||
deleter = ArchivedWorkflowRunDeletion(dry_run=True)
|
||||
|
||||
result = deleter._delete_run(run)
|
||||
|
||||
assert result.success is True
|
||||
assert result.run_id == run_id
|
||||
# Run should still exist because it's a dry run
|
||||
db_session_with_containers.expire_all()
|
||||
assert db_session_with_containers.get(WorkflowRun, run_id) is not None
|
||||
|
||||
def test_delete_run_exception_returns_error(self, db_session_with_containers):
|
||||
"""Exception during deletion should return failure result."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
tenant_id = str(uuid4())
|
||||
run = self._create_workflow_run(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant_id,
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
deleter = ArchivedWorkflowRunDeletion(dry_run=False)
|
||||
|
||||
with patch.object(deleter, "_get_workflow_run_repo") as mock_get_repo:
|
||||
mock_repo = MagicMock()
|
||||
mock_get_repo.return_value = mock_repo
|
||||
mock_repo.delete_runs_with_related.side_effect = Exception("Database error")
|
||||
|
||||
result = deleter._delete_run(run)
|
||||
|
||||
assert result.success is False
|
||||
assert result.error == "Database error"
|
||||
|
||||
def test_delete_by_run_id_success(self, db_session_with_containers):
|
||||
"""Successfully delete an archived workflow run by ID."""
|
||||
tenant_id = str(uuid4())
|
||||
base_time = datetime.now(UTC)
|
||||
run = self._create_workflow_run(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant_id,
|
||||
created_at=base_time,
|
||||
)
|
||||
self._create_archive_log(db_session_with_containers, run=run)
|
||||
run_id = run.id
|
||||
|
||||
deleter = ArchivedWorkflowRunDeletion()
|
||||
result = deleter.delete_by_run_id(run_id)
|
||||
|
||||
assert result.success is True
|
||||
db_session_with_containers.expunge_all()
|
||||
assert db_session_with_containers.get(WorkflowRun, run_id) is None
|
||||
|
||||
def test_get_workflow_run_repo_caches_instance(self, db_session_with_containers):
|
||||
"""_get_workflow_run_repo should return a cached repo on subsequent calls."""
|
||||
deleter = ArchivedWorkflowRunDeletion()
|
||||
|
||||
repo1 = deleter._get_workflow_run_repo()
|
||||
repo2 = deleter._get_workflow_run_repo()
|
||||
|
||||
assert repo1 is repo2
|
||||
assert deleter.workflow_run_repo is repo1
|
||||
|
||||
@ -3,6 +3,7 @@ from uuid import uuid4
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from models.dataset import Dataset, Document
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus
|
||||
from services.dataset_service import DocumentService
|
||||
@ -42,7 +43,7 @@ def _create_document(
|
||||
name=f"doc-{uuid4()}",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=str(uuid4()),
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
)
|
||||
document.id = str(uuid4())
|
||||
document.indexing_status = indexing_status
|
||||
|
||||
@ -7,6 +7,7 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models import Account
|
||||
from models.dataset import Dataset, Document
|
||||
@ -69,7 +70,7 @@ def make_document(
|
||||
name=name,
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=str(uuid4()),
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
)
|
||||
doc.id = document_id
|
||||
doc.indexing_status = "completed"
|
||||
|
||||
@ -5,6 +5,7 @@ from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding, Document
|
||||
from models.enums import DatasetMetadataType, DataSourceType, DocumentCreatedFrom
|
||||
@ -139,7 +140,7 @@ class TestMetadataService:
|
||||
name=fake.file_name(),
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
)
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset
|
||||
from models.enums import DataSourceType, TagType
|
||||
@ -102,7 +103,7 @@ class TestTagService:
|
||||
provider="vendor",
|
||||
permission="only_me",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
tenant_id=tenant_id,
|
||||
created_by=mock_external_service_dependencies["current_user"].id,
|
||||
)
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@ -8,14 +11,14 @@ from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dify_graph.entities.workflow_execution import WorkflowExecutionStatus
|
||||
from models import EndUser, Workflow, WorkflowAppLog, WorkflowRun
|
||||
from models.enums import CreatorUserRole
|
||||
from models import EndUser, Workflow, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun
|
||||
from models.enums import AppTriggerType, CreatorUserRole, WorkflowRunTriggeredFrom
|
||||
from models.workflow import WorkflowAppLogCreatedFrom
|
||||
from services.account_service import AccountService, TenantService
|
||||
|
||||
# Delay import of AppService to avoid circular dependency
|
||||
# from services.app_service import AppService
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
from services.workflow_app_service import LogView, WorkflowAppService
|
||||
from tests.test_containers_integration_tests.helpers import generate_valid_password
|
||||
|
||||
|
||||
@ -1525,3 +1528,168 @@ class TestWorkflowAppService:
|
||||
|
||||
# Should not find tenant2's data when searching from tenant1's context
|
||||
assert result_cross_tenant["total"] == 0
|
||||
|
||||
def test_get_paginate_workflow_app_logs_raises_when_account_filter_email_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
service = WorkflowAppService()
|
||||
|
||||
with pytest.raises(ValueError, match="Account not found: nonexistent@example.com"):
|
||||
service.get_paginate_workflow_app_logs(
|
||||
session=db_session_with_containers,
|
||||
app_model=app,
|
||||
created_by_account="nonexistent@example.com",
|
||||
)
|
||||
|
||||
def test_get_paginate_workflow_app_logs_filters_by_account(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
service = WorkflowAppService()
|
||||
workflow, workflow_run, _log = self._create_test_workflow_data(db_session_with_containers, app, account)
|
||||
|
||||
result = service.get_paginate_workflow_app_logs(
|
||||
session=db_session_with_containers,
|
||||
app_model=app,
|
||||
created_by_account=account.email,
|
||||
)
|
||||
|
||||
assert result["total"] >= 0
|
||||
assert isinstance(result["data"], list)
|
||||
|
||||
def test_get_paginate_workflow_archive_logs(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
service = WorkflowAppService()
|
||||
|
||||
end_user = EndUser(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
type="browser",
|
||||
is_anonymous=False,
|
||||
session_id="session-1",
|
||||
)
|
||||
db_session_with_containers.add(end_user)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
now = datetime.now(UTC)
|
||||
archive_defaults = {
|
||||
"workflow_id": str(uuid.uuid4()),
|
||||
"run_version": "1.0.0",
|
||||
"run_status": WorkflowExecutionStatus.SUCCEEDED,
|
||||
"run_triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
|
||||
"run_error": None,
|
||||
"run_elapsed_time": 1.0,
|
||||
"run_total_tokens": 0,
|
||||
"run_total_steps": 0,
|
||||
"run_created_at": now,
|
||||
"run_finished_at": now,
|
||||
"run_exceptions_count": 0,
|
||||
"trigger_metadata": '{"type":"trigger-webhook"}',
|
||||
"log_created_at": now,
|
||||
"log_created_from": WorkflowAppLogCreatedFrom.SERVICE_API,
|
||||
}
|
||||
archive_account = WorkflowArchiveLog(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
log_id=str(uuid.uuid4()),
|
||||
created_by=account.id,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
**archive_defaults,
|
||||
)
|
||||
archive_end_user = WorkflowArchiveLog(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
log_id=str(uuid.uuid4()),
|
||||
created_by=end_user.id,
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
**archive_defaults,
|
||||
)
|
||||
db_session_with_containers.add_all([archive_account, archive_end_user])
|
||||
db_session_with_containers.commit()
|
||||
|
||||
result = service.get_paginate_workflow_archive_logs(
|
||||
session=db_session_with_containers,
|
||||
app_model=app,
|
||||
page=1,
|
||||
limit=20,
|
||||
)
|
||||
|
||||
assert result["total"] == 2
|
||||
assert len(result["data"]) == 2
|
||||
account_item = next(d for d in result["data"] if d["created_by_account"] is not None)
|
||||
end_user_item = next(d for d in result["data"] if d["created_by_end_user"] is not None)
|
||||
assert account_item["created_by_account"].id == account.id
|
||||
assert end_user_item["created_by_end_user"].id == end_user.id
|
||||
|
||||
|
||||
class TestLogView:
|
||||
def test_details_and_proxy_attributes(self):
|
||||
log = SimpleNamespace(id="log-1", status="succeeded")
|
||||
view = LogView(log=log, details={"trigger_metadata": {"type": "plugin"}})
|
||||
|
||||
assert view.details == {"trigger_metadata": {"type": "plugin"}}
|
||||
assert view.status == "succeeded"
|
||||
|
||||
|
||||
class TestHandleTriggerMetadata:
|
||||
def test_returns_empty_dict_when_metadata_missing(self):
|
||||
service = WorkflowAppService()
|
||||
assert service.handle_trigger_metadata("tenant-1", None) == {}
|
||||
|
||||
def test_enriches_plugin_icons(self):
|
||||
service = WorkflowAppService()
|
||||
meta = {
|
||||
"type": AppTriggerType.TRIGGER_PLUGIN.value,
|
||||
"icon_filename": "light.png",
|
||||
"icon_dark_filename": "dark.png",
|
||||
}
|
||||
with patch(
|
||||
"services.workflow_app_service.PluginService.get_plugin_icon_url",
|
||||
side_effect=["https://cdn/light.png", "https://cdn/dark.png"],
|
||||
) as mock_icon:
|
||||
result = service.handle_trigger_metadata("tenant-1", json.dumps(meta))
|
||||
|
||||
assert result["icon"] == "https://cdn/light.png"
|
||||
assert result["icon_dark"] == "https://cdn/dark.png"
|
||||
assert mock_icon.call_count == 2
|
||||
|
||||
def test_non_plugin_metadata_without_icon_lookup(self):
|
||||
service = WorkflowAppService()
|
||||
meta = {"type": AppTriggerType.TRIGGER_WEBHOOK.value}
|
||||
with patch("services.workflow_app_service.PluginService.get_plugin_icon_url") as mock_icon:
|
||||
result = service.handle_trigger_metadata("tenant-1", json.dumps(meta))
|
||||
|
||||
assert result["type"] == AppTriggerType.TRIGGER_WEBHOOK.value
|
||||
mock_icon.assert_not_called()
|
||||
|
||||
|
||||
class TestSafeJsonLoads:
|
||||
@pytest.mark.parametrize(
|
||||
("value", "expected"),
|
||||
[
|
||||
(None, None),
|
||||
("", None),
|
||||
('{"k":"v"}', {"k": "v"}),
|
||||
("not-json", None),
|
||||
({"raw": True}, {"raw": True}),
|
||||
],
|
||||
)
|
||||
def test_handles_various_inputs(self, value, expected):
|
||||
assert WorkflowAppService._safe_json_loads(value) == expected
|
||||
|
||||
|
||||
class TestSafeParseUuid:
|
||||
def test_returns_none_for_short_or_invalid_values(self):
|
||||
service = WorkflowAppService()
|
||||
assert service._safe_parse_uuid("short") is None
|
||||
assert service._safe_parse_uuid("x" * 40) is None
|
||||
|
||||
def test_returns_uuid_for_valid_string(self):
|
||||
service = WorkflowAppService()
|
||||
raw = str(uuid.uuid4())
|
||||
result = service._safe_parse_uuid(raw)
|
||||
assert result is not None
|
||||
assert str(result) == raw
|
||||
|
||||
@ -1,12 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.tools.entities.api_entities import ToolProviderApiEntity
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderSchemaType,
|
||||
ToolDescription,
|
||||
ToolEntity,
|
||||
ToolIdentity,
|
||||
ToolParameter,
|
||||
ToolProviderEntity,
|
||||
ToolProviderIdentity,
|
||||
ToolProviderType,
|
||||
)
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
|
||||
from services.plugin.plugin_service import PluginService
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
@ -52,7 +64,7 @@ class TestToolTransformService:
|
||||
user_id="test_user_id",
|
||||
credentials_str='{"auth_type": "api_key_header", "api_key": "test_key"}',
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
schema_type_str=ApiProviderSchemaType.OPENAPI,
|
||||
tools_str="[]",
|
||||
)
|
||||
elif provider_type == "builtin":
|
||||
@ -659,7 +671,7 @@ class TestToolTransformService:
|
||||
user_id=fake.uuid4(),
|
||||
credentials_str='{"auth_type": "api_key_header", "api_key": "test_key"}',
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
schema_type_str=ApiProviderSchemaType.OPENAPI,
|
||||
tools_str="[]",
|
||||
)
|
||||
|
||||
@ -695,7 +707,7 @@ class TestToolTransformService:
|
||||
user_id=fake.uuid4(),
|
||||
credentials_str='{"auth_type": "api_key_query", "api_key": "test_key"}',
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
schema_type_str=ApiProviderSchemaType.OPENAPI,
|
||||
tools_str="[]",
|
||||
)
|
||||
|
||||
@ -731,7 +743,7 @@ class TestToolTransformService:
|
||||
user_id=fake.uuid4(),
|
||||
credentials_str='{"auth_type": "api_key", "api_key": "test_key"}',
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
schema_type_str=ApiProviderSchemaType.OPENAPI,
|
||||
tools_str="[]",
|
||||
)
|
||||
|
||||
@ -786,3 +798,192 @@ class TestToolTransformService:
|
||||
assert result is not None
|
||||
assert result == mock_controller
|
||||
mock_from_db.assert_called_once_with(provider)
|
||||
|
||||
|
||||
def _mock_tool(*, base_params, runtime_params):
|
||||
"""Helper to build a Mock tool with real entity objects.
|
||||
|
||||
Tool is abstract and requires runtime behaviour (fork_tool_runtime,
|
||||
get_runtime_parameters), so it stays as a Mock. Everything else uses
|
||||
real Pydantic instances.
|
||||
"""
|
||||
entity = ToolEntity(
|
||||
identity=ToolIdentity(
|
||||
author="test_author",
|
||||
name="test_tool",
|
||||
label=I18nObject(en_US="Test Tool"),
|
||||
provider="test_provider",
|
||||
),
|
||||
parameters=base_params or [],
|
||||
description=ToolDescription(
|
||||
human=I18nObject(en_US="Test description"),
|
||||
llm="Test description for LLM",
|
||||
),
|
||||
output_schema={},
|
||||
)
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = entity
|
||||
mock_tool.get_runtime_parameters.return_value = runtime_params
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
return mock_tool
|
||||
|
||||
|
||||
def _param(name, *, form=ToolParameter.ToolParameterForm.FORM, label=None):
|
||||
return ToolParameter(
|
||||
name=name,
|
||||
label=I18nObject(en_US=label or name),
|
||||
human_description=I18nObject(en_US=name),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
form=form,
|
||||
)
|
||||
|
||||
|
||||
class TestConvertToolEntityToApiEntity:
|
||||
"""Tests for ToolTransformService.convert_tool_entity_to_api_entity."""
|
||||
|
||||
def test_parameter_override(self):
|
||||
base = [_param("param1", label="Base 1"), _param("param2", label="Base 2")]
|
||||
runtime = [_param("param1", label="Runtime 1")]
|
||||
tool = _mock_tool(base_params=base, runtime_params=runtime)
|
||||
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None)
|
||||
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert len(result.parameters) == 2
|
||||
assert next(p for p in result.parameters if p.name == "param1").label.en_US == "Runtime 1"
|
||||
assert next(p for p in result.parameters if p.name == "param2").label.en_US == "Base 2"
|
||||
|
||||
def test_additional_runtime_parameters(self):
|
||||
base = [_param("param1", label="Base 1")]
|
||||
runtime = [_param("param1", label="Runtime 1"), _param("runtime_only", label="Runtime Only")]
|
||||
tool = _mock_tool(base_params=base, runtime_params=runtime)
|
||||
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None)
|
||||
|
||||
assert len(result.parameters) == 2
|
||||
names = [p.name for p in result.parameters]
|
||||
assert "param1" in names
|
||||
assert "runtime_only" in names
|
||||
|
||||
def test_non_form_runtime_parameters_excluded(self):
|
||||
base = [_param("param1")]
|
||||
runtime = [
|
||||
_param("param1", label="Runtime 1"),
|
||||
_param("llm_param", form=ToolParameter.ToolParameterForm.LLM),
|
||||
]
|
||||
tool = _mock_tool(base_params=base, runtime_params=runtime)
|
||||
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None)
|
||||
|
||||
assert len(result.parameters) == 1
|
||||
assert result.parameters[0].name == "param1"
|
||||
|
||||
def test_empty_parameters(self):
|
||||
tool = _mock_tool(base_params=[], runtime_params=[])
|
||||
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None)
|
||||
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert len(result.parameters) == 0
|
||||
|
||||
def test_none_parameters(self):
|
||||
tool = _mock_tool(base_params=None, runtime_params=[])
|
||||
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None)
|
||||
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert len(result.parameters) == 0
|
||||
|
||||
def test_parameter_order_preserved(self):
|
||||
base = [_param("p1", label="B1"), _param("p2", label="B2"), _param("p3", label="B3")]
|
||||
runtime = [_param("p2", label="R2"), _param("p4", label="R4")]
|
||||
tool = _mock_tool(base_params=base, runtime_params=runtime)
|
||||
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None)
|
||||
|
||||
assert [p.name for p in result.parameters] == ["p1", "p2", "p3", "p4"]
|
||||
assert result.parameters[1].label.en_US == "R2"
|
||||
|
||||
|
||||
class TestWorkflowProviderToUserProvider:
|
||||
"""Tests for ToolTransformService.workflow_provider_to_user_provider."""
|
||||
|
||||
@staticmethod
|
||||
def _make_controller(provider_id="provider_123", **identity_overrides):
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
|
||||
defaults = {
|
||||
"author": "test_author",
|
||||
"name": "test_workflow_tool",
|
||||
"description": I18nObject(en_US="Test description"),
|
||||
"icon": '{"type": "emoji", "content": "🔧"}',
|
||||
"icon_dark": None,
|
||||
"label": I18nObject(en_US="Test Workflow Tool"),
|
||||
}
|
||||
defaults.update(identity_overrides)
|
||||
identity = ToolProviderIdentity(**defaults)
|
||||
entity = ToolProviderEntity(identity=identity)
|
||||
return WorkflowToolProviderController(entity=entity, provider_id=provider_id)
|
||||
|
||||
def test_with_workflow_app_id(self):
|
||||
ctrl = self._make_controller()
|
||||
|
||||
result = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=ctrl,
|
||||
labels=["l1", "l2"],
|
||||
workflow_app_id="app_123",
|
||||
)
|
||||
|
||||
assert isinstance(result, ToolProviderApiEntity)
|
||||
assert result.id == "provider_123"
|
||||
assert result.type == ToolProviderType.WORKFLOW
|
||||
assert result.workflow_app_id == "app_123"
|
||||
assert result.labels == ["l1", "l2"]
|
||||
assert result.is_team_authorization is True
|
||||
|
||||
def test_without_workflow_app_id(self):
|
||||
ctrl = self._make_controller()
|
||||
|
||||
result = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=ctrl,
|
||||
labels=["l1"],
|
||||
)
|
||||
|
||||
assert result.workflow_app_id is None
|
||||
|
||||
def test_workflow_app_id_none_explicit(self):
|
||||
ctrl = self._make_controller()
|
||||
|
||||
result = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=ctrl,
|
||||
labels=None,
|
||||
workflow_app_id=None,
|
||||
)
|
||||
|
||||
assert result.workflow_app_id is None
|
||||
assert result.labels == []
|
||||
|
||||
def test_preserves_other_fields(self):
|
||||
ctrl = self._make_controller(
|
||||
"provider_456",
|
||||
author="another_author",
|
||||
name="another_workflow_tool",
|
||||
description=I18nObject(en_US="Another desc", zh_Hans="Another desc"),
|
||||
icon='{"type": "emoji", "content": "⚙️"}',
|
||||
icon_dark='{"type": "emoji", "content": "🔧"}',
|
||||
label=I18nObject(en_US="Another Tool", zh_Hans="Another Tool"),
|
||||
)
|
||||
|
||||
result = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=ctrl,
|
||||
labels=["automation"],
|
||||
workflow_app_id="app_456",
|
||||
)
|
||||
|
||||
assert result.id == "provider_456"
|
||||
assert result.author == "another_author"
|
||||
assert result.name == "another_workflow_tool"
|
||||
assert result.type == ToolProviderType.WORKFLOW
|
||||
assert result.workflow_app_id == "app_456"
|
||||
assert result.is_team_authorization is True
|
||||
assert result.allow_delete is True
|
||||
|
||||
@ -4,7 +4,7 @@ import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from extensions.ext_redis import redis_client
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, DatasetAutoDisableLog, Document, DocumentSegment
|
||||
@ -81,7 +81,7 @@ class TestAddDocumentToIndexTask:
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
||||
@ -13,6 +13,7 @@ import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
@ -152,7 +153,7 @@ class TestBatchCleanDocumentTask:
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
)
|
||||
|
||||
db_session_with_containers.add(document)
|
||||
@ -392,7 +393,12 @@ class TestBatchCleanDocumentTask:
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Execute the task with non-existent dataset
|
||||
batch_clean_document_task(document_ids=[document_id], dataset_id=dataset_id, doc_form="text_model", file_ids=[])
|
||||
batch_clean_document_task(
|
||||
document_ids=[document_id],
|
||||
dataset_id=dataset_id,
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
file_ids=[],
|
||||
)
|
||||
|
||||
# Verify that no index processing occurred
|
||||
mock_external_service_dependencies["index_processor"].clean.assert_not_called()
|
||||
@ -525,7 +531,11 @@ class TestBatchCleanDocumentTask:
|
||||
account = self._create_test_account(db_session_with_containers)
|
||||
|
||||
# Test different doc_form types
|
||||
doc_forms = ["text_model", "qa_model", "hierarchical_model"]
|
||||
doc_forms = [
|
||||
IndexStructureType.PARAGRAPH_INDEX,
|
||||
IndexStructureType.QA_INDEX,
|
||||
IndexStructureType.PARENT_CHILD_INDEX,
|
||||
]
|
||||
|
||||
for doc_form in doc_forms:
|
||||
dataset = self._create_test_dataset(db_session_with_containers, account)
|
||||
|
||||
@ -19,6 +19,7 @@ import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
@ -141,7 +142,7 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
name=fake.company(),
|
||||
description=fake.text(),
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model="text-embedding-ada-002",
|
||||
embedding_model_provider="openai",
|
||||
created_by=account.id,
|
||||
@ -179,7 +180,7 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
word_count=0,
|
||||
)
|
||||
|
||||
@ -221,17 +222,17 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
|
||||
return upload_file
|
||||
|
||||
def _create_test_csv_content(self, content_type="text_model"):
|
||||
def _create_test_csv_content(self, content_type=IndexStructureType.PARAGRAPH_INDEX):
|
||||
"""
|
||||
Helper method to create test CSV content.
|
||||
|
||||
Args:
|
||||
content_type: Type of content to create ("text_model" or "qa_model")
|
||||
content_type: Type of content to create (IndexStructureType.PARAGRAPH_INDEX or IndexStructureType.QA_INDEX)
|
||||
|
||||
Returns:
|
||||
str: CSV content as string
|
||||
"""
|
||||
if content_type == "qa_model":
|
||||
if content_type == IndexStructureType.QA_INDEX:
|
||||
csv_content = "content,answer\n"
|
||||
csv_content += "This is the first segment content,This is the first answer\n"
|
||||
csv_content += "This is the second segment content,This is the second answer\n"
|
||||
@ -264,7 +265,7 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
upload_file = self._create_test_upload_file(db_session_with_containers, account, tenant)
|
||||
|
||||
# Create CSV content
|
||||
csv_content = self._create_test_csv_content("text_model")
|
||||
csv_content = self._create_test_csv_content(IndexStructureType.PARAGRAPH_INDEX)
|
||||
|
||||
# Mock storage to return our CSV content
|
||||
mock_storage = mock_external_service_dependencies["storage"]
|
||||
@ -451,7 +452,7 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=False, # Document is disabled
|
||||
archived=False,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
word_count=0,
|
||||
),
|
||||
# Archived document
|
||||
@ -467,7 +468,7 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=True, # Document is archived
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
word_count=0,
|
||||
),
|
||||
# Document with incomplete indexing
|
||||
@ -483,7 +484,7 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
indexing_status=IndexingStatus.INDEXING, # Not completed
|
||||
enabled=True,
|
||||
archived=False,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
word_count=0,
|
||||
),
|
||||
]
|
||||
@ -655,7 +656,7 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create CSV content
|
||||
csv_content = self._create_test_csv_content("text_model")
|
||||
csv_content = self._create_test_csv_content(IndexStructureType.PARAGRAPH_INDEX)
|
||||
|
||||
# Mock storage to return our CSV content
|
||||
mock_storage = mock_external_service_dependencies["storage"]
|
||||
|
||||
@ -18,6 +18,7 @@ import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import (
|
||||
@ -153,7 +154,7 @@ class TestCleanDatasetTask:
|
||||
tenant_id=tenant.id,
|
||||
name="test_dataset",
|
||||
description="Test dataset for cleanup testing",
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
index_struct='{"type": "paragraph"}',
|
||||
collection_binding_id=str(uuid.uuid4()),
|
||||
created_by=account.id,
|
||||
@ -192,7 +193,7 @@ class TestCleanDatasetTask:
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
doc_form="paragraph_index",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
word_count=100,
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
@ -869,7 +870,7 @@ class TestCleanDatasetTask:
|
||||
tenant_id=tenant.id,
|
||||
name=long_name,
|
||||
description=long_description,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
index_struct='{"type": "paragraph", "max_length": 10000}',
|
||||
collection_binding_id=str(uuid.uuid4()),
|
||||
created_by=account.id,
|
||||
|
||||
@ -12,6 +12,7 @@ from unittest.mock import Mock, patch
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
|
||||
from services.account_service import AccountService, TenantService
|
||||
@ -114,7 +115,7 @@ class TestCleanNotionDocumentTask:
|
||||
name=f"Notion Page {i}",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model", # Set doc_form to ensure dataset.doc_form works
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX, # Set doc_form to ensure dataset.doc_form works
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
)
|
||||
@ -261,7 +262,7 @@ class TestCleanNotionDocumentTask:
|
||||
|
||||
# Test different index types
|
||||
# Note: Only testing text_model to avoid dependency on external services
|
||||
index_types = ["text_model"]
|
||||
index_types = [IndexStructureType.PARAGRAPH_INDEX]
|
||||
|
||||
for index_type in index_types:
|
||||
# Create dataset (doc_form will be set via document creation)
|
||||
|
||||
@ -12,6 +12,7 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from extensions.ext_redis import redis_client
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
@ -120,7 +121,7 @@ class TestCreateSegmentToIndexTask:
|
||||
description=fake.text(max_nb_chars=100),
|
||||
tenant_id=tenant_id,
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model_provider="openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
created_by=account_id,
|
||||
@ -141,7 +142,7 @@ class TestCreateSegmentToIndexTask:
|
||||
enabled=True,
|
||||
archived=False,
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
doc_form="qa_model",
|
||||
doc_form=IndexStructureType.QA_INDEX,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.commit()
|
||||
@ -301,7 +302,7 @@ class TestCreateSegmentToIndexTask:
|
||||
enabled=True,
|
||||
archived=False,
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.commit()
|
||||
@ -552,7 +553,11 @@ class TestCreateSegmentToIndexTask:
|
||||
- Processing completes successfully for different forms
|
||||
"""
|
||||
# Arrange: Test different doc_forms
|
||||
doc_forms = ["qa_model", "text_model", "web_model"]
|
||||
doc_forms = [
|
||||
IndexStructureType.QA_INDEX,
|
||||
IndexStructureType.PARAGRAPH_INDEX,
|
||||
IndexStructureType.PARAGRAPH_INDEX,
|
||||
]
|
||||
|
||||
for doc_form in doc_forms:
|
||||
# Create fresh test data for each form
|
||||
|
||||
@ -8,6 +8,7 @@ import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedError
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document
|
||||
@ -141,7 +142,7 @@ class TestDatasetIndexingTaskIntegration:
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
||||
@ -12,6 +12,7 @@ from unittest.mock import ANY, Mock, patch
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
|
||||
from services.account_service import AccountService, TenantService
|
||||
@ -107,7 +108,7 @@ class TestDealDatasetVectorIndexTask:
|
||||
name="Document for doc_form",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
@ -167,7 +168,7 @@ class TestDealDatasetVectorIndexTask:
|
||||
name="Document for doc_form",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
@ -187,7 +188,7 @@ class TestDealDatasetVectorIndexTask:
|
||||
name="Test Document",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
@ -268,7 +269,7 @@ class TestDealDatasetVectorIndexTask:
|
||||
name="Document for doc_form",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="parent_child_index",
|
||||
doc_form=IndexStructureType.PARENT_CHILD_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
@ -288,7 +289,7 @@ class TestDealDatasetVectorIndexTask:
|
||||
name="Test Document",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="parent_child_index",
|
||||
doc_form=IndexStructureType.PARENT_CHILD_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
@ -416,7 +417,7 @@ class TestDealDatasetVectorIndexTask:
|
||||
name="Test Document",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
@ -505,7 +506,7 @@ class TestDealDatasetVectorIndexTask:
|
||||
name="Document for doc_form",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
@ -525,7 +526,7 @@ class TestDealDatasetVectorIndexTask:
|
||||
name="Test Document",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
@ -601,7 +602,7 @@ class TestDealDatasetVectorIndexTask:
|
||||
name="Test Document",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="qa_index",
|
||||
doc_form=IndexStructureType.QA_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
@ -638,7 +639,7 @@ class TestDealDatasetVectorIndexTask:
|
||||
assert updated_document.indexing_status == IndexingStatus.COMPLETED
|
||||
|
||||
# Verify index processor was initialized with custom index type
|
||||
mock_index_processor_factory.assert_called_once_with("qa_index")
|
||||
mock_index_processor_factory.assert_called_once_with(IndexStructureType.QA_INDEX)
|
||||
mock_factory = mock_index_processor_factory.return_value
|
||||
mock_processor = mock_factory.init_index_processor.return_value
|
||||
mock_processor.load.assert_called_once()
|
||||
@ -677,7 +678,7 @@ class TestDealDatasetVectorIndexTask:
|
||||
name="Test Document",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
@ -714,7 +715,7 @@ class TestDealDatasetVectorIndexTask:
|
||||
assert updated_document.indexing_status == IndexingStatus.COMPLETED
|
||||
|
||||
# Verify index processor was initialized with the document's index type
|
||||
mock_index_processor_factory.assert_called_once_with("text_model")
|
||||
mock_index_processor_factory.assert_called_once_with(IndexStructureType.PARAGRAPH_INDEX)
|
||||
mock_factory = mock_index_processor_factory.return_value
|
||||
mock_processor = mock_factory.init_index_processor.return_value
|
||||
mock_processor.load.assert_called_once()
|
||||
@ -753,7 +754,7 @@ class TestDealDatasetVectorIndexTask:
|
||||
name="Document for doc_form",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
@ -775,7 +776,7 @@ class TestDealDatasetVectorIndexTask:
|
||||
name=f"Test Document {i}",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
@ -856,7 +857,7 @@ class TestDealDatasetVectorIndexTask:
|
||||
name="Document for doc_form",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
@ -876,7 +877,7 @@ class TestDealDatasetVectorIndexTask:
|
||||
name="Test Document",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
@ -953,7 +954,7 @@ class TestDealDatasetVectorIndexTask:
|
||||
name="Document for doc_form",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
@ -973,7 +974,7 @@ class TestDealDatasetVectorIndexTask:
|
||||
name="Enabled Document",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
@ -992,7 +993,7 @@ class TestDealDatasetVectorIndexTask:
|
||||
name="Disabled Document",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=False, # This document should be skipped
|
||||
@ -1074,7 +1075,7 @@ class TestDealDatasetVectorIndexTask:
|
||||
name="Document for doc_form",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
@ -1094,7 +1095,7 @@ class TestDealDatasetVectorIndexTask:
|
||||
name="Active Document",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
@ -1113,7 +1114,7 @@ class TestDealDatasetVectorIndexTask:
|
||||
name="Archived Document",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
@ -1195,7 +1196,7 @@ class TestDealDatasetVectorIndexTask:
|
||||
name="Document for doc_form",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
@ -1215,7 +1216,7 @@ class TestDealDatasetVectorIndexTask:
|
||||
name="Completed Document",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
@ -1234,7 +1235,7 @@ class TestDealDatasetVectorIndexTask:
|
||||
name="Incomplete Document",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.INDEXING, # This document should be skipped
|
||||
enabled=True,
|
||||
|
||||
@ -12,7 +12,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
from faker import Faker
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from models import Account, Dataset, Document, DocumentSegment, Tenant
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom, DocumentDocType, IndexingStatus, SegmentStatus
|
||||
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
|
||||
@ -108,7 +108,7 @@ class TestDeleteSegmentFromIndexTask:
|
||||
dataset.provider = "vendor"
|
||||
dataset.permission = "only_me"
|
||||
dataset.data_source_type = DataSourceType.UPLOAD_FILE
|
||||
dataset.indexing_technique = "high_quality"
|
||||
dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY
|
||||
dataset.index_struct = '{"type": "paragraph"}'
|
||||
dataset.created_by = account.id
|
||||
dataset.created_at = fake.date_time_this_year()
|
||||
|
||||
@ -15,6 +15,7 @@ import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from extensions.ext_redis import redis_client
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
@ -99,7 +100,7 @@ class TestDisableSegmentFromIndexTask:
|
||||
name=fake.sentence(nb_words=3),
|
||||
description=fake.text(max_nb_chars=200),
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
@ -113,7 +114,7 @@ class TestDisableSegmentFromIndexTask:
|
||||
dataset: Dataset,
|
||||
tenant: Tenant,
|
||||
account: Account,
|
||||
doc_form: str = "text_model",
|
||||
doc_form: str = IndexStructureType.PARAGRAPH_INDEX,
|
||||
) -> Document:
|
||||
"""
|
||||
Helper method to create a test document.
|
||||
@ -476,7 +477,11 @@ class TestDisableSegmentFromIndexTask:
|
||||
- Index processor clean method is called correctly
|
||||
"""
|
||||
# Test different document forms
|
||||
doc_forms = ["text_model", "qa_model", "table_model"]
|
||||
doc_forms = [
|
||||
IndexStructureType.PARAGRAPH_INDEX,
|
||||
IndexStructureType.QA_INDEX,
|
||||
IndexStructureType.PARENT_CHILD_INDEX,
|
||||
]
|
||||
|
||||
for doc_form in doc_forms:
|
||||
# Arrange: Create test data for each form
|
||||
|
||||
@ -11,6 +11,7 @@ from unittest.mock import MagicMock, patch
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from models import Account, Dataset, DocumentSegment
|
||||
from models import Document as DatasetDocument
|
||||
from models.dataset import DatasetProcessRule
|
||||
@ -102,7 +103,7 @@ class TestDisableSegmentsFromIndexTask:
|
||||
provider="vendor",
|
||||
permission="only_me",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
embedding_model="text-embedding-ada-002",
|
||||
@ -153,7 +154,7 @@ class TestDisableSegmentsFromIndexTask:
|
||||
document.indexing_status = "completed"
|
||||
document.enabled = True
|
||||
document.archived = False
|
||||
document.doc_form = "text_model" # Use text_model form for testing
|
||||
document.doc_form = IndexStructureType.PARAGRAPH_INDEX # Use text_model form for testing
|
||||
document.doc_language = "en"
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.commit()
|
||||
@ -500,7 +501,11 @@ class TestDisableSegmentsFromIndexTask:
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
|
||||
# Test different document forms
|
||||
doc_forms = ["text_model", "qa_model", "hierarchical_model"]
|
||||
doc_forms = [
|
||||
IndexStructureType.PARAGRAPH_INDEX,
|
||||
IndexStructureType.QA_INDEX,
|
||||
IndexStructureType.PARENT_CHILD_INDEX,
|
||||
]
|
||||
|
||||
for doc_form in doc_forms:
|
||||
# Update document form
|
||||
|
||||
@ -14,6 +14,7 @@ from uuid import uuid4
|
||||
import pytest
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
|
||||
@ -56,7 +57,7 @@ class DocumentIndexingSyncTaskTestDataFactory:
|
||||
name=f"dataset-{uuid4()}",
|
||||
description="sync test dataset",
|
||||
data_source_type=DataSourceType.NOTION_IMPORT,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
created_by=created_by,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
@ -85,7 +86,7 @@ class DocumentIndexingSyncTaskTestDataFactory:
|
||||
created_by=created_by,
|
||||
indexing_status=indexing_status,
|
||||
enabled=True,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user