Merge branch 'main' into jzh

This commit is contained in:
JzoNg 2026-03-18 17:40:51 +08:00
commit 883eb498c0
1010 changed files with 85332 additions and 10785 deletions

View File

@ -63,7 +63,8 @@ pnpm analyze-component <path> --review
### File Naming
- Test files: `ComponentName.spec.tsx` (same directory as component)
- Test files: `ComponentName.spec.tsx` inside a same-level `__tests__/` directory
- Placement rule: Component, hook, and utility tests must live in a sibling `__tests__/` folder at the same level as the source under test. For example, `foo/index.tsx` maps to `foo/__tests__/index.spec.tsx`, and `foo/bar.ts` maps to `foo/__tests__/bar.spec.ts`.
- Integration tests: `web/__tests__/` directory
## Test Structure Template

View File

@ -41,7 +41,7 @@ import userEvent from '@testing-library/user-event'
// Router (if component uses useRouter, usePathname, useSearchParams)
// WHY: Isolates tests from Next.js routing, enables testing navigation behavior
// const mockPush = vi.fn()
// vi.mock('next/navigation', () => ({
// vi.mock('@/next/navigation', () => ({
// useRouter: () => ({ push: mockPush }),
// usePathname: () => '/test-path',
// }))

View File

@ -29,8 +29,8 @@ jobs:
strategy:
fail-fast: false
matrix:
shardIndex: [1, 2, 3, 4]
shardTotal: [4]
shardIndex: [1, 2, 3, 4, 5, 6]
shardTotal: [6]
defaults:
run:
shell: bash

View File

@ -78,7 +78,7 @@ class UserProfile(TypedDict):
nickname: NotRequired[str]
```
- For classes, declare member variables at the top of the class body (before `__init__`) so the class shape is obvious at a glance:
- For classes, declare all member variables explicitly with types at the top of the class body (before `__init__`), even when the class is not a dataclass or Pydantic model, so the class shape is obvious at a glance:
```python
from datetime import datetime

View File

