diff --git a/.agents/skills/frontend-testing/SKILL.md b/.agents/skills/frontend-testing/SKILL.md index 69c099a262..4da070bdbf 100644 --- a/.agents/skills/frontend-testing/SKILL.md +++ b/.agents/skills/frontend-testing/SKILL.md @@ -63,7 +63,8 @@ pnpm analyze-component --review ### File Naming -- Test files: `ComponentName.spec.tsx` (same directory as component) +- Test files: `ComponentName.spec.tsx` inside a same-level `__tests__/` directory +- Placement rule: Component, hook, and utility tests must live in a sibling `__tests__/` folder at the same level as the source under test. For example, `foo/index.tsx` maps to `foo/__tests__/index.spec.tsx`, and `foo/bar.ts` maps to `foo/__tests__/bar.spec.ts`. - Integration tests: `web/__tests__/` directory ## Test Structure Template diff --git a/.agents/skills/frontend-testing/assets/component-test.template.tsx b/.agents/skills/frontend-testing/assets/component-test.template.tsx index 6b7803bd4b..ff38f88d23 100644 --- a/.agents/skills/frontend-testing/assets/component-test.template.tsx +++ b/.agents/skills/frontend-testing/assets/component-test.template.tsx @@ -41,7 +41,7 @@ import userEvent from '@testing-library/user-event' // Router (if component uses useRouter, usePathname, useSearchParams) // WHY: Isolates tests from Next.js routing, enables testing navigation behavior // const mockPush = vi.fn() -// vi.mock('next/navigation', () => ({ +// vi.mock('@/next/navigation', () => ({ // useRouter: () => ({ push: mockPush }), // usePathname: () => '/test-path', // })) diff --git a/.github/workflows/web-tests.yml b/.github/workflows/web-tests.yml index 11222146cf..be2595a599 100644 --- a/.github/workflows/web-tests.yml +++ b/.github/workflows/web-tests.yml @@ -29,8 +29,8 @@ jobs: strategy: fail-fast: false matrix: - shardIndex: [1, 2, 3, 4] - shardTotal: [4] + shardIndex: [1, 2, 3, 4, 5, 6] + shardTotal: [6] defaults: run: shell: bash diff --git a/api/AGENTS.md b/api/AGENTS.md index d43d2528b8..8e5d9f600d 100644 --- a/api/AGENTS.md +++ b/api/AGENTS.md @@ -78,7 +78,7 @@ class UserProfile(TypedDict): nickname: NotRequired[str] ``` -- For classes, declare member variables at the top of the class body (before `__init__`) so the class shape is obvious at a glance: +- For classes, declare all member variables explicitly with types at the top of the class body (before `__init__`), even when the class is not a dataclass or Pydantic model, so the class shape is obvious at a glance: ```python from datetime import datetime diff --git a/api/commands/retention.py b/api/commands/retention.py index 5a91c1cc70..82a77ea77a 100644 --- a/api/commands/retention.py +++ b/api/commands/retention.py @@ -88,6 +88,8 @@ def clean_workflow_runs( """ Clean workflow runs and related workflow data for free tenants. """ + from extensions.otel.runtime import flush_telemetry + if (start_from is None) ^ (end_before is None): raise click.UsageError("--start-from and --end-before must be provided together.") @@ -104,16 +106,27 @@ def clean_workflow_runs( end_before = now - datetime.timedelta(days=to_days_ago) before_days = 0 + if from_days_ago is not None and to_days_ago is not None: + task_label = f"{from_days_ago}to{to_days_ago}" + elif start_from is None: + task_label = f"before-{before_days}" + else: + task_label = "custom" + start_time = datetime.datetime.now(datetime.UTC) click.echo(click.style(f"Starting workflow run cleanup at {start_time.isoformat()}.", fg="white")) - WorkflowRunCleanup( - days=before_days, - batch_size=batch_size, - start_from=start_from, - end_before=end_before, - dry_run=dry_run, - ).run() + try: + WorkflowRunCleanup( + days=before_days, + batch_size=batch_size, + start_from=start_from, + end_before=end_before, + dry_run=dry_run, + task_label=task_label, + ).run() + finally: + flush_telemetry() end_time = datetime.datetime.now(datetime.UTC) elapsed = end_time - start_time @@ -659,6 +672,8 @@ def clean_expired_messages( """ Clean expired messages and related data for tenants based on clean policy. """ + from extensions.otel.runtime import flush_telemetry + click.echo(click.style("clean_messages: start clean messages.", fg="green")) start_at = time.perf_counter() @@ -698,6 +713,13 @@ def clean_expired_messages( # NOTE: graceful_period will be ignored when billing is disabled. policy = create_message_clean_policy(graceful_period_days=graceful_period) + if from_days_ago is not None and before_days is not None: + task_label = f"{from_days_ago}to{before_days}" + elif start_from is None and before_days is not None: + task_label = f"before-{before_days}" + else: + task_label = "custom" + # Create and run the cleanup service if abs_mode: assert start_from is not None @@ -708,6 +730,7 @@ def clean_expired_messages( end_before=end_before, batch_size=batch_size, dry_run=dry_run, + task_label=task_label, ) elif from_days_ago is None: assert before_days is not None @@ -716,6 +739,7 @@ def clean_expired_messages( days=before_days, batch_size=batch_size, dry_run=dry_run, + task_label=task_label, ) else: assert before_days is not None @@ -727,6 +751,7 @@ def clean_expired_messages( end_before=now - datetime.timedelta(days=before_days), batch_size=batch_size, dry_run=dry_run, + task_label=task_label, ) stats = service.run() @@ -752,6 +777,8 @@ def clean_expired_messages( ) ) raise + finally: + flush_telemetry() click.echo(click.style("messages cleanup completed.", fg="green")) diff --git a/api/commands/vector.py b/api/commands/vector.py index 5f41d469c8..52ce26c26d 100644 --- a/api/commands/vector.py +++ b/api/commands/vector.py @@ -14,6 +14,7 @@ from core.rag.models.document import ChildDocument, Document from extensions.ext_database import db from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment from models.dataset import Document as DatasetDocument +from models.enums import DatasetMetadataType, IndexingStatus, SegmentStatus from models.model import App, AppAnnotationSetting, MessageAnnotation @@ -242,7 +243,7 @@ def migrate_knowledge_vector_database(): dataset_documents = db.session.scalars( select(DatasetDocument).where( DatasetDocument.dataset_id == dataset.id, - DatasetDocument.indexing_status == "completed", + DatasetDocument.indexing_status == IndexingStatus.COMPLETED, DatasetDocument.enabled == True, DatasetDocument.archived == False, ) @@ -254,7 +255,7 @@ def migrate_knowledge_vector_database(): segments = db.session.scalars( select(DocumentSegment).where( DocumentSegment.document_id == dataset_document.id, - DocumentSegment.status == "completed", + DocumentSegment.status == SegmentStatus.COMPLETED, DocumentSegment.enabled == True, ) ).all() @@ -430,7 +431,7 @@ def old_metadata_migration(): tenant_id=document.tenant_id, dataset_id=document.dataset_id, name=key, - type="string", + type=DatasetMetadataType.STRING, created_by=document.created_by, ) db.session.add(dataset_metadata) diff --git a/api/configs/middleware/cache/redis_config.py b/api/configs/middleware/cache/redis_config.py index 367cb52731..3b91207545 100644 --- a/api/configs/middleware/cache/redis_config.py +++ b/api/configs/middleware/cache/redis_config.py @@ -1,4 +1,4 @@ -from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt +from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, field_validator from pydantic_settings import BaseSettings @@ -116,3 +116,13 @@ class RedisConfig(BaseSettings): description="Maximum connections in the Redis connection pool (unset for library default)", default=None, ) + + @field_validator("REDIS_MAX_CONNECTIONS", mode="before") + @classmethod + def _empty_string_to_none_for_max_conns(cls, v): + """Allow empty string in env/.env to mean 'unset' (None).""" + if v is None: + return None + if isinstance(v, str) and v.strip() == "": + return None + return v diff --git a/api/configs/middleware/cache/redis_pubsub_config.py b/api/configs/middleware/cache/redis_pubsub_config.py index d30831a0ec..0a166818b3 100644 --- a/api/configs/middleware/cache/redis_pubsub_config.py +++ b/api/configs/middleware/cache/redis_pubsub_config.py @@ -1,4 +1,4 @@ -from typing import Literal, Protocol +from typing import Literal, Protocol, cast from urllib.parse import quote_plus, urlunparse from pydantic import AliasChoices, Field @@ -12,16 +12,13 @@ class RedisConfigDefaults(Protocol): REDIS_PASSWORD: str | None REDIS_DB: int REDIS_USE_SSL: bool - REDIS_USE_SENTINEL: bool | None - REDIS_USE_CLUSTERS: bool -class RedisConfigDefaultsMixin: - def _redis_defaults(self: RedisConfigDefaults) -> RedisConfigDefaults: - return self +def _redis_defaults(config: object) -> RedisConfigDefaults: + return cast(RedisConfigDefaults, config) -class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin): +class RedisPubSubConfig(BaseSettings): """ Configuration settings for event transport between API and workers. @@ -74,7 +71,7 @@ class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin): ) def _build_default_pubsub_url(self) -> str: - defaults = self._redis_defaults() + defaults = _redis_defaults(self) if not defaults.REDIS_HOST or not defaults.REDIS_PORT: raise ValueError("PUBSUB_REDIS_URL must be set when default Redis URL cannot be constructed") @@ -91,11 +88,9 @@ class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin): if userinfo: userinfo = f"{userinfo}@" - host = defaults.REDIS_HOST - port = defaults.REDIS_PORT db = defaults.REDIS_DB - netloc = f"{userinfo}{host}:{port}" + netloc = f"{userinfo}{defaults.REDIS_HOST}:{defaults.REDIS_PORT}" return urlunparse((scheme, netloc, f"/{db}", "", "", "")) @property diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index eebba57fa3..725a8380cd 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -54,6 +54,7 @@ from fields.document_fields import document_status_fields from libs.login import current_account_with_tenant, login_required from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile from models.dataset import DatasetPermission, DatasetPermissionEnum +from models.enums import SegmentStatus from models.provider_ids import ModelProviderID from services.api_token_service import ApiTokenCache from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService @@ -741,13 +742,15 @@ class DatasetIndexingStatusApi(Resource): .where( DocumentSegment.completed_at.isnot(None), DocumentSegment.document_id == str(document.id), - DocumentSegment.status != "re_segment", + DocumentSegment.status != SegmentStatus.RE_SEGMENT, ) .count() ) total_segments = ( db.session.query(DocumentSegment) - .where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") + .where( + DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT + ) .count() ) # Create a dictionary with document attributes and additional fields diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index ee726bc470..0c441553be 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -42,6 +42,7 @@ from libs.datetime_utils import naive_utc_now from libs.login import current_account_with_tenant, login_required from models import DatasetProcessRule, Document, DocumentSegment, UploadFile from models.dataset import DocumentPipelineExecutionLog +from models.enums import IndexingStatus, SegmentStatus from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel from services.file_service import FileService @@ -332,13 +333,16 @@ class DatasetDocumentListApi(Resource): .where( DocumentSegment.completed_at.isnot(None), DocumentSegment.document_id == str(document.id), - DocumentSegment.status != "re_segment", + DocumentSegment.status != SegmentStatus.RE_SEGMENT, ) .count() ) total_segments = ( db.session.query(DocumentSegment) - .where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") + .where( + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) .count() ) document.completed_segments = completed_segments @@ -503,7 +507,7 @@ class DocumentIndexingEstimateApi(DocumentResource): document_id = str(document_id) document = self.get_document(dataset_id, document_id) - if document.indexing_status in {"completed", "error"}: + if document.indexing_status in {IndexingStatus.COMPLETED, IndexingStatus.ERROR}: raise DocumentAlreadyFinishedError() data_process_rule = document.dataset_process_rule @@ -573,7 +577,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): data_process_rule_dict = data_process_rule.to_dict() if data_process_rule else {} extract_settings = [] for document in documents: - if document.indexing_status in {"completed", "error"}: + if document.indexing_status in {IndexingStatus.COMPLETED, IndexingStatus.ERROR}: raise DocumentAlreadyFinishedError() data_source_info = document.data_source_info_dict match document.data_source_type: @@ -671,19 +675,21 @@ class DocumentBatchIndexingStatusApi(DocumentResource): .where( DocumentSegment.completed_at.isnot(None), DocumentSegment.document_id == str(document.id), - DocumentSegment.status != "re_segment", + DocumentSegment.status != SegmentStatus.RE_SEGMENT, ) .count() ) total_segments = ( db.session.query(DocumentSegment) - .where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") + .where( + DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT + ) .count() ) # Create a dictionary with document attributes and additional fields document_dict = { "id": document.id, - "indexing_status": "paused" if document.is_paused else document.indexing_status, + "indexing_status": IndexingStatus.PAUSED if document.is_paused else document.indexing_status, "processing_started_at": document.processing_started_at, "parsing_completed_at": document.parsing_completed_at, "cleaning_completed_at": document.cleaning_completed_at, @@ -720,20 +726,20 @@ class DocumentIndexingStatusApi(DocumentResource): .where( DocumentSegment.completed_at.isnot(None), DocumentSegment.document_id == str(document_id), - DocumentSegment.status != "re_segment", + DocumentSegment.status != SegmentStatus.RE_SEGMENT, ) .count() ) total_segments = ( db.session.query(DocumentSegment) - .where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment") + .where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != SegmentStatus.RE_SEGMENT) .count() ) # Create a dictionary with document attributes and additional fields document_dict = { "id": document.id, - "indexing_status": "paused" if document.is_paused else document.indexing_status, + "indexing_status": IndexingStatus.PAUSED if document.is_paused else document.indexing_status, "processing_started_at": document.processing_started_at, "parsing_completed_at": document.parsing_completed_at, "cleaning_completed_at": document.cleaning_completed_at, @@ -955,7 +961,7 @@ class DocumentProcessingApi(DocumentResource): match action: case "pause": - if document.indexing_status != "indexing": + if document.indexing_status != IndexingStatus.INDEXING: raise InvalidActionError("Document not in indexing state.") document.paused_by = current_user.id @@ -964,7 +970,7 @@ class DocumentProcessingApi(DocumentResource): db.session.commit() case "resume": - if document.indexing_status not in {"paused", "error"}: + if document.indexing_status not in {IndexingStatus.PAUSED, IndexingStatus.ERROR}: raise InvalidActionError("Document not in paused or error state.") document.paused_by = None @@ -1169,7 +1175,7 @@ class DocumentRetryApi(DocumentResource): raise ArchivedDocumentImmutableError() # 400 if document is completed - if document.indexing_status == "completed": + if document.indexing_status == IndexingStatus.COMPLETED: raise DocumentAlreadyFinishedError() retry_documents.append(document) except Exception: diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index 6e0cd31b8d..4f31093cfe 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -46,6 +46,8 @@ class PipelineTemplateDetailApi(Resource): type = request.args.get("type", default="built-in", type=str) rag_pipeline_service = RagPipelineService() pipeline_template = rag_pipeline_service.get_pipeline_template_detail(template_id, type) + if pipeline_template is None: + return {"error": "Pipeline template not found from upstream service."}, 404 return pipeline_template, 200 diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 5a1d28ea1d..d34b4124ae 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -36,6 +36,7 @@ from extensions.ext_database import db from fields.document_fields import document_fields, document_status_fields from libs.login import current_user from models.dataset import Dataset, Document, DocumentSegment +from models.enums import SegmentStatus from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import ( KnowledgeConfig, @@ -622,13 +623,15 @@ class DocumentIndexingStatusApi(DatasetApiResource): .where( DocumentSegment.completed_at.isnot(None), DocumentSegment.document_id == str(document.id), - DocumentSegment.status != "re_segment", + DocumentSegment.status != SegmentStatus.RE_SEGMENT, ) .count() ) total_segments = ( db.session.query(DocumentSegment) - .where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") + .where( + DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT + ) .count() ) # Create a dictionary with document attributes and additional fields diff --git a/api/controllers/trigger/webhook.py b/api/controllers/trigger/webhook.py index 22b24271c6..eb579da5d4 100644 --- a/api/controllers/trigger/webhook.py +++ b/api/controllers/trigger/webhook.py @@ -70,7 +70,14 @@ def handle_webhook(webhook_id: str): @bp.route("/webhook-debug/", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"]) def handle_webhook_debug(webhook_id: str): - """Handle webhook debug calls without triggering production workflow execution.""" + """Handle webhook debug calls without triggering production workflow execution. + + The debug webhook endpoint is only for draft inspection flows. It never enqueues + Celery work for the published workflow; instead it dispatches an in-memory debug + event to an active Variable Inspector listener. Returning a clear error when no + listener is registered prevents a misleading 200 response for requests that are + effectively dropped. + """ try: webhook_trigger, _, node_config, webhook_data, error = _prepare_webhook_execution(webhook_id, is_debug=True) if error: @@ -94,11 +101,32 @@ def handle_webhook_debug(webhook_id: str): "method": webhook_data.get("method"), }, ) - TriggerDebugEventBus.dispatch( + dispatch_count = TriggerDebugEventBus.dispatch( tenant_id=webhook_trigger.tenant_id, event=event, pool_key=pool_key, ) + if dispatch_count == 0: + logger.warning( + "Webhook debug request dropped without an active listener for webhook %s (tenant=%s, app=%s, node=%s)", + webhook_trigger.webhook_id, + webhook_trigger.tenant_id, + webhook_trigger.app_id, + webhook_trigger.node_id, + ) + return ( + jsonify( + { + "error": "No active debug listener", + "message": ( + "The webhook debug URL only works while the Variable Inspector is listening. " + "Use the published webhook URL to execute the workflow in Celery." + ), + "execution_url": webhook_trigger.webhook_url, + } + ), + 409, + ) response_data, status_code = WebhookService.generate_webhook_response(node_config) return jsonify(response_data), status_code diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 4a8b5f3549..1bdc8df813 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -441,7 +441,7 @@ class BaseAgentRunner(AppRunner): continue result.append(self.organize_agent_user_prompt(message)) - agent_thoughts: list[MessageAgentThought] = message.agent_thoughts + agent_thoughts = message.agent_thoughts if agent_thoughts: for agent_thought in agent_thoughts: tool_names_raw = agent_thought.tool diff --git a/api/core/app/app_config/common/parameters_mapping/__init__.py b/api/core/app/app_config/common/parameters_mapping/__init__.py index 6f1a3bf045..460fdfb3ba 100644 --- a/api/core/app/app_config/common/parameters_mapping/__init__.py +++ b/api/core/app/app_config/common/parameters_mapping/__init__.py @@ -1,13 +1,36 @@ from collections.abc import Mapping -from typing import Any +from typing import Any, TypedDict from configs import dify_config from constants import DEFAULT_FILE_NUMBER_LIMITS +class SystemParametersDict(TypedDict): + image_file_size_limit: int + video_file_size_limit: int + audio_file_size_limit: int + file_size_limit: int + workflow_file_upload_limit: int + + +class AppParametersDict(TypedDict): + opening_statement: str | None + suggested_questions: list[str] + suggested_questions_after_answer: dict[str, Any] + speech_to_text: dict[str, Any] + text_to_speech: dict[str, Any] + retriever_resource: dict[str, Any] + annotation_reply: dict[str, Any] + more_like_this: dict[str, Any] + user_input_form: list[dict[str, Any]] + sensitive_word_avoidance: dict[str, Any] + file_upload: dict[str, Any] + system_parameters: SystemParametersDict + + def get_parameters_from_feature_dict( *, features_dict: Mapping[str, Any], user_input_form: list[dict[str, Any]] -) -> Mapping[str, Any]: +) -> AppParametersDict: """ Mapping from feature dict to webapp parameters """ diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index 70f43b2c83..f04a8df119 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -8,6 +8,7 @@ from core.app.app_config.entities import ( ModelConfig, ) from core.entities.agent_entities import PlanningStrategy +from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict from models.model import AppMode, AppModelConfigDict from services.dataset_service import DatasetService @@ -117,8 +118,10 @@ class DatasetConfigManager: score_threshold=float(score_threshold_val) if dataset_configs.get("score_threshold_enabled", False) and score_threshold_val is not None else None, - reranking_model=reranking_model_val if isinstance(reranking_model_val, dict) else None, - weights=weights_val if isinstance(weights_val, dict) else None, + reranking_model=cast(RerankingModelDict, reranking_model_val) + if isinstance(reranking_model_val, dict) + else None, + weights=cast(WeightsDict, weights_val) if isinstance(weights_val, dict) else None, reranking_enabled=bool(dataset_configs.get("reranking_enabled", True)), rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"), metadata_filtering_mode=cast( diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index ac21577d57..95ea70bc40 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -4,6 +4,7 @@ from typing import Any, Literal from pydantic import BaseModel, Field +from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict from dify_graph.file import FileUploadConfig from dify_graph.model_runtime.entities.llm_entities import LLMMode from dify_graph.model_runtime.entities.message_entities import PromptMessageRole @@ -194,8 +195,8 @@ class DatasetRetrieveConfigEntity(BaseModel): top_k: int | None = None score_threshold: float | None = 0.0 rerank_mode: str | None = "reranking_model" - reranking_model: dict | None = None - weights: dict | None = None + reranking_model: RerankingModelDict | None = None + weights: WeightsDict | None = None reranking_enabled: bool | None = True metadata_filtering_mode: Literal["disabled", "automatic", "manual"] | None = "disabled" metadata_model_config: ModelConfig | None = None diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index 5665a2b76c..5509764508 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -3,7 +3,7 @@ import time from collections.abc import Mapping, Sequence from dataclasses import dataclass from datetime import datetime -from typing import Any, NewType, Union +from typing import Any, NewType, TypedDict, Union from sqlalchemy import select from sqlalchemy.orm import Session @@ -76,6 +76,20 @@ NodeExecutionId = NewType("NodeExecutionId", str) logger = logging.getLogger(__name__) +class AccountCreatedByDict(TypedDict): + id: str + name: str + email: str + + +class EndUserCreatedByDict(TypedDict): + id: str + user: str + + +CreatedByDict = AccountCreatedByDict | EndUserCreatedByDict + + @dataclass(slots=True) class _NodeSnapshot: """In-memory cache for node metadata between start and completion events.""" @@ -249,19 +263,19 @@ class WorkflowResponseConverter: outputs_mapping = graph_runtime_state.outputs or {} encoded_outputs = WorkflowRuntimeTypeConverter().to_json_encodable(outputs_mapping) - created_by: Mapping[str, object] | None + created_by: CreatedByDict | dict[str, object] = {} user = self._user if isinstance(user, Account): - created_by = { - "id": user.id, - "name": user.name, - "email": user.email, - } - else: - created_by = { - "id": user.id, - "user": user.session_id, - } + created_by = AccountCreatedByDict( + id=user.id, + name=user.name, + email=user.email, + ) + elif isinstance(user, EndUser): + created_by = EndUserCreatedByDict( + id=user.id, + user=user.session_id, + ) return WorkflowFinishStreamResponse( task_id=task_id, diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py index 3f9f3da9b2..50aed37163 100644 --- a/api/core/app/features/annotation_reply/annotation_reply.py +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -6,6 +6,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.rag.datasource.vdb.vector_factory import Vector from extensions.ext_database import db from models.dataset import Dataset +from models.enums import CollectionBindingType from models.model import App, AppAnnotationSetting, Message, MessageAnnotation from services.annotation_service import AppAnnotationService from services.dataset_service import DatasetCollectionBindingService @@ -43,7 +44,7 @@ class AnnotationReplyFeature: embedding_model_name = collection_binding_detail.model_name dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_provider_name, embedding_model_name, "annotation" + embedding_provider_name, embedding_model_name, CollectionBindingType.ANNOTATION ) dataset = Dataset( diff --git a/api/core/app/task_pipeline/message_file_utils.py b/api/core/app/task_pipeline/message_file_utils.py index 843e9eea30..fc8b6c6b5a 100644 --- a/api/core/app/task_pipeline/message_file_utils.py +++ b/api/core/app/task_pipeline/message_file_utils.py @@ -1,3 +1,5 @@ +from typing import TypedDict + from core.tools.signature import sign_tool_file from dify_graph.file import helpers as file_helpers from dify_graph.file.enums import FileTransferMethod @@ -6,7 +8,20 @@ from models.model import MessageFile, UploadFile MAX_TOOL_FILE_EXTENSION_LENGTH = 10 -def prepare_file_dict(message_file: MessageFile, upload_files_map: dict[str, UploadFile]) -> dict: +class MessageFileInfoDict(TypedDict): + related_id: str + extension: str + filename: str + size: int + mime_type: str + transfer_method: str + type: str + url: str + upload_file_id: str + remote_url: str | None + + +def prepare_file_dict(message_file: MessageFile, upload_files_map: dict[str, UploadFile]) -> MessageFileInfoDict: """ Prepare file dictionary for message end stream response. diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index b054409681..8de5cb1690 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -12,7 +12,7 @@ from core.rag.models.document import Document from extensions.ext_database import db from models.dataset import ChildChunk, DatasetQuery, DocumentSegment from models.dataset import Document as DatasetDocument -from models.enums import CreatorUserRole +from models.enums import CreatorUserRole, DatasetQuerySource _logger = logging.getLogger(__name__) @@ -36,7 +36,7 @@ class DatasetIndexToolCallbackHandler: dataset_query = DatasetQuery( dataset_id=dataset_id, content=query, - source="app", + source=DatasetQuerySource.APP, source_app_id=self._app_id, created_by_role=( CreatorUserRole.ACCOUNT diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 9f8d06e322..c6a270e470 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -30,6 +30,7 @@ from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from libs.datetime_utils import naive_utc_now from models.engine import db +from models.enums import CredentialSourceType from models.provider import ( LoadBalancingModelConfig, Provider, @@ -473,9 +474,21 @@ class ProviderConfiguration(BaseModel): self.switch_preferred_provider_type(provider_type=ProviderType.CUSTOM, session=session) else: - # some historical data may have a provider record but not be set as valid provider_record.is_valid = True + if provider_record.credential_id is None: + provider_record.credential_id = new_record.id + provider_record.updated_at = naive_utc_now() + + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=provider_record.id, + cache_type=ProviderCredentialsCacheType.PROVIDER, + ) + provider_model_credentials_cache.delete() + + self.switch_preferred_provider_type(provider_type=ProviderType.CUSTOM, session=session) + session.commit() except Exception: session.rollback() @@ -534,7 +547,7 @@ class ProviderConfiguration(BaseModel): self._update_load_balancing_configs_with_credential( credential_id=credential_id, credential_record=credential_record, - credential_source="provider", + credential_source=CredentialSourceType.PROVIDER, session=session, ) except Exception: @@ -611,7 +624,7 @@ class ProviderConfiguration(BaseModel): LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()), LoadBalancingModelConfig.credential_id == credential_id, - LoadBalancingModelConfig.credential_source_type == "provider", + LoadBalancingModelConfig.credential_source_type == CredentialSourceType.PROVIDER, ) lb_configs_using_credential = session.execute(lb_stmt).scalars().all() try: @@ -1031,7 +1044,7 @@ class ProviderConfiguration(BaseModel): self._update_load_balancing_configs_with_credential( credential_id=credential_id, credential_record=credential_record, - credential_source="custom_model", + credential_source=CredentialSourceType.CUSTOM_MODEL, session=session, ) except Exception: @@ -1061,7 +1074,7 @@ class ProviderConfiguration(BaseModel): LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()), LoadBalancingModelConfig.credential_id == credential_id, - LoadBalancingModelConfig.credential_source_type == "custom_model", + LoadBalancingModelConfig.credential_source_type == CredentialSourceType.CUSTOM_MODEL, ) lb_configs_using_credential = session.execute(lb_stmt).scalars().all() @@ -1699,7 +1712,7 @@ class ProviderConfiguration(BaseModel): provider_model_lb_configs = [ config for config in model_setting.load_balancing_configs - if config.credential_source_type != "custom_model" + if config.credential_source_type != CredentialSourceType.CUSTOM_MODEL ] load_balancing_enabled = model_setting.load_balancing_enabled @@ -1757,7 +1770,7 @@ class ProviderConfiguration(BaseModel): custom_model_lb_configs = [ config for config in model_setting.load_balancing_configs - if config.credential_source_type != "provider" + if config.credential_source_type != CredentialSourceType.PROVIDER ] load_balancing_enabled = model_setting.load_balancing_enabled diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 275c1fc110..52776ee626 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -40,6 +40,7 @@ from libs.datetime_utils import naive_utc_now from models import Account from models.dataset import AutomaticRulesConfig, ChildChunk, Dataset, DatasetProcessRule, DocumentSegment from models.dataset import Document as DatasetDocument +from models.enums import DataSourceType, IndexingStatus, ProcessRuleMode, SegmentStatus from models.model import UploadFile from services.feature_service import FeatureService @@ -56,7 +57,7 @@ class IndexingRunner: logger.exception("consume document failed") document = db.session.get(DatasetDocument, document_id) if document: - document.indexing_status = "error" + document.indexing_status = IndexingStatus.ERROR error_message = getattr(error, "description", str(error)) document.error = str(error_message) document.stopped_at = naive_utc_now() @@ -219,7 +220,7 @@ class IndexingRunner: if document_segments: for document_segment in document_segments: # transform segment to node - if document_segment.status != "completed": + if document_segment.status != SegmentStatus.COMPLETED: document = Document( page_content=document_segment.content, metadata={ @@ -382,7 +383,7 @@ class IndexingRunner: data_source_info = dataset_document.data_source_info_dict text_docs = [] match dataset_document.data_source_type: - case "upload_file": + case DataSourceType.UPLOAD_FILE: if not data_source_info or "upload_file_id" not in data_source_info: raise ValueError("no upload file found") stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"]) @@ -395,7 +396,7 @@ class IndexingRunner: document_model=dataset_document.doc_form, ) text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) - case "notion_import": + case DataSourceType.NOTION_IMPORT: if ( not data_source_info or "notion_workspace_id" not in data_source_info @@ -417,7 +418,7 @@ class IndexingRunner: document_model=dataset_document.doc_form, ) text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) - case "website_crawl": + case DataSourceType.WEBSITE_CRAWL: if ( not data_source_info or "provider" not in data_source_info @@ -445,7 +446,7 @@ class IndexingRunner: # update document status to splitting self._update_document_index_status( document_id=dataset_document.id, - after_indexing_status="splitting", + after_indexing_status=IndexingStatus.SPLITTING, extra_update_params={ DatasetDocument.parsing_completed_at: naive_utc_now(), }, @@ -545,7 +546,7 @@ class IndexingRunner: Clean the document text according to the processing rules. """ rules: AutomaticRulesConfig | dict[str, Any] - if processing_rule.mode == "automatic": + if processing_rule.mode == ProcessRuleMode.AUTOMATIC: rules = DatasetProcessRule.AUTOMATIC_RULES else: rules = json.loads(processing_rule.rules) if processing_rule.rules else {} @@ -636,7 +637,7 @@ class IndexingRunner: # update document status to completed self._update_document_index_status( document_id=dataset_document.id, - after_indexing_status="completed", + after_indexing_status=IndexingStatus.COMPLETED, extra_update_params={ DatasetDocument.tokens: tokens, DatasetDocument.completed_at: naive_utc_now(), @@ -659,10 +660,10 @@ class IndexingRunner: DocumentSegment.document_id == document_id, DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id.in_(document_ids), - DocumentSegment.status == "indexing", + DocumentSegment.status == SegmentStatus.INDEXING, ).update( { - DocumentSegment.status: "completed", + DocumentSegment.status: SegmentStatus.COMPLETED, DocumentSegment.enabled: True, DocumentSegment.completed_at: naive_utc_now(), } @@ -703,10 +704,10 @@ class IndexingRunner: DocumentSegment.document_id == dataset_document.id, DocumentSegment.dataset_id == dataset.id, DocumentSegment.index_node_id.in_(document_ids), - DocumentSegment.status == "indexing", + DocumentSegment.status == SegmentStatus.INDEXING, ).update( { - DocumentSegment.status: "completed", + DocumentSegment.status: SegmentStatus.COMPLETED, DocumentSegment.enabled: True, DocumentSegment.completed_at: naive_utc_now(), } @@ -725,7 +726,7 @@ class IndexingRunner: @staticmethod def _update_document_index_status( - document_id: str, after_indexing_status: str, extra_update_params: dict | None = None + document_id: str, after_indexing_status: IndexingStatus, extra_update_params: dict | None = None ): """ Update the document indexing status. @@ -803,7 +804,7 @@ class IndexingRunner: cur_time = naive_utc_now() self._update_document_index_status( document_id=dataset_document.id, - after_indexing_status="indexing", + after_indexing_status=IndexingStatus.INDEXING, extra_update_params={ DatasetDocument.cleaning_completed_at: cur_time, DatasetDocument.splitting_completed_at: cur_time, @@ -815,7 +816,7 @@ class IndexingRunner: self._update_segments_by_document( dataset_document_id=dataset_document.id, update_params={ - DocumentSegment.status: "indexing", + DocumentSegment.status: SegmentStatus.INDEXING, DocumentSegment.indexing_at: naive_utc_now(), }, ) diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index c538a557fb..ed34922346 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -196,6 +196,8 @@ class ProviderManager: if preferred_provider_type_record: preferred_provider_type = ProviderType.value_of(preferred_provider_type_record.preferred_provider_type) + elif dify_config.EDITION == "CLOUD" and system_configuration.enabled: + preferred_provider_type = ProviderType.SYSTEM elif custom_configuration.provider or custom_configuration.models: preferred_provider_type = ProviderType.CUSTOM elif system_configuration.enabled: diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index 2b73ef5f26..33eb5f963a 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -1,3 +1,5 @@ +from typing_extensions import TypedDict + from core.model_manager import ModelInstance, ModelManager from core.rag.data_post_processor.reorder import ReorderRunner from core.rag.index_processor.constant.query_type import QueryType @@ -10,6 +12,26 @@ from dify_graph.model_runtime.entities.model_entities import ModelType from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError +class RerankingModelDict(TypedDict): + reranking_provider_name: str + reranking_model_name: str + + +class VectorSettingDict(TypedDict): + vector_weight: float + embedding_provider_name: str + embedding_model_name: str + + +class KeywordSettingDict(TypedDict): + keyword_weight: float + + +class WeightsDict(TypedDict): + vector_setting: VectorSettingDict + keyword_setting: KeywordSettingDict + + class DataPostProcessor: """Interface for data post-processing document.""" @@ -17,8 +39,8 @@ class DataPostProcessor: self, tenant_id: str, reranking_mode: str, - reranking_model: dict | None = None, - weights: dict | None = None, + reranking_model: RerankingModelDict | None = None, + weights: WeightsDict | None = None, reorder_enabled: bool = False, ): self.rerank_runner = self._get_rerank_runner(reranking_mode, tenant_id, reranking_model, weights) @@ -45,8 +67,8 @@ class DataPostProcessor: self, reranking_mode: str, tenant_id: str, - reranking_model: dict | None = None, - weights: dict | None = None, + reranking_model: RerankingModelDict | None = None, + weights: WeightsDict | None = None, ) -> BaseRerankRunner | None: if reranking_mode == RerankMode.WEIGHTED_SCORE and weights: runner = RerankRunnerFactory.create_rerank_runner( @@ -79,12 +101,14 @@ class DataPostProcessor: return ReorderRunner() return None - def _get_rerank_model_instance(self, tenant_id: str, reranking_model: dict | None) -> ModelInstance | None: + def _get_rerank_model_instance( + self, tenant_id: str, reranking_model: RerankingModelDict | None + ) -> ModelInstance | None: if reranking_model: try: model_manager = ModelManager() - reranking_provider_name = reranking_model.get("reranking_provider_name") - reranking_model_name = reranking_model.get("reranking_model_name") + reranking_provider_name = reranking_model["reranking_provider_name"] + reranking_model_name = reranking_model["reranking_model_name"] if not reranking_provider_name or not reranking_model_name: return None rerank_model_instance = model_manager.get_model_instance( diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index e8a3a05e19..7f6ecc3d3f 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -1,19 +1,20 @@ import concurrent.futures import logging from concurrent.futures import ThreadPoolExecutor -from typing import Any +from typing import Any, NotRequired from flask import Flask, current_app from sqlalchemy import select from sqlalchemy.orm import Session, load_only +from typing_extensions import TypedDict from configs import dify_config from core.db.session_factory import session_factory from core.model_manager import ModelManager -from core.rag.data_post_processor.data_post_processor import DataPostProcessor +from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector -from core.rag.embedding.retrieval import RetrievalChildChunk, RetrievalSegments +from core.rag.embedding.retrieval import AttachmentInfoDict, RetrievalChildChunk, RetrievalSegments from core.rag.entities.metadata_entities import MetadataCondition from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.index_type import IndexStructureType @@ -35,7 +36,46 @@ from models.dataset import Document as DatasetDocument from models.model import UploadFile from services.external_knowledge_service import ExternalDatasetService -default_retrieval_model = { + +class SegmentAttachmentResult(TypedDict): + attachment_info: AttachmentInfoDict + segment_id: str + + +class SegmentAttachmentInfoResult(TypedDict): + attachment_id: str + attachment_info: AttachmentInfoDict + segment_id: str + + +class ChildChunkDetail(TypedDict): + id: str + content: str + position: int + score: float + + +class SegmentChildMapDetail(TypedDict): + max_score: float + child_chunks: list[ChildChunkDetail] + + +class SegmentRecord(TypedDict): + segment: DocumentSegment + score: NotRequired[float] + child_chunks: NotRequired[list[ChildChunkDetail]] + files: NotRequired[list[AttachmentInfoDict]] + + +class DefaultRetrievalModelDict(TypedDict): + search_method: RetrievalMethod | str + reranking_enable: bool + reranking_model: RerankingModelDict + top_k: int + score_threshold_enabled: bool + + +default_retrieval_model: DefaultRetrievalModelDict = { "search_method": RetrievalMethod.SEMANTIC_SEARCH, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, @@ -56,9 +96,9 @@ class RetrievalService: query: str, top_k: int = 4, score_threshold: float | None = 0.0, - reranking_model: dict | None = None, + reranking_model: RerankingModelDict | None = None, reranking_mode: str = "reranking_model", - weights: dict | None = None, + weights: WeightsDict | None = None, document_ids_filter: list[str] | None = None, attachment_ids: list | None = None, ): @@ -235,7 +275,7 @@ class RetrievalService: query: str, top_k: int, score_threshold: float | None, - reranking_model: dict | None, + reranking_model: RerankingModelDict | None, all_documents: list, retrieval_method: RetrievalMethod, exceptions: list, @@ -277,8 +317,8 @@ class RetrievalService: if documents: if ( reranking_model - and reranking_model.get("reranking_model_name") - and reranking_model.get("reranking_provider_name") + and reranking_model["reranking_model_name"] + and reranking_model["reranking_provider_name"] and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH ): data_post_processor = DataPostProcessor( @@ -288,8 +328,8 @@ class RetrievalService: model_manager = ModelManager() is_support_vision = model_manager.check_model_support_vision( tenant_id=dataset.tenant_id, - provider=reranking_model.get("reranking_provider_name") or "", - model=reranking_model.get("reranking_model_name") or "", + provider=reranking_model["reranking_provider_name"], + model=reranking_model["reranking_model_name"], model_type=ModelType.RERANK, ) if is_support_vision: @@ -329,7 +369,7 @@ class RetrievalService: query: str, top_k: int, score_threshold: float | None, - reranking_model: dict | None, + reranking_model: RerankingModelDict | None, all_documents: list, retrieval_method: str, exceptions: list, @@ -349,8 +389,8 @@ class RetrievalService: if documents: if ( reranking_model - and reranking_model.get("reranking_model_name") - and reranking_model.get("reranking_provider_name") + and reranking_model["reranking_model_name"] + and reranking_model["reranking_provider_name"] and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH ): data_post_processor = DataPostProcessor( @@ -459,7 +499,7 @@ class RetrievalService: segment_ids: list[str] = [] index_node_segments: list[DocumentSegment] = [] segments: list[DocumentSegment] = [] - attachment_map: dict[str, list[dict[str, Any]]] = {} + attachment_map: dict[str, list[AttachmentInfoDict]] = {} child_chunk_map: dict[str, list[ChildChunk]] = {} doc_segment_map: dict[str, list[str]] = {} segment_summary_map: dict[str, str] = {} # Map segment_id to summary content @@ -544,12 +584,12 @@ class RetrievalService: segment_summary_map[summary.chunk_id] = summary.summary_content include_segment_ids = set() - segment_child_map: dict[str, dict[str, Any]] = {} - records: list[dict[str, Any]] = [] + segment_child_map: dict[str, SegmentChildMapDetail] = {} + records: list[SegmentRecord] = [] for segment in segments: child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, []) - attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, []) + attachment_infos: list[AttachmentInfoDict] = attachment_map.get(segment.id, []) ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(segment.document_id) if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: @@ -560,14 +600,14 @@ class RetrievalService: max_score = summary_score_map.get(segment.id, 0.0) if child_chunks or attachment_infos: - child_chunk_details = [] + child_chunk_details: list[ChildChunkDetail] = [] for child_chunk in child_chunks: child_document: Document | None = doc_to_document_map.get(child_chunk.index_node_id) if child_document: child_score = child_document.metadata.get("score", 0.0) else: child_score = 0.0 - child_chunk_detail = { + child_chunk_detail: ChildChunkDetail = { "id": child_chunk.id, "content": child_chunk.content, "position": child_chunk.position, @@ -580,7 +620,7 @@ class RetrievalService: if file_document: max_score = max(max_score, file_document.metadata.get("score", 0.0)) - map_detail = { + map_detail: SegmentChildMapDetail = { "max_score": max_score, "child_chunks": child_chunk_details, } @@ -593,7 +633,7 @@ class RetrievalService: "max_score": summary_score, "child_chunks": [], } - record: dict[str, Any] = { + record: SegmentRecord = { "segment": segment, } records.append(record) @@ -617,19 +657,19 @@ class RetrievalService: if file_doc: max_score = max(max_score, file_doc.metadata.get("score", 0.0)) - record = { + another_record: SegmentRecord = { "segment": segment, "score": max_score, } - records.append(record) + records.append(another_record) # Add child chunks information to records for record in records: if record["segment"].id in segment_child_map: - record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore - record["score"] = segment_child_map[record["segment"].id]["max_score"] # type: ignore + record["child_chunks"] = segment_child_map[record["segment"].id]["child_chunks"] + record["score"] = segment_child_map[record["segment"].id]["max_score"] if record["segment"].id in attachment_map: - record["files"] = attachment_map[record["segment"].id] # type: ignore[assignment] + record["files"] = attachment_map[record["segment"].id] result: list[RetrievalSegments] = [] for record in records: @@ -693,9 +733,9 @@ class RetrievalService: query: str | None = None, top_k: int = 4, score_threshold: float | None = 0.0, - reranking_model: dict | None = None, + reranking_model: RerankingModelDict | None = None, reranking_mode: str = "reranking_model", - weights: dict | None = None, + weights: WeightsDict | None = None, document_ids_filter: list[str] | None = None, attachment_id: str | None = None, ): @@ -807,7 +847,7 @@ class RetrievalService: @classmethod def get_segment_attachment_info( cls, dataset_id: str, tenant_id: str, attachment_id: str, session: Session - ) -> dict[str, Any] | None: + ) -> SegmentAttachmentResult | None: upload_file = session.query(UploadFile).where(UploadFile.id == attachment_id).first() if upload_file: attachment_binding = ( @@ -816,7 +856,7 @@ class RetrievalService: .first() ) if attachment_binding: - attachment_info = { + attachment_info: AttachmentInfoDict = { "id": upload_file.id, "name": upload_file.name, "extension": "." + upload_file.extension, @@ -828,8 +868,10 @@ class RetrievalService: return None @classmethod - def get_segment_attachment_infos(cls, attachment_ids: list[str], session: Session) -> list[dict[str, Any]]: - attachment_infos = [] + def get_segment_attachment_infos( + cls, attachment_ids: list[str], session: Session + ) -> list[SegmentAttachmentInfoResult]: + attachment_infos: list[SegmentAttachmentInfoResult] = [] upload_files = session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all() if upload_files: upload_file_ids = [upload_file.id for upload_file in upload_files] @@ -843,7 +885,7 @@ class RetrievalService: if attachment_bindings: for upload_file in upload_files: attachment_binding = attachment_binding_map.get(upload_file.id) - attachment_info = { + info: AttachmentInfoDict = { "id": upload_file.id, "name": upload_file.name, "extension": "." + upload_file.extension, @@ -855,7 +897,7 @@ class RetrievalService: attachment_infos.append( { "attachment_id": attachment_binding.attachment_id, - "attachment_info": attachment_info, + "attachment_info": info, "segment_id": attachment_binding.segment_id, } ) diff --git a/api/core/rag/embedding/retrieval.py b/api/core/rag/embedding/retrieval.py index f6834ab87b..030237559d 100644 --- a/api/core/rag/embedding/retrieval.py +++ b/api/core/rag/embedding/retrieval.py @@ -1,8 +1,18 @@ from pydantic import BaseModel +from typing_extensions import TypedDict from models.dataset import DocumentSegment +class AttachmentInfoDict(TypedDict): + id: str + name: str + extension: str + mime_type: str + source_url: str + size: int + + class RetrievalChildChunk(BaseModel): """Retrieval segments.""" @@ -19,5 +29,5 @@ class RetrievalSegments(BaseModel): segment: DocumentSegment child_chunks: list[RetrievalChildChunk] | None = None score: float | None = None - files: list[dict[str, str | int]] | None = None + files: list[AttachmentInfoDict] | None = None summary: str | None = None # Summary content if retrieved via summary index diff --git a/api/core/rag/index_processor/index_processor.py b/api/core/rag/index_processor/index_processor.py index a7c42c5a4e..d9145023ac 100644 --- a/api/core/rag/index_processor/index_processor.py +++ b/api/core/rag/index_processor/index_processor.py @@ -9,6 +9,7 @@ from flask import current_app from sqlalchemy import delete, func, select from core.db.session_factory import session_factory +from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError from core.workflow.nodes.knowledge_index.protocols import Preview, PreviewItem, QaPreview from models.dataset import Dataset, Document, DocumentSegment @@ -51,7 +52,7 @@ class IndexProcessor: original_document_id: str, chunks: Mapping[str, Any], batch: Any, - summary_index_setting: dict | None = None, + summary_index_setting: SummaryIndexSettingDict | None = None, ): with session_factory.create_session() as session: document = session.query(Document).filter_by(id=document_id).first() @@ -131,7 +132,12 @@ class IndexProcessor: } def get_preview_output( - self, chunks: Any, dataset_id: str, document_id: str, chunk_structure: str, summary_index_setting: dict | None + self, + chunks: Any, + dataset_id: str, + document_id: str, + chunk_structure: str, + summary_index_setting: SummaryIndexSettingDict | None, ) -> Preview: doc_language = None with session_factory.create_session() as session: diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index f2191f3702..a435dfc46a 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -7,14 +7,16 @@ import os import re from abc import ABC, abstractmethod from collections.abc import Mapping -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, NotRequired, Optional from urllib.parse import unquote, urlparse import httpx +from typing_extensions import TypedDict from configs import dify_config from core.entities.knowledge_entities import PreviewDetail from core.helper import ssrf_proxy +from core.rag.data_post_processor.data_post_processor import RerankingModelDict from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.index_processor.constant.doc_type import DocType from core.rag.models.document import AttachmentDocument, Document @@ -35,6 +37,13 @@ if TYPE_CHECKING: from core.model_manager import ModelInstance +class SummaryIndexSettingDict(TypedDict): + enable: bool + model_name: NotRequired[str] + model_provider_name: NotRequired[str] + summary_prompt: NotRequired[str] + + class BaseIndexProcessor(ABC): """Interface for extract files.""" @@ -51,7 +60,7 @@ class BaseIndexProcessor(ABC): self, tenant_id: str, preview_texts: list[PreviewDetail], - summary_index_setting: dict, + summary_index_setting: SummaryIndexSettingDict, doc_language: str | None = None, ) -> list[PreviewDetail]: """ @@ -98,7 +107,7 @@ class BaseIndexProcessor(ABC): dataset: Dataset, top_k: int, score_threshold: float, - reranking_model: dict, + reranking_model: RerankingModelDict, ) -> list[Document]: raise NotImplementedError diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 9c21dad488..80163b1707 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -14,6 +14,7 @@ from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT from core.model_manager import ModelInstance from core.provider_manager import ProviderManager from core.rag.cleaner.clean_processor import CleanProcessor +from core.rag.data_post_processor.data_post_processor import RerankingModelDict from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.vdb.vector_factory import Vector @@ -22,7 +23,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.index_type import IndexStructureType -from core.rag.index_processor.index_processor_base import BaseIndexProcessor +from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.text_processing_utils import remove_leading_symbols @@ -175,7 +176,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): dataset: Dataset, top_k: int, score_threshold: float, - reranking_model: dict, + reranking_model: RerankingModelDict, ) -> list[Document]: # Set search parameters. results = RetrievalService.retrieve( @@ -278,7 +279,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): self, tenant_id: str, preview_texts: list[PreviewDetail], - summary_index_setting: dict, + summary_index_setting: SummaryIndexSettingDict, doc_language: str | None = None, ) -> list[PreviewDetail]: """ @@ -362,7 +363,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): def generate_summary( tenant_id: str, text: str, - summary_index_setting: dict | None = None, + summary_index_setting: SummaryIndexSettingDict | None = None, segment_id: str | None = None, document_language: str | None = None, ) -> tuple[str, LLMUsage]: diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index 367f0aec00..df0761ca73 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -11,6 +11,7 @@ from core.db.session_factory import session_factory from core.entities.knowledge_entities import PreviewDetail from core.model_manager import ModelInstance from core.rag.cleaner.clean_processor import CleanProcessor +from core.rag.data_post_processor.data_post_processor import RerankingModelDict from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.vdb.vector_factory import Vector from core.rag.docstore.dataset_docstore import DatasetDocumentStore @@ -18,7 +19,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.index_type import IndexStructureType -from core.rag.index_processor.index_processor_base import BaseIndexProcessor +from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict from core.rag.models.document import AttachmentDocument, ChildDocument, Document, ParentChildStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db @@ -215,7 +216,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): dataset: Dataset, top_k: int, score_threshold: float, - reranking_model: dict, + reranking_model: RerankingModelDict, ) -> list[Document]: # Set search parameters. results = RetrievalService.retrieve( @@ -361,7 +362,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): self, tenant_id: str, preview_texts: list[PreviewDetail], - summary_index_setting: dict, + summary_index_setting: SummaryIndexSettingDict, doc_language: str | None = None, ) -> list[PreviewDetail]: """ diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 503cce2132..62f88b7760 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -15,13 +15,14 @@ from core.db.session_factory import session_factory from core.entities.knowledge_entities import PreviewDetail from core.llm_generator.llm_generator import LLMGenerator from core.rag.cleaner.clean_processor import CleanProcessor +from core.rag.data_post_processor.data_post_processor import RerankingModelDict from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.vdb.vector_factory import Vector from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.index_processor.constant.index_type import IndexStructureType -from core.rag.index_processor.index_processor_base import BaseIndexProcessor +from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict from core.rag.models.document import AttachmentDocument, Document, QAStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.text_processing_utils import remove_leading_symbols @@ -185,7 +186,7 @@ class QAIndexProcessor(BaseIndexProcessor): dataset: Dataset, top_k: int, score_threshold: float, - reranking_model: dict, + reranking_model: RerankingModelDict, ): # Set search parameters. results = RetrievalService.retrieve( @@ -244,7 +245,7 @@ class QAIndexProcessor(BaseIndexProcessor): self, tenant_id: str, preview_texts: list[PreviewDetail], - summary_index_setting: dict, + summary_index_setting: SummaryIndexSettingDict, doc_language: str | None = None, ) -> list[PreviewDetail]: """ diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 4c96b63f25..c44e9b847b 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -31,7 +31,7 @@ from core.ops.utils import measure_time from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.prompt.simple_prompt_transform import ModelMode -from core.rag.data_post_processor.data_post_processor import DataPostProcessor +from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler from core.rag.datasource.retrieval_service import RetrievalService from core.rag.entities.citation_metadata import RetrievalSourceMetadata @@ -83,7 +83,7 @@ from models.dataset import ( ) from models.dataset import Document as DatasetDocument from models.dataset import Document as DocumentModel -from models.enums import CreatorUserRole +from models.enums import CreatorUserRole, DatasetQuerySource from services.external_knowledge_service import ExternalDatasetService from services.feature_service import FeatureService @@ -727,8 +727,8 @@ class DatasetRetrieval: top_k: int, score_threshold: float, reranking_mode: str, - reranking_model: dict | None = None, - weights: dict[str, Any] | None = None, + reranking_model: RerankingModelDict | None = None, + weights: WeightsDict | None = None, reranking_enable: bool = True, message_id: str | None = None, metadata_filter_document_ids: dict[str, list[str]] | None = None, @@ -1008,7 +1008,7 @@ class DatasetRetrieval: dataset_query = DatasetQuery( dataset_id=dataset_id, content=json.dumps(contents), - source="app", + source=DatasetQuerySource.APP, source_app_id=app_id, created_by_role=CreatorUserRole(user_from), created_by=user_id, @@ -1181,8 +1181,8 @@ class DatasetRetrieval: hit_callbacks=[hit_callback], return_resource=return_resource, retriever_from=invoke_from.to_source(), - reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"), - reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"), + reranking_provider_name=retrieve_config.reranking_model["reranking_provider_name"], + reranking_model_name=retrieve_config.reranking_model["reranking_model_name"], ) tools.append(tool) @@ -1685,8 +1685,8 @@ class DatasetRetrieval: tenant_id: str, reranking_enable: bool, reranking_mode: str, - reranking_model: dict | None, - weights: dict[str, Any] | None, + reranking_model: RerankingModelDict | None, + weights: WeightsDict | None, top_k: int, score_threshold: float, query: str | None, diff --git a/api/core/rag/summary_index/summary_index.py b/api/core/rag/summary_index/summary_index.py index 79d7821b4e..31d21dbeee 100644 --- a/api/core/rag/summary_index/summary_index.py +++ b/api/core/rag/summary_index/summary_index.py @@ -2,6 +2,7 @@ import concurrent.futures import logging from core.db.session_factory import session_factory +from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary from services.summary_index_service import SummaryIndexService from tasks.generate_summary_index_task import generate_summary_index_task @@ -11,7 +12,11 @@ logger = logging.getLogger(__name__) class SummaryIndex: def generate_and_vectorize_summary( - self, dataset_id: str, document_id: str, is_preview: bool, summary_index_setting: dict | None = None + self, + dataset_id: str, + document_id: str, + is_preview: bool, + summary_index_setting: SummaryIndexSettingDict | None = None, ) -> None: if is_preview: with session_factory.create_session() as session: diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 7f7787b92a..23a877b7e3 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -72,6 +72,11 @@ class ApiProviderControllerItem(TypedDict): controller: ApiToolProviderController +class EmojiIconDict(TypedDict): + background: str + content: str + + class ToolManager: _builtin_provider_lock = Lock() _hardcoded_providers: dict[str, BuiltinToolProviderController] = {} @@ -916,7 +921,7 @@ class ToolManager: ) @classmethod - def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str]: + def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict: try: workflow_provider: WorkflowToolProvider | None = ( db.session.query(WorkflowToolProvider) @@ -933,7 +938,7 @@ class ToolManager: return {"background": "#252525", "content": "\ud83d\ude01"} @classmethod - def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str]: + def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict: try: api_provider: ApiToolProvider | None = ( db.session.query(ApiToolProvider) @@ -950,7 +955,7 @@ class ToolManager: return {"background": "#252525", "content": "\ud83d\ude01"} @classmethod - def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str] | str: + def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict | dict[str, str] | str: try: with Session(db.engine) as session: mcp_service = MCPToolManageService(session=session) @@ -970,7 +975,7 @@ class ToolManager: tenant_id: str, provider_type: ToolProviderType, provider_id: str, - ) -> str | Mapping[str, str]: + ) -> str | EmojiIconDict | dict[str, str]: """ get the tool icon diff --git a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py index 3dbbbe6563..c2b520fa99 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py @@ -1,5 +1,4 @@ import threading -from typing import Any from flask import Flask, current_app from pydantic import BaseModel, Field @@ -13,11 +12,12 @@ from core.rag.models.document import Document as RagDocument from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool +from core.tools.utils.dataset_retriever.dataset_retriever_tool import DefaultRetrievalModelDict from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment -default_retrieval_model: dict[str, Any] = { +default_retrieval_model: DefaultRetrievalModelDict = { "search_method": RetrievalMethod.SEMANTIC_SEARCH, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index 057ec41f65..429b7e6622 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -1,9 +1,10 @@ -from typing import Any, cast +from typing import NotRequired, TypedDict, cast from pydantic import BaseModel, Field from sqlalchemy import select from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig +from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict from core.rag.datasource.retrieval_service import RetrievalService from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities.context_entities import DocumentContext @@ -16,7 +17,19 @@ from models.dataset import Dataset from models.dataset import Document as DatasetDocument from services.external_knowledge_service import ExternalDatasetService -default_retrieval_model: dict[str, Any] = { + +class DefaultRetrievalModelDict(TypedDict): + search_method: RetrievalMethod + reranking_enable: bool + reranking_model: RerankingModelDict + reranking_mode: NotRequired[str] + weights: NotRequired[WeightsDict | None] + score_threshold: NotRequired[float] + top_k: int + score_threshold_enabled: bool + + +default_retrieval_model: DefaultRetrievalModelDict = { "search_method": RetrievalMethod.SEMANTIC_SEARCH, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, @@ -125,7 +138,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): if metadata_condition and not document_ids_filter: return "" # get retrieval model , if the model is not setting , using default - retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model + retrieval_model = dataset.retrieval_model or default_retrieval_model retrieval_resource_list: list[RetrievalSourceMetadata] = [] if dataset.indexing_technique == "economy": # use keyword table query diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index fc2b41d960..f7484b93fb 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -1,4 +1,5 @@ import re +from collections.abc import Mapping from json import dumps as json_dumps from json import loads as json_loads from json.decoder import JSONDecodeError @@ -20,10 +21,18 @@ class InterfaceDict(TypedDict): operation: dict[str, Any] +class OpenAPISpecDict(TypedDict): + openapi: str + info: dict[str, str] + servers: list[dict[str, Any]] + paths: dict[str, Any] + components: dict[str, Any] + + class ApiBasedToolSchemaParser: @staticmethod def parse_openapi_to_tool_bundle( - openapi: dict, extra_info: dict | None = None, warning: dict | None = None + openapi: Mapping[str, Any], extra_info: dict | None = None, warning: dict | None = None ) -> list[ApiToolBundle]: warning = warning if warning is not None else {} extra_info = extra_info if extra_info is not None else {} @@ -277,7 +286,7 @@ class ApiBasedToolSchemaParser: @staticmethod def parse_swagger_to_openapi( swagger: dict, extra_info: dict | None = None, warning: dict | None = None - ) -> dict[str, Any]: + ) -> OpenAPISpecDict: warning = warning or {} """ parse swagger to openapi @@ -293,7 +302,7 @@ class ApiBasedToolSchemaParser: if len(servers) == 0: raise ToolApiSchemaError("No server found in the swagger yaml.") - converted_openapi: dict[str, Any] = { + converted_openapi: OpenAPISpecDict = { "openapi": "3.0.0", "info": { "title": info.get("title", "Swagger"), diff --git a/api/core/workflow/nodes/knowledge_index/entities.py b/api/core/workflow/nodes/knowledge_index/entities.py index 8b00746268..8d2e9bf3cb 100644 --- a/api/core/workflow/nodes/knowledge_index/entities.py +++ b/api/core/workflow/nodes/knowledge_index/entities.py @@ -2,6 +2,7 @@ from typing import Literal, Union from pydantic import BaseModel +from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE from dify_graph.entities.base_node_data import BaseNodeData @@ -161,4 +162,4 @@ class KnowledgeIndexNodeData(BaseNodeData): chunk_structure: str index_chunk_variable_selector: list[str] indexing_technique: str | None = None - summary_index_setting: dict | None = None + summary_index_setting: SummaryIndexSettingDict | None = None diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index 0a74847bc1..4ea9091c5b 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -3,6 +3,7 @@ from collections.abc import Mapping from typing import TYPE_CHECKING, Any from core.rag.index_processor.index_processor import IndexProcessor +from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.summary_index.summary_index import SummaryIndex from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE from dify_graph.entities.graph_config import NodeConfigDict @@ -127,7 +128,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): is_preview: bool, batch: Any, chunks: Mapping[str, Any], - summary_index_setting: dict | None = None, + summary_index_setting: SummaryIndexSettingDict | None = None, ): if not document_id: raise KnowledgeIndexNodeError("document_id is required.") diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 9c3b9aacbf..80f59140be 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -9,6 +9,7 @@ from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal from core.app.app_config.entities import DatasetRetrieveConfigEntity +from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from dify_graph.entities import GraphInitParams from dify_graph.entities.graph_config import NodeConfigDict @@ -201,8 +202,8 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: if node_data.multiple_retrieval_config is None: raise ValueError("multiple_retrieval_config is required") - reranking_model = None - weights = None + reranking_model: RerankingModelDict | None = None + weights: WeightsDict | None = None match node_data.multiple_retrieval_config.reranking_mode: case "reranking_model": if node_data.multiple_retrieval_config.reranking_model: diff --git a/api/core/workflow/nodes/knowledge_retrieval/retrieval.py b/api/core/workflow/nodes/knowledge_retrieval/retrieval.py index f964f79582..e1311ab962 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/retrieval.py +++ b/api/core/workflow/nodes/knowledge_retrieval/retrieval.py @@ -2,6 +2,7 @@ from typing import Any, Literal, Protocol from pydantic import BaseModel, Field +from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict from dify_graph.model_runtime.entities import LLMUsage from dify_graph.nodes.llm.entities import ModelConfig @@ -75,8 +76,8 @@ class KnowledgeRetrievalRequest(BaseModel): top_k: int = Field(default=0, description="Number of top results to return") score_threshold: float = Field(default=0.0, description="Minimum relevance score threshold") reranking_mode: str = Field(default="reranking_model", description="Reranking strategy") - reranking_model: dict | None = Field(default=None, description="Reranking model configuration") - weights: dict[str, Any] | None = Field(default=None, description="Weights for weighted score reranking") + reranking_model: RerankingModelDict | None = Field(default=None, description="Reranking model configuration") + weights: WeightsDict | None = Field(default=None, description="Weights for weighted score reranking") reranking_enable: bool = Field(default=True, description="Whether reranking is enabled") attachment_ids: list[str] | None = Field(default=None, description="List of attachment file IDs for retrieval") diff --git a/api/dify_graph/nodes/http_request/node.py b/api/dify_graph/nodes/http_request/node.py index b17c820a80..486ae241ee 100644 --- a/api/dify_graph/nodes/http_request/node.py +++ b/api/dify_graph/nodes/http_request/node.py @@ -101,7 +101,6 @@ class HttpRequestNode(Node[HttpRequestNodeData]): timeout=self._get_request_timeout(self.node_data), variable_pool=self.graph_runtime_state.variable_pool, http_request_config=self._http_request_config, - max_retries=0, ssl_verify=self.node_data.ssl_verify, http_client=self._http_client, file_manager=self._file_manager, diff --git a/api/dify_graph/variables/types.py b/api/dify_graph/variables/types.py index df8430de5d..53bf495a27 100644 --- a/api/dify_graph/variables/types.py +++ b/api/dify_graph/variables/types.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any from dify_graph.file.models import File if TYPE_CHECKING: - pass + from dify_graph.variables.segments import Segment class ArrayValidation(StrEnum): @@ -219,7 +219,7 @@ class SegmentType(StrEnum): return _ARRAY_ELEMENT_TYPES_MAPPING.get(self) @staticmethod - def get_zero_value(t: SegmentType): + def get_zero_value(t: SegmentType) -> Segment: # Lazy import to avoid circular dependency from factories import variable_factory diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py index 8778f5cafe..76de5a0740 100644 --- a/api/events/event_handlers/create_document_index.py +++ b/api/events/event_handlers/create_document_index.py @@ -10,6 +10,7 @@ from events.document_index_event import document_index_created from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.dataset import Document +from models.enums import IndexingStatus logger = logging.getLogger(__name__) @@ -35,7 +36,7 @@ def handle(sender, **kwargs): if not document: raise NotFound("Document not found") - document.indexing_status = "parsing" + document.indexing_status = IndexingStatus.PARSING document.processing_started_at = naive_utc_now() documents.append(document) db.session.add(document) diff --git a/api/extensions/ext_fastopenapi.py b/api/extensions/ext_fastopenapi.py index ab4d23a072..569203e974 100644 --- a/api/extensions/ext_fastopenapi.py +++ b/api/extensions/ext_fastopenapi.py @@ -1,3 +1,5 @@ +from typing import Protocol, cast + from fastopenapi.routers import FlaskRouter from flask_cors import CORS @@ -9,6 +11,10 @@ from extensions.ext_blueprints import AUTHENTICATED_HEADERS, EXPOSED_HEADERS DOCS_PREFIX = "/fastopenapi" +class SupportsIncludeRouter(Protocol): + def include_router(self, router: object, *, prefix: str = "") -> None: ... + + def init_app(app: DifyApp) -> None: docs_enabled = dify_config.SWAGGER_UI_ENABLED docs_url = f"{DOCS_PREFIX}/docs" if docs_enabled else None @@ -36,7 +42,7 @@ def init_app(app: DifyApp) -> None: _ = remote_files _ = setup - router.include_router(console_router, prefix="/console/api") + cast(SupportsIncludeRouter, router).include_router(console_router, prefix="/console/api") CORS( app, resources={r"/console/api/.*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}}, diff --git a/api/extensions/otel/runtime.py b/api/extensions/otel/runtime.py index b1c703f944..149d76b07b 100644 --- a/api/extensions/otel/runtime.py +++ b/api/extensions/otel/runtime.py @@ -5,7 +5,7 @@ from typing import Union from celery.signals import worker_init from flask_login import user_loaded_from_request, user_logged_in -from opentelemetry import trace +from opentelemetry import metrics, trace from opentelemetry.propagate import set_global_textmap from opentelemetry.propagators.b3 import B3MultiFormat from opentelemetry.propagators.composite import CompositePropagator @@ -31,9 +31,29 @@ def setup_context_propagation() -> None: def shutdown_tracer() -> None: + flush_telemetry() + + +def flush_telemetry() -> None: + """ + Best-effort flush for telemetry providers. + + This is mainly used by short-lived command processes (e.g. Kubernetes CronJob) + so counters/histograms are exported before the process exits. + """ provider = trace.get_tracer_provider() if hasattr(provider, "force_flush"): - provider.force_flush() + try: + provider.force_flush() + except Exception: + logger.exception("otel: failed to flush trace provider") + + metric_provider = metrics.get_meter_provider() + if hasattr(metric_provider, "force_flush"): + try: + metric_provider.force_flush() + except Exception: + logger.exception("otel: failed to flush metric provider") def is_celery_worker(): diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 255e5cde83..14a56bf4a2 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -55,7 +55,7 @@ class TypeMismatchError(Exception): # Define the constant -SEGMENT_TO_VARIABLE_MAP = { +SEGMENT_TO_VARIABLE_MAP: Mapping[type[Segment], type[VariableBase]] = { ArrayAnySegment: ArrayAnyVariable, ArrayBooleanSegment: ArrayBooleanVariable, ArrayFileSegment: ArrayFileVariable, @@ -296,13 +296,11 @@ def segment_to_variable( raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}") variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type] - return cast( - VariableBase, - variable_class( - id=id, - name=name, - description=description, - value=segment.value, - selector=list(selector), - ), + return variable_class( + id=id, + name=name, + description=description, + value_type=segment.value_type, + value=segment.value, + selector=list(selector), ) diff --git a/api/libs/helper.py b/api/libs/helper.py index 6151eb0940..e7572cc025 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -32,6 +32,11 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +def _stream_with_request_context(response: object) -> Any: + """Bridge Flask's loosely-typed streaming helper without leaking casts into callers.""" + return cast(Any, stream_with_context)(response) + + def escape_like_pattern(pattern: str) -> str: """ Escape special characters in a string for safe use in SQL LIKE patterns. @@ -286,22 +291,32 @@ def generate_text_hash(text: str) -> str: return sha256(hash_text.encode()).hexdigest() -def compact_generate_response(response: Union[Mapping, Generator, RateLimitGenerator]) -> Response: - if isinstance(response, dict): +def compact_generate_response( + response: Mapping[str, Any] | Generator[str, None, None] | RateLimitGenerator, +) -> Response: + if isinstance(response, Mapping): return Response( response=json.dumps(jsonable_encoder(response)), status=200, content_type="application/json; charset=utf-8", ) else: + stream_response = response - def generate() -> Generator: - yield from response + def generate() -> Generator[str, None, None]: + yield from stream_response - return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream") + return Response( + _stream_with_request_context(generate()), + status=200, + mimetype="text/event-stream", + ) -def length_prefixed_response(magic_number: int, response: Union[Mapping, Generator, RateLimitGenerator]) -> Response: +def length_prefixed_response( + magic_number: int, + response: Mapping[str, Any] | BaseModel | Generator[str | bytes, None, None] | RateLimitGenerator, +) -> Response: """ This function is used to return a response with a length prefix. Magic number is a one byte number that indicates the type of the response. @@ -332,7 +347,7 @@ def length_prefixed_response(magic_number: int, response: Union[Mapping, Generat # | Magic Number 1byte | Reserved 1byte | Header Length 2bytes | Data Length 4bytes | Reserved 6bytes | Data return struct.pack(" Generator: - for chunk in response: + stream_response = response + + def generate() -> Generator[bytes, None, None]: + for chunk in stream_response: if isinstance(chunk, str): yield pack_response_with_length_prefix(chunk.encode("utf-8")) else: yield pack_response_with_length_prefix(chunk) - return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream") + return Response( + _stream_with_request_context(generate()), + status=200, + mimetype="text/event-stream", + ) class TokenManager: diff --git a/api/libs/login.py b/api/libs/login.py index 69e2b58426..bd5cb5f30d 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -77,12 +77,14 @@ def login_required(func: Callable[P, R]) -> Callable[P, R | ResponseReturnValue] @wraps(func) def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R | ResponseReturnValue: if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED: - pass - elif current_user is not None and not current_user.is_authenticated: + return current_app.ensure_sync(func)(*args, **kwargs) + + user = _get_user() + if user is None or not user.is_authenticated: return current_app.login_manager.unauthorized() # type: ignore # we put csrf validation here for less conflicts # TODO: maybe find a better place for it. - check_csrf_token(request, current_user.id) + check_csrf_token(request, user.id) return current_app.ensure_sync(func)(*args, **kwargs) return decorated_view diff --git a/api/libs/module_loading.py b/api/libs/module_loading.py index 9f74943433..7063a115b0 100644 --- a/api/libs/module_loading.py +++ b/api/libs/module_loading.py @@ -7,9 +7,10 @@ https://github.com/django/django/blob/main/django/utils/module_loading.py import sys from importlib import import_module +from typing import Any -def cached_import(module_path: str, class_name: str): +def cached_import(module_path: str, class_name: str) -> Any: """ Import a module and return the named attribute/class from it, with caching. @@ -20,16 +21,14 @@ def cached_import(module_path: str, class_name: str): Returns: The imported attribute/class """ - if not ( - (module := sys.modules.get(module_path)) - and (spec := getattr(module, "__spec__", None)) - and getattr(spec, "_initializing", False) is False - ): + module = sys.modules.get(module_path) + spec = getattr(module, "__spec__", None) if module is not None else None + if module is None or getattr(spec, "_initializing", False): module = import_module(module_path) return getattr(module, class_name) -def import_string(dotted_path: str): +def import_string(dotted_path: str) -> Any: """ Import a dotted module path and return the attribute/class designated by the last name in the path. Raise ImportError if the import failed. diff --git a/api/libs/oauth.py b/api/libs/oauth.py index 889a5a3248..efce13f6f1 100644 --- a/api/libs/oauth.py +++ b/api/libs/oauth.py @@ -1,7 +1,48 @@ +import sys import urllib.parse from dataclasses import dataclass +from typing import NotRequired import httpx +from pydantic import TypeAdapter + +if sys.version_info >= (3, 12): + from typing import TypedDict +else: + from typing_extensions import TypedDict + +JsonObject = dict[str, object] +JsonObjectList = list[JsonObject] + +JSON_OBJECT_ADAPTER = TypeAdapter(JsonObject) +JSON_OBJECT_LIST_ADAPTER = TypeAdapter(JsonObjectList) + + +class AccessTokenResponse(TypedDict, total=False): + access_token: str + + +class GitHubEmailRecord(TypedDict, total=False): + email: str + primary: bool + + +class GitHubRawUserInfo(TypedDict): + id: int | str + login: str + name: NotRequired[str] + email: NotRequired[str] + + +class GoogleRawUserInfo(TypedDict): + sub: str + email: str + + +ACCESS_TOKEN_RESPONSE_ADAPTER = TypeAdapter(AccessTokenResponse) +GITHUB_RAW_USER_INFO_ADAPTER = TypeAdapter(GitHubRawUserInfo) +GITHUB_EMAIL_RECORDS_ADAPTER = TypeAdapter(list[GitHubEmailRecord]) +GOOGLE_RAW_USER_INFO_ADAPTER = TypeAdapter(GoogleRawUserInfo) @dataclass @@ -11,26 +52,38 @@ class OAuthUserInfo: email: str +def _json_object(response: httpx.Response) -> JsonObject: + return JSON_OBJECT_ADAPTER.validate_python(response.json()) + + +def _json_list(response: httpx.Response) -> JsonObjectList: + return JSON_OBJECT_LIST_ADAPTER.validate_python(response.json()) + + class OAuth: + client_id: str + client_secret: str + redirect_uri: str + def __init__(self, client_id: str, client_secret: str, redirect_uri: str): self.client_id = client_id self.client_secret = client_secret self.redirect_uri = redirect_uri - def get_authorization_url(self): + def get_authorization_url(self, invite_token: str | None = None) -> str: raise NotImplementedError() - def get_access_token(self, code: str): + def get_access_token(self, code: str) -> str: raise NotImplementedError() - def get_raw_user_info(self, token: str): + def get_raw_user_info(self, token: str) -> JsonObject: raise NotImplementedError() def get_user_info(self, token: str) -> OAuthUserInfo: raw_info = self.get_raw_user_info(token) return self._transform_user_info(raw_info) - def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: + def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo: raise NotImplementedError() @@ -40,7 +93,7 @@ class GitHubOAuth(OAuth): _USER_INFO_URL = "https://api.github.com/user" _EMAIL_INFO_URL = "https://api.github.com/user/emails" - def get_authorization_url(self, invite_token: str | None = None): + def get_authorization_url(self, invite_token: str | None = None) -> str: params = { "client_id": self.client_id, "redirect_uri": self.redirect_uri, @@ -50,7 +103,7 @@ class GitHubOAuth(OAuth): params["state"] = invite_token return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" - def get_access_token(self, code: str): + def get_access_token(self, code: str) -> str: data = { "client_id": self.client_id, "client_secret": self.client_secret, @@ -60,7 +113,7 @@ class GitHubOAuth(OAuth): headers = {"Accept": "application/json"} response = httpx.post(self._TOKEN_URL, data=data, headers=headers) - response_json = response.json() + response_json = ACCESS_TOKEN_RESPONSE_ADAPTER.validate_python(_json_object(response)) access_token = response_json.get("access_token") if not access_token: @@ -68,23 +121,24 @@ class GitHubOAuth(OAuth): return access_token - def get_raw_user_info(self, token: str): + def get_raw_user_info(self, token: str) -> JsonObject: headers = {"Authorization": f"token {token}"} response = httpx.get(self._USER_INFO_URL, headers=headers) response.raise_for_status() - user_info = response.json() + user_info = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(_json_object(response)) email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers) - email_info = email_response.json() - primary_email: dict = next((email for email in email_info if email["primary"] == True), {}) + email_info = GITHUB_EMAIL_RECORDS_ADAPTER.validate_python(_json_list(email_response)) + primary_email = next((email for email in email_info if email.get("primary") is True), None) - return {**user_info, "email": primary_email.get("email", "")} + return {**user_info, "email": primary_email.get("email", "") if primary_email else ""} - def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: - email = raw_info.get("email") + def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo: + payload = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(raw_info) + email = payload.get("email") if not email: - email = f"{raw_info['id']}+{raw_info['login']}@users.noreply.github.com" - return OAuthUserInfo(id=str(raw_info["id"]), name=raw_info["name"], email=email) + email = f"{payload['id']}+{payload['login']}@users.noreply.github.com" + return OAuthUserInfo(id=str(payload["id"]), name=str(payload.get("name", "")), email=email) class GoogleOAuth(OAuth): @@ -92,7 +146,7 @@ class GoogleOAuth(OAuth): _TOKEN_URL = "https://oauth2.googleapis.com/token" _USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo" - def get_authorization_url(self, invite_token: str | None = None): + def get_authorization_url(self, invite_token: str | None = None) -> str: params = { "client_id": self.client_id, "response_type": "code", @@ -103,7 +157,7 @@ class GoogleOAuth(OAuth): params["state"] = invite_token return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" - def get_access_token(self, code: str): + def get_access_token(self, code: str) -> str: data = { "client_id": self.client_id, "client_secret": self.client_secret, @@ -114,7 +168,7 @@ class GoogleOAuth(OAuth): headers = {"Accept": "application/json"} response = httpx.post(self._TOKEN_URL, data=data, headers=headers) - response_json = response.json() + response_json = ACCESS_TOKEN_RESPONSE_ADAPTER.validate_python(_json_object(response)) access_token = response_json.get("access_token") if not access_token: @@ -122,11 +176,12 @@ class GoogleOAuth(OAuth): return access_token - def get_raw_user_info(self, token: str): + def get_raw_user_info(self, token: str) -> JsonObject: headers = {"Authorization": f"Bearer {token}"} response = httpx.get(self._USER_INFO_URL, headers=headers) response.raise_for_status() - return response.json() + return _json_object(response) - def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: - return OAuthUserInfo(id=str(raw_info["sub"]), name="", email=raw_info["email"]) + def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo: + payload = GOOGLE_RAW_USER_INFO_ADAPTER.validate_python(raw_info) + return OAuthUserInfo(id=str(payload["sub"]), name="", email=payload["email"]) diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index ae0ae3bcb6..d5dc35ac97 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -1,25 +1,57 @@ +import sys import urllib.parse -from typing import Any +from typing import Any, Literal import httpx from flask_login import current_user +from pydantic import TypeAdapter from sqlalchemy import select from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.source import DataSourceOauthBinding +if sys.version_info >= (3, 12): + from typing import TypedDict +else: + from typing_extensions import TypedDict + + +class NotionPageSummary(TypedDict): + page_id: str + page_name: str + page_icon: dict[str, str] | None + parent_id: str + type: Literal["page", "database"] + + +class NotionSourceInfo(TypedDict): + workspace_name: str | None + workspace_icon: str | None + workspace_id: str | None + pages: list[NotionPageSummary] + total: int + + +SOURCE_INFO_STORAGE_ADAPTER = TypeAdapter(dict[str, object]) +NOTION_SOURCE_INFO_ADAPTER = TypeAdapter(NotionSourceInfo) +NOTION_PAGE_SUMMARY_ADAPTER = TypeAdapter(NotionPageSummary) + class OAuthDataSource: + client_id: str + client_secret: str + redirect_uri: str + def __init__(self, client_id: str, client_secret: str, redirect_uri: str): self.client_id = client_id self.client_secret = client_secret self.redirect_uri = redirect_uri - def get_authorization_url(self): + def get_authorization_url(self) -> str: raise NotImplementedError() - def get_access_token(self, code: str): + def get_access_token(self, code: str) -> None: raise NotImplementedError() @@ -30,7 +62,7 @@ class NotionOAuth(OAuthDataSource): _NOTION_BLOCK_SEARCH = "https://api.notion.com/v1/blocks" _NOTION_BOT_USER = "https://api.notion.com/v1/users/me" - def get_authorization_url(self): + def get_authorization_url(self) -> str: params = { "client_id": self.client_id, "response_type": "code", @@ -39,7 +71,7 @@ class NotionOAuth(OAuthDataSource): } return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" - def get_access_token(self, code: str): + def get_access_token(self, code: str) -> None: data = {"code": code, "grant_type": "authorization_code", "redirect_uri": self.redirect_uri} headers = {"Accept": "application/json"} auth = (self.client_id, self.client_secret) @@ -54,13 +86,12 @@ class NotionOAuth(OAuthDataSource): workspace_id = response_json.get("workspace_id") # get all authorized pages pages = self.get_authorized_pages(access_token) - source_info = { - "workspace_name": workspace_name, - "workspace_icon": workspace_icon, - "workspace_id": workspace_id, - "pages": pages, - "total": len(pages), - } + source_info = self._build_source_info( + workspace_name=workspace_name, + workspace_icon=workspace_icon, + workspace_id=workspace_id, + pages=pages, + ) # save data source binding data_source_binding = db.session.scalar( select(DataSourceOauthBinding).where( @@ -70,7 +101,7 @@ class NotionOAuth(OAuthDataSource): ) ) if data_source_binding: - data_source_binding.source_info = source_info + data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info) data_source_binding.disabled = False data_source_binding.updated_at = naive_utc_now() db.session.commit() @@ -78,25 +109,24 @@ class NotionOAuth(OAuthDataSource): new_data_source_binding = DataSourceOauthBinding( tenant_id=current_user.current_tenant_id, access_token=access_token, - source_info=source_info, + source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info), provider="notion", ) db.session.add(new_data_source_binding) db.session.commit() - def save_internal_access_token(self, access_token: str): + def save_internal_access_token(self, access_token: str) -> None: workspace_name = self.notion_workspace_name(access_token) workspace_icon = None workspace_id = current_user.current_tenant_id # get all authorized pages pages = self.get_authorized_pages(access_token) - source_info = { - "workspace_name": workspace_name, - "workspace_icon": workspace_icon, - "workspace_id": workspace_id, - "pages": pages, - "total": len(pages), - } + source_info = self._build_source_info( + workspace_name=workspace_name, + workspace_icon=workspace_icon, + workspace_id=workspace_id, + pages=pages, + ) # save data source binding data_source_binding = db.session.scalar( select(DataSourceOauthBinding).where( @@ -106,7 +136,7 @@ class NotionOAuth(OAuthDataSource): ) ) if data_source_binding: - data_source_binding.source_info = source_info + data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info) data_source_binding.disabled = False data_source_binding.updated_at = naive_utc_now() db.session.commit() @@ -114,13 +144,13 @@ class NotionOAuth(OAuthDataSource): new_data_source_binding = DataSourceOauthBinding( tenant_id=current_user.current_tenant_id, access_token=access_token, - source_info=source_info, + source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info), provider="notion", ) db.session.add(new_data_source_binding) db.session.commit() - def sync_data_source(self, binding_id: str): + def sync_data_source(self, binding_id: str) -> None: # save data source binding data_source_binding = db.session.scalar( select(DataSourceOauthBinding).where( @@ -134,23 +164,22 @@ class NotionOAuth(OAuthDataSource): if data_source_binding: # get all authorized pages pages = self.get_authorized_pages(data_source_binding.access_token) - source_info = data_source_binding.source_info - new_source_info = { - "workspace_name": source_info["workspace_name"], - "workspace_icon": source_info["workspace_icon"], - "workspace_id": source_info["workspace_id"], - "pages": pages, - "total": len(pages), - } - data_source_binding.source_info = new_source_info + source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info) + new_source_info = self._build_source_info( + workspace_name=source_info["workspace_name"], + workspace_icon=source_info["workspace_icon"], + workspace_id=source_info["workspace_id"], + pages=pages, + ) + data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info) data_source_binding.disabled = False data_source_binding.updated_at = naive_utc_now() db.session.commit() else: raise ValueError("Data source binding not found") - def get_authorized_pages(self, access_token: str): - pages = [] + def get_authorized_pages(self, access_token: str) -> list[NotionPageSummary]: + pages: list[NotionPageSummary] = [] page_results = self.notion_page_search(access_token) database_results = self.notion_database_search(access_token) # get page detail @@ -187,7 +216,7 @@ class NotionOAuth(OAuthDataSource): "parent_id": parent_id, "type": "page", } - pages.append(page) + pages.append(NOTION_PAGE_SUMMARY_ADAPTER.validate_python(page)) # get database detail for database_result in database_results: page_id = database_result["id"] @@ -220,11 +249,11 @@ class NotionOAuth(OAuthDataSource): "parent_id": parent_id, "type": "database", } - pages.append(page) + pages.append(NOTION_PAGE_SUMMARY_ADAPTER.validate_python(page)) return pages - def notion_page_search(self, access_token: str): - results = [] + def notion_page_search(self, access_token: str) -> list[dict[str, Any]]: + results: list[dict[str, Any]] = [] next_cursor = None has_more = True @@ -249,7 +278,7 @@ class NotionOAuth(OAuthDataSource): return results - def notion_block_parent_page_id(self, access_token: str, block_id: str): + def notion_block_parent_page_id(self, access_token: str, block_id: str) -> str: headers = { "Authorization": f"Bearer {access_token}", "Notion-Version": "2022-06-28", @@ -265,7 +294,7 @@ class NotionOAuth(OAuthDataSource): return self.notion_block_parent_page_id(access_token, parent[parent_type]) return parent[parent_type] - def notion_workspace_name(self, access_token: str): + def notion_workspace_name(self, access_token: str) -> str: headers = { "Authorization": f"Bearer {access_token}", "Notion-Version": "2022-06-28", @@ -279,8 +308,8 @@ class NotionOAuth(OAuthDataSource): return user_info["workspace_name"] return "workspace" - def notion_database_search(self, access_token: str): - results = [] + def notion_database_search(self, access_token: str) -> list[dict[str, Any]]: + results: list[dict[str, Any]] = [] next_cursor = None has_more = True @@ -303,3 +332,19 @@ class NotionOAuth(OAuthDataSource): next_cursor = response_json.get("next_cursor", None) return results + + @staticmethod + def _build_source_info( + *, + workspace_name: str | None, + workspace_icon: str | None, + workspace_id: str | None, + pages: list[NotionPageSummary], + ) -> NotionSourceInfo: + return { + "workspace_name": workspace_name, + "workspace_icon": workspace_icon, + "workspace_id": workspace_id, + "pages": pages, + "total": len(pages), + } diff --git a/api/models/account.py b/api/models/account.py index 1a43c9ca17..5960ac6564 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -177,13 +177,11 @@ class Account(UserMixin, TypeBase): @classmethod def get_by_openid(cls, provider: str, open_id: str): - account_integrate = ( - db.session.query(AccountIntegrate) - .where(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id) - .one_or_none() - ) + account_integrate = db.session.execute( + select(AccountIntegrate).where(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id) + ).scalar_one_or_none() if account_integrate: - return db.session.query(Account).where(Account.id == account_integrate.account_id).one_or_none() + return db.session.scalar(select(Account).where(Account.id == account_integrate.account_id)) return None # check current_user.current_tenant.current_role in ['admin', 'owner'] diff --git a/api/models/dataset.py b/api/models/dataset.py index 8438fda25f..d0163e6984 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -8,6 +8,7 @@ import os import pickle import re import time +from collections.abc import Sequence from datetime import datetime from json import JSONDecodeError from typing import Any, TypedDict, cast @@ -30,7 +31,20 @@ from services.entities.knowledge_entities.knowledge_entities import ParentMode, from .account import Account from .base import Base, TypeBase from .engine import db -from .enums import CreatorUserRole +from .enums import ( + CollectionBindingType, + CreatorUserRole, + DatasetMetadataType, + DatasetQuerySource, + DatasetRuntimeMode, + DataSourceType, + DocumentCreatedFrom, + DocumentDocType, + IndexingStatus, + ProcessRuleMode, + SegmentStatus, + SummaryStatus, +) from .model import App, Tag, TagBinding, UploadFile from .types import AdjustedJSON, BinaryData, EnumText, LongText, StringUUID, adjusted_json_index @@ -120,7 +134,7 @@ class Dataset(Base): server_default=sa.text("'only_me'"), default=DatasetPermissionEnum.ONLY_ME, ) - data_source_type = mapped_column(String(255)) + data_source_type = mapped_column(EnumText(DataSourceType, length=255)) indexing_technique: Mapped[str | None] = mapped_column(String(255)) index_struct = mapped_column(LongText, nullable=True) created_by = mapped_column(StringUUID, nullable=False) @@ -137,7 +151,9 @@ class Dataset(Base): summary_index_setting = mapped_column(AdjustedJSON, nullable=True) built_in_field_enabled = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) icon_info = mapped_column(AdjustedJSON, nullable=True) - runtime_mode = mapped_column(sa.String(255), nullable=True, server_default=sa.text("'general'")) + runtime_mode = mapped_column( + EnumText(DatasetRuntimeMode, length=255), nullable=True, server_default=sa.text("'general'") + ) pipeline_id = mapped_column(StringUUID, nullable=True) chunk_structure = mapped_column(sa.String(255), nullable=True) enable_api = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) @@ -145,30 +161,25 @@ class Dataset(Base): @property def total_documents(self): - return db.session.query(func.count(Document.id)).where(Document.dataset_id == self.id).scalar() + return db.session.scalar(select(func.count(Document.id)).where(Document.dataset_id == self.id)) or 0 @property def total_available_documents(self): return ( - db.session.query(func.count(Document.id)) - .where( - Document.dataset_id == self.id, - Document.indexing_status == "completed", - Document.enabled == True, - Document.archived == False, + db.session.scalar( + select(func.count(Document.id)).where( + Document.dataset_id == self.id, + Document.indexing_status == "completed", + Document.enabled == True, + Document.archived == False, + ) ) - .scalar() + or 0 ) @property def dataset_keyword_table(self): - dataset_keyword_table = ( - db.session.query(DatasetKeywordTable).where(DatasetKeywordTable.dataset_id == self.id).first() - ) - if dataset_keyword_table: - return dataset_keyword_table - - return None + return db.session.scalar(select(DatasetKeywordTable).where(DatasetKeywordTable.dataset_id == self.id)) @property def index_struct_dict(self): @@ -195,64 +206,66 @@ class Dataset(Base): @property def latest_process_rule(self): - return ( - db.session.query(DatasetProcessRule) + return db.session.scalar( + select(DatasetProcessRule) .where(DatasetProcessRule.dataset_id == self.id) .order_by(DatasetProcessRule.created_at.desc()) - .first() + .limit(1) ) @property def app_count(self): return ( - db.session.query(func.count(AppDatasetJoin.id)) - .where(AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id) - .scalar() + db.session.scalar( + select(func.count(AppDatasetJoin.id)).where( + AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id + ) + ) + or 0 ) @property def document_count(self): - return db.session.query(func.count(Document.id)).where(Document.dataset_id == self.id).scalar() + return db.session.scalar(select(func.count(Document.id)).where(Document.dataset_id == self.id)) or 0 @property def available_document_count(self): return ( - db.session.query(func.count(Document.id)) - .where( - Document.dataset_id == self.id, - Document.indexing_status == "completed", - Document.enabled == True, - Document.archived == False, + db.session.scalar( + select(func.count(Document.id)).where( + Document.dataset_id == self.id, + Document.indexing_status == "completed", + Document.enabled == True, + Document.archived == False, + ) ) - .scalar() + or 0 ) @property def available_segment_count(self): return ( - db.session.query(func.count(DocumentSegment.id)) - .where( - DocumentSegment.dataset_id == self.id, - DocumentSegment.status == "completed", - DocumentSegment.enabled == True, + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.dataset_id == self.id, + DocumentSegment.status == "completed", + DocumentSegment.enabled == True, + ) ) - .scalar() + or 0 ) @property def word_count(self): - return ( - db.session.query(Document) - .with_entities(func.coalesce(func.sum(Document.word_count), 0)) - .where(Document.dataset_id == self.id) - .scalar() + return db.session.scalar( + select(func.coalesce(func.sum(Document.word_count), 0)).where(Document.dataset_id == self.id) ) @property def doc_form(self) -> str | None: if self.chunk_structure: return self.chunk_structure - document = db.session.query(Document).where(Document.dataset_id == self.id).first() + document = db.session.scalar(select(Document).where(Document.dataset_id == self.id).limit(1)) if document: return document.doc_form return None @@ -270,8 +283,8 @@ class Dataset(Base): @property def tags(self): - tags = ( - db.session.query(Tag) + tags = db.session.scalars( + select(Tag) .join(TagBinding, Tag.id == TagBinding.tag_id) .where( TagBinding.target_id == self.id, @@ -279,8 +292,7 @@ class Dataset(Base): Tag.tenant_id == self.tenant_id, Tag.type == "knowledge", ) - .all() - ) + ).all() return tags or [] @@ -288,8 +300,8 @@ class Dataset(Base): def external_knowledge_info(self): if self.provider != "external": return None - external_knowledge_binding = ( - db.session.query(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == self.id).first() + external_knowledge_binding = db.session.scalar( + select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == self.id) ) if not external_knowledge_binding: return None @@ -310,7 +322,7 @@ class Dataset(Base): @property def is_published(self): if self.pipeline_id: - pipeline = db.session.query(Pipeline).where(Pipeline.id == self.pipeline_id).first() + pipeline = db.session.scalar(select(Pipeline).where(Pipeline.id == self.pipeline_id)) if pipeline: return pipeline.is_published return False @@ -382,7 +394,7 @@ class DatasetProcessRule(Base): # bug id = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4())) dataset_id = mapped_column(StringUUID, nullable=False) - mode = mapped_column(String(255), nullable=False, server_default=sa.text("'automatic'")) + mode = mapped_column(EnumText(ProcessRuleMode, length=255), nullable=False, server_default=sa.text("'automatic'")) rules = mapped_column(LongText, nullable=True) created_by = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) @@ -428,12 +440,12 @@ class Document(Base): tenant_id = mapped_column(StringUUID, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) position: Mapped[int] = mapped_column(sa.Integer, nullable=False) - data_source_type: Mapped[str] = mapped_column(String(255), nullable=False) + data_source_type: Mapped[str] = mapped_column(EnumText(DataSourceType, length=255), nullable=False) data_source_info = mapped_column(LongText, nullable=True) dataset_process_rule_id = mapped_column(StringUUID, nullable=True) batch: Mapped[str] = mapped_column(String(255), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False) - created_from: Mapped[str] = mapped_column(String(255), nullable=False) + created_from: Mapped[str] = mapped_column(EnumText(DocumentCreatedFrom, length=255), nullable=False) created_by = mapped_column(StringUUID, nullable=False) created_api_request_id = mapped_column(StringUUID, nullable=True) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) @@ -467,7 +479,9 @@ class Document(Base): stopped_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) # basic fields - indexing_status = mapped_column(String(255), nullable=False, server_default=sa.text("'waiting'")) + indexing_status = mapped_column( + EnumText(IndexingStatus, length=255), nullable=False, server_default=sa.text("'waiting'") + ) enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) disabled_by = mapped_column(StringUUID, nullable=True) @@ -478,7 +492,7 @@ class Document(Base): updated_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() ) - doc_type = mapped_column(String(40), nullable=True) + doc_type = mapped_column(EnumText(DocumentDocType, length=40), nullable=True) doc_metadata = mapped_column(AdjustedJSON, nullable=True) doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'")) doc_language = mapped_column(String(255), nullable=True) @@ -521,10 +535,8 @@ class Document(Base): if self.data_source_info: if self.data_source_type == "upload_file": data_source_info_dict: dict[str, Any] = json.loads(self.data_source_info) - file_detail = ( - db.session.query(UploadFile) - .where(UploadFile.id == data_source_info_dict["upload_file_id"]) - .one_or_none() + file_detail = db.session.scalar( + select(UploadFile).where(UploadFile.id == data_source_info_dict["upload_file_id"]) ) if file_detail: return { @@ -557,24 +569,23 @@ class Document(Base): @property def dataset(self): - return db.session.query(Dataset).where(Dataset.id == self.dataset_id).one_or_none() + return db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id)) @property def segment_count(self): - return db.session.query(DocumentSegment).where(DocumentSegment.document_id == self.id).count() + return ( + db.session.scalar(select(func.count(DocumentSegment.id)).where(DocumentSegment.document_id == self.id)) or 0 + ) @property def hit_count(self): - return ( - db.session.query(DocumentSegment) - .with_entities(func.coalesce(func.sum(DocumentSegment.hit_count), 0)) - .where(DocumentSegment.document_id == self.id) - .scalar() + return db.session.scalar( + select(func.coalesce(func.sum(DocumentSegment.hit_count), 0)).where(DocumentSegment.document_id == self.id) ) @property def uploader(self): - user = db.session.query(Account).where(Account.id == self.created_by).first() + user = db.session.scalar(select(Account).where(Account.id == self.created_by)) return user.name if user else None @property @@ -588,14 +599,13 @@ class Document(Base): @property def doc_metadata_details(self) -> list[DocMetadataDetailItem] | None: if self.doc_metadata: - document_metadatas = ( - db.session.query(DatasetMetadata) + document_metadatas = db.session.scalars( + select(DatasetMetadata) .join(DatasetMetadataBinding, DatasetMetadataBinding.metadata_id == DatasetMetadata.id) .where( DatasetMetadataBinding.dataset_id == self.dataset_id, DatasetMetadataBinding.document_id == self.id ) - .all() - ) + ).all() metadata_list: list[DocMetadataDetailItem] = [] for metadata in document_metadatas: metadata_dict: DocMetadataDetailItem = { @@ -791,7 +801,7 @@ class DocumentSegment(Base): enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) disabled_by = mapped_column(StringUUID, nullable=True) - status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'waiting'")) + status: Mapped[str] = mapped_column(EnumText(SegmentStatus, length=255), server_default=sa.text("'waiting'")) created_by = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) @@ -826,7 +836,7 @@ class DocumentSegment(Base): ) @property - def child_chunks(self) -> list[Any]: + def child_chunks(self) -> Sequence[Any]: if not self.document: return [] process_rule = self.document.dataset_process_rule @@ -835,16 +845,13 @@ class DocumentSegment(Base): if rules_dict: rules = Rule.model_validate(rules_dict) if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC: - child_chunks = ( - db.session.query(ChildChunk) - .where(ChildChunk.segment_id == self.id) - .order_by(ChildChunk.position.asc()) - .all() - ) + child_chunks = db.session.scalars( + select(ChildChunk).where(ChildChunk.segment_id == self.id).order_by(ChildChunk.position.asc()) + ).all() return child_chunks or [] return [] - def get_child_chunks(self) -> list[Any]: + def get_child_chunks(self) -> Sequence[Any]: if not self.document: return [] process_rule = self.document.dataset_process_rule @@ -853,12 +860,9 @@ class DocumentSegment(Base): if rules_dict: rules = Rule.model_validate(rules_dict) if rules.parent_mode: - child_chunks = ( - db.session.query(ChildChunk) - .where(ChildChunk.segment_id == self.id) - .order_by(ChildChunk.position.asc()) - .all() - ) + child_chunks = db.session.scalars( + select(ChildChunk).where(ChildChunk.segment_id == self.id).order_by(ChildChunk.position.asc()) + ).all() return child_chunks or [] return [] @@ -1007,15 +1011,15 @@ class ChildChunk(Base): @property def dataset(self): - return db.session.query(Dataset).where(Dataset.id == self.dataset_id).first() + return db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id)) @property def document(self): - return db.session.query(Document).where(Document.id == self.document_id).first() + return db.session.scalar(select(Document).where(Document.id == self.document_id)) @property def segment(self): - return db.session.query(DocumentSegment).where(DocumentSegment.id == self.segment_id).first() + return db.session.scalar(select(DocumentSegment).where(DocumentSegment.id == self.segment_id)) class AppDatasetJoin(TypeBase): @@ -1061,7 +1065,7 @@ class DatasetQuery(TypeBase): ) dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False) content: Mapped[str] = mapped_column(LongText, nullable=False) - source: Mapped[str] = mapped_column(String(255), nullable=False) + source: Mapped[str] = mapped_column(EnumText(DatasetQuerySource, length=255), nullable=False) source_app_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) @@ -1076,7 +1080,7 @@ class DatasetQuery(TypeBase): if isinstance(queries, list): for query in queries: if query["content_type"] == QueryType.IMAGE_QUERY: - file_info = db.session.query(UploadFile).filter_by(id=query["content"]).first() + file_info = db.session.scalar(select(UploadFile).where(UploadFile.id == query["content"])) if file_info: query["file_info"] = { "id": file_info.id, @@ -1141,7 +1145,7 @@ class DatasetKeywordTable(TypeBase): super().__init__(object_hook=object_hook, *args, **kwargs) # get dataset - dataset = db.session.query(Dataset).filter_by(id=self.dataset_id).first() + dataset = db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id)) if not dataset: return None if self.data_source_type == "database": @@ -1206,7 +1210,9 @@ class DatasetCollectionBinding(TypeBase): ) provider_name: Mapped[str] = mapped_column(String(255), nullable=False) model_name: Mapped[str] = mapped_column(String(255), nullable=False) - type: Mapped[str] = mapped_column(String(40), server_default=sa.text("'dataset'"), nullable=False) + type: Mapped[str] = mapped_column( + EnumText(CollectionBindingType, length=40), server_default=sa.text("'dataset'"), nullable=False + ) collection_name: Mapped[str] = mapped_column(String(64), nullable=False) created_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=func.current_timestamp(), init=False @@ -1433,7 +1439,7 @@ class DatasetMetadata(TypeBase): ) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - type: Mapped[str] = mapped_column(String(255), nullable=False) + type: Mapped[str] = mapped_column(EnumText(DatasetMetadataType, length=255), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False) created_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False @@ -1535,7 +1541,7 @@ class PipelineCustomizedTemplate(TypeBase): @property def created_user_name(self): - account = db.session.query(Account).where(Account.id == self.created_by).first() + account = db.session.scalar(select(Account).where(Account.id == self.created_by)) if account: return account.name return "" @@ -1570,7 +1576,7 @@ class Pipeline(TypeBase): ) def retrieve_dataset(self, session: Session): - return session.query(Dataset).where(Dataset.pipeline_id == self.id).first() + return session.scalar(select(Dataset).where(Dataset.pipeline_id == self.id)) class DocumentPipelineExecutionLog(TypeBase): @@ -1660,7 +1666,9 @@ class DocumentSegmentSummary(Base): summary_index_node_id: Mapped[str] = mapped_column(String(255), nullable=True) summary_index_node_hash: Mapped[str] = mapped_column(String(255), nullable=True) tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) - status: Mapped[str] = mapped_column(String(32), nullable=False, server_default=sa.text("'generating'")) + status: Mapped[str] = mapped_column( + EnumText(SummaryStatus, length=32), nullable=False, server_default=sa.text("'generating'") + ) error: Mapped[str] = mapped_column(LongText, nullable=True) enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) diff --git a/api/models/enums.py b/api/models/enums.py index eb478fe02c..6499c5b443 100644 --- a/api/models/enums.py +++ b/api/models/enums.py @@ -11,6 +11,13 @@ class CreatorUserRole(StrEnum): ACCOUNT = "account" END_USER = "end_user" + @classmethod + def _missing_(cls, value): + if value == "end-user": + return cls.END_USER + else: + return super()._missing_(value) + class WorkflowRunTriggeredFrom(StrEnum): DEBUGGING = "debugging" @@ -96,3 +103,216 @@ class ConversationStatus(StrEnum): """Conversation Status Enum""" NORMAL = "normal" + + +class DataSourceType(StrEnum): + """Data Source Type for Dataset and Document""" + + UPLOAD_FILE = "upload_file" + NOTION_IMPORT = "notion_import" + WEBSITE_CRAWL = "website_crawl" + LOCAL_FILE = "local_file" + ONLINE_DOCUMENT = "online_document" + + +class ProcessRuleMode(StrEnum): + """Dataset Process Rule Mode""" + + AUTOMATIC = "automatic" + CUSTOM = "custom" + HIERARCHICAL = "hierarchical" + + +class IndexingStatus(StrEnum): + """Document Indexing Status""" + + WAITING = "waiting" + PARSING = "parsing" + CLEANING = "cleaning" + SPLITTING = "splitting" + INDEXING = "indexing" + PAUSED = "paused" + COMPLETED = "completed" + ERROR = "error" + + +class DocumentCreatedFrom(StrEnum): + """Document Created From""" + + WEB = "web" + API = "api" + RAG_PIPELINE = "rag-pipeline" + + +class ConversationFromSource(StrEnum): + """Conversation / Message from_source""" + + API = "api" + CONSOLE = "console" + + +class FeedbackFromSource(StrEnum): + """MessageFeedback from_source""" + + USER = "user" + ADMIN = "admin" + + +class InvokeFrom(StrEnum): + """How a conversation/message was invoked""" + + SERVICE_API = "service-api" + WEB_APP = "web-app" + TRIGGER = "trigger" + EXPLORE = "explore" + DEBUGGER = "debugger" + PUBLISHED_PIPELINE = "published" + VALIDATION = "validation" + + @classmethod + def value_of(cls, value: str) -> "InvokeFrom": + return cls(value) + + def to_source(self) -> str: + source_mapping = { + InvokeFrom.WEB_APP: "web_app", + InvokeFrom.DEBUGGER: "dev", + InvokeFrom.EXPLORE: "explore_app", + InvokeFrom.TRIGGER: "trigger", + InvokeFrom.SERVICE_API: "api", + } + return source_mapping.get(self, "dev") + + +class DocumentDocType(StrEnum): + """Document doc_type classification""" + + BOOK = "book" + WEB_PAGE = "web_page" + PAPER = "paper" + SOCIAL_MEDIA_POST = "social_media_post" + WIKIPEDIA_ENTRY = "wikipedia_entry" + PERSONAL_DOCUMENT = "personal_document" + BUSINESS_DOCUMENT = "business_document" + IM_CHAT_LOG = "im_chat_log" + SYNCED_FROM_NOTION = "synced_from_notion" + SYNCED_FROM_GITHUB = "synced_from_github" + OTHERS = "others" + + +class TagType(StrEnum): + """Tag type""" + + KNOWLEDGE = "knowledge" + APP = "app" + + +class DatasetMetadataType(StrEnum): + """Dataset metadata value type""" + + STRING = "string" + NUMBER = "number" + TIME = "time" + + +class SegmentStatus(StrEnum): + """Document segment status""" + + WAITING = "waiting" + INDEXING = "indexing" + COMPLETED = "completed" + ERROR = "error" + PAUSED = "paused" + RE_SEGMENT = "re_segment" + + +class DatasetRuntimeMode(StrEnum): + """Dataset runtime mode""" + + GENERAL = "general" + RAG_PIPELINE = "rag_pipeline" + + +class CollectionBindingType(StrEnum): + """Dataset collection binding type""" + + DATASET = "dataset" + ANNOTATION = "annotation" + + +class DatasetQuerySource(StrEnum): + """Dataset query source""" + + HIT_TESTING = "hit_testing" + APP = "app" + + +class TidbAuthBindingStatus(StrEnum): + """TiDB auth binding status""" + + CREATING = "CREATING" + ACTIVE = "ACTIVE" + + +class MessageFileBelongsTo(StrEnum): + """MessageFile belongs_to""" + + USER = "user" + ASSISTANT = "assistant" + + +class CredentialSourceType(StrEnum): + """Load balancing credential source type""" + + PROVIDER = "provider" + CUSTOM_MODEL = "custom_model" + + +class PaymentStatus(StrEnum): + """Provider order payment status""" + + WAIT_PAY = "wait_pay" + PAID = "paid" + FAILED = "failed" + REFUNDED = "refunded" + + +class BannerStatus(StrEnum): + """ExporleBanner status""" + + ENABLED = "enabled" + DISABLED = "disabled" + + +class SummaryStatus(StrEnum): + """Document segment summary status""" + + NOT_STARTED = "not_started" + GENERATING = "generating" + COMPLETED = "completed" + ERROR = "error" + TIMEOUT = "timeout" + + +class MessageChainType(StrEnum): + """Message chain type""" + + SYSTEM = "system" + + +class ProviderQuotaType(StrEnum): + PAID = "paid" + """hosted paid quota""" + + FREE = "free" + """third-party free quota""" + + TRIAL = "trial" + """hosted trial quota""" + + @staticmethod + def value_of(value: str) -> "ProviderQuotaType": + for member in ProviderQuotaType: + if member.value == value: + return member + raise ValueError(f"No matching enum found for value '{value}'") diff --git a/api/models/model.py b/api/models/model.py index 2e747df2c7..fe70fcd401 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -380,13 +380,12 @@ class App(Base): @property def site(self) -> Site | None: - site = db.session.query(Site).where(Site.app_id == self.id).first() - return site + return db.session.scalar(select(Site).where(Site.app_id == self.id)) @property def app_model_config(self) -> AppModelConfig | None: if self.app_model_config_id: - return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first() + return db.session.scalar(select(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id)) return None @@ -395,7 +394,7 @@ class App(Base): if self.workflow_id: from .workflow import Workflow - return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first() + return db.session.scalar(select(Workflow).where(Workflow.id == self.workflow_id)) return None @@ -405,8 +404,7 @@ class App(Base): @property def tenant(self) -> Tenant | None: - tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() - return tenant + return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id)) @property def is_agent(self) -> bool: @@ -546,9 +544,9 @@ class App(Base): return deleted_tools @property - def tags(self) -> list[Tag]: - tags = ( - db.session.query(Tag) + def tags(self) -> Sequence[Tag]: + tags = db.session.scalars( + select(Tag) .join(TagBinding, Tag.id == TagBinding.tag_id) .where( TagBinding.target_id == self.id, @@ -556,15 +554,14 @@ class App(Base): Tag.tenant_id == self.tenant_id, Tag.type == "app", ) - .all() - ) + ).all() return tags or [] @property def author_name(self) -> str | None: if self.created_by: - account = db.session.query(Account).where(Account.id == self.created_by).first() + account = db.session.scalar(select(Account).where(Account.id == self.created_by)) if account: return account.name @@ -616,8 +613,7 @@ class AppModelConfig(TypeBase): @property def app(self) -> App | None: - app = db.session.query(App).where(App.id == self.app_id).first() - return app + return db.session.scalar(select(App).where(App.id == self.app_id)) @property def model_dict(self) -> ModelConfig: @@ -652,8 +648,8 @@ class AppModelConfig(TypeBase): @property def annotation_reply_dict(self) -> AnnotationReplyConfig: - annotation_setting = ( - db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first() + annotation_setting = db.session.scalar( + select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id) ) if annotation_setting: collection_binding_detail = annotation_setting.collection_binding_detail @@ -845,8 +841,7 @@ class RecommendedApp(Base): # bug @property def app(self) -> App | None: - app = db.session.query(App).where(App.id == self.app_id).first() - return app + return db.session.scalar(select(App).where(App.id == self.app_id)) class InstalledApp(TypeBase): @@ -873,13 +868,11 @@ class InstalledApp(TypeBase): @property def app(self) -> App | None: - app = db.session.query(App).where(App.id == self.app_id).first() - return app + return db.session.scalar(select(App).where(App.id == self.app_id)) @property def tenant(self) -> Tenant | None: - tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() - return tenant + return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id)) class TrialApp(Base): @@ -899,8 +892,7 @@ class TrialApp(Base): @property def app(self) -> App | None: - app = db.session.query(App).where(App.id == self.app_id).first() - return app + return db.session.scalar(select(App).where(App.id == self.app_id)) class AccountTrialAppRecord(Base): @@ -919,13 +911,11 @@ class AccountTrialAppRecord(Base): @property def app(self) -> App | None: - app = db.session.query(App).where(App.id == self.app_id).first() - return app + return db.session.scalar(select(App).where(App.id == self.app_id)) @property def user(self) -> Account | None: - user = db.session.query(Account).where(Account.id == self.account_id).first() - return user + return db.session.scalar(select(Account).where(Account.id == self.account_id)) class ExporleBanner(TypeBase): @@ -1117,8 +1107,8 @@ class Conversation(Base): else: model_config["configs"] = override_model_configs # type: ignore[typeddict-unknown-key] else: - app_model_config = ( - db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first() + app_model_config = db.session.scalar( + select(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id) ) if app_model_config: model_config = app_model_config.to_dict() @@ -1141,36 +1131,43 @@ class Conversation(Base): @property def annotated(self): - return db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).count() > 0 + return ( + db.session.scalar( + select(func.count(MessageAnnotation.id)).where(MessageAnnotation.conversation_id == self.id) + ) + or 0 + ) > 0 @property def annotation(self): - return db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).first() + return db.session.scalar(select(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).limit(1)) @property def message_count(self): - return db.session.query(Message).where(Message.conversation_id == self.id).count() + return db.session.scalar(select(func.count(Message.id)).where(Message.conversation_id == self.id)) or 0 @property def user_feedback_stats(self): like = ( - db.session.query(MessageFeedback) - .where( - MessageFeedback.conversation_id == self.id, - MessageFeedback.from_source == "user", - MessageFeedback.rating == "like", + db.session.scalar( + select(func.count(MessageFeedback.id)).where( + MessageFeedback.conversation_id == self.id, + MessageFeedback.from_source == "user", + MessageFeedback.rating == "like", + ) ) - .count() + or 0 ) dislike = ( - db.session.query(MessageFeedback) - .where( - MessageFeedback.conversation_id == self.id, - MessageFeedback.from_source == "user", - MessageFeedback.rating == "dislike", + db.session.scalar( + select(func.count(MessageFeedback.id)).where( + MessageFeedback.conversation_id == self.id, + MessageFeedback.from_source == "user", + MessageFeedback.rating == "dislike", + ) ) - .count() + or 0 ) return {"like": like, "dislike": dislike} @@ -1178,23 +1175,25 @@ class Conversation(Base): @property def admin_feedback_stats(self): like = ( - db.session.query(MessageFeedback) - .where( - MessageFeedback.conversation_id == self.id, - MessageFeedback.from_source == "admin", - MessageFeedback.rating == "like", + db.session.scalar( + select(func.count(MessageFeedback.id)).where( + MessageFeedback.conversation_id == self.id, + MessageFeedback.from_source == "admin", + MessageFeedback.rating == "like", + ) ) - .count() + or 0 ) dislike = ( - db.session.query(MessageFeedback) - .where( - MessageFeedback.conversation_id == self.id, - MessageFeedback.from_source == "admin", - MessageFeedback.rating == "dislike", + db.session.scalar( + select(func.count(MessageFeedback.id)).where( + MessageFeedback.conversation_id == self.id, + MessageFeedback.from_source == "admin", + MessageFeedback.rating == "dislike", + ) ) - .count() + or 0 ) return {"like": like, "dislike": dislike} @@ -1256,22 +1255,19 @@ class Conversation(Base): @property def first_message(self): - return ( - db.session.query(Message) - .where(Message.conversation_id == self.id) - .order_by(Message.created_at.asc()) - .first() + return db.session.scalar( + select(Message).where(Message.conversation_id == self.id).order_by(Message.created_at.asc()) ) @property def app(self) -> App | None: with Session(db.engine, expire_on_commit=False) as session: - return session.query(App).where(App.id == self.app_id).first() + return session.scalar(select(App).where(App.id == self.app_id)) @property def from_end_user_session_id(self): if self.from_end_user_id: - end_user = db.session.query(EndUser).where(EndUser.id == self.from_end_user_id).first() + end_user = db.session.scalar(select(EndUser).where(EndUser.id == self.from_end_user_id)) if end_user: return end_user.session_id @@ -1280,7 +1276,7 @@ class Conversation(Base): @property def from_account_name(self) -> str | None: if self.from_account_id: - account = db.session.query(Account).where(Account.id == self.from_account_id).first() + account = db.session.scalar(select(Account).where(Account.id == self.from_account_id)) if account: return account.name @@ -1505,21 +1501,15 @@ class Message(Base): @property def user_feedback(self): - feedback = ( - db.session.query(MessageFeedback) - .where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user") - .first() + return db.session.scalar( + select(MessageFeedback).where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user") ) - return feedback @property def admin_feedback(self): - feedback = ( - db.session.query(MessageFeedback) - .where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin") - .first() + return db.session.scalar( + select(MessageFeedback).where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin") ) - return feedback @property def feedbacks(self): @@ -1528,28 +1518,27 @@ class Message(Base): @property def annotation(self): - annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == self.id).first() + annotation = db.session.scalar(select(MessageAnnotation).where(MessageAnnotation.message_id == self.id)) return annotation @property def annotation_hit_history(self): - annotation_history = ( - db.session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id == self.id).first() + annotation_history = db.session.scalar( + select(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id == self.id) ) if annotation_history: - annotation = ( - db.session.query(MessageAnnotation) - .where(MessageAnnotation.id == annotation_history.annotation_id) - .first() + return db.session.scalar( + select(MessageAnnotation).where(MessageAnnotation.id == annotation_history.annotation_id) ) - return annotation return None @property def app_model_config(self): - conversation = db.session.query(Conversation).where(Conversation.id == self.conversation_id).first() + conversation = db.session.scalar(select(Conversation).where(Conversation.id == self.conversation_id)) if conversation: - return db.session.query(AppModelConfig).where(AppModelConfig.id == conversation.app_model_config_id).first() + return db.session.scalar( + select(AppModelConfig).where(AppModelConfig.id == conversation.app_model_config_id) + ) return None @@ -1562,13 +1551,12 @@ class Message(Base): return json.loads(self.message_metadata) if self.message_metadata else {} @property - def agent_thoughts(self) -> list[MessageAgentThought]: - return ( - db.session.query(MessageAgentThought) + def agent_thoughts(self) -> Sequence[MessageAgentThought]: + return db.session.scalars( + select(MessageAgentThought) .where(MessageAgentThought.message_id == self.id) .order_by(MessageAgentThought.position.asc()) - .all() - ) + ).all() @property def retriever_resources(self) -> Any: @@ -1579,7 +1567,7 @@ class Message(Base): from factories import file_factory message_files = db.session.scalars(select(MessageFile).where(MessageFile.message_id == self.id)).all() - current_app = db.session.query(App).where(App.id == self.app_id).first() + current_app = db.session.scalar(select(App).where(App.id == self.app_id)) if not current_app: raise ValueError(f"App {self.app_id} not found") @@ -1743,8 +1731,7 @@ class MessageFeedback(TypeBase): @property def from_account(self) -> Account | None: - account = db.session.query(Account).where(Account.id == self.from_account_id).first() - return account + return db.session.scalar(select(Account).where(Account.id == self.from_account_id)) def to_dict(self) -> MessageFeedbackDict: return { @@ -1817,13 +1804,11 @@ class MessageAnnotation(Base): @property def account(self): - account = db.session.query(Account).where(Account.id == self.account_id).first() - return account + return db.session.scalar(select(Account).where(Account.id == self.account_id)) @property def annotation_create_account(self): - account = db.session.query(Account).where(Account.id == self.account_id).first() - return account + return db.session.scalar(select(Account).where(Account.id == self.account_id)) class AppAnnotationHitHistory(TypeBase): @@ -1852,18 +1837,15 @@ class AppAnnotationHitHistory(TypeBase): @property def account(self): - account = ( - db.session.query(Account) + return db.session.scalar( + select(Account) .join(MessageAnnotation, MessageAnnotation.account_id == Account.id) .where(MessageAnnotation.id == self.annotation_id) - .first() ) - return account @property def annotation_create_account(self): - account = db.session.query(Account).where(Account.id == self.account_id).first() - return account + return db.session.scalar(select(Account).where(Account.id == self.account_id)) class AppAnnotationSetting(TypeBase): @@ -1896,12 +1878,9 @@ class AppAnnotationSetting(TypeBase): def collection_binding_detail(self): from .dataset import DatasetCollectionBinding - collection_binding_detail = ( - db.session.query(DatasetCollectionBinding) - .where(DatasetCollectionBinding.id == self.collection_binding_id) - .first() + return db.session.scalar( + select(DatasetCollectionBinding).where(DatasetCollectionBinding.id == self.collection_binding_id) ) - return collection_binding_detail class OperationLog(TypeBase): @@ -2007,7 +1986,9 @@ class AppMCPServer(TypeBase): def generate_server_code(n: int) -> str: while True: result = generate_string(n) - while db.session.query(AppMCPServer).where(AppMCPServer.server_code == result).count() > 0: + while ( + db.session.scalar(select(func.count(AppMCPServer.id)).where(AppMCPServer.server_code == result)) or 0 + ) > 0: result = generate_string(n) return result @@ -2068,7 +2049,7 @@ class Site(Base): def generate_code(n: int) -> str: while True: result = generate_string(n) - while db.session.query(Site).where(Site.code == result).count() > 0: + while (db.session.scalar(select(func.count(Site.id)).where(Site.code == result)) or 0) > 0: result = generate_string(n) return result diff --git a/api/models/provider.py b/api/models/provider.py index 18a0fe92c8..4e114bb034 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -6,13 +6,14 @@ from functools import cached_property from uuid import uuid4 import sqlalchemy as sa -from sqlalchemy import DateTime, String, func, text +from sqlalchemy import DateTime, String, func, select, text from sqlalchemy.orm import Mapped, mapped_column from libs.uuid_utils import uuidv7 from .base import TypeBase from .engine import db +from .enums import CredentialSourceType, PaymentStatus from .types import EnumText, LongText, StringUUID @@ -96,7 +97,7 @@ class Provider(TypeBase): @cached_property def credential(self): if self.credential_id: - return db.session.query(ProviderCredential).where(ProviderCredential.id == self.credential_id).first() + return db.session.scalar(select(ProviderCredential).where(ProviderCredential.id == self.credential_id)) @property def credential_name(self): @@ -159,10 +160,8 @@ class ProviderModel(TypeBase): @cached_property def credential(self): if self.credential_id: - return ( - db.session.query(ProviderModelCredential) - .where(ProviderModelCredential.id == self.credential_id) - .first() + return db.session.scalar( + select(ProviderModelCredential).where(ProviderModelCredential.id == self.credential_id) ) @property @@ -239,7 +238,9 @@ class ProviderOrder(TypeBase): quantity: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=text("1")) currency: Mapped[str | None] = mapped_column(String(40)) total_amount: Mapped[int | None] = mapped_column(sa.Integer) - payment_status: Mapped[str] = mapped_column(String(40), nullable=False, server_default=text("'wait_pay'")) + payment_status: Mapped[PaymentStatus] = mapped_column( + EnumText(PaymentStatus, length=40), nullable=False, server_default=text("'wait_pay'") + ) paid_at: Mapped[datetime | None] = mapped_column(DateTime) pay_failed_at: Mapped[datetime | None] = mapped_column(DateTime) refunded_at: Mapped[datetime | None] = mapped_column(DateTime) @@ -302,7 +303,9 @@ class LoadBalancingModelConfig(TypeBase): name: Mapped[str] = mapped_column(String(255), nullable=False) encrypted_config: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) - credential_source_type: Mapped[str | None] = mapped_column(String(40), nullable=True, default=None) + credential_source_type: Mapped[CredentialSourceType | None] = mapped_column( + EnumText(CredentialSourceType, length=40), nullable=True, default=None + ) enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"), default=True) created_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=func.current_timestamp(), init=False diff --git a/api/models/tools.py b/api/models/tools.py index e7b98dcf27..c09f054e7d 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -8,7 +8,7 @@ from uuid import uuid4 import sqlalchemy as sa from deprecated import deprecated -from sqlalchemy import ForeignKey, String, func +from sqlalchemy import ForeignKey, String, func, select from sqlalchemy.orm import Mapped, mapped_column from core.tools.entities.common_entities import I18nObject @@ -184,11 +184,11 @@ class ApiToolProvider(TypeBase): def user(self) -> Account | None: if not self.user_id: return None - return db.session.query(Account).where(Account.id == self.user_id).first() + return db.session.scalar(select(Account).where(Account.id == self.user_id)) @property def tenant(self) -> Tenant | None: - return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() + return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id)) class ToolLabelBinding(TypeBase): @@ -262,11 +262,11 @@ class WorkflowToolProvider(TypeBase): @property def user(self) -> Account | None: - return db.session.query(Account).where(Account.id == self.user_id).first() + return db.session.scalar(select(Account).where(Account.id == self.user_id)) @property def tenant(self) -> Tenant | None: - return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() + return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id)) @property def parameter_configurations(self) -> list[WorkflowToolParameterConfiguration]: @@ -277,7 +277,7 @@ class WorkflowToolProvider(TypeBase): @property def app(self) -> App | None: - return db.session.query(App).where(App.id == self.app_id).first() + return db.session.scalar(select(App).where(App.id == self.app_id)) class MCPToolProvider(TypeBase): @@ -334,7 +334,7 @@ class MCPToolProvider(TypeBase): encrypted_headers: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) def load_user(self) -> Account | None: - return db.session.query(Account).where(Account.id == self.user_id).first() + return db.session.scalar(select(Account).where(Account.id == self.user_id)) @property def credentials(self) -> dict[str, Any]: diff --git a/api/models/trigger.py b/api/models/trigger.py index 43d7fc5b24..627b854060 100644 --- a/api/models/trigger.py +++ b/api/models/trigger.py @@ -3,7 +3,7 @@ import time from collections.abc import Mapping from datetime import datetime from functools import cached_property -from typing import Any, cast +from typing import Any, TypedDict, cast from uuid import uuid4 import sqlalchemy as sa @@ -23,6 +23,47 @@ from .enums import AppTriggerStatus, AppTriggerType, CreatorUserRole, WorkflowTr from .model import Account from .types import EnumText, LongText, StringUUID +TriggerJsonObject = dict[str, object] +TriggerCredentials = dict[str, str] + + +class WorkflowTriggerLogDict(TypedDict): + id: str + tenant_id: str + app_id: str + workflow_id: str + workflow_run_id: str | None + root_node_id: str | None + trigger_metadata: Any + trigger_type: str + trigger_data: Any + inputs: Any + outputs: Any + status: str + error: str | None + queue_name: str + celery_task_id: str | None + retry_count: int + elapsed_time: float | None + total_tokens: int | None + created_by_role: str + created_by: str + created_at: str | None + triggered_at: str | None + finished_at: str | None + + +class WorkflowSchedulePlanDict(TypedDict): + id: str + app_id: str + node_id: str + tenant_id: str + cron_expression: str + timezone: str + next_run_at: str | None + created_at: str + updated_at: str + class TriggerSubscription(TypeBase): """ @@ -51,10 +92,14 @@ class TriggerSubscription(TypeBase): String(255), nullable=False, comment="Provider identifier (e.g., plugin_id/provider_name)" ) endpoint_id: Mapped[str] = mapped_column(String(255), nullable=False, comment="Subscription endpoint") - parameters: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False, comment="Subscription parameters JSON") - properties: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False, comment="Subscription properties JSON") + parameters: Mapped[TriggerJsonObject] = mapped_column( + sa.JSON, nullable=False, comment="Subscription parameters JSON" + ) + properties: Mapped[TriggerJsonObject] = mapped_column( + sa.JSON, nullable=False, comment="Subscription properties JSON" + ) - credentials: Mapped[dict[str, Any]] = mapped_column( + credentials: Mapped[TriggerCredentials] = mapped_column( sa.JSON, nullable=False, comment="Subscription credentials JSON" ) credential_type: Mapped[str] = mapped_column(String(50), nullable=False, comment="oauth or api_key") @@ -162,8 +207,8 @@ class TriggerOAuthTenantClient(TypeBase): ) @property - def oauth_params(self) -> Mapping[str, Any]: - return cast(Mapping[str, Any], json.loads(self.encrypted_oauth_params or "{}")) + def oauth_params(self) -> Mapping[str, object]: + return cast(TriggerJsonObject, json.loads(self.encrypted_oauth_params or "{}")) class WorkflowTriggerLog(TypeBase): @@ -250,7 +295,7 @@ class WorkflowTriggerLog(TypeBase): created_by_role = CreatorUserRole(self.created_by_role) return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None - def to_dict(self) -> dict[str, Any]: + def to_dict(self) -> WorkflowTriggerLogDict: """Convert to dictionary for API responses""" return { "id": self.id, @@ -481,7 +526,7 @@ class WorkflowSchedulePlan(TypeBase): DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), init=False ) - def to_dict(self) -> dict[str, Any]: + def to_dict(self) -> WorkflowSchedulePlanDict: """Convert to dictionary representation""" return { "id": self.id, diff --git a/api/models/web.py b/api/models/web.py index a1cc11c375..1fb37340d7 100644 --- a/api/models/web.py +++ b/api/models/web.py @@ -2,7 +2,7 @@ from datetime import datetime from uuid import uuid4 import sqlalchemy as sa -from sqlalchemy import DateTime, func +from sqlalchemy import DateTime, func, select from sqlalchemy.orm import Mapped, mapped_column from .base import TypeBase @@ -38,7 +38,7 @@ class SavedMessage(TypeBase): @property def message(self): - return db.session.query(Message).where(Message.id == self.message_id).first() + return db.session.scalar(select(Message).where(Message.id == self.message_id)) class PinnedConversation(TypeBase): diff --git a/api/models/workflow.py b/api/models/workflow.py index 32cbd50648..9bb249481f 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -3,7 +3,7 @@ import logging from collections.abc import Generator, Mapping, Sequence from datetime import datetime from enum import StrEnum -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union, cast from uuid import uuid4 import sqlalchemy as sa @@ -19,7 +19,7 @@ from sqlalchemy import ( orm, select, ) -from sqlalchemy.orm import Mapped, declared_attr, mapped_column +from sqlalchemy.orm import Mapped, mapped_column from typing_extensions import deprecated from core.trigger.constants import TRIGGER_INFO_METADATA_KEY, TRIGGER_PLUGIN_NODE_TYPE @@ -33,7 +33,7 @@ from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowExecutionStatus from dify_graph.file.constants import maybe_file_object from dify_graph.file.models import File from dify_graph.variables import utils as variable_utils -from dify_graph.variables.variables import FloatVariable, IntegerVariable, StringVariable +from dify_graph.variables.variables import FloatVariable, IntegerVariable, RAGPipelineVariable, StringVariable from extensions.ext_storage import Storage from factories.variable_factory import TypeMismatchError, build_segment_with_type from libs.datetime_utils import naive_utc_now @@ -59,6 +59,25 @@ from .types import EnumText, LongText, StringUUID logger = logging.getLogger(__name__) +SerializedWorkflowValue = dict[str, Any] +SerializedWorkflowVariables = dict[str, SerializedWorkflowValue] + + +class WorkflowContentDict(TypedDict): + graph: Mapping[str, Any] + features: dict[str, Any] + environment_variables: list[dict[str, Any]] + conversation_variables: list[dict[str, Any]] + rag_pipeline_variables: list[dict[str, Any]] + + +class WorkflowRunSummaryDict(TypedDict): + id: str + status: str + triggered_from: str + elapsed_time: float + total_tokens: int + class WorkflowType(StrEnum): """ @@ -389,7 +408,7 @@ class Workflow(Base): # bug def rag_pipeline_user_input_form(self) -> list: # get user_input_form from start node - variables: list[Any] = self.rag_pipeline_variables + variables: list[SerializedWorkflowValue] = self.rag_pipeline_variables return variables @@ -432,17 +451,13 @@ class Workflow(Base): # bug def environment_variables( self, ) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]: - # TODO: find some way to init `self._environment_variables` when instance created. - if self._environment_variables is None: - self._environment_variables = "{}" - # Use workflow.tenant_id to avoid relying on request user in background threads tenant_id = self.tenant_id if not tenant_id: return [] - environment_variables_dict: dict[str, Any] = json.loads(self._environment_variables or "{}") + environment_variables_dict = cast(SerializedWorkflowVariables, json.loads(self._environment_variables or "{}")) results = [ variable_factory.build_environment_variable_from_mapping(v) for v in environment_variables_dict.values() ] @@ -502,14 +517,14 @@ class Workflow(Base): # bug ) self._environment_variables = environment_variables_json - def to_dict(self, *, include_secret: bool = False) -> Mapping[str, Any]: + def to_dict(self, *, include_secret: bool = False) -> WorkflowContentDict: environment_variables = list(self.environment_variables) environment_variables = [ v if not isinstance(v, SecretVariable) or include_secret else v.model_copy(update={"value": ""}) for v in environment_variables ] - result = { + result: WorkflowContentDict = { "graph": self.graph_dict, "features": self.features_dict, "environment_variables": [var.model_dump(mode="json") for var in environment_variables], @@ -520,11 +535,7 @@ class Workflow(Base): # bug @property def conversation_variables(self) -> Sequence[VariableBase]: - # TODO: find some way to init `self._conversation_variables` when instance created. - if self._conversation_variables is None: - self._conversation_variables = "{}" - - variables_dict: dict[str, Any] = json.loads(self._conversation_variables) + variables_dict = cast(SerializedWorkflowVariables, json.loads(self._conversation_variables or "{}")) results = [variable_factory.build_conversation_variable_from_mapping(v) for v in variables_dict.values()] return results @@ -536,19 +547,20 @@ class Workflow(Base): # bug ) @property - def rag_pipeline_variables(self) -> list[dict]: - # TODO: find some way to init `self._conversation_variables` when instance created. - if self._rag_pipeline_variables is None: - self._rag_pipeline_variables = "{}" - - variables_dict: dict[str, Any] = json.loads(self._rag_pipeline_variables) - results = list(variables_dict.values()) - return results + def rag_pipeline_variables(self) -> list[SerializedWorkflowValue]: + variables_dict = cast(SerializedWorkflowVariables, json.loads(self._rag_pipeline_variables or "{}")) + return [RAGPipelineVariable.model_validate(item).model_dump(mode="json") for item in variables_dict.values()] @rag_pipeline_variables.setter - def rag_pipeline_variables(self, values: list[dict]) -> None: + def rag_pipeline_variables(self, values: Sequence[Mapping[str, Any] | RAGPipelineVariable]) -> None: self._rag_pipeline_variables = json.dumps( - {item["variable"]: item for item in values}, + { + rag_pipeline_variable.variable: rag_pipeline_variable.model_dump(mode="json") + for rag_pipeline_variable in ( + item if isinstance(item, RAGPipelineVariable) else RAGPipelineVariable.model_validate(item) + for item in values + ) + }, ensure_ascii=False, ) @@ -667,14 +679,14 @@ class WorkflowRun(Base): def message(self): from .model import Message - return ( - db.session.query(Message).where(Message.app_id == self.app_id, Message.workflow_run_id == self.id).first() + return db.session.scalar( + select(Message).where(Message.app_id == self.app_id, Message.workflow_run_id == self.id) ) @property @deprecated("This method is retained for historical reasons; avoid using it if possible.") def workflow(self): - return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first() + return db.session.scalar(select(Workflow).where(Workflow.id == self.workflow_id)) def to_dict(self): return { @@ -786,44 +798,36 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo __tablename__ = "workflow_node_executions" - @declared_attr.directive - @classmethod - def __table_args__(cls) -> Any: - return ( - PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"), - Index( - "workflow_node_execution_workflow_run_id_idx", - "workflow_run_id", - ), - Index( - "workflow_node_execution_node_run_idx", - "tenant_id", - "app_id", - "workflow_id", - "triggered_from", - "node_id", - ), - Index( - "workflow_node_execution_id_idx", - "tenant_id", - "app_id", - "workflow_id", - "triggered_from", - "node_execution_id", - ), - Index( - # The first argument is the index name, - # which we leave as `None`` to allow auto-generation by the ORM. - None, - cls.tenant_id, - cls.workflow_id, - cls.node_id, - # MyPy may flag the following line because it doesn't recognize that - # the `declared_attr` decorator passes the receiving class as the first - # argument to this method, allowing us to reference class attributes. - cls.created_at.desc(), - ), - ) + __table_args__ = ( + PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"), + Index( + "workflow_node_execution_workflow_run_id_idx", + "workflow_run_id", + ), + Index( + "workflow_node_execution_node_run_idx", + "tenant_id", + "app_id", + "workflow_id", + "triggered_from", + "node_id", + ), + Index( + "workflow_node_execution_id_idx", + "tenant_id", + "app_id", + "workflow_id", + "triggered_from", + "node_execution_id", + ), + Index( + None, + "tenant_id", + "workflow_id", + "node_id", + sa.desc("created_at"), + ), + ) id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4())) tenant_id: Mapped[str] = mapped_column(StringUUID) @@ -1231,7 +1235,7 @@ class WorkflowArchiveLog(TypeBase): ) @property - def workflow_run_summary(self) -> dict[str, Any]: + def workflow_run_summary(self) -> WorkflowRunSummaryDict: return { "id": self.workflow_run_id, "status": self.run_status, diff --git a/api/pyproject.toml b/api/pyproject.toml index ac51d10513..31b778ab8c 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dify-api" -version = "1.13.0" +version = "1.13.1" requires-python = ">=3.11,<3.13" dependencies = [ diff --git a/api/pyrefly-local-excludes.txt b/api/pyrefly-local-excludes.txt index c044824a82..ad3c1e8389 100644 --- a/api/pyrefly-local-excludes.txt +++ b/api/pyrefly-local-excludes.txt @@ -1,4 +1,3 @@ -configs/middleware/cache/redis_pubsub_config.py controllers/console/app/annotation.py controllers/console/app/app.py controllers/console/app/app_import.py @@ -138,8 +137,6 @@ dify_graph/nodes/trigger_webhook/node.py dify_graph/nodes/variable_aggregator/variable_aggregator_node.py dify_graph/nodes/variable_assigner/v1/node.py dify_graph/nodes/variable_assigner/v2/node.py -dify_graph/variables/types.py -extensions/ext_fastopenapi.py extensions/logstore/repositories/logstore_api_workflow_run_repository.py extensions/otel/instrumentation.py extensions/otel/runtime.py @@ -156,19 +153,7 @@ extensions/storage/oracle_oci_storage.py extensions/storage/supabase_storage.py extensions/storage/tencent_cos_storage.py extensions/storage/volcengine_tos_storage.py -factories/variable_factory.py -libs/external_api.py libs/gmpy2_pkcs10aep_cipher.py -libs/helper.py -libs/login.py -libs/module_loading.py -libs/oauth.py -libs/oauth_data_source.py -models/trigger.py -models/workflow.py -repositories/sqlalchemy_api_workflow_node_execution_repository.py -repositories/sqlalchemy_api_workflow_run_repository.py -repositories/sqlalchemy_execution_extra_content_repository.py schedule/queue_monitor_task.py services/account_service.py services/audio_service.py @@ -197,4 +182,9 @@ tasks/app_generate/workflow_execute_task.py tasks/regenerate_summary_index_task.py tasks/trigger_processing_tasks.py tasks/workflow_cfs_scheduler/cfs_scheduler.py +tasks/add_document_to_index_task.py +tasks/create_segment_to_index_task.py +tasks/disable_segment_from_index_task.py +tasks/enable_segment_to_index_task.py +tasks/remove_document_from_index_task.py tasks/workflow_execution_tasks.py diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py index 2266c2e646..77e40fc6fc 100644 --- a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py @@ -8,7 +8,7 @@ using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations. import json from collections.abc import Sequence from datetime import datetime -from typing import cast +from typing import Protocol, cast from sqlalchemy import asc, delete, desc, func, select from sqlalchemy.engine import CursorResult @@ -22,6 +22,20 @@ from repositories.api_workflow_node_execution_repository import ( ) +class _WorkflowNodeExecutionSnapshotRow(Protocol): + id: str + node_execution_id: str | None + node_id: str + node_type: str + title: str + index: int + status: WorkflowNodeExecutionStatus + elapsed_time: float | None + created_at: datetime + finished_at: datetime | None + execution_metadata: str | None + + class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRepository): """ SQLAlchemy implementation of DifyAPIWorkflowNodeExecutionRepository. @@ -40,6 +54,8 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut - Thread-safe database operations using session-per-request pattern """ + _session_maker: sessionmaker[Session] + def __init__(self, session_maker: sessionmaker[Session]): """ Initialize the repository with a sessionmaker. @@ -156,12 +172,12 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut ) with self._session_maker() as session: - rows = session.execute(stmt).all() + rows = cast(Sequence[_WorkflowNodeExecutionSnapshotRow], session.execute(stmt).all()) return [self._row_to_snapshot(row) for row in rows] @staticmethod - def _row_to_snapshot(row: object) -> WorkflowNodeExecutionSnapshot: + def _row_to_snapshot(row: _WorkflowNodeExecutionSnapshotRow) -> WorkflowNodeExecutionSnapshot: metadata: dict[str, object] = {} execution_metadata = getattr(row, "execution_metadata", None) if execution_metadata: diff --git a/api/services/agent_service.py b/api/services/agent_service.py index b2db895a5a..2b8a3ee594 100644 --- a/api/services/agent_service.py +++ b/api/services/agent_service.py @@ -11,7 +11,7 @@ from core.tools.tool_manager import ToolManager from extensions.ext_database import db from libs.login import current_user from models import Account -from models.model import App, Conversation, EndUser, Message, MessageAgentThought +from models.model import App, Conversation, EndUser, Message class AgentService: @@ -47,7 +47,7 @@ class AgentService: if not message: raise ValueError(f"Message not found: {message_id}") - agent_thoughts: list[MessageAgentThought] = message.agent_thoughts + agent_thoughts = message.agent_thoughts if conversation.from_end_user_id: # only select name field diff --git a/api/services/async_workflow_service.py b/api/services/async_workflow_service.py index 94452482b3..0133634e5a 100644 --- a/api/services/async_workflow_service.py +++ b/api/services/async_workflow_service.py @@ -18,7 +18,7 @@ from extensions.ext_database import db from models.account import Account from models.enums import CreatorUserRole, WorkflowTriggerStatus from models.model import App, EndUser -from models.trigger import WorkflowTriggerLog +from models.trigger import WorkflowTriggerLog, WorkflowTriggerLogDict from models.workflow import Workflow from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError @@ -224,7 +224,9 @@ class AsyncWorkflowService: return cls.trigger_workflow_async(session, user, trigger_data) @classmethod - def get_trigger_log(cls, workflow_trigger_log_id: str, tenant_id: str | None = None) -> dict[str, Any] | None: + def get_trigger_log( + cls, workflow_trigger_log_id: str, tenant_id: str | None = None + ) -> WorkflowTriggerLogDict | None: """ Get trigger log by ID @@ -247,7 +249,7 @@ class AsyncWorkflowService: @classmethod def get_recent_logs( cls, tenant_id: str, app_id: str, hours: int = 24, limit: int = 100, offset: int = 0 - ) -> list[dict[str, Any]]: + ) -> list[WorkflowTriggerLogDict]: """ Get recent trigger logs @@ -272,7 +274,7 @@ class AsyncWorkflowService: @classmethod def get_failed_logs_for_retry( cls, tenant_id: str, max_retry_count: int = 3, limit: int = 100 - ) -> list[dict[str, Any]]: + ) -> list[WorkflowTriggerLogDict]: """ Get failed logs eligible for retry diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index c527c71d7b..cdab90a3dc 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -51,6 +51,14 @@ from models.dataset import ( Pipeline, SegmentAttachmentBinding, ) +from models.enums import ( + DatasetRuntimeMode, + DataSourceType, + DocumentCreatedFrom, + IndexingStatus, + ProcessRuleMode, + SegmentStatus, +) from models.model import UploadFile from models.provider_ids import ModelProviderID from models.source import DataSourceOauthBinding @@ -319,7 +327,7 @@ class DatasetService: description=rag_pipeline_dataset_create_entity.description, permission=rag_pipeline_dataset_create_entity.permission, provider="vendor", - runtime_mode="rag_pipeline", + runtime_mode=DatasetRuntimeMode.RAG_PIPELINE, icon_info=rag_pipeline_dataset_create_entity.icon_info.model_dump(), created_by=current_user.id, pipeline_id=pipeline.id, @@ -614,7 +622,7 @@ class DatasetService: """ Update pipeline knowledge base node data. """ - if dataset.runtime_mode != "rag_pipeline": + if dataset.runtime_mode != DatasetRuntimeMode.RAG_PIPELINE: return pipeline = db.session.query(Pipeline).filter_by(id=dataset.pipeline_id).first() @@ -1229,10 +1237,15 @@ class DocumentService: "enabled": "available", } - _INDEXING_STATUSES: tuple[str, ...] = ("parsing", "cleaning", "splitting", "indexing") + _INDEXING_STATUSES: tuple[IndexingStatus, ...] = ( + IndexingStatus.PARSING, + IndexingStatus.CLEANING, + IndexingStatus.SPLITTING, + IndexingStatus.INDEXING, + ) DISPLAY_STATUS_FILTERS: dict[str, tuple[Any, ...]] = { - "queuing": (Document.indexing_status == "waiting",), + "queuing": (Document.indexing_status == IndexingStatus.WAITING,), "indexing": ( Document.indexing_status.in_(_INDEXING_STATUSES), Document.is_paused.is_not(True), @@ -1241,19 +1254,19 @@ class DocumentService: Document.indexing_status.in_(_INDEXING_STATUSES), Document.is_paused.is_(True), ), - "error": (Document.indexing_status == "error",), + "error": (Document.indexing_status == IndexingStatus.ERROR,), "available": ( - Document.indexing_status == "completed", + Document.indexing_status == IndexingStatus.COMPLETED, Document.archived.is_(False), Document.enabled.is_(True), ), "disabled": ( - Document.indexing_status == "completed", + Document.indexing_status == IndexingStatus.COMPLETED, Document.archived.is_(False), Document.enabled.is_(False), ), "archived": ( - Document.indexing_status == "completed", + Document.indexing_status == IndexingStatus.COMPLETED, Document.archived.is_(True), ), } @@ -1536,7 +1549,7 @@ class DocumentService: """ Normalize and validate `Document -> UploadFile` linkage for download flows. """ - if document.data_source_type != "upload_file": + if document.data_source_type != DataSourceType.UPLOAD_FILE: raise NotFound(invalid_source_message) data_source_info: dict[str, Any] = document.data_source_info_dict or {} @@ -1617,7 +1630,7 @@ class DocumentService: select(Document).where( Document.id.in_(document_ids), Document.enabled == True, - Document.indexing_status == "completed", + Document.indexing_status == IndexingStatus.COMPLETED, Document.archived == False, ) ).all() @@ -1640,7 +1653,7 @@ class DocumentService: select(Document).where( Document.dataset_id == dataset_id, Document.enabled == True, - Document.indexing_status == "completed", + Document.indexing_status == IndexingStatus.COMPLETED, Document.archived == False, ) ).all() @@ -1650,7 +1663,10 @@ class DocumentService: @staticmethod def get_error_documents_by_dataset_id(dataset_id: str) -> Sequence[Document]: documents = db.session.scalars( - select(Document).where(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"])) + select(Document).where( + Document.dataset_id == dataset_id, + Document.indexing_status.in_([IndexingStatus.ERROR, IndexingStatus.PAUSED]), + ) ).all() return documents @@ -1683,7 +1699,7 @@ class DocumentService: def delete_document(document): # trigger document_was_deleted signal file_id = None - if document.data_source_type == "upload_file": + if document.data_source_type == DataSourceType.UPLOAD_FILE: if document.data_source_info: data_source_info = document.data_source_info_dict if data_source_info and "upload_file_id" in data_source_info: @@ -1704,7 +1720,7 @@ class DocumentService: file_ids = [ document.data_source_info_dict.get("upload_file_id", "") for document in documents - if document.data_source_type == "upload_file" and document.data_source_info_dict + if document.data_source_type == DataSourceType.UPLOAD_FILE and document.data_source_info_dict ] # Delete documents first, then dispatch cleanup task after commit @@ -1753,7 +1769,13 @@ class DocumentService: @staticmethod def pause_document(document): - if document.indexing_status not in {"waiting", "parsing", "cleaning", "splitting", "indexing"}: + if document.indexing_status not in { + IndexingStatus.WAITING, + IndexingStatus.PARSING, + IndexingStatus.CLEANING, + IndexingStatus.SPLITTING, + IndexingStatus.INDEXING, + }: raise DocumentIndexingError() # update document to be paused assert current_user is not None @@ -1793,7 +1815,7 @@ class DocumentService: if cache_result is not None: raise ValueError("Document is being retried, please try again later") # retry document indexing - document.indexing_status = "waiting" + document.indexing_status = IndexingStatus.WAITING db.session.add(document) db.session.commit() @@ -1812,7 +1834,7 @@ class DocumentService: if cache_result is not None: raise ValueError("Document is being synced, please try again later") # sync document indexing - document.indexing_status = "waiting" + document.indexing_status = IndexingStatus.WAITING data_source_info = document.data_source_info_dict if data_source_info: data_source_info["mode"] = "scrape" @@ -1840,7 +1862,7 @@ class DocumentService: knowledge_config: KnowledgeConfig, account: Account | Any, dataset_process_rule: DatasetProcessRule | None = None, - created_from: str = "web", + created_from: str = DocumentCreatedFrom.WEB, ) -> tuple[list[Document], str]: # check doc_form DatasetService.check_doc_form(dataset, knowledge_config.doc_form) @@ -1932,7 +1954,7 @@ class DocumentService: if not dataset_process_rule: process_rule = knowledge_config.process_rule if process_rule: - if process_rule.mode in ("custom", "hierarchical"): + if process_rule.mode in (ProcessRuleMode.CUSTOM, ProcessRuleMode.HIERARCHICAL): if process_rule.rules: dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, @@ -1944,7 +1966,7 @@ class DocumentService: dataset_process_rule = dataset.latest_process_rule if not dataset_process_rule: raise ValueError("No process rule found.") - elif process_rule.mode == "automatic": + elif process_rule.mode == ProcessRuleMode.AUTOMATIC: dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, mode=process_rule.mode, @@ -1967,7 +1989,7 @@ class DocumentService: if not dataset_process_rule: dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, - mode="automatic", + mode=ProcessRuleMode.AUTOMATIC, rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), created_by=account.id, ) @@ -2001,7 +2023,7 @@ class DocumentService: .where( Document.dataset_id == dataset.id, Document.tenant_id == current_user.current_tenant_id, - Document.data_source_type == "upload_file", + Document.data_source_type == DataSourceType.UPLOAD_FILE, Document.enabled == True, Document.name.in_(file_names), ) @@ -2021,7 +2043,7 @@ class DocumentService: document.doc_language = knowledge_config.doc_language document.data_source_info = json.dumps(data_source_info) document.batch = batch - document.indexing_status = "waiting" + document.indexing_status = IndexingStatus.WAITING db.session.add(document) documents.append(document) duplicate_document_ids.append(document.id) @@ -2056,7 +2078,7 @@ class DocumentService: .filter_by( dataset_id=dataset.id, tenant_id=current_user.current_tenant_id, - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, enabled=True, ) .all() @@ -2507,7 +2529,7 @@ class DocumentService: document_data: KnowledgeConfig, account: Account, dataset_process_rule: DatasetProcessRule | None = None, - created_from: str = "web", + created_from: str = DocumentCreatedFrom.WEB, ): assert isinstance(current_user, Account) @@ -2520,14 +2542,14 @@ class DocumentService: # save process rule if document_data.process_rule: process_rule = document_data.process_rule - if process_rule.mode in {"custom", "hierarchical"}: + if process_rule.mode in {ProcessRuleMode.CUSTOM, ProcessRuleMode.HIERARCHICAL}: dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, mode=process_rule.mode, rules=process_rule.rules.model_dump_json() if process_rule.rules else None, created_by=account.id, ) - elif process_rule.mode == "automatic": + elif process_rule.mode == ProcessRuleMode.AUTOMATIC: dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, mode=process_rule.mode, @@ -2609,7 +2631,7 @@ class DocumentService: if document_data.name: document.name = document_data.name # update document to be waiting - document.indexing_status = "waiting" + document.indexing_status = IndexingStatus.WAITING document.completed_at = None document.processing_started_at = None document.parsing_completed_at = None @@ -2623,7 +2645,7 @@ class DocumentService: # update document segment db.session.query(DocumentSegment).filter_by(document_id=document.id).update( - {DocumentSegment.status: "re_segment"} + {DocumentSegment.status: SegmentStatus.RE_SEGMENT} ) db.session.commit() # trigger async task @@ -2754,7 +2776,7 @@ class DocumentService: if knowledge_config.process_rule.mode not in DatasetProcessRule.MODES: raise ValueError("Process rule mode is invalid") - if knowledge_config.process_rule.mode == "automatic": + if knowledge_config.process_rule.mode == ProcessRuleMode.AUTOMATIC: knowledge_config.process_rule.rules = None else: if not knowledge_config.process_rule.rules: @@ -2785,7 +2807,7 @@ class DocumentService: raise ValueError("Process rule segmentation separator is invalid") if not ( - knowledge_config.process_rule.mode == "hierarchical" + knowledge_config.process_rule.mode == ProcessRuleMode.HIERARCHICAL and knowledge_config.process_rule.rules.parent_mode == "full-doc" ): if not knowledge_config.process_rule.rules.segmentation.max_tokens: @@ -2814,7 +2836,7 @@ class DocumentService: if args["process_rule"]["mode"] not in DatasetProcessRule.MODES: raise ValueError("Process rule mode is invalid") - if args["process_rule"]["mode"] == "automatic": + if args["process_rule"]["mode"] == ProcessRuleMode.AUTOMATIC: args["process_rule"]["rules"] = {} else: if "rules" not in args["process_rule"] or not args["process_rule"]["rules"]: @@ -3021,7 +3043,7 @@ class DocumentService: @staticmethod def _prepare_disable_update(document, user, now): """Prepare updates for disabling a document.""" - if not document.completed_at or document.indexing_status != "completed": + if not document.completed_at or document.indexing_status != IndexingStatus.COMPLETED: raise DocumentIndexingError(f"Document: {document.name} is not completed.") if not document.enabled: @@ -3130,7 +3152,7 @@ class SegmentService: content=content, word_count=len(content), tokens=tokens, - status="completed", + status=SegmentStatus.COMPLETED, indexing_at=naive_utc_now(), completed_at=naive_utc_now(), created_by=current_user.id, @@ -3167,7 +3189,7 @@ class SegmentService: logger.exception("create segment index failed") segment_document.enabled = False segment_document.disabled_at = naive_utc_now() - segment_document.status = "error" + segment_document.status = SegmentStatus.ERROR segment_document.error = str(e) db.session.commit() segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_document.id).first() @@ -3227,7 +3249,7 @@ class SegmentService: word_count=len(content), tokens=tokens, keywords=segment_item.get("keywords", []), - status="completed", + status=SegmentStatus.COMPLETED, indexing_at=naive_utc_now(), completed_at=naive_utc_now(), created_by=current_user.id, @@ -3259,7 +3281,7 @@ class SegmentService: for segment_document in segment_data_list: segment_document.enabled = False segment_document.disabled_at = naive_utc_now() - segment_document.status = "error" + segment_document.status = SegmentStatus.ERROR segment_document.error = str(e) db.session.commit() return segment_data_list @@ -3405,7 +3427,7 @@ class SegmentService: segment.index_node_hash = segment_hash segment.word_count = len(content) segment.tokens = tokens - segment.status = "completed" + segment.status = SegmentStatus.COMPLETED segment.indexing_at = naive_utc_now() segment.completed_at = naive_utc_now() segment.updated_by = current_user.id @@ -3530,7 +3552,7 @@ class SegmentService: logger.exception("update segment index failed") segment.enabled = False segment.disabled_at = naive_utc_now() - segment.status = "error" + segment.status = SegmentStatus.ERROR segment.error = str(e) db.session.commit() new_segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment.id).first() diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index d85b290534..9993d24c70 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -13,7 +13,7 @@ from dify_graph.model_runtime.entities import LLMMode from extensions.ext_database import db from models import Account from models.dataset import Dataset, DatasetQuery -from models.enums import CreatorUserRole +from models.enums import CreatorUserRole, DatasetQuerySource logger = logging.getLogger(__name__) @@ -97,7 +97,7 @@ class HitTestingService: dataset_query = DatasetQuery( dataset_id=dataset.id, content=json.dumps(dataset_queries), - source="hit_testing", + source=DatasetQuerySource.HIT_TESTING, source_app_id=None, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, @@ -137,7 +137,7 @@ class HitTestingService: dataset_query = DatasetQuery( dataset_id=dataset.id, content=query, - source="hit_testing", + source=DatasetQuerySource.HIT_TESTING, source_app_id=None, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, diff --git a/api/services/metadata_service.py b/api/services/metadata_service.py index 859fc1902b..2f47a647a8 100644 --- a/api/services/metadata_service.py +++ b/api/services/metadata_service.py @@ -7,6 +7,7 @@ from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from libs.login import current_account_with_tenant from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding +from models.enums import DatasetMetadataType from services.dataset_service import DocumentService from services.entities.knowledge_entities.knowledge_entities import ( MetadataArgs, @@ -130,11 +131,11 @@ class MetadataService: @staticmethod def get_built_in_fields(): return [ - {"name": BuiltInField.document_name, "type": "string"}, - {"name": BuiltInField.uploader, "type": "string"}, - {"name": BuiltInField.upload_date, "type": "time"}, - {"name": BuiltInField.last_update_date, "type": "time"}, - {"name": BuiltInField.source, "type": "string"}, + {"name": BuiltInField.document_name, "type": DatasetMetadataType.STRING}, + {"name": BuiltInField.uploader, "type": DatasetMetadataType.STRING}, + {"name": BuiltInField.upload_date, "type": DatasetMetadataType.TIME}, + {"name": BuiltInField.last_update_date, "type": DatasetMetadataType.TIME}, + {"name": BuiltInField.source, "type": DatasetMetadataType.STRING}, ] @staticmethod diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index 2133dc5b3a..bf3b6db3ed 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -19,6 +19,7 @@ from dify_graph.model_runtime.entities.provider_entities import ( from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from extensions.ext_database import db from libs.datetime_utils import naive_utc_now +from models.enums import CredentialSourceType from models.provider import LoadBalancingModelConfig, ProviderCredential, ProviderModelCredential logger = logging.getLogger(__name__) @@ -103,9 +104,9 @@ class ModelLoadBalancingService: is_load_balancing_enabled = True if config_from == "predefined-model": - credential_source_type = "provider" + credential_source_type = CredentialSourceType.PROVIDER else: - credential_source_type = "custom_model" + credential_source_type = CredentialSourceType.CUSTOM_MODEL # Get load balancing configurations load_balancing_configs = ( @@ -421,7 +422,11 @@ class ModelLoadBalancingService: raise ValueError("Invalid load balancing config name") if credential_id: - credential_source = "provider" if config_from == "predefined-model" else "custom_model" + credential_source = ( + CredentialSourceType.PROVIDER + if config_from == "predefined-model" + else CredentialSourceType.CUSTOM_MODEL + ) assert credential_record is not None load_balancing_model_config = LoadBalancingModelConfig( tenant_id=tenant_id, diff --git a/api/services/plugin/plugin_service.py b/api/services/plugin/plugin_service.py index 55a3ffde78..ca83742d65 100644 --- a/api/services/plugin/plugin_service.py +++ b/api/services/plugin/plugin_service.py @@ -30,7 +30,7 @@ from core.plugin.impl.debugging import PluginDebuggingClient from core.plugin.impl.plugin import PluginInstaller from extensions.ext_database import db from extensions.ext_redis import redis_client -from models.provider import Provider, ProviderCredential +from models.provider import Provider, ProviderCredential, TenantPreferredModelProvider from models.provider_ids import GenericProviderID from services.enterprise.plugin_manager_service import ( PluginManagerService, @@ -534,6 +534,13 @@ class PluginService: plugin_id = plugin.plugin_id logger.info("Deleting credentials for plugin: %s", plugin_id) + session.execute( + delete(TenantPreferredModelProvider).where( + TenantPreferredModelProvider.tenant_id == tenant_id, + TenantPreferredModelProvider.provider_name.like(f"{plugin_id}/%"), + ) + ) + # Delete provider credentials that match this plugin credential_ids = session.scalars( select(ProviderCredential.id).where( diff --git a/api/services/rag_pipeline/pipeline_generate_service.py b/api/services/rag_pipeline/pipeline_generate_service.py index f397b28283..07e1b8f20e 100644 --- a/api/services/rag_pipeline/pipeline_generate_service.py +++ b/api/services/rag_pipeline/pipeline_generate_service.py @@ -6,6 +6,7 @@ from core.app.apps.pipeline.pipeline_generator import PipelineGenerator from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from models.dataset import Document, Pipeline +from models.enums import IndexingStatus from models.model import Account, App, EndUser from models.workflow import Workflow from services.rag_pipeline.rag_pipeline import RagPipelineService @@ -111,6 +112,6 @@ class PipelineGenerateService: """ document = db.session.query(Document).where(Document.id == document_id).first() if document: - document.indexing_status = "waiting" + document.indexing_status = IndexingStatus.WAITING db.session.add(document) db.session.commit() diff --git a/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py index 571ca6c7a6..f996db11dc 100644 --- a/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py @@ -15,7 +15,8 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): Retrieval recommended app from dify official """ - def get_pipeline_template_detail(self, template_id: str): + def get_pipeline_template_detail(self, template_id: str) -> dict | None: + result: dict | None try: result = self.fetch_pipeline_template_detail_from_dify_official(template_id) except Exception as e: @@ -35,17 +36,23 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): return PipelineTemplateType.REMOTE @classmethod - def fetch_pipeline_template_detail_from_dify_official(cls, template_id: str) -> dict | None: + def fetch_pipeline_template_detail_from_dify_official(cls, template_id: str) -> dict: """ Fetch pipeline template detail from dify official. - :param template_id: Pipeline ID - :return: + + :param template_id: Pipeline template ID + :return: Template detail dict + :raises ValueError: When upstream returns a non-200 status code """ domain = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN url = f"{domain}/pipeline-templates/{template_id}" response = httpx.get(url, timeout=httpx.Timeout(10.0, connect=3.0)) if response.status_code != 200: - return None + raise ValueError( + "fetch pipeline template detail failed," + + f" status_code: {response.status_code}," + + f" response: {response.text[:1000]}" + ) data: dict = response.json() return data diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 2118043a98..f3aedafac9 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -64,7 +64,7 @@ from models.dataset import ( # type: ignore PipelineCustomizedTemplate, PipelineRecommendedPlugin, ) -from models.enums import WorkflowRunTriggeredFrom +from models.enums import IndexingStatus, WorkflowRunTriggeredFrom from models.model import EndUser from models.workflow import ( Workflow, @@ -117,13 +117,21 @@ class RagPipelineService: def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> dict | None: """ Get pipeline template detail. + :param template_id: template id - :return: + :param type: template type, "built-in" or "customized" + :return: template detail dict, or None if not found """ if type == "built-in": mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() built_in_result: dict | None = retrieval_instance.get_pipeline_template_detail(template_id) + if built_in_result is None: + logger.warning( + "pipeline template retrieval returned empty result, template_id: %s, mode: %s", + template_id, + mode, + ) return built_in_result else: mode = "customized" @@ -906,7 +914,7 @@ class RagPipelineService: if document_id: document = db.session.query(Document).where(Document.id == document_id.value).first() if document: - document.indexing_status = "error" + document.indexing_status = IndexingStatus.ERROR document.error = error db.session.add(document) db.session.commit() diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index c7da1afe1b..deb59da8d3 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -35,6 +35,7 @@ from extensions.ext_redis import redis_client from factories import variable_factory from models import Account from models.dataset import Dataset, DatasetCollectionBinding, Pipeline +from models.enums import CollectionBindingType, DatasetRuntimeMode from models.workflow import Workflow, WorkflowType from services.entities.knowledge_entities.rag_pipeline_entities import ( IconInfo, @@ -313,7 +314,7 @@ class RagPipelineDslService: indexing_technique=knowledge_configuration.indexing_technique, created_by=account.id, retrieval_model=knowledge_configuration.retrieval_model.model_dump(), - runtime_mode="rag_pipeline", + runtime_mode=DatasetRuntimeMode.RAG_PIPELINE, chunk_structure=knowledge_configuration.chunk_structure, ) if knowledge_configuration.indexing_technique == "high_quality": @@ -323,7 +324,7 @@ class RagPipelineDslService: DatasetCollectionBinding.provider_name == knowledge_configuration.embedding_model_provider, DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model, - DatasetCollectionBinding.type == "dataset", + DatasetCollectionBinding.type == CollectionBindingType.DATASET, ) .order_by(DatasetCollectionBinding.created_at) .first() @@ -334,7 +335,7 @@ class RagPipelineDslService: provider_name=knowledge_configuration.embedding_model_provider, model_name=knowledge_configuration.embedding_model, collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())), - type="dataset", + type=CollectionBindingType.DATASET, ) self._session.add(dataset_collection_binding) self._session.commit() @@ -445,13 +446,13 @@ class RagPipelineDslService: indexing_technique=knowledge_configuration.indexing_technique, created_by=account.id, retrieval_model=knowledge_configuration.retrieval_model.model_dump(), - runtime_mode="rag_pipeline", + runtime_mode=DatasetRuntimeMode.RAG_PIPELINE, chunk_structure=knowledge_configuration.chunk_structure, ) else: dataset.indexing_technique = knowledge_configuration.indexing_technique dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump() - dataset.runtime_mode = "rag_pipeline" + dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE dataset.chunk_structure = knowledge_configuration.chunk_structure if knowledge_configuration.indexing_technique == "high_quality": dataset_collection_binding = ( @@ -460,7 +461,7 @@ class RagPipelineDslService: DatasetCollectionBinding.provider_name == knowledge_configuration.embedding_model_provider, DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model, - DatasetCollectionBinding.type == "dataset", + DatasetCollectionBinding.type == CollectionBindingType.DATASET, ) .order_by(DatasetCollectionBinding.created_at) .first() @@ -471,7 +472,7 @@ class RagPipelineDslService: provider_name=knowledge_configuration.embedding_model_provider, model_name=knowledge_configuration.embedding_model, collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())), - type="dataset", + type=CollectionBindingType.DATASET, ) self._session.add(dataset_collection_binding) self._session.commit() diff --git a/api/services/rag_pipeline/rag_pipeline_transform_service.py b/api/services/rag_pipeline/rag_pipeline_transform_service.py index cee18387b3..1d0aafd5fd 100644 --- a/api/services/rag_pipeline/rag_pipeline_transform_service.py +++ b/api/services/rag_pipeline/rag_pipeline_transform_service.py @@ -13,6 +13,7 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from factories import variable_factory from models.dataset import Dataset, Document, DocumentPipelineExecutionLog, Pipeline +from models.enums import DatasetRuntimeMode, DataSourceType from models.model import UploadFile from models.workflow import Workflow, WorkflowType from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeConfiguration, RetrievalSetting @@ -27,7 +28,7 @@ class RagPipelineTransformService: dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: raise ValueError("Dataset not found") - if dataset.pipeline_id and dataset.runtime_mode == "rag_pipeline": + if dataset.pipeline_id and dataset.runtime_mode == DatasetRuntimeMode.RAG_PIPELINE: return { "pipeline_id": dataset.pipeline_id, "dataset_id": dataset_id, @@ -85,7 +86,7 @@ class RagPipelineTransformService: else: raise ValueError("Unsupported doc form") - dataset.runtime_mode = "rag_pipeline" + dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE dataset.pipeline_id = pipeline.id # deal document data @@ -102,7 +103,7 @@ class RagPipelineTransformService: pipeline_yaml = {} if doc_form == "text_model": match datasource_type: - case "upload_file": + case DataSourceType.UPLOAD_FILE: if indexing_technique == "high_quality": # get graph from transform.file-general-high-quality.yml with open(f"{Path(__file__).parent}/transform/file-general-high-quality.yml") as f: @@ -111,7 +112,7 @@ class RagPipelineTransformService: # get graph from transform.file-general-economy.yml with open(f"{Path(__file__).parent}/transform/file-general-economy.yml") as f: pipeline_yaml = yaml.safe_load(f) - case "notion_import": + case DataSourceType.NOTION_IMPORT: if indexing_technique == "high_quality": # get graph from transform.notion-general-high-quality.yml with open(f"{Path(__file__).parent}/transform/notion-general-high-quality.yml") as f: @@ -120,7 +121,7 @@ class RagPipelineTransformService: # get graph from transform.notion-general-economy.yml with open(f"{Path(__file__).parent}/transform/notion-general-economy.yml") as f: pipeline_yaml = yaml.safe_load(f) - case "website_crawl": + case DataSourceType.WEBSITE_CRAWL: if indexing_technique == "high_quality": # get graph from transform.website-crawl-general-high-quality.yml with open(f"{Path(__file__).parent}/transform/website-crawl-general-high-quality.yml") as f: @@ -133,15 +134,15 @@ class RagPipelineTransformService: raise ValueError("Unsupported datasource type") elif doc_form == "hierarchical_model": match datasource_type: - case "upload_file": + case DataSourceType.UPLOAD_FILE: # get graph from transform.file-parentchild.yml with open(f"{Path(__file__).parent}/transform/file-parentchild.yml") as f: pipeline_yaml = yaml.safe_load(f) - case "notion_import": + case DataSourceType.NOTION_IMPORT: # get graph from transform.notion-parentchild.yml with open(f"{Path(__file__).parent}/transform/notion-parentchild.yml") as f: pipeline_yaml = yaml.safe_load(f) - case "website_crawl": + case DataSourceType.WEBSITE_CRAWL: # get graph from transform.website-crawl-parentchild.yml with open(f"{Path(__file__).parent}/transform/website-crawl-parentchild.yml") as f: pipeline_yaml = yaml.safe_load(f) @@ -287,7 +288,7 @@ class RagPipelineTransformService: db.session.flush() dataset.pipeline_id = pipeline.id - dataset.runtime_mode = "rag_pipeline" + dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE dataset.updated_by = current_user.id dataset.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.add(dataset) @@ -310,8 +311,8 @@ class RagPipelineTransformService: data_source_info_dict = document.data_source_info_dict if not data_source_info_dict: continue - if document.data_source_type == "upload_file": - document.data_source_type = "local_file" + if document.data_source_type == DataSourceType.UPLOAD_FILE: + document.data_source_type = DataSourceType.LOCAL_FILE file_id = data_source_info_dict.get("upload_file_id") if file_id: file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() @@ -331,7 +332,7 @@ class RagPipelineTransformService: document_pipeline_execution_log = DocumentPipelineExecutionLog( document_id=document.id, pipeline_id=dataset.pipeline_id, - datasource_type="local_file", + datasource_type=DataSourceType.LOCAL_FILE, datasource_info=data_source_info, input_data={}, created_by=document.created_by, @@ -340,8 +341,8 @@ class RagPipelineTransformService: document_pipeline_execution_log.created_at = document.created_at db.session.add(document) db.session.add(document_pipeline_execution_log) - elif document.data_source_type == "notion_import": - document.data_source_type = "online_document" + elif document.data_source_type == DataSourceType.NOTION_IMPORT: + document.data_source_type = DataSourceType.ONLINE_DOCUMENT data_source_info = json.dumps( { "workspace_id": data_source_info_dict.get("notion_workspace_id"), @@ -359,7 +360,7 @@ class RagPipelineTransformService: document_pipeline_execution_log = DocumentPipelineExecutionLog( document_id=document.id, pipeline_id=dataset.pipeline_id, - datasource_type="online_document", + datasource_type=DataSourceType.ONLINE_DOCUMENT, datasource_info=data_source_info, input_data={}, created_by=document.created_by, @@ -368,8 +369,7 @@ class RagPipelineTransformService: document_pipeline_execution_log.created_at = document.created_at db.session.add(document) db.session.add(document_pipeline_execution_log) - elif document.data_source_type == "website_crawl": - document.data_source_type = "website_crawl" + elif document.data_source_type == DataSourceType.WEBSITE_CRAWL: data_source_info = json.dumps( { "source_url": data_source_info_dict.get("url"), @@ -388,7 +388,7 @@ class RagPipelineTransformService: document_pipeline_execution_log = DocumentPipelineExecutionLog( document_id=document.id, pipeline_id=dataset.pipeline_id, - datasource_type="website_crawl", + datasource_type=DataSourceType.WEBSITE_CRAWL, datasource_info=data_source_info, input_data={}, created_by=document.created_by, diff --git a/api/services/retention/conversation/messages_clean_service.py b/api/services/retention/conversation/messages_clean_service.py index 04265817d7..48c3e72af0 100644 --- a/api/services/retention/conversation/messages_clean_service.py +++ b/api/services/retention/conversation/messages_clean_service.py @@ -1,16 +1,16 @@ import datetime import logging -import os import random import time from collections.abc import Sequence -from typing import cast +from typing import TYPE_CHECKING, cast import sqlalchemy as sa from sqlalchemy import delete, select, tuple_ from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session +from configs import dify_config from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.model import ( @@ -33,6 +33,131 @@ from services.retention.conversation.messages_clean_policy import ( logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from opentelemetry.metrics import Counter, Histogram + + +class MessagesCleanupMetrics: + """ + Records low-cardinality OpenTelemetry metrics for expired message cleanup jobs. + + We keep labels stable (dry_run/window_mode/task_label/status) so these metrics remain + dashboard-friendly for long-running CronJob executions. + """ + + _job_runs_total: "Counter | None" + _batches_total: "Counter | None" + _messages_scanned_total: "Counter | None" + _messages_filtered_total: "Counter | None" + _messages_deleted_total: "Counter | None" + _job_duration_seconds: "Histogram | None" + _batch_duration_seconds: "Histogram | None" + _base_attributes: dict[str, str] + + def __init__(self, *, dry_run: bool, has_window: bool, task_label: str) -> None: + self._job_runs_total = None + self._batches_total = None + self._messages_scanned_total = None + self._messages_filtered_total = None + self._messages_deleted_total = None + self._job_duration_seconds = None + self._batch_duration_seconds = None + self._base_attributes = { + "job_name": "messages_cleanup", + "dry_run": str(dry_run).lower(), + "window_mode": "between" if has_window else "before_cutoff", + "task_label": task_label, + } + self._init_instruments() + + def _init_instruments(self) -> None: + if not dify_config.ENABLE_OTEL: + return + + try: + from opentelemetry.metrics import get_meter + + meter = get_meter("messages_cleanup", version=dify_config.project.version) + self._job_runs_total = meter.create_counter( + "messages_cleanup_jobs_total", + description="Total number of expired message cleanup jobs by status.", + unit="{job}", + ) + self._batches_total = meter.create_counter( + "messages_cleanup_batches_total", + description="Total number of message cleanup batches processed.", + unit="{batch}", + ) + self._messages_scanned_total = meter.create_counter( + "messages_cleanup_scanned_messages_total", + description="Total messages scanned by cleanup jobs.", + unit="{message}", + ) + self._messages_filtered_total = meter.create_counter( + "messages_cleanup_filtered_messages_total", + description="Total messages selected by cleanup policy.", + unit="{message}", + ) + self._messages_deleted_total = meter.create_counter( + "messages_cleanup_deleted_messages_total", + description="Total messages deleted by cleanup jobs.", + unit="{message}", + ) + self._job_duration_seconds = meter.create_histogram( + "messages_cleanup_job_duration_seconds", + description="Duration of expired message cleanup jobs in seconds.", + unit="s", + ) + self._batch_duration_seconds = meter.create_histogram( + "messages_cleanup_batch_duration_seconds", + description="Duration of expired message cleanup batch processing in seconds.", + unit="s", + ) + except Exception: + logger.exception("messages_cleanup_metrics: failed to initialize instruments") + + def _attrs(self, **extra: str) -> dict[str, str]: + return {**self._base_attributes, **extra} + + @staticmethod + def _add(counter: "Counter | None", value: int, attributes: dict[str, str]) -> None: + if not counter or value <= 0: + return + try: + counter.add(value, attributes) + except Exception: + logger.exception("messages_cleanup_metrics: failed to add counter value") + + @staticmethod + def _record(histogram: "Histogram | None", value: float, attributes: dict[str, str]) -> None: + if not histogram: + return + try: + histogram.record(value, attributes) + except Exception: + logger.exception("messages_cleanup_metrics: failed to record histogram value") + + def record_batch( + self, + *, + scanned_messages: int, + filtered_messages: int, + deleted_messages: int, + batch_duration_seconds: float, + ) -> None: + attributes = self._attrs() + self._add(self._batches_total, 1, attributes) + self._add(self._messages_scanned_total, scanned_messages, attributes) + self._add(self._messages_filtered_total, filtered_messages, attributes) + self._add(self._messages_deleted_total, deleted_messages, attributes) + self._record(self._batch_duration_seconds, batch_duration_seconds, attributes) + + def record_completion(self, *, status: str, job_duration_seconds: float) -> None: + attributes = self._attrs(status=status) + self._add(self._job_runs_total, 1, attributes) + self._record(self._job_duration_seconds, job_duration_seconds, attributes) + + class MessagesCleanService: """ Service for cleaning expired messages based on retention policies. @@ -48,6 +173,7 @@ class MessagesCleanService: start_from: datetime.datetime | None = None, batch_size: int = 1000, dry_run: bool = False, + task_label: str = "custom", ) -> None: """ Initialize the service with cleanup parameters. @@ -58,12 +184,18 @@ class MessagesCleanService: start_from: Optional start time (inclusive) of the range batch_size: Number of messages to process per batch dry_run: Whether to perform a dry run (no actual deletion) + task_label: Optional task label for retention metrics """ self._policy = policy self._end_before = end_before self._start_from = start_from self._batch_size = batch_size self._dry_run = dry_run + self._metrics = MessagesCleanupMetrics( + dry_run=dry_run, + has_window=bool(start_from), + task_label=task_label, + ) @classmethod def from_time_range( @@ -73,6 +205,7 @@ class MessagesCleanService: end_before: datetime.datetime, batch_size: int = 1000, dry_run: bool = False, + task_label: str = "custom", ) -> "MessagesCleanService": """ Create a service instance for cleaning messages within a specific time range. @@ -85,6 +218,7 @@ class MessagesCleanService: end_before: End time (exclusive) of the range batch_size: Number of messages to process per batch dry_run: Whether to perform a dry run (no actual deletion) + task_label: Optional task label for retention metrics Returns: MessagesCleanService instance @@ -112,6 +246,7 @@ class MessagesCleanService: start_from=start_from, batch_size=batch_size, dry_run=dry_run, + task_label=task_label, ) @classmethod @@ -121,6 +256,7 @@ class MessagesCleanService: days: int = 30, batch_size: int = 1000, dry_run: bool = False, + task_label: str = "custom", ) -> "MessagesCleanService": """ Create a service instance for cleaning messages older than specified days. @@ -130,6 +266,7 @@ class MessagesCleanService: days: Number of days to look back from now batch_size: Number of messages to process per batch dry_run: Whether to perform a dry run (no actual deletion) + task_label: Optional task label for retention metrics Returns: MessagesCleanService instance @@ -153,7 +290,14 @@ class MessagesCleanService: policy.__class__.__name__, ) - return cls(policy=policy, end_before=end_before, start_from=None, batch_size=batch_size, dry_run=dry_run) + return cls( + policy=policy, + end_before=end_before, + start_from=None, + batch_size=batch_size, + dry_run=dry_run, + task_label=task_label, + ) def run(self) -> dict[str, int]: """ @@ -162,7 +306,18 @@ class MessagesCleanService: Returns: Dict with statistics: batches, filtered_messages, total_deleted """ - return self._clean_messages_by_time_range() + status = "success" + run_start = time.monotonic() + try: + return self._clean_messages_by_time_range() + except Exception: + status = "failed" + raise + finally: + self._metrics.record_completion( + status=status, + job_duration_seconds=time.monotonic() - run_start, + ) def _clean_messages_by_time_range(self) -> dict[str, int]: """ @@ -197,11 +352,14 @@ class MessagesCleanService: self._end_before, ) - max_batch_interval_ms = int(os.environ.get("SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL", 200)) + max_batch_interval_ms = dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL while True: stats["batches"] += 1 batch_start = time.monotonic() + batch_scanned_messages = 0 + batch_filtered_messages = 0 + batch_deleted_messages = 0 # Step 1: Fetch a batch of messages using cursor with Session(db.engine, expire_on_commit=False) as session: @@ -240,9 +398,16 @@ class MessagesCleanService: # Track total messages fetched across all batches stats["total_messages"] += len(messages) + batch_scanned_messages = len(messages) if not messages: logger.info("clean_messages (batch %s): no more messages to process", stats["batches"]) + self._metrics.record_batch( + scanned_messages=batch_scanned_messages, + filtered_messages=batch_filtered_messages, + deleted_messages=batch_deleted_messages, + batch_duration_seconds=time.monotonic() - batch_start, + ) break # Update cursor to the last message's (created_at, id) @@ -268,6 +433,12 @@ class MessagesCleanService: if not apps: logger.info("clean_messages (batch %s): no apps found, skip", stats["batches"]) + self._metrics.record_batch( + scanned_messages=batch_scanned_messages, + filtered_messages=batch_filtered_messages, + deleted_messages=batch_deleted_messages, + batch_duration_seconds=time.monotonic() - batch_start, + ) continue # Build app_id -> tenant_id mapping @@ -286,9 +457,16 @@ class MessagesCleanService: if not message_ids_to_delete: logger.info("clean_messages (batch %s): no messages to delete, skip", stats["batches"]) + self._metrics.record_batch( + scanned_messages=batch_scanned_messages, + filtered_messages=batch_filtered_messages, + deleted_messages=batch_deleted_messages, + batch_duration_seconds=time.monotonic() - batch_start, + ) continue stats["filtered_messages"] += len(message_ids_to_delete) + batch_filtered_messages = len(message_ids_to_delete) # Step 4: Batch delete messages and their relations if not self._dry_run: @@ -309,6 +487,7 @@ class MessagesCleanService: commit_ms = int((time.monotonic() - commit_start) * 1000) stats["total_deleted"] += messages_deleted + batch_deleted_messages = messages_deleted logger.info( "clean_messages (batch %s): processed %s messages, deleted %s messages", @@ -343,6 +522,13 @@ class MessagesCleanService: for msg_id in sampled_ids: logger.info("clean_messages (batch %s, dry_run) sample: message_id=%s", stats["batches"], msg_id) + self._metrics.record_batch( + scanned_messages=batch_scanned_messages, + filtered_messages=batch_filtered_messages, + deleted_messages=batch_deleted_messages, + batch_duration_seconds=time.monotonic() - batch_start, + ) + logger.info( "clean_messages completed: total batches: %s, total messages: %s, filtered messages: %s, total deleted: %s", stats["batches"], diff --git a/api/services/retention/workflow_run/clear_free_plan_expired_workflow_run_logs.py b/api/services/retention/workflow_run/clear_free_plan_expired_workflow_run_logs.py index 2c94cb5324..62bc9f5f10 100644 --- a/api/services/retention/workflow_run/clear_free_plan_expired_workflow_run_logs.py +++ b/api/services/retention/workflow_run/clear_free_plan_expired_workflow_run_logs.py @@ -1,9 +1,9 @@ import datetime import logging -import os import random import time from collections.abc import Iterable, Sequence +from typing import TYPE_CHECKING import click from sqlalchemy.orm import Session, sessionmaker @@ -20,6 +20,159 @@ from services.billing_service import BillingService, SubscriptionPlan logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from opentelemetry.metrics import Counter, Histogram + + +class WorkflowRunCleanupMetrics: + """ + Records low-cardinality OpenTelemetry metrics for workflow run cleanup jobs. + + Metrics are emitted with stable labels only (dry_run/window_mode/task_label/status) + to keep dashboard and alert cardinality predictable in production clusters. + """ + + _job_runs_total: "Counter | None" + _batches_total: "Counter | None" + _runs_scanned_total: "Counter | None" + _runs_targeted_total: "Counter | None" + _runs_deleted_total: "Counter | None" + _runs_skipped_total: "Counter | None" + _related_records_total: "Counter | None" + _job_duration_seconds: "Histogram | None" + _batch_duration_seconds: "Histogram | None" + _base_attributes: dict[str, str] + + def __init__(self, *, dry_run: bool, has_window: bool, task_label: str) -> None: + self._job_runs_total = None + self._batches_total = None + self._runs_scanned_total = None + self._runs_targeted_total = None + self._runs_deleted_total = None + self._runs_skipped_total = None + self._related_records_total = None + self._job_duration_seconds = None + self._batch_duration_seconds = None + self._base_attributes = { + "job_name": "workflow_run_cleanup", + "dry_run": str(dry_run).lower(), + "window_mode": "between" if has_window else "before_cutoff", + "task_label": task_label, + } + self._init_instruments() + + def _init_instruments(self) -> None: + if not dify_config.ENABLE_OTEL: + return + + try: + from opentelemetry.metrics import get_meter + + meter = get_meter("workflow_run_cleanup", version=dify_config.project.version) + self._job_runs_total = meter.create_counter( + "workflow_run_cleanup_jobs_total", + description="Total number of workflow run cleanup jobs by status.", + unit="{job}", + ) + self._batches_total = meter.create_counter( + "workflow_run_cleanup_batches_total", + description="Total number of processed cleanup batches.", + unit="{batch}", + ) + self._runs_scanned_total = meter.create_counter( + "workflow_run_cleanup_scanned_runs_total", + description="Total workflow runs scanned by cleanup jobs.", + unit="{run}", + ) + self._runs_targeted_total = meter.create_counter( + "workflow_run_cleanup_targeted_runs_total", + description="Total workflow runs targeted by cleanup policy.", + unit="{run}", + ) + self._runs_deleted_total = meter.create_counter( + "workflow_run_cleanup_deleted_runs_total", + description="Total workflow runs deleted by cleanup jobs.", + unit="{run}", + ) + self._runs_skipped_total = meter.create_counter( + "workflow_run_cleanup_skipped_runs_total", + description="Total workflow runs skipped because tenant is paid/unknown.", + unit="{run}", + ) + self._related_records_total = meter.create_counter( + "workflow_run_cleanup_related_records_total", + description="Total related records processed by cleanup jobs.", + unit="{record}", + ) + self._job_duration_seconds = meter.create_histogram( + "workflow_run_cleanup_job_duration_seconds", + description="Duration of workflow run cleanup jobs in seconds.", + unit="s", + ) + self._batch_duration_seconds = meter.create_histogram( + "workflow_run_cleanup_batch_duration_seconds", + description="Duration of workflow run cleanup batch processing in seconds.", + unit="s", + ) + except Exception: + logger.exception("workflow_run_cleanup_metrics: failed to initialize instruments") + + def _attrs(self, **extra: str) -> dict[str, str]: + return {**self._base_attributes, **extra} + + @staticmethod + def _add(counter: "Counter | None", value: int, attributes: dict[str, str]) -> None: + if not counter or value <= 0: + return + try: + counter.add(value, attributes) + except Exception: + logger.exception("workflow_run_cleanup_metrics: failed to add counter value") + + @staticmethod + def _record(histogram: "Histogram | None", value: float, attributes: dict[str, str]) -> None: + if not histogram: + return + try: + histogram.record(value, attributes) + except Exception: + logger.exception("workflow_run_cleanup_metrics: failed to record histogram value") + + def record_batch( + self, + *, + batch_rows: int, + targeted_runs: int, + skipped_runs: int, + deleted_runs: int, + related_counts: dict[str, int] | None, + related_action: str | None, + batch_duration_seconds: float, + ) -> None: + attributes = self._attrs() + self._add(self._batches_total, 1, attributes) + self._add(self._runs_scanned_total, batch_rows, attributes) + self._add(self._runs_targeted_total, targeted_runs, attributes) + self._add(self._runs_skipped_total, skipped_runs, attributes) + self._add(self._runs_deleted_total, deleted_runs, attributes) + self._record(self._batch_duration_seconds, batch_duration_seconds, attributes) + + if not related_counts or not related_action: + return + + for record_type, count in related_counts.items(): + self._add( + self._related_records_total, + count, + self._attrs(action=related_action, record_type=record_type), + ) + + def record_completion(self, *, status: str, job_duration_seconds: float) -> None: + attributes = self._attrs(status=status) + self._add(self._job_runs_total, 1, attributes) + self._record(self._job_duration_seconds, job_duration_seconds, attributes) + + class WorkflowRunCleanup: def __init__( self, @@ -29,6 +182,7 @@ class WorkflowRunCleanup: end_before: datetime.datetime | None = None, workflow_run_repo: APIWorkflowRunRepository | None = None, dry_run: bool = False, + task_label: str = "custom", ): if (start_from is None) ^ (end_before is None): raise ValueError("start_from and end_before must be both set or both omitted.") @@ -46,6 +200,11 @@ class WorkflowRunCleanup: self.batch_size = batch_size self._cleanup_whitelist: set[str] | None = None self.dry_run = dry_run + self._metrics = WorkflowRunCleanupMetrics( + dry_run=dry_run, + has_window=bool(start_from), + task_label=task_label, + ) self.free_plan_grace_period_days = dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD self.workflow_run_repo: APIWorkflowRunRepository if workflow_run_repo: @@ -74,153 +233,193 @@ class WorkflowRunCleanup: related_totals = self._empty_related_counts() if self.dry_run else None batch_index = 0 last_seen: tuple[datetime.datetime, str] | None = None + status = "success" + run_start = time.monotonic() + max_batch_interval_ms = dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL - max_batch_interval_ms = int(os.environ.get("SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL", 200)) + try: + while True: + batch_start = time.monotonic() - while True: - batch_start = time.monotonic() - - fetch_start = time.monotonic() - run_rows = self.workflow_run_repo.get_runs_batch_by_time_range( - start_from=self.window_start, - end_before=self.window_end, - last_seen=last_seen, - batch_size=self.batch_size, - ) - if not run_rows: - logger.info("workflow_run_cleanup (batch #%s): no more rows to process", batch_index + 1) - break - - batch_index += 1 - last_seen = (run_rows[-1].created_at, run_rows[-1].id) - logger.info( - "workflow_run_cleanup (batch #%s): fetched %s rows in %sms", - batch_index, - len(run_rows), - int((time.monotonic() - fetch_start) * 1000), - ) - - tenant_ids = {row.tenant_id for row in run_rows} - - filter_start = time.monotonic() - free_tenants = self._filter_free_tenants(tenant_ids) - logger.info( - "workflow_run_cleanup (batch #%s): filtered %s free tenants from %s tenants in %sms", - batch_index, - len(free_tenants), - len(tenant_ids), - int((time.monotonic() - filter_start) * 1000), - ) - - free_runs = [row for row in run_rows if row.tenant_id in free_tenants] - paid_or_skipped = len(run_rows) - len(free_runs) - - if not free_runs: - skipped_message = ( - f"[batch #{batch_index}] skipped (no sandbox runs in batch, {paid_or_skipped} paid/unknown)" + fetch_start = time.monotonic() + run_rows = self.workflow_run_repo.get_runs_batch_by_time_range( + start_from=self.window_start, + end_before=self.window_end, + last_seen=last_seen, + batch_size=self.batch_size, ) - click.echo( - click.style( - skipped_message, - fg="yellow", - ) - ) - continue + if not run_rows: + logger.info("workflow_run_cleanup (batch #%s): no more rows to process", batch_index + 1) + break - total_runs_targeted += len(free_runs) - - if self.dry_run: - count_start = time.monotonic() - batch_counts = self.workflow_run_repo.count_runs_with_related( - free_runs, - count_node_executions=self._count_node_executions, - count_trigger_logs=self._count_trigger_logs, - ) + batch_index += 1 + last_seen = (run_rows[-1].created_at, run_rows[-1].id) logger.info( - "workflow_run_cleanup (batch #%s, dry_run): counted related records in %sms", + "workflow_run_cleanup (batch #%s): fetched %s rows in %sms", batch_index, - int((time.monotonic() - count_start) * 1000), + len(run_rows), + int((time.monotonic() - fetch_start) * 1000), ) - if related_totals is not None: - for key in related_totals: - related_totals[key] += batch_counts.get(key, 0) - sample_ids = ", ".join(run.id for run in free_runs[:5]) + + tenant_ids = {row.tenant_id for row in run_rows} + + filter_start = time.monotonic() + free_tenants = self._filter_free_tenants(tenant_ids) + logger.info( + "workflow_run_cleanup (batch #%s): filtered %s free tenants from %s tenants in %sms", + batch_index, + len(free_tenants), + len(tenant_ids), + int((time.monotonic() - filter_start) * 1000), + ) + + free_runs = [row for row in run_rows if row.tenant_id in free_tenants] + paid_or_skipped = len(run_rows) - len(free_runs) + + if not free_runs: + skipped_message = ( + f"[batch #{batch_index}] skipped (no sandbox runs in batch, {paid_or_skipped} paid/unknown)" + ) + click.echo( + click.style( + skipped_message, + fg="yellow", + ) + ) + self._metrics.record_batch( + batch_rows=len(run_rows), + targeted_runs=0, + skipped_runs=paid_or_skipped, + deleted_runs=0, + related_counts=None, + related_action=None, + batch_duration_seconds=time.monotonic() - batch_start, + ) + continue + + total_runs_targeted += len(free_runs) + + if self.dry_run: + count_start = time.monotonic() + batch_counts = self.workflow_run_repo.count_runs_with_related( + free_runs, + count_node_executions=self._count_node_executions, + count_trigger_logs=self._count_trigger_logs, + ) + logger.info( + "workflow_run_cleanup (batch #%s, dry_run): counted related records in %sms", + batch_index, + int((time.monotonic() - count_start) * 1000), + ) + if related_totals is not None: + for key in related_totals: + related_totals[key] += batch_counts.get(key, 0) + sample_ids = ", ".join(run.id for run in free_runs[:5]) + click.echo( + click.style( + f"[batch #{batch_index}] would delete {len(free_runs)} runs " + f"(sample ids: {sample_ids}) and skip {paid_or_skipped} paid/unknown", + fg="yellow", + ) + ) + logger.info( + "workflow_run_cleanup (batch #%s, dry_run): batch total %sms", + batch_index, + int((time.monotonic() - batch_start) * 1000), + ) + self._metrics.record_batch( + batch_rows=len(run_rows), + targeted_runs=len(free_runs), + skipped_runs=paid_or_skipped, + deleted_runs=0, + related_counts={key: batch_counts.get(key, 0) for key in self._empty_related_counts()}, + related_action="would_delete", + batch_duration_seconds=time.monotonic() - batch_start, + ) + continue + + try: + delete_start = time.monotonic() + counts = self.workflow_run_repo.delete_runs_with_related( + free_runs, + delete_node_executions=self._delete_node_executions, + delete_trigger_logs=self._delete_trigger_logs, + ) + delete_ms = int((time.monotonic() - delete_start) * 1000) + except Exception: + logger.exception("Failed to delete workflow runs batch ending at %s", last_seen[0]) + raise + + total_runs_deleted += counts["runs"] click.echo( click.style( - f"[batch #{batch_index}] would delete {len(free_runs)} runs " - f"(sample ids: {sample_ids}) and skip {paid_or_skipped} paid/unknown", - fg="yellow", + f"[batch #{batch_index}] deleted runs: {counts['runs']} " + f"(nodes {counts['node_executions']}, offloads {counts['offloads']}, " + f"app_logs {counts['app_logs']}, trigger_logs {counts['trigger_logs']}, " + f"pauses {counts['pauses']}, pause_reasons {counts['pause_reasons']}); " + f"skipped {paid_or_skipped} paid/unknown", + fg="green", ) ) logger.info( - "workflow_run_cleanup (batch #%s, dry_run): batch total %sms", + "workflow_run_cleanup (batch #%s): delete %sms, batch total %sms", batch_index, + delete_ms, int((time.monotonic() - batch_start) * 1000), ) - continue - - try: - delete_start = time.monotonic() - counts = self.workflow_run_repo.delete_runs_with_related( - free_runs, - delete_node_executions=self._delete_node_executions, - delete_trigger_logs=self._delete_trigger_logs, + self._metrics.record_batch( + batch_rows=len(run_rows), + targeted_runs=len(free_runs), + skipped_runs=paid_or_skipped, + deleted_runs=counts["runs"], + related_counts={key: counts.get(key, 0) for key in self._empty_related_counts()}, + related_action="deleted", + batch_duration_seconds=time.monotonic() - batch_start, ) - delete_ms = int((time.monotonic() - delete_start) * 1000) - except Exception: - logger.exception("Failed to delete workflow runs batch ending at %s", last_seen[0]) - raise - total_runs_deleted += counts["runs"] - click.echo( - click.style( - f"[batch #{batch_index}] deleted runs: {counts['runs']} " - f"(nodes {counts['node_executions']}, offloads {counts['offloads']}, " - f"app_logs {counts['app_logs']}, trigger_logs {counts['trigger_logs']}, " - f"pauses {counts['pauses']}, pause_reasons {counts['pause_reasons']}); " - f"skipped {paid_or_skipped} paid/unknown", - fg="green", - ) - ) - logger.info( - "workflow_run_cleanup (batch #%s): delete %sms, batch total %sms", - batch_index, - delete_ms, - int((time.monotonic() - batch_start) * 1000), - ) + # Random sleep between batches to avoid overwhelming the database + sleep_ms = random.uniform(0, max_batch_interval_ms) # noqa: S311 + logger.info("workflow_run_cleanup (batch #%s): sleeping for %.2fms", batch_index, sleep_ms) + time.sleep(sleep_ms / 1000) - # Random sleep between batches to avoid overwhelming the database - sleep_ms = random.uniform(0, max_batch_interval_ms) # noqa: S311 - logger.info("workflow_run_cleanup (batch #%s): sleeping for %.2fms", batch_index, sleep_ms) - time.sleep(sleep_ms / 1000) - - if self.dry_run: - if self.window_start: - summary_message = ( - f"Dry run complete. Would delete {total_runs_targeted} workflow runs " - f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}" - ) + if self.dry_run: + if self.window_start: + summary_message = ( + f"Dry run complete. Would delete {total_runs_targeted} workflow runs " + f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}" + ) + else: + summary_message = ( + f"Dry run complete. Would delete {total_runs_targeted} workflow runs " + f"before {self.window_end.isoformat()}" + ) + if related_totals is not None: + summary_message = ( + f"{summary_message}; related records: {self._format_related_counts(related_totals)}" + ) + summary_color = "yellow" else: - summary_message = ( - f"Dry run complete. Would delete {total_runs_targeted} workflow runs " - f"before {self.window_end.isoformat()}" - ) - if related_totals is not None: - summary_message = f"{summary_message}; related records: {self._format_related_counts(related_totals)}" - summary_color = "yellow" - else: - if self.window_start: - summary_message = ( - f"Cleanup complete. Deleted {total_runs_deleted} workflow runs " - f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}" - ) - else: - summary_message = ( - f"Cleanup complete. Deleted {total_runs_deleted} workflow runs before {self.window_end.isoformat()}" - ) - summary_color = "white" + if self.window_start: + summary_message = ( + f"Cleanup complete. Deleted {total_runs_deleted} workflow runs " + f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}" + ) + else: + summary_message = ( + f"Cleanup complete. Deleted {total_runs_deleted} workflow runs " + f"before {self.window_end.isoformat()}" + ) + summary_color = "white" - click.echo(click.style(summary_message, fg=summary_color)) + click.echo(click.style(summary_message, fg=summary_color)) + except Exception: + status = "failed" + raise + finally: + self._metrics.record_completion( + status=status, + job_duration_seconds=time.monotonic() - run_start, + ) def _filter_free_tenants(self, tenant_ids: Iterable[str]) -> set[str]: tenant_id_list = list(tenant_ids) diff --git a/api/services/summary_index_service.py b/api/services/summary_index_service.py index eb78be8f88..943dfc972b 100644 --- a/api/services/summary_index_service.py +++ b/api/services/summary_index_service.py @@ -12,12 +12,14 @@ from core.db.session_factory import session_factory from core.model_manager import ModelManager from core.rag.datasource.vdb.vector_factory import Vector from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.models.document import Document from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.model_runtime.entities.model_entities import ModelType from libs import helper from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary from models.dataset import Document as DatasetDocument +from models.enums import SummaryStatus logger = logging.getLogger(__name__) @@ -29,7 +31,7 @@ class SummaryIndexService: def generate_summary_for_segment( segment: DocumentSegment, dataset: Dataset, - summary_index_setting: dict, + summary_index_setting: SummaryIndexSettingDict, ) -> tuple[str, LLMUsage]: """ Generate summary for a single segment. @@ -73,7 +75,7 @@ class SummaryIndexService: segment: DocumentSegment, dataset: Dataset, summary_content: str, - status: str = "generating", + status: SummaryStatus = SummaryStatus.GENERATING, ) -> DocumentSegmentSummary: """ Create or update a DocumentSegmentSummary record. @@ -83,7 +85,7 @@ class SummaryIndexService: segment: DocumentSegment to create summary for dataset: Dataset containing the segment summary_content: Generated summary content - status: Summary status (default: "generating") + status: Summary status (default: SummaryStatus.GENERATING) Returns: Created or updated DocumentSegmentSummary instance @@ -326,7 +328,7 @@ class SummaryIndexService: summary_index_node_id=summary_index_node_id, summary_index_node_hash=summary_hash, tokens=embedding_tokens, - status="completed", + status=SummaryStatus.COMPLETED, enabled=True, ) session.add(summary_record_in_session) @@ -362,7 +364,7 @@ class SummaryIndexService: summary_record_in_session.summary_index_node_id = summary_index_node_id summary_record_in_session.summary_index_node_hash = summary_hash summary_record_in_session.tokens = embedding_tokens # Save embedding tokens - summary_record_in_session.status = "completed" + summary_record_in_session.status = SummaryStatus.COMPLETED # Ensure summary_content is preserved (use the latest from summary_record parameter) # This is critical: use the parameter value, not the database value summary_record_in_session.summary_content = summary_content @@ -400,7 +402,7 @@ class SummaryIndexService: summary_record.summary_index_node_id = summary_index_node_id summary_record.summary_index_node_hash = summary_hash summary_record.tokens = embedding_tokens - summary_record.status = "completed" + summary_record.status = SummaryStatus.COMPLETED summary_record.summary_content = summary_content if summary_record_in_session.updated_at: summary_record.updated_at = summary_record_in_session.updated_at @@ -487,7 +489,7 @@ class SummaryIndexService: ) if summary_record_in_session: - summary_record_in_session.status = "error" + summary_record_in_session.status = SummaryStatus.ERROR summary_record_in_session.error = f"Vectorization failed: {str(e)}" summary_record_in_session.updated_at = datetime.now(UTC).replace(tzinfo=None) error_session.add(summary_record_in_session) @@ -498,7 +500,7 @@ class SummaryIndexService: summary_record_in_session.id, ) # Update the original object for consistency - summary_record.status = "error" + summary_record.status = SummaryStatus.ERROR summary_record.error = summary_record_in_session.error summary_record.updated_at = summary_record_in_session.updated_at else: @@ -514,7 +516,7 @@ class SummaryIndexService: def batch_create_summary_records( segments: list[DocumentSegment], dataset: Dataset, - status: str = "not_started", + status: SummaryStatus = SummaryStatus.NOT_STARTED, ) -> None: """ Batch create summary records for segments with specified status. @@ -523,7 +525,7 @@ class SummaryIndexService: Args: segments: List of DocumentSegment instances dataset: Dataset containing the segments - status: Initial status for the records (default: "not_started") + status: Initial status for the records (default: SummaryStatus.NOT_STARTED) """ segment_ids = [segment.id for segment in segments] if not segment_ids: @@ -588,7 +590,7 @@ class SummaryIndexService: ) if summary_record: - summary_record.status = "error" + summary_record.status = SummaryStatus.ERROR summary_record.error = error session.add(summary_record) session.commit() @@ -599,7 +601,7 @@ class SummaryIndexService: def generate_and_vectorize_summary( segment: DocumentSegment, dataset: Dataset, - summary_index_setting: dict, + summary_index_setting: SummaryIndexSettingDict, ) -> DocumentSegmentSummary: """ Generate summary for a segment and vectorize it. @@ -631,14 +633,14 @@ class SummaryIndexService: document_id=segment.document_id, chunk_id=segment.id, summary_content="", - status="generating", + status=SummaryStatus.GENERATING, enabled=True, ) session.add(summary_record_in_session) session.flush() # Update status to "generating" - summary_record_in_session.status = "generating" + summary_record_in_session.status = SummaryStatus.GENERATING summary_record_in_session.error = None # type: ignore[assignment] session.add(summary_record_in_session) # Don't flush here - wait until after vectorization succeeds @@ -681,7 +683,7 @@ class SummaryIndexService: except Exception as vectorize_error: # If vectorization fails, update status to error in current session logger.exception("Failed to vectorize summary for segment %s", segment.id) - summary_record_in_session.status = "error" + summary_record_in_session.status = SummaryStatus.ERROR summary_record_in_session.error = f"Vectorization failed: {str(vectorize_error)}" session.add(summary_record_in_session) session.commit() @@ -694,7 +696,7 @@ class SummaryIndexService: session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() ) if summary_record_in_session: - summary_record_in_session.status = "error" + summary_record_in_session.status = SummaryStatus.ERROR summary_record_in_session.error = str(e) session.add(summary_record_in_session) session.commit() @@ -704,7 +706,7 @@ class SummaryIndexService: def generate_summaries_for_document( dataset: Dataset, document: DatasetDocument, - summary_index_setting: dict, + summary_index_setting: SummaryIndexSettingDict, segment_ids: list[str] | None = None, only_parent_chunks: bool = False, ) -> list[DocumentSegmentSummary]: @@ -770,7 +772,7 @@ class SummaryIndexService: SummaryIndexService.batch_create_summary_records( segments=segments, dataset=dataset, - status="not_started", + status=SummaryStatus.NOT_STARTED, ) summary_records = [] @@ -1067,7 +1069,7 @@ class SummaryIndexService: # Update summary content summary_record.summary_content = summary_content - summary_record.status = "generating" + summary_record.status = SummaryStatus.GENERATING summary_record.error = None # type: ignore[assignment] # Clear any previous errors session.add(summary_record) # Flush to ensure summary_content is saved before vectorize_summary queries it @@ -1102,7 +1104,7 @@ class SummaryIndexService: # If vectorization fails, update status to error in current session # Don't raise the exception - just log it and return the record with error status # This allows the segment update to complete even if vectorization fails - summary_record.status = "error" + summary_record.status = SummaryStatus.ERROR summary_record.error = f"Vectorization failed: {str(e)}" session.commit() logger.exception("Failed to vectorize summary for segment %s", segment.id) @@ -1112,7 +1114,7 @@ class SummaryIndexService: else: # Create new summary record if doesn't exist summary_record = SummaryIndexService.create_summary_record( - segment, dataset, summary_content, status="generating" + segment, dataset, summary_content, status=SummaryStatus.GENERATING ) # Re-vectorize summary (this will update status to "completed" and tokens in its own session) # Note: summary_record was created in a different session, @@ -1132,7 +1134,7 @@ class SummaryIndexService: # If vectorization fails, update status to error in current session # Merge the record into current session first error_record = session.merge(summary_record) - error_record.status = "error" + error_record.status = SummaryStatus.ERROR error_record.error = f"Vectorization failed: {str(e)}" session.commit() logger.exception("Failed to vectorize summary for segment %s", segment.id) @@ -1146,7 +1148,7 @@ class SummaryIndexService: session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() ) if summary_record: - summary_record.status = "error" + summary_record.status = SummaryStatus.ERROR summary_record.error = str(e) session.add(summary_record) session.commit() @@ -1266,7 +1268,7 @@ class SummaryIndexService: # Check if there are any "not_started" or "generating" status summaries has_pending_summaries = any( summary_status_map.get(segment_id) is not None # Ensure summary exists (enabled=True) - and summary_status_map[segment_id] in ("not_started", "generating") + and summary_status_map[segment_id] in (SummaryStatus.NOT_STARTED, SummaryStatus.GENERATING) for segment_id in segment_ids ) @@ -1330,7 +1332,7 @@ class SummaryIndexService: # it means the summary is disabled (enabled=False) or not created yet, ignore it has_pending_summaries = any( summary_status_map.get(segment_id) is not None # Ensure summary exists (enabled=True) - and summary_status_map[segment_id] in ("not_started", "generating") + and summary_status_map[segment_id] in (SummaryStatus.NOT_STARTED, SummaryStatus.GENERATING) for segment_id in segment_ids ) @@ -1393,17 +1395,17 @@ class SummaryIndexService: # Count statuses status_counts = { - "completed": 0, - "generating": 0, - "error": 0, - "not_started": 0, + SummaryStatus.COMPLETED: 0, + SummaryStatus.GENERATING: 0, + SummaryStatus.ERROR: 0, + SummaryStatus.NOT_STARTED: 0, } summary_list = [] for segment in segments: summary = summary_map.get(segment.id) if summary: - status = summary.status + status = SummaryStatus(summary.status) status_counts[status] = status_counts.get(status, 0) + 1 summary_list.append( { @@ -1421,12 +1423,12 @@ class SummaryIndexService: } ) else: - status_counts["not_started"] += 1 + status_counts[SummaryStatus.NOT_STARTED] += 1 summary_list.append( { "segment_id": segment.id, "segment_position": segment.position, - "status": "not_started", + "status": SummaryStatus.NOT_STARTED, "summary_preview": None, "error": None, "created_at": None, diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index 2d3d00cd50..ae55c9ee03 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -13,6 +13,7 @@ from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from models.dataset import DatasetAutoDisableLog, DocumentSegment from models.dataset import Document as DatasetDocument +from models.enums import IndexingStatus, SegmentStatus logger = logging.getLogger(__name__) @@ -34,7 +35,7 @@ def add_document_to_index_task(dataset_document_id: str): logger.info(click.style(f"Document not found: {dataset_document_id}", fg="red")) return - if dataset_document.indexing_status != "completed": + if dataset_document.indexing_status != IndexingStatus.COMPLETED: return indexing_cache_key = f"document_{dataset_document.id}_indexing" @@ -48,7 +49,7 @@ def add_document_to_index_task(dataset_document_id: str): session.query(DocumentSegment) .where( DocumentSegment.document_id == dataset_document.id, - DocumentSegment.status == "completed", + DocumentSegment.status == SegmentStatus.COMPLETED, ) .order_by(DocumentSegment.position.asc()) .all() @@ -139,7 +140,7 @@ def add_document_to_index_task(dataset_document_id: str): logger.exception("add document to index failed") dataset_document.enabled = False dataset_document.disabled_at = naive_utc_now() - dataset_document.indexing_status = "error" + dataset_document.indexing_status = IndexingStatus.ERROR dataset_document.error = str(e) session.commit() finally: diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index 4f8e2fec7a..1fe43c3d62 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -11,6 +11,7 @@ from core.rag.models.document import Document from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from models.dataset import Dataset +from models.enums import CollectionBindingType from models.model import App, AppAnnotationSetting, MessageAnnotation from services.dataset_service import DatasetCollectionBindingService @@ -47,7 +48,7 @@ def enable_annotation_reply_task( try: documents = [] dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_provider_name, embedding_model_name, "annotation" + embedding_provider_name, embedding_model_name, CollectionBindingType.ANNOTATION ) annotation_setting = ( session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() @@ -56,7 +57,7 @@ def enable_annotation_reply_task( if dataset_collection_binding.id != annotation_setting.collection_binding_id: old_dataset_collection_binding = ( DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( - annotation_setting.collection_binding_id, "annotation" + annotation_setting.collection_binding_id, CollectionBindingType.ANNOTATION ) ) if old_dataset_collection_binding and annotations: diff --git a/api/tasks/create_segment_to_index_task.py b/api/tasks/create_segment_to_index_task.py index b5e472d71e..b3cbc73d6e 100644 --- a/api/tasks/create_segment_to_index_task.py +++ b/api/tasks/create_segment_to_index_task.py @@ -10,6 +10,7 @@ from core.rag.models.document import Document from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from models.dataset import DocumentSegment +from models.enums import IndexingStatus, SegmentStatus logger = logging.getLogger(__name__) @@ -31,7 +32,7 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N logger.info(click.style(f"Segment not found: {segment_id}", fg="red")) return - if segment.status != "waiting": + if segment.status != SegmentStatus.WAITING: return indexing_cache_key = f"segment_{segment.id}_indexing" @@ -40,7 +41,7 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N # update segment status to indexing session.query(DocumentSegment).filter_by(id=segment.id).update( { - DocumentSegment.status: "indexing", + DocumentSegment.status: SegmentStatus.INDEXING, DocumentSegment.indexing_at: naive_utc_now(), } ) @@ -70,7 +71,7 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N if ( not dataset_document.enabled or dataset_document.archived - or dataset_document.indexing_status != "completed" + or dataset_document.indexing_status != IndexingStatus.COMPLETED ): logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan")) return @@ -82,7 +83,7 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N # update segment to completed session.query(DocumentSegment).filter_by(id=segment.id).update( { - DocumentSegment.status: "completed", + DocumentSegment.status: SegmentStatus.COMPLETED, DocumentSegment.completed_at: naive_utc_now(), } ) @@ -94,7 +95,7 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N logger.exception("create segment to index failed") segment.enabled = False segment.disabled_at = naive_utc_now() - segment.status = "error" + segment.status = SegmentStatus.ERROR segment.error = str(e) session.commit() finally: diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index fddd9199d1..f99e90062f 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -12,6 +12,7 @@ from core.rag.extractor.notion_extractor import NotionExtractor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment +from models.enums import IndexingStatus from services.datasource_provider_service import DatasourceProviderService logger = logging.getLogger(__name__) @@ -37,7 +38,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): logger.info(click.style(f"Document not found: {document_id}", fg="red")) return - if document.indexing_status == "parsing": + if document.indexing_status == IndexingStatus.PARSING: logger.info(click.style(f"Document {document_id} is already being processed, skipping", fg="yellow")) return @@ -88,7 +89,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): with session_factory.create_session() as session, session.begin(): document = session.query(Document).filter_by(id=document_id).first() if document: - document.indexing_status = "error" + document.indexing_status = IndexingStatus.ERROR document.error = "Datasource credential not found. Please reconnect your Notion workspace." document.stopped_at = naive_utc_now() return @@ -128,7 +129,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): data_source_info["last_edited_time"] = last_edited_time document.data_source_info = json.dumps(data_source_info) - document.indexing_status = "parsing" + document.indexing_status = IndexingStatus.PARSING document.processing_started_at = naive_utc_now() segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id) @@ -151,6 +152,6 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): with session_factory.create_session() as session, session.begin(): document = session.query(Document).filter_by(id=document_id).first() if document: - document.indexing_status = "error" + document.indexing_status = IndexingStatus.ERROR document.error = str(e) document.stopped_at = naive_utc_now() diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index b3f36d8f44..e05d63426c 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -14,6 +14,7 @@ from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document +from models.enums import IndexingStatus from services.feature_service import FeatureService from tasks.generate_summary_index_task import generate_summary_index_task @@ -81,7 +82,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() ) if document: - document.indexing_status = "error" + document.indexing_status = IndexingStatus.ERROR document.error = str(e) document.stopped_at = naive_utc_now() session.add(document) @@ -96,7 +97,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): for document in documents: if document: - document.indexing_status = "parsing" + document.indexing_status = IndexingStatus.PARSING document.processing_started_at = naive_utc_now() session.add(document) # Transaction committed and closed @@ -148,7 +149,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): document.need_summary, ) if ( - document.indexing_status == "completed" + document.indexing_status == IndexingStatus.COMPLETED and document.doc_form != "qa_model" and document.need_summary is True ): diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index c7508c6d05..62bce24de4 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -10,6 +10,7 @@ from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment +from models.enums import IndexingStatus logger = logging.getLogger(__name__) @@ -33,7 +34,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str): logger.info(click.style(f"Document not found: {document_id}", fg="red")) return - document.indexing_status = "parsing" + document.indexing_status = IndexingStatus.PARSING document.processing_started_at = naive_utc_now() dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py index 00a963255b..13c651753f 100644 --- a/api/tasks/duplicate_document_indexing_task.py +++ b/api/tasks/duplicate_document_indexing_task.py @@ -15,6 +15,7 @@ from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment +from models.enums import IndexingStatus from services.feature_service import FeatureService logger = logging.getLogger(__name__) @@ -112,7 +113,7 @@ def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[st ) for document in documents: if document: - document.indexing_status = "error" + document.indexing_status = IndexingStatus.ERROR document.error = str(e) document.stopped_at = naive_utc_now() session.add(document) @@ -146,7 +147,7 @@ def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[st session.execute(segment_delete_stmt) session.commit() - document.indexing_status = "parsing" + document.indexing_status = IndexingStatus.PARSING document.processing_started_at = naive_utc_now() session.add(document) session.commit() diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py index 41ebb0b076..5ad17d75d4 100644 --- a/api/tasks/enable_segment_to_index_task.py +++ b/api/tasks/enable_segment_to_index_task.py @@ -12,6 +12,7 @@ from core.rag.models.document import AttachmentDocument, ChildDocument, Document from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from models.dataset import DocumentSegment +from models.enums import IndexingStatus, SegmentStatus logger = logging.getLogger(__name__) @@ -33,7 +34,7 @@ def enable_segment_to_index_task(segment_id: str): logger.info(click.style(f"Segment not found: {segment_id}", fg="red")) return - if segment.status != "completed": + if segment.status != SegmentStatus.COMPLETED: logger.info(click.style(f"Segment is not completed, enable is not allowed: {segment_id}", fg="red")) return @@ -65,7 +66,7 @@ def enable_segment_to_index_task(segment_id: str): if ( not dataset_document.enabled or dataset_document.archived - or dataset_document.indexing_status != "completed" + or dataset_document.indexing_status != IndexingStatus.COMPLETED ): logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan")) return @@ -123,7 +124,7 @@ def enable_segment_to_index_task(segment_id: str): logger.exception("enable segment to index failed") segment.enabled = False segment.disabled_at = naive_utc_now() - segment.status = "error" + segment.status = SegmentStatus.ERROR segment.error = str(e) session.commit() finally: diff --git a/api/tasks/retry_document_indexing_task.py b/api/tasks/retry_document_indexing_task.py index f20b15ac83..4fcb0cf804 100644 --- a/api/tasks/retry_document_indexing_task.py +++ b/api/tasks/retry_document_indexing_task.py @@ -12,6 +12,7 @@ from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from models import Account, Tenant from models.dataset import Dataset, Document, DocumentSegment +from models.enums import IndexingStatus from services.feature_service import FeatureService from services.rag_pipeline.rag_pipeline import RagPipelineService @@ -63,7 +64,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_ .first() ) if document: - document.indexing_status = "error" + document.indexing_status = IndexingStatus.ERROR document.error = str(e) document.stopped_at = naive_utc_now() session.add(document) @@ -95,7 +96,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_ session.execute(segment_delete_stmt) session.commit() - document.indexing_status = "parsing" + document.indexing_status = IndexingStatus.PARSING document.processing_started_at = naive_utc_now() session.add(document) session.commit() @@ -108,7 +109,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_ indexing_runner.run([document]) redis_client.delete(retry_indexing_cache_key) except Exception as ex: - document.indexing_status = "error" + document.indexing_status = IndexingStatus.ERROR document.error = str(ex) document.stopped_at = naive_utc_now() session.add(document) diff --git a/api/tasks/sync_website_document_indexing_task.py b/api/tasks/sync_website_document_indexing_task.py index f1c8c56995..aa6bce958b 100644 --- a/api/tasks/sync_website_document_indexing_task.py +++ b/api/tasks/sync_website_document_indexing_task.py @@ -11,6 +11,7 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFacto from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment +from models.enums import IndexingStatus from services.feature_service import FeatureService logger = logging.getLogger(__name__) @@ -48,7 +49,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() ) if document: - document.indexing_status = "error" + document.indexing_status = IndexingStatus.ERROR document.error = str(e) document.stopped_at = naive_utc_now() session.add(document) @@ -76,7 +77,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): session.execute(segment_delete_stmt) session.commit() - document.indexing_status = "parsing" + document.indexing_status = IndexingStatus.PARSING document.processing_started_at = naive_utc_now() session.add(document) session.commit() @@ -85,7 +86,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): indexing_runner.run([document]) redis_client.delete(sync_indexing_cache_key) except Exception as ex: - document.indexing_status = "error" + document.indexing_status = IndexingStatus.ERROR document.error = str(ex) document.stopped_at = naive_utc_now() session.add(document) diff --git a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py index 75471afef8..781e297fa4 100644 --- a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py +++ b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py @@ -7,6 +7,7 @@ from faker import Faker from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest from models.dataset import Dataset, Document +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus from services.account_service import AccountService, TenantService from tests.test_containers_integration_tests.helpers import generate_valid_password @@ -35,7 +36,7 @@ class TestGetAvailableDatasetsIntegration: name=fake.company(), description=fake.text(max_nb_chars=100), provider="dify", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, indexing_technique="high_quality", ) @@ -49,14 +50,14 @@ class TestGetAvailableDatasetsIntegration: tenant_id=tenant.id, dataset_id=dataset.id, position=i, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch=str(uuid.uuid4()), # Required field name=f"Document {i}", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, ) @@ -94,7 +95,7 @@ class TestGetAvailableDatasetsIntegration: tenant_id=tenant.id, name=fake.company(), provider="dify", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -106,13 +107,13 @@ class TestGetAvailableDatasetsIntegration: tenant_id=tenant.id, dataset_id=dataset.id, position=i, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch=str(uuid.uuid4()), # Required field - created_from="web", + created_from=DocumentCreatedFrom.WEB, name=f"Archived Document {i}", created_by=account.id, doc_form="text_model", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=True, # Archived ) @@ -147,7 +148,7 @@ class TestGetAvailableDatasetsIntegration: tenant_id=tenant.id, name=fake.company(), provider="dify", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -159,13 +160,13 @@ class TestGetAvailableDatasetsIntegration: tenant_id=tenant.id, dataset_id=dataset.id, position=i, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch=str(uuid.uuid4()), # Required field - created_from="web", + created_from=DocumentCreatedFrom.WEB, name=f"Disabled Document {i}", created_by=account.id, doc_form="text_model", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=False, # Disabled archived=False, ) @@ -200,21 +201,21 @@ class TestGetAvailableDatasetsIntegration: tenant_id=tenant.id, name=fake.company(), provider="dify", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) # Create documents with non-completed status - for i, status in enumerate(["indexing", "parsing", "splitting"]): + for i, status in enumerate([IndexingStatus.INDEXING, IndexingStatus.PARSING, IndexingStatus.SPLITTING]): document = Document( id=str(uuid.uuid4()), tenant_id=tenant.id, dataset_id=dataset.id, position=i, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch=str(uuid.uuid4()), # Required field - created_from="web", + created_from=DocumentCreatedFrom.WEB, name=f"Document {status}", created_by=account.id, doc_form="text_model", @@ -263,7 +264,7 @@ class TestGetAvailableDatasetsIntegration: tenant_id=tenant.id, name=fake.company(), provider="external", # External provider - data_source_type="external", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -307,7 +308,7 @@ class TestGetAvailableDatasetsIntegration: tenant_id=tenant1.id, name="Tenant 1 Dataset", provider="dify", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account1.id, ) db_session_with_containers.add(dataset1) @@ -318,7 +319,7 @@ class TestGetAvailableDatasetsIntegration: tenant_id=tenant2.id, name="Tenant 2 Dataset", provider="dify", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account2.id, ) db_session_with_containers.add(dataset2) @@ -330,13 +331,13 @@ class TestGetAvailableDatasetsIntegration: tenant_id=dataset.tenant_id, dataset_id=dataset.id, position=0, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch=str(uuid.uuid4()), # Required field - created_from="web", + created_from=DocumentCreatedFrom.WEB, name=f"Document for {dataset.name}", created_by=account.id, doc_form="text_model", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, ) @@ -398,7 +399,7 @@ class TestGetAvailableDatasetsIntegration: tenant_id=tenant.id, name=f"Dataset {i}", provider="dify", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -410,13 +411,13 @@ class TestGetAvailableDatasetsIntegration: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch=str(uuid.uuid4()), # Required field - created_from="web", + created_from=DocumentCreatedFrom.WEB, name=f"Document {i}", created_by=account.id, doc_form="text_model", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, ) @@ -456,7 +457,7 @@ class TestKnowledgeRetrievalIntegration: tenant_id=tenant.id, name=fake.company(), provider="dify", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, indexing_technique="high_quality", ) @@ -467,12 +468,12 @@ class TestKnowledgeRetrievalIntegration: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch=str(uuid.uuid4()), # Required field - created_from="web", + created_from=DocumentCreatedFrom.WEB, name=fake.sentence(), created_by=account.id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, doc_form="text_model", @@ -525,7 +526,7 @@ class TestKnowledgeRetrievalIntegration: tenant_id=tenant.id, name=fake.company(), provider="dify", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -572,7 +573,7 @@ class TestKnowledgeRetrievalIntegration: tenant_id=tenant.id, name=fake.company(), provider="dify", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/models/test_dataset_models.py b/api/tests/test_containers_integration_tests/models/test_dataset_models.py index 6c541a8ad2..a3bbf19657 100644 --- a/api/tests/test_containers_integration_tests/models/test_dataset_models.py +++ b/api/tests/test_containers_integration_tests/models/test_dataset_models.py @@ -12,6 +12,7 @@ import pytest from sqlalchemy.orm import Session from models.dataset import Dataset, Document, DocumentSegment +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus class TestDatasetDocumentProperties: @@ -29,7 +30,7 @@ class TestDatasetDocumentProperties: created_by = str(uuid4()) dataset = Dataset( - tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by + tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by ) db_session_with_containers.add(dataset) db_session_with_containers.flush() @@ -39,10 +40,10 @@ class TestDatasetDocumentProperties: tenant_id=tenant_id, dataset_id=dataset.id, position=i + 1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name=f"doc_{i}.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, ) db_session_with_containers.add(doc) @@ -56,7 +57,7 @@ class TestDatasetDocumentProperties: created_by = str(uuid4()) dataset = Dataset( - tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by + tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by ) db_session_with_containers.add(dataset) db_session_with_containers.flush() @@ -65,12 +66,12 @@ class TestDatasetDocumentProperties: tenant_id=tenant_id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="available.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, ) @@ -78,12 +79,12 @@ class TestDatasetDocumentProperties: tenant_id=tenant_id, dataset_id=dataset.id, position=2, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="pending.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, - indexing_status="waiting", + indexing_status=IndexingStatus.WAITING, enabled=True, archived=False, ) @@ -91,12 +92,12 @@ class TestDatasetDocumentProperties: tenant_id=tenant_id, dataset_id=dataset.id, position=3, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="disabled.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=False, archived=False, ) @@ -111,7 +112,7 @@ class TestDatasetDocumentProperties: created_by = str(uuid4()) dataset = Dataset( - tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by + tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by ) db_session_with_containers.add(dataset) db_session_with_containers.flush() @@ -121,10 +122,10 @@ class TestDatasetDocumentProperties: tenant_id=tenant_id, dataset_id=dataset.id, position=i + 1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name=f"doc_{i}.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, word_count=wc, ) @@ -139,7 +140,7 @@ class TestDatasetDocumentProperties: created_by = str(uuid4()) dataset = Dataset( - tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by + tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by ) db_session_with_containers.add(dataset) db_session_with_containers.flush() @@ -148,10 +149,10 @@ class TestDatasetDocumentProperties: tenant_id=tenant_id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="doc.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, ) db_session_with_containers.add(doc) @@ -166,7 +167,7 @@ class TestDatasetDocumentProperties: content=f"segment {i}", word_count=100, tokens=50, - status="completed", + status=SegmentStatus.COMPLETED, enabled=True, created_by=created_by, ) @@ -180,7 +181,7 @@ class TestDatasetDocumentProperties: content="waiting segment", word_count=100, tokens=50, - status="waiting", + status=SegmentStatus.WAITING, enabled=True, created_by=created_by, ) @@ -195,7 +196,7 @@ class TestDatasetDocumentProperties: created_by = str(uuid4()) dataset = Dataset( - tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by + tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by ) db_session_with_containers.add(dataset) db_session_with_containers.flush() @@ -204,10 +205,10 @@ class TestDatasetDocumentProperties: tenant_id=tenant_id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="doc.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, ) db_session_with_containers.add(doc) @@ -235,7 +236,7 @@ class TestDatasetDocumentProperties: created_by = str(uuid4()) dataset = Dataset( - tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by + tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by ) db_session_with_containers.add(dataset) db_session_with_containers.flush() @@ -244,10 +245,10 @@ class TestDatasetDocumentProperties: tenant_id=tenant_id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="doc.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, ) db_session_with_containers.add(doc) @@ -288,7 +289,7 @@ class TestDocumentSegmentNavigationProperties: dataset = Dataset( tenant_id=tenant_id, name="Test Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by, ) db_session_with_containers.add(dataset) @@ -298,10 +299,10 @@ class TestDocumentSegmentNavigationProperties: tenant_id=tenant_id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, ) db_session_with_containers.add(document) @@ -335,7 +336,7 @@ class TestDocumentSegmentNavigationProperties: dataset = Dataset( tenant_id=tenant_id, name="Test Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by, ) db_session_with_containers.add(dataset) @@ -345,10 +346,10 @@ class TestDocumentSegmentNavigationProperties: tenant_id=tenant_id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, ) db_session_with_containers.add(document) @@ -382,7 +383,7 @@ class TestDocumentSegmentNavigationProperties: dataset = Dataset( tenant_id=tenant_id, name="Test Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by, ) db_session_with_containers.add(dataset) @@ -392,10 +393,10 @@ class TestDocumentSegmentNavigationProperties: tenant_id=tenant_id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, ) db_session_with_containers.add(document) @@ -439,7 +440,7 @@ class TestDocumentSegmentNavigationProperties: dataset = Dataset( tenant_id=tenant_id, name="Test Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by, ) db_session_with_containers.add(dataset) @@ -449,10 +450,10 @@ class TestDocumentSegmentNavigationProperties: tenant_id=tenant_id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, ) db_session_with_containers.add(document) diff --git a/api/tests/test_containers_integration_tests/services/dataset_collection_binding.py b/api/tests/test_containers_integration_tests/services/dataset_collection_binding.py index 191c161613..638a61c815 100644 --- a/api/tests/test_containers_integration_tests/services/dataset_collection_binding.py +++ b/api/tests/test_containers_integration_tests/services/dataset_collection_binding.py @@ -12,6 +12,7 @@ import pytest from sqlalchemy.orm import Session from models.dataset import DatasetCollectionBinding +from models.enums import CollectionBindingType from services.dataset_service import DatasetCollectionBindingService @@ -32,7 +33,7 @@ class DatasetCollectionBindingTestDataFactory: provider_name: str = "openai", model_name: str = "text-embedding-ada-002", collection_name: str = "collection-abc", - collection_type: str = "dataset", + collection_type: str = CollectionBindingType.DATASET, ) -> DatasetCollectionBinding: """ Create a DatasetCollectionBinding with specified attributes. @@ -41,7 +42,7 @@ class DatasetCollectionBindingTestDataFactory: provider_name: Name of the embedding model provider (e.g., "openai", "cohere") model_name: Name of the embedding model (e.g., "text-embedding-ada-002") collection_name: Name of the vector database collection - collection_type: Type of collection (default: "dataset") + collection_type: Type of collection (default: CollectionBindingType.DATASET) Returns: DatasetCollectionBinding instance @@ -76,7 +77,7 @@ class TestDatasetCollectionBindingServiceGetBinding: # Arrange provider_name = "openai" model_name = "text-embedding-ada-002" - collection_type = "dataset" + collection_type = CollectionBindingType.DATASET existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding( db_session_with_containers, provider_name=provider_name, @@ -104,7 +105,7 @@ class TestDatasetCollectionBindingServiceGetBinding: # Arrange provider_name = f"provider-{uuid4()}" model_name = f"model-{uuid4()}" - collection_type = "dataset" + collection_type = CollectionBindingType.DATASET # Act result = DatasetCollectionBindingService.get_dataset_collection_binding( @@ -145,7 +146,7 @@ class TestDatasetCollectionBindingServiceGetBinding: result = DatasetCollectionBindingService.get_dataset_collection_binding(provider_name, model_name) # Assert - assert result.type == "dataset" + assert result.type == CollectionBindingType.DATASET assert result.provider_name == provider_name assert result.model_name == model_name @@ -186,18 +187,20 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType: provider_name="openai", model_name="text-embedding-ada-002", collection_name="test-collection", - collection_type="dataset", + collection_type=CollectionBindingType.DATASET, ) # Act - result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(binding.id, "dataset") + result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( + binding.id, CollectionBindingType.DATASET + ) # Assert assert result.id == binding.id assert result.provider_name == "openai" assert result.model_name == "text-embedding-ada-002" assert result.collection_name == "test-collection" - assert result.type == "dataset" + assert result.type == CollectionBindingType.DATASET def test_get_dataset_collection_binding_by_id_and_type_not_found_error(self, db_session_with_containers: Session): """Test error handling when collection binding is not found by ID and type.""" @@ -206,7 +209,9 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType: # Act & Assert with pytest.raises(ValueError, match="Dataset collection binding not found"): - DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(non_existent_id, "dataset") + DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( + non_existent_id, CollectionBindingType.DATASET + ) def test_get_dataset_collection_binding_by_id_and_type_different_collection_type( self, db_session_with_containers: Session @@ -240,7 +245,7 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType: provider_name="openai", model_name="text-embedding-ada-002", collection_name="test-collection", - collection_type="dataset", + collection_type=CollectionBindingType.DATASET, ) # Act @@ -248,7 +253,7 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType: # Assert assert result.id == binding.id - assert result.type == "dataset" + assert result.type == CollectionBindingType.DATASET def test_get_dataset_collection_binding_by_id_and_type_wrong_type_error(self, db_session_with_containers: Session): """Test error when binding exists but with wrong collection type.""" @@ -258,7 +263,7 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType: provider_name="openai", model_name="text-embedding-ada-002", collection_name="test-collection", - collection_type="dataset", + collection_type=CollectionBindingType.DATASET, ) # Act & Assert diff --git a/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py b/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py index 4b98bddd26..6b35f867d7 100644 --- a/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py +++ b/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py @@ -15,6 +15,7 @@ from werkzeug.exceptions import NotFound from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import AppDatasetJoin, Dataset, DatasetPermissionEnum +from models.enums import DataSourceType from models.model import App from services.dataset_service import DatasetService from services.errors.account import NoPermissionError @@ -72,7 +73,7 @@ class DatasetUpdateDeleteTestDataFactory: tenant_id=tenant_id, name=name, description="Test description", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", created_by=created_by, permission=permission, diff --git a/api/tests/test_containers_integration_tests/services/document_service_status.py b/api/tests/test_containers_integration_tests/services/document_service_status.py index c08ea2a93b..251f17dd03 100644 --- a/api/tests/test_containers_integration_tests/services/document_service_status.py +++ b/api/tests/test_containers_integration_tests/services/document_service_status.py @@ -15,7 +15,7 @@ import pytest from models import Account from models.dataset import Dataset, Document -from models.enums import CreatorUserRole +from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom, IndexingStatus from models.model import UploadFile from services.dataset_service import DocumentService from services.errors.document import DocumentIndexingError @@ -88,7 +88,7 @@ class DocumentStatusTestDataFactory: data_source_info=json.dumps(data_source_info or {}), batch=f"batch-{uuid4()}", name=name, - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, doc_form="text_model", ) @@ -100,7 +100,7 @@ class DocumentStatusTestDataFactory: document.paused_by = paused_by document.paused_at = paused_at document.doc_metadata = doc_metadata or {} - if indexing_status == "completed" and "completed_at" not in kwargs: + if indexing_status == IndexingStatus.COMPLETED and "completed_at" not in kwargs: document.completed_at = FIXED_TIME for key, value in kwargs.items(): @@ -139,7 +139,7 @@ class DocumentStatusTestDataFactory: dataset = Dataset( tenant_id=tenant_id, name=name, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by, ) dataset.id = dataset_id @@ -291,7 +291,7 @@ class TestDocumentServicePauseDocument: db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, - indexing_status="waiting", + indexing_status=IndexingStatus.WAITING, is_paused=False, ) @@ -326,7 +326,7 @@ class TestDocumentServicePauseDocument: db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, - indexing_status="indexing", + indexing_status=IndexingStatus.INDEXING, is_paused=False, ) @@ -354,7 +354,7 @@ class TestDocumentServicePauseDocument: db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, - indexing_status="parsing", + indexing_status=IndexingStatus.PARSING, is_paused=False, ) @@ -383,7 +383,7 @@ class TestDocumentServicePauseDocument: db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, is_paused=False, ) @@ -412,7 +412,7 @@ class TestDocumentServicePauseDocument: db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, - indexing_status="error", + indexing_status=IndexingStatus.ERROR, is_paused=False, ) @@ -487,7 +487,7 @@ class TestDocumentServiceRecoverDocument: db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, - indexing_status="indexing", + indexing_status=IndexingStatus.INDEXING, is_paused=True, paused_by=str(uuid4()), paused_at=paused_time, @@ -526,7 +526,7 @@ class TestDocumentServiceRecoverDocument: db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, - indexing_status="indexing", + indexing_status=IndexingStatus.INDEXING, is_paused=False, ) @@ -609,7 +609,7 @@ class TestDocumentServiceRetryDocument: dataset_id=dataset.id, tenant_id=dataset.tenant_id, document_id=str(uuid4()), - indexing_status="error", + indexing_status=IndexingStatus.ERROR, ) mock_document_service_dependencies["redis_client"].get.return_value = None @@ -619,7 +619,7 @@ class TestDocumentServiceRetryDocument: # Assert db_session_with_containers.refresh(document) - assert document.indexing_status == "waiting" + assert document.indexing_status == IndexingStatus.WAITING expected_cache_key = f"document_{document.id}_is_retried" mock_document_service_dependencies["redis_client"].setex.assert_called_once_with(expected_cache_key, 600, 1) @@ -646,14 +646,14 @@ class TestDocumentServiceRetryDocument: dataset_id=dataset.id, tenant_id=dataset.tenant_id, document_id=str(uuid4()), - indexing_status="error", + indexing_status=IndexingStatus.ERROR, ) document2 = DocumentStatusTestDataFactory.create_document( db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, document_id=str(uuid4()), - indexing_status="error", + indexing_status=IndexingStatus.ERROR, position=2, ) @@ -665,8 +665,8 @@ class TestDocumentServiceRetryDocument: # Assert db_session_with_containers.refresh(document1) db_session_with_containers.refresh(document2) - assert document1.indexing_status == "waiting" - assert document2.indexing_status == "waiting" + assert document1.indexing_status == IndexingStatus.WAITING + assert document2.indexing_status == IndexingStatus.WAITING mock_document_service_dependencies["retry_task"].delay.assert_called_once_with( dataset.id, [document1.id, document2.id], mock_document_service_dependencies["user_id"] @@ -693,7 +693,7 @@ class TestDocumentServiceRetryDocument: dataset_id=dataset.id, tenant_id=dataset.tenant_id, document_id=str(uuid4()), - indexing_status="error", + indexing_status=IndexingStatus.ERROR, ) mock_document_service_dependencies["redis_client"].get.return_value = "1" @@ -703,7 +703,7 @@ class TestDocumentServiceRetryDocument: DocumentService.retry_document(dataset.id, [document]) db_session_with_containers.refresh(document) - assert document.indexing_status == "error" + assert document.indexing_status == IndexingStatus.ERROR def test_retry_document_missing_current_user_error( self, db_session_with_containers, mock_document_service_dependencies @@ -726,7 +726,7 @@ class TestDocumentServiceRetryDocument: dataset_id=dataset.id, tenant_id=dataset.tenant_id, document_id=str(uuid4()), - indexing_status="error", + indexing_status=IndexingStatus.ERROR, ) mock_document_service_dependencies["redis_client"].get.return_value = None @@ -816,7 +816,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus: tenant_id=dataset.tenant_id, document_id=str(uuid4()), enabled=False, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) document2 = DocumentStatusTestDataFactory.create_document( db_session_with_containers, @@ -824,7 +824,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus: tenant_id=dataset.tenant_id, document_id=str(uuid4()), enabled=False, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, position=2, ) document_ids = [document1.id, document2.id] @@ -866,7 +866,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus: tenant_id=dataset.tenant_id, document_id=str(uuid4()), enabled=True, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, completed_at=FIXED_TIME, ) document_ids = [document.id] @@ -909,7 +909,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus: document_id=str(uuid4()), archived=False, enabled=True, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) document_ids = [document.id] @@ -951,7 +951,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus: document_id=str(uuid4()), archived=True, enabled=True, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) document_ids = [document.id] @@ -1015,7 +1015,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus: dataset_id=dataset.id, tenant_id=dataset.tenant_id, document_id=str(uuid4()), - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) document_ids = [document.id] @@ -1098,7 +1098,7 @@ class TestDocumentServiceRenameDocument: document_id=document_id, dataset_id=dataset.id, tenant_id=tenant_id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) # Act @@ -1139,7 +1139,7 @@ class TestDocumentServiceRenameDocument: dataset_id=dataset.id, tenant_id=tenant_id, doc_metadata={"existing_key": "existing_value"}, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) # Act @@ -1187,7 +1187,7 @@ class TestDocumentServiceRenameDocument: dataset_id=dataset.id, tenant_id=tenant_id, data_source_info={"upload_file_id": upload_file.id}, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) # Act @@ -1277,7 +1277,7 @@ class TestDocumentServiceRenameDocument: document_id=document_id, dataset_id=dataset.id, tenant_id=str(uuid4()), - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) # Act & Assert diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py index 44525e0036..975af3d428 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py @@ -16,6 +16,7 @@ from models.dataset import ( DatasetPermission, DatasetPermissionEnum, ) +from models.enums import DataSourceType from services.dataset_service import DatasetPermissionService, DatasetService from services.errors.account import NoPermissionError @@ -67,7 +68,7 @@ class DatasetPermissionTestDataFactory: tenant_id=tenant_id, name=name, description="desc", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", created_by=created_by, permission=permission, diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_service.py index 102c1a1eb5..ac3d9f9604 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service.py @@ -15,6 +15,7 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod from dify_graph.model_runtime.entities.model_entities import ModelType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetPermissionEnum, Document, ExternalKnowledgeBindings, Pipeline +from models.enums import DatasetRuntimeMode, DataSourceType, DocumentCreatedFrom, IndexingStatus from services.dataset_service import DatasetService from services.entities.knowledge_entities.knowledge_entities import RerankingModel, RetrievalModel from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity @@ -74,7 +75,7 @@ class DatasetServiceIntegrationDataFactory: tenant_id=tenant_id, name=name, description=description, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique=indexing_technique, created_by=created_by, provider=provider, @@ -98,13 +99,13 @@ class DatasetServiceIntegrationDataFactory: tenant_id=dataset.tenant_id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, data_source_info='{"upload_file_id": "upload-file-id"}', batch=str(uuid4()), name=name, - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, doc_form="text_model", ) db_session_with_containers.add(document) @@ -437,7 +438,7 @@ class TestDatasetServiceCreateRagPipelineDataset: created_pipeline = db_session_with_containers.get(Pipeline, result.pipeline_id) assert created_dataset is not None assert created_dataset.name == entity.name - assert created_dataset.runtime_mode == "rag_pipeline" + assert created_dataset.runtime_mode == DatasetRuntimeMode.RAG_PIPELINE assert created_dataset.created_by == account.id assert created_dataset.permission == DatasetPermissionEnum.ONLY_ME assert created_pipeline is not None diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py index 322b67d373..7983b1cd93 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py @@ -14,6 +14,7 @@ import pytest from sqlalchemy.orm import Session from models.dataset import Dataset, Document +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus from services.dataset_service import DocumentService from services.errors.document import DocumentIndexingError @@ -42,7 +43,7 @@ class DocumentBatchUpdateIntegrationDataFactory: dataset = Dataset( tenant_id=tenant_id or str(uuid4()), name=name, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by or str(uuid4()), ) if dataset_id: @@ -72,11 +73,11 @@ class DocumentBatchUpdateIntegrationDataFactory: tenant_id=dataset.tenant_id, dataset_id=dataset.id, position=position, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, data_source_info=json.dumps({"upload_file_id": str(uuid4())}), batch=f"batch-{uuid4()}", name=name, - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by or str(uuid4()), doc_form="text_model", ) @@ -85,7 +86,9 @@ class DocumentBatchUpdateIntegrationDataFactory: document.archived = archived document.indexing_status = indexing_status document.completed_at = ( - completed_at if completed_at is not None else (FIXED_TIME if indexing_status == "completed" else None) + completed_at + if completed_at is not None + else (FIXED_TIME if indexing_status == IndexingStatus.COMPLETED else None) ) for key, value in kwargs.items(): @@ -243,7 +246,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus: dataset=dataset, document_ids=document_ids, enabled=True, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) # Act @@ -277,7 +280,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus: db_session_with_containers, dataset=dataset, enabled=False, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, completed_at=FIXED_TIME, ) @@ -306,7 +309,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus: db_session_with_containers, dataset=dataset, enabled=True, - indexing_status="indexing", + indexing_status=IndexingStatus.INDEXING, completed_at=None, ) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py index c47e35791d..ed070527c9 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py @@ -5,6 +5,7 @@ from uuid import uuid4 from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document +from models.enums import DataSourceType, DocumentCreatedFrom from services.dataset_service import DatasetService @@ -58,7 +59,7 @@ class DatasetDeleteIntegrationDataFactory: dataset = Dataset( tenant_id=tenant_id, name=f"dataset-{uuid4()}", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique=indexing_technique, index_struct=index_struct, created_by=created_by, @@ -84,10 +85,10 @@ class DatasetDeleteIntegrationDataFactory: tenant_id=tenant_id, dataset_id=dataset_id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch=f"batch-{uuid4()}", name="Document", - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, doc_form=doc_form, ) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py index e78894fcae..c4b3a57bb2 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py @@ -14,6 +14,7 @@ from sqlalchemy.orm import Session from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment +from models.enums import DataSourceType, DocumentCreatedFrom from services.dataset_service import SegmentService @@ -62,7 +63,7 @@ class SegmentServiceTestDataFactory: tenant_id=tenant_id, name=f"Test Dataset {uuid4()}", description="Test description", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", created_by=created_by, permission=DatasetPermissionEnum.ONLY_ME, @@ -82,10 +83,10 @@ class SegmentServiceTestDataFactory: tenant_id=tenant_id, dataset_id=dataset_id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch=f"batch-{uuid4()}", name=f"test-doc-{uuid4()}.txt", - created_from="api", + created_from=DocumentCreatedFrom.API, created_by=created_by, ) db_session_with_containers.add(document) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py index 8bd994937a..3021d8984d 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py @@ -24,6 +24,7 @@ from models.dataset import ( DatasetProcessRule, DatasetQuery, ) +from models.enums import DatasetQuerySource, DataSourceType, ProcessRuleMode from models.model import Tag, TagBinding from services.dataset_service import DatasetService, DocumentService @@ -100,7 +101,7 @@ class DatasetRetrievalTestDataFactory: tenant_id=tenant_id, name=name, description="desc", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", created_by=created_by, permission=permission, @@ -149,7 +150,7 @@ class DatasetRetrievalTestDataFactory: dataset_query = DatasetQuery( dataset_id=dataset_id, content=content, - source="web", + source=DatasetQuerySource.APP, source_app_id=None, created_by_role="account", created_by=created_by, @@ -601,7 +602,7 @@ class TestDatasetServiceGetProcessRules: db_session_with_containers, dataset_id=dataset.id, created_by=account.id, - mode="custom", + mode=ProcessRuleMode.CUSTOM, rules=rules_data, ) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py index ebaa3b4637..fd81948247 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py @@ -7,6 +7,7 @@ from sqlalchemy.orm import Session from dify_graph.model_runtime.entities.model_entities import ModelType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, ExternalKnowledgeBindings +from models.enums import DataSourceType from services.dataset_service import DatasetService from services.errors.account import NoPermissionError @@ -64,7 +65,7 @@ class DatasetUpdateTestDataFactory: tenant_id=tenant_id, name=name, description=description, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique=indexing_technique, created_by=created_by, provider=provider, diff --git a/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py b/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py index 124056e10f..c6aa89c733 100644 --- a/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py +++ b/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py @@ -4,6 +4,7 @@ from uuid import uuid4 from sqlalchemy import select from models.dataset import Dataset, Document +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus from services.dataset_service import DocumentService @@ -11,7 +12,7 @@ def _create_dataset(db_session_with_containers) -> Dataset: dataset = Dataset( tenant_id=str(uuid4()), name=f"dataset-{uuid4()}", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=str(uuid4()), ) dataset.id = str(uuid4()) @@ -35,11 +36,11 @@ def _create_document( tenant_id=tenant_id, dataset_id=dataset_id, position=position, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, data_source_info="{}", batch=f"batch-{uuid4()}", name=f"doc-{uuid4()}", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), doc_form="text_model", ) @@ -48,7 +49,7 @@ def _create_document( document.enabled = enabled document.archived = archived document.is_paused = is_paused - if indexing_status == "completed": + if indexing_status == IndexingStatus.COMPLETED: document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db_session_with_containers.add(document) @@ -62,7 +63,7 @@ def test_build_display_status_filters_available(db_session_with_containers): db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, position=1, @@ -71,7 +72,7 @@ def test_build_display_status_filters_available(db_session_with_containers): db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=False, archived=False, position=2, @@ -80,7 +81,7 @@ def test_build_display_status_filters_available(db_session_with_containers): db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=True, position=3, @@ -101,14 +102,14 @@ def test_apply_display_status_filter_applies_when_status_present(db_session_with db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, - indexing_status="waiting", + indexing_status=IndexingStatus.WAITING, position=1, ) _create_document( db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, position=2, ) @@ -125,14 +126,14 @@ def test_apply_display_status_filter_returns_same_when_invalid(db_session_with_c db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, - indexing_status="waiting", + indexing_status=IndexingStatus.WAITING, position=1, ) doc2 = _create_document( db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, position=2, ) diff --git a/api/tests/test_containers_integration_tests/services/test_document_service_rename_document.py b/api/tests/test_containers_integration_tests/services/test_document_service_rename_document.py index f641da6576..b159af0090 100644 --- a/api/tests/test_containers_integration_tests/services/test_document_service_rename_document.py +++ b/api/tests/test_containers_integration_tests/services/test_document_service_rename_document.py @@ -9,7 +9,7 @@ import pytest from models import Account from models.dataset import Dataset, Document -from models.enums import CreatorUserRole +from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom from models.model import UploadFile from services.dataset_service import DocumentService @@ -33,7 +33,7 @@ def make_dataset(db_session_with_containers, dataset_id=None, tenant_id=None, bu dataset = Dataset( tenant_id=tenant_id, name=f"dataset-{uuid4()}", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=str(uuid4()), ) dataset.id = dataset_id @@ -62,11 +62,11 @@ def make_document( tenant_id=tenant_id, dataset_id=dataset_id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, data_source_info=json.dumps(data_source_info or {}), batch=f"batch-{uuid4()}", name=name, - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), doc_form="text_model", ) diff --git a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py index 6fe40c0744..ef1f31d36b 100644 --- a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py +++ b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py @@ -11,6 +11,7 @@ from sqlalchemy.orm import Session from enums.cloud_plan import CloudPlan from extensions.ext_redis import redis_client from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.enums import DataSourceType from models.model import ( App, AppAnnotationHitHistory, @@ -287,7 +288,7 @@ class TestMessagesCleanServiceIntegration: dataset_name="Test dataset", document_id=str(uuid.uuid4()), document_name="Test document", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, segment_id=str(uuid.uuid4()), score=0.9, content="Test content", diff --git a/api/tests/test_containers_integration_tests/services/test_metadata_service.py b/api/tests/test_containers_integration_tests/services/test_metadata_service.py index 694dc1c1b9..e847329c5b 100644 --- a/api/tests/test_containers_integration_tests/services/test_metadata_service.py +++ b/api/tests/test_containers_integration_tests/services/test_metadata_service.py @@ -7,6 +7,7 @@ from sqlalchemy.orm import Session from core.rag.index_processor.constant.built_in_field import BuiltInField from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding, Document +from models.enums import DatasetMetadataType, DataSourceType, DocumentCreatedFrom from services.entities.knowledge_entities.knowledge_entities import MetadataArgs from services.metadata_service import MetadataService @@ -101,7 +102,7 @@ class TestMetadataService: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, built_in_field_enabled=False, ) @@ -132,11 +133,11 @@ class TestMetadataService: tenant_id=dataset.tenant_id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, data_source_info="{}", batch="test-batch", name=fake.file_name(), - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text", doc_language="en", @@ -163,7 +164,7 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id mock_external_service_dependencies["current_user"].id = account.id - metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata") # Act: Execute the method under test result = MetadataService.create_metadata(dataset.id, metadata_args) @@ -201,7 +202,7 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].id = account.id long_name = "a" * 256 # 256 characters, exceeding 255 limit - metadata_args = MetadataArgs(type="string", name=long_name) + metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name=long_name) # Act & Assert: Verify proper error handling with pytest.raises(ValueError, match="Metadata name cannot exceed 255 characters."): @@ -226,11 +227,11 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].id = account.id # Create first metadata - first_metadata_args = MetadataArgs(type="string", name="duplicate_name") + first_metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="duplicate_name") MetadataService.create_metadata(dataset.id, first_metadata_args) # Try to create second metadata with same name - second_metadata_args = MetadataArgs(type="number", name="duplicate_name") + second_metadata_args = MetadataArgs(type=DatasetMetadataType.NUMBER, name="duplicate_name") # Act & Assert: Verify proper error handling with pytest.raises(ValueError, match="Metadata name already exists."): @@ -256,7 +257,7 @@ class TestMetadataService: # Try to create metadata with built-in field name built_in_field_name = BuiltInField.document_name - metadata_args = MetadataArgs(type="string", name=built_in_field_name) + metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name=built_in_field_name) # Act & Assert: Verify proper error handling with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."): @@ -281,7 +282,7 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].id = account.id # Create metadata first - metadata_args = MetadataArgs(type="string", name="old_name") + metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="old_name") metadata = MetadataService.create_metadata(dataset.id, metadata_args) # Act: Execute the method under test @@ -318,7 +319,7 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].id = account.id # Create metadata first - metadata_args = MetadataArgs(type="string", name="old_name") + metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="old_name") metadata = MetadataService.create_metadata(dataset.id, metadata_args) # Try to update with too long name @@ -347,10 +348,10 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].id = account.id # Create two metadata entries - first_metadata_args = MetadataArgs(type="string", name="first_metadata") + first_metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="first_metadata") first_metadata = MetadataService.create_metadata(dataset.id, first_metadata_args) - second_metadata_args = MetadataArgs(type="number", name="second_metadata") + second_metadata_args = MetadataArgs(type=DatasetMetadataType.NUMBER, name="second_metadata") second_metadata = MetadataService.create_metadata(dataset.id, second_metadata_args) # Try to update first metadata with second metadata's name @@ -376,7 +377,7 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].id = account.id # Create metadata first - metadata_args = MetadataArgs(type="string", name="old_name") + metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="old_name") metadata = MetadataService.create_metadata(dataset.id, metadata_args) # Try to update with built-in field name @@ -432,7 +433,7 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].id = account.id # Create metadata first - metadata_args = MetadataArgs(type="string", name="to_be_deleted") + metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="to_be_deleted") metadata = MetadataService.create_metadata(dataset.id, metadata_args) # Act: Execute the method under test @@ -496,7 +497,7 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].id = account.id # Create metadata - metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata") metadata = MetadataService.create_metadata(dataset.id, metadata_args) # Create metadata binding @@ -798,7 +799,7 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].id = account.id # Create metadata - metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata") metadata = MetadataService.create_metadata(dataset.id, metadata_args) # Mock DocumentService.get_document @@ -866,7 +867,7 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].id = account.id # Create metadata - metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata") metadata = MetadataService.create_metadata(dataset.id, metadata_args) # Mock DocumentService.get_document @@ -917,7 +918,7 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].id = account.id # Create metadata - metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata") metadata = MetadataService.create_metadata(dataset.id, metadata_args) # Create metadata operation data @@ -1038,7 +1039,7 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].id = account.id # Create metadata - metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata") metadata = MetadataService.create_metadata(dataset.id, metadata_args) # Create document and metadata binding @@ -1101,7 +1102,7 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].id = account.id # Create metadata - metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata") metadata = MetadataService.create_metadata(dataset.id, metadata_args) # Act: Execute the method under test diff --git a/api/tests/test_containers_integration_tests/services/test_tag_service.py b/api/tests/test_containers_integration_tests/services/test_tag_service.py index 597ba6b75b..fa6e651529 100644 --- a/api/tests/test_containers_integration_tests/services/test_tag_service.py +++ b/api/tests/test_containers_integration_tests/services/test_tag_service.py @@ -9,6 +9,7 @@ from werkzeug.exceptions import NotFound from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset +from models.enums import DataSourceType from models.model import App, Tag, TagBinding from services.tag_service import TagService @@ -100,7 +101,7 @@ class TestTagService: description=fake.text(max_nb_chars=100), provider="vendor", permission="only_me", - data_source_type="upload", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", tenant_id=tenant_id, created_by=mock_external_service_dependencies["current_user"].id, diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py index 8c007877fd..c3fe6a2950 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py @@ -510,7 +510,7 @@ class TestWorkflowConverter: retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, top_k=10, score_threshold=0.8, - reranking_model={"provider": "cohere", "model": "rerank-v2"}, + reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"}, reranking_enabled=True, ), ) @@ -543,8 +543,8 @@ class TestWorkflowConverter: multiple_config = node["data"]["multiple_retrieval_config"] assert multiple_config["top_k"] == 10 assert multiple_config["score_threshold"] == 0.8 - assert multiple_config["reranking_model"]["provider"] == "cohere" - assert multiple_config["reranking_model"]["model"] == "rerank-v2" + assert multiple_config["reranking_model"]["reranking_provider_name"] == "cohere" + assert multiple_config["reranking_model"]["reranking_model_name"] == "rerank-v2" # Verify single retrieval config is None for multiple strategy assert node["data"]["single_retrieval_config"] is None diff --git a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py index efeb29cf20..94173c34bf 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py @@ -8,6 +8,7 @@ from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetAutoDisableLog, Document, DocumentSegment +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from tasks.add_document_to_index_task import add_document_to_index_task @@ -79,7 +80,7 @@ class TestAddDocumentToIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", created_by=account.id, ) @@ -92,12 +93,12 @@ class TestAddDocumentToIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name=fake.file_name(), - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, doc_form=IndexStructureType.PARAGRAPH_INDEX, ) @@ -137,7 +138,7 @@ class TestAddDocumentToIndexTask: index_node_id=f"node_{i}", index_node_hash=f"hash_{i}", enabled=False, - status="completed", + status=SegmentStatus.COMPLETED, created_by=document.created_by, ) db_session_with_containers.add(segment) @@ -297,7 +298,7 @@ class TestAddDocumentToIndexTask: ) # Set invalid indexing status - document.indexing_status = "processing" + document.indexing_status = IndexingStatus.INDEXING db_session_with_containers.commit() # Act: Execute the task @@ -339,7 +340,7 @@ class TestAddDocumentToIndexTask: # Assert: Verify error handling db_session_with_containers.refresh(document) assert document.enabled is False - assert document.indexing_status == "error" + assert document.indexing_status == IndexingStatus.ERROR assert document.error is not None assert "doesn't exist" in document.error assert document.disabled_at is not None @@ -434,7 +435,7 @@ class TestAddDocumentToIndexTask: Test document indexing when segments are already enabled. This test verifies: - - Segments with status="completed" are processed regardless of enabled status + - Segments with status=SegmentStatus.COMPLETED are processed regardless of enabled status - Index processing occurs with all completed segments - Auto disable log deletion still occurs - Redis cache is cleared @@ -460,7 +461,7 @@ class TestAddDocumentToIndexTask: index_node_id=f"node_{i}", index_node_hash=f"hash_{i}", enabled=True, # Already enabled - status="completed", + status=SegmentStatus.COMPLETED, created_by=document.created_by, ) db_session_with_containers.add(segment) @@ -482,7 +483,7 @@ class TestAddDocumentToIndexTask: mock_external_service_dependencies["index_processor"].load.assert_called_once() # Verify the load method was called with all completed segments - # (implementation doesn't filter by enabled status, only by status="completed") + # (implementation doesn't filter by enabled status, only by status=SegmentStatus.COMPLETED) call_args = mock_external_service_dependencies["index_processor"].load.call_args assert call_args is not None documents = call_args[0][1] # Second argument should be documents list @@ -594,7 +595,7 @@ class TestAddDocumentToIndexTask: # Assert: Verify error handling db_session_with_containers.refresh(document) assert document.enabled is False - assert document.indexing_status == "error" + assert document.indexing_status == IndexingStatus.ERROR assert document.error is not None assert "Index processing failed" in document.error assert document.disabled_at is not None @@ -614,7 +615,7 @@ class TestAddDocumentToIndexTask: Test segment filtering with various edge cases. This test verifies: - - Only segments with status="completed" are processed (regardless of enabled status) + - Only segments with status=SegmentStatus.COMPLETED are processed (regardless of enabled status) - Segments with status!="completed" are NOT processed - Segments are ordered by position correctly - Mixed segment states are handled properly @@ -630,7 +631,7 @@ class TestAddDocumentToIndexTask: fake = Faker() segments = [] - # Segment 1: Should be processed (enabled=False, status="completed") + # Segment 1: Should be processed (enabled=False, status=SegmentStatus.COMPLETED) segment1 = DocumentSegment( id=fake.uuid4(), tenant_id=document.tenant_id, @@ -643,14 +644,14 @@ class TestAddDocumentToIndexTask: index_node_id="node_0", index_node_hash="hash_0", enabled=False, - status="completed", + status=SegmentStatus.COMPLETED, created_by=document.created_by, ) db_session_with_containers.add(segment1) segments.append(segment1) - # Segment 2: Should be processed (enabled=True, status="completed") - # Note: Implementation doesn't filter by enabled status, only by status="completed" + # Segment 2: Should be processed (enabled=True, status=SegmentStatus.COMPLETED) + # Note: Implementation doesn't filter by enabled status, only by status=SegmentStatus.COMPLETED segment2 = DocumentSegment( id=fake.uuid4(), tenant_id=document.tenant_id, @@ -663,7 +664,7 @@ class TestAddDocumentToIndexTask: index_node_id="node_1", index_node_hash="hash_1", enabled=True, # Already enabled, but will still be processed - status="completed", + status=SegmentStatus.COMPLETED, created_by=document.created_by, ) db_session_with_containers.add(segment2) @@ -682,13 +683,13 @@ class TestAddDocumentToIndexTask: index_node_id="node_2", index_node_hash="hash_2", enabled=False, - status="processing", # Not completed + status=SegmentStatus.INDEXING, # Not completed created_by=document.created_by, ) db_session_with_containers.add(segment3) segments.append(segment3) - # Segment 4: Should be processed (enabled=False, status="completed") + # Segment 4: Should be processed (enabled=False, status=SegmentStatus.COMPLETED) segment4 = DocumentSegment( id=fake.uuid4(), tenant_id=document.tenant_id, @@ -701,7 +702,7 @@ class TestAddDocumentToIndexTask: index_node_id="node_3", index_node_hash="hash_3", enabled=False, - status="completed", + status=SegmentStatus.COMPLETED, created_by=document.created_by, ) db_session_with_containers.add(segment4) @@ -726,7 +727,7 @@ class TestAddDocumentToIndexTask: call_args = mock_external_service_dependencies["index_processor"].load.call_args assert call_args is not None documents = call_args[0][1] # Second argument should be documents list - assert len(documents) == 3 # 3 segments with status="completed" should be processed + assert len(documents) == 3 # 3 segments with status=SegmentStatus.COMPLETED should be processed # Verify correct segments were processed (by position order) # Segments 1, 2, 4 should be processed (positions 0, 1, 3) @@ -799,7 +800,7 @@ class TestAddDocumentToIndexTask: # Assert: Verify consistent error handling db_session_with_containers.refresh(document) assert document.enabled is False, f"Document should be disabled for {error_name}" - assert document.indexing_status == "error", f"Document status should be error for {error_name}" + assert document.indexing_status == IndexingStatus.ERROR, f"Document status should be error for {error_name}" assert document.error is not None, f"Error should be recorded for {error_name}" assert str(exception) in document.error, f"Error message should contain exception for {error_name}" assert document.disabled_at is not None, f"Disabled timestamp should be set for {error_name}" diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py index ec789418a8..6adefd59be 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py @@ -16,6 +16,7 @@ from sqlalchemy.orm import Session from libs.datetime_utils import naive_utc_now from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from models.model import UploadFile from tasks.batch_clean_document_task import batch_clean_document_task @@ -113,7 +114,7 @@ class TestBatchCleanDocumentTask: tenant_id=account.current_tenant.id, name=fake.word(), description=fake.sentence(), - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, embedding_model="text-embedding-ada-002", embedding_model_provider="openai", @@ -144,12 +145,12 @@ class TestBatchCleanDocumentTask: dataset_id=dataset.id, position=0, name=fake.word(), - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, data_source_info=json.dumps({"upload_file_id": str(uuid.uuid4())}), batch="test_batch", - created_from="test", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, doc_form="text_model", ) @@ -183,7 +184,7 @@ class TestBatchCleanDocumentTask: tokens=50, index_node_id=str(uuid.uuid4()), created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, ) db_session_with_containers.add(segment) @@ -297,7 +298,7 @@ class TestBatchCleanDocumentTask: tokens=50, index_node_id=str(uuid.uuid4()), created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, ) db_session_with_containers.add(segment) @@ -671,7 +672,7 @@ class TestBatchCleanDocumentTask: tokens=25 + i * 5, index_node_id=str(uuid.uuid4()), created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, ) segments.append(segment) diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py index a2324979db..ebe5ff1d96 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py @@ -21,7 +21,7 @@ from sqlalchemy.orm import Session from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment -from models.enums import CreatorUserRole +from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from models.model import UploadFile from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task @@ -139,7 +139,7 @@ class TestBatchCreateSegmentToIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(), - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", embedding_model="text-embedding-ada-002", embedding_model_provider="openai", @@ -170,12 +170,12 @@ class TestBatchCreateSegmentToIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name=fake.file_name(), - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, doc_form="text_model", @@ -301,7 +301,7 @@ class TestBatchCreateSegmentToIndexTask: assert segment.dataset_id == dataset.id assert segment.document_id == document.id assert segment.position == i + 1 - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED assert segment.indexing_at is not None assert segment.completed_at is not None assert segment.answer is None # text_model doesn't have answers @@ -442,12 +442,12 @@ class TestBatchCreateSegmentToIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name="disabled_document", - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=False, # Document is disabled archived=False, doc_form="text_model", @@ -458,12 +458,12 @@ class TestBatchCreateSegmentToIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=2, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name="archived_document", - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=True, # Document is archived doc_form="text_model", @@ -474,12 +474,12 @@ class TestBatchCreateSegmentToIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=3, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name="incomplete_document", - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, - indexing_status="indexing", # Not completed + indexing_status=IndexingStatus.INDEXING, # Not completed enabled=True, archived=False, doc_form="text_model", @@ -643,7 +643,7 @@ class TestBatchCreateSegmentToIndexTask: word_count=len(f"Existing segment {i + 1}"), tokens=10, created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, index_node_id=str(uuid.uuid4()), index_node_hash=f"hash_{i}", ) @@ -694,7 +694,7 @@ class TestBatchCreateSegmentToIndexTask: for i, segment in enumerate(new_segments): expected_position = 4 + i # Should start at position 4 assert segment.position == expected_position - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED assert segment.indexing_at is not None assert segment.completed_at is not None diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py index 41d9fc8a29..638752cf8b 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py @@ -29,7 +29,14 @@ from models.dataset import ( Document, DocumentSegment, ) -from models.enums import CreatorUserRole +from models.enums import ( + CreatorUserRole, + DatasetMetadataType, + DataSourceType, + DocumentCreatedFrom, + IndexingStatus, + SegmentStatus, +) from models.model import UploadFile from tasks.clean_dataset_task import clean_dataset_task @@ -176,12 +183,12 @@ class TestCleanDatasetTask: tenant_id=tenant.id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name="test_document", - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, doc_form="paragraph_index", @@ -219,7 +226,7 @@ class TestCleanDatasetTask: word_count=20, tokens=30, created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, index_node_id=str(uuid.uuid4()), index_node_hash="test_hash", created_at=datetime.now(), @@ -373,7 +380,7 @@ class TestCleanDatasetTask: dataset_id=dataset.id, tenant_id=tenant.id, name="test_metadata", - type="string", + type=DatasetMetadataType.STRING, created_by=account.id, ) metadata.id = str(uuid.uuid4()) @@ -587,7 +594,7 @@ class TestCleanDatasetTask: word_count=len(segment_content), tokens=50, created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, index_node_id=str(uuid.uuid4()), index_node_hash="test_hash", created_at=datetime.now(), @@ -686,7 +693,7 @@ class TestCleanDatasetTask: dataset_id=dataset.id, tenant_id=tenant.id, name=f"test_metadata_{i}", - type="string", + type=DatasetMetadataType.STRING, created_by=account.id, ) metadata.id = str(uuid.uuid4()) @@ -880,11 +887,11 @@ class TestCleanDatasetTask: tenant_id=tenant.id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, data_source_info="{}", batch="test_batch", name=f"test_doc_{special_content}", - created_from="test", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, created_at=datetime.now(), updated_at=datetime.now(), @@ -905,7 +912,7 @@ class TestCleanDatasetTask: word_count=len(segment_content.split()), tokens=len(segment_content) // 4, # Rough token estimation created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, index_node_id=str(uuid.uuid4()), index_node_hash="test_hash_" + "x" * 50, # Long hash within limits created_at=datetime.now(), @@ -946,7 +953,7 @@ class TestCleanDatasetTask: dataset_id=dataset.id, tenant_id=tenant.id, name=f"metadata_{special_content}", - type="string", + type=DatasetMetadataType.STRING, created_by=account.id, ) special_metadata.id = str(uuid.uuid4()) diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py index 3ce199c602..a2a190fd69 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py @@ -13,6 +13,7 @@ import pytest from faker import Faker from models.dataset import Dataset, Document, DocumentSegment +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from services.account_service import AccountService, TenantService from tasks.clean_notion_document_task import clean_notion_document_task from tests.test_containers_integration_tests.helpers import generate_valid_password @@ -88,7 +89,7 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -105,17 +106,17 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, dataset_id=dataset.id, position=i, - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, data_source_info=json.dumps( {"notion_workspace_id": f"workspace_{i}", "notion_page_id": f"page_{i}", "type": "page"} ), batch="test_batch", name=f"Notion Page {i}", - created_from="notion_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", # Set doc_form to ensure dataset.doc_form works doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) db_session_with_containers.add(document) db_session_with_containers.flush() @@ -134,7 +135,7 @@ class TestCleanNotionDocumentTask: tokens=50, index_node_id=f"node_{i}_{j}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, ) db_session_with_containers.add(segment) segments.append(segment) @@ -220,7 +221,7 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -269,7 +270,7 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, name=f"{fake.company()}_{index_type}", description=fake.text(max_nb_chars=100), - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -281,17 +282,17 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, data_source_info=json.dumps( {"notion_workspace_id": "workspace_test", "notion_page_id": "page_test", "type": "page"} ), batch="test_batch", name="Test Notion Page", - created_from="notion_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form=index_type, doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) db_session_with_containers.add(document) db_session_with_containers.flush() @@ -308,7 +309,7 @@ class TestCleanNotionDocumentTask: tokens=50, index_node_id="test_node", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, ) db_session_with_containers.add(segment) db_session_with_containers.commit() @@ -357,7 +358,7 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -369,16 +370,16 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, data_source_info=json.dumps( {"notion_workspace_id": "workspace_test", "notion_page_id": "page_test", "type": "page"} ), batch="test_batch", name="Test Notion Page", - created_from="notion_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) db_session_with_containers.add(document) db_session_with_containers.flush() @@ -397,7 +398,7 @@ class TestCleanNotionDocumentTask: tokens=50, index_node_id=None, # No index node ID created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, ) db_session_with_containers.add(segment) segments.append(segment) @@ -443,7 +444,7 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -460,16 +461,16 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, dataset_id=dataset.id, position=i, - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, data_source_info=json.dumps( {"notion_workspace_id": f"workspace_{i}", "notion_page_id": f"page_{i}", "type": "page"} ), batch="test_batch", name=f"Notion Page {i}", - created_from="notion_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) db_session_with_containers.add(document) db_session_with_containers.flush() @@ -488,7 +489,7 @@ class TestCleanNotionDocumentTask: tokens=50, index_node_id=f"node_{i}_{j}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, ) db_session_with_containers.add(segment) all_segments.append(segment) @@ -558,7 +559,7 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -570,22 +571,22 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, data_source_info=json.dumps( {"notion_workspace_id": "workspace_test", "notion_page_id": "page_test", "type": "page"} ), batch="test_batch", name="Test Notion Page", - created_from="notion_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) db_session_with_containers.add(document) db_session_with_containers.flush() # Create segments with different statuses - segment_statuses = ["waiting", "processing", "completed", "error"] + segment_statuses = [SegmentStatus.WAITING, SegmentStatus.INDEXING, SegmentStatus.COMPLETED, SegmentStatus.ERROR] segments = [] index_node_ids = [] @@ -654,7 +655,7 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -666,16 +667,16 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, data_source_info=json.dumps( {"notion_workspace_id": "workspace_test", "notion_page_id": "page_test", "type": "page"} ), batch="test_batch", name="Test Notion Page", - created_from="notion_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) db_session_with_containers.add(document) db_session_with_containers.flush() @@ -692,7 +693,7 @@ class TestCleanNotionDocumentTask: tokens=50, index_node_id="test_node", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, ) db_session_with_containers.add(segment) db_session_with_containers.commit() @@ -736,7 +737,7 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -754,16 +755,16 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, dataset_id=dataset.id, position=i, - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, data_source_info=json.dumps( {"notion_workspace_id": f"workspace_{i}", "notion_page_id": f"page_{i}", "type": "page"} ), batch="test_batch", name=f"Notion Page {i}", - created_from="notion_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) db_session_with_containers.add(document) db_session_with_containers.flush() @@ -783,7 +784,7 @@ class TestCleanNotionDocumentTask: tokens=50, index_node_id=f"node_{i}_{j}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, ) db_session_with_containers.add(segment) all_segments.append(segment) @@ -848,7 +849,7 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, name=f"{fake.company()}_{i}", description=fake.text(max_nb_chars=100), - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -866,16 +867,16 @@ class TestCleanNotionDocumentTask: tenant_id=account.current_tenant.id, dataset_id=dataset.id, position=0, - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, data_source_info=json.dumps( {"notion_workspace_id": f"workspace_{i}", "notion_page_id": f"page_{i}", "type": "page"} ), batch="test_batch", name=f"Notion Page {i}", - created_from="notion_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) db_session_with_containers.add(document) db_session_with_containers.flush() @@ -894,7 +895,7 @@ class TestCleanNotionDocumentTask: tokens=50, index_node_id=f"node_{i}_{j}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, ) db_session_with_containers.add(segment) all_segments.append(segment) @@ -963,14 +964,22 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, created_by=account.id, ) db_session_with_containers.add(dataset) db_session_with_containers.flush() # Create documents with different indexing statuses - document_statuses = ["waiting", "parsing", "cleaning", "splitting", "indexing", "completed", "error"] + document_statuses = [ + IndexingStatus.WAITING, + IndexingStatus.PARSING, + IndexingStatus.CLEANING, + IndexingStatus.SPLITTING, + IndexingStatus.INDEXING, + IndexingStatus.COMPLETED, + IndexingStatus.ERROR, + ] documents = [] all_segments = [] all_index_node_ids = [] @@ -981,13 +990,13 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, dataset_id=dataset.id, position=i, - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, data_source_info=json.dumps( {"notion_workspace_id": f"workspace_{i}", "notion_page_id": f"page_{i}", "type": "page"} ), batch="test_batch", name=f"Notion Page {i}", - created_from="notion_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_language="en", indexing_status=status, @@ -1009,7 +1018,7 @@ class TestCleanNotionDocumentTask: tokens=50, index_node_id=f"node_{i}_{j}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, ) db_session_with_containers.add(segment) all_segments.append(segment) @@ -1066,7 +1075,7 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, created_by=account.id, built_in_field_enabled=True, ) @@ -1079,7 +1088,7 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, data_source_info=json.dumps( { "notion_workspace_id": "workspace_test", @@ -1091,10 +1100,10 @@ class TestCleanNotionDocumentTask: ), batch="test_batch", name="Test Notion Page with Metadata", - created_from="notion_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, doc_metadata={ "document_name": "Test Notion Page with Metadata", "uploader": account.name, @@ -1122,7 +1131,7 @@ class TestCleanNotionDocumentTask: tokens=75, index_node_id=f"node_{i}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, keywords={"key1": ["value1", "value2"], "key2": ["value3"]}, ) db_session_with_containers.add(segment) diff --git a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py index 4fa52ff2a9..132f43c320 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py @@ -15,6 +15,7 @@ from faker import Faker from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from tasks.create_segment_to_index_task import create_segment_to_index_task @@ -118,7 +119,7 @@ class TestCreateSegmentToIndexTask: name=fake.company(), description=fake.text(max_nb_chars=100), tenant_id=tenant_id, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", embedding_model_provider="openai", embedding_model="text-embedding-ada-002", @@ -133,13 +134,13 @@ class TestCreateSegmentToIndexTask: dataset_id=dataset.id, tenant_id=tenant_id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account_id, enabled=True, archived=False, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, doc_form="qa_model", ) db_session_with_containers.add(document) @@ -148,7 +149,7 @@ class TestCreateSegmentToIndexTask: return dataset, document def _create_test_segment( - self, db_session_with_containers, dataset_id, document_id, tenant_id, account_id, status="waiting" + self, db_session_with_containers, dataset_id, document_id, tenant_id, account_id, status=SegmentStatus.WAITING ): """ Helper method to create a test document segment for testing. @@ -200,7 +201,7 @@ class TestCreateSegmentToIndexTask: account, tenant = self._create_test_account_and_tenant(db_session_with_containers) dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) # Act: Execute the task @@ -208,7 +209,7 @@ class TestCreateSegmentToIndexTask: # Assert: Verify segment status changes db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED assert segment.indexing_at is not None assert segment.completed_at is not None assert segment.error is None @@ -257,7 +258,7 @@ class TestCreateSegmentToIndexTask: account, tenant = self._create_test_account_and_tenant(db_session_with_containers) dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="completed" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.COMPLETED ) # Act: Execute the task @@ -268,7 +269,7 @@ class TestCreateSegmentToIndexTask: # Verify segment status unchanged db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED assert segment.indexing_at is None # Verify no index processor calls were made @@ -293,20 +294,25 @@ class TestCreateSegmentToIndexTask: dataset_id=invalid_dataset_id, tenant_id=tenant.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, enabled=True, archived=False, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, doc_form="text_model", ) db_session_with_containers.add(document) db_session_with_containers.commit() segment = self._create_test_segment( - db_session_with_containers, invalid_dataset_id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, + invalid_dataset_id, + document.id, + tenant.id, + account.id, + status=SegmentStatus.WAITING, ) # Act: Execute the task @@ -317,7 +323,7 @@ class TestCreateSegmentToIndexTask: # Verify segment status changed to indexing (task updates status before checking document) db_session_with_containers.refresh(segment) - assert segment.status == "indexing" + assert segment.status == SegmentStatus.INDEXING # Verify no index processor calls were made mock_external_service_dependencies["index_processor_factory"].assert_not_called() @@ -337,7 +343,12 @@ class TestCreateSegmentToIndexTask: invalid_document_id = str(uuid4()) segment = self._create_test_segment( - db_session_with_containers, dataset.id, invalid_document_id, tenant.id, account.id, status="waiting" + db_session_with_containers, + dataset.id, + invalid_document_id, + tenant.id, + account.id, + status=SegmentStatus.WAITING, ) # Act: Execute the task @@ -348,7 +359,7 @@ class TestCreateSegmentToIndexTask: # Verify segment status changed to indexing (task updates status before checking document) db_session_with_containers.refresh(segment) - assert segment.status == "indexing" + assert segment.status == SegmentStatus.INDEXING # Verify no index processor calls were made mock_external_service_dependencies["index_processor_factory"].assert_not_called() @@ -373,7 +384,7 @@ class TestCreateSegmentToIndexTask: db_session_with_containers.commit() segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) # Act: Execute the task @@ -384,7 +395,7 @@ class TestCreateSegmentToIndexTask: # Verify segment status changed to indexing (task updates status before checking document) db_session_with_containers.refresh(segment) - assert segment.status == "indexing" + assert segment.status == SegmentStatus.INDEXING # Verify no index processor calls were made mock_external_service_dependencies["index_processor_factory"].assert_not_called() @@ -409,7 +420,7 @@ class TestCreateSegmentToIndexTask: db_session_with_containers.commit() segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) # Act: Execute the task @@ -420,7 +431,7 @@ class TestCreateSegmentToIndexTask: # Verify segment status changed to indexing (task updates status before checking document) db_session_with_containers.refresh(segment) - assert segment.status == "indexing" + assert segment.status == SegmentStatus.INDEXING # Verify no index processor calls were made mock_external_service_dependencies["index_processor_factory"].assert_not_called() @@ -445,7 +456,7 @@ class TestCreateSegmentToIndexTask: db_session_with_containers.commit() segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) # Act: Execute the task @@ -456,7 +467,7 @@ class TestCreateSegmentToIndexTask: # Verify segment status changed to indexing (task updates status before checking document) db_session_with_containers.refresh(segment) - assert segment.status == "indexing" + assert segment.status == SegmentStatus.INDEXING # Verify no index processor calls were made mock_external_service_dependencies["index_processor_factory"].assert_not_called() @@ -477,7 +488,7 @@ class TestCreateSegmentToIndexTask: account, tenant = self._create_test_account_and_tenant(db_session_with_containers) dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) # Mock processor to raise exception @@ -488,7 +499,7 @@ class TestCreateSegmentToIndexTask: # Assert: Verify error handling db_session_with_containers.refresh(segment) - assert segment.status == "error" + assert segment.status == SegmentStatus.ERROR assert segment.enabled is False assert segment.disabled_at is not None assert segment.error == "Processor failed" @@ -512,7 +523,7 @@ class TestCreateSegmentToIndexTask: account, tenant = self._create_test_account_and_tenant(db_session_with_containers) dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) custom_keywords = ["custom", "keywords", "test"] @@ -521,7 +532,7 @@ class TestCreateSegmentToIndexTask: # Assert: Verify successful indexing db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED assert segment.indexing_at is not None assert segment.completed_at is not None @@ -555,7 +566,7 @@ class TestCreateSegmentToIndexTask: db_session_with_containers.commit() segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) # Act: Execute the task @@ -563,7 +574,7 @@ class TestCreateSegmentToIndexTask: # Assert: Verify successful indexing db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED # Verify correct doc_form was passed to factory mock_external_service_dependencies["index_processor_factory"].assert_called_with(doc_form) @@ -583,7 +594,7 @@ class TestCreateSegmentToIndexTask: account, tenant = self._create_test_account_and_tenant(db_session_with_containers) dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) # Act: Execute the task and measure time @@ -597,7 +608,7 @@ class TestCreateSegmentToIndexTask: # Verify successful completion db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED def test_create_segment_to_index_concurrent_execution( self, db_session_with_containers, mock_external_service_dependencies @@ -617,7 +628,7 @@ class TestCreateSegmentToIndexTask: segments = [] for i in range(3): segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) segments.append(segment) @@ -629,7 +640,7 @@ class TestCreateSegmentToIndexTask: # Assert: Verify all segments processed for segment in segments: db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED assert segment.indexing_at is not None assert segment.completed_at is not None @@ -665,7 +676,7 @@ class TestCreateSegmentToIndexTask: keywords=["large", "content", "test"], index_node_id=str(uuid4()), index_node_hash=str(uuid4()), - status="waiting", + status=SegmentStatus.WAITING, created_by=account.id, ) db_session_with_containers.add(segment) @@ -681,7 +692,7 @@ class TestCreateSegmentToIndexTask: assert execution_time < 10.0 # Should complete within 10 seconds db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED assert segment.indexing_at is not None assert segment.completed_at is not None @@ -700,7 +711,7 @@ class TestCreateSegmentToIndexTask: account, tenant = self._create_test_account_and_tenant(db_session_with_containers) dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) # Set up Redis cache key to simulate indexing in progress @@ -718,7 +729,7 @@ class TestCreateSegmentToIndexTask: # Assert: Verify indexing still completed successfully despite Redis failure db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED assert segment.indexing_at is not None assert segment.completed_at is not None @@ -740,7 +751,7 @@ class TestCreateSegmentToIndexTask: account, tenant = self._create_test_account_and_tenant(db_session_with_containers) dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) # Simulate an error during indexing to trigger rollback path @@ -752,7 +763,7 @@ class TestCreateSegmentToIndexTask: # Assert: Verify error handling and rollback db_session_with_containers.refresh(segment) - assert segment.status == "error" + assert segment.status == SegmentStatus.ERROR assert segment.enabled is False assert segment.disabled_at is not None assert segment.error is not None @@ -772,7 +783,7 @@ class TestCreateSegmentToIndexTask: account, tenant = self._create_test_account_and_tenant(db_session_with_containers) dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) # Act: Execute the task @@ -780,7 +791,7 @@ class TestCreateSegmentToIndexTask: # Assert: Verify successful indexing db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED # Verify index processor was called with correct metadata mock_processor = mock_external_service_dependencies["index_processor"] @@ -814,11 +825,11 @@ class TestCreateSegmentToIndexTask: account, tenant = self._create_test_account_and_tenant(db_session_with_containers) dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) # Verify initial state - assert segment.status == "waiting" + assert segment.status == SegmentStatus.WAITING assert segment.indexing_at is None assert segment.completed_at is None @@ -827,7 +838,7 @@ class TestCreateSegmentToIndexTask: # Assert: Verify final state db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED assert segment.indexing_at is not None assert segment.completed_at is not None @@ -861,7 +872,7 @@ class TestCreateSegmentToIndexTask: keywords=[], index_node_id=str(uuid4()), index_node_hash=str(uuid4()), - status="waiting", + status=SegmentStatus.WAITING, created_by=account.id, ) db_session_with_containers.add(segment) @@ -872,7 +883,7 @@ class TestCreateSegmentToIndexTask: # Assert: Verify successful indexing db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED assert segment.indexing_at is not None assert segment.completed_at is not None @@ -907,7 +918,7 @@ class TestCreateSegmentToIndexTask: keywords=["special", "unicode", "test"], index_node_id=str(uuid4()), index_node_hash=str(uuid4()), - status="waiting", + status=SegmentStatus.WAITING, created_by=account.id, ) db_session_with_containers.add(segment) @@ -918,7 +929,7 @@ class TestCreateSegmentToIndexTask: # Assert: Verify successful indexing db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED assert segment.indexing_at is not None assert segment.completed_at is not None @@ -937,7 +948,7 @@ class TestCreateSegmentToIndexTask: account, tenant = self._create_test_account_and_tenant(db_session_with_containers) dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) # Create long keyword list @@ -948,7 +959,7 @@ class TestCreateSegmentToIndexTask: # Assert: Verify successful indexing db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED assert segment.indexing_at is not None assert segment.completed_at is not None @@ -979,10 +990,10 @@ class TestCreateSegmentToIndexTask: ) segment1 = self._create_test_segment( - db_session_with_containers, dataset1.id, document1.id, tenant1.id, account1.id, status="waiting" + db_session_with_containers, dataset1.id, document1.id, tenant1.id, account1.id, status=SegmentStatus.WAITING ) segment2 = self._create_test_segment( - db_session_with_containers, dataset2.id, document2.id, tenant2.id, account2.id, status="waiting" + db_session_with_containers, dataset2.id, document2.id, tenant2.id, account2.id, status=SegmentStatus.WAITING ) # Act: Execute tasks for both tenants @@ -993,8 +1004,8 @@ class TestCreateSegmentToIndexTask: db_session_with_containers.refresh(segment1) db_session_with_containers.refresh(segment2) - assert segment1.status == "completed" - assert segment2.status == "completed" + assert segment1.status == SegmentStatus.COMPLETED + assert segment2.status == SegmentStatus.COMPLETED assert segment1.tenant_id == tenant1.id assert segment2.tenant_id == tenant2.id assert segment1.tenant_id != segment2.tenant_id @@ -1014,7 +1025,7 @@ class TestCreateSegmentToIndexTask: account, tenant = self._create_test_account_and_tenant(db_session_with_containers) dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) # Act: Execute the task with None keywords @@ -1022,7 +1033,7 @@ class TestCreateSegmentToIndexTask: # Assert: Verify successful indexing db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED assert segment.indexing_at is not None assert segment.completed_at is not None @@ -1050,7 +1061,7 @@ class TestCreateSegmentToIndexTask: segments = [] for i in range(5): segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) segments.append(segment) @@ -1067,7 +1078,7 @@ class TestCreateSegmentToIndexTask: # Verify all segments processed successfully for segment in segments: db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED assert segment.indexing_at is not None assert segment.completed_at is not None assert segment.error is None diff --git a/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py index 4a62383590..67f9dc7011 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py @@ -11,6 +11,7 @@ from core.indexing_runner import DocumentIsPausedError from enums.cloud_plan import CloudPlan from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus from tasks.document_indexing_task import ( _document_indexing, _document_indexing_with_tenant_queue, @@ -139,7 +140,7 @@ class TestDatasetIndexingTaskIntegration: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", created_by=account.id, ) @@ -155,12 +156,12 @@ class TestDatasetIndexingTaskIntegration: tenant_id=tenant.id, dataset_id=dataset.id, position=position, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name=f"doc-{position}.txt", - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, - indexing_status="waiting", + indexing_status=IndexingStatus.WAITING, enabled=True, ) db_session_with_containers.add(document) @@ -181,7 +182,7 @@ class TestDatasetIndexingTaskIntegration: for document_id in document_ids: updated = self._query_document(db_session_with_containers, document_id) assert updated is not None - assert updated.indexing_status == "parsing" + assert updated.indexing_status == IndexingStatus.PARSING assert updated.processing_started_at is not None def _assert_documents_error_contains( @@ -195,7 +196,7 @@ class TestDatasetIndexingTaskIntegration: for document_id in document_ids: updated = self._query_document(db_session_with_containers, document_id) assert updated is not None - assert updated.indexing_status == "error" + assert updated.indexing_status == IndexingStatus.ERROR assert updated.error is not None assert expected_error_substring in updated.error assert updated.stopped_at is not None diff --git a/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py index 10c719fb6d..e80b37ac1b 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py @@ -13,6 +13,7 @@ import pytest from faker import Faker from models.dataset import Dataset, Document, DocumentSegment +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from services.account_service import AccountService, TenantService from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task from tests.test_containers_integration_tests.helpers import generate_valid_password @@ -90,7 +91,7 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -102,13 +103,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Document for doc_form", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -150,7 +151,7 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -162,13 +163,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Document for doc_form", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -182,13 +183,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Test Document", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -209,7 +210,7 @@ class TestDealDatasetVectorIndexTask: index_node_id=f"node_{uuid.uuid4()}", index_node_hash=f"hash_{uuid.uuid4()}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, enabled=True, ) db_session_with_containers.add(segment) @@ -220,7 +221,7 @@ class TestDealDatasetVectorIndexTask: # Verify document status was updated to indexing then completed updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() - assert updated_document.indexing_status == "completed" + assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor load method was called mock_factory = mock_index_processor_factory.return_value @@ -251,7 +252,7 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -263,13 +264,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Document for doc_form", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="parent_child_index", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -283,13 +284,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Test Document", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="parent_child_index", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -310,7 +311,7 @@ class TestDealDatasetVectorIndexTask: index_node_id=f"node_{uuid.uuid4()}", index_node_hash=f"hash_{uuid.uuid4()}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, enabled=True, ) db_session_with_containers.add(segment) @@ -321,7 +322,7 @@ class TestDealDatasetVectorIndexTask: # Verify document status was updated to indexing then completed updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() - assert updated_document.indexing_status == "completed" + assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor clean and load methods were called mock_factory = mock_index_processor_factory.return_value @@ -367,7 +368,7 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -399,7 +400,7 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -411,13 +412,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Test Document", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -430,7 +431,7 @@ class TestDealDatasetVectorIndexTask: # Verify document status was updated to indexing then completed updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() - assert updated_document.indexing_status == "completed" + assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify that no index processor load was called since no segments exist mock_factory = mock_index_processor_factory.return_value @@ -455,7 +456,7 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -488,7 +489,7 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -500,13 +501,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Document for doc_form", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -520,13 +521,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Test Document", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -547,7 +548,7 @@ class TestDealDatasetVectorIndexTask: index_node_id=f"node_{uuid.uuid4()}", index_node_hash=f"hash_{uuid.uuid4()}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, enabled=True, ) db_session_with_containers.add(segment) @@ -563,7 +564,7 @@ class TestDealDatasetVectorIndexTask: # Verify document status was updated to error updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() - assert updated_document.indexing_status == "error" + assert updated_document.indexing_status == IndexingStatus.ERROR assert "Test exception during indexing" in updated_document.error def test_deal_dataset_vector_index_task_with_custom_index_type( @@ -584,7 +585,7 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -596,13 +597,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Test Document", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="qa_index", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -623,7 +624,7 @@ class TestDealDatasetVectorIndexTask: index_node_id=f"node_{uuid.uuid4()}", index_node_hash=f"hash_{uuid.uuid4()}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, enabled=True, ) db_session_with_containers.add(segment) @@ -634,7 +635,7 @@ class TestDealDatasetVectorIndexTask: # Verify document status was updated to indexing then completed updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() - assert updated_document.indexing_status == "completed" + assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor was initialized with custom index type mock_index_processor_factory.assert_called_once_with("qa_index") @@ -660,7 +661,7 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -672,13 +673,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Test Document", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -699,7 +700,7 @@ class TestDealDatasetVectorIndexTask: index_node_id=f"node_{uuid.uuid4()}", index_node_hash=f"hash_{uuid.uuid4()}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, enabled=True, ) db_session_with_containers.add(segment) @@ -710,7 +711,7 @@ class TestDealDatasetVectorIndexTask: # Verify document status was updated to indexing then completed updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() - assert updated_document.indexing_status == "completed" + assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor was initialized with the document's index type mock_index_processor_factory.assert_called_once_with("text_model") @@ -736,7 +737,7 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -748,13 +749,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Document for doc_form", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -770,13 +771,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=i, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name=f"Test Document {i}", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -801,7 +802,7 @@ class TestDealDatasetVectorIndexTask: index_node_id=f"node_{i}_{j}", index_node_hash=f"hash_{i}_{j}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, enabled=True, ) db_session_with_containers.add(segment) @@ -814,7 +815,7 @@ class TestDealDatasetVectorIndexTask: # Verify all documents were processed for document in documents: updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() - assert updated_document.indexing_status == "completed" + assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor load was called multiple times mock_factory = mock_index_processor_factory.return_value @@ -839,7 +840,7 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -851,13 +852,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Document for doc_form", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -871,13 +872,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Test Document", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -898,7 +899,7 @@ class TestDealDatasetVectorIndexTask: index_node_id=f"node_{uuid.uuid4()}", index_node_hash=f"hash_{uuid.uuid4()}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, enabled=True, ) db_session_with_containers.add(segment) @@ -916,7 +917,7 @@ class TestDealDatasetVectorIndexTask: # Verify final document status updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() - assert updated_document.indexing_status == "completed" + assert updated_document.indexing_status == IndexingStatus.COMPLETED def test_deal_dataset_vector_index_task_with_disabled_documents( self, db_session_with_containers, mock_index_processor_factory, account_and_tenant @@ -936,7 +937,7 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -948,13 +949,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Document for doc_form", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -968,13 +969,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Enabled Document", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -987,13 +988,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=1, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Disabled Document", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=False, # This document should be skipped archived=False, batch="test_batch", @@ -1015,7 +1016,7 @@ class TestDealDatasetVectorIndexTask: index_node_id=f"node_{uuid.uuid4()}", index_node_hash=f"hash_{uuid.uuid4()}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, enabled=True, ) db_session_with_containers.add(segment) @@ -1026,13 +1027,13 @@ class TestDealDatasetVectorIndexTask: # Verify only enabled document was processed updated_enabled_document = db_session_with_containers.query(Document).filter_by(id=enabled_document.id).first() - assert updated_enabled_document.indexing_status == "completed" + assert updated_enabled_document.indexing_status == IndexingStatus.COMPLETED # Verify disabled document status remains unchanged updated_disabled_document = ( db_session_with_containers.query(Document).filter_by(id=disabled_document.id).first() ) - assert updated_disabled_document.indexing_status == "completed" # Should not change + assert updated_disabled_document.indexing_status == IndexingStatus.COMPLETED # Should not change # Verify index processor load was called only once (for enabled document) mock_factory = mock_index_processor_factory.return_value @@ -1057,7 +1058,7 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -1069,13 +1070,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Document for doc_form", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -1089,13 +1090,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Active Document", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -1108,13 +1109,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=1, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Archived Document", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=True, # This document should be skipped batch="test_batch", @@ -1136,7 +1137,7 @@ class TestDealDatasetVectorIndexTask: index_node_id=f"node_{uuid.uuid4()}", index_node_hash=f"hash_{uuid.uuid4()}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, enabled=True, ) db_session_with_containers.add(segment) @@ -1147,13 +1148,13 @@ class TestDealDatasetVectorIndexTask: # Verify only active document was processed updated_active_document = db_session_with_containers.query(Document).filter_by(id=active_document.id).first() - assert updated_active_document.indexing_status == "completed" + assert updated_active_document.indexing_status == IndexingStatus.COMPLETED # Verify archived document status remains unchanged updated_archived_document = ( db_session_with_containers.query(Document).filter_by(id=archived_document.id).first() ) - assert updated_archived_document.indexing_status == "completed" # Should not change + assert updated_archived_document.indexing_status == IndexingStatus.COMPLETED # Should not change # Verify index processor load was called only once (for active document) mock_factory = mock_index_processor_factory.return_value @@ -1178,7 +1179,7 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -1190,13 +1191,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Document for doc_form", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -1210,13 +1211,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Completed Document", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -1229,13 +1230,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=1, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Incomplete Document", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="indexing", # This document should be skipped + indexing_status=IndexingStatus.INDEXING, # This document should be skipped enabled=True, archived=False, batch="test_batch", @@ -1257,7 +1258,7 @@ class TestDealDatasetVectorIndexTask: index_node_id=f"node_{uuid.uuid4()}", index_node_hash=f"hash_{uuid.uuid4()}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, enabled=True, ) db_session_with_containers.add(segment) @@ -1270,13 +1271,13 @@ class TestDealDatasetVectorIndexTask: updated_completed_document = ( db_session_with_containers.query(Document).filter_by(id=completed_document.id).first() ) - assert updated_completed_document.indexing_status == "completed" + assert updated_completed_document.indexing_status == IndexingStatus.COMPLETED # Verify incomplete document status remains unchanged updated_incomplete_document = ( db_session_with_containers.query(Document).filter_by(id=incomplete_document.id).first() ) - assert updated_incomplete_document.indexing_status == "indexing" # Should not change + assert updated_incomplete_document.indexing_status == IndexingStatus.INDEXING # Should not change # Verify index processor load was called only once (for completed document) mock_factory = mock_index_processor_factory.return_value diff --git a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py index 69ed5b632d..6fc2a53f9c 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py @@ -14,6 +14,7 @@ from faker import Faker from core.rag.index_processor.constant.index_type import IndexStructureType from models import Account, Dataset, Document, DocumentSegment, Tenant +from models.enums import DataSourceType, DocumentCreatedFrom, DocumentDocType, IndexingStatus, SegmentStatus from tasks.delete_segment_from_index_task import delete_segment_from_index_task logger = logging.getLogger(__name__) @@ -106,7 +107,7 @@ class TestDeleteSegmentFromIndexTask: dataset.description = fake.text(max_nb_chars=200) dataset.provider = "vendor" dataset.permission = "only_me" - dataset.data_source_type = "upload_file" + dataset.data_source_type = DataSourceType.UPLOAD_FILE dataset.indexing_technique = "high_quality" dataset.index_struct = '{"type": "paragraph"}' dataset.created_by = account.id @@ -145,7 +146,7 @@ class TestDeleteSegmentFromIndexTask: document.data_source_info = kwargs.get("data_source_info", "{}") document.batch = kwargs.get("batch", fake.uuid4()) document.name = kwargs.get("name", f"Test Document {fake.word()}") - document.created_from = kwargs.get("created_from", "api") + document.created_from = kwargs.get("created_from", DocumentCreatedFrom.API) document.created_by = account.id document.created_at = fake.date_time_this_year() document.processing_started_at = kwargs.get("processing_started_at", fake.date_time_this_year()) @@ -162,7 +163,7 @@ class TestDeleteSegmentFromIndexTask: document.enabled = kwargs.get("enabled", True) document.archived = kwargs.get("archived", False) document.updated_at = fake.date_time_this_year() - document.doc_type = kwargs.get("doc_type", "text") + document.doc_type = kwargs.get("doc_type", DocumentDocType.PERSONAL_DOCUMENT) document.doc_metadata = kwargs.get("doc_metadata", {}) document.doc_form = kwargs.get("doc_form", IndexStructureType.PARAGRAPH_INDEX) document.doc_language = kwargs.get("doc_language", "en") @@ -204,7 +205,7 @@ class TestDeleteSegmentFromIndexTask: segment.index_node_hash = fake.sha256() segment.hit_count = 0 segment.enabled = True - segment.status = "completed" + segment.status = SegmentStatus.COMPLETED segment.created_by = account.id segment.created_at = fake.date_time_this_year() segment.updated_by = account.id @@ -386,7 +387,7 @@ class TestDeleteSegmentFromIndexTask: account = self._create_test_account(db_session_with_containers, tenant, fake) dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake) document = self._create_test_document( - db_session_with_containers, dataset, account, fake, indexing_status="indexing" + db_session_with_containers, dataset, account, fake, indexing_status=IndexingStatus.INDEXING ) segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake) diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py index ab9e5b639a..da42fc7167 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py @@ -18,6 +18,7 @@ from sqlalchemy.orm import Session from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from tasks.disable_segment_from_index_task import disable_segment_from_index_task logger = logging.getLogger(__name__) @@ -97,7 +98,7 @@ class TestDisableSegmentFromIndexTask: tenant_id=tenant.id, name=fake.sentence(nb_words=3), description=fake.text(max_nb_chars=200), - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", created_by=account.id, ) @@ -132,12 +133,12 @@ class TestDisableSegmentFromIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch=fake.uuid4(), name=fake.file_name(), - created_from="api", + created_from=DocumentCreatedFrom.API, created_by=account.id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, doc_form=doc_form, @@ -189,7 +190,7 @@ class TestDisableSegmentFromIndexTask: status=status, enabled=enabled, created_by=account.id, - completed_at=datetime.now(UTC) if status == "completed" else None, + completed_at=datetime.now(UTC) if status == SegmentStatus.COMPLETED else None, ) db_session_with_containers.add(segment) db_session_with_containers.commit() @@ -271,7 +272,7 @@ class TestDisableSegmentFromIndexTask: dataset = self._create_test_dataset(db_session_with_containers, tenant, account) document = self._create_test_document(db_session_with_containers, dataset, tenant, account) segment = self._create_test_segment( - db_session_with_containers, document, dataset, tenant, account, status="indexing", enabled=True + db_session_with_containers, document, dataset, tenant, account, status=SegmentStatus.INDEXING, enabled=True ) # Act: Execute the task diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py index 6f7d2c28b5..4bc9bb4749 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py @@ -14,6 +14,7 @@ from sqlalchemy.orm import Session from models import Account, Dataset, DocumentSegment from models import Document as DatasetDocument from models.dataset import DatasetProcessRule +from models.enums import DataSourceType, DocumentCreatedFrom, ProcessRuleMode, SegmentStatus from tasks.disable_segments_from_index_task import disable_segments_from_index_task @@ -100,7 +101,7 @@ class TestDisableSegmentsFromIndexTask: description=fake.text(max_nb_chars=200), provider="vendor", permission="only_me", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", created_by=account.id, updated_by=account.id, @@ -134,11 +135,11 @@ class TestDisableSegmentsFromIndexTask: document.tenant_id = dataset.tenant_id document.dataset_id = dataset.id document.position = 1 - document.data_source_type = "upload_file" + document.data_source_type = DataSourceType.UPLOAD_FILE document.data_source_info = '{"upload_file_id": "test_file_id"}' document.batch = fake.uuid4() document.name = f"Test Document {fake.word()}.txt" - document.created_from = "upload_file" + document.created_from = DocumentCreatedFrom.WEB document.created_by = account.id document.created_api_request_id = fake.uuid4() document.processing_started_at = fake.date_time_this_year() @@ -197,7 +198,7 @@ class TestDisableSegmentsFromIndexTask: segment.enabled = True segment.disabled_at = None segment.disabled_by = None - segment.status = "completed" + segment.status = SegmentStatus.COMPLETED segment.created_by = account.id segment.updated_by = account.id segment.indexing_at = fake.date_time_this_year() @@ -230,7 +231,7 @@ class TestDisableSegmentsFromIndexTask: process_rule.id = fake.uuid4() process_rule.tenant_id = dataset.tenant_id process_rule.dataset_id = dataset.id - process_rule.mode = "automatic" + process_rule.mode = ProcessRuleMode.AUTOMATIC process_rule.rules = ( "{" '"mode": "automatic", ' diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py index df5c5dc54b..6a17a19a54 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py @@ -16,6 +16,7 @@ import pytest from core.indexing_runner import DocumentIsPausedError, IndexingRunner from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from tasks.document_indexing_sync_task import document_indexing_sync_task @@ -54,7 +55,7 @@ class DocumentIndexingSyncTaskTestDataFactory: tenant_id=tenant_id, name=f"dataset-{uuid4()}", description="sync test dataset", - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, indexing_technique="high_quality", created_by=created_by, ) @@ -76,11 +77,11 @@ class DocumentIndexingSyncTaskTestDataFactory: tenant_id=tenant_id, dataset_id=dataset_id, position=0, - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, data_source_info=json.dumps(data_source_info) if data_source_info is not None else None, batch="test-batch", name=f"doc-{uuid4()}", - created_from="notion_import", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, indexing_status=indexing_status, enabled=True, @@ -113,7 +114,7 @@ class DocumentIndexingSyncTaskTestDataFactory: word_count=10, tokens=5, index_node_id=f"node-{document_id}-{i}", - status="completed", + status=SegmentStatus.COMPLETED, created_by=created_by, ) db_session_with_containers.add(segment) @@ -181,7 +182,7 @@ class TestDocumentIndexingSyncTask: dataset_id=dataset.id, created_by=account.id, data_source_info=notion_info, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) segments = DocumentIndexingSyncTaskTestDataFactory.create_segments( @@ -276,7 +277,7 @@ class TestDocumentIndexingSyncTask: db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() ) assert updated_document is not None - assert updated_document.indexing_status == "error" + assert updated_document.indexing_status == IndexingStatus.ERROR assert "Datasource credential not found" in updated_document.error assert updated_document.stopped_at is not None mock_external_dependencies["indexing_runner"].run.assert_not_called() @@ -301,7 +302,7 @@ class TestDocumentIndexingSyncTask: .count() ) assert updated_document is not None - assert updated_document.indexing_status == "completed" + assert updated_document.indexing_status == IndexingStatus.COMPLETED assert updated_document.processing_started_at is None assert remaining_segments == 3 mock_external_dependencies["index_processor"].clean.assert_not_called() @@ -327,7 +328,7 @@ class TestDocumentIndexingSyncTask: ) assert updated_document is not None - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None assert updated_document.data_source_info_dict.get("last_edited_time") == "2024-01-02T00:00:00Z" assert remaining_segments == 0 @@ -369,7 +370,7 @@ class TestDocumentIndexingSyncTask: db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() ) assert updated_document is not None - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING mock_external_dependencies["index_processor"].clean.assert_not_called() mock_external_dependencies["indexing_runner"].run.assert_called_once() @@ -393,7 +394,7 @@ class TestDocumentIndexingSyncTask: .count() ) assert updated_document is not None - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert remaining_segments == 0 mock_external_dependencies["indexing_runner"].run.assert_called_once() @@ -412,7 +413,7 @@ class TestDocumentIndexingSyncTask: db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() ) assert updated_document is not None - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.error is None def test_indexing_runner_general_error(self, db_session_with_containers, mock_external_dependencies): @@ -430,7 +431,7 @@ class TestDocumentIndexingSyncTask: db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() ) assert updated_document is not None - assert updated_document.indexing_status == "error" + assert updated_document.indexing_status == IndexingStatus.ERROR assert "Indexing error" in updated_document.error assert updated_document.stopped_at is not None diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py index 5dc1f6bee0..9421b07285 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py @@ -8,6 +8,7 @@ from core.entities.document_task import DocumentTask from enums.cloud_plan import CloudPlan from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus from tasks.document_indexing_task import ( _document_indexing, # Core function _document_indexing_with_tenant_queue, # Tenant queue wrapper function @@ -97,7 +98,7 @@ class TestDocumentIndexingTasks: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", created_by=account.id, ) @@ -112,12 +113,12 @@ class TestDocumentIndexingTasks: tenant_id=tenant.id, dataset_id=dataset.id, position=i, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name=fake.file_name(), - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, - indexing_status="waiting", + indexing_status=IndexingStatus.WAITING, enabled=True, ) db_session_with_containers.add(document) @@ -179,7 +180,7 @@ class TestDocumentIndexingTasks: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", created_by=account.id, ) @@ -194,12 +195,12 @@ class TestDocumentIndexingTasks: tenant_id=tenant.id, dataset_id=dataset.id, position=i, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name=fake.file_name(), - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, - indexing_status="waiting", + indexing_status=IndexingStatus.WAITING, enabled=True, ) db_session_with_containers.add(document) @@ -250,7 +251,7 @@ class TestDocumentIndexingTasks: # Re-query documents from database since _document_indexing uses a different session for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None # Verify the run method was called with correct documents @@ -320,7 +321,7 @@ class TestDocumentIndexingTasks: # Re-query documents from database since _document_indexing uses a different session for doc_id in existing_document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None # Verify the run method was called with only existing documents @@ -367,7 +368,7 @@ class TestDocumentIndexingTasks: # Re-query documents from database since _document_indexing close the session for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None def test_document_indexing_task_mixed_document_states( @@ -397,12 +398,12 @@ class TestDocumentIndexingTasks: tenant_id=dataset.tenant_id, dataset_id=dataset.id, position=2, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name=fake.file_name(), - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=dataset.created_by, - indexing_status="completed", # Already completed + indexing_status=IndexingStatus.COMPLETED, # Already completed enabled=True, ) db_session_with_containers.add(doc1) @@ -414,12 +415,12 @@ class TestDocumentIndexingTasks: tenant_id=dataset.tenant_id, dataset_id=dataset.id, position=3, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name=fake.file_name(), - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=dataset.created_by, - indexing_status="waiting", + indexing_status=IndexingStatus.WAITING, enabled=False, # Disabled ) db_session_with_containers.add(doc2) @@ -444,7 +445,7 @@ class TestDocumentIndexingTasks: # Re-query documents from database since _document_indexing uses a different session for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None # Verify the run method was called with all documents @@ -482,12 +483,12 @@ class TestDocumentIndexingTasks: tenant_id=dataset.tenant_id, dataset_id=dataset.id, position=i + 3, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name=fake.file_name(), - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=dataset.created_by, - indexing_status="waiting", + indexing_status=IndexingStatus.WAITING, enabled=True, ) db_session_with_containers.add(document) @@ -507,7 +508,7 @@ class TestDocumentIndexingTasks: # Re-query documents from database since _document_indexing uses a different session for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "error" + assert updated_document.indexing_status == IndexingStatus.ERROR assert updated_document.error is not None assert "batch upload" in updated_document.error assert updated_document.stopped_at is not None @@ -548,7 +549,7 @@ class TestDocumentIndexingTasks: # Re-query documents from database since _document_indexing uses a different session for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None def test_document_indexing_task_document_is_paused_error( @@ -591,7 +592,7 @@ class TestDocumentIndexingTasks: # Re-query documents from database since _document_indexing uses a different session for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None # ==================== NEW TESTS FOR REFACTORED FUNCTIONS ==================== @@ -702,7 +703,7 @@ class TestDocumentIndexingTasks: # Re-query documents from database since _document_indexing uses a different session for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None # Verify the run method was called with correct documents @@ -827,7 +828,7 @@ class TestDocumentIndexingTasks: # Re-query documents from database since _document_indexing uses a different session for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None # Verify waiting task was still processed despite core processing error diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py index 9da9a4132e..2fbea1388c 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py @@ -5,6 +5,7 @@ from faker import Faker from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from tasks.document_indexing_update_task import document_indexing_update_task @@ -61,7 +62,7 @@ class TestDocumentIndexingUpdateTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=64), - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", created_by=account.id, ) @@ -72,12 +73,12 @@ class TestDocumentIndexingUpdateTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name=fake.file_name(), - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, - indexing_status="waiting", + indexing_status=IndexingStatus.WAITING, enabled=True, doc_form="text_model", ) @@ -98,7 +99,7 @@ class TestDocumentIndexingUpdateTask: word_count=10, tokens=5, index_node_id=node_id, - status="completed", + status=SegmentStatus.COMPLETED, created_by=account.id, ) db_session_with_containers.add(seg) @@ -122,7 +123,7 @@ class TestDocumentIndexingUpdateTask: # Assert document status updated before reindex updated = db_session_with_containers.query(Document).where(Document.id == document.id).first() - assert updated.indexing_status == "parsing" + assert updated.indexing_status == IndexingStatus.PARSING assert updated.processing_started_at is not None # Segments should be deleted diff --git a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py index c61e37b1e9..f1f5a4b105 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py @@ -7,6 +7,7 @@ from core.indexing_runner import DocumentIsPausedError from enums.cloud_plan import CloudPlan from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from tasks.duplicate_document_indexing_task import ( _duplicate_document_indexing_task, # Core function _duplicate_document_indexing_task_with_tenant_queue, # Tenant queue wrapper function @@ -107,7 +108,7 @@ class TestDuplicateDocumentIndexingTasks: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", created_by=account.id, ) @@ -122,12 +123,12 @@ class TestDuplicateDocumentIndexingTasks: tenant_id=tenant.id, dataset_id=dataset.id, position=i, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name=fake.file_name(), - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, - indexing_status="waiting", + indexing_status=IndexingStatus.WAITING, enabled=True, doc_form="text_model", ) @@ -177,7 +178,7 @@ class TestDuplicateDocumentIndexingTasks: content=fake.text(max_nb_chars=200), word_count=50, tokens=100, - status="completed", + status=SegmentStatus.COMPLETED, enabled=True, indexing_at=fake.date_time_this_year(), created_by=dataset.created_by, # Add required field @@ -242,7 +243,7 @@ class TestDuplicateDocumentIndexingTasks: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", created_by=account.id, ) @@ -257,12 +258,12 @@ class TestDuplicateDocumentIndexingTasks: tenant_id=tenant.id, dataset_id=dataset.id, position=i, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name=fake.file_name(), - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, - indexing_status="waiting", + indexing_status=IndexingStatus.WAITING, enabled=True, doc_form="text_model", ) @@ -316,7 +317,7 @@ class TestDuplicateDocumentIndexingTasks: # Re-query documents from database since _duplicate_document_indexing_task uses a different session for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None # Verify the run method was called with correct documents @@ -368,7 +369,7 @@ class TestDuplicateDocumentIndexingTasks: # Verify documents were updated to parsing status for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None # Verify indexing runner was called @@ -437,7 +438,7 @@ class TestDuplicateDocumentIndexingTasks: # Re-query documents from database since _duplicate_document_indexing_task uses a different session for doc_id in existing_document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None # Verify the run method was called with only existing documents @@ -484,7 +485,7 @@ class TestDuplicateDocumentIndexingTasks: # Re-query documents from database since _duplicate_document_indexing_task close the session for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None def _test_duplicate_document_indexing_task_billing_sandbox_plan_batch_limit( @@ -516,12 +517,12 @@ class TestDuplicateDocumentIndexingTasks: tenant_id=dataset.tenant_id, dataset_id=dataset.id, position=i + 3, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name=fake.file_name(), - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=dataset.created_by, - indexing_status="waiting", + indexing_status=IndexingStatus.WAITING, enabled=True, doc_form="text_model", ) @@ -542,7 +543,7 @@ class TestDuplicateDocumentIndexingTasks: # Re-query documents from database since _duplicate_document_indexing_task uses a different session for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "error" + assert updated_document.indexing_status == IndexingStatus.ERROR assert updated_document.error is not None assert "batch upload" in updated_document.error.lower() assert updated_document.stopped_at is not None @@ -584,7 +585,7 @@ class TestDuplicateDocumentIndexingTasks: # Re-query documents from database since _duplicate_document_indexing_task uses a different session for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "error" + assert updated_document.indexing_status == IndexingStatus.ERROR assert updated_document.error is not None assert "limit" in updated_document.error.lower() assert updated_document.stopped_at is not None @@ -648,7 +649,7 @@ class TestDuplicateDocumentIndexingTasks: # Verify documents were processed for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True) def test_normal_duplicate_document_indexing_task_with_tenant_queue( @@ -691,7 +692,7 @@ class TestDuplicateDocumentIndexingTasks: # Verify documents were processed for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True) def test_priority_duplicate_document_indexing_task_with_tenant_queue( @@ -735,7 +736,7 @@ class TestDuplicateDocumentIndexingTasks: # Verify documents were processed for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True) def test_tenant_queue_wrapper_processes_next_tasks( @@ -851,7 +852,7 @@ class TestDuplicateDocumentIndexingTasks: for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() assert updated_document.is_paused is True - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.display_status == "paused" assert updated_document.processing_started_at is not None mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() diff --git a/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py index bc29395545..54b50016a8 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py @@ -8,6 +8,7 @@ from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.ext_redis import redis_client from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from tasks.enable_segments_to_index_task import enable_segments_to_index_task @@ -79,7 +80,7 @@ class TestEnableSegmentsToIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", created_by=account.id, ) @@ -92,12 +93,12 @@ class TestEnableSegmentsToIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name=fake.file_name(), - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, doc_form=IndexStructureType.PARAGRAPH_INDEX, ) @@ -110,7 +111,13 @@ class TestEnableSegmentsToIndexTask: return dataset, document def _create_test_segments( - self, db_session_with_containers: Session, document, dataset, count=3, enabled=False, status="completed" + self, + db_session_with_containers: Session, + document, + dataset, + count=3, + enabled=False, + status=SegmentStatus.COMPLETED, ): """ Helper method to create test document segments. @@ -278,7 +285,7 @@ class TestEnableSegmentsToIndexTask: invalid_statuses = [ ("disabled", {"enabled": False}), ("archived", {"archived": True}), - ("not_completed", {"indexing_status": "processing"}), + ("not_completed", {"indexing_status": IndexingStatus.INDEXING}), ] for _, status_attrs in invalid_statuses: @@ -447,7 +454,7 @@ class TestEnableSegmentsToIndexTask: for segment in segments: db_session_with_containers.refresh(segment) assert segment.enabled is False - assert segment.status == "error" + assert segment.status == SegmentStatus.ERROR assert segment.error is not None assert "Index processing failed" in segment.error assert segment.disabled_at is not None diff --git a/api/tests/unit_tests/commands/test_clean_expired_messages.py b/api/tests/unit_tests/commands/test_clean_expired_messages.py index 60173f723d..5375988a69 100644 --- a/api/tests/unit_tests/commands/test_clean_expired_messages.py +++ b/api/tests/unit_tests/commands/test_clean_expired_messages.py @@ -46,6 +46,7 @@ def test_absolute_mode_calls_from_time_range(): end_before=end_before, batch_size=200, dry_run=True, + task_label="custom", ) mock_from_days.assert_not_called() @@ -74,6 +75,7 @@ def test_relative_mode_before_days_only_calls_from_days(): days=30, batch_size=500, dry_run=False, + task_label="before-30", ) mock_from_time_range.assert_not_called() @@ -105,6 +107,7 @@ def test_relative_mode_with_from_days_ago_calls_from_time_range(): end_before=fixed_now - datetime.timedelta(days=30), batch_size=1000, dry_run=False, + task_label="60to30", ) mock_from_days.assert_not_called() diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py index 3b8679f4ec..ebbb34e069 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py @@ -59,6 +59,44 @@ class TestPipelineTemplateDetailApi: assert status == 200 assert response == template + def test_get_returns_404_when_template_not_found(self, app): + api = PipelineTemplateDetailApi() + method = unwrap(api.get) + + service = MagicMock() + service.get_pipeline_template_detail.return_value = None + + with ( + app.test_request_context("/?type=built-in"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService", + return_value=service, + ), + ): + response, status = method(api, "non-existent-id") + + assert status == 404 + assert "error" in response + + def test_get_returns_404_for_customized_type_not_found(self, app): + api = PipelineTemplateDetailApi() + method = unwrap(api.get) + + service = MagicMock() + service.get_pipeline_template_detail.return_value = None + + with ( + app.test_request_context("/?type=customized"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService", + return_value=service, + ), + ): + response, status = method(api, "non-existent-id") + + assert status == 404 + assert "error" in response + class TestCustomizedPipelineTemplateApi: def test_patch_success(self, app): diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py index dbe54ccb99..f23dd5b44a 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py @@ -30,6 +30,7 @@ from controllers.console.datasets.error import ( InvalidActionError, InvalidMetadataError, ) +from models.enums import DataSourceType, IndexingStatus def unwrap(func): @@ -62,8 +63,8 @@ def document(): return MagicMock( id="doc-1", tenant_id="tenant-1", - indexing_status="indexing", - data_source_type="upload_file", + indexing_status=IndexingStatus.INDEXING, + data_source_type=DataSourceType.UPLOAD_FILE, data_source_info_dict={"upload_file_id": "file-1"}, doc_form="text", archived=False, @@ -407,7 +408,7 @@ class TestDocumentProcessingApi: api = DocumentProcessingApi() method = unwrap(api.patch) - doc = MagicMock(indexing_status="error", is_paused=True) + doc = MagicMock(indexing_status=IndexingStatus.ERROR, is_paused=True) with ( app.test_request_context("/"), @@ -425,7 +426,7 @@ class TestDocumentProcessingApi: api = DocumentProcessingApi() method = unwrap(api.patch) - document = MagicMock(indexing_status="paused", is_paused=True) + document = MagicMock(indexing_status=IndexingStatus.PAUSED, is_paused=True) with ( app.test_request_context("/"), @@ -461,7 +462,7 @@ class TestDocumentProcessingApi: api = DocumentProcessingApi() method = unwrap(api.patch) - document = MagicMock(indexing_status="completed") + document = MagicMock(indexing_status=IndexingStatus.COMPLETED) with app.test_request_context("/"), patch.object(api, "get_document", return_value=document): with pytest.raises(InvalidActionError): @@ -630,7 +631,7 @@ class TestDocumentRetryApi: payload = {"document_ids": ["doc-1"]} - document = MagicMock(indexing_status="indexing", archived=False) + document = MagicMock(indexing_status=IndexingStatus.INDEXING, archived=False) with ( app.test_request_context("/", json=payload), @@ -659,7 +660,7 @@ class TestDocumentRetryApi: payload = {"document_ids": ["doc-1"]} - document = MagicMock(indexing_status="completed", archived=False) + document = MagicMock(indexing_status=IndexingStatus.COMPLETED, archived=False) with ( app.test_request_context("/", json=payload), @@ -817,8 +818,8 @@ class TestDocumentIndexingEstimateApi: method = unwrap(api.get) document = MagicMock( - indexing_status="indexing", - data_source_type="upload_file", + indexing_status=IndexingStatus.INDEXING, + data_source_type=DataSourceType.UPLOAD_FILE, data_source_info_dict={"upload_file_id": "file-1"}, tenant_id="tenant-1", doc_form="text", @@ -844,8 +845,8 @@ class TestDocumentIndexingEstimateApi: method = unwrap(api.get) document = MagicMock( - indexing_status="indexing", - data_source_type="upload_file", + indexing_status=IndexingStatus.INDEXING, + data_source_type=DataSourceType.UPLOAD_FILE, data_source_info_dict={"upload_file_id": "file-1"}, tenant_id="tenant-1", doc_form="text", @@ -882,7 +883,7 @@ class TestDocumentIndexingEstimateApi: api = DocumentIndexingEstimateApi() method = unwrap(api.get) - document = MagicMock(indexing_status="completed") + document = MagicMock(indexing_status=IndexingStatus.COMPLETED) with app.test_request_context("/"), patch.object(api, "get_document", return_value=document): with pytest.raises(DocumentAlreadyFinishedError): @@ -963,8 +964,8 @@ class TestDocumentBatchIndexingEstimateApi: method = unwrap(api.get) doc = MagicMock( - indexing_status="indexing", - data_source_type="website_crawl", + indexing_status=IndexingStatus.INDEXING, + data_source_type=DataSourceType.WEBSITE_CRAWL, data_source_info_dict={ "provider": "firecrawl", "job_id": "j1", @@ -992,8 +993,8 @@ class TestDocumentBatchIndexingEstimateApi: method = unwrap(api.get) doc = MagicMock( - indexing_status="indexing", - data_source_type="notion_import", + indexing_status=IndexingStatus.INDEXING, + data_source_type=DataSourceType.NOTION_IMPORT, data_source_info_dict={ "credential_id": "c1", "notion_workspace_id": "w1", @@ -1020,7 +1021,7 @@ class TestDocumentBatchIndexingEstimateApi: method = unwrap(api.get) document = MagicMock( - indexing_status="indexing", + indexing_status=IndexingStatus.INDEXING, data_source_type="unknown", data_source_info_dict={}, doc_form="text", @@ -1130,7 +1131,7 @@ class TestDocumentProcessingApiResume: api = DocumentProcessingApi() method = unwrap(api.patch) - document = MagicMock(indexing_status="completed", is_paused=False) + document = MagicMock(indexing_status=IndexingStatus.COMPLETED, is_paused=False) with app.test_request_context("/"), patch.object(api, "get_document", return_value=document): with pytest.raises(InvalidActionError): @@ -1348,8 +1349,8 @@ class TestDocumentIndexingEdgeCases: method = unwrap(api.get) document = MagicMock( - indexing_status="indexing", - data_source_type="upload_file", + indexing_status=IndexingStatus.INDEXING, + data_source_type=DataSourceType.UPLOAD_FILE, data_source_info_dict={"upload_file_id": "file-1"}, tenant_id="tenant-1", doc_form="text", diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py index dc651a1627..5c48ef1804 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py @@ -32,6 +32,7 @@ from controllers.service_api.dataset.segment import ( SegmentListQuery, ) from models.dataset import ChildChunk, Dataset, Document, DocumentSegment +from models.enums import IndexingStatus from services.dataset_service import DocumentService, SegmentService @@ -657,12 +658,27 @@ class TestSegmentIndexingRequirements: dataset.indexing_technique = technique assert dataset.indexing_technique in ["high_quality", "economy"] - @pytest.mark.parametrize("status", ["waiting", "parsing", "indexing", "completed", "error"]) + @pytest.mark.parametrize( + "status", + [ + IndexingStatus.WAITING, + IndexingStatus.PARSING, + IndexingStatus.INDEXING, + IndexingStatus.COMPLETED, + IndexingStatus.ERROR, + ], + ) def test_valid_indexing_statuses(self, status): """Test valid document indexing statuses.""" document = Mock(spec=Document) document.indexing_status = status - assert document.indexing_status in ["waiting", "parsing", "indexing", "completed", "error"] + assert document.indexing_status in { + IndexingStatus.WAITING, + IndexingStatus.PARSING, + IndexingStatus.INDEXING, + IndexingStatus.COMPLETED, + IndexingStatus.ERROR, + } def test_completed_status_required_for_segments(self): """Test that completed status is required for segment operations.""" diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py index f98109af79..e6e841be19 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py @@ -35,6 +35,7 @@ from controllers.service_api.dataset.document import ( InvalidMetadataError, ) from controllers.service_api.dataset.error import ArchivedDocumentImmutableError +from models.enums import IndexingStatus from services.dataset_service import DocumentService from services.entities.knowledge_entities.knowledge_entities import ProcessRule, RetrievalModel @@ -244,23 +245,26 @@ class TestDocumentService: class TestDocumentIndexingStatus: """Test document indexing status values.""" + _VALID_STATUSES = { + IndexingStatus.WAITING, + IndexingStatus.PARSING, + IndexingStatus.INDEXING, + IndexingStatus.COMPLETED, + IndexingStatus.ERROR, + IndexingStatus.PAUSED, + } + def test_completed_status(self): """Test completed status.""" - status = "completed" - valid_statuses = ["waiting", "parsing", "indexing", "completed", "error", "paused"] - assert status in valid_statuses + assert IndexingStatus.COMPLETED in self._VALID_STATUSES def test_indexing_status(self): """Test indexing status.""" - status = "indexing" - valid_statuses = ["waiting", "parsing", "indexing", "completed", "error", "paused"] - assert status in valid_statuses + assert IndexingStatus.INDEXING in self._VALID_STATUSES def test_error_status(self): """Test error status.""" - status = "error" - valid_statuses = ["waiting", "parsing", "indexing", "completed", "error", "paused"] - assert status in valid_statuses + assert IndexingStatus.ERROR in self._VALID_STATUSES class TestDocumentDocForm: diff --git a/api/tests/unit_tests/controllers/trigger/test_webhook.py b/api/tests/unit_tests/controllers/trigger/test_webhook.py index d633365f2b..91c793d292 100644 --- a/api/tests/unit_tests/controllers/trigger/test_webhook.py +++ b/api/tests/unit_tests/controllers/trigger/test_webhook.py @@ -23,6 +23,7 @@ def mock_jsonify(): class DummyWebhookTrigger: webhook_id = "wh-1" + webhook_url = "http://localhost:5001/triggers/webhook/wh-1" tenant_id = "tenant-1" app_id = "app-1" node_id = "node-1" @@ -104,7 +105,32 @@ class TestHandleWebhookDebug: @patch.object(module.WebhookService, "get_webhook_trigger_and_workflow") @patch.object(module.WebhookService, "extract_and_validate_webhook_data") @patch.object(module.WebhookService, "build_workflow_inputs", return_value={"x": 1}) - @patch.object(module.TriggerDebugEventBus, "dispatch") + @patch.object(module.TriggerDebugEventBus, "dispatch", return_value=0) + def test_debug_requires_active_listener( + self, + mock_dispatch, + mock_build_inputs, + mock_extract, + mock_get, + ): + mock_get.return_value = (DummyWebhookTrigger(), None, "node_config") + mock_extract.return_value = {"method": "POST"} + + response, status = module.handle_webhook_debug("wh-1") + + assert status == 409 + assert response["error"] == "No active debug listener" + assert response["message"] == ( + "The webhook debug URL only works while the Variable Inspector is listening. " + "Use the published webhook URL to execute the workflow in Celery." + ) + assert response["execution_url"] == DummyWebhookTrigger.webhook_url + mock_dispatch.assert_called_once() + + @patch.object(module.WebhookService, "get_webhook_trigger_and_workflow") + @patch.object(module.WebhookService, "extract_and_validate_webhook_data") + @patch.object(module.WebhookService, "build_workflow_inputs", return_value={"x": 1}) + @patch.object(module.TriggerDebugEventBus, "dispatch", return_value=1) @patch.object(module.WebhookService, "generate_webhook_response") def test_debug_success( self, diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py index 82f98d07a3..75473fc89a 100644 --- a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py +++ b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py @@ -35,6 +35,7 @@ from dify_graph.model_runtime.entities.provider_entities import ( ProviderCredentialSchema, ProviderEntity, ) +from models.enums import CredentialSourceType from models.provider import ProviderType from models.provider_ids import ModelProviderID @@ -514,7 +515,7 @@ def test_get_custom_provider_models_sets_status_for_removed_credentials_and_inva id="lb-base", name="LB Base", credentials={}, - credential_source_type="provider", + credential_source_type=CredentialSourceType.PROVIDER, ) ], ), @@ -528,7 +529,7 @@ def test_get_custom_provider_models_sets_status_for_removed_credentials_and_inva id="lb-custom", name="LB Custom", credentials={}, - credential_source_type="custom_model", + credential_source_type=CredentialSourceType.CUSTOM_MODEL, ) ], ), @@ -734,7 +735,7 @@ def test_create_provider_credential_creates_provider_record_when_missing() -> No def test_create_provider_credential_marks_existing_provider_as_valid() -> None: configuration = _build_provider_configuration() session = Mock() - provider_record = SimpleNamespace(is_valid=False) + provider_record = SimpleNamespace(id="provider-1", is_valid=False, credential_id="existing-cred") with _patched_session(session): with patch.object(ProviderConfiguration, "_check_provider_credential_name_exists", return_value=False): @@ -743,6 +744,25 @@ def test_create_provider_credential_marks_existing_provider_as_valid() -> None: configuration.create_provider_credential({"api_key": "raw"}, "Main") assert provider_record.is_valid is True + assert provider_record.credential_id == "existing-cred" + session.commit.assert_called_once() + + +def test_create_provider_credential_auto_activates_when_no_active_credential() -> None: + configuration = _build_provider_configuration() + session = Mock() + provider_record = SimpleNamespace(id="provider-1", is_valid=False, credential_id=None, updated_at=None) + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_check_provider_credential_name_exists", return_value=False): + with patch.object(ProviderConfiguration, "validate_provider_credentials", return_value={"api_key": "enc"}): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=provider_record): + with patch("core.entities.provider_configuration.ProviderCredentialsCache"): + with patch.object(ProviderConfiguration, "switch_preferred_provider_type"): + configuration.create_provider_credential({"api_key": "raw"}, "Main") + + assert provider_record.is_valid is True + assert provider_record.credential_id is not None session.commit.assert_called_once() @@ -807,7 +827,7 @@ def test_update_load_balancing_configs_updates_all_matching_configs() -> None: configuration._update_load_balancing_configs_with_credential( credential_id="cred-1", credential_record=credential_record, - credential_source="provider", + credential_source=CredentialSourceType.PROVIDER, session=session, ) @@ -825,7 +845,7 @@ def test_update_load_balancing_configs_returns_when_no_matching_configs() -> Non configuration._update_load_balancing_configs_with_credential( credential_id="cred-1", credential_record=SimpleNamespace(encrypted_config="{}", credential_name="Main"), - credential_source="provider", + credential_source=CredentialSourceType.PROVIDER, session=session, ) diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py index 2451db70b6..e6cc582398 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py @@ -236,7 +236,8 @@ class TestParagraphIndexProcessor: "core.rag.index_processor.processor.paragraph_index_processor.RetrievalService.retrieve" ) as mock_retrieve: mock_retrieve.return_value = [accepted, rejected] - docs = processor.retrieve("semantic_search", "query", dataset, 5, 0.5, {}) + reranking_model = {"reranking_provider_name": "", "reranking_model_name": ""} + docs = processor.retrieve("semantic_search", "query", dataset, 5, 0.5, reranking_model) assert len(docs) == 1 assert docs[0].metadata["score"] == 0.9 diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py index abe40f05d1..5c78cae7c1 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py @@ -307,7 +307,8 @@ class TestParentChildIndexProcessor: "core.rag.index_processor.processor.parent_child_index_processor.RetrievalService.retrieve" ) as mock_retrieve: mock_retrieve.return_value = [ok_result, low_result] - docs = processor.retrieve("semantic_search", "query", dataset, 3, 0.5, {}) + reranking_model = {"reranking_provider_name": "", "reranking_model_name": ""} + docs = processor.retrieve("semantic_search", "query", dataset, 3, 0.5, reranking_model) assert len(docs) == 1 assert docs[0].page_content == "keep" diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py index 8596647ef3..99323eeec9 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py @@ -262,7 +262,8 @@ class TestQAIndexProcessor: with patch("core.rag.index_processor.processor.qa_index_processor.RetrievalService.retrieve") as mock_retrieve: mock_retrieve.return_value = [result_ok, result_low] - docs = processor.retrieve("semantic_search", "query", dataset, 5, 0.5, {}) + reranking_model = {"reranking_provider_name": "", "reranking_model_name": ""} + docs = processor.retrieve("semantic_search", "query", dataset, 5, 0.5, reranking_model) assert len(docs) == 1 assert docs[0].page_content == "accepted" diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index d61f01c616..665e98bd9c 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -25,6 +25,7 @@ from core.app.app_config.entities import ModelConfig as WorkflowModelConfig from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity from core.entities.agent_entities import PlanningStrategy from core.entities.model_entities import ModelStatus +from core.rag.data_post_processor.data_post_processor import WeightsDict from core.rag.datasource.retrieval_service import RetrievalService from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.index_type import IndexStructureType @@ -4686,7 +4687,10 @@ class TestSingleAndMultipleRetrieveCoverage: extra={"dataset_name": "Ext", "title": "Ext"}, ) app = Flask(__name__) - weights = {"vector_setting": {}} + weights: WeightsDict = { + "vector_setting": {"vector_weight": 0.5, "embedding_provider_name": "", "embedding_model_name": ""}, + "keyword_setting": {"keyword_weight": 0.5}, + } def fake_multiple_thread(**kwargs): if kwargs["query"]: diff --git a/api/tests/unit_tests/models/test_account_models.py b/api/tests/unit_tests/models/test_account_models.py index 1726fc2e8b..f48db77bb5 100644 --- a/api/tests/unit_tests/models/test_account_models.py +++ b/api/tests/unit_tests/models/test_account_models.py @@ -622,28 +622,10 @@ class TestAccountGetByOpenId: mock_account = Account(name="Test User", email="test@example.com") mock_account.id = account_id - # Mock the query chain - mock_query = MagicMock() - mock_where = MagicMock() - mock_where.one_or_none.return_value = mock_account_integrate - mock_query.where.return_value = mock_where - mock_db.session.query.return_value = mock_query - - # Mock the second query for account - mock_account_query = MagicMock() - mock_account_where = MagicMock() - mock_account_where.one_or_none.return_value = mock_account - mock_account_query.where.return_value = mock_account_where - - # Setup query to return different results based on model - def query_side_effect(model): - if model.__name__ == "AccountIntegrate": - return mock_query - elif model.__name__ == "Account": - return mock_account_query - return MagicMock() - - mock_db.session.query.side_effect = query_side_effect + # Mock db.session.execute().scalar_one_or_none() for AccountIntegrate lookup + mock_db.session.execute.return_value.scalar_one_or_none.return_value = mock_account_integrate + # Mock db.session.scalar() for Account lookup + mock_db.session.scalar.return_value = mock_account # Act result = Account.get_by_openid(provider, open_id) @@ -658,12 +640,8 @@ class TestAccountGetByOpenId: provider = "github" open_id = "github_user_456" - # Mock the query chain to return None - mock_query = MagicMock() - mock_where = MagicMock() - mock_where.one_or_none.return_value = None - mock_query.where.return_value = mock_where - mock_db.session.query.return_value = mock_query + # Mock db.session.execute().scalar_one_or_none() to return None + mock_db.session.execute.return_value.scalar_one_or_none.return_value = None # Act result = Account.get_by_openid(provider, open_id) diff --git a/api/tests/unit_tests/models/test_app_models.py b/api/tests/unit_tests/models/test_app_models.py index 6c619dcf98..329fe554ea 100644 --- a/api/tests/unit_tests/models/test_app_models.py +++ b/api/tests/unit_tests/models/test_app_models.py @@ -300,10 +300,8 @@ class TestAppModelConfig: created_by=str(uuid4()), ) - # Mock database query to return None - with patch("models.model.db.session.query", autospec=True) as mock_query: - mock_query.return_value.where.return_value.first.return_value = None - + # Mock database scalar to return None (no annotation setting found) + with patch("models.model.db.session.scalar", return_value=None): # Act result = config.annotation_reply_dict @@ -951,10 +949,8 @@ class TestSiteModel: def test_site_generate_code(self): """Test Site.generate_code static method.""" - # Mock database query to return 0 (no existing codes) - with patch("models.model.db.session.query", autospec=True) as mock_query: - mock_query.return_value.where.return_value.count.return_value = 0 - + # Mock database scalar to return 0 (no existing codes) + with patch("models.model.db.session.scalar", return_value=0): # Act code = Site.generate_code(8) diff --git a/api/tests/unit_tests/models/test_dataset_models.py b/api/tests/unit_tests/models/test_dataset_models.py index 9bb7c05a91..98dd07907a 100644 --- a/api/tests/unit_tests/models/test_dataset_models.py +++ b/api/tests/unit_tests/models/test_dataset_models.py @@ -25,6 +25,13 @@ from models.dataset import ( DocumentSegment, Embedding, ) +from models.enums import ( + DataSourceType, + DocumentCreatedFrom, + IndexingStatus, + ProcessRuleMode, + SegmentStatus, +) class TestDatasetModelValidation: @@ -40,14 +47,14 @@ class TestDatasetModelValidation: dataset = Dataset( tenant_id=tenant_id, name="Test Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by, ) # Assert assert dataset.name == "Test Dataset" assert dataset.tenant_id == tenant_id - assert dataset.data_source_type == "upload_file" + assert dataset.data_source_type == DataSourceType.UPLOAD_FILE assert dataset.created_by == created_by # Note: Default values are set by database, not by model instantiation @@ -57,7 +64,7 @@ class TestDatasetModelValidation: dataset = Dataset( tenant_id=str(uuid4()), name="Test Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=str(uuid4()), description="Test description", indexing_technique="high_quality", @@ -77,14 +84,14 @@ class TestDatasetModelValidation: dataset_high_quality = Dataset( tenant_id=str(uuid4()), name="High Quality Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=str(uuid4()), indexing_technique="high_quality", ) dataset_economy = Dataset( tenant_id=str(uuid4()), name="Economy Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=str(uuid4()), indexing_technique="economy", ) @@ -101,14 +108,14 @@ class TestDatasetModelValidation: dataset_vendor = Dataset( tenant_id=str(uuid4()), name="Vendor Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=str(uuid4()), provider="vendor", ) dataset_external = Dataset( tenant_id=str(uuid4()), name="External Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=str(uuid4()), provider="external", ) @@ -126,7 +133,7 @@ class TestDatasetModelValidation: dataset = Dataset( tenant_id=str(uuid4()), name="Test Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=str(uuid4()), index_struct=json.dumps(index_struct_data), ) @@ -145,7 +152,7 @@ class TestDatasetModelValidation: dataset = Dataset( tenant_id=str(uuid4()), name="Test Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=str(uuid4()), ) @@ -161,7 +168,7 @@ class TestDatasetModelValidation: dataset = Dataset( tenant_id=str(uuid4()), name="Test Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=str(uuid4()), ) @@ -178,7 +185,7 @@ class TestDatasetModelValidation: dataset = Dataset( tenant_id=str(uuid4()), name="Test Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=str(uuid4()), ) @@ -218,10 +225,10 @@ class TestDocumentModelRelationships: tenant_id=tenant_id, dataset_id=dataset_id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test_document.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, ) @@ -229,10 +236,10 @@ class TestDocumentModelRelationships: assert document.tenant_id == tenant_id assert document.dataset_id == dataset_id assert document.position == 1 - assert document.data_source_type == "upload_file" + assert document.data_source_type == DataSourceType.UPLOAD_FILE assert document.batch == "batch_001" assert document.name == "test_document.pdf" - assert document.created_from == "web" + assert document.created_from == DocumentCreatedFrom.WEB assert document.created_by == created_by # Note: Default values are set by database, not by model instantiation @@ -250,12 +257,12 @@ class TestDocumentModelRelationships: tenant_id=str(uuid4()), dataset_id=str(uuid4()), position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), - indexing_status="waiting", + indexing_status=IndexingStatus.WAITING, ) # Act @@ -271,12 +278,12 @@ class TestDocumentModelRelationships: tenant_id=str(uuid4()), dataset_id=str(uuid4()), position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), - indexing_status="parsing", + indexing_status=IndexingStatus.PARSING, is_paused=True, ) @@ -289,15 +296,20 @@ class TestDocumentModelRelationships: def test_document_display_status_indexing(self): """Test document display_status property for indexing state.""" # Arrange - for indexing_status in ["parsing", "cleaning", "splitting", "indexing"]: + for indexing_status in [ + IndexingStatus.PARSING, + IndexingStatus.CLEANING, + IndexingStatus.SPLITTING, + IndexingStatus.INDEXING, + ]: document = Document( tenant_id=str(uuid4()), dataset_id=str(uuid4()), position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), indexing_status=indexing_status, ) @@ -315,12 +327,12 @@ class TestDocumentModelRelationships: tenant_id=str(uuid4()), dataset_id=str(uuid4()), position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), - indexing_status="error", + indexing_status=IndexingStatus.ERROR, ) # Act @@ -336,12 +348,12 @@ class TestDocumentModelRelationships: tenant_id=str(uuid4()), dataset_id=str(uuid4()), position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, ) @@ -359,12 +371,12 @@ class TestDocumentModelRelationships: tenant_id=str(uuid4()), dataset_id=str(uuid4()), position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=False, archived=False, ) @@ -382,12 +394,12 @@ class TestDocumentModelRelationships: tenant_id=str(uuid4()), dataset_id=str(uuid4()), position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, archived=True, ) @@ -405,10 +417,10 @@ class TestDocumentModelRelationships: tenant_id=str(uuid4()), dataset_id=str(uuid4()), position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), data_source_info=json.dumps(data_source_info), ) @@ -428,10 +440,10 @@ class TestDocumentModelRelationships: tenant_id=str(uuid4()), dataset_id=str(uuid4()), position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), ) @@ -448,10 +460,10 @@ class TestDocumentModelRelationships: tenant_id=str(uuid4()), dataset_id=str(uuid4()), position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), word_count=1000, ) @@ -471,10 +483,10 @@ class TestDocumentModelRelationships: tenant_id=str(uuid4()), dataset_id=str(uuid4()), position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), word_count=0, ) @@ -582,7 +594,7 @@ class TestDocumentSegmentIndexing: word_count=1, tokens=2, created_by=str(uuid4()), - status="waiting", + status=SegmentStatus.WAITING, ) segment_completed = DocumentSegment( tenant_id=str(uuid4()), @@ -593,12 +605,12 @@ class TestDocumentSegmentIndexing: word_count=1, tokens=2, created_by=str(uuid4()), - status="completed", + status=SegmentStatus.COMPLETED, ) # Assert - assert segment_waiting.status == "waiting" - assert segment_completed.status == "completed" + assert segment_waiting.status == SegmentStatus.WAITING + assert segment_completed.status == SegmentStatus.COMPLETED def test_document_segment_enabled_disabled_tracking(self): """Test document segment enabled/disabled state tracking.""" @@ -769,13 +781,13 @@ class TestDatasetProcessRule: # Act process_rule = DatasetProcessRule( dataset_id=dataset_id, - mode="automatic", + mode=ProcessRuleMode.AUTOMATIC, created_by=created_by, ) # Assert assert process_rule.dataset_id == dataset_id - assert process_rule.mode == "automatic" + assert process_rule.mode == ProcessRuleMode.AUTOMATIC assert process_rule.created_by == created_by def test_dataset_process_rule_modes(self): @@ -797,7 +809,7 @@ class TestDatasetProcessRule: } process_rule = DatasetProcessRule( dataset_id=str(uuid4()), - mode="custom", + mode=ProcessRuleMode.CUSTOM, created_by=str(uuid4()), rules=json.dumps(rules_data), ) @@ -817,7 +829,7 @@ class TestDatasetProcessRule: rules_data = {"test": "data"} process_rule = DatasetProcessRule( dataset_id=dataset_id, - mode="automatic", + mode=ProcessRuleMode.AUTOMATIC, created_by=str(uuid4()), rules=json.dumps(rules_data), ) @@ -827,7 +839,7 @@ class TestDatasetProcessRule: # Assert assert result["dataset_id"] == dataset_id - assert result["mode"] == "automatic" + assert result["mode"] == ProcessRuleMode.AUTOMATIC assert result["rules"] == rules_data def test_dataset_process_rule_automatic_rules(self): @@ -969,7 +981,7 @@ class TestModelIntegration: dataset = Dataset( tenant_id=tenant_id, name="Test Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by, indexing_technique="high_quality", ) @@ -980,10 +992,10 @@ class TestModelIntegration: tenant_id=tenant_id, dataset_id=dataset_id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, word_count=100, ) @@ -999,7 +1011,7 @@ class TestModelIntegration: word_count=3, tokens=5, created_by=created_by, - status="completed", + status=SegmentStatus.COMPLETED, ) # Assert @@ -1009,7 +1021,7 @@ class TestModelIntegration: assert segment.document_id == document_id assert dataset.indexing_technique == "high_quality" assert document.word_count == 100 - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED def test_document_to_dict_serialization(self): """Test document to_dict method for serialization.""" @@ -1022,13 +1034,13 @@ class TestModelIntegration: tenant_id=tenant_id, dataset_id=dataset_id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, word_count=100, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) # Mock segment_count and hit_count @@ -1044,6 +1056,6 @@ class TestModelIntegration: assert result["dataset_id"] == dataset_id assert result["name"] == "test.pdf" assert result["word_count"] == 100 - assert result["indexing_status"] == "completed" + assert result["indexing_status"] == IndexingStatus.COMPLETED assert result["segment_count"] == 5 assert result["hit_count"] == 10 diff --git a/api/tests/unit_tests/models/test_enums_creator_user_role.py b/api/tests/unit_tests/models/test_enums_creator_user_role.py new file mode 100644 index 0000000000..6317166fdc --- /dev/null +++ b/api/tests/unit_tests/models/test_enums_creator_user_role.py @@ -0,0 +1,19 @@ +import pytest + +from models.enums import CreatorUserRole + + +def test_creator_user_role_missing_maps_hyphen_to_enum(): + # given an alias with hyphen + value = "end-user" + + # when converting to enum (invokes StrEnum._missing_ override) + role = CreatorUserRole(value) + + # then it should map to END_USER + assert role is CreatorUserRole.END_USER + + +def test_creator_user_role_missing_raises_for_unknown(): + with pytest.raises(ValueError): + CreatorUserRole("unknown") diff --git a/api/tests/unit_tests/models/test_provider_models.py b/api/tests/unit_tests/models/test_provider_models.py index ec84a61c8e..f628e54a4d 100644 --- a/api/tests/unit_tests/models/test_provider_models.py +++ b/api/tests/unit_tests/models/test_provider_models.py @@ -19,6 +19,7 @@ from uuid import uuid4 import pytest +from models.enums import CredentialSourceType, PaymentStatus from models.provider import ( LoadBalancingModelConfig, Provider, @@ -158,7 +159,7 @@ class TestProviderModel: # Assert assert provider.tenant_id == tenant_id assert provider.provider_name == provider_name - assert provider.provider_type == "custom" + assert provider.provider_type == ProviderType.CUSTOM assert provider.is_valid is False assert provider.quota_used == 0 @@ -172,10 +173,10 @@ class TestProviderModel: provider = Provider( tenant_id=tenant_id, provider_name="anthropic", - provider_type="system", + provider_type=ProviderType.SYSTEM, is_valid=True, credential_id=credential_id, - quota_type="paid", + quota_type=ProviderQuotaType.PAID, quota_limit=10000, quota_used=500, ) @@ -183,10 +184,10 @@ class TestProviderModel: # Assert assert provider.tenant_id == tenant_id assert provider.provider_name == "anthropic" - assert provider.provider_type == "system" + assert provider.provider_type == ProviderType.SYSTEM assert provider.is_valid is True assert provider.credential_id == credential_id - assert provider.quota_type == "paid" + assert provider.quota_type == ProviderQuotaType.PAID assert provider.quota_limit == 10000 assert provider.quota_used == 500 @@ -199,7 +200,7 @@ class TestProviderModel: ) # Assert - assert provider.provider_type == "custom" + assert provider.provider_type == ProviderType.CUSTOM assert provider.is_valid is False assert provider.quota_type == "" assert provider.quota_limit is None @@ -213,7 +214,7 @@ class TestProviderModel: provider = Provider( tenant_id=tenant_id, provider_name="openai", - provider_type="custom", + provider_type=ProviderType.CUSTOM, ) # Act @@ -253,7 +254,7 @@ class TestProviderModel: provider = Provider( tenant_id=str(uuid4()), provider_name="openai", - provider_type=ProviderType.SYSTEM.value, + provider_type=ProviderType.SYSTEM, is_valid=True, ) @@ -266,13 +267,13 @@ class TestProviderModel: provider = Provider( tenant_id=str(uuid4()), provider_name="openai", - quota_type="trial", + quota_type=ProviderQuotaType.TRIAL, quota_limit=1000, quota_used=250, ) # Assert - assert provider.quota_type == "trial" + assert provider.quota_type == ProviderQuotaType.TRIAL assert provider.quota_limit == 1000 assert provider.quota_used == 250 remaining = provider.quota_limit - provider.quota_used @@ -429,13 +430,13 @@ class TestTenantPreferredModelProvider: preferred = TenantPreferredModelProvider( tenant_id=tenant_id, provider_name="openai", - preferred_provider_type="custom", + preferred_provider_type=ProviderType.CUSTOM, ) # Assert assert preferred.tenant_id == tenant_id assert preferred.provider_name == "openai" - assert preferred.preferred_provider_type == "custom" + assert preferred.preferred_provider_type == ProviderType.CUSTOM def test_tenant_preferred_provider_system_type(self): """Test tenant preferred provider with system type.""" @@ -443,11 +444,11 @@ class TestTenantPreferredModelProvider: preferred = TenantPreferredModelProvider( tenant_id=str(uuid4()), provider_name="anthropic", - preferred_provider_type="system", + preferred_provider_type=ProviderType.SYSTEM, ) # Assert - assert preferred.preferred_provider_type == "system" + assert preferred.preferred_provider_type == ProviderType.SYSTEM class TestProviderOrder: @@ -470,7 +471,7 @@ class TestProviderOrder: quantity=1, currency=None, total_amount=None, - payment_status="wait_pay", + payment_status=PaymentStatus.WAIT_PAY, paid_at=None, pay_failed_at=None, refunded_at=None, @@ -481,7 +482,7 @@ class TestProviderOrder: assert order.provider_name == "openai" assert order.account_id == account_id assert order.payment_product_id == "prod_123" - assert order.payment_status == "wait_pay" + assert order.payment_status == PaymentStatus.WAIT_PAY assert order.quantity == 1 def test_provider_order_with_payment_details(self): @@ -502,7 +503,7 @@ class TestProviderOrder: quantity=5, currency="USD", total_amount=9999, - payment_status="paid", + payment_status=PaymentStatus.PAID, paid_at=paid_time, pay_failed_at=None, refunded_at=None, @@ -514,7 +515,7 @@ class TestProviderOrder: assert order.quantity == 5 assert order.currency == "USD" assert order.total_amount == 9999 - assert order.payment_status == "paid" + assert order.payment_status == PaymentStatus.PAID assert order.paid_at == paid_time def test_provider_order_payment_statuses(self): @@ -536,23 +537,23 @@ class TestProviderOrder: } # Act & Assert - Wait pay status - wait_order = ProviderOrder(**base_params, payment_status="wait_pay") - assert wait_order.payment_status == "wait_pay" + wait_order = ProviderOrder(**base_params, payment_status=PaymentStatus.WAIT_PAY) + assert wait_order.payment_status == PaymentStatus.WAIT_PAY # Act & Assert - Paid status - paid_order = ProviderOrder(**base_params, payment_status="paid") - assert paid_order.payment_status == "paid" + paid_order = ProviderOrder(**base_params, payment_status=PaymentStatus.PAID) + assert paid_order.payment_status == PaymentStatus.PAID # Act & Assert - Failed status failed_params = {**base_params, "pay_failed_at": datetime.now(UTC)} - failed_order = ProviderOrder(**failed_params, payment_status="failed") - assert failed_order.payment_status == "failed" + failed_order = ProviderOrder(**failed_params, payment_status=PaymentStatus.FAILED) + assert failed_order.payment_status == PaymentStatus.FAILED assert failed_order.pay_failed_at is not None # Act & Assert - Refunded status refunded_params = {**base_params, "refunded_at": datetime.now(UTC)} - refunded_order = ProviderOrder(**refunded_params, payment_status="refunded") - assert refunded_order.payment_status == "refunded" + refunded_order = ProviderOrder(**refunded_params, payment_status=PaymentStatus.REFUNDED) + assert refunded_order.payment_status == PaymentStatus.REFUNDED assert refunded_order.refunded_at is not None @@ -650,13 +651,13 @@ class TestLoadBalancingModelConfig: name="Secondary API Key", encrypted_config='{"api_key": "encrypted_value"}', credential_id=credential_id, - credential_source_type="custom", + credential_source_type=CredentialSourceType.CUSTOM_MODEL, ) # Assert assert config.encrypted_config == '{"api_key": "encrypted_value"}' assert config.credential_id == credential_id - assert config.credential_source_type == "custom" + assert config.credential_source_type == CredentialSourceType.CUSTOM_MODEL def test_load_balancing_config_disabled(self): """Test disabled load balancing config.""" diff --git a/api/tests/unit_tests/services/retention/conversation/test_messages_clean_service.py b/api/tests/unit_tests/services/retention/conversation/test_messages_clean_service.py index a34defeba9..f9d901fca2 100644 --- a/api/tests/unit_tests/services/retention/conversation/test_messages_clean_service.py +++ b/api/tests/unit_tests/services/retention/conversation/test_messages_clean_service.py @@ -1,5 +1,4 @@ import datetime -import os from unittest.mock import MagicMock, patch import pytest @@ -282,7 +281,6 @@ class TestMessagesCleanService: MessagesCleanService._batch_delete_message_relations(mock_db_session, ["msg1", "msg2"]) assert mock_db_session.execute.call_count == 8 # 8 tables to clean up - @patch.dict(os.environ, {"SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL": "500"}) def test_clean_messages_interval_from_env(self, mock_db_session, mock_policy): service = MessagesCleanService( policy=mock_policy, @@ -301,9 +299,13 @@ class TestMessagesCleanService: mock_db_session.execute.side_effect = mock_returns mock_policy.filter_message_ids.return_value = ["msg1"] - with patch("services.retention.conversation.messages_clean_service.time.sleep") as mock_sleep: - with patch("services.retention.conversation.messages_clean_service.random.uniform") as mock_uniform: - mock_uniform.return_value = 300.0 - service.run() - mock_uniform.assert_called_with(0, 500) - mock_sleep.assert_called_with(0.3) + with patch( + "services.retention.conversation.messages_clean_service.dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL", + 500, + ): + with patch("services.retention.conversation.messages_clean_service.time.sleep") as mock_sleep: + with patch("services.retention.conversation.messages_clean_service.random.uniform") as mock_uniform: + mock_uniform.return_value = 300.0 + service.run() + mock_uniform.assert_called_with(0, 500) + mock_sleep.assert_called_with(0.3) diff --git a/api/tests/unit_tests/services/retention/workflow_run/test_clear_free_plan_expired_workflow_run_logs.py b/api/tests/unit_tests/services/retention/workflow_run/test_clear_free_plan_expired_workflow_run_logs.py index 0013cde79e..7d30645d38 100644 --- a/api/tests/unit_tests/services/retention/workflow_run/test_clear_free_plan_expired_workflow_run_logs.py +++ b/api/tests/unit_tests/services/retention/workflow_run/test_clear_free_plan_expired_workflow_run_logs.py @@ -80,7 +80,13 @@ class TestWorkflowRunCleanupInit: cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 cfg.BILLING_ENABLED = False with pytest.raises(ValueError): - WorkflowRunCleanup(days=30, batch_size=10, start_from=dt, end_before=dt, workflow_run_repo=mock_repo) + WorkflowRunCleanup( + days=30, + batch_size=10, + start_from=dt, + end_before=dt, + workflow_run_repo=mock_repo, + ) def test_zero_batch_size_raises(self, mock_repo): with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: @@ -102,10 +108,24 @@ class TestWorkflowRunCleanupInit: cfg.BILLING_ENABLED = False start = datetime.datetime(2024, 1, 1) end = datetime.datetime(2024, 6, 1) - c = WorkflowRunCleanup(days=30, batch_size=5, start_from=start, end_before=end, workflow_run_repo=mock_repo) + c = WorkflowRunCleanup( + days=30, + batch_size=5, + start_from=start, + end_before=end, + workflow_run_repo=mock_repo, + ) assert c.window_start == start assert c.window_end == end + def test_default_task_label_is_custom(self, mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = False + c = WorkflowRunCleanup(days=30, batch_size=10, workflow_run_repo=mock_repo) + + assert c._metrics._base_attributes["task_label"] == "custom" + # --------------------------------------------------------------------------- # _empty_related_counts / _format_related_counts @@ -393,7 +413,12 @@ class TestRunDryRunMode: with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 cfg.BILLING_ENABLED = False - return WorkflowRunCleanup(days=30, batch_size=10, workflow_run_repo=mock_repo, dry_run=True) + return WorkflowRunCleanup( + days=30, + batch_size=10, + workflow_run_repo=mock_repo, + dry_run=True, + ) def test_dry_run_no_delete_called(self, mock_repo): run = make_run("t1") diff --git a/api/tests/unit_tests/services/test_clear_free_plan_expired_workflow_run_logs.py b/api/tests/unit_tests/services/test_clear_free_plan_expired_workflow_run_logs.py index 50826d6798..6bf78d3411 100644 --- a/api/tests/unit_tests/services/test_clear_free_plan_expired_workflow_run_logs.py +++ b/api/tests/unit_tests/services/test_clear_free_plan_expired_workflow_run_logs.py @@ -265,6 +265,61 @@ def test_run_exits_on_empty_batch(monkeypatch: pytest.MonkeyPatch) -> None: cleanup.run() +def test_run_records_metrics_on_success(monkeypatch: pytest.MonkeyPatch) -> None: + cutoff = datetime.datetime.now() + repo = FakeRepo( + batches=[[FakeRun("run-free", "t_free", cutoff)]], + delete_result={ + "runs": 0, + "node_executions": 2, + "offloads": 1, + "app_logs": 3, + "trigger_logs": 4, + "pauses": 5, + "pause_reasons": 6, + }, + ) + cleanup = create_cleanup(monkeypatch, repo=repo, days=30, batch_size=10) + monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", False) + + batch_calls: list[dict[str, object]] = [] + completion_calls: list[dict[str, object]] = [] + monkeypatch.setattr(cleanup._metrics, "record_batch", lambda **kwargs: batch_calls.append(kwargs)) + monkeypatch.setattr(cleanup._metrics, "record_completion", lambda **kwargs: completion_calls.append(kwargs)) + + cleanup.run() + + assert len(batch_calls) == 1 + assert batch_calls[0]["batch_rows"] == 1 + assert batch_calls[0]["targeted_runs"] == 1 + assert batch_calls[0]["deleted_runs"] == 1 + assert batch_calls[0]["related_action"] == "deleted" + assert len(completion_calls) == 1 + assert completion_calls[0]["status"] == "success" + + +def test_run_records_failed_metrics(monkeypatch: pytest.MonkeyPatch) -> None: + class FailingRepo(FakeRepo): + def delete_runs_with_related( + self, runs: list[FakeRun], delete_node_executions=None, delete_trigger_logs=None + ) -> dict[str, int]: + raise RuntimeError("delete failed") + + cutoff = datetime.datetime.now() + repo = FailingRepo(batches=[[FakeRun("run-free", "t_free", cutoff)]]) + cleanup = create_cleanup(monkeypatch, repo=repo, days=30, batch_size=10) + monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", False) + + completion_calls: list[dict[str, object]] = [] + monkeypatch.setattr(cleanup._metrics, "record_completion", lambda **kwargs: completion_calls.append(kwargs)) + + with pytest.raises(RuntimeError, match="delete failed"): + cleanup.run() + + assert len(completion_calls) == 1 + assert completion_calls[0]["status"] == "failed" + + def test_run_dry_run_skips_deletions(monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]) -> None: cutoff = datetime.datetime.now() repo = FakeRepo( diff --git a/api/tests/unit_tests/services/test_messages_clean_service.py b/api/tests/unit_tests/services/test_messages_clean_service.py index 4449b442d6..f3efc4463e 100644 --- a/api/tests/unit_tests/services/test_messages_clean_service.py +++ b/api/tests/unit_tests/services/test_messages_clean_service.py @@ -540,6 +540,20 @@ class TestMessagesCleanServiceFromTimeRange: assert service._batch_size == 1000 # default assert service._dry_run is False # default + def test_explicit_task_label(self): + start_from = datetime.datetime(2024, 1, 1) + end_before = datetime.datetime(2024, 1, 2) + policy = BillingDisabledPolicy() + + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=start_from, + end_before=end_before, + task_label="60to30", + ) + + assert service._metrics._base_attributes["task_label"] == "60to30" + class TestMessagesCleanServiceFromDays: """Unit tests for MessagesCleanService.from_days factory method.""" @@ -619,3 +633,54 @@ class TestMessagesCleanServiceFromDays: assert service._end_before == expected_end_before assert service._batch_size == 1000 # default assert service._dry_run is False # default + assert service._metrics._base_attributes["task_label"] == "custom" + + +class TestMessagesCleanServiceRun: + """Unit tests for MessagesCleanService.run instrumentation behavior.""" + + def test_run_records_completion_metrics_on_success(self): + # Arrange + service = MessagesCleanService( + policy=BillingDisabledPolicy(), + start_from=datetime.datetime(2024, 1, 1), + end_before=datetime.datetime(2024, 1, 2), + batch_size=100, + dry_run=False, + ) + expected_stats = { + "batches": 1, + "total_messages": 10, + "filtered_messages": 5, + "total_deleted": 5, + } + service._clean_messages_by_time_range = MagicMock(return_value=expected_stats) # type: ignore[method-assign] + completion_calls: list[dict[str, object]] = [] + service._metrics.record_completion = lambda **kwargs: completion_calls.append(kwargs) # type: ignore[method-assign] + + # Act + result = service.run() + + # Assert + assert result == expected_stats + assert len(completion_calls) == 1 + assert completion_calls[0]["status"] == "success" + + def test_run_records_completion_metrics_on_failure(self): + # Arrange + service = MessagesCleanService( + policy=BillingDisabledPolicy(), + start_from=datetime.datetime(2024, 1, 1), + end_before=datetime.datetime(2024, 1, 2), + batch_size=100, + dry_run=False, + ) + service._clean_messages_by_time_range = MagicMock(side_effect=RuntimeError("clean failed")) # type: ignore[method-assign] + completion_calls: list[dict[str, object]] = [] + service._metrics.record_completion = lambda **kwargs: completion_calls.append(kwargs) # type: ignore[method-assign] + + # Act & Assert + with pytest.raises(RuntimeError, match="clean failed"): + service.run() + assert len(completion_calls) == 1 + assert completion_calls[0]["status"] == "failed" diff --git a/api/tests/unit_tests/services/test_summary_index_service.py b/api/tests/unit_tests/services/test_summary_index_service.py index c7e1fed21f..be64e431ba 100644 --- a/api/tests/unit_tests/services/test_summary_index_service.py +++ b/api/tests/unit_tests/services/test_summary_index_service.py @@ -11,6 +11,7 @@ from unittest.mock import MagicMock import pytest import services.summary_index_service as summary_module +from models.enums import SegmentStatus, SummaryStatus from services.summary_index_service import SummaryIndexService @@ -42,7 +43,7 @@ def _segment(*, has_document: bool = True) -> MagicMock: segment.dataset_id = "dataset-1" segment.content = "hello world" segment.enabled = True - segment.status = "completed" + segment.status = SegmentStatus.COMPLETED segment.position = 1 if has_document: doc = MagicMock(name="document") @@ -64,7 +65,7 @@ def _summary_record(*, summary_content: str = "summary", node_id: str | None = N record.summary_index_node_id = node_id record.summary_index_node_hash = None record.tokens = None - record.status = "generating" + record.status = SummaryStatus.GENERATING record.error = None record.enabled = True record.created_at = datetime(2024, 1, 1, tzinfo=UTC) @@ -133,10 +134,10 @@ def test_create_summary_record_updates_existing_and_reenables(monkeypatch: pytes segment = _segment() dataset = _dataset() - result = SummaryIndexService.create_summary_record(segment, dataset, "new", status="generating") + result = SummaryIndexService.create_summary_record(segment, dataset, "new", status=SummaryStatus.GENERATING) assert result is existing assert existing.summary_content == "new" - assert existing.status == "generating" + assert existing.status == SummaryStatus.GENERATING assert existing.enabled is True assert existing.disabled_at is None assert existing.disabled_by is None @@ -155,7 +156,7 @@ def test_create_summary_record_creates_new(monkeypatch: pytest.MonkeyPatch) -> N create_session_mock = MagicMock(return_value=_SessionContext(session)) monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) - record = SummaryIndexService.create_summary_record(_segment(), _dataset(), "new", status="generating") + record = SummaryIndexService.create_summary_record(_segment(), _dataset(), "new", status=SummaryStatus.GENERATING) assert record.dataset_id == "dataset-1" assert record.chunk_id == "seg-1" assert record.summary_content == "new" @@ -204,7 +205,7 @@ def test_vectorize_summary_retries_connection_errors_then_succeeds(monkeypatch: assert vector_instance.add_texts.call_count == 2 summary_module.time.sleep.assert_called_once() # type: ignore[attr-defined] session.flush.assert_called_once() - assert summary.status == "completed" + assert summary.status == SummaryStatus.COMPLETED assert summary.summary_index_node_id == "uuid-1" assert summary.summary_index_node_hash == "hash-1" assert summary.tokens == 5 @@ -245,7 +246,7 @@ def test_vectorize_summary_without_session_creates_record_when_missing(monkeypat create_session_mock.assert_called() session.add.assert_called() session.commit.assert_called_once() - assert summary.status == "completed" + assert summary.status == SummaryStatus.COMPLETED assert summary.summary_index_node_id == "old-node" # reused @@ -275,7 +276,7 @@ def test_vectorize_summary_final_failure_updates_error_status(monkeypatch: pytes with pytest.raises(RuntimeError, match="boom"): SummaryIndexService.vectorize_summary(summary, segment, dataset, session=None) - assert summary.status == "error" + assert summary.status == SummaryStatus.ERROR assert "Vectorization failed" in (summary.error or "") error_session.commit.assert_called_once() @@ -310,7 +311,7 @@ def test_batch_create_summary_records_creates_and_updates(monkeypatch: pytest.Mo SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), ) - SummaryIndexService.batch_create_summary_records([s1, s2], dataset, status="not_started") + SummaryIndexService.batch_create_summary_records([s1, s2], dataset, status=SummaryStatus.NOT_STARTED) session.commit.assert_called_once() assert existing.enabled is True @@ -332,7 +333,7 @@ def test_update_summary_record_error_updates_when_exists(monkeypatch: pytest.Mon ) SummaryIndexService.update_summary_record_error(segment, dataset, "err") - assert record.status == "error" + assert record.status == SummaryStatus.ERROR assert record.error == "err" session.commit.assert_called_once() @@ -387,7 +388,7 @@ def test_generate_and_vectorize_summary_vectorize_failure_sets_error(monkeypatch with pytest.raises(RuntimeError, match="boom"): SummaryIndexService.generate_and_vectorize_summary(segment, dataset, {"enable": True}) - assert record.status == "error" + assert record.status == SummaryStatus.ERROR # Outer exception handler overwrites the error with the raw exception message. assert record.error == "boom" @@ -614,7 +615,7 @@ def test_generate_and_vectorize_summary_creates_missing_record_and_logs_usage(mo monkeypatch.setattr(summary_module, "logger", logger_mock) result = SummaryIndexService.generate_and_vectorize_summary(segment, dataset, {"enable": True}) - assert result.status in {"generating", "completed"} + assert result.status in {SummaryStatus.GENERATING, SummaryStatus.COMPLETED} logger_mock.info.assert_called() @@ -787,7 +788,7 @@ def test_enable_summaries_for_segments_revectorizes_and_enables(monkeypatch: pyt segment = _segment() segment.id = summary.chunk_id segment.enabled = True - segment.status = "completed" + segment.status = SegmentStatus.COMPLETED session = MagicMock() summary_query = MagicMock() @@ -850,11 +851,11 @@ def test_enable_summaries_for_segments_skips_segment_or_content_and_handles_vect bad_segment = _segment() bad_segment.enabled = False - bad_segment.status = "completed" + bad_segment.status = SegmentStatus.COMPLETED good_segment = _segment() good_segment.enabled = True - good_segment.status = "completed" + good_segment.status = SegmentStatus.COMPLETED session = MagicMock() summary_query = MagicMock() @@ -1084,7 +1085,7 @@ def test_update_summary_for_segment_existing_vectorize_failure_returns_error_rec out = SummaryIndexService.update_summary_for_segment(segment, dataset, "new") assert out is record - assert out.status == "error" + assert out.status == SummaryStatus.ERROR assert "Vectorization failed" in (out.error or "") @@ -1133,7 +1134,7 @@ def test_update_summary_for_segment_outer_exception_sets_error_and_reraises(monk with pytest.raises(RuntimeError, match="flush boom"): SummaryIndexService.update_summary_for_segment(segment, dataset, "new") - assert record.status == "error" + assert record.status == SummaryStatus.ERROR assert record.error == "flush boom" session.commit.assert_called() @@ -1222,7 +1223,7 @@ def test_get_documents_summary_index_status_no_pending_sets_none(monkeypatch: py monkeypatch.setattr( SummaryIndexService, "get_segments_summaries", - MagicMock(return_value={"seg-1": SimpleNamespace(status="completed")}), + MagicMock(return_value={"seg-1": SimpleNamespace(status=SummaryStatus.COMPLETED)}), ) result = SummaryIndexService.get_documents_summary_index_status(["doc-1"], "dataset-1", "tenant-1") assert result["doc-1"] is None @@ -1254,7 +1255,7 @@ def test_update_summary_for_segment_creates_new_and_vectorize_fails_returns_erro monkeypatch.setattr(SummaryIndexService, "vectorize_summary", vectorize_mock) out = SummaryIndexService.update_summary_for_segment(segment, dataset, "new") - assert out.status == "error" + assert out.status == SummaryStatus.ERROR assert "Vectorization failed" in (out.error or "") @@ -1276,7 +1277,7 @@ def test_get_document_summary_index_status_and_documents_status(monkeypatch: pyt monkeypatch.setattr( SummaryIndexService, "get_segments_summaries", - MagicMock(return_value={"seg-1": SimpleNamespace(status="generating")}), + MagicMock(return_value={"seg-1": SimpleNamespace(status=SummaryStatus.GENERATING)}), ) assert SummaryIndexService.get_document_summary_index_status("doc-1", "dataset-1", "tenant-1") == "SUMMARIZING" @@ -1294,7 +1295,7 @@ def test_get_document_summary_index_status_and_documents_status(monkeypatch: pyt monkeypatch.setattr( SummaryIndexService, "get_segments_summaries", - MagicMock(return_value={"seg-1": SimpleNamespace(status="not_started")}), + MagicMock(return_value={"seg-1": SimpleNamespace(status=SummaryStatus.NOT_STARTED)}), ) result = SummaryIndexService.get_documents_summary_index_status(["doc-1", "doc-2"], "dataset-1", "tenant-1") assert result["doc-1"] == "SUMMARIZING" @@ -1311,7 +1312,7 @@ def test_get_document_summary_status_detail_counts_and_previews(monkeypatch: pyt summary1 = _summary_record(summary_content="x" * 150, node_id="n1") summary1.chunk_id = "seg-1" - summary1.status = "completed" + summary1.status = SummaryStatus.COMPLETED summary1.error = None summary1.created_at = datetime(2024, 1, 1, tzinfo=UTC) summary1.updated_at = datetime(2024, 1, 2, tzinfo=UTC) diff --git a/api/tests/unit_tests/tasks/test_clean_dataset_task.py b/api/tests/unit_tests/tasks/test_clean_dataset_task.py index df33f20c9b..74ba7f9c34 100644 --- a/api/tests/unit_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/unit_tests/tasks/test_clean_dataset_task.py @@ -16,6 +16,7 @@ from unittest.mock import MagicMock, patch import pytest +from models.enums import DataSourceType from tasks.clean_dataset_task import clean_dataset_task # ============================================================================ @@ -116,7 +117,7 @@ def mock_document(): doc.id = str(uuid.uuid4()) doc.tenant_id = str(uuid.uuid4()) doc.dataset_id = str(uuid.uuid4()) - doc.data_source_type = "upload_file" + doc.data_source_type = DataSourceType.UPLOAD_FILE doc.data_source_info = '{"upload_file_id": "test-file-id"}' doc.data_source_info_dict = {"upload_file_id": "test-file-id"} return doc diff --git a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py index 67e0a8efaf..8a721124d6 100644 --- a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py @@ -19,6 +19,7 @@ from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan from extensions.ext_redis import redis_client from models.dataset import Dataset, Document +from models.enums import IndexingStatus from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy from tasks.document_indexing_task import ( _document_indexing, @@ -424,7 +425,7 @@ class TestBatchProcessing: # Assert - All documents should be set to 'parsing' status for doc in mock_documents: - assert doc.indexing_status == "parsing" + assert doc.indexing_status == IndexingStatus.PARSING assert doc.processing_started_at is not None # IndexingRunner should be called with all documents @@ -573,7 +574,7 @@ class TestProgressTracking: # Assert - Status should be 'parsing' for doc in mock_documents: - assert doc.indexing_status == "parsing" + assert doc.indexing_status == IndexingStatus.PARSING assert doc.processing_started_at is not None # Verify commit was called to persist status @@ -1158,7 +1159,7 @@ class TestAdvancedScenarios: # Assert # All documents should be set to parsing (no limit errors) for doc in mock_documents: - assert doc.indexing_status == "parsing" + assert doc.indexing_status == IndexingStatus.PARSING # IndexingRunner should be called with all documents mock_indexing_runner.run.assert_called_once() @@ -1377,7 +1378,7 @@ class TestPerformanceScenarios: # Assert for doc in mock_documents: - assert doc.indexing_status == "parsing" + assert doc.indexing_status == IndexingStatus.PARSING mock_indexing_runner.run.assert_called_once() call_args = mock_indexing_runner.run.call_args[0][0] diff --git a/api/uv.lock b/api/uv.lock index 4bb86aa762..ddb70f6b54 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1533,7 +1533,7 @@ wheels = [ [[package]] name = "dify-api" -version = "1.13.0" +version = "1.13.1" source = { virtual = "." } dependencies = [ { name = "aliyun-log-python-sdk" }, @@ -5154,11 +5154,11 @@ wheels = [ [[package]] name = "pyasn1" -version = "0.6.2" +version = "0.6.3" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/fe/b6/6e630dff89739fcd427e3f72b3d905ce0acb85a45d4ec3e2678718a3487f/pyasn1-0.6.2.tar.gz", hash = "sha256:9b59a2b25ba7e4f8197db7686c09fb33e658b98339fadb826e9512629017833b", size = 146586, upload-time = "2026-01-16T18:04:18.534Z" } +sdist = { url = "https://files.pythonhosted.org/packages/5c/5f/6583902b6f79b399c9c40674ac384fd9cd77805f9e6205075f828ef11fb2/pyasn1-0.6.3.tar.gz", hash = "sha256:697a8ecd6d98891189184ca1fa05d1bb00e2f84b5977c481452050549c8a72cf", size = 148685, upload-time = "2026-03-17T01:06:53.382Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/44/b5/a96872e5184f354da9c84ae119971a0a4c221fe9b27a4d94bd43f2596727/pyasn1-0.6.2-py3-none-any.whl", hash = "sha256:1eb26d860996a18e9b6ed05e7aae0e9fc21619fcee6af91cca9bad4fbea224bf", size = 83371, upload-time = "2026-01-16T18:04:17.174Z" }, + { url = "https://files.pythonhosted.org/packages/5d/a0/7d793dce3fa811fe047d6ae2431c672364b462850c6235ae306c0efd025f/pyasn1-0.6.3-py3-none-any.whl", hash = "sha256:a80184d120f0864a52a073acc6fc642847d0be408e7c7252f31390c0f4eadcde", size = 83997, upload-time = "2026-03-17T01:06:52.036Z" }, ] [[package]] diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index 1804592c0e..939f23136a 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -21,7 +21,7 @@ services: # API service api: - image: langgenius/dify-api:1.13.0 + image: langgenius/dify-api:1.13.1 restart: always environment: # Use the shared environment variables. @@ -63,7 +63,7 @@ services: # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: - image: langgenius/dify-api:1.13.0 + image: langgenius/dify-api:1.13.1 restart: always environment: # Use the shared environment variables. @@ -102,7 +102,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.13.0 + image: langgenius/dify-api:1.13.1 restart: always environment: # Use the shared environment variables. @@ -132,7 +132,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.13.0 + image: langgenius/dify-web:1.13.1 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index d14f0503e7..b6b6f299cf 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -728,7 +728,7 @@ services: # API service api: - image: langgenius/dify-api:1.13.0 + image: langgenius/dify-api:1.13.1 restart: always environment: # Use the shared environment variables. @@ -770,7 +770,7 @@ services: # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: - image: langgenius/dify-api:1.13.0 + image: langgenius/dify-api:1.13.1 restart: always environment: # Use the shared environment variables. @@ -809,7 +809,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.13.0 + image: langgenius/dify-api:1.13.1 restart: always environment: # Use the shared environment variables. @@ -839,7 +839,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.13.0 + image: langgenius/dify-web:1.13.1 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} diff --git a/web/__tests__/apps/app-card-operations-flow.test.tsx b/web/__tests__/apps/app-card-operations-flow.test.tsx index c3e8410955..c5766878a1 100644 --- a/web/__tests__/apps/app-card-operations-flow.test.tsx +++ b/web/__tests__/apps/app-card-operations-flow.test.tsx @@ -29,7 +29,7 @@ const mockOnPlanInfoChanged = vi.fn() const mockDeleteAppMutation = vi.fn().mockResolvedValue(undefined) let mockDeleteMutationPending = false -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockRouterPush, }), @@ -57,7 +57,7 @@ vi.mock('@headlessui/react', async () => { } }) -vi.mock('next/dynamic', () => ({ +vi.mock('@/next/dynamic', () => ({ default: (loader: () => Promise<{ default: React.ComponentType }>) => { let Component: React.ComponentType> | null = null loader().then((mod) => { diff --git a/web/__tests__/apps/app-list-browsing-flow.test.tsx b/web/__tests__/apps/app-list-browsing-flow.test.tsx index 079f667dbc..1be7e56086 100644 --- a/web/__tests__/apps/app-list-browsing-flow.test.tsx +++ b/web/__tests__/apps/app-list-browsing-flow.test.tsx @@ -38,7 +38,7 @@ let mockShowTagManagementModal = false const mockRouterPush = vi.fn() const mockRouterReplace = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockRouterPush, replace: mockRouterReplace, @@ -46,7 +46,7 @@ vi.mock('next/navigation', () => ({ useSearchParams: () => new URLSearchParams(), })) -vi.mock('next/dynamic', () => ({ +vi.mock('@/next/dynamic', () => ({ default: (_loader: () => Promise<{ default: React.ComponentType }>) => { const LazyComponent = (props: Record) => { return
diff --git a/web/__tests__/apps/create-app-flow.test.tsx b/web/__tests__/apps/create-app-flow.test.tsx index 4ac9824ddd..bc1f7a3a06 100644 --- a/web/__tests__/apps/create-app-flow.test.tsx +++ b/web/__tests__/apps/create-app-flow.test.tsx @@ -35,7 +35,7 @@ const mockRouterPush = vi.fn() const mockRouterReplace = vi.fn() const mockOnPlanInfoChanged = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockRouterPush, replace: mockRouterReplace, @@ -117,7 +117,7 @@ vi.mock('ahooks', async () => { }) // Mock dynamically loaded modals with test stubs -vi.mock('next/dynamic', () => ({ +vi.mock('@/next/dynamic', () => ({ default: (loader: () => Promise<{ default: React.ComponentType }>) => { let Component: React.ComponentType> | null = null loader().then((mod) => { diff --git a/web/__tests__/billing/billing-integration.test.tsx b/web/__tests__/billing/billing-integration.test.tsx index 4891760df4..64d358cbe6 100644 --- a/web/__tests__/billing/billing-integration.test.tsx +++ b/web/__tests__/billing/billing-integration.test.tsx @@ -64,7 +64,7 @@ vi.mock('@/service/use-education', () => ({ // ─── Navigation mocks ─────────────────────────────────────────────────────── const mockRouterPush = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockRouterPush }), usePathname: () => '/billing', useSearchParams: () => new URLSearchParams(), diff --git a/web/__tests__/billing/cloud-plan-payment-flow.test.tsx b/web/__tests__/billing/cloud-plan-payment-flow.test.tsx index e01d9250fd..84653cd68c 100644 --- a/web/__tests__/billing/cloud-plan-payment-flow.test.tsx +++ b/web/__tests__/billing/cloud-plan-payment-flow.test.tsx @@ -11,6 +11,7 @@ import type { BasicPlan } from '@/app/components/billing/type' import { cleanup, render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' import * as React from 'react' +import { toast, ToastHost } from '@/app/components/base/ui/toast' import { ALL_PLANS } from '@/app/components/billing/config' import { PlanRange } from '@/app/components/billing/pricing/plan-switcher/plan-range-switcher' import CloudPlanItem from '@/app/components/billing/pricing/plans/cloud-plan-item' @@ -21,7 +22,6 @@ let mockAppCtx: Record = {} const mockFetchSubscriptionUrls = vi.fn() const mockInvoices = vi.fn() const mockOpenAsyncWindow = vi.fn() -const mockToastNotify = vi.fn() // ─── Context mocks ─────────────────────────────────────────────────────────── vi.mock('@/context/app-context', () => ({ @@ -49,12 +49,8 @@ vi.mock('@/hooks/use-async-window-open', () => ({ useAsyncWindowOpen: () => mockOpenAsyncWindow, })) -vi.mock('@/app/components/base/toast', () => ({ - default: { notify: (args: unknown) => mockToastNotify(args) }, -})) - // ─── Navigation mocks ─────────────────────────────────────────────────────── -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: vi.fn() }), usePathname: () => '/billing', useSearchParams: () => new URLSearchParams(), @@ -82,12 +78,15 @@ const renderCloudPlanItem = ({ canPay = true, }: RenderCloudPlanItemOptions = {}) => { return render( - , + <> + + + , ) } @@ -96,6 +95,7 @@ describe('Cloud Plan Payment Flow', () => { beforeEach(() => { vi.clearAllMocks() cleanup() + toast.close() setupAppContext() mockFetchSubscriptionUrls.mockResolvedValue({ url: 'https://pay.example.com/checkout' }) mockInvoices.mockResolvedValue({ url: 'https://billing.example.com/invoices' }) @@ -283,11 +283,7 @@ describe('Cloud Plan Payment Flow', () => { await user.click(button) await waitFor(() => { - expect(mockToastNotify).toHaveBeenCalledWith( - expect.objectContaining({ - type: 'error', - }), - ) + expect(screen.getByText('billing.buyPermissionDeniedTip')).toBeInTheDocument() }) // Should not proceed with payment expect(mockFetchSubscriptionUrls).not.toHaveBeenCalled() diff --git a/web/__tests__/billing/education-verification-flow.test.tsx b/web/__tests__/billing/education-verification-flow.test.tsx index 8c35cd9a8c..707f1d690a 100644 --- a/web/__tests__/billing/education-verification-flow.test.tsx +++ b/web/__tests__/billing/education-verification-flow.test.tsx @@ -63,7 +63,7 @@ vi.mock('@/service/use-billing', () => ({ })) // ─── Navigation mocks ─────────────────────────────────────────────────────── -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockRouterPush }), usePathname: () => '/billing', useSearchParams: () => new URLSearchParams(), diff --git a/web/__tests__/billing/partner-stack-flow.test.tsx b/web/__tests__/billing/partner-stack-flow.test.tsx index 4f265478cd..fe642ac70b 100644 --- a/web/__tests__/billing/partner-stack-flow.test.tsx +++ b/web/__tests__/billing/partner-stack-flow.test.tsx @@ -18,7 +18,7 @@ let mockSearchParams = new URLSearchParams() const mockMutateAsync = vi.fn() // ─── Module mocks ──────────────────────────────────────────────────────────── -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useSearchParams: () => mockSearchParams, useRouter: () => ({ push: vi.fn() }), usePathname: () => '/', diff --git a/web/__tests__/billing/pricing-modal-flow.test.tsx b/web/__tests__/billing/pricing-modal-flow.test.tsx index 6b8fb57f83..2ec7298618 100644 --- a/web/__tests__/billing/pricing-modal-flow.test.tsx +++ b/web/__tests__/billing/pricing-modal-flow.test.tsx @@ -51,7 +51,7 @@ vi.mock('@/hooks/use-async-window-open', () => ({ })) // ─── Navigation mocks ─────────────────────────────────────────────────────── -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: vi.fn() }), usePathname: () => '/billing', useSearchParams: () => new URLSearchParams(), @@ -295,24 +295,7 @@ describe('Pricing Modal Flow', () => { }) }) - // ─── 6. Close Handling ─────────────────────────────────────────────────── - describe('Close handling', () => { - it('should call onCancel when pressing ESC key', () => { - render() - - // ahooks useKeyPress listens on document for keydown events - document.dispatchEvent(new KeyboardEvent('keydown', { - key: 'Escape', - code: 'Escape', - keyCode: 27, - bubbles: true, - })) - - expect(onCancel).toHaveBeenCalledTimes(1) - }) - }) - - // ─── 7. Pricing URL ───────────────────────────────────────────────────── + // ─── 6. Pricing URL ───────────────────────────────────────────────────── describe('Pricing page URL', () => { it('should render pricing link with correct URL', () => { render() diff --git a/web/__tests__/billing/self-hosted-plan-flow.test.tsx b/web/__tests__/billing/self-hosted-plan-flow.test.tsx index 810d36da8a..0802b760e1 100644 --- a/web/__tests__/billing/self-hosted-plan-flow.test.tsx +++ b/web/__tests__/billing/self-hosted-plan-flow.test.tsx @@ -10,12 +10,12 @@ import { cleanup, render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' import * as React from 'react' +import { toast, ToastHost } from '@/app/components/base/ui/toast' import { contactSalesUrl, getStartedWithCommunityUrl, getWithPremiumUrl } from '@/app/components/billing/config' import SelfHostedPlanItem from '@/app/components/billing/pricing/plans/self-hosted-plan-item' import { SelfHostedPlan } from '@/app/components/billing/type' let mockAppCtx: Record = {} -const mockToastNotify = vi.fn() const originalLocation = window.location let assignedHref = '' @@ -40,10 +40,6 @@ vi.mock('@/app/components/base/icons/src/public/billing', () => ({ AwsMarketplaceDark: () => , })) -vi.mock('@/app/components/base/toast', () => ({ - default: { notify: (args: unknown) => mockToastNotify(args) }, -})) - vi.mock('@/app/components/billing/pricing/plans/self-hosted-plan-item/list', () => ({ default: ({ plan }: { plan: string }) => (
Features
@@ -57,10 +53,20 @@ const setupAppContext = (overrides: Record = {}) => { } } +const renderSelfHostedPlanItem = (plan: SelfHostedPlan) => { + return render( + <> + + + , + ) +} + describe('Self-Hosted Plan Flow', () => { beforeEach(() => { vi.clearAllMocks() cleanup() + toast.close() setupAppContext() // Mock window.location with minimal getter/setter (Location props are non-enumerable) @@ -85,14 +91,14 @@ describe('Self-Hosted Plan Flow', () => { // ─── 1. Plan Rendering ────────────────────────────────────────────────── describe('Plan rendering', () => { it('should render community plan with name and description', () => { - render() + renderSelfHostedPlanItem(SelfHostedPlan.community) expect(screen.getByText(/plans\.community\.name/i)).toBeInTheDocument() expect(screen.getByText(/plans\.community\.description/i)).toBeInTheDocument() }) it('should render premium plan with cloud provider icons', () => { - render() + renderSelfHostedPlanItem(SelfHostedPlan.premium) expect(screen.getByText(/plans\.premium\.name/i)).toBeInTheDocument() expect(screen.getByTestId('icon-azure')).toBeInTheDocument() @@ -100,39 +106,39 @@ describe('Self-Hosted Plan Flow', () => { }) it('should render enterprise plan without cloud provider icons', () => { - render() + renderSelfHostedPlanItem(SelfHostedPlan.enterprise) expect(screen.getByText(/plans\.enterprise\.name/i)).toBeInTheDocument() expect(screen.queryByTestId('icon-azure')).not.toBeInTheDocument() }) it('should not show price tip for community (free) plan', () => { - render() + renderSelfHostedPlanItem(SelfHostedPlan.community) expect(screen.queryByText(/plans\.community\.priceTip/i)).not.toBeInTheDocument() }) it('should show price tip for premium plan', () => { - render() + renderSelfHostedPlanItem(SelfHostedPlan.premium) expect(screen.getByText(/plans\.premium\.priceTip/i)).toBeInTheDocument() }) it('should render features list for each plan', () => { - const { unmount: unmount1 } = render() + const { unmount: unmount1 } = renderSelfHostedPlanItem(SelfHostedPlan.community) expect(screen.getByTestId('self-hosted-list-community')).toBeInTheDocument() unmount1() - const { unmount: unmount2 } = render() + const { unmount: unmount2 } = renderSelfHostedPlanItem(SelfHostedPlan.premium) expect(screen.getByTestId('self-hosted-list-premium')).toBeInTheDocument() unmount2() - render() + renderSelfHostedPlanItem(SelfHostedPlan.enterprise) expect(screen.getByTestId('self-hosted-list-enterprise')).toBeInTheDocument() }) it('should show AWS marketplace icon for premium plan button', () => { - render() + renderSelfHostedPlanItem(SelfHostedPlan.premium) expect(screen.getByTestId('icon-aws-light')).toBeInTheDocument() }) @@ -142,7 +148,7 @@ describe('Self-Hosted Plan Flow', () => { describe('Navigation flow', () => { it('should redirect to GitHub when clicking community plan button', async () => { const user = userEvent.setup() - render() + renderSelfHostedPlanItem(SelfHostedPlan.community) const button = screen.getByRole('button') await user.click(button) @@ -152,7 +158,7 @@ describe('Self-Hosted Plan Flow', () => { it('should redirect to AWS Marketplace when clicking premium plan button', async () => { const user = userEvent.setup() - render() + renderSelfHostedPlanItem(SelfHostedPlan.premium) const button = screen.getByRole('button') await user.click(button) @@ -162,7 +168,7 @@ describe('Self-Hosted Plan Flow', () => { it('should redirect to Typeform when clicking enterprise plan button', async () => { const user = userEvent.setup() - render() + renderSelfHostedPlanItem(SelfHostedPlan.enterprise) const button = screen.getByRole('button') await user.click(button) @@ -176,15 +182,13 @@ describe('Self-Hosted Plan Flow', () => { it('should show error toast when non-manager clicks community button', async () => { setupAppContext({ isCurrentWorkspaceManager: false }) const user = userEvent.setup() - render() + renderSelfHostedPlanItem(SelfHostedPlan.community) const button = screen.getByRole('button') await user.click(button) await waitFor(() => { - expect(mockToastNotify).toHaveBeenCalledWith( - expect.objectContaining({ type: 'error' }), - ) + expect(screen.getByText('billing.buyPermissionDeniedTip')).toBeInTheDocument() }) // Should NOT redirect expect(assignedHref).toBe('') @@ -193,15 +197,13 @@ describe('Self-Hosted Plan Flow', () => { it('should show error toast when non-manager clicks premium button', async () => { setupAppContext({ isCurrentWorkspaceManager: false }) const user = userEvent.setup() - render() + renderSelfHostedPlanItem(SelfHostedPlan.premium) const button = screen.getByRole('button') await user.click(button) await waitFor(() => { - expect(mockToastNotify).toHaveBeenCalledWith( - expect.objectContaining({ type: 'error' }), - ) + expect(screen.getByText('billing.buyPermissionDeniedTip')).toBeInTheDocument() }) expect(assignedHref).toBe('') }) @@ -209,15 +211,13 @@ describe('Self-Hosted Plan Flow', () => { it('should show error toast when non-manager clicks enterprise button', async () => { setupAppContext({ isCurrentWorkspaceManager: false }) const user = userEvent.setup() - render() + renderSelfHostedPlanItem(SelfHostedPlan.enterprise) const button = screen.getByRole('button') await user.click(button) await waitFor(() => { - expect(mockToastNotify).toHaveBeenCalledWith( - expect.objectContaining({ type: 'error' }), - ) + expect(screen.getByText('billing.buyPermissionDeniedTip')).toBeInTheDocument() }) expect(assignedHref).toBe('') }) diff --git a/web/__tests__/datasets/document-management.test.tsx b/web/__tests__/datasets/document-management.test.tsx index 8aedd4fc63..f9d80520ed 100644 --- a/web/__tests__/datasets/document-management.test.tsx +++ b/web/__tests__/datasets/document-management.test.tsx @@ -13,7 +13,7 @@ import { DataSourceType } from '@/models/datasets' import { renderHookWithNuqs } from '@/test/nuqs-testing' const mockPush = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useSearchParams: () => new URLSearchParams(''), useRouter: () => ({ push: mockPush }), usePathname: () => '/datasets/ds-1/documents', diff --git a/web/__tests__/document-detail-navigation-fix.test.tsx b/web/__tests__/document-detail-navigation-fix.test.tsx index 6b348cd15b..5cb115830e 100644 --- a/web/__tests__/document-detail-navigation-fix.test.tsx +++ b/web/__tests__/document-detail-navigation-fix.test.tsx @@ -7,12 +7,12 @@ import type { Mock } from 'vitest' */ import { fireEvent, render, screen } from '@testing-library/react' -import { useRouter } from 'next/navigation' +import { useRouter } from '@/next/navigation' import { useDocumentDetail, useDocumentMetadata } from '@/service/knowledge/use-document' // Mock Next.js router const mockPush = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: vi.fn(() => ({ push: mockPush, })), diff --git a/web/__tests__/embedded-user-id-auth.test.tsx b/web/__tests__/embedded-user-id-auth.test.tsx index 9231ac6199..cacd6331f8 100644 --- a/web/__tests__/embedded-user-id-auth.test.tsx +++ b/web/__tests__/embedded-user-id-auth.test.tsx @@ -8,7 +8,7 @@ const replaceMock = vi.fn() const backMock = vi.fn() const useSearchParamsMock = vi.fn(() => new URLSearchParams()) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ usePathname: vi.fn(() => '/chatbot/test-app'), useRouter: vi.fn(() => ({ replace: replaceMock, diff --git a/web/__tests__/embedded-user-id-store.test.tsx b/web/__tests__/embedded-user-id-store.test.tsx index 901218e76b..04597ccfeb 100644 --- a/web/__tests__/embedded-user-id-store.test.tsx +++ b/web/__tests__/embedded-user-id-store.test.tsx @@ -4,7 +4,7 @@ import WebAppStoreProvider, { useWebAppStore } from '@/context/web-app-context' import { AccessMode } from '@/models/access-control' -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ usePathname: vi.fn(() => '/chatbot/sample-app'), useSearchParams: vi.fn(() => { const params = new URLSearchParams() diff --git a/web/__tests__/explore/sidebar-lifecycle-flow.test.tsx b/web/__tests__/explore/sidebar-lifecycle-flow.test.tsx index e2c18bcc4f..77f493ab18 100644 --- a/web/__tests__/explore/sidebar-lifecycle-flow.test.tsx +++ b/web/__tests__/explore/sidebar-lifecycle-flow.test.tsx @@ -19,7 +19,7 @@ const mockUninstall = vi.fn() const mockUpdatePinStatus = vi.fn() let mockInstalledApps: InstalledApp[] = [] -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useSelectedLayoutSegments: () => mockSegments, useRouter: () => ({ push: mockPush, diff --git a/web/__tests__/plugins/plugin-card-rendering.test.tsx b/web/__tests__/plugins/plugin-card-rendering.test.tsx index 7abcb01b49..5bd7f0c8bf 100644 --- a/web/__tests__/plugins/plugin-card-rendering.test.tsx +++ b/web/__tests__/plugins/plugin-card-rendering.test.tsx @@ -8,6 +8,8 @@ import { cleanup, render, screen } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' +let mockTheme = 'light' + vi.mock('#i18n', () => ({ useTranslation: () => ({ t: (key: string) => key, @@ -19,16 +21,16 @@ vi.mock('@/context/i18n', () => ({ })) vi.mock('@/hooks/use-theme', () => ({ - default: () => ({ theme: 'light' }), + default: () => ({ theme: mockTheme }), })) vi.mock('@/i18n-config', () => ({ renderI18nObject: (obj: Record, locale: string) => obj[locale] || obj.en_US || '', })) -vi.mock('@/types/app', () => ({ - Theme: { dark: 'dark', light: 'light' }, -})) +vi.mock('@/types/app', async () => { + return vi.importActual('@/types/app') +}) vi.mock('@/utils/classnames', () => ({ cn: (...args: unknown[]) => args.filter(a => typeof a === 'string' && a).join(' '), @@ -100,6 +102,7 @@ type CardPayload = Parameters[0]['payload'] describe('Plugin Card Rendering Integration', () => { beforeEach(() => { cleanup() + mockTheme = 'light' }) const makePayload = (overrides = {}) => ({ @@ -194,9 +197,7 @@ describe('Plugin Card Rendering Integration', () => { }) it('uses dark icon when theme is dark and icon_dark is provided', () => { - vi.doMock('@/hooks/use-theme', () => ({ - default: () => ({ theme: 'dark' }), - })) + mockTheme = 'dark' const payload = makePayload({ icon: 'https://example.com/icon-light.png', @@ -204,7 +205,7 @@ describe('Plugin Card Rendering Integration', () => { }) render() - expect(screen.getByTestId('card-icon')).toBeInTheDocument() + expect(screen.getByTestId('card-icon')).toHaveTextContent('https://example.com/icon-dark.png') }) it('shows loading placeholder when isLoading is true', () => { diff --git a/web/__tests__/share/text-generation-index-flow.test.tsx b/web/__tests__/share/text-generation-index-flow.test.tsx index 3292474bec..2fec054a47 100644 --- a/web/__tests__/share/text-generation-index-flow.test.tsx +++ b/web/__tests__/share/text-generation-index-flow.test.tsx @@ -5,7 +5,7 @@ import TextGeneration from '@/app/components/share/text-generation' const useSearchParamsMock = vi.fn(() => new URLSearchParams()) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useSearchParams: () => useSearchParamsMock(), })) diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx index 0160553092..ca134cb17e 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx @@ -15,8 +15,6 @@ import { RiTerminalWindowLine, } from '@remixicon/react' import { useUnmount } from 'ahooks' -import dynamic from 'next/dynamic' -import { usePathname, useRouter } from 'next/navigation' import * as React from 'react' import { useCallback, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -28,6 +26,8 @@ import { useStore as useTagStore } from '@/app/components/base/tag-management/st import { useAppContext } from '@/context/app-context' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import useDocumentTitle from '@/hooks/use-document-title' +import dynamic from '@/next/dynamic' +import { usePathname, useRouter } from '@/next/navigation' import { fetchAppDetailDirect } from '@/service/apps' import { AppModeEnum } from '@/types/app' import { cn } from '@/utils/classnames' diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx index 5e7d98d191..4201d11490 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx @@ -7,7 +7,6 @@ import { RiEqualizer2Line, } from '@remixicon/react' import { useBoolean } from 'ahooks' -import { usePathname } from 'next/navigation' import * as React from 'react' import { useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -17,6 +16,7 @@ import Loading from '@/app/components/base/loading' import Toast from '@/app/components/base/toast' import Indicator from '@/app/components/header/indicator' import { useAppContext } from '@/context/app-context' +import { usePathname } from '@/next/navigation' import { fetchTracingConfig as doFetchTracingConfig, fetchTracingStatus, updateTracingStatus } from '@/service/apps' import { cn } from '@/utils/classnames' import ConfigButton from './config-button' diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx index 495c57a4ce..ebae9c98cf 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx @@ -11,7 +11,6 @@ import { RiFocus2Fill, RiFocus2Line, } from '@remixicon/react' -import { usePathname } from 'next/navigation' import * as React from 'react' import { useEffect, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -25,6 +24,7 @@ import DatasetDetailContext from '@/context/dataset-detail' import { useEventEmitterContextContext } from '@/context/event-emitter' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import useDocumentTitle from '@/hooks/use-document-title' +import { usePathname } from '@/next/navigation' import { useDatasetDetail, useDatasetRelatedApps } from '@/service/knowledge/use-dataset' import { cn } from '@/utils/classnames' diff --git a/web/app/(commonLayout)/datasets/layout.spec.tsx b/web/app/(commonLayout)/datasets/layout.spec.tsx index 5873f344d0..9c01cffba8 100644 --- a/web/app/(commonLayout)/datasets/layout.spec.tsx +++ b/web/app/(commonLayout)/datasets/layout.spec.tsx @@ -6,7 +6,7 @@ import DatasetsLayout from './layout' const mockReplace = vi.fn() const mockUseAppContext = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ replace: mockReplace, }), diff --git a/web/app/(commonLayout)/datasets/layout.tsx b/web/app/(commonLayout)/datasets/layout.tsx index b543c42570..a465f8222b 100644 --- a/web/app/(commonLayout)/datasets/layout.tsx +++ b/web/app/(commonLayout)/datasets/layout.tsx @@ -1,11 +1,11 @@ 'use client' -import { useRouter } from 'next/navigation' import { useEffect } from 'react' import Loading from '@/app/components/base/loading' import { useAppContext } from '@/context/app-context' import { ExternalApiPanelProvider } from '@/context/external-api-panel-context' import { ExternalKnowledgeApiProvider } from '@/context/external-knowledge-api-context' +import { useRouter } from '@/next/navigation' export default function DatasetsLayout({ children }: { children: React.ReactNode }) { const { isCurrentWorkspaceEditor, isCurrentWorkspaceDatasetOperator, currentWorkspace, isLoadingCurrentWorkspace } = useAppContext() diff --git a/web/app/(commonLayout)/education-apply/page.tsx b/web/app/(commonLayout)/education-apply/page.tsx index fce6fe1d5d..44ba5ee8ad 100644 --- a/web/app/(commonLayout)/education-apply/page.tsx +++ b/web/app/(commonLayout)/education-apply/page.tsx @@ -1,15 +1,15 @@ 'use client' -import { - useRouter, - useSearchParams, -} from 'next/navigation' import { useEffect, useMemo, } from 'react' import EducationApplyPage from '@/app/education-apply/education-apply-page' import { useProviderContext } from '@/context/provider-context' +import { + useRouter, + useSearchParams, +} from '@/next/navigation' export default function EducationApply() { const router = useRouter() diff --git a/web/app/(commonLayout)/role-route-guard.spec.tsx b/web/app/(commonLayout)/role-route-guard.spec.tsx index 87bf9be8af..ca1550f0b8 100644 --- a/web/app/(commonLayout)/role-route-guard.spec.tsx +++ b/web/app/(commonLayout)/role-route-guard.spec.tsx @@ -6,7 +6,7 @@ const mockReplace = vi.fn() const mockUseAppContext = vi.fn() let mockPathname = '/apps' -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ usePathname: () => mockPathname, useRouter: () => ({ replace: mockReplace, diff --git a/web/app/(commonLayout)/role-route-guard.tsx b/web/app/(commonLayout)/role-route-guard.tsx index 9ca5b25caa..6de5efb346 100644 --- a/web/app/(commonLayout)/role-route-guard.tsx +++ b/web/app/(commonLayout)/role-route-guard.tsx @@ -1,10 +1,10 @@ 'use client' import type { ReactNode } from 'react' -import { usePathname, useRouter } from 'next/navigation' import { useEffect } from 'react' import Loading from '@/app/components/base/loading' import { useAppContext } from '@/context/app-context' +import { usePathname, useRouter } from '@/next/navigation' const datasetOperatorRedirectRoutes = ['/apps', '/app', '/snippets', '/explore', '/tools'] as const diff --git a/web/app/(humanInputLayout)/form/[token]/form.tsx b/web/app/(humanInputLayout)/form/[token]/form.tsx index d027ef8b7d..035da6be8a 100644 --- a/web/app/(humanInputLayout)/form/[token]/form.tsx +++ b/web/app/(humanInputLayout)/form/[token]/form.tsx @@ -9,7 +9,6 @@ import { RiInformation2Fill, } from '@remixicon/react' import { produce } from 'immer' -import { useParams } from 'next/navigation' import * as React from 'react' import { useEffect, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -21,6 +20,7 @@ import { getButtonStyle } from '@/app/components/base/chat/chat/answer/human-inp import Loading from '@/app/components/base/loading' import DifyLogo from '@/app/components/base/logo/dify-logo' import useDocumentTitle from '@/hooks/use-document-title' +import { useParams } from '@/next/navigation' import { useGetHumanInputForm, useSubmitHumanInputForm } from '@/service/use-share' import { cn } from '@/utils/classnames' diff --git a/web/app/(shareLayout)/components/authenticated-layout.tsx b/web/app/(shareLayout)/components/authenticated-layout.tsx index c874990448..9f956a8501 100644 --- a/web/app/(shareLayout)/components/authenticated-layout.tsx +++ b/web/app/(shareLayout)/components/authenticated-layout.tsx @@ -1,12 +1,12 @@ 'use client' -import { usePathname, useRouter, useSearchParams } from 'next/navigation' import * as React from 'react' import { useCallback, useEffect } from 'react' import { useTranslation } from 'react-i18next' import AppUnavailable from '@/app/components/base/app-unavailable' import Loading from '@/app/components/base/loading' import { useWebAppStore } from '@/context/web-app-context' +import { usePathname, useRouter, useSearchParams } from '@/next/navigation' import { useGetUserCanAccessApp } from '@/service/access-control' import { useGetWebAppInfo, useGetWebAppMeta, useGetWebAppParams } from '@/service/use-share' import { webAppLogout } from '@/service/webapp-auth' diff --git a/web/app/(shareLayout)/components/splash.tsx b/web/app/(shareLayout)/components/splash.tsx index a2b847f74f..402005752d 100644 --- a/web/app/(shareLayout)/components/splash.tsx +++ b/web/app/(shareLayout)/components/splash.tsx @@ -1,11 +1,11 @@ 'use client' import type { FC, PropsWithChildren } from 'react' -import { useRouter, useSearchParams } from 'next/navigation' import { useCallback, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import AppUnavailable from '@/app/components/base/app-unavailable' import Loading from '@/app/components/base/loading' import { useWebAppStore } from '@/context/web-app-context' +import { useRouter, useSearchParams } from '@/next/navigation' import { fetchAccessToken } from '@/service/share' import { setWebAppAccessToken, setWebAppPassport, webAppLoginStatus, webAppLogout } from '@/service/webapp-auth' diff --git a/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx b/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx index fbf45259e5..a0aa86e35b 100644 --- a/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx +++ b/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx @@ -1,14 +1,14 @@ 'use client' import { RiArrowLeftLine, RiMailSendFill } from '@remixicon/react' -import { useRouter, useSearchParams } from 'next/navigation' import { useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Toast from '@/app/components/base/toast' import Countdown from '@/app/components/signin/countdown' - import { useLocale } from '@/context/i18n' + +import { useRouter, useSearchParams } from '@/next/navigation' import { sendWebAppResetPasswordCode, verifyWebAppResetPasswordCode } from '@/service/common' export default function CheckCode() { diff --git a/web/app/(shareLayout)/webapp-reset-password/page.tsx b/web/app/(shareLayout)/webapp-reset-password/page.tsx index 9b9a853cdd..3763e0bb2a 100644 --- a/web/app/(shareLayout)/webapp-reset-password/page.tsx +++ b/web/app/(shareLayout)/webapp-reset-password/page.tsx @@ -1,8 +1,6 @@ 'use client' import { RiArrowLeftLine, RiLockPasswordLine } from '@remixicon/react' import { noop } from 'es-toolkit/function' -import Link from 'next/link' -import { useRouter, useSearchParams } from 'next/navigation' import { useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' @@ -10,9 +8,11 @@ import Input from '@/app/components/base/input' import Toast from '@/app/components/base/toast' import { COUNT_DOWN_KEY, COUNT_DOWN_TIME_MS } from '@/app/components/signin/countdown' import { emailRegex } from '@/config' - import { useLocale } from '@/context/i18n' import useDocumentTitle from '@/hooks/use-document-title' + +import Link from '@/next/link' +import { useRouter, useSearchParams } from '@/next/navigation' import { sendResetPasswordCode } from '@/service/common' export default function CheckCode() { diff --git a/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx b/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx index 9f59e8f9eb..1a97f6440b 100644 --- a/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx +++ b/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx @@ -1,13 +1,13 @@ 'use client' import { RiCheckboxCircleFill } from '@remixicon/react' import { useCountDown } from 'ahooks' -import { useRouter, useSearchParams } from 'next/navigation' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Toast from '@/app/components/base/toast' import { validPassword } from '@/config' +import { useRouter, useSearchParams } from '@/next/navigation' import { changeWebAppPasswordWithToken } from '@/service/common' import { cn } from '@/utils/classnames' diff --git a/web/app/(shareLayout)/webapp-signin/check-code/page.tsx b/web/app/(shareLayout)/webapp-signin/check-code/page.tsx index afea9d668b..81b7c1b9a6 100644 --- a/web/app/(shareLayout)/webapp-signin/check-code/page.tsx +++ b/web/app/(shareLayout)/webapp-signin/check-code/page.tsx @@ -1,7 +1,6 @@ 'use client' import type { FormEvent } from 'react' import { RiArrowLeftLine, RiMailSendFill } from '@remixicon/react' -import { useRouter, useSearchParams } from 'next/navigation' import { useCallback, useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' @@ -10,6 +9,7 @@ import Toast from '@/app/components/base/toast' import Countdown from '@/app/components/signin/countdown' import { useLocale } from '@/context/i18n' import { useWebAppStore } from '@/context/web-app-context' +import { useRouter, useSearchParams } from '@/next/navigation' import { sendWebAppEMailLoginCode, webAppEmailLoginWithCode } from '@/service/common' import { fetchAccessToken } from '@/service/share' import { setWebAppAccessToken, setWebAppPassport } from '@/service/webapp-auth' diff --git a/web/app/(shareLayout)/webapp-signin/components/external-member-sso-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/external-member-sso-auth.tsx index 0776df036d..391479c870 100644 --- a/web/app/(shareLayout)/webapp-signin/components/external-member-sso-auth.tsx +++ b/web/app/(shareLayout)/webapp-signin/components/external-member-sso-auth.tsx @@ -1,11 +1,11 @@ 'use client' -import { useRouter, useSearchParams } from 'next/navigation' import * as React from 'react' import { useCallback, useEffect } from 'react' import AppUnavailable from '@/app/components/base/app-unavailable' import Loading from '@/app/components/base/loading' import Toast from '@/app/components/base/toast' import { useGlobalPublicStore } from '@/context/global-public-context' +import { useRouter, useSearchParams } from '@/next/navigation' import { fetchWebOAuth2SSOUrl, fetchWebOIDCSSOUrl, fetchWebSAMLSSOUrl } from '@/service/share' import { SSOProtocol } from '@/types/feature' diff --git a/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx index 5aa9d9f141..b350549784 100644 --- a/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx +++ b/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx @@ -1,5 +1,4 @@ import { noop } from 'es-toolkit/function' -import { useRouter, useSearchParams } from 'next/navigation' import { useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' @@ -8,6 +7,7 @@ import Toast from '@/app/components/base/toast' import { COUNT_DOWN_KEY, COUNT_DOWN_TIME_MS } from '@/app/components/signin/countdown' import { emailRegex } from '@/config' import { useLocale } from '@/context/i18n' +import { useRouter, useSearchParams } from '@/next/navigation' import { sendWebAppEMailLoginCode } from '@/service/common' export default function MailAndCodeAuth() { diff --git a/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx index e49559401d..87419438e3 100644 --- a/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx +++ b/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx @@ -1,7 +1,5 @@ 'use client' import { noop } from 'es-toolkit/function' -import Link from 'next/link' -import { useRouter, useSearchParams } from 'next/navigation' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' @@ -10,6 +8,8 @@ import Toast from '@/app/components/base/toast' import { emailRegex } from '@/config' import { useLocale } from '@/context/i18n' import { useWebAppStore } from '@/context/web-app-context' +import Link from '@/next/link' +import { useRouter, useSearchParams } from '@/next/navigation' import { webAppLogin } from '@/service/common' import { fetchAccessToken } from '@/service/share' import { setWebAppAccessToken, setWebAppPassport } from '@/service/webapp-auth' diff --git a/web/app/(shareLayout)/webapp-signin/components/sso-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/sso-auth.tsx index d8f3854868..79d67dde5c 100644 --- a/web/app/(shareLayout)/webapp-signin/components/sso-auth.tsx +++ b/web/app/(shareLayout)/webapp-signin/components/sso-auth.tsx @@ -1,11 +1,11 @@ 'use client' import type { FC } from 'react' -import { useRouter, useSearchParams } from 'next/navigation' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import { Lock01 } from '@/app/components/base/icons/src/vender/solid/security' import Toast from '@/app/components/base/toast' +import { useRouter, useSearchParams } from '@/next/navigation' import { fetchMembersOAuth2SSOUrl, fetchMembersOIDCSSOUrl, fetchMembersSAMLSSOUrl } from '@/service/share' import { SSOProtocol } from '@/types/feature' diff --git a/web/app/(shareLayout)/webapp-signin/normalForm.tsx b/web/app/(shareLayout)/webapp-signin/normalForm.tsx index b15145346f..7ee08d66ae 100644 --- a/web/app/(shareLayout)/webapp-signin/normalForm.tsx +++ b/web/app/(shareLayout)/webapp-signin/normalForm.tsx @@ -1,12 +1,12 @@ 'use client' import { RiContractLine, RiDoorLockLine, RiErrorWarningFill } from '@remixicon/react' -import Link from 'next/link' import * as React from 'react' import { useCallback, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import Loading from '@/app/components/base/loading' import { IS_CE_EDITION } from '@/config' import { useGlobalPublicStore } from '@/context/global-public-context' +import Link from '@/next/link' import { LicenseStatus } from '@/types/feature' import { cn } from '@/utils/classnames' import MailAndCodeAuth from './components/mail-and-code-auth' diff --git a/web/app/(shareLayout)/webapp-signin/page.tsx b/web/app/(shareLayout)/webapp-signin/page.tsx index b3ad1d48a6..a5c2528cc7 100644 --- a/web/app/(shareLayout)/webapp-signin/page.tsx +++ b/web/app/(shareLayout)/webapp-signin/page.tsx @@ -1,6 +1,5 @@ 'use client' import type { FC } from 'react' -import { useRouter, useSearchParams } from 'next/navigation' import * as React from 'react' import { useCallback } from 'react' import { useTranslation } from 'react-i18next' @@ -8,6 +7,7 @@ import AppUnavailable from '@/app/components/base/app-unavailable' import { useGlobalPublicStore } from '@/context/global-public-context' import { useWebAppStore } from '@/context/web-app-context' import { AccessMode } from '@/models/access-control' +import { useRouter, useSearchParams } from '@/next/navigation' import { webAppLogout } from '@/service/webapp-auth' import ExternalMemberSsoAuth from './components/external-member-sso-auth' import NormalForm from './normalForm' diff --git a/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx b/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx index 9bd32d2576..3fc677d8d8 100644 --- a/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx +++ b/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx @@ -160,7 +160,7 @@ const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => { isShow={isShowDeleteConfirm} onClose={() => setIsShowDeleteConfirm(false)} > -
{t('avatar.deleteTitle', { ns: 'common' })}
+
{t('avatar.deleteTitle', { ns: 'common' })}

{t('avatar.deleteDescription', { ns: 'common' })}

diff --git a/web/app/account/(commonLayout)/account-page/email-change-modal.tsx b/web/app/account/(commonLayout)/account-page/email-change-modal.tsx index 463c27294a..f0dfd4f12f 100644 --- a/web/app/account/(commonLayout)/account-page/email-change-modal.tsx +++ b/web/app/account/(commonLayout)/account-page/email-change-modal.tsx @@ -1,7 +1,6 @@ import type { ResponseError } from '@/service/fetch' import { RiCloseLine } from '@remixicon/react' import { noop } from 'es-toolkit/function' -import { useRouter } from 'next/navigation' import * as React from 'react' import { useState } from 'react' import { Trans, useTranslation } from 'react-i18next' @@ -10,6 +9,7 @@ import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Modal from '@/app/components/base/modal' import { ToastContext } from '@/app/components/base/toast/context' +import { useRouter } from '@/next/navigation' import { checkEmailExisted, resetEmail, @@ -209,14 +209,14 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
{step === STEP.start && ( <> -
{t('account.changeEmail.title', { ns: 'common' })}
+
{t('account.changeEmail.title', { ns: 'common' })}
-
{t('account.changeEmail.authTip', { ns: 'common' })}
-
+
{t('account.changeEmail.authTip', { ns: 'common' })}
+
}} + components={{ email: }} values={{ email }} />
@@ -241,19 +241,19 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => { )} {step === STEP.verifyOrigin && ( <> -
{t('account.changeEmail.verifyEmail', { ns: 'common' })}
+
{t('account.changeEmail.verifyEmail', { ns: 'common' })}
-
+
}} + components={{ email: }} values={{ email }} />
-
{t('account.changeEmail.codeLabel', { ns: 'common' })}
+
{t('account.changeEmail.codeLabel', { ns: 'common' })}
{ {t('operation.cancel', { ns: 'common' })}
-
+
{t('account.changeEmail.resendTip', { ns: 'common' })} {time > 0 && ( {t('account.changeEmail.resendCount', { ns: 'common', count: time })} )} {!time && ( - {t('account.changeEmail.resend', { ns: 'common' })} + {t('account.changeEmail.resend', { ns: 'common' })} )}
)} {step === STEP.newEmail && ( <> -
{t('account.changeEmail.newEmail', { ns: 'common' })}
+
{t('account.changeEmail.newEmail', { ns: 'common' })}
-
{t('account.changeEmail.content3', { ns: 'common' })}
+
{t('account.changeEmail.content3', { ns: 'common' })}
-
{t('account.changeEmail.emailLabel', { ns: 'common' })}
+
{t('account.changeEmail.emailLabel', { ns: 'common' })}
{ destructive={newEmailExited || unAvailableEmail} /> {newEmailExited && ( -
{t('account.changeEmail.existingEmail', { ns: 'common' })}
+
{t('account.changeEmail.existingEmail', { ns: 'common' })}
)} {unAvailableEmail && ( -
{t('account.changeEmail.unAvailableEmail', { ns: 'common' })}
+
{t('account.changeEmail.unAvailableEmail', { ns: 'common' })}
)}
@@ -331,19 +331,19 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => { )} {step === STEP.verifyNew && ( <> -
{t('account.changeEmail.verifyNew', { ns: 'common' })}
+
{t('account.changeEmail.verifyNew', { ns: 'common' })}
-
+
}} + components={{ email: }} values={{ email: mail }} />
-
{t('account.changeEmail.codeLabel', { ns: 'common' })}
+
{t('account.changeEmail.codeLabel', { ns: 'common' })}
{ {t('operation.cancel', { ns: 'common' })}
-
+
{t('account.changeEmail.resendTip', { ns: 'common' })} {time > 0 && ( {t('account.changeEmail.resendCount', { ns: 'common', count: time })} )} {!time && ( - {t('account.changeEmail.resend', { ns: 'common' })} + {t('account.changeEmail.resend', { ns: 'common' })} )}
diff --git a/web/app/account/(commonLayout)/account-page/index.tsx b/web/app/account/(commonLayout)/account-page/index.tsx index 58331e3a77..9a104619da 100644 --- a/web/app/account/(commonLayout)/account-page/index.tsx +++ b/web/app/account/(commonLayout)/account-page/index.tsx @@ -145,7 +145,7 @@ export default function AccountPage() { imageUrl={icon_url} />
-
{item.name}
+
{item.name}
) } @@ -153,12 +153,12 @@ export default function AccountPage() { return ( <>
-

{t('account.myAccount', { ns: 'common' })}

+

{t('account.myAccount', { ns: 'common' })}

-

+

{userProfile.name} {isEducationAccount && ( @@ -167,16 +167,16 @@ export default function AccountPage() { )}

-

{userProfile.email}

+

{userProfile.email}

{t('account.name', { ns: 'common' })}
-
+
{userProfile.name}
-
+
{t('operation.edit', { ns: 'common' })}
@@ -184,11 +184,11 @@ export default function AccountPage() {
{t('account.email', { ns: 'common' })}
-
+
{userProfile.email}
{systemFeatures.enable_change_email && ( -
setShowUpdateEmail(true)}> +
setShowUpdateEmail(true)}> {t('operation.change', { ns: 'common' })}
)} @@ -198,8 +198,8 @@ export default function AccountPage() { systemFeatures.enable_email_password_login && (
-
{t('account.password', { ns: 'common' })}
-
{t('account.passwordTip', { ns: 'common' })}
+
{t('account.password', { ns: 'common' })}
+
{t('account.passwordTip', { ns: 'common' })}
@@ -226,7 +226,7 @@ export default function AccountPage() { onClose={() => setEditNameModalVisible(false)} className="!w-[420px] !p-6" > -
{t('account.editName', { ns: 'common' })}
+
{t('account.editName', { ns: 'common' })}
{t('account.name', { ns: 'common' })}
-
{userProfile.is_password_set ? t('account.resetPassword', { ns: 'common' }) : t('account.setPassword', { ns: 'common' })}
+
{userProfile.is_password_set ? t('account.resetPassword', { ns: 'common' }) : t('account.setPassword', { ns: 'common' })}
{userProfile.is_password_set && ( <>
{t('account.currentPassword', { ns: 'common' })}
@@ -279,7 +279,7 @@ export default function AccountPage() {
)} -
+
{userProfile.is_password_set ? t('account.newPassword', { ns: 'common' }) : t('account.password', { ns: 'common' })}
@@ -298,7 +298,7 @@ export default function AccountPage() {
-
{t('account.confirmPassword', { ns: 'common' })}
+
{t('account.confirmPassword', { ns: 'common' })}
{ diff --git a/web/app/account/oauth/authorize/page.tsx b/web/app/account/oauth/authorize/page.tsx index 835a1e702e..30cfdd25d3 100644 --- a/web/app/account/oauth/authorize/page.tsx +++ b/web/app/account/oauth/authorize/page.tsx @@ -7,16 +7,16 @@ import { RiMailLine, RiTranslate2, } from '@remixicon/react' -import { useRouter, useSearchParams } from 'next/navigation' import * as React from 'react' import { useEffect, useRef } from 'react' import { useTranslation } from 'react-i18next' import { Avatar } from '@/app/components/base/avatar' import Button from '@/app/components/base/button' import Loading from '@/app/components/base/loading' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { useLanguage } from '@/app/components/header/account-setting/model-provider-page/hooks' import { setPostLoginRedirect } from '@/app/signin/utils/post-login-redirect' +import { useRouter, useSearchParams } from '@/next/navigation' import { useIsLogin, useUserProfile } from '@/service/use-common' import { useAuthorizeOAuthApp, useOAuthAppInfo } from '@/service/use-oauth' @@ -91,9 +91,9 @@ export default function OAuthAuthorize() { globalThis.location.href = url.toString() } catch (err: any) { - Toast.notify({ + toast.add({ type: 'error', - message: `${t('error.authorizeFailed', { ns: 'oauth' })}: ${err.message}`, + title: `${t('error.authorizeFailed', { ns: 'oauth' })}: ${err.message}`, }) } } @@ -102,10 +102,10 @@ export default function OAuthAuthorize() { const invalidParams = !client_id || !redirect_uri if ((invalidParams || isError) && !hasNotifiedRef.current) { hasNotifiedRef.current = true - Toast.notify({ + toast.add({ type: 'error', - message: invalidParams ? t('error.invalidParams', { ns: 'oauth' }) : t('error.authAppInfoFetchFailed', { ns: 'oauth' }), - duration: 0, + title: invalidParams ? t('error.invalidParams', { ns: 'oauth' }) : t('error.authAppInfoFetchFailed', { ns: 'oauth' }), + timeout: 0, }) } }, [client_id, redirect_uri, isError]) diff --git a/web/app/activate/activateForm.tsx b/web/app/activate/activateForm.tsx index 421b816652..418d3b8bb1 100644 --- a/web/app/activate/activateForm.tsx +++ b/web/app/activate/activateForm.tsx @@ -1,11 +1,11 @@ 'use client' -import { useRouter, useSearchParams } from 'next/navigation' import { useEffect } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Loading from '@/app/components/base/loading' - import useDocumentTitle from '@/hooks/use-document-title' + +import { useRouter, useSearchParams } from '@/next/navigation' import { useInvitationCheck } from '@/service/use-common' import { cn } from '@/utils/classnames' diff --git a/web/app/components/browser-initializer.spec.ts b/web/app/components/__tests__/browser-initializer.spec.ts similarity index 100% rename from web/app/components/browser-initializer.spec.ts rename to web/app/components/__tests__/browser-initializer.spec.ts diff --git a/web/app/components/app-initializer.tsx b/web/app/components/app-initializer.tsx index bf7aa39580..e08ece6666 100644 --- a/web/app/components/app-initializer.tsx +++ b/web/app/components/app-initializer.tsx @@ -2,13 +2,13 @@ import type { ReactNode } from 'react' import Cookies from 'js-cookie' -import { usePathname, useRouter, useSearchParams } from 'next/navigation' import { parseAsBoolean, useQueryState } from 'nuqs' import { useCallback, useEffect, useState } from 'react' import { EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION, EDUCATION_VERIFYING_LOCALSTORAGE_ITEM, } from '@/app/education-apply/constants' +import { usePathname, useRouter, useSearchParams } from '@/next/navigation' import { sendGAEvent } from '@/utils/gtag' import { fetchSetupStatusWithCache } from '@/utils/setup-status' import { resolvePostLoginRedirect } from '../signin/utils/post-login-redirect' diff --git a/web/app/components/app-sidebar/__tests__/index.spec.tsx b/web/app/components/app-sidebar/__tests__/index.spec.tsx index cf685b33a5..1b6046baee 100644 --- a/web/app/components/app-sidebar/__tests__/index.spec.tsx +++ b/web/app/components/app-sidebar/__tests__/index.spec.tsx @@ -19,7 +19,7 @@ vi.mock('zustand/react/shallow', () => ({ useShallow: (fn: unknown) => fn, })) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ usePathname: () => mockPathname, })) diff --git a/web/app/components/app-sidebar/__tests__/text-squeeze-fix-verification.spec.tsx b/web/app/components/app-sidebar/__tests__/text-squeeze-fix-verification.spec.tsx index fb19833dd2..a3868a8330 100644 --- a/web/app/components/app-sidebar/__tests__/text-squeeze-fix-verification.spec.tsx +++ b/web/app/components/app-sidebar/__tests__/text-squeeze-fix-verification.spec.tsx @@ -7,7 +7,7 @@ import { render } from '@testing-library/react' import * as React from 'react' // Mock Next.js navigation -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useSelectedLayoutSegment: () => 'overview', })) diff --git a/web/app/components/app-sidebar/app-info/__tests__/app-info-modals.spec.tsx b/web/app/components/app-sidebar/app-info/__tests__/app-info-modals.spec.tsx index f8612e8057..2f98089e40 100644 --- a/web/app/components/app-sidebar/app-info/__tests__/app-info-modals.spec.tsx +++ b/web/app/components/app-sidebar/app-info/__tests__/app-info-modals.spec.tsx @@ -5,7 +5,7 @@ import * as React from 'react' import { AppModeEnum } from '@/types/app' import AppInfoModals from '../app-info-modals' -vi.mock('next/dynamic', () => ({ +vi.mock('@/next/dynamic', () => ({ default: (loader: () => Promise<{ default: React.ComponentType }>) => { const LazyComp = React.lazy(loader) return function DynamicWrapper(props: Record) { diff --git a/web/app/components/app-sidebar/app-info/__tests__/use-app-info-actions.spec.ts b/web/app/components/app-sidebar/app-info/__tests__/use-app-info-actions.spec.ts index 6104e2b641..deea28ce3e 100644 --- a/web/app/components/app-sidebar/app-info/__tests__/use-app-info-actions.spec.ts +++ b/web/app/components/app-sidebar/app-info/__tests__/use-app-info-actions.spec.ts @@ -23,7 +23,7 @@ let mockAppDetail: Record | undefined = { icon_background: '#FFEAD5', } -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ replace: mockReplace }), })) diff --git a/web/app/components/app-sidebar/app-info/app-info-modals.tsx b/web/app/components/app-sidebar/app-info/app-info-modals.tsx index 4ca7f6adbc..232afb18c7 100644 --- a/web/app/components/app-sidebar/app-info/app-info-modals.tsx +++ b/web/app/components/app-sidebar/app-info/app-info-modals.tsx @@ -3,9 +3,9 @@ import type { DuplicateAppModalProps } from '@/app/components/app/duplicate-moda import type { CreateAppModalProps } from '@/app/components/explore/create-app-modal' import type { EnvironmentVariable } from '@/app/components/workflow/types' import type { App, AppSSO } from '@/types/app' -import dynamic from 'next/dynamic' import * as React from 'react' import { useTranslation } from 'react-i18next' +import dynamic from '@/next/dynamic' const SwitchAppModal = dynamic(() => import('@/app/components/app/switch-app-modal'), { ssr: false }) const CreateAppModal = dynamic(() => import('@/app/components/explore/create-app-modal'), { ssr: false }) diff --git a/web/app/components/app-sidebar/app-info/use-app-info-actions.ts b/web/app/components/app-sidebar/app-info/use-app-info-actions.ts index 800f21de44..55ec13e506 100644 --- a/web/app/components/app-sidebar/app-info/use-app-info-actions.ts +++ b/web/app/components/app-sidebar/app-info/use-app-info-actions.ts @@ -1,7 +1,6 @@ import type { DuplicateAppModalProps } from '@/app/components/app/duplicate-modal' import type { CreateAppModalProps } from '@/app/components/explore/create-app-modal' import type { EnvironmentVariable } from '@/app/components/workflow/types' -import { useRouter } from 'next/navigation' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' @@ -9,6 +8,7 @@ import { useStore as useAppStore } from '@/app/components/app/store' import { ToastContext } from '@/app/components/base/toast/context' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useProviderContext } from '@/context/provider-context' +import { useRouter } from '@/next/navigation' import { copyApp, deleteApp, exportAppConfig, updateAppInfo } from '@/service/apps' import { useInvalidateAppList } from '@/service/use-apps' import { fetchWorkflowDraft } from '@/service/workflow' diff --git a/web/app/components/app-sidebar/dataset-info/__tests__/dropdown-callbacks.spec.tsx b/web/app/components/app-sidebar/dataset-info/__tests__/dropdown-callbacks.spec.tsx index 512f9490c2..1df6fa79b7 100644 --- a/web/app/components/app-sidebar/dataset-info/__tests__/dropdown-callbacks.spec.tsx +++ b/web/app/components/app-sidebar/dataset-info/__tests__/dropdown-callbacks.spec.tsx @@ -80,7 +80,7 @@ const createDataset = (overrides: Partial = {}): DataSet => ({ ...overrides, }) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ replace: mockReplace }), })) diff --git a/web/app/components/app-sidebar/dataset-info/__tests__/index.spec.tsx b/web/app/components/app-sidebar/dataset-info/__tests__/index.spec.tsx index be27e247d7..a1e275d731 100644 --- a/web/app/components/app-sidebar/dataset-info/__tests__/index.spec.tsx +++ b/web/app/components/app-sidebar/dataset-info/__tests__/index.spec.tsx @@ -90,7 +90,7 @@ const createDataset = (overrides: Partial = {}): DataSet => ({ ...overrides, }) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ replace: mockReplace, }), diff --git a/web/app/components/app-sidebar/dataset-info/dropdown.tsx b/web/app/components/app-sidebar/dataset-info/dropdown.tsx index 96127c4210..528bac831f 100644 --- a/web/app/components/app-sidebar/dataset-info/dropdown.tsx +++ b/web/app/components/app-sidebar/dataset-info/dropdown.tsx @@ -1,11 +1,11 @@ import type { DataSet } from '@/models/datasets' import { RiMoreFill } from '@remixicon/react' -import { useRouter } from 'next/navigation' import * as React from 'react' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import { useSelector as useAppContextWithSelector } from '@/context/app-context' import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' +import { useRouter } from '@/next/navigation' import { checkIsUsedInApp, deleteDataset } from '@/service/datasets' import { datasetDetailQueryKeyPrefix, useInvalidDatasetList } from '@/service/knowledge/use-dataset' import { useInvalid } from '@/service/use-base' diff --git a/web/app/components/app-sidebar/index.tsx b/web/app/components/app-sidebar/index.tsx index ee046e4f30..8cc734d46f 100644 --- a/web/app/components/app-sidebar/index.tsx +++ b/web/app/components/app-sidebar/index.tsx @@ -1,12 +1,12 @@ import type { NavIcon } from './nav-link' import { useHover, useKeyPress } from 'ahooks' -import { usePathname } from 'next/navigation' import * as React from 'react' import { useCallback, useEffect, useState } from 'react' import { useShallow } from 'zustand/react/shallow' import { useStore as useAppStore } from '@/app/components/app/store' import { useEventEmitterContextContext } from '@/context/event-emitter' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' +import { usePathname } from '@/next/navigation' import { cn } from '@/utils/classnames' import Divider from '../base/divider' import { getKeyboardKeyCodeBySystem } from '../workflow/utils' diff --git a/web/app/components/app-sidebar/nav-link/__tests__/index.spec.tsx b/web/app/components/app-sidebar/nav-link/__tests__/index.spec.tsx index af0cccf263..e480c36945 100644 --- a/web/app/components/app-sidebar/nav-link/__tests__/index.spec.tsx +++ b/web/app/components/app-sidebar/nav-link/__tests__/index.spec.tsx @@ -4,12 +4,12 @@ import * as React from 'react' import NavLink from '..' // Mock Next.js navigation -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useSelectedLayoutSegment: () => 'overview', })) // Mock Next.js Link component -vi.mock('next/link', () => ({ +vi.mock('@/next/link', () => ({ default: function MockLink({ children, href, className, title }: { children: React.ReactNode, href: string, className?: string, title?: string }) { return ( diff --git a/web/app/components/app-sidebar/nav-link/index.tsx b/web/app/components/app-sidebar/nav-link/index.tsx index b9c0ee7345..e27f5adede 100644 --- a/web/app/components/app-sidebar/nav-link/index.tsx +++ b/web/app/components/app-sidebar/nav-link/index.tsx @@ -1,8 +1,8 @@ 'use client' import type { RemixiconComponentType } from '@remixicon/react' -import Link from 'next/link' -import { useSelectedLayoutSegment } from 'next/navigation' import * as React from 'react' +import Link from '@/next/link' +import { useSelectedLayoutSegment } from '@/next/navigation' import { cn } from '@/utils/classnames' export type NavIcon = React.ComponentType< diff --git a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx index 118eaea58e..a969b3d491 100644 --- a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx +++ b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx @@ -94,7 +94,7 @@ const CSVUploader: FC = ({ />
{!file && ( -
+
diff --git a/web/app/components/app/configuration/base/warning-mask/has-not-set-api.spec.tsx b/web/app/components/app/configuration/base/warning-mask/has-not-set-api.spec.tsx index be4377bfd9..abcf5795d0 100644 --- a/web/app/components/app/configuration/base/warning-mask/has-not-set-api.spec.tsx +++ b/web/app/components/app/configuration/base/warning-mask/has-not-set-api.spec.tsx @@ -2,25 +2,19 @@ import { fireEvent, render, screen } from '@testing-library/react' import * as React from 'react' import HasNotSetAPI from './has-not-set-api' -describe('HasNotSetAPI WarningMask', () => { - it('should show default title when trial not finished', () => { - render() +describe('HasNotSetAPI', () => { + it('should render the empty state copy', () => { + render() - expect(screen.getByText('appDebug.notSetAPIKey.title')).toBeInTheDocument() - expect(screen.getByText('appDebug.notSetAPIKey.description')).toBeInTheDocument() + expect(screen.getByText('appDebug.noModelProviderConfigured')).toBeInTheDocument() + expect(screen.getByText('appDebug.noModelProviderConfiguredTip')).toBeInTheDocument() }) - it('should show trail finished title when flag is true', () => { - render() - - expect(screen.getByText('appDebug.notSetAPIKey.trailFinished')).toBeInTheDocument() - }) - - it('should call onSetting when primary button clicked', () => { + it('should call onSetting when manage models button is clicked', () => { const onSetting = vi.fn() - render() + render() - fireEvent.click(screen.getByRole('button', { name: 'appDebug.notSetAPIKey.settingBtn' })) + fireEvent.click(screen.getByRole('button', { name: 'appDebug.manageModels' })) expect(onSetting).toHaveBeenCalledTimes(1) }) }) diff --git a/web/app/components/app/configuration/base/warning-mask/has-not-set-api.tsx b/web/app/components/app/configuration/base/warning-mask/has-not-set-api.tsx index 84323e64f5..2c5fc5ff2f 100644 --- a/web/app/components/app/configuration/base/warning-mask/has-not-set-api.tsx +++ b/web/app/components/app/configuration/base/warning-mask/has-not-set-api.tsx @@ -2,38 +2,38 @@ import type { FC } from 'react' import * as React from 'react' import { useTranslation } from 'react-i18next' -import Button from '@/app/components/base/button' -import WarningMask from '.' export type IHasNotSetAPIProps = { - isTrailFinished: boolean onSetting: () => void } -const icon = ( - - - - -) - const HasNotSetAPI: FC = ({ - isTrailFinished, onSetting, }) => { const { t } = useTranslation() return ( - - {t('notSetAPIKey.settingBtn', { ns: 'appDebug' })} - {icon} - - )} - /> +
+
+
+
+ +
+
+
+
{t('noModelProviderConfigured', { ns: 'appDebug' })}
+
{t('noModelProviderConfiguredTip', { ns: 'appDebug' })}
+
+ +
+
) } export default React.memo(HasNotSetAPI) diff --git a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx index c33d55873d..39a1699063 100644 --- a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx +++ b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx @@ -178,7 +178,7 @@ const Prompt: FC = ({ {!noTitle && (
-
{mode !== AppModeEnum.COMPLETION ? t('chatSubTitle', { ns: 'appDebug' }) : t('completionSubTitle', { ns: 'appDebug' })}
+
{mode !== AppModeEnum.COMPLETION ? t('chatSubTitle', { ns: 'appDebug' }) : t('completionSubTitle', { ns: 'appDebug' })}
{!readonly && ( = ({ )}
-
+
= (
-
{t('codegen.instruction', { ns: 'appDebug' })}
+
{t('codegen.instruction', { ns: 'appDebug' })}
= ( disabled={isLoading} > - {t('codegen.generate', { ns: 'appDebug' })} + {t('codegen.generate', { ns: 'appDebug' })}
diff --git a/web/app/components/app/configuration/dataset-config/context-var/index.spec.tsx b/web/app/components/app/configuration/dataset-config/context-var/index.spec.tsx index 7f71247d56..8c6e626b45 100644 --- a/web/app/components/app/configuration/dataset-config/context-var/index.spec.tsx +++ b/web/app/components/app/configuration/dataset-config/context-var/index.spec.tsx @@ -5,7 +5,7 @@ import * as React from 'react' import ContextVar from './index' // Mock external dependencies only -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: vi.fn() }), usePathname: () => '/test', })) diff --git a/web/app/components/app/configuration/dataset-config/context-var/var-picker.spec.tsx b/web/app/components/app/configuration/dataset-config/context-var/var-picker.spec.tsx index aa8dae813f..6704fa0afd 100644 --- a/web/app/components/app/configuration/dataset-config/context-var/var-picker.spec.tsx +++ b/web/app/components/app/configuration/dataset-config/context-var/var-picker.spec.tsx @@ -5,7 +5,7 @@ import * as React from 'react' import VarPicker from './var-picker' // Mock external dependencies only -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: vi.fn() }), usePathname: () => '/test', })) diff --git a/web/app/components/app/configuration/dataset-config/params-config/index.spec.tsx b/web/app/components/app/configuration/dataset-config/params-config/index.spec.tsx index 7904159109..9366039414 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/index.spec.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/index.spec.tsx @@ -180,7 +180,7 @@ describe('dataset-config/params-config', () => { const dialog = await screen.findByRole('dialog', {}, { timeout: 3000 }) const dialogScope = within(dialog) - const incrementButtons = dialogScope.getAllByRole('button', { name: 'increment' }) + const incrementButtons = dialogScope.getAllByRole('button', { name: /increment/i }) await user.click(incrementButtons[0]) await waitFor(() => { @@ -213,7 +213,7 @@ describe('dataset-config/params-config', () => { const dialog = await screen.findByRole('dialog', {}, { timeout: 3000 }) const dialogScope = within(dialog) - const incrementButtons = dialogScope.getAllByRole('button', { name: 'increment' }) + const incrementButtons = dialogScope.getAllByRole('button', { name: /increment/i }) await user.click(incrementButtons[0]) await waitFor(() => { diff --git a/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx b/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx index 91e5353cc4..8c2fb77c20 100644 --- a/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx +++ b/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx @@ -2,7 +2,6 @@ import type { FC } from 'react' import type { DataSet } from '@/models/datasets' import { useInfiniteScroll } from 'ahooks' -import Link from 'next/link' import * as React from 'react' import { useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -14,6 +13,7 @@ import Modal from '@/app/components/base/modal' import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import FeatureIcon from '@/app/components/header/account-setting/model-provider-page/model-selector/feature-icon' import { useKnowledge } from '@/hooks/use-knowledge' +import Link from '@/next/link' import { useInfiniteDatasets } from '@/service/knowledge/use-dataset' import { cn } from '@/utils/classnames' diff --git a/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx b/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx index c9c8d080f2..bc534599de 100644 --- a/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx +++ b/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx @@ -210,7 +210,7 @@ const SettingsModal: FC = ({
-
{t('form.name', { ns: 'datasetSettings' })}
+
{t('form.name', { ns: 'datasetSettings' })}
= ({
-
{t('form.desc', { ns: 'datasetSettings' })}
+
{t('form.desc', { ns: 'datasetSettings' })}