@ -88,6 +88,8 @@ def clean_workflow_runs(
"""
Clean workflow runs and related workflow data for free tenants.
"""
from extensions.otel.runtime import flush_telemetry
if (start_from is None) ^ (end_before is None):
raise click.UsageError("--start-from and --end-before must be provided together.")
@ -104,16 +106,27 @@ def clean_workflow_runs(
end_before = now - datetime.timedelta(days=to_days_ago)
before_days = 0
if from_days_ago is not None and to_days_ago is not None:
task_label = f"{from_days_ago}to{to_days_ago}"
elif start_from is None:
task_label = f"before-{before_days}"
else:
task_label = "custom"
start_time = datetime.datetime.now(datetime.UTC)
click.echo(click.style(f"Starting workflow run cleanup at {start_time.isoformat()}.", fg="white"))
WorkflowRunCleanup(
days=before_days,
batch_size=batch_size,
start_from=start_from,
end_before=end_before,
dry_run=dry_run,
).run()
try:
WorkflowRunCleanup(
days=before_days,
batch_size=batch_size,
start_from=start_from,
end_before=end_before,
dry_run=dry_run,
task_label=task_label,
).run()
finally:
flush_telemetry()
end_time = datetime.datetime.now(datetime.UTC)
elapsed = end_time - start_time
@ -659,6 +672,8 @@ def clean_expired_messages(
"""
Clean expired messages and related data for tenants based on clean policy.
"""
from extensions.otel.runtime import flush_telemetry
click.echo(click.style("clean_messages: start clean messages.", fg="green"))
start_at = time.perf_counter()
@ -698,6 +713,13 @@ def clean_expired_messages(
# NOTE: graceful_period will be ignored when billing is disabled.
policy = create_message_clean_policy(graceful_period_days=graceful_period)
if from_days_ago is not None and before_days is not None:
task_label = f"{from_days_ago}to{before_days}"
elif start_from is None and before_days is not None:
task_label = f"before-{before_days}"
else:
task_label = "custom"
# Create and run the cleanup service
if abs_mode:
assert start_from is not None
@ -708,6 +730,7 @@ def clean_expired_messages(
end_before=end_before,
batch_size=batch_size,
dry_run=dry_run,
task_label=task_label,
)
elif from_days_ago is None:
assert before_days is not None
@ -716,6 +739,7 @@ def clean_expired_messages(
days=before_days,
batch_size=batch_size,
dry_run=dry_run,
task_label=task_label,
)
else:
assert before_days is not None
@ -727,6 +751,7 @@ def clean_expired_messages(
end_before=now - datetime.timedelta(days=before_days),
batch_size=batch_size,
dry_run=dry_run,
task_label=task_label,
)
stats = service.run()
@ -752,6 +777,8 @@ def clean_expired_messages(
)
)
raise
finally:
flush_telemetry()
click.echo(click.style("messages cleanup completed.", fg="green"))

View File

@ -14,6 +14,7 @@ from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.enums import DatasetMetadataType, IndexingStatus, SegmentStatus
from models.model import App, AppAnnotationSetting, MessageAnnotation
@ -242,7 +243,7 @@ def migrate_knowledge_vector_database():
dataset_documents = db.session.scalars(
select(DatasetDocument).where(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.indexing_status == IndexingStatus.COMPLETED,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
@ -254,7 +255,7 @@ def migrate_knowledge_vector_database():
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == "completed",
DocumentSegment.status == SegmentStatus.COMPLETED,
DocumentSegment.enabled == True,
)
).all()
@ -430,7 +431,7 @@ def old_metadata_migration():
tenant_id=document.tenant_id,
dataset_id=document.dataset_id,
name=key,
type="string",
type=DatasetMetadataType.STRING,
created_by=document.created_by,
)
db.session.add(dataset_metadata)

View File

@ -1,4 +1,4 @@
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, field_validator
from pydantic_settings import BaseSettings
@ -116,3 +116,13 @@ class RedisConfig(BaseSettings):
description="Maximum connections in the Redis connection pool (unset for library default)",
default=None,
)
@field_validator("REDIS_MAX_CONNECTIONS", mode="before")
@classmethod
def _empty_string_to_none_for_max_conns(cls, v):
"""Allow empty string in env/.env to mean 'unset' (None)."""
if v is None:
return None
if isinstance(v, str) and v.strip() == "":
return None
return v

View File

@ -1,4 +1,4 @@
from typing import Literal, Protocol
from typing import Literal, Protocol, cast
from urllib.parse import quote_plus, urlunparse
from pydantic import AliasChoices, Field
@ -12,16 +12,13 @@ class RedisConfigDefaults(Protocol):
REDIS_PASSWORD: str | None
REDIS_DB: int
REDIS_USE_SSL: bool
REDIS_USE_SENTINEL: bool | None
REDIS_USE_CLUSTERS: bool
class RedisConfigDefaultsMixin:
def _redis_defaults(self: RedisConfigDefaults) -> RedisConfigDefaults:
return self
def _redis_defaults(config: object) -> RedisConfigDefaults:
return cast(RedisConfigDefaults, config)
class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
class RedisPubSubConfig(BaseSettings):
"""
Configuration settings for event transport between API and workers.
@ -74,7 +71,7 @@ class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
)
def _build_default_pubsub_url(self) -> str:
defaults = self._redis_defaults()
defaults = _redis_defaults(self)
if not defaults.REDIS_HOST or not defaults.REDIS_PORT:
raise ValueError("PUBSUB_REDIS_URL must be set when default Redis URL cannot be constructed")
@ -91,11 +88,9 @@ class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
if userinfo:
userinfo = f"{userinfo}@"
host = defaults.REDIS_HOST
port = defaults.REDIS_PORT
db = defaults.REDIS_DB
netloc = f"{userinfo}{host}:{port}"
netloc = f"{userinfo}{defaults.REDIS_HOST}:{defaults.REDIS_PORT}"
return urlunparse((scheme, netloc, f"/{db}", "", "", ""))
@property

View File

@ -54,6 +54,7 @@ from fields.document_fields import document_status_fields
from libs.login import current_account_with_tenant, login_required
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.dataset import DatasetPermission, DatasetPermissionEnum
from models.enums import SegmentStatus
from models.provider_ids import ModelProviderID
from services.api_token_service import ApiTokenCache
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
@ -741,13 +742,15 @@ class DatasetIndexingStatusApi(Resource):
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
.count()
)
total_segments = (
db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.where(
DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT
)
.count()
)
# Create a dictionary with document attributes and additional fields

View File

@ -42,6 +42,7 @@ from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.dataset import DocumentPipelineExecutionLog
from models.enums import IndexingStatus, SegmentStatus
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
from services.file_service import FileService
@ -332,13 +333,16 @@ class DatasetDocumentListApi(Resource):
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
.count()
)
total_segments = (
db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.where(
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
.count()
)
document.completed_segments = completed_segments
@ -503,7 +507,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
if document.indexing_status in {"completed", "error"}:
if document.indexing_status in {IndexingStatus.COMPLETED, IndexingStatus.ERROR}:
raise DocumentAlreadyFinishedError()
data_process_rule = document.dataset_process_rule
@ -573,7 +577,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
data_process_rule_dict = data_process_rule.to_dict() if data_process_rule else {}
extract_settings = []
for document in documents:
if document.indexing_status in {"completed", "error"}:
if document.indexing_status in {IndexingStatus.COMPLETED, IndexingStatus.ERROR}:
raise DocumentAlreadyFinishedError()
data_source_info = document.data_source_info_dict
match document.data_source_type:
@ -671,19 +675,21 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
.count()
)
total_segments = (
db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.where(
DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT
)
.count()
)
# Create a dictionary with document attributes and additional fields
document_dict = {
"id": document.id,
"indexing_status": "paused" if document.is_paused else document.indexing_status,
"indexing_status": IndexingStatus.PAUSED if document.is_paused else document.indexing_status,
"processing_started_at": document.processing_started_at,
"parsing_completed_at": document.parsing_completed_at,
"cleaning_completed_at": document.cleaning_completed_at,
@ -720,20 +726,20 @@ class DocumentIndexingStatusApi(DocumentResource):
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document_id),
DocumentSegment.status != "re_segment",
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
.count()
)
total_segments = (
db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment")
.where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != SegmentStatus.RE_SEGMENT)
.count()
)
# Create a dictionary with document attributes and additional fields
document_dict = {
"id": document.id,
"indexing_status": "paused" if document.is_paused else document.indexing_status,
"indexing_status": IndexingStatus.PAUSED if document.is_paused else document.indexing_status,
"processing_started_at": document.processing_started_at,
"parsing_completed_at": document.parsing_completed_at,
"cleaning_completed_at": document.cleaning_completed_at,
@ -955,7 +961,7 @@ class DocumentProcessingApi(DocumentResource):
match action:
case "pause":
if document.indexing_status != "indexing":
if document.indexing_status != IndexingStatus.INDEXING:
raise InvalidActionError("Document not in indexing state.")
document.paused_by = current_user.id
@ -964,7 +970,7 @@ class DocumentProcessingApi(DocumentResource):
db.session.commit()
case "resume":
if document.indexing_status not in {"paused", "error"}:
if document.indexing_status not in {IndexingStatus.PAUSED, IndexingStatus.ERROR}:
raise InvalidActionError("Document not in paused or error state.")
document.paused_by = None
@ -1169,7 +1175,7 @@ class DocumentRetryApi(DocumentResource):
raise ArchivedDocumentImmutableError()
# 400 if document is completed
if document.indexing_status == "completed":
if document.indexing_status == IndexingStatus.COMPLETED:
raise DocumentAlreadyFinishedError()
retry_documents.append(document)
except Exception:

View File

@ -46,6 +46,8 @@ class PipelineTemplateDetailApi(Resource):
type = request.args.get("type", default="built-in", type=str)
rag_pipeline_service = RagPipelineService()
pipeline_template = rag_pipeline_service.get_pipeline_template_detail(template_id, type)
if pipeline_template is None:
return {"error": "Pipeline template not found from upstream service."}, 404
return pipeline_template, 200

View File

@ -36,6 +36,7 @@ from extensions.ext_database import db
from fields.document_fields import document_fields, document_status_fields
from libs.login import current_user
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import SegmentStatus
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import (
KnowledgeConfig,
@ -622,13 +623,15 @@ class DocumentIndexingStatusApi(DatasetApiResource):
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
.count()
)
total_segments = (
db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.where(
DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT
)
.count()
)
# Create a dictionary with document attributes and additional fields

View File

@ -70,7 +70,14 @@ def handle_webhook(webhook_id: str):
@bp.route("/webhook-debug/<string:webhook_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
def handle_webhook_debug(webhook_id: str):
"""Handle webhook debug calls without triggering production workflow execution."""
"""Handle webhook debug calls without triggering production workflow execution.
The debug webhook endpoint is only for draft inspection flows. It never enqueues
Celery work for the published workflow; instead it dispatches an in-memory debug
event to an active Variable Inspector listener. Returning a clear error when no
listener is registered prevents a misleading 200 response for requests that are
effectively dropped.
"""
try:
webhook_trigger, _, node_config, webhook_data, error = _prepare_webhook_execution(webhook_id, is_debug=True)
if error:
@ -94,11 +101,32 @@ def handle_webhook_debug(webhook_id: str):
"method": webhook_data.get("method"),
},
)
TriggerDebugEventBus.dispatch(
dispatch_count = TriggerDebugEventBus.dispatch(
tenant_id=webhook_trigger.tenant_id,
event=event,
pool_key=pool_key,
)
if dispatch_count == 0:
logger.warning(
"Webhook debug request dropped without an active listener for webhook %s (tenant=%s, app=%s, node=%s)",
webhook_trigger.webhook_id,
webhook_trigger.tenant_id,
webhook_trigger.app_id,
webhook_trigger.node_id,
)
return (
jsonify(
{
"error": "No active debug listener",
"message": (
"The webhook debug URL only works while the Variable Inspector is listening. "
"Use the published webhook URL to execute the workflow in Celery."
),
"execution_url": webhook_trigger.webhook_url,
}
),
409,
)
response_data, status_code = WebhookService.generate_webhook_response(node_config)
return jsonify(response_data), status_code

View File

@ -441,7 +441,7 @@ class BaseAgentRunner(AppRunner):
continue
result.append(self.organize_agent_user_prompt(message))
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
agent_thoughts = message.agent_thoughts
if agent_thoughts:
for agent_thought in agent_thoughts:
tool_names_raw = agent_thought.tool

View File

@ -1,13 +1,36 @@
from collections.abc import Mapping
from typing import Any
from typing import Any, TypedDict
from configs import dify_config
from constants import DEFAULT_FILE_NUMBER_LIMITS
class SystemParametersDict(TypedDict):
image_file_size_limit: int
video_file_size_limit: int
audio_file_size_limit: int
file_size_limit: int
workflow_file_upload_limit: int
class AppParametersDict(TypedDict):
opening_statement: str | None
suggested_questions: list[str]
suggested_questions_after_answer: dict[str, Any]
speech_to_text: dict[str, Any]
text_to_speech: dict[str, Any]
retriever_resource: dict[str, Any]
annotation_reply: dict[str, Any]
more_like_this: dict[str, Any]
user_input_form: list[dict[str, Any]]
sensitive_word_avoidance: dict[str, Any]
file_upload: dict[str, Any]
system_parameters: SystemParametersDict
def get_parameters_from_feature_dict(
*, features_dict: Mapping[str, Any], user_input_form: list[dict[str, Any]]
) -> Mapping[str, Any]:
) -> AppParametersDict:
"""
Mapping from feature dict to webapp parameters
"""

View File

@ -8,6 +8,7 @@ from core.app.app_config.entities import (
ModelConfig,
)
from core.entities.agent_entities import PlanningStrategy
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
from models.model import AppMode, AppModelConfigDict
from services.dataset_service import DatasetService
@ -117,8 +118,10 @@ class DatasetConfigManager:
score_threshold=float(score_threshold_val)
if dataset_configs.get("score_threshold_enabled", False) and score_threshold_val is not None
else None,
reranking_model=reranking_model_val if isinstance(reranking_model_val, dict) else None,
weights=weights_val if isinstance(weights_val, dict) else None,
reranking_model=cast(RerankingModelDict, reranking_model_val)
if isinstance(reranking_model_val, dict)
else None,
weights=cast(WeightsDict, weights_val) if isinstance(weights_val, dict) else None,
reranking_enabled=bool(dataset_configs.get("reranking_enabled", True)),
rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"),
metadata_filtering_mode=cast(

View File

@ -4,6 +4,7 @@ from typing import Any, Literal
from pydantic import BaseModel, Field
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
from dify_graph.file import FileUploadConfig
from dify_graph.model_runtime.entities.llm_entities import LLMMode
from dify_graph.model_runtime.entities.message_entities import PromptMessageRole
@ -194,8 +195,8 @@ class DatasetRetrieveConfigEntity(BaseModel):
top_k: int | None = None
score_threshold: float | None = 0.0
rerank_mode: str | None = "reranking_model"
reranking_model: dict | None = None
weights: dict | None = None
reranking_model: RerankingModelDict | None = None
weights: WeightsDict | None = None
reranking_enabled: bool | None = True
metadata_filtering_mode: Literal["disabled", "automatic", "manual"] | None = "disabled"
metadata_model_config: ModelConfig | None = None

View File

@ -3,7 +3,7 @@ import time
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from datetime import datetime
from typing import Any, NewType, Union
from typing import Any, NewType, TypedDict, Union
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -76,6 +76,20 @@ NodeExecutionId = NewType("NodeExecutionId", str)
logger = logging.getLogger(__name__)
class AccountCreatedByDict(TypedDict):
id: str
name: str
email: str
class EndUserCreatedByDict(TypedDict):
id: str
user: str
CreatedByDict = AccountCreatedByDict | EndUserCreatedByDict
@dataclass(slots=True)
class _NodeSnapshot:
"""In-memory cache for node metadata between start and completion events."""
@ -249,19 +263,19 @@ class WorkflowResponseConverter:
outputs_mapping = graph_runtime_state.outputs or {}
encoded_outputs = WorkflowRuntimeTypeConverter().to_json_encodable(outputs_mapping)
created_by: Mapping[str, object] | None
created_by: CreatedByDict | dict[str, object] = {}
user = self._user
if isinstance(user, Account):
created_by = {
"id": user.id,
"name": user.name,
"email": user.email,
}
else:
created_by = {
"id": user.id,
"user": user.session_id,
}
created_by = AccountCreatedByDict(
id=user.id,
name=user.name,
email=user.email,
)
elif isinstance(user, EndUser):
created_by = EndUserCreatedByDict(
id=user.id,
user=user.session_id,
)
return WorkflowFinishStreamResponse(
task_id=task_id,

View File

@ -6,6 +6,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db
from models.dataset import Dataset
from models.enums import CollectionBindingType
from models.model import App, AppAnnotationSetting, Message, MessageAnnotation
from services.annotation_service import AppAnnotationService
from services.dataset_service import DatasetCollectionBindingService
@ -43,7 +44,7 @@ class AnnotationReplyFeature:
embedding_model_name = collection_binding_detail.model_name
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_provider_name, embedding_model_name, "annotation"
embedding_provider_name, embedding_model_name, CollectionBindingType.ANNOTATION
)
dataset = Dataset(

View File

@ -1,3 +1,5 @@
from typing import TypedDict
from core.tools.signature import sign_tool_file
from dify_graph.file import helpers as file_helpers
from dify_graph.file.enums import FileTransferMethod
@ -6,7 +8,20 @@ from models.model import MessageFile, UploadFile
MAX_TOOL_FILE_EXTENSION_LENGTH = 10
def prepare_file_dict(message_file: MessageFile, upload_files_map: dict[str, UploadFile]) -> dict:
class MessageFileInfoDict(TypedDict):
related_id: str
extension: str
filename: str
size: int
mime_type: str
transfer_method: str
type: str
url: str
upload_file_id: str
remote_url: str | None
def prepare_file_dict(message_file: MessageFile, upload_files_map: dict[str, UploadFile]) -> MessageFileInfoDict:
"""
Prepare file dictionary for message end stream response.

View File

@ -12,7 +12,7 @@ from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import ChildChunk, DatasetQuery, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.enums import CreatorUserRole
from models.enums import CreatorUserRole, DatasetQuerySource
_logger = logging.getLogger(__name__)
@ -36,7 +36,7 @@ class DatasetIndexToolCallbackHandler:
dataset_query = DatasetQuery(
dataset_id=dataset_id,
content=query,
source="app",
source=DatasetQuerySource.APP,
source_app_id=self._app_id,
created_by_role=(
CreatorUserRole.ACCOUNT

View File

@ -30,6 +30,7 @@ from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from libs.datetime_utils import naive_utc_now
from models.engine import db
from models.enums import CredentialSourceType
from models.provider import (
LoadBalancingModelConfig,
Provider,
@ -473,9 +474,21 @@ class ProviderConfiguration(BaseModel):
self.switch_preferred_provider_type(provider_type=ProviderType.CUSTOM, session=session)
else:
# some historical data may have a provider record but not be set as valid
provider_record.is_valid = True
if provider_record.credential_id is None:
provider_record.credential_id = new_record.id
provider_record.updated_at = naive_utc_now()
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=provider_record.id,
cache_type=ProviderCredentialsCacheType.PROVIDER,
)
provider_model_credentials_cache.delete()
self.switch_preferred_provider_type(provider_type=ProviderType.CUSTOM, session=session)
session.commit()
except Exception:
session.rollback()
@ -534,7 +547,7 @@ class ProviderConfiguration(BaseModel):
self._update_load_balancing_configs_with_credential(
credential_id=credential_id,
credential_record=credential_record,
credential_source="provider",
credential_source=CredentialSourceType.PROVIDER,
session=session,
)
except Exception:
@ -611,7 +624,7 @@ class ProviderConfiguration(BaseModel):
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
LoadBalancingModelConfig.credential_id == credential_id,
LoadBalancingModelConfig.credential_source_type == "provider",
LoadBalancingModelConfig.credential_source_type == CredentialSourceType.PROVIDER,
)
lb_configs_using_credential = session.execute(lb_stmt).scalars().all()
try:
@ -1031,7 +1044,7 @@ class ProviderConfiguration(BaseModel):
self._update_load_balancing_configs_with_credential(
credential_id=credential_id,
credential_record=credential_record,
credential_source="custom_model",
credential_source=CredentialSourceType.CUSTOM_MODEL,
session=session,
)
except Exception:
@ -1061,7 +1074,7 @@ class ProviderConfiguration(BaseModel):
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
LoadBalancingModelConfig.credential_id == credential_id,
LoadBalancingModelConfig.credential_source_type == "custom_model",
LoadBalancingModelConfig.credential_source_type == CredentialSourceType.CUSTOM_MODEL,
)
lb_configs_using_credential = session.execute(lb_stmt).scalars().all()
@ -1699,7 +1712,7 @@ class ProviderConfiguration(BaseModel):
provider_model_lb_configs = [
config
for config in model_setting.load_balancing_configs
if config.credential_source_type != "custom_model"
if config.credential_source_type != CredentialSourceType.CUSTOM_MODEL
]
load_balancing_enabled = model_setting.load_balancing_enabled
@ -1757,7 +1770,7 @@ class ProviderConfiguration(BaseModel):
custom_model_lb_configs = [
config
for config in model_setting.load_balancing_configs
if config.credential_source_type != "provider"
if config.credential_source_type != CredentialSourceType.PROVIDER
]
load_balancing_enabled = model_setting.load_balancing_enabled

View File

@ -40,6 +40,7 @@ from libs.datetime_utils import naive_utc_now
from models import Account
from models.dataset import AutomaticRulesConfig, ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.enums import DataSourceType, IndexingStatus, ProcessRuleMode, SegmentStatus
from models.model import UploadFile
from services.feature_service import FeatureService
@ -56,7 +57,7 @@ class IndexingRunner:
logger.exception("consume document failed")
document = db.session.get(DatasetDocument, document_id)
if document:
document.indexing_status = "error"
document.indexing_status = IndexingStatus.ERROR
error_message = getattr(error, "description", str(error))
document.error = str(error_message)
document.stopped_at = naive_utc_now()
@ -219,7 +220,7 @@ class IndexingRunner:
if document_segments:
for document_segment in document_segments:
# transform segment to node
if document_segment.status != "completed":
if document_segment.status != SegmentStatus.COMPLETED:
document = Document(
page_content=document_segment.content,
metadata={
@ -382,7 +383,7 @@ class IndexingRunner:
data_source_info = dataset_document.data_source_info_dict
text_docs = []
match dataset_document.data_source_type:
case "upload_file":
case DataSourceType.UPLOAD_FILE:
if not data_source_info or "upload_file_id" not in data_source_info:
raise ValueError("no upload file found")
stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"])
@ -395,7 +396,7 @@ class IndexingRunner:
document_model=dataset_document.doc_form,
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
case "notion_import":
case DataSourceType.NOTION_IMPORT:
if (
not data_source_info
or "notion_workspace_id" not in data_source_info
@ -417,7 +418,7 @@ class IndexingRunner:
document_model=dataset_document.doc_form,
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
case "website_crawl":
case DataSourceType.WEBSITE_CRAWL:
if (
not data_source_info
or "provider" not in data_source_info
@ -445,7 +446,7 @@ class IndexingRunner:
# update document status to splitting
self._update_document_index_status(
document_id=dataset_document.id,
after_indexing_status="splitting",
after_indexing_status=IndexingStatus.SPLITTING,
extra_update_params={
DatasetDocument.parsing_completed_at: naive_utc_now(),
},
@ -545,7 +546,7 @@ class IndexingRunner:
Clean the document text according to the processing rules.
"""
rules: AutomaticRulesConfig | dict[str, Any]
if processing_rule.mode == "automatic":
if processing_rule.mode == ProcessRuleMode.AUTOMATIC:
rules = DatasetProcessRule.AUTOMATIC_RULES
else:
rules = json.loads(processing_rule.rules) if processing_rule.rules else {}
@ -636,7 +637,7 @@ class IndexingRunner:
# update document status to completed
self._update_document_index_status(
document_id=dataset_document.id,
after_indexing_status="completed",
after_indexing_status=IndexingStatus.COMPLETED,
extra_update_params={
DatasetDocument.tokens: tokens,
DatasetDocument.completed_at: naive_utc_now(),
@ -659,10 +660,10 @@ class IndexingRunner:
DocumentSegment.document_id == document_id,
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.index_node_id.in_(document_ids),
DocumentSegment.status == "indexing",
DocumentSegment.status == SegmentStatus.INDEXING,
).update(
{
DocumentSegment.status: "completed",
DocumentSegment.status: SegmentStatus.COMPLETED,
DocumentSegment.enabled: True,
DocumentSegment.completed_at: naive_utc_now(),
}
@ -703,10 +704,10 @@ class IndexingRunner:
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(document_ids),
DocumentSegment.status == "indexing",
DocumentSegment.status == SegmentStatus.INDEXING,
).update(
{
DocumentSegment.status: "completed",
DocumentSegment.status: SegmentStatus.COMPLETED,
DocumentSegment.enabled: True,
DocumentSegment.completed_at: naive_utc_now(),
}
@ -725,7 +726,7 @@ class IndexingRunner:
@staticmethod
def _update_document_index_status(
document_id: str, after_indexing_status: str, extra_update_params: dict | None = None
document_id: str, after_indexing_status: IndexingStatus, extra_update_params: dict | None = None
):
"""
Update the document indexing status.
@ -803,7 +804,7 @@ class IndexingRunner:
cur_time = naive_utc_now()
self._update_document_index_status(
document_id=dataset_document.id,
after_indexing_status="indexing",
after_indexing_status=IndexingStatus.INDEXING,
extra_update_params={
DatasetDocument.cleaning_completed_at: cur_time,
DatasetDocument.splitting_completed_at: cur_time,
@ -815,7 +816,7 @@ class IndexingRunner:
self._update_segments_by_document(
dataset_document_id=dataset_document.id,
update_params={
DocumentSegment.status: "indexing",
DocumentSegment.status: SegmentStatus.INDEXING,
DocumentSegment.indexing_at: naive_utc_now(),
},
)

View File

@ -196,6 +196,8 @@ class ProviderManager:
if preferred_provider_type_record:
preferred_provider_type = ProviderType.value_of(preferred_provider_type_record.preferred_provider_type)
elif dify_config.EDITION == "CLOUD" and system_configuration.enabled:
preferred_provider_type = ProviderType.SYSTEM
elif custom_configuration.provider or custom_configuration.models:
preferred_provider_type = ProviderType.CUSTOM
elif system_configuration.enabled:

View File

@ -1,3 +1,5 @@
from typing_extensions import TypedDict
from core.model_manager import ModelInstance, ModelManager
from core.rag.data_post_processor.reorder import ReorderRunner
from core.rag.index_processor.constant.query_type import QueryType
@ -10,6 +12,26 @@ from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
class RerankingModelDict(TypedDict):
reranking_provider_name: str
reranking_model_name: str
class VectorSettingDict(TypedDict):
vector_weight: float
embedding_provider_name: str
embedding_model_name: str
class KeywordSettingDict(TypedDict):
keyword_weight: float
class WeightsDict(TypedDict):
vector_setting: VectorSettingDict
keyword_setting: KeywordSettingDict
class DataPostProcessor:
"""Interface for data post-processing document."""
@ -17,8 +39,8 @@ class DataPostProcessor:
self,
tenant_id: str,
reranking_mode: str,
reranking_model: dict | None = None,
weights: dict | None = None,
reranking_model: RerankingModelDict | None = None,
weights: WeightsDict | None = None,
reorder_enabled: bool = False,
):
self.rerank_runner = self._get_rerank_runner(reranking_mode, tenant_id, reranking_model, weights)
@ -45,8 +67,8 @@ class DataPostProcessor:
self,
reranking_mode: str,
tenant_id: str,
reranking_model: dict | None = None,
weights: dict | None = None,
reranking_model: RerankingModelDict | None = None,
weights: WeightsDict | None = None,
) -> BaseRerankRunner | None:
if reranking_mode == RerankMode.WEIGHTED_SCORE and weights:
runner = RerankRunnerFactory.create_rerank_runner(
@ -79,12 +101,14 @@ class DataPostProcessor:
return ReorderRunner()
return None
def _get_rerank_model_instance(self, tenant_id: str, reranking_model: dict | None) -> ModelInstance | None:
def _get_rerank_model_instance(
self, tenant_id: str, reranking_model: RerankingModelDict | None
) -> ModelInstance | None:
if reranking_model:
try:
model_manager = ModelManager()
reranking_provider_name = reranking_model.get("reranking_provider_name")
reranking_model_name = reranking_model.get("reranking_model_name")
reranking_provider_name = reranking_model["reranking_provider_name"]
reranking_model_name = reranking_model["reranking_model_name"]
if not reranking_provider_name or not reranking_model_name:
return None
rerank_model_instance = model_manager.get_model_instance(

View File

@ -1,19 +1,20 @@
import concurrent.futures
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from typing import Any, NotRequired
from flask import Flask, current_app
from sqlalchemy import select
from sqlalchemy.orm import Session, load_only
from typing_extensions import TypedDict
from configs import dify_config
from core.db.session_factory import session_factory
from core.model_manager import ModelManager
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.embedding.retrieval import RetrievalChildChunk, RetrievalSegments
from core.rag.embedding.retrieval import AttachmentInfoDict, RetrievalChildChunk, RetrievalSegments
from core.rag.entities.metadata_entities import MetadataCondition
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
@ -35,7 +36,46 @@ from models.dataset import Document as DatasetDocument
from models.model import UploadFile
from services.external_knowledge_service import ExternalDatasetService
default_retrieval_model = {
class SegmentAttachmentResult(TypedDict):
attachment_info: AttachmentInfoDict
segment_id: str
class SegmentAttachmentInfoResult(TypedDict):
attachment_id: str
attachment_info: AttachmentInfoDict
segment_id: str
class ChildChunkDetail(TypedDict):
id: str
content: str
position: int
score: float
class SegmentChildMapDetail(TypedDict):
max_score: float
child_chunks: list[ChildChunkDetail]
class SegmentRecord(TypedDict):
segment: DocumentSegment
score: NotRequired[float]
child_chunks: NotRequired[list[ChildChunkDetail]]
files: NotRequired[list[AttachmentInfoDict]]
class DefaultRetrievalModelDict(TypedDict):
search_method: RetrievalMethod | str
reranking_enable: bool
reranking_model: RerankingModelDict
top_k: int
score_threshold_enabled: bool
default_retrieval_model: DefaultRetrievalModelDict = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
@ -56,9 +96,9 @@ class RetrievalService:
query: str,
top_k: int = 4,
score_threshold: float | None = 0.0,
reranking_model: dict | None = None,
reranking_model: RerankingModelDict | None = None,
reranking_mode: str = "reranking_model",
weights: dict | None = None,
weights: WeightsDict | None = None,
document_ids_filter: list[str] | None = None,
attachment_ids: list | None = None,
):
@ -235,7 +275,7 @@ class RetrievalService:
query: str,
top_k: int,
score_threshold: float | None,
reranking_model: dict | None,
reranking_model: RerankingModelDict | None,
all_documents: list,
retrieval_method: RetrievalMethod,
exceptions: list,
@ -277,8 +317,8 @@ class RetrievalService:
if documents:
if (
reranking_model
and reranking_model.get("reranking_model_name")
and reranking_model.get("reranking_provider_name")
and reranking_model["reranking_model_name"]
and reranking_model["reranking_provider_name"]
and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH
):
data_post_processor = DataPostProcessor(
@ -288,8 +328,8 @@ class RetrievalService:
model_manager = ModelManager()
is_support_vision = model_manager.check_model_support_vision(
tenant_id=dataset.tenant_id,
provider=reranking_model.get("reranking_provider_name") or "",
model=reranking_model.get("reranking_model_name") or "",
provider=reranking_model["reranking_provider_name"],
model=reranking_model["reranking_model_name"],
model_type=ModelType.RERANK,
)
if is_support_vision:
@ -329,7 +369,7 @@ class RetrievalService:
query: str,
top_k: int,
score_threshold: float | None,
reranking_model: dict | None,
reranking_model: RerankingModelDict | None,
all_documents: list,
retrieval_method: str,
exceptions: list,
@ -349,8 +389,8 @@ class RetrievalService:
if documents:
if (
reranking_model
and reranking_model.get("reranking_model_name")
and reranking_model.get("reranking_provider_name")
and reranking_model["reranking_model_name"]
and reranking_model["reranking_provider_name"]
and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH
):
data_post_processor = DataPostProcessor(
@ -459,7 +499,7 @@ class RetrievalService:
segment_ids: list[str] = []
index_node_segments: list[DocumentSegment] = []
segments: list[DocumentSegment] = []
attachment_map: dict[str, list[dict[str, Any]]] = {}
attachment_map: dict[str, list[AttachmentInfoDict]] = {}
child_chunk_map: dict[str, list[ChildChunk]] = {}
doc_segment_map: dict[str, list[str]] = {}
segment_summary_map: dict[str, str] = {} # Map segment_id to summary content
@ -544,12 +584,12 @@ class RetrievalService:
segment_summary_map[summary.chunk_id] = summary.summary_content
include_segment_ids = set()
segment_child_map: dict[str, dict[str, Any]] = {}
records: list[dict[str, Any]] = []
segment_child_map: dict[str, SegmentChildMapDetail] = {}
records: list[SegmentRecord] = []
for segment in segments:
child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, [])
attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, [])
attachment_infos: list[AttachmentInfoDict] = attachment_map.get(segment.id, [])
ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(segment.document_id)
if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
@ -560,14 +600,14 @@ class RetrievalService:
max_score = summary_score_map.get(segment.id, 0.0)
if child_chunks or attachment_infos:
child_chunk_details = []
child_chunk_details: list[ChildChunkDetail] = []
for child_chunk in child_chunks:
child_document: Document | None = doc_to_document_map.get(child_chunk.index_node_id)
if child_document:
child_score = child_document.metadata.get("score", 0.0)
else:
child_score = 0.0
child_chunk_detail = {
child_chunk_detail: ChildChunkDetail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
@ -580,7 +620,7 @@ class RetrievalService:
if file_document:
max_score = max(max_score, file_document.metadata.get("score", 0.0))
map_detail = {
map_detail: SegmentChildMapDetail = {
"max_score": max_score,
"child_chunks": child_chunk_details,
}
@ -593,7 +633,7 @@ class RetrievalService:
"max_score": summary_score,
"child_chunks": [],
}
record: dict[str, Any] = {
record: SegmentRecord = {
"segment": segment,
}
records.append(record)
@ -617,19 +657,19 @@ class RetrievalService:
if file_doc:
max_score = max(max_score, file_doc.metadata.get("score", 0.0))
record = {
another_record: SegmentRecord = {
"segment": segment,
"score": max_score,
}
records.append(record)
records.append(another_record)
# Add child chunks information to records
for record in records:
if record["segment"].id in segment_child_map:
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
record["score"] = segment_child_map[record["segment"].id]["max_score"] # type: ignore
record["child_chunks"] = segment_child_map[record["segment"].id]["child_chunks"]
record["score"] = segment_child_map[record["segment"].id]["max_score"]
if record["segment"].id in attachment_map:
record["files"] = attachment_map[record["segment"].id] # type: ignore[assignment]
record["files"] = attachment_map[record["segment"].id]
result: list[RetrievalSegments] = []
for record in records:
@ -693,9 +733,9 @@ class RetrievalService:
query: str | None = None,
top_k: int = 4,
score_threshold: float | None = 0.0,
reranking_model: dict | None = None,
reranking_model: RerankingModelDict | None = None,
reranking_mode: str = "reranking_model",
weights: dict | None = None,
weights: WeightsDict | None = None,
document_ids_filter: list[str] | None = None,
attachment_id: str | None = None,
):
@ -807,7 +847,7 @@ class RetrievalService:
@classmethod
def get_segment_attachment_info(
cls, dataset_id: str, tenant_id: str, attachment_id: str, session: Session
) -> dict[str, Any] | None:
) -> SegmentAttachmentResult | None:
upload_file = session.query(UploadFile).where(UploadFile.id == attachment_id).first()
if upload_file:
attachment_binding = (
@ -816,7 +856,7 @@ class RetrievalService:
.first()
)
if attachment_binding:
attachment_info = {
attachment_info: AttachmentInfoDict = {
"id": upload_file.id,
"name": upload_file.name,
"extension": "." + upload_file.extension,
@ -828,8 +868,10 @@ class RetrievalService:
return None
@classmethod
def get_segment_attachment_infos(cls, attachment_ids: list[str], session: Session) -> list[dict[str, Any]]:
attachment_infos = []
def get_segment_attachment_infos(
cls, attachment_ids: list[str], session: Session
) -> list[SegmentAttachmentInfoResult]:
attachment_infos: list[SegmentAttachmentInfoResult] = []
upload_files = session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all()
if upload_files:
upload_file_ids = [upload_file.id for upload_file in upload_files]
@ -843,7 +885,7 @@ class RetrievalService:
if attachment_bindings:
for upload_file in upload_files:
attachment_binding = attachment_binding_map.get(upload_file.id)
attachment_info = {
info: AttachmentInfoDict = {
"id": upload_file.id,
"name": upload_file.name,
"extension": "." + upload_file.extension,
@ -855,7 +897,7 @@ class RetrievalService:
attachment_infos.append(
{
"attachment_id": attachment_binding.attachment_id,
"attachment_info": attachment_info,
"attachment_info": info,
"segment_id": attachment_binding.segment_id,
}
)

View File

@ -1,8 +1,18 @@
from pydantic import BaseModel
from typing_extensions import TypedDict
from models.dataset import DocumentSegment
class AttachmentInfoDict(TypedDict):
id: str
name: str
extension: str
mime_type: str
source_url: str
size: int
class RetrievalChildChunk(BaseModel):
"""Retrieval segments."""
@ -19,5 +29,5 @@ class RetrievalSegments(BaseModel):
segment: DocumentSegment
child_chunks: list[RetrievalChildChunk] | None = None
score: float | None = None
files: list[dict[str, str | int]] | None = None
files: list[AttachmentInfoDict] | None = None
summary: str | None = None # Summary content if retrieved via summary index

View File

@ -9,6 +9,7 @@ from flask import current_app
from sqlalchemy import delete, func, select
from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError
from core.workflow.nodes.knowledge_index.protocols import Preview, PreviewItem, QaPreview
from models.dataset import Dataset, Document, DocumentSegment
@ -51,7 +52,7 @@ class IndexProcessor:
original_document_id: str,
chunks: Mapping[str, Any],
batch: Any,
summary_index_setting: dict | None = None,
summary_index_setting: SummaryIndexSettingDict | None = None,
):
with session_factory.create_session() as session:
document = session.query(Document).filter_by(id=document_id).first()
@ -131,7 +132,12 @@ class IndexProcessor:
}
def get_preview_output(
self, chunks: Any, dataset_id: str, document_id: str, chunk_structure: str, summary_index_setting: dict | None
self,
chunks: Any,
dataset_id: str,
document_id: str,
chunk_structure: str,
summary_index_setting: SummaryIndexSettingDict | None,
) -> Preview:
doc_language = None
with session_factory.create_session() as session:

View File

@ -7,14 +7,16 @@ import os
import re
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, NotRequired, Optional
from urllib.parse import unquote, urlparse
import httpx
from typing_extensions import TypedDict
from configs import dify_config
from core.entities.knowledge_entities import PreviewDetail
from core.helper import ssrf_proxy
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.models.document import AttachmentDocument, Document
@ -35,6 +37,13 @@ if TYPE_CHECKING:
from core.model_manager import ModelInstance
class SummaryIndexSettingDict(TypedDict):
enable: bool
model_name: NotRequired[str]
model_provider_name: NotRequired[str]
summary_prompt: NotRequired[str]
class BaseIndexProcessor(ABC):
"""Interface for extract files."""
@ -51,7 +60,7 @@ class BaseIndexProcessor(ABC):
self,
tenant_id: str,
preview_texts: list[PreviewDetail],
summary_index_setting: dict,
summary_index_setting: SummaryIndexSettingDict,
doc_language: str | None = None,
) -> list[PreviewDetail]:
"""
@ -98,7 +107,7 @@ class BaseIndexProcessor(ABC):
dataset: Dataset,
top_k: int,
score_threshold: float,
reranking_model: dict,
reranking_model: RerankingModelDict,
) -> list[Document]:
raise NotImplementedError

View File

@ -14,6 +14,7 @@ from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
from core.model_manager import ModelInstance
from core.provider_manager import ProviderManager
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector
@ -22,7 +23,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols
@ -175,7 +176,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
dataset: Dataset,
top_k: int,
score_threshold: float,
reranking_model: dict,
reranking_model: RerankingModelDict,
) -> list[Document]:
# Set search parameters.
results = RetrievalService.retrieve(
@ -278,7 +279,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
self,
tenant_id: str,
preview_texts: list[PreviewDetail],
summary_index_setting: dict,
summary_index_setting: SummaryIndexSettingDict,
doc_language: str | None = None,
) -> list[PreviewDetail]:
"""
@ -362,7 +363,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
def generate_summary(
tenant_id: str,
text: str,
summary_index_setting: dict | None = None,
summary_index_setting: SummaryIndexSettingDict | None = None,
segment_id: str | None = None,
document_language: str | None = None,
) -> tuple[str, LLMUsage]:

View File

@ -11,6 +11,7 @@ from core.db.session_factory import session_factory
from core.entities.knowledge_entities import PreviewDetail
from core.model_manager import ModelInstance
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
@ -18,7 +19,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
from core.rag.models.document import AttachmentDocument, ChildDocument, Document, ParentChildStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
@ -215,7 +216,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
dataset: Dataset,
top_k: int,
score_threshold: float,
reranking_model: dict,
reranking_model: RerankingModelDict,
) -> list[Document]:
# Set search parameters.
results = RetrievalService.retrieve(
@ -361,7 +362,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
self,
tenant_id: str,
preview_texts: list[PreviewDetail],
summary_index_setting: dict,
summary_index_setting: SummaryIndexSettingDict,
doc_language: str | None = None,
) -> list[PreviewDetail]:
"""

View File

@ -15,13 +15,14 @@ from core.db.session_factory import session_factory
from core.entities.knowledge_entities import PreviewDetail
from core.llm_generator.llm_generator import LLMGenerator
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
from core.rag.models.document import AttachmentDocument, Document, QAStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols
@ -185,7 +186,7 @@ class QAIndexProcessor(BaseIndexProcessor):
dataset: Dataset,
top_k: int,
score_threshold: float,
reranking_model: dict,
reranking_model: RerankingModelDict,
):
# Set search parameters.
results = RetrievalService.retrieve(
@ -244,7 +245,7 @@ class QAIndexProcessor(BaseIndexProcessor):
self,
tenant_id: str,
preview_texts: list[PreviewDetail],
summary_index_setting: dict,
summary_index_setting: SummaryIndexSettingDict,
doc_language: str | None = None,
) -> list[PreviewDetail]:
"""

View File

@ -31,7 +31,7 @@ from core.ops.utils import measure_time
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
@ -83,7 +83,7 @@ from models.dataset import (
)
from models.dataset import Document as DatasetDocument
from models.dataset import Document as DocumentModel
from models.enums import CreatorUserRole
from models.enums import CreatorUserRole, DatasetQuerySource
from services.external_knowledge_service import ExternalDatasetService
from services.feature_service import FeatureService
@ -727,8 +727,8 @@ class DatasetRetrieval:
top_k: int,
score_threshold: float,
reranking_mode: str,
reranking_model: dict | None = None,
weights: dict[str, Any] | None = None,
reranking_model: RerankingModelDict | None = None,
weights: WeightsDict | None = None,
reranking_enable: bool = True,
message_id: str | None = None,
metadata_filter_document_ids: dict[str, list[str]] | None = None,
@ -1008,7 +1008,7 @@ class DatasetRetrieval:
dataset_query = DatasetQuery(
dataset_id=dataset_id,
content=json.dumps(contents),
source="app",
source=DatasetQuerySource.APP,
source_app_id=app_id,
created_by_role=CreatorUserRole(user_from),
created_by=user_id,
@ -1181,8 +1181,8 @@ class DatasetRetrieval:
hit_callbacks=[hit_callback],
return_resource=return_resource,
retriever_from=invoke_from.to_source(),
reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"),
reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"),
reranking_provider_name=retrieve_config.reranking_model["reranking_provider_name"],
reranking_model_name=retrieve_config.reranking_model["reranking_model_name"],
)
tools.append(tool)
@ -1685,8 +1685,8 @@ class DatasetRetrieval:
tenant_id: str,
reranking_enable: bool,
reranking_mode: str,
reranking_model: dict | None,
weights: dict[str, Any] | None,
reranking_model: RerankingModelDict | None,
weights: WeightsDict | None,
top_k: int,
score_threshold: float,
query: str | None,

View File

@ -2,6 +2,7 @@ import concurrent.futures
import logging
from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary
from services.summary_index_service import SummaryIndexService
from tasks.generate_summary_index_task import generate_summary_index_task
@ -11,7 +12,11 @@ logger = logging.getLogger(__name__)
class SummaryIndex:
def generate_and_vectorize_summary(
self, dataset_id: str, document_id: str, is_preview: bool, summary_index_setting: dict | None = None
self,
dataset_id: str,
document_id: str,
is_preview: bool,
summary_index_setting: SummaryIndexSettingDict | None = None,
) -> None:
if is_preview:
with session_factory.create_session() as session:

View File

@ -72,6 +72,11 @@ class ApiProviderControllerItem(TypedDict):
controller: ApiToolProviderController
class EmojiIconDict(TypedDict):
background: str
content: str
class ToolManager:
_builtin_provider_lock = Lock()
_hardcoded_providers: dict[str, BuiltinToolProviderController] = {}
@ -916,7 +921,7 @@ class ToolManager:
)
@classmethod
def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str]:
def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict:
try:
workflow_provider: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
@ -933,7 +938,7 @@ class ToolManager:
return {"background": "#252525", "content": "\ud83d\ude01"}
@classmethod
def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str]:
def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict:
try:
api_provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
@ -950,7 +955,7 @@ class ToolManager:
return {"background": "#252525", "content": "\ud83d\ude01"}
@classmethod
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str] | str:
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict | dict[str, str] | str:
try:
with Session(db.engine) as session:
mcp_service = MCPToolManageService(session=session)
@ -970,7 +975,7 @@ class ToolManager:
tenant_id: str,
provider_type: ToolProviderType,
provider_id: str,
) -> str | Mapping[str, str]:
) -> str | EmojiIconDict | dict[str, str]:
"""
get the tool icon

View File

@ -1,5 +1,4 @@
import threading
from typing import Any
from flask import Flask, current_app
from pydantic import BaseModel, Field
@ -13,11 +12,12 @@ from core.rag.models.document import Document as RagDocument
from core.rag.rerank.rerank_model import RerankModelRunner
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from core.tools.utils.dataset_retriever.dataset_retriever_tool import DefaultRetrievalModelDict
from dify_graph.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
default_retrieval_model: dict[str, Any] = {
default_retrieval_model: DefaultRetrievalModelDict = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},

View File

@ -1,9 +1,10 @@
from typing import Any, cast
from typing import NotRequired, TypedDict, cast
from pydantic import BaseModel, Field
from sqlalchemy import select
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities.context_entities import DocumentContext
@ -16,7 +17,19 @@ from models.dataset import Dataset
from models.dataset import Document as DatasetDocument
from services.external_knowledge_service import ExternalDatasetService
default_retrieval_model: dict[str, Any] = {
class DefaultRetrievalModelDict(TypedDict):
search_method: RetrievalMethod
reranking_enable: bool
reranking_model: RerankingModelDict
reranking_mode: NotRequired[str]
weights: NotRequired[WeightsDict | None]
score_threshold: NotRequired[float]
top_k: int
score_threshold_enabled: bool
default_retrieval_model: DefaultRetrievalModelDict = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
@ -125,7 +138,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
if metadata_condition and not document_ids_filter:
return ""
# get retrieval model , if the model is not setting , using default
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
retrieval_model = dataset.retrieval_model or default_retrieval_model
retrieval_resource_list: list[RetrievalSourceMetadata] = []
if dataset.indexing_technique == "economy":
# use keyword table query

View File

@ -1,4 +1,5 @@
import re
from collections.abc import Mapping
from json import dumps as json_dumps
from json import loads as json_loads
from json.decoder import JSONDecodeError
@ -20,10 +21,18 @@ class InterfaceDict(TypedDict):
operation: dict[str, Any]
class OpenAPISpecDict(TypedDict):
openapi: str
info: dict[str, str]
servers: list[dict[str, Any]]
paths: dict[str, Any]
components: dict[str, Any]
class ApiBasedToolSchemaParser:
@staticmethod
def parse_openapi_to_tool_bundle(
openapi: dict, extra_info: dict | None = None, warning: dict | None = None
openapi: Mapping[str, Any], extra_info: dict | None = None, warning: dict | None = None
) -> list[ApiToolBundle]:
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
@ -277,7 +286,7 @@ class ApiBasedToolSchemaParser:
@staticmethod
def parse_swagger_to_openapi(
swagger: dict, extra_info: dict | None = None, warning: dict | None = None
) -> dict[str, Any]:
) -> OpenAPISpecDict:
warning = warning or {}
"""
parse swagger to openapi
@ -293,7 +302,7 @@ class ApiBasedToolSchemaParser:
if len(servers) == 0:
raise ToolApiSchemaError("No server found in the swagger yaml.")
converted_openapi: dict[str, Any] = {
converted_openapi: OpenAPISpecDict = {
"openapi": "3.0.0",
"info": {
"title": info.get("title", "Swagger"),

View File

@ -2,6 +2,7 @@ from typing import Literal, Union
from pydantic import BaseModel
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE
from dify_graph.entities.base_node_data import BaseNodeData
@ -161,4 +162,4 @@ class KnowledgeIndexNodeData(BaseNodeData):
chunk_structure: str
index_chunk_variable_selector: list[str]
indexing_technique: str | None = None
summary_index_setting: dict | None = None
summary_index_setting: SummaryIndexSettingDict | None = None

View File

@ -3,6 +3,7 @@ from collections.abc import Mapping
from typing import TYPE_CHECKING, Any
from core.rag.index_processor.index_processor import IndexProcessor
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
from core.rag.summary_index.summary_index import SummaryIndex
from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE
from dify_graph.entities.graph_config import NodeConfigDict
@ -127,7 +128,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
is_preview: bool,
batch: Any,
chunks: Mapping[str, Any],
summary_index_setting: dict | None = None,
summary_index_setting: SummaryIndexSettingDict | None = None,
):
if not document_id:
raise KnowledgeIndexNodeError("document_id is required.")

View File

@ -9,6 +9,7 @@ from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from dify_graph.entities import GraphInitParams
from dify_graph.entities.graph_config import NodeConfigDict
@ -201,8 +202,8 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
if node_data.multiple_retrieval_config is None:
raise ValueError("multiple_retrieval_config is required")
reranking_model = None
weights = None
reranking_model: RerankingModelDict | None = None
weights: WeightsDict | None = None
match node_data.multiple_retrieval_config.reranking_mode:
case "reranking_model":
if node_data.multiple_retrieval_config.reranking_model:

View File

@ -2,6 +2,7 @@ from typing import Any, Literal, Protocol
from pydantic import BaseModel, Field
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
from dify_graph.model_runtime.entities import LLMUsage
from dify_graph.nodes.llm.entities import ModelConfig
@ -75,8 +76,8 @@ class KnowledgeRetrievalRequest(BaseModel):
top_k: int = Field(default=0, description="Number of top results to return")
score_threshold: float = Field(default=0.0, description="Minimum relevance score threshold")
reranking_mode: str = Field(default="reranking_model", description="Reranking strategy")
reranking_model: dict | None = Field(default=None, description="Reranking model configuration")
weights: dict[str, Any] | None = Field(default=None, description="Weights for weighted score reranking")
reranking_model: RerankingModelDict | None = Field(default=None, description="Reranking model configuration")
weights: WeightsDict | None = Field(default=None, description="Weights for weighted score reranking")
reranking_enable: bool = Field(default=True, description="Whether reranking is enabled")
attachment_ids: list[str] | None = Field(default=None, description="List of attachment file IDs for retrieval")

View File

@ -101,7 +101,6 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
timeout=self._get_request_timeout(self.node_data),
variable_pool=self.graph_runtime_state.variable_pool,
http_request_config=self._http_request_config,
max_retries=0,
ssl_verify=self.node_data.ssl_verify,
http_client=self._http_client,
file_manager=self._file_manager,

View File

@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any
from dify_graph.file.models import File
if TYPE_CHECKING:
pass
from dify_graph.variables.segments import Segment
class ArrayValidation(StrEnum):
@ -219,7 +219,7 @@ class SegmentType(StrEnum):
return _ARRAY_ELEMENT_TYPES_MAPPING.get(self)
@staticmethod
def get_zero_value(t: SegmentType):
def get_zero_value(t: SegmentType) -> Segment:
# Lazy import to avoid circular dependency
from factories import variable_factory

View File

@ -10,6 +10,7 @@ from events.document_index_event import document_index_created
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import Document
from models.enums import IndexingStatus
logger = logging.getLogger(__name__)
@ -35,7 +36,7 @@ def handle(sender, **kwargs):
if not document:
raise NotFound("Document not found")
document.indexing_status = "parsing"
document.indexing_status = IndexingStatus.PARSING
document.processing_started_at = naive_utc_now()
documents.append(document)
db.session.add(document)

View File

@ -1,3 +1,5 @@
from typing import Protocol, cast
from fastopenapi.routers import FlaskRouter
from flask_cors import CORS
@ -9,6 +11,10 @@ from extensions.ext_blueprints import AUTHENTICATED_HEADERS, EXPOSED_HEADERS
DOCS_PREFIX = "/fastopenapi"
class SupportsIncludeRouter(Protocol):
def include_router(self, router: object, *, prefix: str = "") -> None: ...
def init_app(app: DifyApp) -> None:
docs_enabled = dify_config.SWAGGER_UI_ENABLED
docs_url = f"{DOCS_PREFIX}/docs" if docs_enabled else None
@ -36,7 +42,7 @@ def init_app(app: DifyApp) -> None:
_ = remote_files
_ = setup
router.include_router(console_router, prefix="/console/api")
cast(SupportsIncludeRouter, router).include_router(console_router, prefix="/console/api")
CORS(
app,
resources={r"/console/api/.*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},

View File

@ -5,7 +5,7 @@ from typing import Union
from celery.signals import worker_init
from flask_login import user_loaded_from_request, user_logged_in
from opentelemetry import trace
from opentelemetry import metrics, trace
from opentelemetry.propagate import set_global_textmap
from opentelemetry.propagators.b3 import B3MultiFormat
from opentelemetry.propagators.composite import CompositePropagator
@ -31,9 +31,29 @@ def setup_context_propagation() -> None:
def shutdown_tracer() -> None:
flush_telemetry()
def flush_telemetry() -> None:
"""
Best-effort flush for telemetry providers.
This is mainly used by short-lived command processes (e.g. Kubernetes CronJob)
so counters/histograms are exported before the process exits.
"""
provider = trace.get_tracer_provider()
if hasattr(provider, "force_flush"):
provider.force_flush()
try:
provider.force_flush()
except Exception:
logger.exception("otel: failed to flush trace provider")
metric_provider = metrics.get_meter_provider()
if hasattr(metric_provider, "force_flush"):
try:
metric_provider.force_flush()
except Exception:
logger.exception("otel: failed to flush metric provider")
def is_celery_worker():

View File

@ -55,7 +55,7 @@ class TypeMismatchError(Exception):
# Define the constant
SEGMENT_TO_VARIABLE_MAP = {
SEGMENT_TO_VARIABLE_MAP: Mapping[type[Segment], type[VariableBase]] = {
ArrayAnySegment: ArrayAnyVariable,
ArrayBooleanSegment: ArrayBooleanVariable,
ArrayFileSegment: ArrayFileVariable,
@ -296,13 +296,11 @@ def segment_to_variable(
raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}")
variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type]
return cast(
VariableBase,
variable_class(
id=id,
name=name,
description=description,
value=segment.value,
selector=list(selector),
),
return variable_class(
id=id,
name=name,
description=description,
value_type=segment.value_type,
value=segment.value,
selector=list(selector),
)

View File

@ -32,6 +32,11 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def _stream_with_request_context(response: object) -> Any:
"""Bridge Flask's loosely-typed streaming helper without leaking casts into callers."""
return cast(Any, stream_with_context)(response)
def escape_like_pattern(pattern: str) -> str:
"""
Escape special characters in a string for safe use in SQL LIKE patterns.
@ -286,22 +291,32 @@ def generate_text_hash(text: str) -> str:
return sha256(hash_text.encode()).hexdigest()
def compact_generate_response(response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
if isinstance(response, dict):
def compact_generate_response(
response: Mapping[str, Any] | Generator[str, None, None] | RateLimitGenerator,
) -> Response:
if isinstance(response, Mapping):
return Response(
response=json.dumps(jsonable_encoder(response)),
status=200,
content_type="application/json; charset=utf-8",
)
else:
stream_response = response
def generate() -> Generator:
yield from response
def generate() -> Generator[str, None, None]:
yield from stream_response
return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream")
return Response(
_stream_with_request_context(generate()),
status=200,
mimetype="text/event-stream",
)
def length_prefixed_response(magic_number: int, response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
def length_prefixed_response(
magic_number: int,
response: Mapping[str, Any] | BaseModel | Generator[str | bytes, None, None] | RateLimitGenerator,
) -> Response:
"""
This function is used to return a response with a length prefix.
Magic number is a one byte number that indicates the type of the response.
@ -332,7 +347,7 @@ def length_prefixed_response(magic_number: int, response: Union[Mapping, Generat
# | Magic Number 1byte | Reserved 1byte | Header Length 2bytes | Data Length 4bytes | Reserved 6bytes | Data
return struct.pack("<BBHI", magic_number, 0, header_length, data_length) + b"\x00" * 6 + response
if isinstance(response, dict):
if isinstance(response, Mapping):
return Response(
response=pack_response_with_length_prefix(json.dumps(jsonable_encoder(response)).encode("utf-8")),
status=200,
@ -345,14 +360,20 @@ def length_prefixed_response(magic_number: int, response: Union[Mapping, Generat
mimetype="application/json",
)
def generate() -> Generator:
for chunk in response:
stream_response = response
def generate() -> Generator[bytes, None, None]:
for chunk in stream_response:
if isinstance(chunk, str):
yield pack_response_with_length_prefix(chunk.encode("utf-8"))
else:
yield pack_response_with_length_prefix(chunk)
return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream")
return Response(
_stream_with_request_context(generate()),
status=200,
mimetype="text/event-stream",
)
class TokenManager:

View File

@ -77,12 +77,14 @@ def login_required(func: Callable[P, R]) -> Callable[P, R | ResponseReturnValue]
@wraps(func)
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R | ResponseReturnValue:
if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
pass
elif current_user is not None and not current_user.is_authenticated:
return current_app.ensure_sync(func)(*args, **kwargs)
user = _get_user()
if user is None or not user.is_authenticated:
return current_app.login_manager.unauthorized() # type: ignore
# we put csrf validation here for less conflicts
# TODO: maybe find a better place for it.
check_csrf_token(request, current_user.id)
check_csrf_token(request, user.id)
return current_app.ensure_sync(func)(*args, **kwargs)
return decorated_view

View File

@ -7,9 +7,10 @@ https://github.com/django/django/blob/main/django/utils/module_loading.py
import sys
from importlib import import_module
from typing import Any
def cached_import(module_path: str, class_name: str):
def cached_import(module_path: str, class_name: str) -> Any:
"""
Import a module and return the named attribute/class from it, with caching.
@ -20,16 +21,14 @@ def cached_import(module_path: str, class_name: str):
Returns:
The imported attribute/class
"""
if not (
(module := sys.modules.get(module_path))
and (spec := getattr(module, "__spec__", None))
and getattr(spec, "_initializing", False) is False
):
module = sys.modules.get(module_path)
spec = getattr(module, "__spec__", None) if module is not None else None
if module is None or getattr(spec, "_initializing", False):
module = import_module(module_path)
return getattr(module, class_name)
def import_string(dotted_path: str):
def import_string(dotted_path: str) -> Any:
"""
Import a dotted module path and return the attribute/class designated by
the last name in the path. Raise ImportError if the import failed.

View File

@ -1,7 +1,48 @@
import sys
import urllib.parse
from dataclasses import dataclass
from typing import NotRequired
import httpx
from pydantic import TypeAdapter
if sys.version_info >= (3, 12):
from typing import TypedDict
else:
from typing_extensions import TypedDict
JsonObject = dict[str, object]
JsonObjectList = list[JsonObject]
JSON_OBJECT_ADAPTER = TypeAdapter(JsonObject)
JSON_OBJECT_LIST_ADAPTER = TypeAdapter(JsonObjectList)
class AccessTokenResponse(TypedDict, total=False):
access_token: str
class GitHubEmailRecord(TypedDict, total=False):
email: str
primary: bool
class GitHubRawUserInfo(TypedDict):
id: int | str
login: str
name: NotRequired[str]
email: NotRequired[str]
class GoogleRawUserInfo(TypedDict):
sub: str
email: str
ACCESS_TOKEN_RESPONSE_ADAPTER = TypeAdapter(AccessTokenResponse)
GITHUB_RAW_USER_INFO_ADAPTER = TypeAdapter(GitHubRawUserInfo)
GITHUB_EMAIL_RECORDS_ADAPTER = TypeAdapter(list[GitHubEmailRecord])
GOOGLE_RAW_USER_INFO_ADAPTER = TypeAdapter(GoogleRawUserInfo)
@dataclass
@ -11,26 +52,38 @@ class OAuthUserInfo:
email: str
def _json_object(response: httpx.Response) -> JsonObject:
return JSON_OBJECT_ADAPTER.validate_python(response.json())
def _json_list(response: httpx.Response) -> JsonObjectList:
return JSON_OBJECT_LIST_ADAPTER.validate_python(response.json())
class OAuth:
client_id: str
client_secret: str
redirect_uri: str
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
self.client_id = client_id
self.client_secret = client_secret
self.redirect_uri = redirect_uri
def get_authorization_url(self):
def get_authorization_url(self, invite_token: str | None = None) -> str:
raise NotImplementedError()
def get_access_token(self, code: str):
def get_access_token(self, code: str) -> str:
raise NotImplementedError()
def get_raw_user_info(self, token: str):
def get_raw_user_info(self, token: str) -> JsonObject:
raise NotImplementedError()
def get_user_info(self, token: str) -> OAuthUserInfo:
raw_info = self.get_raw_user_info(token)
return self._transform_user_info(raw_info)
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo:
raise NotImplementedError()
@ -40,7 +93,7 @@ class GitHubOAuth(OAuth):
_USER_INFO_URL = "https://api.github.com/user"
_EMAIL_INFO_URL = "https://api.github.com/user/emails"
def get_authorization_url(self, invite_token: str | None = None):
def get_authorization_url(self, invite_token: str | None = None) -> str:
params = {
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
@ -50,7 +103,7 @@ class GitHubOAuth(OAuth):
params["state"] = invite_token
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
def get_access_token(self, code: str):
def get_access_token(self, code: str) -> str:
data = {
"client_id": self.client_id,
"client_secret": self.client_secret,
@ -60,7 +113,7 @@ class GitHubOAuth(OAuth):
headers = {"Accept": "application/json"}
response = httpx.post(self._TOKEN_URL, data=data, headers=headers)
response_json = response.json()
response_json = ACCESS_TOKEN_RESPONSE_ADAPTER.validate_python(_json_object(response))
access_token = response_json.get("access_token")
if not access_token:
@ -68,23 +121,24 @@ class GitHubOAuth(OAuth):
return access_token
def get_raw_user_info(self, token: str):
def get_raw_user_info(self, token: str) -> JsonObject:
headers = {"Authorization": f"token {token}"}
response = httpx.get(self._USER_INFO_URL, headers=headers)
response.raise_for_status()
user_info = response.json()
user_info = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(_json_object(response))
email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers)
email_info = email_response.json()
primary_email: dict = next((email for email in email_info if email["primary"] == True), {})
email_info = GITHUB_EMAIL_RECORDS_ADAPTER.validate_python(_json_list(email_response))
primary_email = next((email for email in email_info if email.get("primary") is True), None)
return {**user_info, "email": primary_email.get("email", "")}
return {**user_info, "email": primary_email.get("email", "") if primary_email else ""}
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
email = raw_info.get("email")
def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo:
payload = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(raw_info)
email = payload.get("email")
if not email:
email = f"{raw_info['id']}+{raw_info['login']}@users.noreply.github.com"
return OAuthUserInfo(id=str(raw_info["id"]), name=raw_info["name"], email=email)
email = f"{payload['id']}+{payload['login']}@users.noreply.github.com"
return OAuthUserInfo(id=str(payload["id"]), name=str(payload.get("name", "")), email=email)
class GoogleOAuth(OAuth):
@ -92,7 +146,7 @@ class GoogleOAuth(OAuth):
_TOKEN_URL = "https://oauth2.googleapis.com/token"
_USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo"
def get_authorization_url(self, invite_token: str | None = None):
def get_authorization_url(self, invite_token: str | None = None) -> str:
params = {
"client_id": self.client_id,
"response_type": "code",
@ -103,7 +157,7 @@ class GoogleOAuth(OAuth):
params["state"] = invite_token
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
def get_access_token(self, code: str):
def get_access_token(self, code: str) -> str:
data = {
"client_id": self.client_id,
"client_secret": self.client_secret,
@ -114,7 +168,7 @@ class GoogleOAuth(OAuth):
headers = {"Accept": "application/json"}
response = httpx.post(self._TOKEN_URL, data=data, headers=headers)
response_json = response.json()
response_json = ACCESS_TOKEN_RESPONSE_ADAPTER.validate_python(_json_object(response))
access_token = response_json.get("access_token")
if not access_token:
@ -122,11 +176,12 @@ class GoogleOAuth(OAuth):
return access_token
def get_raw_user_info(self, token: str):
def get_raw_user_info(self, token: str) -> JsonObject:
headers = {"Authorization": f"Bearer {token}"}
response = httpx.get(self._USER_INFO_URL, headers=headers)
response.raise_for_status()
return response.json()
return _json_object(response)
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
return OAuthUserInfo(id=str(raw_info["sub"]), name="", email=raw_info["email"])
def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo:
payload = GOOGLE_RAW_USER_INFO_ADAPTER.validate_python(raw_info)
return OAuthUserInfo(id=str(payload["sub"]), name="", email=payload["email"])

View File

@ -1,25 +1,57 @@
import sys
import urllib.parse
from typing import Any
from typing import Any, Literal
import httpx
from flask_login import current_user
from pydantic import TypeAdapter
from sqlalchemy import select
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.source import DataSourceOauthBinding
if sys.version_info >= (3, 12):
from typing import TypedDict
else:
from typing_extensions import TypedDict
class NotionPageSummary(TypedDict):
page_id: str
page_name: str
page_icon: dict[str, str] | None
parent_id: str
type: Literal["page", "database"]
class NotionSourceInfo(TypedDict):
workspace_name: str | None
workspace_icon: str | None
workspace_id: str | None
pages: list[NotionPageSummary]
total: int
SOURCE_INFO_STORAGE_ADAPTER = TypeAdapter(dict[str, object])
NOTION_SOURCE_INFO_ADAPTER = TypeAdapter(NotionSourceInfo)
NOTION_PAGE_SUMMARY_ADAPTER = TypeAdapter(NotionPageSummary)
class OAuthDataSource:
client_id: str
client_secret: str
redirect_uri: str
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
self.client_id = client_id
self.client_secret = client_secret
self.redirect_uri = redirect_uri
def get_authorization_url(self):
def get_authorization_url(self) -> str:
raise NotImplementedError()
def get_access_token(self, code: str):
def get_access_token(self, code: str) -> None:
raise NotImplementedError()
@ -30,7 +62,7 @@ class NotionOAuth(OAuthDataSource):
_NOTION_BLOCK_SEARCH = "https://api.notion.com/v1/blocks"
_NOTION_BOT_USER = "https://api.notion.com/v1/users/me"
def get_authorization_url(self):
def get_authorization_url(self) -> str:
params = {
"client_id": self.client_id,
"response_type": "code",
@ -39,7 +71,7 @@ class NotionOAuth(OAuthDataSource):
}
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
def get_access_token(self, code: str):
def get_access_token(self, code: str) -> None:
data = {"code": code, "grant_type": "authorization_code", "redirect_uri": self.redirect_uri}
headers = {"Accept": "application/json"}
auth = (self.client_id, self.client_secret)
@ -54,13 +86,12 @@ class NotionOAuth(OAuthDataSource):
workspace_id = response_json.get("workspace_id")
# get all authorized pages
pages = self.get_authorized_pages(access_token)
source_info = {
"workspace_name": workspace_name,
"workspace_icon": workspace_icon,
"workspace_id": workspace_id,
"pages": pages,
"total": len(pages),
}
source_info = self._build_source_info(
workspace_name=workspace_name,
workspace_icon=workspace_icon,
workspace_id=workspace_id,
pages=pages,
)
# save data source binding
data_source_binding = db.session.scalar(
select(DataSourceOauthBinding).where(
@ -70,7 +101,7 @@ class NotionOAuth(OAuthDataSource):
)
)
if data_source_binding:
data_source_binding.source_info = source_info
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
db.session.commit()
@ -78,25 +109,24 @@ class NotionOAuth(OAuthDataSource):
new_data_source_binding = DataSourceOauthBinding(
tenant_id=current_user.current_tenant_id,
access_token=access_token,
source_info=source_info,
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
provider="notion",
)
db.session.add(new_data_source_binding)
db.session.commit()
def save_internal_access_token(self, access_token: str):
def save_internal_access_token(self, access_token: str) -> None:
workspace_name = self.notion_workspace_name(access_token)
workspace_icon = None
workspace_id = current_user.current_tenant_id
# get all authorized pages
pages = self.get_authorized_pages(access_token)
source_info = {
"workspace_name": workspace_name,
"workspace_icon": workspace_icon,
"workspace_id": workspace_id,
"pages": pages,
"total": len(pages),
}
source_info = self._build_source_info(
workspace_name=workspace_name,
workspace_icon=workspace_icon,
workspace_id=workspace_id,
pages=pages,
)
# save data source binding
data_source_binding = db.session.scalar(
select(DataSourceOauthBinding).where(
@ -106,7 +136,7 @@ class NotionOAuth(OAuthDataSource):
)
)
if data_source_binding:
data_source_binding.source_info = source_info
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
db.session.commit()
@ -114,13 +144,13 @@ class NotionOAuth(OAuthDataSource):
new_data_source_binding = DataSourceOauthBinding(
tenant_id=current_user.current_tenant_id,
access_token=access_token,
source_info=source_info,
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
provider="notion",
)
db.session.add(new_data_source_binding)
db.session.commit()
def sync_data_source(self, binding_id: str):
def sync_data_source(self, binding_id: str) -> None:
# save data source binding
data_source_binding = db.session.scalar(
select(DataSourceOauthBinding).where(
@ -134,23 +164,22 @@ class NotionOAuth(OAuthDataSource):
if data_source_binding:
# get all authorized pages
pages = self.get_authorized_pages(data_source_binding.access_token)
source_info = data_source_binding.source_info
new_source_info = {
"workspace_name": source_info["workspace_name"],
"workspace_icon": source_info["workspace_icon"],
"workspace_id": source_info["workspace_id"],
"pages": pages,
"total": len(pages),
}
data_source_binding.source_info = new_source_info
source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info)
new_source_info = self._build_source_info(
workspace_name=source_info["workspace_name"],
workspace_icon=source_info["workspace_icon"],
workspace_id=source_info["workspace_id"],
pages=pages,
)
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info)
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
db.session.commit()
else:
raise ValueError("Data source binding not found")
def get_authorized_pages(self, access_token: str):
pages = []
def get_authorized_pages(self, access_token: str) -> list[NotionPageSummary]:
pages: list[NotionPageSummary] = []
page_results = self.notion_page_search(access_token)
database_results = self.notion_database_search(access_token)
# get page detail
@ -187,7 +216,7 @@ class NotionOAuth(OAuthDataSource):
"parent_id": parent_id,
"type": "page",
}
pages.append(page)
pages.append(NOTION_PAGE_SUMMARY_ADAPTER.validate_python(page))
# get database detail
for database_result in database_results:
page_id = database_result["id"]
@ -220,11 +249,11 @@ class NotionOAuth(OAuthDataSource):
"parent_id": parent_id,
"type": "database",
}
pages.append(page)
pages.append(NOTION_PAGE_SUMMARY_ADAPTER.validate_python(page))
return pages
def notion_page_search(self, access_token: str):
results = []
def notion_page_search(self, access_token: str) -> list[dict[str, Any]]:
results: list[dict[str, Any]] = []
next_cursor = None
has_more = True
@ -249,7 +278,7 @@ class NotionOAuth(OAuthDataSource):
return results
def notion_block_parent_page_id(self, access_token: str, block_id: str):
def notion_block_parent_page_id(self, access_token: str, block_id: str) -> str:
headers = {
"Authorization": f"Bearer {access_token}",
"Notion-Version": "2022-06-28",
@ -265,7 +294,7 @@ class NotionOAuth(OAuthDataSource):
return self.notion_block_parent_page_id(access_token, parent[parent_type])
return parent[parent_type]
def notion_workspace_name(self, access_token: str):
def notion_workspace_name(self, access_token: str) -> str:
headers = {
"Authorization": f"Bearer {access_token}",
"Notion-Version": "2022-06-28",
@ -279,8 +308,8 @@ class NotionOAuth(OAuthDataSource):
return user_info["workspace_name"]
return "workspace"
def notion_database_search(self, access_token: str):
results = []
def notion_database_search(self, access_token: str) -> list[dict[str, Any]]:
results: list[dict[str, Any]] = []
next_cursor = None
has_more = True
@ -303,3 +332,19 @@ class NotionOAuth(OAuthDataSource):
next_cursor = response_json.get("next_cursor", None)
return results
@staticmethod
def _build_source_info(
*,
workspace_name: str | None,
workspace_icon: str | None,
workspace_id: str | None,
pages: list[NotionPageSummary],
) -> NotionSourceInfo:
return {
"workspace_name": workspace_name,
"workspace_icon": workspace_icon,
"workspace_id": workspace_id,
"pages": pages,
"total": len(pages),
}

View File

@ -177,13 +177,11 @@ class Account(UserMixin, TypeBase):
@classmethod
def get_by_openid(cls, provider: str, open_id: str):
account_integrate = (
db.session.query(AccountIntegrate)
.where(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id)
.one_or_none()
)
account_integrate = db.session.execute(
select(AccountIntegrate).where(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id)
).scalar_one_or_none()
if account_integrate:
return db.session.query(Account).where(Account.id == account_integrate.account_id).one_or_none()
return db.session.scalar(select(Account).where(Account.id == account_integrate.account_id))
return None
# check current_user.current_tenant.current_role in ['admin', 'owner']

View File

@ -8,6 +8,7 @@ import os
import pickle
import re
import time
from collections.abc import Sequence
from datetime import datetime
from json import JSONDecodeError
from typing import Any, TypedDict, cast
@ -30,7 +31,20 @@ from services.entities.knowledge_entities.knowledge_entities import ParentMode,
from .account import Account
from .base import Base, TypeBase
from .engine import db
from .enums import CreatorUserRole
from .enums import (
CollectionBindingType,
CreatorUserRole,
DatasetMetadataType,
DatasetQuerySource,
DatasetRuntimeMode,
DataSourceType,
DocumentCreatedFrom,
DocumentDocType,
IndexingStatus,
ProcessRuleMode,
SegmentStatus,
SummaryStatus,
)
from .model import App, Tag, TagBinding, UploadFile
from .types import AdjustedJSON, BinaryData, EnumText, LongText, StringUUID, adjusted_json_index
@ -120,7 +134,7 @@ class Dataset(Base):
server_default=sa.text("'only_me'"),
default=DatasetPermissionEnum.ONLY_ME,
)
data_source_type = mapped_column(String(255))
data_source_type = mapped_column(EnumText(DataSourceType, length=255))
indexing_technique: Mapped[str | None] = mapped_column(String(255))
index_struct = mapped_column(LongText, nullable=True)
created_by = mapped_column(StringUUID, nullable=False)
@ -137,7 +151,9 @@ class Dataset(Base):
summary_index_setting = mapped_column(AdjustedJSON, nullable=True)
built_in_field_enabled = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
icon_info = mapped_column(AdjustedJSON, nullable=True)
runtime_mode = mapped_column(sa.String(255), nullable=True, server_default=sa.text("'general'"))
runtime_mode = mapped_column(
EnumText(DatasetRuntimeMode, length=255), nullable=True, server_default=sa.text("'general'")
)
pipeline_id = mapped_column(StringUUID, nullable=True)
chunk_structure = mapped_column(sa.String(255), nullable=True)
enable_api = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
@ -145,30 +161,25 @@ class Dataset(Base):
@property
def total_documents(self):
return db.session.query(func.count(Document.id)).where(Document.dataset_id == self.id).scalar()
return db.session.scalar(select(func.count(Document.id)).where(Document.dataset_id == self.id)) or 0
@property
def total_available_documents(self):
return (
db.session.query(func.count(Document.id))
.where(
Document.dataset_id == self.id,
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
db.session.scalar(
select(func.count(Document.id)).where(
Document.dataset_id == self.id,
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
)
)
.scalar()
or 0
)
@property
def dataset_keyword_table(self):
dataset_keyword_table = (
db.session.query(DatasetKeywordTable).where(DatasetKeywordTable.dataset_id == self.id).first()
)
if dataset_keyword_table:
return dataset_keyword_table
return None
return db.session.scalar(select(DatasetKeywordTable).where(DatasetKeywordTable.dataset_id == self.id))
@property
def index_struct_dict(self):
@ -195,64 +206,66 @@ class Dataset(Base):
@property
def latest_process_rule(self):
return (
db.session.query(DatasetProcessRule)
return db.session.scalar(
select(DatasetProcessRule)
.where(DatasetProcessRule.dataset_id == self.id)
.order_by(DatasetProcessRule.created_at.desc())
.first()
.limit(1)
)
@property
def app_count(self):
return (
db.session.query(func.count(AppDatasetJoin.id))
.where(AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id)
.scalar()
db.session.scalar(
select(func.count(AppDatasetJoin.id)).where(
AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id
)
)
or 0
)
@property
def document_count(self):
return db.session.query(func.count(Document.id)).where(Document.dataset_id == self.id).scalar()
return db.session.scalar(select(func.count(Document.id)).where(Document.dataset_id == self.id)) or 0
@property
def available_document_count(self):
return (
db.session.query(func.count(Document.id))
.where(
Document.dataset_id == self.id,
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
db.session.scalar(
select(func.count(Document.id)).where(
Document.dataset_id == self.id,
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
)
)
.scalar()
or 0
)
@property
def available_segment_count(self):
return (
db.session.query(func.count(DocumentSegment.id))
.where(
DocumentSegment.dataset_id == self.id,
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.dataset_id == self.id,
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
)
)
.scalar()
or 0
)
@property
def word_count(self):
return (
db.session.query(Document)
.with_entities(func.coalesce(func.sum(Document.word_count), 0))
.where(Document.dataset_id == self.id)
.scalar()
return db.session.scalar(
select(func.coalesce(func.sum(Document.word_count), 0)).where(Document.dataset_id == self.id)
)
@property
def doc_form(self) -> str | None:
if self.chunk_structure:
return self.chunk_structure
document = db.session.query(Document).where(Document.dataset_id == self.id).first()
document = db.session.scalar(select(Document).where(Document.dataset_id == self.id).limit(1))
if document:
return document.doc_form
return None
@ -270,8 +283,8 @@ class Dataset(Base):
@property
def tags(self):
tags = (
db.session.query(Tag)
tags = db.session.scalars(
select(Tag)
.join(TagBinding, Tag.id == TagBinding.tag_id)
.where(
TagBinding.target_id == self.id,
@ -279,8 +292,7 @@ class Dataset(Base):
Tag.tenant_id == self.tenant_id,
Tag.type == "knowledge",
)
.all()
)
).all()
return tags or []
@ -288,8 +300,8 @@ class Dataset(Base):
def external_knowledge_info(self):
if self.provider != "external":
return None
external_knowledge_binding = (
db.session.query(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == self.id).first()
external_knowledge_binding = db.session.scalar(
select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == self.id)
)
if not external_knowledge_binding:
return None
@ -310,7 +322,7 @@ class Dataset(Base):
@property
def is_published(self):
if self.pipeline_id:
pipeline = db.session.query(Pipeline).where(Pipeline.id == self.pipeline_id).first()
pipeline = db.session.scalar(select(Pipeline).where(Pipeline.id == self.pipeline_id))
if pipeline:
return pipeline.is_published
return False
@ -382,7 +394,7 @@ class DatasetProcessRule(Base): # bug
id = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()))
dataset_id = mapped_column(StringUUID, nullable=False)
mode = mapped_column(String(255), nullable=False, server_default=sa.text("'automatic'"))
mode = mapped_column(EnumText(ProcessRuleMode, length=255), nullable=False, server_default=sa.text("'automatic'"))
rules = mapped_column(LongText, nullable=True)
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
@ -428,12 +440,12 @@ class Document(Base):
tenant_id = mapped_column(StringUUID, nullable=False)
dataset_id = mapped_column(StringUUID, nullable=False)
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
data_source_type: Mapped[str] = mapped_column(String(255), nullable=False)
data_source_type: Mapped[str] = mapped_column(EnumText(DataSourceType, length=255), nullable=False)
data_source_info = mapped_column(LongText, nullable=True)
dataset_process_rule_id = mapped_column(StringUUID, nullable=True)
batch: Mapped[str] = mapped_column(String(255), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
created_from: Mapped[str] = mapped_column(String(255), nullable=False)
created_from: Mapped[str] = mapped_column(EnumText(DocumentCreatedFrom, length=255), nullable=False)
created_by = mapped_column(StringUUID, nullable=False)
created_api_request_id = mapped_column(StringUUID, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
@ -467,7 +479,9 @@ class Document(Base):
stopped_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
# basic fields
indexing_status = mapped_column(String(255), nullable=False, server_default=sa.text("'waiting'"))
indexing_status = mapped_column(
EnumText(IndexingStatus, length=255), nullable=False, server_default=sa.text("'waiting'")
)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
disabled_by = mapped_column(StringUUID, nullable=True)
@ -478,7 +492,7 @@ class Document(Base):
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
doc_type = mapped_column(String(40), nullable=True)
doc_type = mapped_column(EnumText(DocumentDocType, length=40), nullable=True)
doc_metadata = mapped_column(AdjustedJSON, nullable=True)
doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'"))
doc_language = mapped_column(String(255), nullable=True)
@ -521,10 +535,8 @@ class Document(Base):
if self.data_source_info:
if self.data_source_type == "upload_file":
data_source_info_dict: dict[str, Any] = json.loads(self.data_source_info)
file_detail = (
db.session.query(UploadFile)
.where(UploadFile.id == data_source_info_dict["upload_file_id"])
.one_or_none()
file_detail = db.session.scalar(
select(UploadFile).where(UploadFile.id == data_source_info_dict["upload_file_id"])
)
if file_detail:
return {
@ -557,24 +569,23 @@ class Document(Base):
@property
def dataset(self):
return db.session.query(Dataset).where(Dataset.id == self.dataset_id).one_or_none()
return db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id))
@property
def segment_count(self):
return db.session.query(DocumentSegment).where(DocumentSegment.document_id == self.id).count()
return (
db.session.scalar(select(func.count(DocumentSegment.id)).where(DocumentSegment.document_id == self.id)) or 0
)
@property
def hit_count(self):
return (
db.session.query(DocumentSegment)
.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count), 0))
.where(DocumentSegment.document_id == self.id)
.scalar()
return db.session.scalar(
select(func.coalesce(func.sum(DocumentSegment.hit_count), 0)).where(DocumentSegment.document_id == self.id)
)
@property
def uploader(self):
user = db.session.query(Account).where(Account.id == self.created_by).first()
user = db.session.scalar(select(Account).where(Account.id == self.created_by))
return user.name if user else None
@property
@ -588,14 +599,13 @@ class Document(Base):
@property
def doc_metadata_details(self) -> list[DocMetadataDetailItem] | None:
if self.doc_metadata:
document_metadatas = (
db.session.query(DatasetMetadata)
document_metadatas = db.session.scalars(
select(DatasetMetadata)
.join(DatasetMetadataBinding, DatasetMetadataBinding.metadata_id == DatasetMetadata.id)
.where(
DatasetMetadataBinding.dataset_id == self.dataset_id, DatasetMetadataBinding.document_id == self.id
)
.all()
)
).all()
metadata_list: list[DocMetadataDetailItem] = []
for metadata in document_metadatas:
metadata_dict: DocMetadataDetailItem = {
@ -791,7 +801,7 @@ class DocumentSegment(Base):
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
disabled_by = mapped_column(StringUUID, nullable=True)
status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'waiting'"))
status: Mapped[str] = mapped_column(EnumText(SegmentStatus, length=255), server_default=sa.text("'waiting'"))
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
@ -826,7 +836,7 @@ class DocumentSegment(Base):
)
@property
def child_chunks(self) -> list[Any]:
def child_chunks(self) -> Sequence[Any]:
if not self.document:
return []
process_rule = self.document.dataset_process_rule
@ -835,16 +845,13 @@ class DocumentSegment(Base):
if rules_dict:
rules = Rule.model_validate(rules_dict)
if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:
child_chunks = (
db.session.query(ChildChunk)
.where(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc())
.all()
)
child_chunks = db.session.scalars(
select(ChildChunk).where(ChildChunk.segment_id == self.id).order_by(ChildChunk.position.asc())
).all()
return child_chunks or []
return []
def get_child_chunks(self) -> list[Any]:
def get_child_chunks(self) -> Sequence[Any]:
if not self.document:
return []
process_rule = self.document.dataset_process_rule
@ -853,12 +860,9 @@ class DocumentSegment(Base):
if rules_dict:
rules = Rule.model_validate(rules_dict)
if rules.parent_mode:
child_chunks = (
db.session.query(ChildChunk)
.where(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc())
.all()
)
child_chunks = db.session.scalars(
select(ChildChunk).where(ChildChunk.segment_id == self.id).order_by(ChildChunk.position.asc())
).all()
return child_chunks or []
return []
@ -1007,15 +1011,15 @@ class ChildChunk(Base):
@property
def dataset(self):
return db.session.query(Dataset).where(Dataset.id == self.dataset_id).first()
return db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id))
@property
def document(self):
return db.session.query(Document).where(Document.id == self.document_id).first()
return db.session.scalar(select(Document).where(Document.id == self.document_id))
@property
def segment(self):
return db.session.query(DocumentSegment).where(DocumentSegment.id == self.segment_id).first()
return db.session.scalar(select(DocumentSegment).where(DocumentSegment.id == self.segment_id))
class AppDatasetJoin(TypeBase):
@ -1061,7 +1065,7 @@ class DatasetQuery(TypeBase):
)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
content: Mapped[str] = mapped_column(LongText, nullable=False)
source: Mapped[str] = mapped_column(String(255), nullable=False)
source: Mapped[str] = mapped_column(EnumText(DatasetQuerySource, length=255), nullable=False)
source_app_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
@ -1076,7 +1080,7 @@ class DatasetQuery(TypeBase):
if isinstance(queries, list):
for query in queries:
if query["content_type"] == QueryType.IMAGE_QUERY:
file_info = db.session.query(UploadFile).filter_by(id=query["content"]).first()
file_info = db.session.scalar(select(UploadFile).where(UploadFile.id == query["content"]))
if file_info:
query["file_info"] = {
"id": file_info.id,
@ -1141,7 +1145,7 @@ class DatasetKeywordTable(TypeBase):
super().__init__(object_hook=object_hook, *args, **kwargs)
# get dataset
dataset = db.session.query(Dataset).filter_by(id=self.dataset_id).first()
dataset = db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id))
if not dataset:
return None
if self.data_source_type == "database":
@ -1206,7 +1210,9 @@ class DatasetCollectionBinding(TypeBase):
)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
type: Mapped[str] = mapped_column(String(40), server_default=sa.text("'dataset'"), nullable=False)
type: Mapped[str] = mapped_column(
EnumText(CollectionBindingType, length=40), server_default=sa.text("'dataset'"), nullable=False
)
collection_name: Mapped[str] = mapped_column(String(64), nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
@ -1433,7 +1439,7 @@ class DatasetMetadata(TypeBase):
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
type: Mapped[str] = mapped_column(String(255), nullable=False)
type: Mapped[str] = mapped_column(EnumText(DatasetMetadataType, length=255), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False
@ -1535,7 +1541,7 @@ class PipelineCustomizedTemplate(TypeBase):
@property
def created_user_name(self):
account = db.session.query(Account).where(Account.id == self.created_by).first()
account = db.session.scalar(select(Account).where(Account.id == self.created_by))
if account:
return account.name
return ""
@ -1570,7 +1576,7 @@ class Pipeline(TypeBase):
)
def retrieve_dataset(self, session: Session):
return session.query(Dataset).where(Dataset.pipeline_id == self.id).first()
return session.scalar(select(Dataset).where(Dataset.pipeline_id == self.id))
class DocumentPipelineExecutionLog(TypeBase):
@ -1660,7 +1666,9 @@ class DocumentSegmentSummary(Base):
summary_index_node_id: Mapped[str] = mapped_column(String(255), nullable=True)
summary_index_node_hash: Mapped[str] = mapped_column(String(255), nullable=True)
tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
status: Mapped[str] = mapped_column(String(32), nullable=False, server_default=sa.text("'generating'"))
status: Mapped[str] = mapped_column(
EnumText(SummaryStatus, length=32), nullable=False, server_default=sa.text("'generating'")
)
error: Mapped[str] = mapped_column(LongText, nullable=True)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)

View File

@ -11,6 +11,13 @@ class CreatorUserRole(StrEnum):
ACCOUNT = "account"
END_USER = "end_user"
@classmethod
def _missing_(cls, value):
if value == "end-user":
return cls.END_USER
else:
return super()._missing_(value)
class WorkflowRunTriggeredFrom(StrEnum):
DEBUGGING = "debugging"
@ -96,3 +103,216 @@ class ConversationStatus(StrEnum):
"""Conversation Status Enum"""
NORMAL = "normal"
class DataSourceType(StrEnum):
"""Data Source Type for Dataset and Document"""
UPLOAD_FILE = "upload_file"
NOTION_IMPORT = "notion_import"
WEBSITE_CRAWL = "website_crawl"
LOCAL_FILE = "local_file"
ONLINE_DOCUMENT = "online_document"
class ProcessRuleMode(StrEnum):
"""Dataset Process Rule Mode"""
AUTOMATIC = "automatic"
CUSTOM = "custom"
HIERARCHICAL = "hierarchical"
class IndexingStatus(StrEnum):
"""Document Indexing Status"""
WAITING = "waiting"
PARSING = "parsing"
CLEANING = "cleaning"
SPLITTING = "splitting"
INDEXING = "indexing"
PAUSED = "paused"
COMPLETED = "completed"
ERROR = "error"
class DocumentCreatedFrom(StrEnum):
"""Document Created From"""
WEB = "web"
API = "api"
RAG_PIPELINE = "rag-pipeline"
class ConversationFromSource(StrEnum):
"""Conversation / Message from_source"""
API = "api"
CONSOLE = "console"
class FeedbackFromSource(StrEnum):
"""MessageFeedback from_source"""
USER = "user"
ADMIN = "admin"
class InvokeFrom(StrEnum):
"""How a conversation/message was invoked"""
SERVICE_API = "service-api"
WEB_APP = "web-app"
TRIGGER = "trigger"
EXPLORE = "explore"
DEBUGGER = "debugger"
PUBLISHED_PIPELINE = "published"
VALIDATION = "validation"
@classmethod
def value_of(cls, value: str) -> "InvokeFrom":
return cls(value)
def to_source(self) -> str:
source_mapping = {
InvokeFrom.WEB_APP: "web_app",
InvokeFrom.DEBUGGER: "dev",
InvokeFrom.EXPLORE: "explore_app",
InvokeFrom.TRIGGER: "trigger",
InvokeFrom.SERVICE_API: "api",
}
return source_mapping.get(self, "dev")
class DocumentDocType(StrEnum):
"""Document doc_type classification"""
BOOK = "book"
WEB_PAGE = "web_page"
PAPER = "paper"
SOCIAL_MEDIA_POST = "social_media_post"
WIKIPEDIA_ENTRY = "wikipedia_entry"
PERSONAL_DOCUMENT = "personal_document"
BUSINESS_DOCUMENT = "business_document"
IM_CHAT_LOG = "im_chat_log"
SYNCED_FROM_NOTION = "synced_from_notion"
SYNCED_FROM_GITHUB = "synced_from_github"
OTHERS = "others"
class TagType(StrEnum):
"""Tag type"""
KNOWLEDGE = "knowledge"
APP = "app"
class DatasetMetadataType(StrEnum):
"""Dataset metadata value type"""
STRING = "string"
NUMBER = "number"
TIME = "time"
class SegmentStatus(StrEnum):
"""Document segment status"""
WAITING = "waiting"
INDEXING = "indexing"
COMPLETED = "completed"
ERROR = "error"
PAUSED = "paused"
RE_SEGMENT = "re_segment"
class DatasetRuntimeMode(StrEnum):
"""Dataset runtime mode"""
GENERAL = "general"
RAG_PIPELINE = "rag_pipeline"
class CollectionBindingType(StrEnum):
"""Dataset collection binding type"""
DATASET = "dataset"
ANNOTATION = "annotation"
class DatasetQuerySource(StrEnum):
"""Dataset query source"""
HIT_TESTING = "hit_testing"
APP = "app"
class TidbAuthBindingStatus(StrEnum):
"""TiDB auth binding status"""
CREATING = "CREATING"
ACTIVE = "ACTIVE"
class MessageFileBelongsTo(StrEnum):
"""MessageFile belongs_to"""
USER = "user"
ASSISTANT = "assistant"
class CredentialSourceType(StrEnum):
"""Load balancing credential source type"""
PROVIDER = "provider"
CUSTOM_MODEL = "custom_model"
class PaymentStatus(StrEnum):
"""Provider order payment status"""
WAIT_PAY = "wait_pay"
PAID = "paid"
FAILED = "failed"
REFUNDED = "refunded"
class BannerStatus(StrEnum):
"""ExporleBanner status"""
ENABLED = "enabled"
DISABLED = "disabled"
class SummaryStatus(StrEnum):
"""Document segment summary status"""
NOT_STARTED = "not_started"
GENERATING = "generating"
COMPLETED = "completed"
ERROR = "error"
TIMEOUT = "timeout"
class MessageChainType(StrEnum):
"""Message chain type"""
SYSTEM = "system"
class ProviderQuotaType(StrEnum):
PAID = "paid"
"""hosted paid quota"""
FREE = "free"
"""third-party free quota"""
TRIAL = "trial"
"""hosted trial quota"""
@staticmethod
def value_of(value: str) -> "ProviderQuotaType":
for member in ProviderQuotaType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")

View File

@ -380,13 +380,12 @@ class App(Base):
@property
def site(self) -> Site | None:
site = db.session.query(Site).where(Site.app_id == self.id).first()
return site
return db.session.scalar(select(Site).where(Site.app_id == self.id))
@property
def app_model_config(self) -> AppModelConfig | None:
if self.app_model_config_id:
return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
return db.session.scalar(select(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id))
return None
@ -395,7 +394,7 @@ class App(Base):
if self.workflow_id:
from .workflow import Workflow
return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first()
return db.session.scalar(select(Workflow).where(Workflow.id == self.workflow_id))
return None
@ -405,8 +404,7 @@ class App(Base):
@property
def tenant(self) -> Tenant | None:
tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
return tenant
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
@property
def is_agent(self) -> bool:
@ -546,9 +544,9 @@ class App(Base):
return deleted_tools
@property
def tags(self) -> list[Tag]:
tags = (
db.session.query(Tag)
def tags(self) -> Sequence[Tag]:
tags = db.session.scalars(
select(Tag)
.join(TagBinding, Tag.id == TagBinding.tag_id)
.where(
TagBinding.target_id == self.id,
@ -556,15 +554,14 @@ class App(Base):
Tag.tenant_id == self.tenant_id,
Tag.type == "app",
)
.all()
)
).all()
return tags or []
@property
def author_name(self) -> str | None:
if self.created_by:
account = db.session.query(Account).where(Account.id == self.created_by).first()
account = db.session.scalar(select(Account).where(Account.id == self.created_by))
if account:
return account.name
@ -616,8 +613,7 @@ class AppModelConfig(TypeBase):
@property
def app(self) -> App | None:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
return db.session.scalar(select(App).where(App.id == self.app_id))
@property
def model_dict(self) -> ModelConfig:
@ -652,8 +648,8 @@ class AppModelConfig(TypeBase):
@property
def annotation_reply_dict(self) -> AnnotationReplyConfig:
annotation_setting = (
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first()
annotation_setting = db.session.scalar(
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id)
)
if annotation_setting:
collection_binding_detail = annotation_setting.collection_binding_detail
@ -845,8 +841,7 @@ class RecommendedApp(Base): # bug
@property
def app(self) -> App | None:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
return db.session.scalar(select(App).where(App.id == self.app_id))
class InstalledApp(TypeBase):
@ -873,13 +868,11 @@ class InstalledApp(TypeBase):
@property
def app(self) -> App | None:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
return db.session.scalar(select(App).where(App.id == self.app_id))
@property
def tenant(self) -> Tenant | None:
tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
return tenant
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
class TrialApp(Base):
@ -899,8 +892,7 @@ class TrialApp(Base):
@property
def app(self) -> App | None:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
return db.session.scalar(select(App).where(App.id == self.app_id))
class AccountTrialAppRecord(Base):
@ -919,13 +911,11 @@ class AccountTrialAppRecord(Base):
@property
def app(self) -> App | None:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
return db.session.scalar(select(App).where(App.id == self.app_id))
@property
def user(self) -> Account | None:
user = db.session.query(Account).where(Account.id == self.account_id).first()
return user
return db.session.scalar(select(Account).where(Account.id == self.account_id))
class ExporleBanner(TypeBase):
@ -1117,8 +1107,8 @@ class Conversation(Base):
else:
model_config["configs"] = override_model_configs # type: ignore[typeddict-unknown-key]
else:
app_model_config = (
db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
app_model_config = db.session.scalar(
select(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id)
)
if app_model_config:
model_config = app_model_config.to_dict()
@ -1141,36 +1131,43 @@ class Conversation(Base):
@property
def annotated(self):
return db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).count() > 0
return (
db.session.scalar(
select(func.count(MessageAnnotation.id)).where(MessageAnnotation.conversation_id == self.id)
)
or 0
) > 0
@property
def annotation(self):
return db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).first()
return db.session.scalar(select(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).limit(1))
@property
def message_count(self):
return db.session.query(Message).where(Message.conversation_id == self.id).count()
return db.session.scalar(select(func.count(Message.id)).where(Message.conversation_id == self.id)) or 0
@property
def user_feedback_stats(self):
like = (
db.session.query(MessageFeedback)
.where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "user",
MessageFeedback.rating == "like",
db.session.scalar(
select(func.count(MessageFeedback.id)).where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "user",
MessageFeedback.rating == "like",
)
)
.count()
or 0
)
dislike = (
db.session.query(MessageFeedback)
.where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "user",
MessageFeedback.rating == "dislike",
db.session.scalar(
select(func.count(MessageFeedback.id)).where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "user",
MessageFeedback.rating == "dislike",
)
)
.count()
or 0
)
return {"like": like, "dislike": dislike}
@ -1178,23 +1175,25 @@ class Conversation(Base):
@property
def admin_feedback_stats(self):
like = (
db.session.query(MessageFeedback)
.where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "admin",
MessageFeedback.rating == "like",
db.session.scalar(
select(func.count(MessageFeedback.id)).where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "admin",
MessageFeedback.rating == "like",
)
)
.count()
or 0
)
dislike = (
db.session.query(MessageFeedback)
.where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "admin",
MessageFeedback.rating == "dislike",
db.session.scalar(
select(func.count(MessageFeedback.id)).where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "admin",
MessageFeedback.rating == "dislike",
)
)
.count()
or 0
)
return {"like": like, "dislike": dislike}
@ -1256,22 +1255,19 @@ class Conversation(Base):
@property
def first_message(self):
return (
db.session.query(Message)
.where(Message.conversation_id == self.id)
.order_by(Message.created_at.asc())
.first()
return db.session.scalar(
select(Message).where(Message.conversation_id == self.id).order_by(Message.created_at.asc())
)
@property
def app(self) -> App | None:
with Session(db.engine, expire_on_commit=False) as session:
return session.query(App).where(App.id == self.app_id).first()
return session.scalar(select(App).where(App.id == self.app_id))
@property
def from_end_user_session_id(self):
if self.from_end_user_id:
end_user = db.session.query(EndUser).where(EndUser.id == self.from_end_user_id).first()
end_user = db.session.scalar(select(EndUser).where(EndUser.id == self.from_end_user_id))
if end_user:
return end_user.session_id
@ -1280,7 +1276,7 @@ class Conversation(Base):
@property
def from_account_name(self) -> str | None:
if self.from_account_id:
account = db.session.query(Account).where(Account.id == self.from_account_id).first()
account = db.session.scalar(select(Account).where(Account.id == self.from_account_id))
if account:
return account.name
@ -1505,21 +1501,15 @@ class Message(Base):
@property
def user_feedback(self):
feedback = (
db.session.query(MessageFeedback)
.where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user")
.first()
return db.session.scalar(
select(MessageFeedback).where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user")
)
return feedback
@property
def admin_feedback(self):
feedback = (
db.session.query(MessageFeedback)
.where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin")
.first()
return db.session.scalar(
select(MessageFeedback).where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin")
)
return feedback
@property
def feedbacks(self):
@ -1528,28 +1518,27 @@ class Message(Base):
@property
def annotation(self):
annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == self.id).first()
annotation = db.session.scalar(select(MessageAnnotation).where(MessageAnnotation.message_id == self.id))
return annotation
@property
def annotation_hit_history(self):
annotation_history = (
db.session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id == self.id).first()
annotation_history = db.session.scalar(
select(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id == self.id)
)
if annotation_history:
annotation = (
db.session.query(MessageAnnotation)
.where(MessageAnnotation.id == annotation_history.annotation_id)
.first()
return db.session.scalar(
select(MessageAnnotation).where(MessageAnnotation.id == annotation_history.annotation_id)
)
return annotation
return None
@property
def app_model_config(self):
conversation = db.session.query(Conversation).where(Conversation.id == self.conversation_id).first()
conversation = db.session.scalar(select(Conversation).where(Conversation.id == self.conversation_id))
if conversation:
return db.session.query(AppModelConfig).where(AppModelConfig.id == conversation.app_model_config_id).first()
return db.session.scalar(
select(AppModelConfig).where(AppModelConfig.id == conversation.app_model_config_id)
)
return None
@ -1562,13 +1551,12 @@ class Message(Base):
return json.loads(self.message_metadata) if self.message_metadata else {}
@property
def agent_thoughts(self) -> list[MessageAgentThought]:
return (
db.session.query(MessageAgentThought)
def agent_thoughts(self) -> Sequence[MessageAgentThought]:
return db.session.scalars(
select(MessageAgentThought)
.where(MessageAgentThought.message_id == self.id)
.order_by(MessageAgentThought.position.asc())
.all()
)
).all()
@property
def retriever_resources(self) -> Any:
@ -1579,7 +1567,7 @@ class Message(Base):
from factories import file_factory
message_files = db.session.scalars(select(MessageFile).where(MessageFile.message_id == self.id)).all()
current_app = db.session.query(App).where(App.id == self.app_id).first()
current_app = db.session.scalar(select(App).where(App.id == self.app_id))
if not current_app:
raise ValueError(f"App {self.app_id} not found")
@ -1743,8 +1731,7 @@ class MessageFeedback(TypeBase):
@property
def from_account(self) -> Account | None:
account = db.session.query(Account).where(Account.id == self.from_account_id).first()
return account
return db.session.scalar(select(Account).where(Account.id == self.from_account_id))
def to_dict(self) -> MessageFeedbackDict:
return {
@ -1817,13 +1804,11 @@ class MessageAnnotation(Base):
@property
def account(self):
account = db.session.query(Account).where(Account.id == self.account_id).first()
return account
return db.session.scalar(select(Account).where(Account.id == self.account_id))
@property
def annotation_create_account(self):
account = db.session.query(Account).where(Account.id == self.account_id).first()
return account
return db.session.scalar(select(Account).where(Account.id == self.account_id))
class AppAnnotationHitHistory(TypeBase):
@ -1852,18 +1837,15 @@ class AppAnnotationHitHistory(TypeBase):
@property
def account(self):
account = (
db.session.query(Account)
return db.session.scalar(
select(Account)
.join(MessageAnnotation, MessageAnnotation.account_id == Account.id)
.where(MessageAnnotation.id == self.annotation_id)
.first()
)
return account
@property
def annotation_create_account(self):
account = db.session.query(Account).where(Account.id == self.account_id).first()
return account
return db.session.scalar(select(Account).where(Account.id == self.account_id))
class AppAnnotationSetting(TypeBase):
@ -1896,12 +1878,9 @@ class AppAnnotationSetting(TypeBase):
def collection_binding_detail(self):
from .dataset import DatasetCollectionBinding
collection_binding_detail = (
db.session.query(DatasetCollectionBinding)
.where(DatasetCollectionBinding.id == self.collection_binding_id)
.first()
return db.session.scalar(
select(DatasetCollectionBinding).where(DatasetCollectionBinding.id == self.collection_binding_id)
)
return collection_binding_detail
class OperationLog(TypeBase):
@ -2007,7 +1986,9 @@ class AppMCPServer(TypeBase):
def generate_server_code(n: int) -> str:
while True:
result = generate_string(n)
while db.session.query(AppMCPServer).where(AppMCPServer.server_code == result).count() > 0:
while (
db.session.scalar(select(func.count(AppMCPServer.id)).where(AppMCPServer.server_code == result)) or 0
) > 0:
result = generate_string(n)
return result
@ -2068,7 +2049,7 @@ class Site(Base):
def generate_code(n: int) -> str:
while True:
result = generate_string(n)
while db.session.query(Site).where(Site.code == result).count() > 0:
while (db.session.scalar(select(func.count(Site.id)).where(Site.code == result)) or 0) > 0:
result = generate_string(n)
return result

View File

@ -6,13 +6,14 @@ from functools import cached_property
from uuid import uuid4
import sqlalchemy as sa
from sqlalchemy import DateTime, String, func, text
from sqlalchemy import DateTime, String, func, select, text
from sqlalchemy.orm import Mapped, mapped_column
from libs.uuid_utils import uuidv7
from .base import TypeBase
from .engine import db
from .enums import CredentialSourceType, PaymentStatus
from .types import EnumText, LongText, StringUUID
@ -96,7 +97,7 @@ class Provider(TypeBase):
@cached_property
def credential(self):
if self.credential_id:
return db.session.query(ProviderCredential).where(ProviderCredential.id == self.credential_id).first()
return db.session.scalar(select(ProviderCredential).where(ProviderCredential.id == self.credential_id))
@property
def credential_name(self):
@ -159,10 +160,8 @@ class ProviderModel(TypeBase):
@cached_property
def credential(self):
if self.credential_id:
return (
db.session.query(ProviderModelCredential)
.where(ProviderModelCredential.id == self.credential_id)
.first()
return db.session.scalar(
select(ProviderModelCredential).where(ProviderModelCredential.id == self.credential_id)
)
@property
@ -239,7 +238,9 @@ class ProviderOrder(TypeBase):
quantity: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=text("1"))
currency: Mapped[str | None] = mapped_column(String(40))
total_amount: Mapped[int | None] = mapped_column(sa.Integer)
payment_status: Mapped[str] = mapped_column(String(40), nullable=False, server_default=text("'wait_pay'"))
payment_status: Mapped[PaymentStatus] = mapped_column(
EnumText(PaymentStatus, length=40), nullable=False, server_default=text("'wait_pay'")
)
paid_at: Mapped[datetime | None] = mapped_column(DateTime)
pay_failed_at: Mapped[datetime | None] = mapped_column(DateTime)
refunded_at: Mapped[datetime | None] = mapped_column(DateTime)
@ -302,7 +303,9 @@ class LoadBalancingModelConfig(TypeBase):
name: Mapped[str] = mapped_column(String(255), nullable=False)
encrypted_config: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
credential_source_type: Mapped[str | None] = mapped_column(String(40), nullable=True, default=None)
credential_source_type: Mapped[CredentialSourceType | None] = mapped_column(
EnumText(CredentialSourceType, length=40), nullable=True, default=None
)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"), default=True)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False

View File

@ -8,7 +8,7 @@ from uuid import uuid4
import sqlalchemy as sa
from deprecated import deprecated
from sqlalchemy import ForeignKey, String, func
from sqlalchemy import ForeignKey, String, func, select
from sqlalchemy.orm import Mapped, mapped_column
from core.tools.entities.common_entities import I18nObject
@ -184,11 +184,11 @@ class ApiToolProvider(TypeBase):
def user(self) -> Account | None:
if not self.user_id:
return None
return db.session.query(Account).where(Account.id == self.user_id).first()
return db.session.scalar(select(Account).where(Account.id == self.user_id))
@property
def tenant(self) -> Tenant | None:
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
class ToolLabelBinding(TypeBase):
@ -262,11 +262,11 @@ class WorkflowToolProvider(TypeBase):
@property
def user(self) -> Account | None:
return db.session.query(Account).where(Account.id == self.user_id).first()
return db.session.scalar(select(Account).where(Account.id == self.user_id))
@property
def tenant(self) -> Tenant | None:
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
@property
def parameter_configurations(self) -> list[WorkflowToolParameterConfiguration]:
@ -277,7 +277,7 @@ class WorkflowToolProvider(TypeBase):
@property
def app(self) -> App | None:
return db.session.query(App).where(App.id == self.app_id).first()
return db.session.scalar(select(App).where(App.id == self.app_id))
class MCPToolProvider(TypeBase):
@ -334,7 +334,7 @@ class MCPToolProvider(TypeBase):
encrypted_headers: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
def load_user(self) -> Account | None:
return db.session.query(Account).where(Account.id == self.user_id).first()
return db.session.scalar(select(Account).where(Account.id == self.user_id))
@property
def credentials(self) -> dict[str, Any]:

View File

@ -3,7 +3,7 @@ import time
from collections.abc import Mapping
from datetime import datetime
from functools import cached_property
from typing import Any, cast
from typing import Any, TypedDict, cast
from uuid import uuid4
import sqlalchemy as sa
@ -23,6 +23,47 @@ from .enums import AppTriggerStatus, AppTriggerType, CreatorUserRole, WorkflowTr
from .model import Account
from .types import EnumText, LongText, StringUUID
TriggerJsonObject = dict[str, object]
TriggerCredentials = dict[str, str]
class WorkflowTriggerLogDict(TypedDict):
id: str
tenant_id: str
app_id: str
workflow_id: str
workflow_run_id: str | None
root_node_id: str | None
trigger_metadata: Any
trigger_type: str
trigger_data: Any
inputs: Any
outputs: Any
status: str
error: str | None
queue_name: str
celery_task_id: str | None
retry_count: int
elapsed_time: float | None
total_tokens: int | None
created_by_role: str
created_by: str
created_at: str | None
triggered_at: str | None
finished_at: str | None
class WorkflowSchedulePlanDict(TypedDict):
id: str
app_id: str
node_id: str
tenant_id: str
cron_expression: str
timezone: str
next_run_at: str | None
created_at: str
updated_at: str
class TriggerSubscription(TypeBase):
"""
@ -51,10 +92,14 @@ class TriggerSubscription(TypeBase):
String(255), nullable=False, comment="Provider identifier (e.g., plugin_id/provider_name)"
)
endpoint_id: Mapped[str] = mapped_column(String(255), nullable=False, comment="Subscription endpoint")
parameters: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False, comment="Subscription parameters JSON")
properties: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False, comment="Subscription properties JSON")
parameters: Mapped[TriggerJsonObject] = mapped_column(
sa.JSON, nullable=False, comment="Subscription parameters JSON"
)
properties: Mapped[TriggerJsonObject] = mapped_column(
sa.JSON, nullable=False, comment="Subscription properties JSON"
)
credentials: Mapped[dict[str, Any]] = mapped_column(
credentials: Mapped[TriggerCredentials] = mapped_column(
sa.JSON, nullable=False, comment="Subscription credentials JSON"
)
credential_type: Mapped[str] = mapped_column(String(50), nullable=False, comment="oauth or api_key")
@ -162,8 +207,8 @@ class TriggerOAuthTenantClient(TypeBase):
)
@property
def oauth_params(self) -> Mapping[str, Any]:
return cast(Mapping[str, Any], json.loads(self.encrypted_oauth_params or "{}"))
def oauth_params(self) -> Mapping[str, object]:
return cast(TriggerJsonObject, json.loads(self.encrypted_oauth_params or "{}"))
class WorkflowTriggerLog(TypeBase):
@ -250,7 +295,7 @@ class WorkflowTriggerLog(TypeBase):
created_by_role = CreatorUserRole(self.created_by_role)
return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None
def to_dict(self) -> dict[str, Any]:
def to_dict(self) -> WorkflowTriggerLogDict:
"""Convert to dictionary for API responses"""
return {
"id": self.id,
@ -481,7 +526,7 @@ class WorkflowSchedulePlan(TypeBase):
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), init=False
)
def to_dict(self) -> dict[str, Any]:
def to_dict(self) -> WorkflowSchedulePlanDict:
"""Convert to dictionary representation"""
return {
"id": self.id,

View File

@ -2,7 +2,7 @@ from datetime import datetime
from uuid import uuid4
import sqlalchemy as sa
from sqlalchemy import DateTime, func
from sqlalchemy import DateTime, func, select
from sqlalchemy.orm import Mapped, mapped_column
from .base import TypeBase
@ -38,7 +38,7 @@ class SavedMessage(TypeBase):
@property
def message(self):
return db.session.query(Message).where(Message.id == self.message_id).first()
return db.session.scalar(select(Message).where(Message.id == self.message_id))
class PinnedConversation(TypeBase):

View File

@ -3,7 +3,7 @@ import logging
from collections.abc import Generator, Mapping, Sequence
from datetime import datetime
from enum import StrEnum
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union, cast
from uuid import uuid4
import sqlalchemy as sa
@ -19,7 +19,7 @@ from sqlalchemy import (
orm,
select,
)
from sqlalchemy.orm import Mapped, declared_attr, mapped_column
from sqlalchemy.orm import Mapped, mapped_column
from typing_extensions import deprecated
from core.trigger.constants import TRIGGER_INFO_METADATA_KEY, TRIGGER_PLUGIN_NODE_TYPE
@ -33,7 +33,7 @@ from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowExecutionStatus
from dify_graph.file.constants import maybe_file_object
from dify_graph.file.models import File
from dify_graph.variables import utils as variable_utils
from dify_graph.variables.variables import FloatVariable, IntegerVariable, StringVariable
from dify_graph.variables.variables import FloatVariable, IntegerVariable, RAGPipelineVariable, StringVariable
from extensions.ext_storage import Storage
from factories.variable_factory import TypeMismatchError, build_segment_with_type
from libs.datetime_utils import naive_utc_now
@ -59,6 +59,25 @@ from .types import EnumText, LongText, StringUUID
logger = logging.getLogger(__name__)
SerializedWorkflowValue = dict[str, Any]
SerializedWorkflowVariables = dict[str, SerializedWorkflowValue]
class WorkflowContentDict(TypedDict):
graph: Mapping[str, Any]
features: dict[str, Any]
environment_variables: list[dict[str, Any]]
conversation_variables: list[dict[str, Any]]
rag_pipeline_variables: list[dict[str, Any]]
class WorkflowRunSummaryDict(TypedDict):
id: str
status: str
triggered_from: str
elapsed_time: float
total_tokens: int
class WorkflowType(StrEnum):
"""
@ -389,7 +408,7 @@ class Workflow(Base): # bug
def rag_pipeline_user_input_form(self) -> list:
# get user_input_form from start node
variables: list[Any] = self.rag_pipeline_variables
variables: list[SerializedWorkflowValue] = self.rag_pipeline_variables
return variables
@ -432,17 +451,13 @@ class Workflow(Base): # bug
def environment_variables(
self,
) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
# TODO: find some way to init `self._environment_variables` when instance created.
if self._environment_variables is None:
self._environment_variables = "{}"
# Use workflow.tenant_id to avoid relying on request user in background threads
tenant_id = self.tenant_id
if not tenant_id:
return []
environment_variables_dict: dict[str, Any] = json.loads(self._environment_variables or "{}")
environment_variables_dict = cast(SerializedWorkflowVariables, json.loads(self._environment_variables or "{}"))
results = [
variable_factory.build_environment_variable_from_mapping(v) for v in environment_variables_dict.values()
]
@ -502,14 +517,14 @@ class Workflow(Base): # bug
)
self._environment_variables = environment_variables_json
def to_dict(self, *, include_secret: bool = False) -> Mapping[str, Any]:
def to_dict(self, *, include_secret: bool = False) -> WorkflowContentDict:
environment_variables = list(self.environment_variables)
environment_variables = [
v if not isinstance(v, SecretVariable) or include_secret else v.model_copy(update={"value": ""})
for v in environment_variables
]
result = {
result: WorkflowContentDict = {
"graph": self.graph_dict,
"features": self.features_dict,
"environment_variables": [var.model_dump(mode="json") for var in environment_variables],
@ -520,11 +535,7 @@ class Workflow(Base): # bug
@property
def conversation_variables(self) -> Sequence[VariableBase]:
# TODO: find some way to init `self._conversation_variables` when instance created.
if self._conversation_variables is None:
self._conversation_variables = "{}"
variables_dict: dict[str, Any] = json.loads(self._conversation_variables)
variables_dict = cast(SerializedWorkflowVariables, json.loads(self._conversation_variables or "{}"))
results = [variable_factory.build_conversation_variable_from_mapping(v) for v in variables_dict.values()]
return results
@ -536,19 +547,20 @@ class Workflow(Base): # bug
)
@property
def rag_pipeline_variables(self) -> list[dict]:
# TODO: find some way to init `self._conversation_variables` when instance created.
if self._rag_pipeline_variables is None:
self._rag_pipeline_variables = "{}"
variables_dict: dict[str, Any] = json.loads(self._rag_pipeline_variables)
results = list(variables_dict.values())
return results
def rag_pipeline_variables(self) -> list[SerializedWorkflowValue]:
variables_dict = cast(SerializedWorkflowVariables, json.loads(self._rag_pipeline_variables or "{}"))
return [RAGPipelineVariable.model_validate(item).model_dump(mode="json") for item in variables_dict.values()]
@rag_pipeline_variables.setter
def rag_pipeline_variables(self, values: list[dict]) -> None:
def rag_pipeline_variables(self, values: Sequence[Mapping[str, Any] | RAGPipelineVariable]) -> None:
self._rag_pipeline_variables = json.dumps(
{item["variable"]: item for item in values},
{
rag_pipeline_variable.variable: rag_pipeline_variable.model_dump(mode="json")
for rag_pipeline_variable in (
item if isinstance(item, RAGPipelineVariable) else RAGPipelineVariable.model_validate(item)
for item in values
)
},
ensure_ascii=False,
)
@ -667,14 +679,14 @@ class WorkflowRun(Base):
def message(self):
from .model import Message
return (
db.session.query(Message).where(Message.app_id == self.app_id, Message.workflow_run_id == self.id).first()
return db.session.scalar(
select(Message).where(Message.app_id == self.app_id, Message.workflow_run_id == self.id)
)
@property
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
def workflow(self):
return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first()
return db.session.scalar(select(Workflow).where(Workflow.id == self.workflow_id))
def to_dict(self):
return {
@ -786,44 +798,36 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
__tablename__ = "workflow_node_executions"
@declared_attr.directive
@classmethod
def __table_args__(cls) -> Any:
return (
PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
Index(
"workflow_node_execution_workflow_run_id_idx",
"workflow_run_id",
),
Index(
"workflow_node_execution_node_run_idx",
"tenant_id",
"app_id",
"workflow_id",
"triggered_from",
"node_id",
),
Index(
"workflow_node_execution_id_idx",
"tenant_id",
"app_id",
"workflow_id",
"triggered_from",
"node_execution_id",
),
Index(
# The first argument is the index name,
# which we leave as `None`` to allow auto-generation by the ORM.
None,
cls.tenant_id,
cls.workflow_id,
cls.node_id,
# MyPy may flag the following line because it doesn't recognize that
# the `declared_attr` decorator passes the receiving class as the first
# argument to this method, allowing us to reference class attributes.
cls.created_at.desc(),
),
)
__table_args__ = (
PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
Index(
"workflow_node_execution_workflow_run_id_idx",
"workflow_run_id",
),
Index(
"workflow_node_execution_node_run_idx",
"tenant_id",
"app_id",
"workflow_id",
"triggered_from",
"node_id",
),
Index(
"workflow_node_execution_id_idx",
"tenant_id",
"app_id",
"workflow_id",
"triggered_from",
"node_execution_id",
),
Index(
None,
"tenant_id",
"workflow_id",
"node_id",
sa.desc("created_at"),
),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
tenant_id: Mapped[str] = mapped_column(StringUUID)
@ -1231,7 +1235,7 @@ class WorkflowArchiveLog(TypeBase):
)
@property
def workflow_run_summary(self) -> dict[str, Any]:
def workflow_run_summary(self) -> WorkflowRunSummaryDict:
return {
"id": self.workflow_run_id,
"status": self.run_status,

View File

@ -1,6 +1,6 @@
[project]
name = "dify-api"
version = "1.13.0"
version = "1.13.1"
requires-python = ">=3.11,<3.13"
dependencies = [

View File

@ -1,4 +1,3 @@
configs/middleware/cache/redis_pubsub_config.py
controllers/console/app/annotation.py
controllers/console/app/app.py
controllers/console/app/app_import.py
@ -138,8 +137,6 @@ dify_graph/nodes/trigger_webhook/node.py
dify_graph/nodes/variable_aggregator/variable_aggregator_node.py
dify_graph/nodes/variable_assigner/v1/node.py
dify_graph/nodes/variable_assigner/v2/node.py
dify_graph/variables/types.py
extensions/ext_fastopenapi.py
extensions/logstore/repositories/logstore_api_workflow_run_repository.py
extensions/otel/instrumentation.py
extensions/otel/runtime.py
@ -156,19 +153,7 @@ extensions/storage/oracle_oci_storage.py
extensions/storage/supabase_storage.py
extensions/storage/tencent_cos_storage.py
extensions/storage/volcengine_tos_storage.py
factories/variable_factory.py
libs/external_api.py
libs/gmpy2_pkcs10aep_cipher.py
libs/helper.py
libs/login.py
libs/module_loading.py
libs/oauth.py
libs/oauth_data_source.py
models/trigger.py
models/workflow.py
repositories/sqlalchemy_api_workflow_node_execution_repository.py
repositories/sqlalchemy_api_workflow_run_repository.py
repositories/sqlalchemy_execution_extra_content_repository.py
schedule/queue_monitor_task.py
services/account_service.py
services/audio_service.py
@ -197,4 +182,9 @@ tasks/app_generate/workflow_execute_task.py
tasks/regenerate_summary_index_task.py
tasks/trigger_processing_tasks.py
tasks/workflow_cfs_scheduler/cfs_scheduler.py
tasks/add_document_to_index_task.py
tasks/create_segment_to_index_task.py
tasks/disable_segment_from_index_task.py
tasks/enable_segment_to_index_task.py
tasks/remove_document_from_index_task.py
tasks/workflow_execution_tasks.py

View File

@ -8,7 +8,7 @@ using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations.
import json
from collections.abc import Sequence
from datetime import datetime
from typing import cast
from typing import Protocol, cast
from sqlalchemy import asc, delete, desc, func, select
from sqlalchemy.engine import CursorResult
@ -22,6 +22,20 @@ from repositories.api_workflow_node_execution_repository import (
)
class _WorkflowNodeExecutionSnapshotRow(Protocol):
id: str
node_execution_id: str | None
node_id: str
node_type: str
title: str
index: int
status: WorkflowNodeExecutionStatus
elapsed_time: float | None
created_at: datetime
finished_at: datetime | None
execution_metadata: str | None
class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRepository):
"""
SQLAlchemy implementation of DifyAPIWorkflowNodeExecutionRepository.
@ -40,6 +54,8 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
- Thread-safe database operations using session-per-request pattern
"""
_session_maker: sessionmaker[Session]
def __init__(self, session_maker: sessionmaker[Session]):
"""
Initialize the repository with a sessionmaker.
@ -156,12 +172,12 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
)
with self._session_maker() as session:
rows = session.execute(stmt).all()
rows = cast(Sequence[_WorkflowNodeExecutionSnapshotRow], session.execute(stmt).all())
return [self._row_to_snapshot(row) for row in rows]
@staticmethod
def _row_to_snapshot(row: object) -> WorkflowNodeExecutionSnapshot:
def _row_to_snapshot(row: _WorkflowNodeExecutionSnapshotRow) -> WorkflowNodeExecutionSnapshot:
metadata: dict[str, object] = {}
execution_metadata = getattr(row, "execution_metadata", None)
if execution_metadata:

View File

@ -11,7 +11,7 @@ from core.tools.tool_manager import ToolManager
from extensions.ext_database import db
from libs.login import current_user
from models import Account
from models.model import App, Conversation, EndUser, Message, MessageAgentThought
from models.model import App, Conversation, EndUser, Message
class AgentService:
@ -47,7 +47,7 @@ class AgentService:
if not message:
raise ValueError(f"Message not found: {message_id}")
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
agent_thoughts = message.agent_thoughts
if conversation.from_end_user_id:
# only select name field

View File

@ -18,7 +18,7 @@ from extensions.ext_database import db
from models.account import Account
from models.enums import CreatorUserRole, WorkflowTriggerStatus
from models.model import App, EndUser
from models.trigger import WorkflowTriggerLog
from models.trigger import WorkflowTriggerLog, WorkflowTriggerLogDict
from models.workflow import Workflow
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError
@ -224,7 +224,9 @@ class AsyncWorkflowService:
return cls.trigger_workflow_async(session, user, trigger_data)
@classmethod
def get_trigger_log(cls, workflow_trigger_log_id: str, tenant_id: str | None = None) -> dict[str, Any] | None:
def get_trigger_log(
cls, workflow_trigger_log_id: str, tenant_id: str | None = None
) -> WorkflowTriggerLogDict | None:
"""
Get trigger log by ID
@ -247,7 +249,7 @@ class AsyncWorkflowService:
@classmethod
def get_recent_logs(
cls, tenant_id: str, app_id: str, hours: int = 24, limit: int = 100, offset: int = 0
) -> list[dict[str, Any]]:
) -> list[WorkflowTriggerLogDict]:
"""
Get recent trigger logs
@ -272,7 +274,7 @@ class AsyncWorkflowService:
@classmethod
def get_failed_logs_for_retry(
cls, tenant_id: str, max_retry_count: int = 3, limit: int = 100
) -> list[dict[str, Any]]:
) -> list[WorkflowTriggerLogDict]:
"""
Get failed logs eligible for retry

View File

@ -51,6 +51,14 @@ from models.dataset import (
Pipeline,
SegmentAttachmentBinding,
)
from models.enums import (
DatasetRuntimeMode,
DataSourceType,
DocumentCreatedFrom,
IndexingStatus,
ProcessRuleMode,
SegmentStatus,
)
from models.model import UploadFile
from models.provider_ids import ModelProviderID
from models.source import DataSourceOauthBinding
@ -319,7 +327,7 @@ class DatasetService:
description=rag_pipeline_dataset_create_entity.description,
permission=rag_pipeline_dataset_create_entity.permission,
provider="vendor",
runtime_mode="rag_pipeline",
runtime_mode=DatasetRuntimeMode.RAG_PIPELINE,
icon_info=rag_pipeline_dataset_create_entity.icon_info.model_dump(),
created_by=current_user.id,
pipeline_id=pipeline.id,
@ -614,7 +622,7 @@ class DatasetService:
"""
Update pipeline knowledge base node data.
"""
if dataset.runtime_mode != "rag_pipeline":
if dataset.runtime_mode != DatasetRuntimeMode.RAG_PIPELINE:
return
pipeline = db.session.query(Pipeline).filter_by(id=dataset.pipeline_id).first()
@ -1229,10 +1237,15 @@ class DocumentService:
"enabled": "available",
}
_INDEXING_STATUSES: tuple[str, ...] = ("parsing", "cleaning", "splitting", "indexing")
_INDEXING_STATUSES: tuple[IndexingStatus, ...] = (
IndexingStatus.PARSING,
IndexingStatus.CLEANING,
IndexingStatus.SPLITTING,
IndexingStatus.INDEXING,
)
DISPLAY_STATUS_FILTERS: dict[str, tuple[Any, ...]] = {
"queuing": (Document.indexing_status == "waiting",),
"queuing": (Document.indexing_status == IndexingStatus.WAITING,),
"indexing": (
Document.indexing_status.in_(_INDEXING_STATUSES),
Document.is_paused.is_not(True),
@ -1241,19 +1254,19 @@ class DocumentService:
Document.indexing_status.in_(_INDEXING_STATUSES),
Document.is_paused.is_(True),
),
"error": (Document.indexing_status == "error",),
"error": (Document.indexing_status == IndexingStatus.ERROR,),
"available": (
Document.indexing_status == "completed",
Document.indexing_status == IndexingStatus.COMPLETED,
Document.archived.is_(False),
Document.enabled.is_(True),
),
"disabled": (
Document.indexing_status == "completed",
Document.indexing_status == IndexingStatus.COMPLETED,
Document.archived.is_(False),
Document.enabled.is_(False),
),
"archived": (
Document.indexing_status == "completed",
Document.indexing_status == IndexingStatus.COMPLETED,
Document.archived.is_(True),
),
}
@ -1536,7 +1549,7 @@ class DocumentService:
"""
Normalize and validate `Document -> UploadFile` linkage for download flows.
"""
if document.data_source_type != "upload_file":
if document.data_source_type != DataSourceType.UPLOAD_FILE:
raise NotFound(invalid_source_message)
data_source_info: dict[str, Any] = document.data_source_info_dict or {}
@ -1617,7 +1630,7 @@ class DocumentService:
select(Document).where(
Document.id.in_(document_ids),
Document.enabled == True,
Document.indexing_status == "completed",
Document.indexing_status == IndexingStatus.COMPLETED,
Document.archived == False,
)
).all()
@ -1640,7 +1653,7 @@ class DocumentService:
select(Document).where(
Document.dataset_id == dataset_id,
Document.enabled == True,
Document.indexing_status == "completed",
Document.indexing_status == IndexingStatus.COMPLETED,
Document.archived == False,
)
).all()
@ -1650,7 +1663,10 @@ class DocumentService:
@staticmethod
def get_error_documents_by_dataset_id(dataset_id: str) -> Sequence[Document]:
documents = db.session.scalars(
select(Document).where(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"]))
select(Document).where(
Document.dataset_id == dataset_id,
Document.indexing_status.in_([IndexingStatus.ERROR, IndexingStatus.PAUSED]),
)
).all()
return documents
@ -1683,7 +1699,7 @@ class DocumentService:
def delete_document(document):
# trigger document_was_deleted signal
file_id = None
if document.data_source_type == "upload_file":
if document.data_source_type == DataSourceType.UPLOAD_FILE:
if document.data_source_info:
data_source_info = document.data_source_info_dict
if data_source_info and "upload_file_id" in data_source_info:
@ -1704,7 +1720,7 @@ class DocumentService:
file_ids = [
document.data_source_info_dict.get("upload_file_id", "")
for document in documents
if document.data_source_type == "upload_file" and document.data_source_info_dict
if document.data_source_type == DataSourceType.UPLOAD_FILE and document.data_source_info_dict
]
# Delete documents first, then dispatch cleanup task after commit
@ -1753,7 +1769,13 @@ class DocumentService:
@staticmethod
def pause_document(document):
if document.indexing_status not in {"waiting", "parsing", "cleaning", "splitting", "indexing"}:
if document.indexing_status not in {
IndexingStatus.WAITING,
IndexingStatus.PARSING,
IndexingStatus.CLEANING,
IndexingStatus.SPLITTING,
IndexingStatus.INDEXING,
}:
raise DocumentIndexingError()
# update document to be paused
assert current_user is not None
@ -1793,7 +1815,7 @@ class DocumentService:
if cache_result is not None:
raise ValueError("Document is being retried, please try again later")
# retry document indexing
document.indexing_status = "waiting"
document.indexing_status = IndexingStatus.WAITING
db.session.add(document)
db.session.commit()
@ -1812,7 +1834,7 @@ class DocumentService:
if cache_result is not None:
raise ValueError("Document is being synced, please try again later")
# sync document indexing
document.indexing_status = "waiting"
document.indexing_status = IndexingStatus.WAITING
data_source_info = document.data_source_info_dict
if data_source_info:
data_source_info["mode"] = "scrape"
@ -1840,7 +1862,7 @@ class DocumentService:
knowledge_config: KnowledgeConfig,
account: Account | Any,
dataset_process_rule: DatasetProcessRule | None = None,
created_from: str = "web",
created_from: str = DocumentCreatedFrom.WEB,
) -> tuple[list[Document], str]:
# check doc_form
DatasetService.check_doc_form(dataset, knowledge_config.doc_form)
@ -1932,7 +1954,7 @@ class DocumentService:
if not dataset_process_rule:
process_rule = knowledge_config.process_rule
if process_rule:
if process_rule.mode in ("custom", "hierarchical"):
if process_rule.mode in (ProcessRuleMode.CUSTOM, ProcessRuleMode.HIERARCHICAL):
if process_rule.rules:
dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id,
@ -1944,7 +1966,7 @@ class DocumentService:
dataset_process_rule = dataset.latest_process_rule
if not dataset_process_rule:
raise ValueError("No process rule found.")
elif process_rule.mode == "automatic":
elif process_rule.mode == ProcessRuleMode.AUTOMATIC:
dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id,
mode=process_rule.mode,
@ -1967,7 +1989,7 @@ class DocumentService:
if not dataset_process_rule:
dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id,
mode="automatic",
mode=ProcessRuleMode.AUTOMATIC,
rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
created_by=account.id,
)
@ -2001,7 +2023,7 @@ class DocumentService:
.where(
Document.dataset_id == dataset.id,
Document.tenant_id == current_user.current_tenant_id,
Document.data_source_type == "upload_file",
Document.data_source_type == DataSourceType.UPLOAD_FILE,
Document.enabled == True,
Document.name.in_(file_names),
)
@ -2021,7 +2043,7 @@ class DocumentService:
document.doc_language = knowledge_config.doc_language
document.data_source_info = json.dumps(data_source_info)
document.batch = batch
document.indexing_status = "waiting"
document.indexing_status = IndexingStatus.WAITING
db.session.add(document)
documents.append(document)
duplicate_document_ids.append(document.id)
@ -2056,7 +2078,7 @@ class DocumentService:
.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
enabled=True,
)
.all()
@ -2507,7 +2529,7 @@ class DocumentService:
document_data: KnowledgeConfig,
account: Account,
dataset_process_rule: DatasetProcessRule | None = None,
created_from: str = "web",
created_from: str = DocumentCreatedFrom.WEB,
):
assert isinstance(current_user, Account)
@ -2520,14 +2542,14 @@ class DocumentService:
# save process rule
if document_data.process_rule:
process_rule = document_data.process_rule
if process_rule.mode in {"custom", "hierarchical"}:
if process_rule.mode in {ProcessRuleMode.CUSTOM, ProcessRuleMode.HIERARCHICAL}:
dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id,
mode=process_rule.mode,
rules=process_rule.rules.model_dump_json() if process_rule.rules else None,
created_by=account.id,
)
elif process_rule.mode == "automatic":
elif process_rule.mode == ProcessRuleMode.AUTOMATIC:
dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id,
mode=process_rule.mode,
@ -2609,7 +2631,7 @@ class DocumentService:
if document_data.name:
document.name = document_data.name
# update document to be waiting
document.indexing_status = "waiting"
document.indexing_status = IndexingStatus.WAITING
document.completed_at = None
document.processing_started_at = None
document.parsing_completed_at = None
@ -2623,7 +2645,7 @@ class DocumentService:
# update document segment
db.session.query(DocumentSegment).filter_by(document_id=document.id).update(
{DocumentSegment.status: "re_segment"}
{DocumentSegment.status: SegmentStatus.RE_SEGMENT}
)
db.session.commit()
# trigger async task
@ -2754,7 +2776,7 @@ class DocumentService:
if knowledge_config.process_rule.mode not in DatasetProcessRule.MODES:
raise ValueError("Process rule mode is invalid")
if knowledge_config.process_rule.mode == "automatic":
if knowledge_config.process_rule.mode == ProcessRuleMode.AUTOMATIC:
knowledge_config.process_rule.rules = None
else:
if not knowledge_config.process_rule.rules:
@ -2785,7 +2807,7 @@ class DocumentService:
raise ValueError("Process rule segmentation separator is invalid")
if not (
knowledge_config.process_rule.mode == "hierarchical"
knowledge_config.process_rule.mode == ProcessRuleMode.HIERARCHICAL
and knowledge_config.process_rule.rules.parent_mode == "full-doc"
):
if not knowledge_config.process_rule.rules.segmentation.max_tokens:
@ -2814,7 +2836,7 @@ class DocumentService:
if args["process_rule"]["mode"] not in DatasetProcessRule.MODES:
raise ValueError("Process rule mode is invalid")
if args["process_rule"]["mode"] == "automatic":
if args["process_rule"]["mode"] == ProcessRuleMode.AUTOMATIC:
args["process_rule"]["rules"] = {}
else:
if "rules" not in args["process_rule"] or not args["process_rule"]["rules"]:
@ -3021,7 +3043,7 @@ class DocumentService:
@staticmethod
def _prepare_disable_update(document, user, now):
"""Prepare updates for disabling a document."""
if not document.completed_at or document.indexing_status != "completed":
if not document.completed_at or document.indexing_status != IndexingStatus.COMPLETED:
raise DocumentIndexingError(f"Document: {document.name} is not completed.")
if not document.enabled:
@ -3130,7 +3152,7 @@ class SegmentService:
content=content,
word_count=len(content),
tokens=tokens,
status="completed",
status=SegmentStatus.COMPLETED,
indexing_at=naive_utc_now(),
completed_at=naive_utc_now(),
created_by=current_user.id,
@ -3167,7 +3189,7 @@ class SegmentService:
logger.exception("create segment index failed")
segment_document.enabled = False
segment_document.disabled_at = naive_utc_now()
segment_document.status = "error"
segment_document.status = SegmentStatus.ERROR
segment_document.error = str(e)
db.session.commit()
segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_document.id).first()
@ -3227,7 +3249,7 @@ class SegmentService:
word_count=len(content),
tokens=tokens,
keywords=segment_item.get("keywords", []),
status="completed",
status=SegmentStatus.COMPLETED,
indexing_at=naive_utc_now(),
completed_at=naive_utc_now(),
created_by=current_user.id,
@ -3259,7 +3281,7 @@ class SegmentService:
for segment_document in segment_data_list:
segment_document.enabled = False
segment_document.disabled_at = naive_utc_now()
segment_document.status = "error"
segment_document.status = SegmentStatus.ERROR
segment_document.error = str(e)
db.session.commit()
return segment_data_list
@ -3405,7 +3427,7 @@ class SegmentService:
segment.index_node_hash = segment_hash
segment.word_count = len(content)
segment.tokens = tokens
segment.status = "completed"
segment.status = SegmentStatus.COMPLETED
segment.indexing_at = naive_utc_now()
segment.completed_at = naive_utc_now()
segment.updated_by = current_user.id
@ -3530,7 +3552,7 @@ class SegmentService:
logger.exception("update segment index failed")
segment.enabled = False
segment.disabled_at = naive_utc_now()
segment.status = "error"
segment.status = SegmentStatus.ERROR
segment.error = str(e)
db.session.commit()
new_segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment.id).first()

View File

@ -13,7 +13,7 @@ from dify_graph.model_runtime.entities import LLMMode
from extensions.ext_database import db
from models import Account
from models.dataset import Dataset, DatasetQuery
from models.enums import CreatorUserRole
from models.enums import CreatorUserRole, DatasetQuerySource
logger = logging.getLogger(__name__)
@ -97,7 +97,7 @@ class HitTestingService:
dataset_query = DatasetQuery(
dataset_id=dataset.id,
content=json.dumps(dataset_queries),
source="hit_testing",
source=DatasetQuerySource.HIT_TESTING,
source_app_id=None,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
@ -137,7 +137,7 @@ class HitTestingService:
dataset_query = DatasetQuery(
dataset_id=dataset.id,
content=query,
source="hit_testing",
source=DatasetQuerySource.HIT_TESTING,
source_app_id=None,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,

View File

@ -7,6 +7,7 @@ from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant
from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding
from models.enums import DatasetMetadataType
from services.dataset_service import DocumentService
from services.entities.knowledge_entities.knowledge_entities import (
MetadataArgs,
@ -130,11 +131,11 @@ class MetadataService:
@staticmethod
def get_built_in_fields():
return [
{"name": BuiltInField.document_name, "type": "string"},
{"name": BuiltInField.uploader, "type": "string"},
{"name": BuiltInField.upload_date, "type": "time"},
{"name": BuiltInField.last_update_date, "type": "time"},
{"name": BuiltInField.source, "type": "string"},
{"name": BuiltInField.document_name, "type": DatasetMetadataType.STRING},
{"name": BuiltInField.uploader, "type": DatasetMetadataType.STRING},
{"name": BuiltInField.upload_date, "type": DatasetMetadataType.TIME},
{"name": BuiltInField.last_update_date, "type": DatasetMetadataType.TIME},
{"name": BuiltInField.source, "type": DatasetMetadataType.STRING},
]
@staticmethod

View File

@ -19,6 +19,7 @@ from dify_graph.model_runtime.entities.provider_entities import (
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.enums import CredentialSourceType
from models.provider import LoadBalancingModelConfig, ProviderCredential, ProviderModelCredential
logger = logging.getLogger(__name__)
@ -103,9 +104,9 @@ class ModelLoadBalancingService:
is_load_balancing_enabled = True
if config_from == "predefined-model":
credential_source_type = "provider"
credential_source_type = CredentialSourceType.PROVIDER
else:
credential_source_type = "custom_model"
credential_source_type = CredentialSourceType.CUSTOM_MODEL
# Get load balancing configurations
load_balancing_configs = (
@ -421,7 +422,11 @@ class ModelLoadBalancingService:
raise ValueError("Invalid load balancing config name")
if credential_id:
credential_source = "provider" if config_from == "predefined-model" else "custom_model"
credential_source = (
CredentialSourceType.PROVIDER
if config_from == "predefined-model"
else CredentialSourceType.CUSTOM_MODEL
)
assert credential_record is not None
load_balancing_model_config = LoadBalancingModelConfig(
tenant_id=tenant_id,

View File

@ -30,7 +30,7 @@ from core.plugin.impl.debugging import PluginDebuggingClient
from core.plugin.impl.plugin import PluginInstaller
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.provider import Provider, ProviderCredential
from models.provider import Provider, ProviderCredential, TenantPreferredModelProvider
from models.provider_ids import GenericProviderID
from services.enterprise.plugin_manager_service import (
PluginManagerService,
@ -534,6 +534,13 @@ class PluginService:
plugin_id = plugin.plugin_id
logger.info("Deleting credentials for plugin: %s", plugin_id)
session.execute(
delete(TenantPreferredModelProvider).where(
TenantPreferredModelProvider.tenant_id == tenant_id,
TenantPreferredModelProvider.provider_name.like(f"{plugin_id}/%"),
)
)
# Delete provider credentials that match this plugin
credential_ids = session.scalars(
select(ProviderCredential.id).where(

View File

@ -6,6 +6,7 @@ from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from models.dataset import Document, Pipeline
from models.enums import IndexingStatus
from models.model import Account, App, EndUser
from models.workflow import Workflow
from services.rag_pipeline.rag_pipeline import RagPipelineService
@ -111,6 +112,6 @@ class PipelineGenerateService:
"""
document = db.session.query(Document).where(Document.id == document_id).first()
if document:
document.indexing_status = "waiting"
document.indexing_status = IndexingStatus.WAITING
db.session.add(document)
db.session.commit()

View File

@ -15,7 +15,8 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
Retrieval recommended app from dify official
"""
def get_pipeline_template_detail(self, template_id: str):
def get_pipeline_template_detail(self, template_id: str) -> dict | None:
result: dict | None
try:
result = self.fetch_pipeline_template_detail_from_dify_official(template_id)
except Exception as e:
@ -35,17 +36,23 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
return PipelineTemplateType.REMOTE
@classmethod
def fetch_pipeline_template_detail_from_dify_official(cls, template_id: str) -> dict | None:
def fetch_pipeline_template_detail_from_dify_official(cls, template_id: str) -> dict:
"""
Fetch pipeline template detail from dify official.
:param template_id: Pipeline ID
:return:
:param template_id: Pipeline template ID
:return: Template detail dict
:raises ValueError: When upstream returns a non-200 status code
"""
domain = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN
url = f"{domain}/pipeline-templates/{template_id}"
response = httpx.get(url, timeout=httpx.Timeout(10.0, connect=3.0))
if response.status_code != 200:
return None
raise ValueError(
"fetch pipeline template detail failed,"
+ f" status_code: {response.status_code},"
+ f" response: {response.text[:1000]}"
)
data: dict = response.json()
return data

View File

@ -64,7 +64,7 @@ from models.dataset import ( # type: ignore
PipelineCustomizedTemplate,
PipelineRecommendedPlugin,
)
from models.enums import WorkflowRunTriggeredFrom
from models.enums import IndexingStatus, WorkflowRunTriggeredFrom
from models.model import EndUser
from models.workflow import (
Workflow,
@ -117,13 +117,21 @@ class RagPipelineService:
def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> dict | None:
"""
Get pipeline template detail.
:param template_id: template id
:return:
:param type: template type, "built-in" or "customized"
:return: template detail dict, or None if not found
"""
if type == "built-in":
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
built_in_result: dict | None = retrieval_instance.get_pipeline_template_detail(template_id)
if built_in_result is None:
logger.warning(
"pipeline template retrieval returned empty result, template_id: %s, mode: %s",
template_id,
mode,
)
return built_in_result
else:
mode = "customized"
@ -906,7 +914,7 @@ class RagPipelineService:
if document_id:
document = db.session.query(Document).where(Document.id == document_id.value).first()
if document:
document.indexing_status = "error"
document.indexing_status = IndexingStatus.ERROR
document.error = error
db.session.add(document)
db.session.commit()

View File

@ -35,6 +35,7 @@ from extensions.ext_redis import redis_client
from factories import variable_factory
from models import Account
from models.dataset import Dataset, DatasetCollectionBinding, Pipeline
from models.enums import CollectionBindingType, DatasetRuntimeMode
from models.workflow import Workflow, WorkflowType
from services.entities.knowledge_entities.rag_pipeline_entities import (
IconInfo,
@ -313,7 +314,7 @@ class RagPipelineDslService:
indexing_technique=knowledge_configuration.indexing_technique,
created_by=account.id,
retrieval_model=knowledge_configuration.retrieval_model.model_dump(),
runtime_mode="rag_pipeline",
runtime_mode=DatasetRuntimeMode.RAG_PIPELINE,
chunk_structure=knowledge_configuration.chunk_structure,
)
if knowledge_configuration.indexing_technique == "high_quality":
@ -323,7 +324,7 @@ class RagPipelineDslService:
DatasetCollectionBinding.provider_name
== knowledge_configuration.embedding_model_provider,
DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model,
DatasetCollectionBinding.type == "dataset",
DatasetCollectionBinding.type == CollectionBindingType.DATASET,
)
.order_by(DatasetCollectionBinding.created_at)
.first()
@ -334,7 +335,7 @@ class RagPipelineDslService:
provider_name=knowledge_configuration.embedding_model_provider,
model_name=knowledge_configuration.embedding_model,
collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())),
type="dataset",
type=CollectionBindingType.DATASET,
)
self._session.add(dataset_collection_binding)
self._session.commit()
@ -445,13 +446,13 @@ class RagPipelineDslService:
indexing_technique=knowledge_configuration.indexing_technique,
created_by=account.id,
retrieval_model=knowledge_configuration.retrieval_model.model_dump(),
runtime_mode="rag_pipeline",
runtime_mode=DatasetRuntimeMode.RAG_PIPELINE,
chunk_structure=knowledge_configuration.chunk_structure,
)
else:
dataset.indexing_technique = knowledge_configuration.indexing_technique
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
dataset.runtime_mode = "rag_pipeline"
dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE
dataset.chunk_structure = knowledge_configuration.chunk_structure
if knowledge_configuration.indexing_technique == "high_quality":
dataset_collection_binding = (
@ -460,7 +461,7 @@ class RagPipelineDslService:
DatasetCollectionBinding.provider_name
== knowledge_configuration.embedding_model_provider,
DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model,
DatasetCollectionBinding.type == "dataset",
DatasetCollectionBinding.type == CollectionBindingType.DATASET,
)
.order_by(DatasetCollectionBinding.created_at)
.first()
@ -471,7 +472,7 @@ class RagPipelineDslService:
provider_name=knowledge_configuration.embedding_model_provider,
model_name=knowledge_configuration.embedding_model,
collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())),
type="dataset",
type=CollectionBindingType.DATASET,
)
self._session.add(dataset_collection_binding)
self._session.commit()

View File

@ -13,6 +13,7 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from factories import variable_factory
from models.dataset import Dataset, Document, DocumentPipelineExecutionLog, Pipeline
from models.enums import DatasetRuntimeMode, DataSourceType
from models.model import UploadFile
from models.workflow import Workflow, WorkflowType
from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeConfiguration, RetrievalSetting
@ -27,7 +28,7 @@ class RagPipelineTransformService:
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise ValueError("Dataset not found")
if dataset.pipeline_id and dataset.runtime_mode == "rag_pipeline":
if dataset.pipeline_id and dataset.runtime_mode == DatasetRuntimeMode.RAG_PIPELINE:
return {
"pipeline_id": dataset.pipeline_id,
"dataset_id": dataset_id,
@ -85,7 +86,7 @@ class RagPipelineTransformService:
else:
raise ValueError("Unsupported doc form")
dataset.runtime_mode = "rag_pipeline"
dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE
dataset.pipeline_id = pipeline.id
# deal document data
@ -102,7 +103,7 @@ class RagPipelineTransformService:
pipeline_yaml = {}
if doc_form == "text_model":
match datasource_type:
case "upload_file":
case DataSourceType.UPLOAD_FILE:
if indexing_technique == "high_quality":
# get graph from transform.file-general-high-quality.yml
with open(f"{Path(__file__).parent}/transform/file-general-high-quality.yml") as f:
@ -111,7 +112,7 @@ class RagPipelineTransformService:
# get graph from transform.file-general-economy.yml
with open(f"{Path(__file__).parent}/transform/file-general-economy.yml") as f:
pipeline_yaml = yaml.safe_load(f)
case "notion_import":
case DataSourceType.NOTION_IMPORT:
if indexing_technique == "high_quality":
# get graph from transform.notion-general-high-quality.yml
with open(f"{Path(__file__).parent}/transform/notion-general-high-quality.yml") as f:
@ -120,7 +121,7 @@ class RagPipelineTransformService:
# get graph from transform.notion-general-economy.yml
with open(f"{Path(__file__).parent}/transform/notion-general-economy.yml") as f:
pipeline_yaml = yaml.safe_load(f)
case "website_crawl":
case DataSourceType.WEBSITE_CRAWL:
if indexing_technique == "high_quality":
# get graph from transform.website-crawl-general-high-quality.yml
with open(f"{Path(__file__).parent}/transform/website-crawl-general-high-quality.yml") as f:
@ -133,15 +134,15 @@ class RagPipelineTransformService:
raise ValueError("Unsupported datasource type")
elif doc_form == "hierarchical_model":
match datasource_type:
case "upload_file":
case DataSourceType.UPLOAD_FILE:
# get graph from transform.file-parentchild.yml
with open(f"{Path(__file__).parent}/transform/file-parentchild.yml") as f:
pipeline_yaml = yaml.safe_load(f)
case "notion_import":
case DataSourceType.NOTION_IMPORT:
# get graph from transform.notion-parentchild.yml
with open(f"{Path(__file__).parent}/transform/notion-parentchild.yml") as f:
pipeline_yaml = yaml.safe_load(f)
case "website_crawl":
case DataSourceType.WEBSITE_CRAWL:
# get graph from transform.website-crawl-parentchild.yml
with open(f"{Path(__file__).parent}/transform/website-crawl-parentchild.yml") as f:
pipeline_yaml = yaml.safe_load(f)
@ -287,7 +288,7 @@ class RagPipelineTransformService:
db.session.flush()
dataset.pipeline_id = pipeline.id
dataset.runtime_mode = "rag_pipeline"
dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE
dataset.updated_by = current_user.id
dataset.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.add(dataset)
@ -310,8 +311,8 @@ class RagPipelineTransformService:
data_source_info_dict = document.data_source_info_dict
if not data_source_info_dict:
continue
if document.data_source_type == "upload_file":
document.data_source_type = "local_file"
if document.data_source_type == DataSourceType.UPLOAD_FILE:
document.data_source_type = DataSourceType.LOCAL_FILE
file_id = data_source_info_dict.get("upload_file_id")
if file_id:
file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
@ -331,7 +332,7 @@ class RagPipelineTransformService:
document_pipeline_execution_log = DocumentPipelineExecutionLog(
document_id=document.id,
pipeline_id=dataset.pipeline_id,
datasource_type="local_file",
datasource_type=DataSourceType.LOCAL_FILE,
datasource_info=data_source_info,
input_data={},
created_by=document.created_by,
@ -340,8 +341,8 @@ class RagPipelineTransformService:
document_pipeline_execution_log.created_at = document.created_at
db.session.add(document)
db.session.add(document_pipeline_execution_log)
elif document.data_source_type == "notion_import":
document.data_source_type = "online_document"
elif document.data_source_type == DataSourceType.NOTION_IMPORT:
document.data_source_type = DataSourceType.ONLINE_DOCUMENT
data_source_info = json.dumps(
{
"workspace_id": data_source_info_dict.get("notion_workspace_id"),
@ -359,7 +360,7 @@ class RagPipelineTransformService:
document_pipeline_execution_log = DocumentPipelineExecutionLog(
document_id=document.id,
pipeline_id=dataset.pipeline_id,
datasource_type="online_document",
datasource_type=DataSourceType.ONLINE_DOCUMENT,
datasource_info=data_source_info,
input_data={},
created_by=document.created_by,
@ -368,8 +369,7 @@ class RagPipelineTransformService:
document_pipeline_execution_log.created_at = document.created_at
db.session.add(document)
db.session.add(document_pipeline_execution_log)
elif document.data_source_type == "website_crawl":
document.data_source_type = "website_crawl"
elif document.data_source_type == DataSourceType.WEBSITE_CRAWL:
data_source_info = json.dumps(
{
"source_url": data_source_info_dict.get("url"),
@ -388,7 +388,7 @@ class RagPipelineTransformService:
document_pipeline_execution_log = DocumentPipelineExecutionLog(
document_id=document.id,
pipeline_id=dataset.pipeline_id,
datasource_type="website_crawl",
datasource_type=DataSourceType.WEBSITE_CRAWL,
datasource_info=data_source_info,
input_data={},
created_by=document.created_by,

View File

@ -1,16 +1,16 @@
import datetime
import logging
import os
import random
import time
from collections.abc import Sequence
from typing import cast
from typing import TYPE_CHECKING, cast
import sqlalchemy as sa
from sqlalchemy import delete, select, tuple_
from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session
from configs import dify_config
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.model import (
@ -33,6 +33,131 @@ from services.retention.conversation.messages_clean_policy import (
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from opentelemetry.metrics import Counter, Histogram
class MessagesCleanupMetrics:
"""
Records low-cardinality OpenTelemetry metrics for expired message cleanup jobs.
We keep labels stable (dry_run/window_mode/task_label/status) so these metrics remain
dashboard-friendly for long-running CronJob executions.
"""
_job_runs_total: "Counter | None"
_batches_total: "Counter | None"
_messages_scanned_total: "Counter | None"
_messages_filtered_total: "Counter | None"
_messages_deleted_total: "Counter | None"
_job_duration_seconds: "Histogram | None"
_batch_duration_seconds: "Histogram | None"
_base_attributes: dict[str, str]
def __init__(self, *, dry_run: bool, has_window: bool, task_label: str) -> None:
self._job_runs_total = None
self._batches_total = None
self._messages_scanned_total = None
self._messages_filtered_total = None
self._messages_deleted_total = None
self._job_duration_seconds = None
self._batch_duration_seconds = None
self._base_attributes = {
"job_name": "messages_cleanup",
"dry_run": str(dry_run).lower(),
"window_mode": "between" if has_window else "before_cutoff",
"task_label": task_label,
}
self._init_instruments()
def _init_instruments(self) -> None:
if not dify_config.ENABLE_OTEL:
return
try:
from opentelemetry.metrics import get_meter
meter = get_meter("messages_cleanup", version=dify_config.project.version)
self._job_runs_total = meter.create_counter(
"messages_cleanup_jobs_total",
description="Total number of expired message cleanup jobs by status.",
unit="{job}",
)
self._batches_total = meter.create_counter(
"messages_cleanup_batches_total",
description="Total number of message cleanup batches processed.",
unit="{batch}",
)
self._messages_scanned_total = meter.create_counter(
"messages_cleanup_scanned_messages_total",
description="Total messages scanned by cleanup jobs.",
unit="{message}",
)
self._messages_filtered_total = meter.create_counter(
"messages_cleanup_filtered_messages_total",
description="Total messages selected by cleanup policy.",
unit="{message}",
)
self._messages_deleted_total = meter.create_counter(
"messages_cleanup_deleted_messages_total",
description="Total messages deleted by cleanup jobs.",
unit="{message}",
)
self._job_duration_seconds = meter.create_histogram(
"messages_cleanup_job_duration_seconds",
description="Duration of expired message cleanup jobs in seconds.",
unit="s",
)
self._batch_duration_seconds = meter.create_histogram(
"messages_cleanup_batch_duration_seconds",
description="Duration of expired message cleanup batch processing in seconds.",
unit="s",
)
except Exception:
logger.exception("messages_cleanup_metrics: failed to initialize instruments")
def _attrs(self, **extra: str) -> dict[str, str]:
return {**self._base_attributes, **extra}
@staticmethod
def _add(counter: "Counter | None", value: int, attributes: dict[str, str]) -> None:
if not counter or value <= 0:
return
try:
counter.add(value, attributes)
except Exception:
logger.exception("messages_cleanup_metrics: failed to add counter value")
@staticmethod
def _record(histogram: "Histogram | None", value: float, attributes: dict[str, str]) -> None:
if not histogram:
return
try:
histogram.record(value, attributes)
except Exception:
logger.exception("messages_cleanup_metrics: failed to record histogram value")
def record_batch(
self,
*,
scanned_messages: int,
filtered_messages: int,
deleted_messages: int,
batch_duration_seconds: float,
) -> None:
attributes = self._attrs()
self._add(self._batches_total, 1, attributes)
self._add(self._messages_scanned_total, scanned_messages, attributes)
self._add(self._messages_filtered_total, filtered_messages, attributes)
self._add(self._messages_deleted_total, deleted_messages, attributes)
self._record(self._batch_duration_seconds, batch_duration_seconds, attributes)
def record_completion(self, *, status: str, job_duration_seconds: float) -> None:
attributes = self._attrs(status=status)
self._add(self._job_runs_total, 1, attributes)
self._record(self._job_duration_seconds, job_duration_seconds, attributes)
class MessagesCleanService:
"""
Service for cleaning expired messages based on retention policies.
@ -48,6 +173,7 @@ class MessagesCleanService:
start_from: datetime.datetime | None = None,
batch_size: int = 1000,
dry_run: bool = False,
task_label: str = "custom",
) -> None:
"""
Initialize the service with cleanup parameters.
@ -58,12 +184,18 @@ class MessagesCleanService:
start_from: Optional start time (inclusive) of the range
batch_size: Number of messages to process per batch
dry_run: Whether to perform a dry run (no actual deletion)
task_label: Optional task label for retention metrics
"""
self._policy = policy
self._end_before = end_before
self._start_from = start_from
self._batch_size = batch_size
self._dry_run = dry_run
self._metrics = MessagesCleanupMetrics(
dry_run=dry_run,
has_window=bool(start_from),
task_label=task_label,
)
@classmethod
def from_time_range(
@ -73,6 +205,7 @@ class MessagesCleanService:
end_before: datetime.datetime,
batch_size: int = 1000,
dry_run: bool = False,
task_label: str = "custom",
) -> "MessagesCleanService":
"""
Create a service instance for cleaning messages within a specific time range.
@ -85,6 +218,7 @@ class MessagesCleanService:
end_before: End time (exclusive) of the range
batch_size: Number of messages to process per batch
dry_run: Whether to perform a dry run (no actual deletion)
task_label: Optional task label for retention metrics
Returns:
MessagesCleanService instance
@ -112,6 +246,7 @@ class MessagesCleanService:
start_from=start_from,
batch_size=batch_size,
dry_run=dry_run,
task_label=task_label,
)
@classmethod
@ -121,6 +256,7 @@ class MessagesCleanService:
days: int = 30,
batch_size: int = 1000,
dry_run: bool = False,
task_label: str = "custom",
) -> "MessagesCleanService":
"""
Create a service instance for cleaning messages older than specified days.
@ -130,6 +266,7 @@ class MessagesCleanService:
days: Number of days to look back from now
batch_size: Number of messages to process per batch
dry_run: Whether to perform a dry run (no actual deletion)
task_label: Optional task label for retention metrics
Returns:
MessagesCleanService instance
@ -153,7 +290,14 @@ class MessagesCleanService:
policy.__class__.__name__,
)
return cls(policy=policy, end_before=end_before, start_from=None, batch_size=batch_size, dry_run=dry_run)
return cls(
policy=policy,
end_before=end_before,
start_from=None,
batch_size=batch_size,
dry_run=dry_run,
task_label=task_label,
)
def run(self) -> dict[str, int]:
"""
@ -162,7 +306,18 @@ class MessagesCleanService:
Returns:
Dict with statistics: batches, filtered_messages, total_deleted
"""
return self._clean_messages_by_time_range()
status = "success"
run_start = time.monotonic()
try:
return self._clean_messages_by_time_range()
except Exception:
status = "failed"
raise
finally:
self._metrics.record_completion(
status=status,
job_duration_seconds=time.monotonic() - run_start,
)
def _clean_messages_by_time_range(self) -> dict[str, int]:
"""
@ -197,11 +352,14 @@ class MessagesCleanService:
self._end_before,
)
max_batch_interval_ms = int(os.environ.get("SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL", 200))
max_batch_interval_ms = dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL
while True:
stats["batches"] += 1
batch_start = time.monotonic()
batch_scanned_messages = 0
batch_filtered_messages = 0
batch_deleted_messages = 0
# Step 1: Fetch a batch of messages using cursor
with Session(db.engine, expire_on_commit=False) as session:
@ -240,9 +398,16 @@ class MessagesCleanService:
# Track total messages fetched across all batches
stats["total_messages"] += len(messages)
batch_scanned_messages = len(messages)
if not messages:
logger.info("clean_messages (batch %s): no more messages to process", stats["batches"])
self._metrics.record_batch(
scanned_messages=batch_scanned_messages,
filtered_messages=batch_filtered_messages,
deleted_messages=batch_deleted_messages,
batch_duration_seconds=time.monotonic() - batch_start,
)
break
# Update cursor to the last message's (created_at, id)
@ -268,6 +433,12 @@ class MessagesCleanService:
if not apps:
logger.info("clean_messages (batch %s): no apps found, skip", stats["batches"])
self._metrics.record_batch(
scanned_messages=batch_scanned_messages,
filtered_messages=batch_filtered_messages,
deleted_messages=batch_deleted_messages,
batch_duration_seconds=time.monotonic() - batch_start,
)
continue
# Build app_id -> tenant_id mapping
@ -286,9 +457,16 @@ class MessagesCleanService:
if not message_ids_to_delete:
logger.info("clean_messages (batch %s): no messages to delete, skip", stats["batches"])
self._metrics.record_batch(
scanned_messages=batch_scanned_messages,
filtered_messages=batch_filtered_messages,
deleted_messages=batch_deleted_messages,
batch_duration_seconds=time.monotonic() - batch_start,
)
continue
stats["filtered_messages"] += len(message_ids_to_delete)
batch_filtered_messages = len(message_ids_to_delete)
# Step 4: Batch delete messages and their relations
if not self._dry_run:
@ -309,6 +487,7 @@ class MessagesCleanService:
commit_ms = int((time.monotonic() - commit_start) * 1000)
stats["total_deleted"] += messages_deleted
batch_deleted_messages = messages_deleted
logger.info(
"clean_messages (batch %s): processed %s messages, deleted %s messages",
@ -343,6 +522,13 @@ class MessagesCleanService:
for msg_id in sampled_ids:
logger.info("clean_messages (batch %s, dry_run) sample: message_id=%s", stats["batches"], msg_id)
self._metrics.record_batch(
scanned_messages=batch_scanned_messages,
filtered_messages=batch_filtered_messages,
deleted_messages=batch_deleted_messages,
batch_duration_seconds=time.monotonic() - batch_start,
)
logger.info(
"clean_messages completed: total batches: %s, total messages: %s, filtered messages: %s, total deleted: %s",
stats["batches"],

View File

@ -1,9 +1,9 @@
import datetime
import logging
import os
import random
import time
from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING
import click
from sqlalchemy.orm import Session, sessionmaker
@ -20,6 +20,159 @@ from services.billing_service import BillingService, SubscriptionPlan
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from opentelemetry.metrics import Counter, Histogram
class WorkflowRunCleanupMetrics:
"""
Records low-cardinality OpenTelemetry metrics for workflow run cleanup jobs.
Metrics are emitted with stable labels only (dry_run/window_mode/task_label/status)
to keep dashboard and alert cardinality predictable in production clusters.
"""
_job_runs_total: "Counter | None"
_batches_total: "Counter | None"
_runs_scanned_total: "Counter | None"
_runs_targeted_total: "Counter | None"
_runs_deleted_total: "Counter | None"
_runs_skipped_total: "Counter | None"
_related_records_total: "Counter | None"
_job_duration_seconds: "Histogram | None"
_batch_duration_seconds: "Histogram | None"
_base_attributes: dict[str, str]
def __init__(self, *, dry_run: bool, has_window: bool, task_label: str) -> None:
self._job_runs_total = None
self._batches_total = None
self._runs_scanned_total = None
self._runs_targeted_total = None
self._runs_deleted_total = None
self._runs_skipped_total = None
self._related_records_total = None
self._job_duration_seconds = None
self._batch_duration_seconds = None
self._base_attributes = {
"job_name": "workflow_run_cleanup",
"dry_run": str(dry_run).lower(),
"window_mode": "between" if has_window else "before_cutoff",
"task_label": task_label,
}
self._init_instruments()
def _init_instruments(self) -> None:
if not dify_config.ENABLE_OTEL:
return
try:
from opentelemetry.metrics import get_meter
meter = get_meter("workflow_run_cleanup", version=dify_config.project.version)
self._job_runs_total = meter.create_counter(
"workflow_run_cleanup_jobs_total",
description="Total number of workflow run cleanup jobs by status.",
unit="{job}",
)
self._batches_total = meter.create_counter(
"workflow_run_cleanup_batches_total",
description="Total number of processed cleanup batches.",
unit="{batch}",
)
self._runs_scanned_total = meter.create_counter(
"workflow_run_cleanup_scanned_runs_total",
description="Total workflow runs scanned by cleanup jobs.",
unit="{run}",
)
self._runs_targeted_total = meter.create_counter(
"workflow_run_cleanup_targeted_runs_total",
description="Total workflow runs targeted by cleanup policy.",
unit="{run}",
)
self._runs_deleted_total = meter.create_counter(
"workflow_run_cleanup_deleted_runs_total",
description="Total workflow runs deleted by cleanup jobs.",
unit="{run}",
)
self._runs_skipped_total = meter.create_counter(
"workflow_run_cleanup_skipped_runs_total",
description="Total workflow runs skipped because tenant is paid/unknown.",
unit="{run}",
)
self._related_records_total = meter.create_counter(
"workflow_run_cleanup_related_records_total",
description="Total related records processed by cleanup jobs.",
unit="{record}",
)
self._job_duration_seconds = meter.create_histogram(
"workflow_run_cleanup_job_duration_seconds",
description="Duration of workflow run cleanup jobs in seconds.",
unit="s",
)
self._batch_duration_seconds = meter.create_histogram(
"workflow_run_cleanup_batch_duration_seconds",
description="Duration of workflow run cleanup batch processing in seconds.",
unit="s",
)
except Exception:
logger.exception("workflow_run_cleanup_metrics: failed to initialize instruments")
def _attrs(self, **extra: str) -> dict[str, str]:
return {**self._base_attributes, **extra}
@staticmethod
def _add(counter: "Counter | None", value: int, attributes: dict[str, str]) -> None:
if not counter or value <= 0:
return
try:
counter.add(value, attributes)
except Exception:
logger.exception("workflow_run_cleanup_metrics: failed to add counter value")
@staticmethod
def _record(histogram: "Histogram | None", value: float, attributes: dict[str, str]) -> None:
if not histogram:
return
try:
histogram.record(value, attributes)
except Exception:
logger.exception("workflow_run_cleanup_metrics: failed to record histogram value")
def record_batch(
self,
*,
batch_rows: int,
targeted_runs: int,
skipped_runs: int,
deleted_runs: int,
related_counts: dict[str, int] | None,
related_action: str | None,
batch_duration_seconds: float,
) -> None:
attributes = self._attrs()
self._add(self._batches_total, 1, attributes)
self._add(self._runs_scanned_total, batch_rows, attributes)
self._add(self._runs_targeted_total, targeted_runs, attributes)
self._add(self._runs_skipped_total, skipped_runs, attributes)
self._add(self._runs_deleted_total, deleted_runs, attributes)
self._record(self._batch_duration_seconds, batch_duration_seconds, attributes)
if not related_counts or not related_action:
return
for record_type, count in related_counts.items():
self._add(
self._related_records_total,
count,
self._attrs(action=related_action, record_type=record_type),
)
def record_completion(self, *, status: str, job_duration_seconds: float) -> None:
attributes = self._attrs(status=status)
self._add(self._job_runs_total, 1, attributes)
self._record(self._job_duration_seconds, job_duration_seconds, attributes)
class WorkflowRunCleanup:
def __init__(
self,
@ -29,6 +182,7 @@ class WorkflowRunCleanup:
end_before: datetime.datetime | None = None,
workflow_run_repo: APIWorkflowRunRepository | None = None,
dry_run: bool = False,
task_label: str = "custom",
):
if (start_from is None) ^ (end_before is None):
raise ValueError("start_from and end_before must be both set or both omitted.")
@ -46,6 +200,11 @@ class WorkflowRunCleanup:
self.batch_size = batch_size
self._cleanup_whitelist: set[str] | None = None
self.dry_run = dry_run
self._metrics = WorkflowRunCleanupMetrics(
dry_run=dry_run,
has_window=bool(start_from),
task_label=task_label,
)
self.free_plan_grace_period_days = dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD
self.workflow_run_repo: APIWorkflowRunRepository
if workflow_run_repo:
@ -74,153 +233,193 @@ class WorkflowRunCleanup:
related_totals = self._empty_related_counts() if self.dry_run else None
batch_index = 0
last_seen: tuple[datetime.datetime, str] | None = None
status = "success"
run_start = time.monotonic()
max_batch_interval_ms = dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL
max_batch_interval_ms = int(os.environ.get("SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL", 200))
try:
while True:
batch_start = time.monotonic()
while True:
batch_start = time.monotonic()
fetch_start = time.monotonic()
run_rows = self.workflow_run_repo.get_runs_batch_by_time_range(
start_from=self.window_start,
end_before=self.window_end,
last_seen=last_seen,
batch_size=self.batch_size,
)
if not run_rows:
logger.info("workflow_run_cleanup (batch #%s): no more rows to process", batch_index + 1)
break
batch_index += 1
last_seen = (run_rows[-1].created_at, run_rows[-1].id)
logger.info(
"workflow_run_cleanup (batch #%s): fetched %s rows in %sms",
batch_index,
len(run_rows),
int((time.monotonic() - fetch_start) * 1000),
)
tenant_ids = {row.tenant_id for row in run_rows}
filter_start = time.monotonic()
free_tenants = self._filter_free_tenants(tenant_ids)
logger.info(
"workflow_run_cleanup (batch #%s): filtered %s free tenants from %s tenants in %sms",
batch_index,
len(free_tenants),
len(tenant_ids),
int((time.monotonic() - filter_start) * 1000),
)
free_runs = [row for row in run_rows if row.tenant_id in free_tenants]
paid_or_skipped = len(run_rows) - len(free_runs)
if not free_runs:
skipped_message = (
f"[batch #{batch_index}] skipped (no sandbox runs in batch, {paid_or_skipped} paid/unknown)"
fetch_start = time.monotonic()
run_rows = self.workflow_run_repo.get_runs_batch_by_time_range(
start_from=self.window_start,
end_before=self.window_end,
last_seen=last_seen,
batch_size=self.batch_size,
)
click.echo(
click.style(
skipped_message,
fg="yellow",
)
)
continue
if not run_rows:
logger.info("workflow_run_cleanup (batch #%s): no more rows to process", batch_index + 1)
break
total_runs_targeted += len(free_runs)
if self.dry_run:
count_start = time.monotonic()
batch_counts = self.workflow_run_repo.count_runs_with_related(
free_runs,
count_node_executions=self._count_node_executions,
count_trigger_logs=self._count_trigger_logs,
)
batch_index += 1
last_seen = (run_rows[-1].created_at, run_rows[-1].id)
logger.info(
"workflow_run_cleanup (batch #%s, dry_run): counted related records in %sms",
"workflow_run_cleanup (batch #%s): fetched %s rows in %sms",
batch_index,
int((time.monotonic() - count_start) * 1000),
len(run_rows),
int((time.monotonic() - fetch_start) * 1000),
)
if related_totals is not None:
for key in related_totals:
related_totals[key] += batch_counts.get(key, 0)
sample_ids = ", ".join(run.id for run in free_runs[:5])
tenant_ids = {row.tenant_id for row in run_rows}
filter_start = time.monotonic()
free_tenants = self._filter_free_tenants(tenant_ids)
logger.info(
"workflow_run_cleanup (batch #%s): filtered %s free tenants from %s tenants in %sms",
batch_index,
len(free_tenants),
len(tenant_ids),
int((time.monotonic() - filter_start) * 1000),
)
free_runs = [row for row in run_rows if row.tenant_id in free_tenants]
paid_or_skipped = len(run_rows) - len(free_runs)
if not free_runs:
skipped_message = (
f"[batch #{batch_index}] skipped (no sandbox runs in batch, {paid_or_skipped} paid/unknown)"
)
click.echo(
click.style(
skipped_message,
fg="yellow",
)
)
self._metrics.record_batch(
batch_rows=len(run_rows),
targeted_runs=0,
skipped_runs=paid_or_skipped,
deleted_runs=0,
related_counts=None,
related_action=None,
batch_duration_seconds=time.monotonic() - batch_start,
)
continue
total_runs_targeted += len(free_runs)
if self.dry_run:
count_start = time.monotonic()
batch_counts = self.workflow_run_repo.count_runs_with_related(
free_runs,
count_node_executions=self._count_node_executions,
count_trigger_logs=self._count_trigger_logs,
)
logger.info(
"workflow_run_cleanup (batch #%s, dry_run): counted related records in %sms",
batch_index,
int((time.monotonic() - count_start) * 1000),
)
if related_totals is not None:
for key in related_totals:
related_totals[key] += batch_counts.get(key, 0)
sample_ids = ", ".join(run.id for run in free_runs[:5])
click.echo(
click.style(
f"[batch #{batch_index}] would delete {len(free_runs)} runs "
f"(sample ids: {sample_ids}) and skip {paid_or_skipped} paid/unknown",
fg="yellow",
)
)
logger.info(
"workflow_run_cleanup (batch #%s, dry_run): batch total %sms",
batch_index,
int((time.monotonic() - batch_start) * 1000),
)
self._metrics.record_batch(
batch_rows=len(run_rows),
targeted_runs=len(free_runs),
skipped_runs=paid_or_skipped,
deleted_runs=0,
related_counts={key: batch_counts.get(key, 0) for key in self._empty_related_counts()},
related_action="would_delete",
batch_duration_seconds=time.monotonic() - batch_start,
)
continue
try:
delete_start = time.monotonic()
counts = self.workflow_run_repo.delete_runs_with_related(
free_runs,
delete_node_executions=self._delete_node_executions,
delete_trigger_logs=self._delete_trigger_logs,
)
delete_ms = int((time.monotonic() - delete_start) * 1000)
except Exception:
logger.exception("Failed to delete workflow runs batch ending at %s", last_seen[0])
raise
total_runs_deleted += counts["runs"]
click.echo(
click.style(
f"[batch #{batch_index}] would delete {len(free_runs)} runs "
f"(sample ids: {sample_ids}) and skip {paid_or_skipped} paid/unknown",
fg="yellow",
f"[batch #{batch_index}] deleted runs: {counts['runs']} "
f"(nodes {counts['node_executions']}, offloads {counts['offloads']}, "
f"app_logs {counts['app_logs']}, trigger_logs {counts['trigger_logs']}, "
f"pauses {counts['pauses']}, pause_reasons {counts['pause_reasons']}); "
f"skipped {paid_or_skipped} paid/unknown",
fg="green",
)
)
logger.info(
"workflow_run_cleanup (batch #%s, dry_run): batch total %sms",
"workflow_run_cleanup (batch #%s): delete %sms, batch total %sms",
batch_index,
delete_ms,
int((time.monotonic() - batch_start) * 1000),
)
continue
try:
delete_start = time.monotonic()
counts = self.workflow_run_repo.delete_runs_with_related(
free_runs,
delete_node_executions=self._delete_node_executions,
delete_trigger_logs=self._delete_trigger_logs,
self._metrics.record_batch(
batch_rows=len(run_rows),
targeted_runs=len(free_runs),
skipped_runs=paid_or_skipped,
deleted_runs=counts["runs"],
related_counts={key: counts.get(key, 0) for key in self._empty_related_counts()},
related_action="deleted",
batch_duration_seconds=time.monotonic() - batch_start,
)
delete_ms = int((time.monotonic() - delete_start) * 1000)
except Exception:
logger.exception("Failed to delete workflow runs batch ending at %s", last_seen[0])
raise
total_runs_deleted += counts["runs"]
click.echo(
click.style(
f"[batch #{batch_index}] deleted runs: {counts['runs']} "
f"(nodes {counts['node_executions']}, offloads {counts['offloads']}, "
f"app_logs {counts['app_logs']}, trigger_logs {counts['trigger_logs']}, "
f"pauses {counts['pauses']}, pause_reasons {counts['pause_reasons']}); "
f"skipped {paid_or_skipped} paid/unknown",
fg="green",
)
)
logger.info(
"workflow_run_cleanup (batch #%s): delete %sms, batch total %sms",
batch_index,
delete_ms,
int((time.monotonic() - batch_start) * 1000),
)
# Random sleep between batches to avoid overwhelming the database
sleep_ms = random.uniform(0, max_batch_interval_ms) # noqa: S311
logger.info("workflow_run_cleanup (batch #%s): sleeping for %.2fms", batch_index, sleep_ms)
time.sleep(sleep_ms / 1000)
# Random sleep between batches to avoid overwhelming the database
sleep_ms = random.uniform(0, max_batch_interval_ms) # noqa: S311
logger.info("workflow_run_cleanup (batch #%s): sleeping for %.2fms", batch_index, sleep_ms)
time.sleep(sleep_ms / 1000)
if self.dry_run:
if self.window_start:
summary_message = (
f"Dry run complete. Would delete {total_runs_targeted} workflow runs "
f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}"
)
if self.dry_run:
if self.window_start:
summary_message = (
f"Dry run complete. Would delete {total_runs_targeted} workflow runs "
f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}"
)
else:
summary_message = (
f"Dry run complete. Would delete {total_runs_targeted} workflow runs "
f"before {self.window_end.isoformat()}"
)
if related_totals is not None:
summary_message = (
f"{summary_message}; related records: {self._format_related_counts(related_totals)}"
)
summary_color = "yellow"
else:
summary_message = (
f"Dry run complete. Would delete {total_runs_targeted} workflow runs "
f"before {self.window_end.isoformat()}"
)
if related_totals is not None:
summary_message = f"{summary_message}; related records: {self._format_related_counts(related_totals)}"
summary_color = "yellow"
else:
if self.window_start:
summary_message = (
f"Cleanup complete. Deleted {total_runs_deleted} workflow runs "
f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}"
)
else:
summary_message = (
f"Cleanup complete. Deleted {total_runs_deleted} workflow runs before {self.window_end.isoformat()}"
)
summary_color = "white"
if self.window_start:
summary_message = (
f"Cleanup complete. Deleted {total_runs_deleted} workflow runs "
f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}"
)
else:
summary_message = (
f"Cleanup complete. Deleted {total_runs_deleted} workflow runs "
f"before {self.window_end.isoformat()}"
)
summary_color = "white"
click.echo(click.style(summary_message, fg=summary_color))
click.echo(click.style(summary_message, fg=summary_color))
except Exception:
status = "failed"
raise
finally:
self._metrics.record_completion(
status=status,
job_duration_seconds=time.monotonic() - run_start,
)
def _filter_free_tenants(self, tenant_ids: Iterable[str]) -> set[str]:
tenant_id_list = list(tenant_ids)

View File

@ -12,12 +12,14 @@ from core.db.session_factory import session_factory
from core.model_manager import ModelManager
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
from core.rag.models.document import Document
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.model_runtime.entities.model_entities import ModelType
from libs import helper
from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary
from models.dataset import Document as DatasetDocument
from models.enums import SummaryStatus
logger = logging.getLogger(__name__)
@ -29,7 +31,7 @@ class SummaryIndexService:
def generate_summary_for_segment(
segment: DocumentSegment,
dataset: Dataset,
summary_index_setting: dict,
summary_index_setting: SummaryIndexSettingDict,
) -> tuple[str, LLMUsage]:
"""
Generate summary for a single segment.
@ -73,7 +75,7 @@ class SummaryIndexService:
segment: DocumentSegment,
dataset: Dataset,
summary_content: str,
status: str = "generating",
status: SummaryStatus = SummaryStatus.GENERATING,
) -> DocumentSegmentSummary:
"""
Create or update a DocumentSegmentSummary record.
@ -83,7 +85,7 @@ class SummaryIndexService:
segment: DocumentSegment to create summary for
dataset: Dataset containing the segment
summary_content: Generated summary content
status: Summary status (default: "generating")
status: Summary status (default: SummaryStatus.GENERATING)
Returns:
Created or updated DocumentSegmentSummary instance
@ -326,7 +328,7 @@ class SummaryIndexService:
summary_index_node_id=summary_index_node_id,
summary_index_node_hash=summary_hash,
tokens=embedding_tokens,
status="completed",
status=SummaryStatus.COMPLETED,
enabled=True,
)
session.add(summary_record_in_session)
@ -362,7 +364,7 @@ class SummaryIndexService:
summary_record_in_session.summary_index_node_id = summary_index_node_id
summary_record_in_session.summary_index_node_hash = summary_hash
summary_record_in_session.tokens = embedding_tokens # Save embedding tokens
summary_record_in_session.status = "completed"
summary_record_in_session.status = SummaryStatus.COMPLETED
# Ensure summary_content is preserved (use the latest from summary_record parameter)
# This is critical: use the parameter value, not the database value
summary_record_in_session.summary_content = summary_content
@ -400,7 +402,7 @@ class SummaryIndexService:
summary_record.summary_index_node_id = summary_index_node_id
summary_record.summary_index_node_hash = summary_hash
summary_record.tokens = embedding_tokens
summary_record.status = "completed"
summary_record.status = SummaryStatus.COMPLETED
summary_record.summary_content = summary_content
if summary_record_in_session.updated_at:
summary_record.updated_at = summary_record_in_session.updated_at
@ -487,7 +489,7 @@ class SummaryIndexService:
)
if summary_record_in_session:
summary_record_in_session.status = "error"
summary_record_in_session.status = SummaryStatus.ERROR
summary_record_in_session.error = f"Vectorization failed: {str(e)}"
summary_record_in_session.updated_at = datetime.now(UTC).replace(tzinfo=None)
error_session.add(summary_record_in_session)
@ -498,7 +500,7 @@ class SummaryIndexService:
summary_record_in_session.id,
)
# Update the original object for consistency
summary_record.status = "error"
summary_record.status = SummaryStatus.ERROR
summary_record.error = summary_record_in_session.error
summary_record.updated_at = summary_record_in_session.updated_at
else:
@ -514,7 +516,7 @@ class SummaryIndexService:
def batch_create_summary_records(
segments: list[DocumentSegment],
dataset: Dataset,
status: str = "not_started",
status: SummaryStatus = SummaryStatus.NOT_STARTED,
) -> None:
"""
Batch create summary records for segments with specified status.
@ -523,7 +525,7 @@ class SummaryIndexService:
Args:
segments: List of DocumentSegment instances
dataset: Dataset containing the segments
status: Initial status for the records (default: "not_started")
status: Initial status for the records (default: SummaryStatus.NOT_STARTED)
"""
segment_ids = [segment.id for segment in segments]
if not segment_ids:
@ -588,7 +590,7 @@ class SummaryIndexService:
)
if summary_record:
summary_record.status = "error"
summary_record.status = SummaryStatus.ERROR
summary_record.error = error
session.add(summary_record)
session.commit()
@ -599,7 +601,7 @@ class SummaryIndexService:
def generate_and_vectorize_summary(
segment: DocumentSegment,
dataset: Dataset,
summary_index_setting: dict,
summary_index_setting: SummaryIndexSettingDict,
) -> DocumentSegmentSummary:
"""
Generate summary for a segment and vectorize it.
@ -631,14 +633,14 @@ class SummaryIndexService:
document_id=segment.document_id,
chunk_id=segment.id,
summary_content="",
status="generating",
status=SummaryStatus.GENERATING,
enabled=True,
)
session.add(summary_record_in_session)
session.flush()
# Update status to "generating"
summary_record_in_session.status = "generating"
summary_record_in_session.status = SummaryStatus.GENERATING
summary_record_in_session.error = None # type: ignore[assignment]
session.add(summary_record_in_session)
# Don't flush here - wait until after vectorization succeeds
@ -681,7 +683,7 @@ class SummaryIndexService:
except Exception as vectorize_error:
# If vectorization fails, update status to error in current session
logger.exception("Failed to vectorize summary for segment %s", segment.id)
summary_record_in_session.status = "error"
summary_record_in_session.status = SummaryStatus.ERROR
summary_record_in_session.error = f"Vectorization failed: {str(vectorize_error)}"
session.add(summary_record_in_session)
session.commit()
@ -694,7 +696,7 @@ class SummaryIndexService:
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
)
if summary_record_in_session:
summary_record_in_session.status = "error"
summary_record_in_session.status = SummaryStatus.ERROR
summary_record_in_session.error = str(e)
session.add(summary_record_in_session)
session.commit()
@ -704,7 +706,7 @@ class SummaryIndexService:
def generate_summaries_for_document(
dataset: Dataset,
document: DatasetDocument,
summary_index_setting: dict,
summary_index_setting: SummaryIndexSettingDict,
segment_ids: list[str] | None = None,
only_parent_chunks: bool = False,
) -> list[DocumentSegmentSummary]:
@ -770,7 +772,7 @@ class SummaryIndexService:
SummaryIndexService.batch_create_summary_records(
segments=segments,
dataset=dataset,
status="not_started",
status=SummaryStatus.NOT_STARTED,
)
summary_records = []
@ -1067,7 +1069,7 @@ class SummaryIndexService:
# Update summary content
summary_record.summary_content = summary_content
summary_record.status = "generating"
summary_record.status = SummaryStatus.GENERATING
summary_record.error = None # type: ignore[assignment] # Clear any previous errors
session.add(summary_record)
# Flush to ensure summary_content is saved before vectorize_summary queries it
@ -1102,7 +1104,7 @@ class SummaryIndexService:
# If vectorization fails, update status to error in current session
# Don't raise the exception - just log it and return the record with error status
# This allows the segment update to complete even if vectorization fails
summary_record.status = "error"
summary_record.status = SummaryStatus.ERROR
summary_record.error = f"Vectorization failed: {str(e)}"
session.commit()
logger.exception("Failed to vectorize summary for segment %s", segment.id)
@ -1112,7 +1114,7 @@ class SummaryIndexService:
else:
# Create new summary record if doesn't exist
summary_record = SummaryIndexService.create_summary_record(
segment, dataset, summary_content, status="generating"
segment, dataset, summary_content, status=SummaryStatus.GENERATING
)
# Re-vectorize summary (this will update status to "completed" and tokens in its own session)
# Note: summary_record was created in a different session,
@ -1132,7 +1134,7 @@ class SummaryIndexService:
# If vectorization fails, update status to error in current session
# Merge the record into current session first
error_record = session.merge(summary_record)
error_record.status = "error"
error_record.status = SummaryStatus.ERROR
error_record.error = f"Vectorization failed: {str(e)}"
session.commit()
logger.exception("Failed to vectorize summary for segment %s", segment.id)
@ -1146,7 +1148,7 @@ class SummaryIndexService:
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
)
if summary_record:
summary_record.status = "error"
summary_record.status = SummaryStatus.ERROR
summary_record.error = str(e)
session.add(summary_record)
session.commit()
@ -1266,7 +1268,7 @@ class SummaryIndexService:
# Check if there are any "not_started" or "generating" status summaries
has_pending_summaries = any(
summary_status_map.get(segment_id) is not None # Ensure summary exists (enabled=True)
and summary_status_map[segment_id] in ("not_started", "generating")
and summary_status_map[segment_id] in (SummaryStatus.NOT_STARTED, SummaryStatus.GENERATING)
for segment_id in segment_ids
)
@ -1330,7 +1332,7 @@ class SummaryIndexService:
# it means the summary is disabled (enabled=False) or not created yet, ignore it
has_pending_summaries = any(
summary_status_map.get(segment_id) is not None # Ensure summary exists (enabled=True)
and summary_status_map[segment_id] in ("not_started", "generating")
and summary_status_map[segment_id] in (SummaryStatus.NOT_STARTED, SummaryStatus.GENERATING)
for segment_id in segment_ids
)
@ -1393,17 +1395,17 @@ class SummaryIndexService:
# Count statuses
status_counts = {
"completed": 0,
"generating": 0,
"error": 0,
"not_started": 0,
SummaryStatus.COMPLETED: 0,
SummaryStatus.GENERATING: 0,
SummaryStatus.ERROR: 0,
SummaryStatus.NOT_STARTED: 0,
}
summary_list = []
for segment in segments:
summary = summary_map.get(segment.id)
if summary:
status = summary.status
status = SummaryStatus(summary.status)
status_counts[status] = status_counts.get(status, 0) + 1
summary_list.append(
{
@ -1421,12 +1423,12 @@ class SummaryIndexService:
}
)
else:
status_counts["not_started"] += 1
status_counts[SummaryStatus.NOT_STARTED] += 1
summary_list.append(
{
"segment_id": segment.id,
"segment_position": segment.position,
"status": "not_started",
"status": SummaryStatus.NOT_STARTED,
"summary_preview": None,
"error": None,
"created_at": None,

View File

@ -13,6 +13,7 @@ from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import DatasetAutoDisableLog, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.enums import IndexingStatus, SegmentStatus
logger = logging.getLogger(__name__)
@ -34,7 +35,7 @@ def add_document_to_index_task(dataset_document_id: str):
logger.info(click.style(f"Document not found: {dataset_document_id}", fg="red"))
return
if dataset_document.indexing_status != "completed":
if dataset_document.indexing_status != IndexingStatus.COMPLETED:
return
indexing_cache_key = f"document_{dataset_document.id}_indexing"
@ -48,7 +49,7 @@ def add_document_to_index_task(dataset_document_id: str):
session.query(DocumentSegment)
.where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == "completed",
DocumentSegment.status == SegmentStatus.COMPLETED,
)
.order_by(DocumentSegment.position.asc())
.all()
@ -139,7 +140,7 @@ def add_document_to_index_task(dataset_document_id: str):
logger.exception("add document to index failed")
dataset_document.enabled = False
dataset_document.disabled_at = naive_utc_now()
dataset_document.indexing_status = "error"
dataset_document.indexing_status = IndexingStatus.ERROR
dataset_document.error = str(e)
session.commit()
finally:

View File

@ -11,6 +11,7 @@ from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset
from models.enums import CollectionBindingType
from models.model import App, AppAnnotationSetting, MessageAnnotation
from services.dataset_service import DatasetCollectionBindingService
@ -47,7 +48,7 @@ def enable_annotation_reply_task(
try:
documents = []
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_provider_name, embedding_model_name, "annotation"
embedding_provider_name, embedding_model_name, CollectionBindingType.ANNOTATION
)
annotation_setting = (
session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
@ -56,7 +57,7 @@ def enable_annotation_reply_task(
if dataset_collection_binding.id != annotation_setting.collection_binding_id:
old_dataset_collection_binding = (
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
annotation_setting.collection_binding_id, "annotation"
annotation_setting.collection_binding_id, CollectionBindingType.ANNOTATION
)
)
if old_dataset_collection_binding and annotations:

View File

@ -10,6 +10,7 @@ from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import DocumentSegment
from models.enums import IndexingStatus, SegmentStatus
logger = logging.getLogger(__name__)
@ -31,7 +32,7 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
return
if segment.status != "waiting":
if segment.status != SegmentStatus.WAITING:
return
indexing_cache_key = f"segment_{segment.id}_indexing"
@ -40,7 +41,7 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N
# update segment status to indexing
session.query(DocumentSegment).filter_by(id=segment.id).update(
{
DocumentSegment.status: "indexing",
DocumentSegment.status: SegmentStatus.INDEXING,
DocumentSegment.indexing_at: naive_utc_now(),
}
)
@ -70,7 +71,7 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N
if (
not dataset_document.enabled
or dataset_document.archived
or dataset_document.indexing_status != "completed"
or dataset_document.indexing_status != IndexingStatus.COMPLETED
):
logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
return
@ -82,7 +83,7 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N
# update segment to completed
session.query(DocumentSegment).filter_by(id=segment.id).update(
{
DocumentSegment.status: "completed",
DocumentSegment.status: SegmentStatus.COMPLETED,
DocumentSegment.completed_at: naive_utc_now(),
}
)
@ -94,7 +95,7 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N
logger.exception("create segment to index failed")
segment.enabled = False
segment.disabled_at = naive_utc_now()
segment.status = "error"
segment.status = SegmentStatus.ERROR
segment.error = str(e)
session.commit()
finally:

View File

@ -12,6 +12,7 @@ from core.rag.extractor.notion_extractor import NotionExtractor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import IndexingStatus
from services.datasource_provider_service import DatasourceProviderService
logger = logging.getLogger(__name__)
@ -37,7 +38,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
return
if document.indexing_status == "parsing":
if document.indexing_status == IndexingStatus.PARSING:
logger.info(click.style(f"Document {document_id} is already being processed, skipping", fg="yellow"))
return
@ -88,7 +89,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
with session_factory.create_session() as session, session.begin():
document = session.query(Document).filter_by(id=document_id).first()
if document:
document.indexing_status = "error"
document.indexing_status = IndexingStatus.ERROR
document.error = "Datasource credential not found. Please reconnect your Notion workspace."
document.stopped_at = naive_utc_now()
return
@ -128,7 +129,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
data_source_info["last_edited_time"] = last_edited_time
document.data_source_info = json.dumps(data_source_info)
document.indexing_status = "parsing"
document.indexing_status = IndexingStatus.PARSING
document.processing_started_at = naive_utc_now()
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id)
@ -151,6 +152,6 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
with session_factory.create_session() as session, session.begin():
document = session.query(Document).filter_by(id=document_id).first()
if document:
document.indexing_status = "error"
document.indexing_status = IndexingStatus.ERROR
document.error = str(e)
document.stopped_at = naive_utc_now()

View File

@ -14,6 +14,7 @@ from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document
from models.enums import IndexingStatus
from services.feature_service import FeatureService
from tasks.generate_summary_index_task import generate_summary_index_task
@ -81,7 +82,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
if document:
document.indexing_status = "error"
document.indexing_status = IndexingStatus.ERROR
document.error = str(e)
document.stopped_at = naive_utc_now()
session.add(document)
@ -96,7 +97,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
for document in documents:
if document:
document.indexing_status = "parsing"
document.indexing_status = IndexingStatus.PARSING
document.processing_started_at = naive_utc_now()
session.add(document)
# Transaction committed and closed
@ -148,7 +149,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
document.need_summary,
)
if (
document.indexing_status == "completed"
document.indexing_status == IndexingStatus.COMPLETED
and document.doc_form != "qa_model"
and document.need_summary is True
):

View File

@ -10,6 +10,7 @@ from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import IndexingStatus
logger = logging.getLogger(__name__)
@ -33,7 +34,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
return
document.indexing_status = "parsing"
document.indexing_status = IndexingStatus.PARSING
document.processing_started_at = naive_utc_now()
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()

View File

@ -15,6 +15,7 @@ from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import IndexingStatus
from services.feature_service import FeatureService
logger = logging.getLogger(__name__)
@ -112,7 +113,7 @@ def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[st
)
for document in documents:
if document:
document.indexing_status = "error"
document.indexing_status = IndexingStatus.ERROR
document.error = str(e)
document.stopped_at = naive_utc_now()
session.add(document)
@ -146,7 +147,7 @@ def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[st
session.execute(segment_delete_stmt)
session.commit()
document.indexing_status = "parsing"
document.indexing_status = IndexingStatus.PARSING
document.processing_started_at = naive_utc_now()
session.add(document)
session.commit()

View File

@ -12,6 +12,7 @@ from core.rag.models.document import AttachmentDocument, ChildDocument, Document
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import DocumentSegment
from models.enums import IndexingStatus, SegmentStatus
logger = logging.getLogger(__name__)
@ -33,7 +34,7 @@ def enable_segment_to_index_task(segment_id: str):
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
return
if segment.status != "completed":
if segment.status != SegmentStatus.COMPLETED:
logger.info(click.style(f"Segment is not completed, enable is not allowed: {segment_id}", fg="red"))
return
@ -65,7 +66,7 @@ def enable_segment_to_index_task(segment_id: str):
if (
not dataset_document.enabled
or dataset_document.archived
or dataset_document.indexing_status != "completed"
or dataset_document.indexing_status != IndexingStatus.COMPLETED
):
logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
return
@ -123,7 +124,7 @@ def enable_segment_to_index_task(segment_id: str):
logger.exception("enable segment to index failed")
segment.enabled = False
segment.disabled_at = naive_utc_now()
segment.status = "error"
segment.status = SegmentStatus.ERROR
segment.error = str(e)
session.commit()
finally:

View File

@ -12,6 +12,7 @@ from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models import Account, Tenant
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import IndexingStatus
from services.feature_service import FeatureService
from services.rag_pipeline.rag_pipeline import RagPipelineService
@ -63,7 +64,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_
.first()
)
if document:
document.indexing_status = "error"
document.indexing_status = IndexingStatus.ERROR
document.error = str(e)
document.stopped_at = naive_utc_now()
session.add(document)
@ -95,7 +96,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_
session.execute(segment_delete_stmt)
session.commit()
document.indexing_status = "parsing"
document.indexing_status = IndexingStatus.PARSING
document.processing_started_at = naive_utc_now()
session.add(document)
session.commit()
@ -108,7 +109,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_
indexing_runner.run([document])
redis_client.delete(retry_indexing_cache_key)
except Exception as ex:
document.indexing_status = "error"
document.indexing_status = IndexingStatus.ERROR
document.error = str(ex)
document.stopped_at = naive_utc_now()
session.add(document)

View File

@ -11,6 +11,7 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFacto
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import IndexingStatus
from services.feature_service import FeatureService
logger = logging.getLogger(__name__)
@ -48,7 +49,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
if document:
document.indexing_status = "error"
document.indexing_status = IndexingStatus.ERROR
document.error = str(e)
document.stopped_at = naive_utc_now()
session.add(document)
@ -76,7 +77,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
session.execute(segment_delete_stmt)
session.commit()
document.indexing_status = "parsing"
document.indexing_status = IndexingStatus.PARSING
document.processing_started_at = naive_utc_now()
session.add(document)
session.commit()
@ -85,7 +86,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
indexing_runner.run([document])
redis_client.delete(sync_indexing_cache_key)
except Exception as ex:
document.indexing_status = "error"
document.indexing_status = IndexingStatus.ERROR
document.error = str(ex)
document.stopped_at = naive_utc_now()
session.add(document)

View File

@ -7,6 +7,7 @@ from faker import Faker
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest
from models.dataset import Dataset, Document
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus
from services.account_service import AccountService, TenantService
from tests.test_containers_integration_tests.helpers import generate_valid_password
@ -35,7 +36,7 @@ class TestGetAvailableDatasetsIntegration:
name=fake.company(),
description=fake.text(max_nb_chars=100),
provider="dify",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
indexing_technique="high_quality",
)
@ -49,14 +50,14 @@ class TestGetAvailableDatasetsIntegration:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=i,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch=str(uuid.uuid4()), # Required field
name=f"Document {i}",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
)
@ -94,7 +95,7 @@ class TestGetAvailableDatasetsIntegration:
tenant_id=tenant.id,
name=fake.company(),
provider="dify",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -106,13 +107,13 @@ class TestGetAvailableDatasetsIntegration:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=i,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch=str(uuid.uuid4()), # Required field
created_from="web",
created_from=DocumentCreatedFrom.WEB,
name=f"Archived Document {i}",
created_by=account.id,
doc_form="text_model",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=True, # Archived
)
@ -147,7 +148,7 @@ class TestGetAvailableDatasetsIntegration:
tenant_id=tenant.id,
name=fake.company(),
provider="dify",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -159,13 +160,13 @@ class TestGetAvailableDatasetsIntegration:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=i,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch=str(uuid.uuid4()), # Required field
created_from="web",
created_from=DocumentCreatedFrom.WEB,
name=f"Disabled Document {i}",
created_by=account.id,
doc_form="text_model",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=False, # Disabled
archived=False,
)
@ -200,21 +201,21 @@ class TestGetAvailableDatasetsIntegration:
tenant_id=tenant.id,
name=fake.company(),
provider="dify",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
# Create documents with non-completed status
for i, status in enumerate(["indexing", "parsing", "splitting"]):
for i, status in enumerate([IndexingStatus.INDEXING, IndexingStatus.PARSING, IndexingStatus.SPLITTING]):
document = Document(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
position=i,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch=str(uuid.uuid4()), # Required field
created_from="web",
created_from=DocumentCreatedFrom.WEB,
name=f"Document {status}",
created_by=account.id,
doc_form="text_model",
@ -263,7 +264,7 @@ class TestGetAvailableDatasetsIntegration:
tenant_id=tenant.id,
name=fake.company(),
provider="external", # External provider
data_source_type="external",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -307,7 +308,7 @@ class TestGetAvailableDatasetsIntegration:
tenant_id=tenant1.id,
name="Tenant 1 Dataset",
provider="dify",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account1.id,
)
db_session_with_containers.add(dataset1)
@ -318,7 +319,7 @@ class TestGetAvailableDatasetsIntegration:
tenant_id=tenant2.id,
name="Tenant 2 Dataset",
provider="dify",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account2.id,
)
db_session_with_containers.add(dataset2)
@ -330,13 +331,13 @@ class TestGetAvailableDatasetsIntegration:
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
position=0,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch=str(uuid.uuid4()), # Required field
created_from="web",
created_from=DocumentCreatedFrom.WEB,
name=f"Document for {dataset.name}",
created_by=account.id,
doc_form="text_model",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
)
@ -398,7 +399,7 @@ class TestGetAvailableDatasetsIntegration:
tenant_id=tenant.id,
name=f"Dataset {i}",
provider="dify",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -410,13 +411,13 @@ class TestGetAvailableDatasetsIntegration:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch=str(uuid.uuid4()), # Required field
created_from="web",
created_from=DocumentCreatedFrom.WEB,
name=f"Document {i}",
created_by=account.id,
doc_form="text_model",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
)
@ -456,7 +457,7 @@ class TestKnowledgeRetrievalIntegration:
tenant_id=tenant.id,
name=fake.company(),
provider="dify",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
indexing_technique="high_quality",
)
@ -467,12 +468,12 @@ class TestKnowledgeRetrievalIntegration:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch=str(uuid.uuid4()), # Required field
created_from="web",
created_from=DocumentCreatedFrom.WEB,
name=fake.sentence(),
created_by=account.id,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
doc_form="text_model",
@ -525,7 +526,7 @@ class TestKnowledgeRetrievalIntegration:
tenant_id=tenant.id,
name=fake.company(),
provider="dify",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -572,7 +573,7 @@ class TestKnowledgeRetrievalIntegration:
tenant_id=tenant.id,
name=fake.company(),
provider="dify",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)

View File

@ -12,6 +12,7 @@ import pytest
from sqlalchemy.orm import Session
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
class TestDatasetDocumentProperties:
@ -29,7 +30,7 @@ class TestDatasetDocumentProperties:
created_by = str(uuid4())
dataset = Dataset(
tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by
tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by
)
db_session_with_containers.add(dataset)
db_session_with_containers.flush()
@ -39,10 +40,10 @@ class TestDatasetDocumentProperties:
tenant_id=tenant_id,
dataset_id=dataset.id,
position=i + 1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name=f"doc_{i}.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
)
db_session_with_containers.add(doc)
@ -56,7 +57,7 @@ class TestDatasetDocumentProperties:
created_by = str(uuid4())
dataset = Dataset(
tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by
tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by
)
db_session_with_containers.add(dataset)
db_session_with_containers.flush()
@ -65,12 +66,12 @@ class TestDatasetDocumentProperties:
tenant_id=tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="available.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
)
@ -78,12 +79,12 @@ class TestDatasetDocumentProperties:
tenant_id=tenant_id,
dataset_id=dataset.id,
position=2,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="pending.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
indexing_status="waiting",
indexing_status=IndexingStatus.WAITING,
enabled=True,
archived=False,
)
@ -91,12 +92,12 @@ class TestDatasetDocumentProperties:
tenant_id=tenant_id,
dataset_id=dataset.id,
position=3,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="disabled.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=False,
archived=False,
)
@ -111,7 +112,7 @@ class TestDatasetDocumentProperties:
created_by = str(uuid4())
dataset = Dataset(
tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by
tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by
)
db_session_with_containers.add(dataset)
db_session_with_containers.flush()
@ -121,10 +122,10 @@ class TestDatasetDocumentProperties:
tenant_id=tenant_id,
dataset_id=dataset.id,
position=i + 1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name=f"doc_{i}.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
word_count=wc,
)
@ -139,7 +140,7 @@ class TestDatasetDocumentProperties:
created_by = str(uuid4())
dataset = Dataset(
tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by
tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by
)
db_session_with_containers.add(dataset)
db_session_with_containers.flush()
@ -148,10 +149,10 @@ class TestDatasetDocumentProperties:
tenant_id=tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="doc.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
)
db_session_with_containers.add(doc)
@ -166,7 +167,7 @@ class TestDatasetDocumentProperties:
content=f"segment {i}",
word_count=100,
tokens=50,
status="completed",
status=SegmentStatus.COMPLETED,
enabled=True,
created_by=created_by,
)
@ -180,7 +181,7 @@ class TestDatasetDocumentProperties:
content="waiting segment",
word_count=100,
tokens=50,
status="waiting",
status=SegmentStatus.WAITING,
enabled=True,
created_by=created_by,
)
@ -195,7 +196,7 @@ class TestDatasetDocumentProperties:
created_by = str(uuid4())
dataset = Dataset(
tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by
tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by
)
db_session_with_containers.add(dataset)
db_session_with_containers.flush()
@ -204,10 +205,10 @@ class TestDatasetDocumentProperties:
tenant_id=tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="doc.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
)
db_session_with_containers.add(doc)
@ -235,7 +236,7 @@ class TestDatasetDocumentProperties:
created_by = str(uuid4())
dataset = Dataset(
tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by
tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by
)
db_session_with_containers.add(dataset)
db_session_with_containers.flush()
@ -244,10 +245,10 @@ class TestDatasetDocumentProperties:
tenant_id=tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="doc.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
)
db_session_with_containers.add(doc)
@ -288,7 +289,7 @@ class TestDocumentSegmentNavigationProperties:
dataset = Dataset(
tenant_id=tenant_id,
name="Test Dataset",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=created_by,
)
db_session_with_containers.add(dataset)
@ -298,10 +299,10 @@ class TestDocumentSegmentNavigationProperties:
tenant_id=tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="test.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
)
db_session_with_containers.add(document)
@ -335,7 +336,7 @@ class TestDocumentSegmentNavigationProperties:
dataset = Dataset(
tenant_id=tenant_id,
name="Test Dataset",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=created_by,
)
db_session_with_containers.add(dataset)
@ -345,10 +346,10 @@ class TestDocumentSegmentNavigationProperties:
tenant_id=tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="test.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
)
db_session_with_containers.add(document)
@ -382,7 +383,7 @@ class TestDocumentSegmentNavigationProperties:
dataset = Dataset(
tenant_id=tenant_id,
name="Test Dataset",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=created_by,
)
db_session_with_containers.add(dataset)
@ -392,10 +393,10 @@ class TestDocumentSegmentNavigationProperties:
tenant_id=tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="test.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
)
db_session_with_containers.add(document)
@ -439,7 +440,7 @@ class TestDocumentSegmentNavigationProperties:
dataset = Dataset(
tenant_id=tenant_id,
name="Test Dataset",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=created_by,
)
db_session_with_containers.add(dataset)
@ -449,10 +450,10 @@ class TestDocumentSegmentNavigationProperties:
tenant_id=tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="test.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
)
db_session_with_containers.add(document)

View File

@ -12,6 +12,7 @@ import pytest
from sqlalchemy.orm import Session
from models.dataset import DatasetCollectionBinding
from models.enums import CollectionBindingType
from services.dataset_service import DatasetCollectionBindingService
@ -32,7 +33,7 @@ class DatasetCollectionBindingTestDataFactory:
provider_name: str = "openai",
model_name: str = "text-embedding-ada-002",
collection_name: str = "collection-abc",
collection_type: str = "dataset",
collection_type: str = CollectionBindingType.DATASET,
) -> DatasetCollectionBinding:
"""
Create a DatasetCollectionBinding with specified attributes.
@ -41,7 +42,7 @@ class DatasetCollectionBindingTestDataFactory:
provider_name: Name of the embedding model provider (e.g., "openai", "cohere")
model_name: Name of the embedding model (e.g., "text-embedding-ada-002")
collection_name: Name of the vector database collection
collection_type: Type of collection (default: "dataset")
collection_type: Type of collection (default: CollectionBindingType.DATASET)
Returns:
DatasetCollectionBinding instance
@ -76,7 +77,7 @@ class TestDatasetCollectionBindingServiceGetBinding:
# Arrange
provider_name = "openai"
model_name = "text-embedding-ada-002"
collection_type = "dataset"
collection_type = CollectionBindingType.DATASET
existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding(
db_session_with_containers,
provider_name=provider_name,
@ -104,7 +105,7 @@ class TestDatasetCollectionBindingServiceGetBinding:
# Arrange
provider_name = f"provider-{uuid4()}"
model_name = f"model-{uuid4()}"
collection_type = "dataset"
collection_type = CollectionBindingType.DATASET
# Act
result = DatasetCollectionBindingService.get_dataset_collection_binding(
@ -145,7 +146,7 @@ class TestDatasetCollectionBindingServiceGetBinding:
result = DatasetCollectionBindingService.get_dataset_collection_binding(provider_name, model_name)
# Assert
assert result.type == "dataset"
assert result.type == CollectionBindingType.DATASET
assert result.provider_name == provider_name
assert result.model_name == model_name
@ -186,18 +187,20 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
provider_name="openai",
model_name="text-embedding-ada-002",
collection_name="test-collection",
collection_type="dataset",
collection_type=CollectionBindingType.DATASET,
)
# Act
result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(binding.id, "dataset")
result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
binding.id, CollectionBindingType.DATASET
)
# Assert
assert result.id == binding.id
assert result.provider_name == "openai"
assert result.model_name == "text-embedding-ada-002"
assert result.collection_name == "test-collection"
assert result.type == "dataset"
assert result.type == CollectionBindingType.DATASET
def test_get_dataset_collection_binding_by_id_and_type_not_found_error(self, db_session_with_containers: Session):
"""Test error handling when collection binding is not found by ID and type."""
@ -206,7 +209,9 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
# Act & Assert
with pytest.raises(ValueError, match="Dataset collection binding not found"):
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(non_existent_id, "dataset")
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
non_existent_id, CollectionBindingType.DATASET
)
def test_get_dataset_collection_binding_by_id_and_type_different_collection_type(
self, db_session_with_containers: Session
@ -240,7 +245,7 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
provider_name="openai",
model_name="text-embedding-ada-002",
collection_name="test-collection",
collection_type="dataset",
collection_type=CollectionBindingType.DATASET,
)
# Act
@ -248,7 +253,7 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
# Assert
assert result.id == binding.id
assert result.type == "dataset"
assert result.type == CollectionBindingType.DATASET
def test_get_dataset_collection_binding_by_id_and_type_wrong_type_error(self, db_session_with_containers: Session):
"""Test error when binding exists but with wrong collection type."""
@ -258,7 +263,7 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
provider_name="openai",
model_name="text-embedding-ada-002",
collection_name="test-collection",
collection_type="dataset",
collection_type=CollectionBindingType.DATASET,
)
# Act & Assert

View File

@ -15,6 +15,7 @@ from werkzeug.exceptions import NotFound
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import AppDatasetJoin, Dataset, DatasetPermissionEnum
from models.enums import DataSourceType
from models.model import App
from services.dataset_service import DatasetService
from services.errors.account import NoPermissionError
@ -72,7 +73,7 @@ class DatasetUpdateDeleteTestDataFactory:
tenant_id=tenant_id,
name=name,
description="Test description",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
created_by=created_by,
permission=permission,

View File

@ -15,7 +15,7 @@ import pytest
from models import Account
from models.dataset import Dataset, Document
from models.enums import CreatorUserRole
from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom, IndexingStatus
from models.model import UploadFile
from services.dataset_service import DocumentService
from services.errors.document import DocumentIndexingError
@ -88,7 +88,7 @@ class DocumentStatusTestDataFactory:
data_source_info=json.dumps(data_source_info or {}),
batch=f"batch-{uuid4()}",
name=name,
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
doc_form="text_model",
)
@ -100,7 +100,7 @@ class DocumentStatusTestDataFactory:
document.paused_by = paused_by
document.paused_at = paused_at
document.doc_metadata = doc_metadata or {}
if indexing_status == "completed" and "completed_at" not in kwargs:
if indexing_status == IndexingStatus.COMPLETED and "completed_at" not in kwargs:
document.completed_at = FIXED_TIME
for key, value in kwargs.items():
@ -139,7 +139,7 @@ class DocumentStatusTestDataFactory:
dataset = Dataset(
tenant_id=tenant_id,
name=name,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=created_by,
)
dataset.id = dataset_id
@ -291,7 +291,7 @@ class TestDocumentServicePauseDocument:
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="waiting",
indexing_status=IndexingStatus.WAITING,
is_paused=False,
)
@ -326,7 +326,7 @@ class TestDocumentServicePauseDocument:
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="indexing",
indexing_status=IndexingStatus.INDEXING,
is_paused=False,
)
@ -354,7 +354,7 @@ class TestDocumentServicePauseDocument:
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="parsing",
indexing_status=IndexingStatus.PARSING,
is_paused=False,
)
@ -383,7 +383,7 @@ class TestDocumentServicePauseDocument:
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
is_paused=False,
)
@ -412,7 +412,7 @@ class TestDocumentServicePauseDocument:
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="error",
indexing_status=IndexingStatus.ERROR,
is_paused=False,
)
@ -487,7 +487,7 @@ class TestDocumentServiceRecoverDocument:
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="indexing",
indexing_status=IndexingStatus.INDEXING,
is_paused=True,
paused_by=str(uuid4()),
paused_at=paused_time,
@ -526,7 +526,7 @@ class TestDocumentServiceRecoverDocument:
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="indexing",
indexing_status=IndexingStatus.INDEXING,
is_paused=False,
)
@ -609,7 +609,7 @@ class TestDocumentServiceRetryDocument:
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
document_id=str(uuid4()),
indexing_status="error",
indexing_status=IndexingStatus.ERROR,
)
mock_document_service_dependencies["redis_client"].get.return_value = None
@ -619,7 +619,7 @@ class TestDocumentServiceRetryDocument:
# Assert
db_session_with_containers.refresh(document)
assert document.indexing_status == "waiting"
assert document.indexing_status == IndexingStatus.WAITING
expected_cache_key = f"document_{document.id}_is_retried"
mock_document_service_dependencies["redis_client"].setex.assert_called_once_with(expected_cache_key, 600, 1)
@ -646,14 +646,14 @@ class TestDocumentServiceRetryDocument:
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
document_id=str(uuid4()),
indexing_status="error",
indexing_status=IndexingStatus.ERROR,
)
document2 = DocumentStatusTestDataFactory.create_document(
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
document_id=str(uuid4()),
indexing_status="error",
indexing_status=IndexingStatus.ERROR,
position=2,
)
@ -665,8 +665,8 @@ class TestDocumentServiceRetryDocument:
# Assert
db_session_with_containers.refresh(document1)
db_session_with_containers.refresh(document2)
assert document1.indexing_status == "waiting"
assert document2.indexing_status == "waiting"
assert document1.indexing_status == IndexingStatus.WAITING
assert document2.indexing_status == IndexingStatus.WAITING
mock_document_service_dependencies["retry_task"].delay.assert_called_once_with(
dataset.id, [document1.id, document2.id], mock_document_service_dependencies["user_id"]
@ -693,7 +693,7 @@ class TestDocumentServiceRetryDocument:
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
document_id=str(uuid4()),
indexing_status="error",
indexing_status=IndexingStatus.ERROR,
)
mock_document_service_dependencies["redis_client"].get.return_value = "1"
@ -703,7 +703,7 @@ class TestDocumentServiceRetryDocument:
DocumentService.retry_document(dataset.id, [document])
db_session_with_containers.refresh(document)
assert document.indexing_status == "error"
assert document.indexing_status == IndexingStatus.ERROR
def test_retry_document_missing_current_user_error(
self, db_session_with_containers, mock_document_service_dependencies
@ -726,7 +726,7 @@ class TestDocumentServiceRetryDocument:
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
document_id=str(uuid4()),
indexing_status="error",
indexing_status=IndexingStatus.ERROR,
)
mock_document_service_dependencies["redis_client"].get.return_value = None
@ -816,7 +816,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
tenant_id=dataset.tenant_id,
document_id=str(uuid4()),
enabled=False,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
document2 = DocumentStatusTestDataFactory.create_document(
db_session_with_containers,
@ -824,7 +824,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
tenant_id=dataset.tenant_id,
document_id=str(uuid4()),
enabled=False,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
position=2,
)
document_ids = [document1.id, document2.id]
@ -866,7 +866,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
tenant_id=dataset.tenant_id,
document_id=str(uuid4()),
enabled=True,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
completed_at=FIXED_TIME,
)
document_ids = [document.id]
@ -909,7 +909,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
document_id=str(uuid4()),
archived=False,
enabled=True,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
document_ids = [document.id]
@ -951,7 +951,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
document_id=str(uuid4()),
archived=True,
enabled=True,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
document_ids = [document.id]
@ -1015,7 +1015,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
document_id=str(uuid4()),
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
document_ids = [document.id]
@ -1098,7 +1098,7 @@ class TestDocumentServiceRenameDocument:
document_id=document_id,
dataset_id=dataset.id,
tenant_id=tenant_id,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
# Act
@ -1139,7 +1139,7 @@ class TestDocumentServiceRenameDocument:
dataset_id=dataset.id,
tenant_id=tenant_id,
doc_metadata={"existing_key": "existing_value"},
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
# Act
@ -1187,7 +1187,7 @@ class TestDocumentServiceRenameDocument:
dataset_id=dataset.id,
tenant_id=tenant_id,
data_source_info={"upload_file_id": upload_file.id},
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
# Act
@ -1277,7 +1277,7 @@ class TestDocumentServiceRenameDocument:
document_id=document_id,
dataset_id=dataset.id,
tenant_id=str(uuid4()),
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
# Act & Assert

View File

@ -16,6 +16,7 @@ from models.dataset import (
DatasetPermission,
DatasetPermissionEnum,
)
from models.enums import DataSourceType
from services.dataset_service import DatasetPermissionService, DatasetService
from services.errors.account import NoPermissionError
@ -67,7 +68,7 @@ class DatasetPermissionTestDataFactory:
tenant_id=tenant_id,
name=name,
description="desc",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
created_by=created_by,
permission=permission,

View File

@ -15,6 +15,7 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod
from dify_graph.model_runtime.entities.model_entities import ModelType
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, DatasetPermissionEnum, Document, ExternalKnowledgeBindings, Pipeline
from models.enums import DatasetRuntimeMode, DataSourceType, DocumentCreatedFrom, IndexingStatus
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import RerankingModel, RetrievalModel
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity
@ -74,7 +75,7 @@ class DatasetServiceIntegrationDataFactory:
tenant_id=tenant_id,
name=name,
description=description,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique=indexing_technique,
created_by=created_by,
provider=provider,
@ -98,13 +99,13 @@ class DatasetServiceIntegrationDataFactory:
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
data_source_info='{"upload_file_id": "upload-file-id"}',
batch=str(uuid4()),
name=name,
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
doc_form="text_model",
)
db_session_with_containers.add(document)
@ -437,7 +438,7 @@ class TestDatasetServiceCreateRagPipelineDataset:
created_pipeline = db_session_with_containers.get(Pipeline, result.pipeline_id)
assert created_dataset is not None
assert created_dataset.name == entity.name
assert created_dataset.runtime_mode == "rag_pipeline"
assert created_dataset.runtime_mode == DatasetRuntimeMode.RAG_PIPELINE
assert created_dataset.created_by == account.id
assert created_dataset.permission == DatasetPermissionEnum.ONLY_ME
assert created_pipeline is not None

View File

@ -14,6 +14,7 @@ import pytest
from sqlalchemy.orm import Session
from models.dataset import Dataset, Document
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus
from services.dataset_service import DocumentService
from services.errors.document import DocumentIndexingError
@ -42,7 +43,7 @@ class DocumentBatchUpdateIntegrationDataFactory:
dataset = Dataset(
tenant_id=tenant_id or str(uuid4()),
name=name,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=created_by or str(uuid4()),
)
if dataset_id:
@ -72,11 +73,11 @@ class DocumentBatchUpdateIntegrationDataFactory:
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
position=position,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
data_source_info=json.dumps({"upload_file_id": str(uuid4())}),
batch=f"batch-{uuid4()}",
name=name,
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by or str(uuid4()),
doc_form="text_model",
)
@ -85,7 +86,9 @@ class DocumentBatchUpdateIntegrationDataFactory:
document.archived = archived
document.indexing_status = indexing_status
document.completed_at = (
completed_at if completed_at is not None else (FIXED_TIME if indexing_status == "completed" else None)
completed_at
if completed_at is not None
else (FIXED_TIME if indexing_status == IndexingStatus.COMPLETED else None)
)
for key, value in kwargs.items():
@ -243,7 +246,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
dataset=dataset,
document_ids=document_ids,
enabled=True,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
# Act
@ -277,7 +280,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
db_session_with_containers,
dataset=dataset,
enabled=False,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
completed_at=FIXED_TIME,
)
@ -306,7 +309,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
db_session_with_containers,
dataset=dataset,
enabled=True,
indexing_status="indexing",
indexing_status=IndexingStatus.INDEXING,
completed_at=None,
)

View File

@ -5,6 +5,7 @@ from uuid import uuid4
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document
from models.enums import DataSourceType, DocumentCreatedFrom
from services.dataset_service import DatasetService
@ -58,7 +59,7 @@ class DatasetDeleteIntegrationDataFactory:
dataset = Dataset(
tenant_id=tenant_id,
name=f"dataset-{uuid4()}",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique=indexing_technique,
index_struct=index_struct,
created_by=created_by,
@ -84,10 +85,10 @@ class DatasetDeleteIntegrationDataFactory:
tenant_id=tenant_id,
dataset_id=dataset_id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch=f"batch-{uuid4()}",
name="Document",
created_from="upload_file",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
doc_form=doc_form,
)

View File

@ -14,6 +14,7 @@ from sqlalchemy.orm import Session
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment
from models.enums import DataSourceType, DocumentCreatedFrom
from services.dataset_service import SegmentService
@ -62,7 +63,7 @@ class SegmentServiceTestDataFactory:
tenant_id=tenant_id,
name=f"Test Dataset {uuid4()}",
description="Test description",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
created_by=created_by,
permission=DatasetPermissionEnum.ONLY_ME,
@ -82,10 +83,10 @@ class SegmentServiceTestDataFactory:
tenant_id=tenant_id,
dataset_id=dataset_id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch=f"batch-{uuid4()}",
name=f"test-doc-{uuid4()}.txt",
created_from="api",
created_from=DocumentCreatedFrom.API,
created_by=created_by,
)
db_session_with_containers.add(document)

Some files were not shown because too many files have changed in this diff Show More