mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 18:27:19 +08:00
Merge branch 'main' into jzh
This commit is contained in:
commit
883eb498c0
@ -63,7 +63,8 @@ pnpm analyze-component <path> --review
|
||||
|
||||
### File Naming
|
||||
|
||||
- Test files: `ComponentName.spec.tsx` (same directory as component)
|
||||
- Test files: `ComponentName.spec.tsx` inside a same-level `__tests__/` directory
|
||||
- Placement rule: Component, hook, and utility tests must live in a sibling `__tests__/` folder at the same level as the source under test. For example, `foo/index.tsx` maps to `foo/__tests__/index.spec.tsx`, and `foo/bar.ts` maps to `foo/__tests__/bar.spec.ts`.
|
||||
- Integration tests: `web/__tests__/` directory
|
||||
|
||||
## Test Structure Template
|
||||
|
||||
@ -41,7 +41,7 @@ import userEvent from '@testing-library/user-event'
|
||||
// Router (if component uses useRouter, usePathname, useSearchParams)
|
||||
// WHY: Isolates tests from Next.js routing, enables testing navigation behavior
|
||||
// const mockPush = vi.fn()
|
||||
// vi.mock('next/navigation', () => ({
|
||||
// vi.mock('@/next/navigation', () => ({
|
||||
// useRouter: () => ({ push: mockPush }),
|
||||
// usePathname: () => '/test-path',
|
||||
// }))
|
||||
|
||||
4
.github/workflows/web-tests.yml
vendored
4
.github/workflows/web-tests.yml
vendored
@ -29,8 +29,8 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
shardIndex: [1, 2, 3, 4]
|
||||
shardTotal: [4]
|
||||
shardIndex: [1, 2, 3, 4, 5, 6]
|
||||
shardTotal: [6]
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
@ -78,7 +78,7 @@ class UserProfile(TypedDict):
|
||||
nickname: NotRequired[str]
|
||||
```
|
||||
|
||||
- For classes, declare member variables at the top of the class body (before `__init__`) so the class shape is obvious at a glance:
|
||||
- For classes, declare all member variables explicitly with types at the top of the class body (before `__init__`), even when the class is not a dataclass or Pydantic model, so the class shape is obvious at a glance:
|
||||
|
||||
```python
|
||||
from datetime import datetime
|
||||
|
||||
@ -88,6 +88,8 @@ def clean_workflow_runs(
|
||||
"""
|
||||
Clean workflow runs and related workflow data for free tenants.
|
||||
"""
|
||||
from extensions.otel.runtime import flush_telemetry
|
||||
|
||||
if (start_from is None) ^ (end_before is None):
|
||||
raise click.UsageError("--start-from and --end-before must be provided together.")
|
||||
|
||||
@ -104,16 +106,27 @@ def clean_workflow_runs(
|
||||
end_before = now - datetime.timedelta(days=to_days_ago)
|
||||
before_days = 0
|
||||
|
||||
if from_days_ago is not None and to_days_ago is not None:
|
||||
task_label = f"{from_days_ago}to{to_days_ago}"
|
||||
elif start_from is None:
|
||||
task_label = f"before-{before_days}"
|
||||
else:
|
||||
task_label = "custom"
|
||||
|
||||
start_time = datetime.datetime.now(datetime.UTC)
|
||||
click.echo(click.style(f"Starting workflow run cleanup at {start_time.isoformat()}.", fg="white"))
|
||||
|
||||
WorkflowRunCleanup(
|
||||
days=before_days,
|
||||
batch_size=batch_size,
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
dry_run=dry_run,
|
||||
).run()
|
||||
try:
|
||||
WorkflowRunCleanup(
|
||||
days=before_days,
|
||||
batch_size=batch_size,
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
dry_run=dry_run,
|
||||
task_label=task_label,
|
||||
).run()
|
||||
finally:
|
||||
flush_telemetry()
|
||||
|
||||
end_time = datetime.datetime.now(datetime.UTC)
|
||||
elapsed = end_time - start_time
|
||||
@ -659,6 +672,8 @@ def clean_expired_messages(
|
||||
"""
|
||||
Clean expired messages and related data for tenants based on clean policy.
|
||||
"""
|
||||
from extensions.otel.runtime import flush_telemetry
|
||||
|
||||
click.echo(click.style("clean_messages: start clean messages.", fg="green"))
|
||||
|
||||
start_at = time.perf_counter()
|
||||
@ -698,6 +713,13 @@ def clean_expired_messages(
|
||||
# NOTE: graceful_period will be ignored when billing is disabled.
|
||||
policy = create_message_clean_policy(graceful_period_days=graceful_period)
|
||||
|
||||
if from_days_ago is not None and before_days is not None:
|
||||
task_label = f"{from_days_ago}to{before_days}"
|
||||
elif start_from is None and before_days is not None:
|
||||
task_label = f"before-{before_days}"
|
||||
else:
|
||||
task_label = "custom"
|
||||
|
||||
# Create and run the cleanup service
|
||||
if abs_mode:
|
||||
assert start_from is not None
|
||||
@ -708,6 +730,7 @@ def clean_expired_messages(
|
||||
end_before=end_before,
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
task_label=task_label,
|
||||
)
|
||||
elif from_days_ago is None:
|
||||
assert before_days is not None
|
||||
@ -716,6 +739,7 @@ def clean_expired_messages(
|
||||
days=before_days,
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
task_label=task_label,
|
||||
)
|
||||
else:
|
||||
assert before_days is not None
|
||||
@ -727,6 +751,7 @@ def clean_expired_messages(
|
||||
end_before=now - datetime.timedelta(days=before_days),
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
task_label=task_label,
|
||||
)
|
||||
stats = service.run()
|
||||
|
||||
@ -752,6 +777,8 @@ def clean_expired_messages(
|
||||
)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
flush_telemetry()
|
||||
|
||||
click.echo(click.style("messages cleanup completed.", fg="green"))
|
||||
|
||||
|
||||
@ -14,6 +14,7 @@ from core.rag.models.document import ChildDocument, Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.enums import DatasetMetadataType, IndexingStatus, SegmentStatus
|
||||
from models.model import App, AppAnnotationSetting, MessageAnnotation
|
||||
|
||||
|
||||
@ -242,7 +243,7 @@ def migrate_knowledge_vector_database():
|
||||
dataset_documents = db.session.scalars(
|
||||
select(DatasetDocument).where(
|
||||
DatasetDocument.dataset_id == dataset.id,
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.indexing_status == IndexingStatus.COMPLETED,
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
@ -254,7 +255,7 @@ def migrate_knowledge_vector_database():
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.status == SegmentStatus.COMPLETED,
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
).all()
|
||||
@ -430,7 +431,7 @@ def old_metadata_migration():
|
||||
tenant_id=document.tenant_id,
|
||||
dataset_id=document.dataset_id,
|
||||
name=key,
|
||||
type="string",
|
||||
type=DatasetMetadataType.STRING,
|
||||
created_by=document.created_by,
|
||||
)
|
||||
db.session.add(dataset_metadata)
|
||||
|
||||
12
api/configs/middleware/cache/redis_config.py
vendored
12
api/configs/middleware/cache/redis_config.py
vendored
@ -1,4 +1,4 @@
|
||||
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt
|
||||
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, field_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
@ -116,3 +116,13 @@ class RedisConfig(BaseSettings):
|
||||
description="Maximum connections in the Redis connection pool (unset for library default)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
@field_validator("REDIS_MAX_CONNECTIONS", mode="before")
|
||||
@classmethod
|
||||
def _empty_string_to_none_for_max_conns(cls, v):
|
||||
"""Allow empty string in env/.env to mean 'unset' (None)."""
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, str) and v.strip() == "":
|
||||
return None
|
||||
return v
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Literal, Protocol
|
||||
from typing import Literal, Protocol, cast
|
||||
from urllib.parse import quote_plus, urlunparse
|
||||
|
||||
from pydantic import AliasChoices, Field
|
||||
@ -12,16 +12,13 @@ class RedisConfigDefaults(Protocol):
|
||||
REDIS_PASSWORD: str | None
|
||||
REDIS_DB: int
|
||||
REDIS_USE_SSL: bool
|
||||
REDIS_USE_SENTINEL: bool | None
|
||||
REDIS_USE_CLUSTERS: bool
|
||||
|
||||
|
||||
class RedisConfigDefaultsMixin:
|
||||
def _redis_defaults(self: RedisConfigDefaults) -> RedisConfigDefaults:
|
||||
return self
|
||||
def _redis_defaults(config: object) -> RedisConfigDefaults:
|
||||
return cast(RedisConfigDefaults, config)
|
||||
|
||||
|
||||
class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
|
||||
class RedisPubSubConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for event transport between API and workers.
|
||||
|
||||
@ -74,7 +71,7 @@ class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
|
||||
)
|
||||
|
||||
def _build_default_pubsub_url(self) -> str:
|
||||
defaults = self._redis_defaults()
|
||||
defaults = _redis_defaults(self)
|
||||
if not defaults.REDIS_HOST or not defaults.REDIS_PORT:
|
||||
raise ValueError("PUBSUB_REDIS_URL must be set when default Redis URL cannot be constructed")
|
||||
|
||||
@ -91,11 +88,9 @@ class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
|
||||
if userinfo:
|
||||
userinfo = f"{userinfo}@"
|
||||
|
||||
host = defaults.REDIS_HOST
|
||||
port = defaults.REDIS_PORT
|
||||
db = defaults.REDIS_DB
|
||||
|
||||
netloc = f"{userinfo}{host}:{port}"
|
||||
netloc = f"{userinfo}{defaults.REDIS_HOST}:{defaults.REDIS_PORT}"
|
||||
return urlunparse((scheme, netloc, f"/{db}", "", "", ""))
|
||||
|
||||
@property
|
||||
|
||||
@ -54,6 +54,7 @@ from fields.document_fields import document_status_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DatasetPermission, DatasetPermissionEnum
|
||||
from models.enums import SegmentStatus
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.api_token_service import ApiTokenCache
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
@ -741,13 +742,15 @@ class DatasetIndexingStatusApi(Resource):
|
||||
.where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != "re_segment",
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
total_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
|
||||
.where(
|
||||
DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT
|
||||
)
|
||||
.count()
|
||||
)
|
||||
# Create a dictionary with document attributes and additional fields
|
||||
|
||||
@ -42,6 +42,7 @@ from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DocumentPipelineExecutionLog
|
||||
from models.enums import IndexingStatus, SegmentStatus
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
|
||||
from services.file_service import FileService
|
||||
@ -332,13 +333,16 @@ class DatasetDocumentListApi(Resource):
|
||||
.where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != "re_segment",
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
total_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
|
||||
.where(
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
document.completed_segments = completed_segments
|
||||
@ -503,7 +507,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
document_id = str(document_id)
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
if document.indexing_status in {"completed", "error"}:
|
||||
if document.indexing_status in {IndexingStatus.COMPLETED, IndexingStatus.ERROR}:
|
||||
raise DocumentAlreadyFinishedError()
|
||||
|
||||
data_process_rule = document.dataset_process_rule
|
||||
@ -573,7 +577,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
data_process_rule_dict = data_process_rule.to_dict() if data_process_rule else {}
|
||||
extract_settings = []
|
||||
for document in documents:
|
||||
if document.indexing_status in {"completed", "error"}:
|
||||
if document.indexing_status in {IndexingStatus.COMPLETED, IndexingStatus.ERROR}:
|
||||
raise DocumentAlreadyFinishedError()
|
||||
data_source_info = document.data_source_info_dict
|
||||
match document.data_source_type:
|
||||
@ -671,19 +675,21 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
|
||||
.where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != "re_segment",
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
total_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
|
||||
.where(
|
||||
DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT
|
||||
)
|
||||
.count()
|
||||
)
|
||||
# Create a dictionary with document attributes and additional fields
|
||||
document_dict = {
|
||||
"id": document.id,
|
||||
"indexing_status": "paused" if document.is_paused else document.indexing_status,
|
||||
"indexing_status": IndexingStatus.PAUSED if document.is_paused else document.indexing_status,
|
||||
"processing_started_at": document.processing_started_at,
|
||||
"parsing_completed_at": document.parsing_completed_at,
|
||||
"cleaning_completed_at": document.cleaning_completed_at,
|
||||
@ -720,20 +726,20 @@ class DocumentIndexingStatusApi(DocumentResource):
|
||||
.where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document_id),
|
||||
DocumentSegment.status != "re_segment",
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
total_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment")
|
||||
.where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != SegmentStatus.RE_SEGMENT)
|
||||
.count()
|
||||
)
|
||||
|
||||
# Create a dictionary with document attributes and additional fields
|
||||
document_dict = {
|
||||
"id": document.id,
|
||||
"indexing_status": "paused" if document.is_paused else document.indexing_status,
|
||||
"indexing_status": IndexingStatus.PAUSED if document.is_paused else document.indexing_status,
|
||||
"processing_started_at": document.processing_started_at,
|
||||
"parsing_completed_at": document.parsing_completed_at,
|
||||
"cleaning_completed_at": document.cleaning_completed_at,
|
||||
@ -955,7 +961,7 @@ class DocumentProcessingApi(DocumentResource):
|
||||
|
||||
match action:
|
||||
case "pause":
|
||||
if document.indexing_status != "indexing":
|
||||
if document.indexing_status != IndexingStatus.INDEXING:
|
||||
raise InvalidActionError("Document not in indexing state.")
|
||||
|
||||
document.paused_by = current_user.id
|
||||
@ -964,7 +970,7 @@ class DocumentProcessingApi(DocumentResource):
|
||||
db.session.commit()
|
||||
|
||||
case "resume":
|
||||
if document.indexing_status not in {"paused", "error"}:
|
||||
if document.indexing_status not in {IndexingStatus.PAUSED, IndexingStatus.ERROR}:
|
||||
raise InvalidActionError("Document not in paused or error state.")
|
||||
|
||||
document.paused_by = None
|
||||
@ -1169,7 +1175,7 @@ class DocumentRetryApi(DocumentResource):
|
||||
raise ArchivedDocumentImmutableError()
|
||||
|
||||
# 400 if document is completed
|
||||
if document.indexing_status == "completed":
|
||||
if document.indexing_status == IndexingStatus.COMPLETED:
|
||||
raise DocumentAlreadyFinishedError()
|
||||
retry_documents.append(document)
|
||||
except Exception:
|
||||
|
||||
@ -46,6 +46,8 @@ class PipelineTemplateDetailApi(Resource):
|
||||
type = request.args.get("type", default="built-in", type=str)
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
pipeline_template = rag_pipeline_service.get_pipeline_template_detail(template_id, type)
|
||||
if pipeline_template is None:
|
||||
return {"error": "Pipeline template not found from upstream service."}, 404
|
||||
return pipeline_template, 200
|
||||
|
||||
|
||||
|
||||
@ -36,6 +36,7 @@ from extensions.ext_database import db
|
||||
from fields.document_fields import document_fields, document_status_fields
|
||||
from libs.login import current_user
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.enums import SegmentStatus
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
KnowledgeConfig,
|
||||
@ -622,13 +623,15 @@ class DocumentIndexingStatusApi(DatasetApiResource):
|
||||
.where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != "re_segment",
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
total_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
|
||||
.where(
|
||||
DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT
|
||||
)
|
||||
.count()
|
||||
)
|
||||
# Create a dictionary with document attributes and additional fields
|
||||
|
||||
@ -70,7 +70,14 @@ def handle_webhook(webhook_id: str):
|
||||
|
||||
@bp.route("/webhook-debug/<string:webhook_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
|
||||
def handle_webhook_debug(webhook_id: str):
|
||||
"""Handle webhook debug calls without triggering production workflow execution."""
|
||||
"""Handle webhook debug calls without triggering production workflow execution.
|
||||
|
||||
The debug webhook endpoint is only for draft inspection flows. It never enqueues
|
||||
Celery work for the published workflow; instead it dispatches an in-memory debug
|
||||
event to an active Variable Inspector listener. Returning a clear error when no
|
||||
listener is registered prevents a misleading 200 response for requests that are
|
||||
effectively dropped.
|
||||
"""
|
||||
try:
|
||||
webhook_trigger, _, node_config, webhook_data, error = _prepare_webhook_execution(webhook_id, is_debug=True)
|
||||
if error:
|
||||
@ -94,11 +101,32 @@ def handle_webhook_debug(webhook_id: str):
|
||||
"method": webhook_data.get("method"),
|
||||
},
|
||||
)
|
||||
TriggerDebugEventBus.dispatch(
|
||||
dispatch_count = TriggerDebugEventBus.dispatch(
|
||||
tenant_id=webhook_trigger.tenant_id,
|
||||
event=event,
|
||||
pool_key=pool_key,
|
||||
)
|
||||
if dispatch_count == 0:
|
||||
logger.warning(
|
||||
"Webhook debug request dropped without an active listener for webhook %s (tenant=%s, app=%s, node=%s)",
|
||||
webhook_trigger.webhook_id,
|
||||
webhook_trigger.tenant_id,
|
||||
webhook_trigger.app_id,
|
||||
webhook_trigger.node_id,
|
||||
)
|
||||
return (
|
||||
jsonify(
|
||||
{
|
||||
"error": "No active debug listener",
|
||||
"message": (
|
||||
"The webhook debug URL only works while the Variable Inspector is listening. "
|
||||
"Use the published webhook URL to execute the workflow in Celery."
|
||||
),
|
||||
"execution_url": webhook_trigger.webhook_url,
|
||||
}
|
||||
),
|
||||
409,
|
||||
)
|
||||
response_data, status_code = WebhookService.generate_webhook_response(node_config)
|
||||
return jsonify(response_data), status_code
|
||||
|
||||
|
||||
@ -441,7 +441,7 @@ class BaseAgentRunner(AppRunner):
|
||||
continue
|
||||
|
||||
result.append(self.organize_agent_user_prompt(message))
|
||||
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
|
||||
agent_thoughts = message.agent_thoughts
|
||||
if agent_thoughts:
|
||||
for agent_thought in agent_thoughts:
|
||||
tool_names_raw = agent_thought.tool
|
||||
|
||||
@ -1,13 +1,36 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from constants import DEFAULT_FILE_NUMBER_LIMITS
|
||||
|
||||
|
||||
class SystemParametersDict(TypedDict):
|
||||
image_file_size_limit: int
|
||||
video_file_size_limit: int
|
||||
audio_file_size_limit: int
|
||||
file_size_limit: int
|
||||
workflow_file_upload_limit: int
|
||||
|
||||
|
||||
class AppParametersDict(TypedDict):
|
||||
opening_statement: str | None
|
||||
suggested_questions: list[str]
|
||||
suggested_questions_after_answer: dict[str, Any]
|
||||
speech_to_text: dict[str, Any]
|
||||
text_to_speech: dict[str, Any]
|
||||
retriever_resource: dict[str, Any]
|
||||
annotation_reply: dict[str, Any]
|
||||
more_like_this: dict[str, Any]
|
||||
user_input_form: list[dict[str, Any]]
|
||||
sensitive_word_avoidance: dict[str, Any]
|
||||
file_upload: dict[str, Any]
|
||||
system_parameters: SystemParametersDict
|
||||
|
||||
|
||||
def get_parameters_from_feature_dict(
|
||||
*, features_dict: Mapping[str, Any], user_input_form: list[dict[str, Any]]
|
||||
) -> Mapping[str, Any]:
|
||||
) -> AppParametersDict:
|
||||
"""
|
||||
Mapping from feature dict to webapp parameters
|
||||
"""
|
||||
|
||||
@ -8,6 +8,7 @@ from core.app.app_config.entities import (
|
||||
ModelConfig,
|
||||
)
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
|
||||
from models.model import AppMode, AppModelConfigDict
|
||||
from services.dataset_service import DatasetService
|
||||
|
||||
@ -117,8 +118,10 @@ class DatasetConfigManager:
|
||||
score_threshold=float(score_threshold_val)
|
||||
if dataset_configs.get("score_threshold_enabled", False) and score_threshold_val is not None
|
||||
else None,
|
||||
reranking_model=reranking_model_val if isinstance(reranking_model_val, dict) else None,
|
||||
weights=weights_val if isinstance(weights_val, dict) else None,
|
||||
reranking_model=cast(RerankingModelDict, reranking_model_val)
|
||||
if isinstance(reranking_model_val, dict)
|
||||
else None,
|
||||
weights=cast(WeightsDict, weights_val) if isinstance(weights_val, dict) else None,
|
||||
reranking_enabled=bool(dataset_configs.get("reranking_enabled", True)),
|
||||
rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"),
|
||||
metadata_filtering_mode=cast(
|
||||
|
||||
@ -4,6 +4,7 @@ from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
|
||||
from dify_graph.file import FileUploadConfig
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMMode
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessageRole
|
||||
@ -194,8 +195,8 @@ class DatasetRetrieveConfigEntity(BaseModel):
|
||||
top_k: int | None = None
|
||||
score_threshold: float | None = 0.0
|
||||
rerank_mode: str | None = "reranking_model"
|
||||
reranking_model: dict | None = None
|
||||
weights: dict | None = None
|
||||
reranking_model: RerankingModelDict | None = None
|
||||
weights: WeightsDict | None = None
|
||||
reranking_enabled: bool | None = True
|
||||
metadata_filtering_mode: Literal["disabled", "automatic", "manual"] | None = "disabled"
|
||||
metadata_model_config: ModelConfig | None = None
|
||||
|
||||
@ -3,7 +3,7 @@ import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, NewType, Union
|
||||
from typing import Any, NewType, TypedDict, Union
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@ -76,6 +76,20 @@ NodeExecutionId = NewType("NodeExecutionId", str)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AccountCreatedByDict(TypedDict):
|
||||
id: str
|
||||
name: str
|
||||
email: str
|
||||
|
||||
|
||||
class EndUserCreatedByDict(TypedDict):
|
||||
id: str
|
||||
user: str
|
||||
|
||||
|
||||
CreatedByDict = AccountCreatedByDict | EndUserCreatedByDict
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _NodeSnapshot:
|
||||
"""In-memory cache for node metadata between start and completion events."""
|
||||
@ -249,19 +263,19 @@ class WorkflowResponseConverter:
|
||||
outputs_mapping = graph_runtime_state.outputs or {}
|
||||
encoded_outputs = WorkflowRuntimeTypeConverter().to_json_encodable(outputs_mapping)
|
||||
|
||||
created_by: Mapping[str, object] | None
|
||||
created_by: CreatedByDict | dict[str, object] = {}
|
||||
user = self._user
|
||||
if isinstance(user, Account):
|
||||
created_by = {
|
||||
"id": user.id,
|
||||
"name": user.name,
|
||||
"email": user.email,
|
||||
}
|
||||
else:
|
||||
created_by = {
|
||||
"id": user.id,
|
||||
"user": user.session_id,
|
||||
}
|
||||
created_by = AccountCreatedByDict(
|
||||
id=user.id,
|
||||
name=user.name,
|
||||
email=user.email,
|
||||
)
|
||||
elif isinstance(user, EndUser):
|
||||
created_by = EndUserCreatedByDict(
|
||||
id=user.id,
|
||||
user=user.session_id,
|
||||
)
|
||||
|
||||
return WorkflowFinishStreamResponse(
|
||||
task_id=task_id,
|
||||
|
||||
@ -6,6 +6,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.enums import CollectionBindingType
|
||||
from models.model import App, AppAnnotationSetting, Message, MessageAnnotation
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
@ -43,7 +44,7 @@ class AnnotationReplyFeature:
|
||||
embedding_model_name = collection_binding_detail.model_name
|
||||
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_provider_name, embedding_model_name, "annotation"
|
||||
embedding_provider_name, embedding_model_name, CollectionBindingType.ANNOTATION
|
||||
)
|
||||
|
||||
dataset = Dataset(
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import TypedDict
|
||||
|
||||
from core.tools.signature import sign_tool_file
|
||||
from dify_graph.file import helpers as file_helpers
|
||||
from dify_graph.file.enums import FileTransferMethod
|
||||
@ -6,7 +8,20 @@ from models.model import MessageFile, UploadFile
|
||||
MAX_TOOL_FILE_EXTENSION_LENGTH = 10
|
||||
|
||||
|
||||
def prepare_file_dict(message_file: MessageFile, upload_files_map: dict[str, UploadFile]) -> dict:
|
||||
class MessageFileInfoDict(TypedDict):
|
||||
related_id: str
|
||||
extension: str
|
||||
filename: str
|
||||
size: int
|
||||
mime_type: str
|
||||
transfer_method: str
|
||||
type: str
|
||||
url: str
|
||||
upload_file_id: str
|
||||
remote_url: str | None
|
||||
|
||||
|
||||
def prepare_file_dict(message_file: MessageFile, upload_files_map: dict[str, UploadFile]) -> MessageFileInfoDict:
|
||||
"""
|
||||
Prepare file dictionary for message end stream response.
|
||||
|
||||
|
||||
@ -12,7 +12,7 @@ from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import ChildChunk, DatasetQuery, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.enums import CreatorUserRole
|
||||
from models.enums import CreatorUserRole, DatasetQuerySource
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
@ -36,7 +36,7 @@ class DatasetIndexToolCallbackHandler:
|
||||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset_id,
|
||||
content=query,
|
||||
source="app",
|
||||
source=DatasetQuerySource.APP,
|
||||
source_app_id=self._app_id,
|
||||
created_by_role=(
|
||||
CreatorUserRole.ACCOUNT
|
||||
|
||||
@ -30,6 +30,7 @@ from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.engine import db
|
||||
from models.enums import CredentialSourceType
|
||||
from models.provider import (
|
||||
LoadBalancingModelConfig,
|
||||
Provider,
|
||||
@ -473,9 +474,21 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
self.switch_preferred_provider_type(provider_type=ProviderType.CUSTOM, session=session)
|
||||
else:
|
||||
# some historical data may have a provider record but not be set as valid
|
||||
provider_record.is_valid = True
|
||||
|
||||
if provider_record.credential_id is None:
|
||||
provider_record.credential_id = new_record.id
|
||||
provider_record.updated_at = naive_utc_now()
|
||||
|
||||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=provider_record.id,
|
||||
cache_type=ProviderCredentialsCacheType.PROVIDER,
|
||||
)
|
||||
provider_model_credentials_cache.delete()
|
||||
|
||||
self.switch_preferred_provider_type(provider_type=ProviderType.CUSTOM, session=session)
|
||||
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
@ -534,7 +547,7 @@ class ProviderConfiguration(BaseModel):
|
||||
self._update_load_balancing_configs_with_credential(
|
||||
credential_id=credential_id,
|
||||
credential_record=credential_record,
|
||||
credential_source="provider",
|
||||
credential_source=CredentialSourceType.PROVIDER,
|
||||
session=session,
|
||||
)
|
||||
except Exception:
|
||||
@ -611,7 +624,7 @@ class ProviderConfiguration(BaseModel):
|
||||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
|
||||
LoadBalancingModelConfig.credential_id == credential_id,
|
||||
LoadBalancingModelConfig.credential_source_type == "provider",
|
||||
LoadBalancingModelConfig.credential_source_type == CredentialSourceType.PROVIDER,
|
||||
)
|
||||
lb_configs_using_credential = session.execute(lb_stmt).scalars().all()
|
||||
try:
|
||||
@ -1031,7 +1044,7 @@ class ProviderConfiguration(BaseModel):
|
||||
self._update_load_balancing_configs_with_credential(
|
||||
credential_id=credential_id,
|
||||
credential_record=credential_record,
|
||||
credential_source="custom_model",
|
||||
credential_source=CredentialSourceType.CUSTOM_MODEL,
|
||||
session=session,
|
||||
)
|
||||
except Exception:
|
||||
@ -1061,7 +1074,7 @@ class ProviderConfiguration(BaseModel):
|
||||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
|
||||
LoadBalancingModelConfig.credential_id == credential_id,
|
||||
LoadBalancingModelConfig.credential_source_type == "custom_model",
|
||||
LoadBalancingModelConfig.credential_source_type == CredentialSourceType.CUSTOM_MODEL,
|
||||
)
|
||||
lb_configs_using_credential = session.execute(lb_stmt).scalars().all()
|
||||
|
||||
@ -1699,7 +1712,7 @@ class ProviderConfiguration(BaseModel):
|
||||
provider_model_lb_configs = [
|
||||
config
|
||||
for config in model_setting.load_balancing_configs
|
||||
if config.credential_source_type != "custom_model"
|
||||
if config.credential_source_type != CredentialSourceType.CUSTOM_MODEL
|
||||
]
|
||||
|
||||
load_balancing_enabled = model_setting.load_balancing_enabled
|
||||
@ -1757,7 +1770,7 @@ class ProviderConfiguration(BaseModel):
|
||||
custom_model_lb_configs = [
|
||||
config
|
||||
for config in model_setting.load_balancing_configs
|
||||
if config.credential_source_type != "provider"
|
||||
if config.credential_source_type != CredentialSourceType.PROVIDER
|
||||
]
|
||||
|
||||
load_balancing_enabled = model_setting.load_balancing_enabled
|
||||
|
||||
@ -40,6 +40,7 @@ from libs.datetime_utils import naive_utc_now
|
||||
from models import Account
|
||||
from models.dataset import AutomaticRulesConfig, ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.enums import DataSourceType, IndexingStatus, ProcessRuleMode, SegmentStatus
|
||||
from models.model import UploadFile
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
@ -56,7 +57,7 @@ class IndexingRunner:
|
||||
logger.exception("consume document failed")
|
||||
document = db.session.get(DatasetDocument, document_id)
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.indexing_status = IndexingStatus.ERROR
|
||||
error_message = getattr(error, "description", str(error))
|
||||
document.error = str(error_message)
|
||||
document.stopped_at = naive_utc_now()
|
||||
@ -219,7 +220,7 @@ class IndexingRunner:
|
||||
if document_segments:
|
||||
for document_segment in document_segments:
|
||||
# transform segment to node
|
||||
if document_segment.status != "completed":
|
||||
if document_segment.status != SegmentStatus.COMPLETED:
|
||||
document = Document(
|
||||
page_content=document_segment.content,
|
||||
metadata={
|
||||
@ -382,7 +383,7 @@ class IndexingRunner:
|
||||
data_source_info = dataset_document.data_source_info_dict
|
||||
text_docs = []
|
||||
match dataset_document.data_source_type:
|
||||
case "upload_file":
|
||||
case DataSourceType.UPLOAD_FILE:
|
||||
if not data_source_info or "upload_file_id" not in data_source_info:
|
||||
raise ValueError("no upload file found")
|
||||
stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"])
|
||||
@ -395,7 +396,7 @@ class IndexingRunner:
|
||||
document_model=dataset_document.doc_form,
|
||||
)
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
|
||||
case "notion_import":
|
||||
case DataSourceType.NOTION_IMPORT:
|
||||
if (
|
||||
not data_source_info
|
||||
or "notion_workspace_id" not in data_source_info
|
||||
@ -417,7 +418,7 @@ class IndexingRunner:
|
||||
document_model=dataset_document.doc_form,
|
||||
)
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
|
||||
case "website_crawl":
|
||||
case DataSourceType.WEBSITE_CRAWL:
|
||||
if (
|
||||
not data_source_info
|
||||
or "provider" not in data_source_info
|
||||
@ -445,7 +446,7 @@ class IndexingRunner:
|
||||
# update document status to splitting
|
||||
self._update_document_index_status(
|
||||
document_id=dataset_document.id,
|
||||
after_indexing_status="splitting",
|
||||
after_indexing_status=IndexingStatus.SPLITTING,
|
||||
extra_update_params={
|
||||
DatasetDocument.parsing_completed_at: naive_utc_now(),
|
||||
},
|
||||
@ -545,7 +546,7 @@ class IndexingRunner:
|
||||
Clean the document text according to the processing rules.
|
||||
"""
|
||||
rules: AutomaticRulesConfig | dict[str, Any]
|
||||
if processing_rule.mode == "automatic":
|
||||
if processing_rule.mode == ProcessRuleMode.AUTOMATIC:
|
||||
rules = DatasetProcessRule.AUTOMATIC_RULES
|
||||
else:
|
||||
rules = json.loads(processing_rule.rules) if processing_rule.rules else {}
|
||||
@ -636,7 +637,7 @@ class IndexingRunner:
|
||||
# update document status to completed
|
||||
self._update_document_index_status(
|
||||
document_id=dataset_document.id,
|
||||
after_indexing_status="completed",
|
||||
after_indexing_status=IndexingStatus.COMPLETED,
|
||||
extra_update_params={
|
||||
DatasetDocument.tokens: tokens,
|
||||
DatasetDocument.completed_at: naive_utc_now(),
|
||||
@ -659,10 +660,10 @@ class IndexingRunner:
|
||||
DocumentSegment.document_id == document_id,
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.index_node_id.in_(document_ids),
|
||||
DocumentSegment.status == "indexing",
|
||||
DocumentSegment.status == SegmentStatus.INDEXING,
|
||||
).update(
|
||||
{
|
||||
DocumentSegment.status: "completed",
|
||||
DocumentSegment.status: SegmentStatus.COMPLETED,
|
||||
DocumentSegment.enabled: True,
|
||||
DocumentSegment.completed_at: naive_utc_now(),
|
||||
}
|
||||
@ -703,10 +704,10 @@ class IndexingRunner:
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.index_node_id.in_(document_ids),
|
||||
DocumentSegment.status == "indexing",
|
||||
DocumentSegment.status == SegmentStatus.INDEXING,
|
||||
).update(
|
||||
{
|
||||
DocumentSegment.status: "completed",
|
||||
DocumentSegment.status: SegmentStatus.COMPLETED,
|
||||
DocumentSegment.enabled: True,
|
||||
DocumentSegment.completed_at: naive_utc_now(),
|
||||
}
|
||||
@ -725,7 +726,7 @@ class IndexingRunner:
|
||||
|
||||
@staticmethod
|
||||
def _update_document_index_status(
|
||||
document_id: str, after_indexing_status: str, extra_update_params: dict | None = None
|
||||
document_id: str, after_indexing_status: IndexingStatus, extra_update_params: dict | None = None
|
||||
):
|
||||
"""
|
||||
Update the document indexing status.
|
||||
@ -803,7 +804,7 @@ class IndexingRunner:
|
||||
cur_time = naive_utc_now()
|
||||
self._update_document_index_status(
|
||||
document_id=dataset_document.id,
|
||||
after_indexing_status="indexing",
|
||||
after_indexing_status=IndexingStatus.INDEXING,
|
||||
extra_update_params={
|
||||
DatasetDocument.cleaning_completed_at: cur_time,
|
||||
DatasetDocument.splitting_completed_at: cur_time,
|
||||
@ -815,7 +816,7 @@ class IndexingRunner:
|
||||
self._update_segments_by_document(
|
||||
dataset_document_id=dataset_document.id,
|
||||
update_params={
|
||||
DocumentSegment.status: "indexing",
|
||||
DocumentSegment.status: SegmentStatus.INDEXING,
|
||||
DocumentSegment.indexing_at: naive_utc_now(),
|
||||
},
|
||||
)
|
||||
|
||||
@ -196,6 +196,8 @@ class ProviderManager:
|
||||
|
||||
if preferred_provider_type_record:
|
||||
preferred_provider_type = ProviderType.value_of(preferred_provider_type_record.preferred_provider_type)
|
||||
elif dify_config.EDITION == "CLOUD" and system_configuration.enabled:
|
||||
preferred_provider_type = ProviderType.SYSTEM
|
||||
elif custom_configuration.provider or custom_configuration.models:
|
||||
preferred_provider_type = ProviderType.CUSTOM
|
||||
elif system_configuration.enabled:
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.rag.data_post_processor.reorder import ReorderRunner
|
||||
from core.rag.index_processor.constant.query_type import QueryType
|
||||
@ -10,6 +12,26 @@ from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
|
||||
|
||||
class RerankingModelDict(TypedDict):
|
||||
reranking_provider_name: str
|
||||
reranking_model_name: str
|
||||
|
||||
|
||||
class VectorSettingDict(TypedDict):
|
||||
vector_weight: float
|
||||
embedding_provider_name: str
|
||||
embedding_model_name: str
|
||||
|
||||
|
||||
class KeywordSettingDict(TypedDict):
|
||||
keyword_weight: float
|
||||
|
||||
|
||||
class WeightsDict(TypedDict):
|
||||
vector_setting: VectorSettingDict
|
||||
keyword_setting: KeywordSettingDict
|
||||
|
||||
|
||||
class DataPostProcessor:
|
||||
"""Interface for data post-processing document."""
|
||||
|
||||
@ -17,8 +39,8 @@ class DataPostProcessor:
|
||||
self,
|
||||
tenant_id: str,
|
||||
reranking_mode: str,
|
||||
reranking_model: dict | None = None,
|
||||
weights: dict | None = None,
|
||||
reranking_model: RerankingModelDict | None = None,
|
||||
weights: WeightsDict | None = None,
|
||||
reorder_enabled: bool = False,
|
||||
):
|
||||
self.rerank_runner = self._get_rerank_runner(reranking_mode, tenant_id, reranking_model, weights)
|
||||
@ -45,8 +67,8 @@ class DataPostProcessor:
|
||||
self,
|
||||
reranking_mode: str,
|
||||
tenant_id: str,
|
||||
reranking_model: dict | None = None,
|
||||
weights: dict | None = None,
|
||||
reranking_model: RerankingModelDict | None = None,
|
||||
weights: WeightsDict | None = None,
|
||||
) -> BaseRerankRunner | None:
|
||||
if reranking_mode == RerankMode.WEIGHTED_SCORE and weights:
|
||||
runner = RerankRunnerFactory.create_rerank_runner(
|
||||
@ -79,12 +101,14 @@ class DataPostProcessor:
|
||||
return ReorderRunner()
|
||||
return None
|
||||
|
||||
def _get_rerank_model_instance(self, tenant_id: str, reranking_model: dict | None) -> ModelInstance | None:
|
||||
def _get_rerank_model_instance(
|
||||
self, tenant_id: str, reranking_model: RerankingModelDict | None
|
||||
) -> ModelInstance | None:
|
||||
if reranking_model:
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
reranking_provider_name = reranking_model.get("reranking_provider_name")
|
||||
reranking_model_name = reranking_model.get("reranking_model_name")
|
||||
reranking_provider_name = reranking_model["reranking_provider_name"]
|
||||
reranking_model_name = reranking_model["reranking_model_name"]
|
||||
if not reranking_provider_name or not reranking_model_name:
|
||||
return None
|
||||
rerank_model_instance = model_manager.get_model_instance(
|
||||
|
||||
@ -1,19 +1,20 @@
|
||||
import concurrent.futures
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
from typing import Any, NotRequired
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, load_only
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
from core.model_manager import ModelManager
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.embedding.retrieval import RetrievalChildChunk, RetrievalSegments
|
||||
from core.rag.embedding.retrieval import AttachmentInfoDict, RetrievalChildChunk, RetrievalSegments
|
||||
from core.rag.entities.metadata_entities import MetadataCondition
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
@ -35,7 +36,46 @@ from models.dataset import Document as DatasetDocument
|
||||
from models.model import UploadFile
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
||||
default_retrieval_model = {
|
||||
|
||||
class SegmentAttachmentResult(TypedDict):
|
||||
attachment_info: AttachmentInfoDict
|
||||
segment_id: str
|
||||
|
||||
|
||||
class SegmentAttachmentInfoResult(TypedDict):
|
||||
attachment_id: str
|
||||
attachment_info: AttachmentInfoDict
|
||||
segment_id: str
|
||||
|
||||
|
||||
class ChildChunkDetail(TypedDict):
|
||||
id: str
|
||||
content: str
|
||||
position: int
|
||||
score: float
|
||||
|
||||
|
||||
class SegmentChildMapDetail(TypedDict):
|
||||
max_score: float
|
||||
child_chunks: list[ChildChunkDetail]
|
||||
|
||||
|
||||
class SegmentRecord(TypedDict):
|
||||
segment: DocumentSegment
|
||||
score: NotRequired[float]
|
||||
child_chunks: NotRequired[list[ChildChunkDetail]]
|
||||
files: NotRequired[list[AttachmentInfoDict]]
|
||||
|
||||
|
||||
class DefaultRetrievalModelDict(TypedDict):
|
||||
search_method: RetrievalMethod | str
|
||||
reranking_enable: bool
|
||||
reranking_model: RerankingModelDict
|
||||
top_k: int
|
||||
score_threshold_enabled: bool
|
||||
|
||||
|
||||
default_retrieval_model: DefaultRetrievalModelDict = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
@ -56,9 +96,9 @@ class RetrievalService:
|
||||
query: str,
|
||||
top_k: int = 4,
|
||||
score_threshold: float | None = 0.0,
|
||||
reranking_model: dict | None = None,
|
||||
reranking_model: RerankingModelDict | None = None,
|
||||
reranking_mode: str = "reranking_model",
|
||||
weights: dict | None = None,
|
||||
weights: WeightsDict | None = None,
|
||||
document_ids_filter: list[str] | None = None,
|
||||
attachment_ids: list | None = None,
|
||||
):
|
||||
@ -235,7 +275,7 @@ class RetrievalService:
|
||||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: float | None,
|
||||
reranking_model: dict | None,
|
||||
reranking_model: RerankingModelDict | None,
|
||||
all_documents: list,
|
||||
retrieval_method: RetrievalMethod,
|
||||
exceptions: list,
|
||||
@ -277,8 +317,8 @@ class RetrievalService:
|
||||
if documents:
|
||||
if (
|
||||
reranking_model
|
||||
and reranking_model.get("reranking_model_name")
|
||||
and reranking_model.get("reranking_provider_name")
|
||||
and reranking_model["reranking_model_name"]
|
||||
and reranking_model["reranking_provider_name"]
|
||||
and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH
|
||||
):
|
||||
data_post_processor = DataPostProcessor(
|
||||
@ -288,8 +328,8 @@ class RetrievalService:
|
||||
model_manager = ModelManager()
|
||||
is_support_vision = model_manager.check_model_support_vision(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=reranking_model.get("reranking_provider_name") or "",
|
||||
model=reranking_model.get("reranking_model_name") or "",
|
||||
provider=reranking_model["reranking_provider_name"],
|
||||
model=reranking_model["reranking_model_name"],
|
||||
model_type=ModelType.RERANK,
|
||||
)
|
||||
if is_support_vision:
|
||||
@ -329,7 +369,7 @@ class RetrievalService:
|
||||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: float | None,
|
||||
reranking_model: dict | None,
|
||||
reranking_model: RerankingModelDict | None,
|
||||
all_documents: list,
|
||||
retrieval_method: str,
|
||||
exceptions: list,
|
||||
@ -349,8 +389,8 @@ class RetrievalService:
|
||||
if documents:
|
||||
if (
|
||||
reranking_model
|
||||
and reranking_model.get("reranking_model_name")
|
||||
and reranking_model.get("reranking_provider_name")
|
||||
and reranking_model["reranking_model_name"]
|
||||
and reranking_model["reranking_provider_name"]
|
||||
and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH
|
||||
):
|
||||
data_post_processor = DataPostProcessor(
|
||||
@ -459,7 +499,7 @@ class RetrievalService:
|
||||
segment_ids: list[str] = []
|
||||
index_node_segments: list[DocumentSegment] = []
|
||||
segments: list[DocumentSegment] = []
|
||||
attachment_map: dict[str, list[dict[str, Any]]] = {}
|
||||
attachment_map: dict[str, list[AttachmentInfoDict]] = {}
|
||||
child_chunk_map: dict[str, list[ChildChunk]] = {}
|
||||
doc_segment_map: dict[str, list[str]] = {}
|
||||
segment_summary_map: dict[str, str] = {} # Map segment_id to summary content
|
||||
@ -544,12 +584,12 @@ class RetrievalService:
|
||||
segment_summary_map[summary.chunk_id] = summary.summary_content
|
||||
|
||||
include_segment_ids = set()
|
||||
segment_child_map: dict[str, dict[str, Any]] = {}
|
||||
records: list[dict[str, Any]] = []
|
||||
segment_child_map: dict[str, SegmentChildMapDetail] = {}
|
||||
records: list[SegmentRecord] = []
|
||||
|
||||
for segment in segments:
|
||||
child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, [])
|
||||
attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, [])
|
||||
attachment_infos: list[AttachmentInfoDict] = attachment_map.get(segment.id, [])
|
||||
ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(segment.document_id)
|
||||
|
||||
if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
@ -560,14 +600,14 @@ class RetrievalService:
|
||||
max_score = summary_score_map.get(segment.id, 0.0)
|
||||
|
||||
if child_chunks or attachment_infos:
|
||||
child_chunk_details = []
|
||||
child_chunk_details: list[ChildChunkDetail] = []
|
||||
for child_chunk in child_chunks:
|
||||
child_document: Document | None = doc_to_document_map.get(child_chunk.index_node_id)
|
||||
if child_document:
|
||||
child_score = child_document.metadata.get("score", 0.0)
|
||||
else:
|
||||
child_score = 0.0
|
||||
child_chunk_detail = {
|
||||
child_chunk_detail: ChildChunkDetail = {
|
||||
"id": child_chunk.id,
|
||||
"content": child_chunk.content,
|
||||
"position": child_chunk.position,
|
||||
@ -580,7 +620,7 @@ class RetrievalService:
|
||||
if file_document:
|
||||
max_score = max(max_score, file_document.metadata.get("score", 0.0))
|
||||
|
||||
map_detail = {
|
||||
map_detail: SegmentChildMapDetail = {
|
||||
"max_score": max_score,
|
||||
"child_chunks": child_chunk_details,
|
||||
}
|
||||
@ -593,7 +633,7 @@ class RetrievalService:
|
||||
"max_score": summary_score,
|
||||
"child_chunks": [],
|
||||
}
|
||||
record: dict[str, Any] = {
|
||||
record: SegmentRecord = {
|
||||
"segment": segment,
|
||||
}
|
||||
records.append(record)
|
||||
@ -617,19 +657,19 @@ class RetrievalService:
|
||||
if file_doc:
|
||||
max_score = max(max_score, file_doc.metadata.get("score", 0.0))
|
||||
|
||||
record = {
|
||||
another_record: SegmentRecord = {
|
||||
"segment": segment,
|
||||
"score": max_score,
|
||||
}
|
||||
records.append(record)
|
||||
records.append(another_record)
|
||||
|
||||
# Add child chunks information to records
|
||||
for record in records:
|
||||
if record["segment"].id in segment_child_map:
|
||||
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
|
||||
record["score"] = segment_child_map[record["segment"].id]["max_score"] # type: ignore
|
||||
record["child_chunks"] = segment_child_map[record["segment"].id]["child_chunks"]
|
||||
record["score"] = segment_child_map[record["segment"].id]["max_score"]
|
||||
if record["segment"].id in attachment_map:
|
||||
record["files"] = attachment_map[record["segment"].id] # type: ignore[assignment]
|
||||
record["files"] = attachment_map[record["segment"].id]
|
||||
|
||||
result: list[RetrievalSegments] = []
|
||||
for record in records:
|
||||
@ -693,9 +733,9 @@ class RetrievalService:
|
||||
query: str | None = None,
|
||||
top_k: int = 4,
|
||||
score_threshold: float | None = 0.0,
|
||||
reranking_model: dict | None = None,
|
||||
reranking_model: RerankingModelDict | None = None,
|
||||
reranking_mode: str = "reranking_model",
|
||||
weights: dict | None = None,
|
||||
weights: WeightsDict | None = None,
|
||||
document_ids_filter: list[str] | None = None,
|
||||
attachment_id: str | None = None,
|
||||
):
|
||||
@ -807,7 +847,7 @@ class RetrievalService:
|
||||
@classmethod
|
||||
def get_segment_attachment_info(
|
||||
cls, dataset_id: str, tenant_id: str, attachment_id: str, session: Session
|
||||
) -> dict[str, Any] | None:
|
||||
) -> SegmentAttachmentResult | None:
|
||||
upload_file = session.query(UploadFile).where(UploadFile.id == attachment_id).first()
|
||||
if upload_file:
|
||||
attachment_binding = (
|
||||
@ -816,7 +856,7 @@ class RetrievalService:
|
||||
.first()
|
||||
)
|
||||
if attachment_binding:
|
||||
attachment_info = {
|
||||
attachment_info: AttachmentInfoDict = {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"extension": "." + upload_file.extension,
|
||||
@ -828,8 +868,10 @@ class RetrievalService:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_segment_attachment_infos(cls, attachment_ids: list[str], session: Session) -> list[dict[str, Any]]:
|
||||
attachment_infos = []
|
||||
def get_segment_attachment_infos(
|
||||
cls, attachment_ids: list[str], session: Session
|
||||
) -> list[SegmentAttachmentInfoResult]:
|
||||
attachment_infos: list[SegmentAttachmentInfoResult] = []
|
||||
upload_files = session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all()
|
||||
if upload_files:
|
||||
upload_file_ids = [upload_file.id for upload_file in upload_files]
|
||||
@ -843,7 +885,7 @@ class RetrievalService:
|
||||
if attachment_bindings:
|
||||
for upload_file in upload_files:
|
||||
attachment_binding = attachment_binding_map.get(upload_file.id)
|
||||
attachment_info = {
|
||||
info: AttachmentInfoDict = {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"extension": "." + upload_file.extension,
|
||||
@ -855,7 +897,7 @@ class RetrievalService:
|
||||
attachment_infos.append(
|
||||
{
|
||||
"attachment_id": attachment_binding.attachment_id,
|
||||
"attachment_info": attachment_info,
|
||||
"attachment_info": info,
|
||||
"segment_id": attachment_binding.segment_id,
|
||||
}
|
||||
)
|
||||
|
||||
@ -1,8 +1,18 @@
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from models.dataset import DocumentSegment
|
||||
|
||||
|
||||
class AttachmentInfoDict(TypedDict):
|
||||
id: str
|
||||
name: str
|
||||
extension: str
|
||||
mime_type: str
|
||||
source_url: str
|
||||
size: int
|
||||
|
||||
|
||||
class RetrievalChildChunk(BaseModel):
|
||||
"""Retrieval segments."""
|
||||
|
||||
@ -19,5 +29,5 @@ class RetrievalSegments(BaseModel):
|
||||
segment: DocumentSegment
|
||||
child_chunks: list[RetrievalChildChunk] | None = None
|
||||
score: float | None = None
|
||||
files: list[dict[str, str | int]] | None = None
|
||||
files: list[AttachmentInfoDict] | None = None
|
||||
summary: str | None = None # Summary content if retrieved via summary index
|
||||
|
||||
@ -9,6 +9,7 @@ from flask import current_app
|
||||
from sqlalchemy import delete, func, select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
||||
from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError
|
||||
from core.workflow.nodes.knowledge_index.protocols import Preview, PreviewItem, QaPreview
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
@ -51,7 +52,7 @@ class IndexProcessor:
|
||||
original_document_id: str,
|
||||
chunks: Mapping[str, Any],
|
||||
batch: Any,
|
||||
summary_index_setting: dict | None = None,
|
||||
summary_index_setting: SummaryIndexSettingDict | None = None,
|
||||
):
|
||||
with session_factory.create_session() as session:
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
@ -131,7 +132,12 @@ class IndexProcessor:
|
||||
}
|
||||
|
||||
def get_preview_output(
|
||||
self, chunks: Any, dataset_id: str, document_id: str, chunk_structure: str, summary_index_setting: dict | None
|
||||
self,
|
||||
chunks: Any,
|
||||
dataset_id: str,
|
||||
document_id: str,
|
||||
chunk_structure: str,
|
||||
summary_index_setting: SummaryIndexSettingDict | None,
|
||||
) -> Preview:
|
||||
doc_language = None
|
||||
with session_factory.create_session() as session:
|
||||
|
||||
@ -7,14 +7,16 @@ import os
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, NotRequired, Optional
|
||||
from urllib.parse import unquote, urlparse
|
||||
|
||||
import httpx
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.helper import ssrf_proxy
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.models.document import AttachmentDocument, Document
|
||||
@ -35,6 +37,13 @@ if TYPE_CHECKING:
|
||||
from core.model_manager import ModelInstance
|
||||
|
||||
|
||||
class SummaryIndexSettingDict(TypedDict):
|
||||
enable: bool
|
||||
model_name: NotRequired[str]
|
||||
model_provider_name: NotRequired[str]
|
||||
summary_prompt: NotRequired[str]
|
||||
|
||||
|
||||
class BaseIndexProcessor(ABC):
|
||||
"""Interface for extract files."""
|
||||
|
||||
@ -51,7 +60,7 @@ class BaseIndexProcessor(ABC):
|
||||
self,
|
||||
tenant_id: str,
|
||||
preview_texts: list[PreviewDetail],
|
||||
summary_index_setting: dict,
|
||||
summary_index_setting: SummaryIndexSettingDict,
|
||||
doc_language: str | None = None,
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
@ -98,7 +107,7 @@ class BaseIndexProcessor(ABC):
|
||||
dataset: Dataset,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
reranking_model: dict,
|
||||
reranking_model: RerankingModelDict,
|
||||
) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -14,6 +14,7 @@ from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
|
||||
from core.model_manager import ModelInstance
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
@ -22,7 +23,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
|
||||
from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
@ -175,7 +176,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
dataset: Dataset,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
reranking_model: dict,
|
||||
reranking_model: RerankingModelDict,
|
||||
) -> list[Document]:
|
||||
# Set search parameters.
|
||||
results = RetrievalService.retrieve(
|
||||
@ -278,7 +279,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
self,
|
||||
tenant_id: str,
|
||||
preview_texts: list[PreviewDetail],
|
||||
summary_index_setting: dict,
|
||||
summary_index_setting: SummaryIndexSettingDict,
|
||||
doc_language: str | None = None,
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
@ -362,7 +363,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
def generate_summary(
|
||||
tenant_id: str,
|
||||
text: str,
|
||||
summary_index_setting: dict | None = None,
|
||||
summary_index_setting: SummaryIndexSettingDict | None = None,
|
||||
segment_id: str | None = None,
|
||||
document_language: str | None = None,
|
||||
) -> tuple[str, LLMUsage]:
|
||||
|
||||
@ -11,6 +11,7 @@ from core.db.session_factory import session_factory
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.model_manager import ModelInstance
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
@ -18,7 +19,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
|
||||
from core.rag.models.document import AttachmentDocument, ChildDocument, Document, ParentChildStructureChunk
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
@ -215,7 +216,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
dataset: Dataset,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
reranking_model: dict,
|
||||
reranking_model: RerankingModelDict,
|
||||
) -> list[Document]:
|
||||
# Set search parameters.
|
||||
results = RetrievalService.retrieve(
|
||||
@ -361,7 +362,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
self,
|
||||
tenant_id: str,
|
||||
preview_texts: list[PreviewDetail],
|
||||
summary_index_setting: dict,
|
||||
summary_index_setting: SummaryIndexSettingDict,
|
||||
doc_language: str | None = None,
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
|
||||
@ -15,13 +15,14 @@ from core.db.session_factory import session_factory
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
|
||||
from core.rag.models.document import AttachmentDocument, Document, QAStructureChunk
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
@ -185,7 +186,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
dataset: Dataset,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
reranking_model: dict,
|
||||
reranking_model: RerankingModelDict,
|
||||
):
|
||||
# Set search parameters.
|
||||
results = RetrievalService.retrieve(
|
||||
@ -244,7 +245,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
self,
|
||||
tenant_id: str,
|
||||
preview_texts: list[PreviewDetail],
|
||||
summary_index_setting: dict,
|
||||
summary_index_setting: SummaryIndexSettingDict,
|
||||
doc_language: str | None = None,
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
|
||||
@ -31,7 +31,7 @@ from core.ops.utils import measure_time
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict
|
||||
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
@ -83,7 +83,7 @@ from models.dataset import (
|
||||
)
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.dataset import Document as DocumentModel
|
||||
from models.enums import CreatorUserRole
|
||||
from models.enums import CreatorUserRole, DatasetQuerySource
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
@ -727,8 +727,8 @@ class DatasetRetrieval:
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
reranking_mode: str,
|
||||
reranking_model: dict | None = None,
|
||||
weights: dict[str, Any] | None = None,
|
||||
reranking_model: RerankingModelDict | None = None,
|
||||
weights: WeightsDict | None = None,
|
||||
reranking_enable: bool = True,
|
||||
message_id: str | None = None,
|
||||
metadata_filter_document_ids: dict[str, list[str]] | None = None,
|
||||
@ -1008,7 +1008,7 @@ class DatasetRetrieval:
|
||||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset_id,
|
||||
content=json.dumps(contents),
|
||||
source="app",
|
||||
source=DatasetQuerySource.APP,
|
||||
source_app_id=app_id,
|
||||
created_by_role=CreatorUserRole(user_from),
|
||||
created_by=user_id,
|
||||
@ -1181,8 +1181,8 @@ class DatasetRetrieval:
|
||||
hit_callbacks=[hit_callback],
|
||||
return_resource=return_resource,
|
||||
retriever_from=invoke_from.to_source(),
|
||||
reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"),
|
||||
reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"),
|
||||
reranking_provider_name=retrieve_config.reranking_model["reranking_provider_name"],
|
||||
reranking_model_name=retrieve_config.reranking_model["reranking_model_name"],
|
||||
)
|
||||
|
||||
tools.append(tool)
|
||||
@ -1685,8 +1685,8 @@ class DatasetRetrieval:
|
||||
tenant_id: str,
|
||||
reranking_enable: bool,
|
||||
reranking_mode: str,
|
||||
reranking_model: dict | None,
|
||||
weights: dict[str, Any] | None,
|
||||
reranking_model: RerankingModelDict | None,
|
||||
weights: WeightsDict | None,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
query: str | None,
|
||||
|
||||
@ -2,6 +2,7 @@ import concurrent.futures
|
||||
import logging
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
||||
from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
from tasks.generate_summary_index_task import generate_summary_index_task
|
||||
@ -11,7 +12,11 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class SummaryIndex:
|
||||
def generate_and_vectorize_summary(
|
||||
self, dataset_id: str, document_id: str, is_preview: bool, summary_index_setting: dict | None = None
|
||||
self,
|
||||
dataset_id: str,
|
||||
document_id: str,
|
||||
is_preview: bool,
|
||||
summary_index_setting: SummaryIndexSettingDict | None = None,
|
||||
) -> None:
|
||||
if is_preview:
|
||||
with session_factory.create_session() as session:
|
||||
|
||||
@ -72,6 +72,11 @@ class ApiProviderControllerItem(TypedDict):
|
||||
controller: ApiToolProviderController
|
||||
|
||||
|
||||
class EmojiIconDict(TypedDict):
|
||||
background: str
|
||||
content: str
|
||||
|
||||
|
||||
class ToolManager:
|
||||
_builtin_provider_lock = Lock()
|
||||
_hardcoded_providers: dict[str, BuiltinToolProviderController] = {}
|
||||
@ -916,7 +921,7 @@ class ToolManager:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str]:
|
||||
def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict:
|
||||
try:
|
||||
workflow_provider: WorkflowToolProvider | None = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
@ -933,7 +938,7 @@ class ToolManager:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
@classmethod
|
||||
def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str]:
|
||||
def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict:
|
||||
try:
|
||||
api_provider: ApiToolProvider | None = (
|
||||
db.session.query(ApiToolProvider)
|
||||
@ -950,7 +955,7 @@ class ToolManager:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
@classmethod
|
||||
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str] | str:
|
||||
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict | dict[str, str] | str:
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
@ -970,7 +975,7 @@ class ToolManager:
|
||||
tenant_id: str,
|
||||
provider_type: ToolProviderType,
|
||||
provider_id: str,
|
||||
) -> str | Mapping[str, str]:
|
||||
) -> str | EmojiIconDict | dict[str, str]:
|
||||
"""
|
||||
get the tool icon
|
||||
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import BaseModel, Field
|
||||
@ -13,11 +12,12 @@ from core.rag.models.document import Document as RagDocument
|
||||
from core.rag.rerank.rerank_model import RerankModelRunner
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_tool import DefaultRetrievalModelDict
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
||||
default_retrieval_model: dict[str, Any] = {
|
||||
default_retrieval_model: DefaultRetrievalModelDict = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
from typing import Any, cast
|
||||
from typing import NotRequired, TypedDict, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities.context_entities import DocumentContext
|
||||
@ -16,7 +17,19 @@ from models.dataset import Dataset
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
||||
default_retrieval_model: dict[str, Any] = {
|
||||
|
||||
class DefaultRetrievalModelDict(TypedDict):
|
||||
search_method: RetrievalMethod
|
||||
reranking_enable: bool
|
||||
reranking_model: RerankingModelDict
|
||||
reranking_mode: NotRequired[str]
|
||||
weights: NotRequired[WeightsDict | None]
|
||||
score_threshold: NotRequired[float]
|
||||
top_k: int
|
||||
score_threshold_enabled: bool
|
||||
|
||||
|
||||
default_retrieval_model: DefaultRetrievalModelDict = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
@ -125,7 +138,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
if metadata_condition and not document_ids_filter:
|
||||
return ""
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
|
||||
retrieval_model = dataset.retrieval_model or default_retrieval_model
|
||||
retrieval_resource_list: list[RetrievalSourceMetadata] = []
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import re
|
||||
from collections.abc import Mapping
|
||||
from json import dumps as json_dumps
|
||||
from json import loads as json_loads
|
||||
from json.decoder import JSONDecodeError
|
||||
@ -20,10 +21,18 @@ class InterfaceDict(TypedDict):
|
||||
operation: dict[str, Any]
|
||||
|
||||
|
||||
class OpenAPISpecDict(TypedDict):
|
||||
openapi: str
|
||||
info: dict[str, str]
|
||||
servers: list[dict[str, Any]]
|
||||
paths: dict[str, Any]
|
||||
components: dict[str, Any]
|
||||
|
||||
|
||||
class ApiBasedToolSchemaParser:
|
||||
@staticmethod
|
||||
def parse_openapi_to_tool_bundle(
|
||||
openapi: dict, extra_info: dict | None = None, warning: dict | None = None
|
||||
openapi: Mapping[str, Any], extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> list[ApiToolBundle]:
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
@ -277,7 +286,7 @@ class ApiBasedToolSchemaParser:
|
||||
@staticmethod
|
||||
def parse_swagger_to_openapi(
|
||||
swagger: dict, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> dict[str, Any]:
|
||||
) -> OpenAPISpecDict:
|
||||
warning = warning or {}
|
||||
"""
|
||||
parse swagger to openapi
|
||||
@ -293,7 +302,7 @@ class ApiBasedToolSchemaParser:
|
||||
if len(servers) == 0:
|
||||
raise ToolApiSchemaError("No server found in the swagger yaml.")
|
||||
|
||||
converted_openapi: dict[str, Any] = {
|
||||
converted_openapi: OpenAPISpecDict = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {
|
||||
"title": info.get("title", "Swagger"),
|
||||
|
||||
@ -2,6 +2,7 @@ from typing import Literal, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
@ -161,4 +162,4 @@ class KnowledgeIndexNodeData(BaseNodeData):
|
||||
chunk_structure: str
|
||||
index_chunk_variable_selector: list[str]
|
||||
indexing_technique: str | None = None
|
||||
summary_index_setting: dict | None = None
|
||||
summary_index_setting: SummaryIndexSettingDict | None = None
|
||||
|
||||
@ -3,6 +3,7 @@ from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.rag.index_processor.index_processor import IndexProcessor
|
||||
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
||||
from core.rag.summary_index.summary_index import SummaryIndex
|
||||
from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
@ -127,7 +128,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
is_preview: bool,
|
||||
batch: Any,
|
||||
chunks: Mapping[str, Any],
|
||||
summary_index_setting: dict | None = None,
|
||||
summary_index_setting: SummaryIndexSettingDict | None = None,
|
||||
):
|
||||
if not document_id:
|
||||
raise KnowledgeIndexNodeError("document_id is required.")
|
||||
|
||||
@ -9,6 +9,7 @@ from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
@ -201,8 +202,8 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
|
||||
if node_data.multiple_retrieval_config is None:
|
||||
raise ValueError("multiple_retrieval_config is required")
|
||||
reranking_model = None
|
||||
weights = None
|
||||
reranking_model: RerankingModelDict | None = None
|
||||
weights: WeightsDict | None = None
|
||||
match node_data.multiple_retrieval_config.reranking_mode:
|
||||
case "reranking_model":
|
||||
if node_data.multiple_retrieval_config.reranking_model:
|
||||
|
||||
@ -2,6 +2,7 @@ from typing import Any, Literal, Protocol
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
|
||||
from dify_graph.model_runtime.entities import LLMUsage
|
||||
from dify_graph.nodes.llm.entities import ModelConfig
|
||||
|
||||
@ -75,8 +76,8 @@ class KnowledgeRetrievalRequest(BaseModel):
|
||||
top_k: int = Field(default=0, description="Number of top results to return")
|
||||
score_threshold: float = Field(default=0.0, description="Minimum relevance score threshold")
|
||||
reranking_mode: str = Field(default="reranking_model", description="Reranking strategy")
|
||||
reranking_model: dict | None = Field(default=None, description="Reranking model configuration")
|
||||
weights: dict[str, Any] | None = Field(default=None, description="Weights for weighted score reranking")
|
||||
reranking_model: RerankingModelDict | None = Field(default=None, description="Reranking model configuration")
|
||||
weights: WeightsDict | None = Field(default=None, description="Weights for weighted score reranking")
|
||||
reranking_enable: bool = Field(default=True, description="Whether reranking is enabled")
|
||||
attachment_ids: list[str] | None = Field(default=None, description="List of attachment file IDs for retrieval")
|
||||
|
||||
|
||||
@ -101,7 +101,6 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
timeout=self._get_request_timeout(self.node_data),
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
http_request_config=self._http_request_config,
|
||||
max_retries=0,
|
||||
ssl_verify=self.node_data.ssl_verify,
|
||||
http_client=self._http_client,
|
||||
file_manager=self._file_manager,
|
||||
|
||||
@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any
|
||||
from dify_graph.file.models import File
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
from dify_graph.variables.segments import Segment
|
||||
|
||||
|
||||
class ArrayValidation(StrEnum):
|
||||
@ -219,7 +219,7 @@ class SegmentType(StrEnum):
|
||||
return _ARRAY_ELEMENT_TYPES_MAPPING.get(self)
|
||||
|
||||
@staticmethod
|
||||
def get_zero_value(t: SegmentType):
|
||||
def get_zero_value(t: SegmentType) -> Segment:
|
||||
# Lazy import to avoid circular dependency
|
||||
from factories import variable_factory
|
||||
|
||||
|
||||
@ -10,6 +10,7 @@ from events.document_index_event import document_index_created
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Document
|
||||
from models.enums import IndexingStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -35,7 +36,7 @@ def handle(sender, **kwargs):
|
||||
if not document:
|
||||
raise NotFound("Document not found")
|
||||
|
||||
document.indexing_status = "parsing"
|
||||
document.indexing_status = IndexingStatus.PARSING
|
||||
document.processing_started_at = naive_utc_now()
|
||||
documents.append(document)
|
||||
db.session.add(document)
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Protocol, cast
|
||||
|
||||
from fastopenapi.routers import FlaskRouter
|
||||
from flask_cors import CORS
|
||||
|
||||
@ -9,6 +11,10 @@ from extensions.ext_blueprints import AUTHENTICATED_HEADERS, EXPOSED_HEADERS
|
||||
DOCS_PREFIX = "/fastopenapi"
|
||||
|
||||
|
||||
class SupportsIncludeRouter(Protocol):
|
||||
def include_router(self, router: object, *, prefix: str = "") -> None: ...
|
||||
|
||||
|
||||
def init_app(app: DifyApp) -> None:
|
||||
docs_enabled = dify_config.SWAGGER_UI_ENABLED
|
||||
docs_url = f"{DOCS_PREFIX}/docs" if docs_enabled else None
|
||||
@ -36,7 +42,7 @@ def init_app(app: DifyApp) -> None:
|
||||
_ = remote_files
|
||||
_ = setup
|
||||
|
||||
router.include_router(console_router, prefix="/console/api")
|
||||
cast(SupportsIncludeRouter, router).include_router(console_router, prefix="/console/api")
|
||||
CORS(
|
||||
app,
|
||||
resources={r"/console/api/.*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
|
||||
|
||||
@ -5,7 +5,7 @@ from typing import Union
|
||||
|
||||
from celery.signals import worker_init
|
||||
from flask_login import user_loaded_from_request, user_logged_in
|
||||
from opentelemetry import trace
|
||||
from opentelemetry import metrics, trace
|
||||
from opentelemetry.propagate import set_global_textmap
|
||||
from opentelemetry.propagators.b3 import B3MultiFormat
|
||||
from opentelemetry.propagators.composite import CompositePropagator
|
||||
@ -31,9 +31,29 @@ def setup_context_propagation() -> None:
|
||||
|
||||
|
||||
def shutdown_tracer() -> None:
|
||||
flush_telemetry()
|
||||
|
||||
|
||||
def flush_telemetry() -> None:
|
||||
"""
|
||||
Best-effort flush for telemetry providers.
|
||||
|
||||
This is mainly used by short-lived command processes (e.g. Kubernetes CronJob)
|
||||
so counters/histograms are exported before the process exits.
|
||||
"""
|
||||
provider = trace.get_tracer_provider()
|
||||
if hasattr(provider, "force_flush"):
|
||||
provider.force_flush()
|
||||
try:
|
||||
provider.force_flush()
|
||||
except Exception:
|
||||
logger.exception("otel: failed to flush trace provider")
|
||||
|
||||
metric_provider = metrics.get_meter_provider()
|
||||
if hasattr(metric_provider, "force_flush"):
|
||||
try:
|
||||
metric_provider.force_flush()
|
||||
except Exception:
|
||||
logger.exception("otel: failed to flush metric provider")
|
||||
|
||||
|
||||
def is_celery_worker():
|
||||
|
||||
@ -55,7 +55,7 @@ class TypeMismatchError(Exception):
|
||||
|
||||
|
||||
# Define the constant
|
||||
SEGMENT_TO_VARIABLE_MAP = {
|
||||
SEGMENT_TO_VARIABLE_MAP: Mapping[type[Segment], type[VariableBase]] = {
|
||||
ArrayAnySegment: ArrayAnyVariable,
|
||||
ArrayBooleanSegment: ArrayBooleanVariable,
|
||||
ArrayFileSegment: ArrayFileVariable,
|
||||
@ -296,13 +296,11 @@ def segment_to_variable(
|
||||
raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}")
|
||||
|
||||
variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type]
|
||||
return cast(
|
||||
VariableBase,
|
||||
variable_class(
|
||||
id=id,
|
||||
name=name,
|
||||
description=description,
|
||||
value=segment.value,
|
||||
selector=list(selector),
|
||||
),
|
||||
return variable_class(
|
||||
id=id,
|
||||
name=name,
|
||||
description=description,
|
||||
value_type=segment.value_type,
|
||||
value=segment.value,
|
||||
selector=list(selector),
|
||||
)
|
||||
|
||||
@ -32,6 +32,11 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _stream_with_request_context(response: object) -> Any:
|
||||
"""Bridge Flask's loosely-typed streaming helper without leaking casts into callers."""
|
||||
return cast(Any, stream_with_context)(response)
|
||||
|
||||
|
||||
def escape_like_pattern(pattern: str) -> str:
|
||||
"""
|
||||
Escape special characters in a string for safe use in SQL LIKE patterns.
|
||||
@ -286,22 +291,32 @@ def generate_text_hash(text: str) -> str:
|
||||
return sha256(hash_text.encode()).hexdigest()
|
||||
|
||||
|
||||
def compact_generate_response(response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
|
||||
if isinstance(response, dict):
|
||||
def compact_generate_response(
|
||||
response: Mapping[str, Any] | Generator[str, None, None] | RateLimitGenerator,
|
||||
) -> Response:
|
||||
if isinstance(response, Mapping):
|
||||
return Response(
|
||||
response=json.dumps(jsonable_encoder(response)),
|
||||
status=200,
|
||||
content_type="application/json; charset=utf-8",
|
||||
)
|
||||
else:
|
||||
stream_response = response
|
||||
|
||||
def generate() -> Generator:
|
||||
yield from response
|
||||
def generate() -> Generator[str, None, None]:
|
||||
yield from stream_response
|
||||
|
||||
return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream")
|
||||
return Response(
|
||||
_stream_with_request_context(generate()),
|
||||
status=200,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
|
||||
|
||||
def length_prefixed_response(magic_number: int, response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
|
||||
def length_prefixed_response(
|
||||
magic_number: int,
|
||||
response: Mapping[str, Any] | BaseModel | Generator[str | bytes, None, None] | RateLimitGenerator,
|
||||
) -> Response:
|
||||
"""
|
||||
This function is used to return a response with a length prefix.
|
||||
Magic number is a one byte number that indicates the type of the response.
|
||||
@ -332,7 +347,7 @@ def length_prefixed_response(magic_number: int, response: Union[Mapping, Generat
|
||||
# | Magic Number 1byte | Reserved 1byte | Header Length 2bytes | Data Length 4bytes | Reserved 6bytes | Data
|
||||
return struct.pack("<BBHI", magic_number, 0, header_length, data_length) + b"\x00" * 6 + response
|
||||
|
||||
if isinstance(response, dict):
|
||||
if isinstance(response, Mapping):
|
||||
return Response(
|
||||
response=pack_response_with_length_prefix(json.dumps(jsonable_encoder(response)).encode("utf-8")),
|
||||
status=200,
|
||||
@ -345,14 +360,20 @@ def length_prefixed_response(magic_number: int, response: Union[Mapping, Generat
|
||||
mimetype="application/json",
|
||||
)
|
||||
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
stream_response = response
|
||||
|
||||
def generate() -> Generator[bytes, None, None]:
|
||||
for chunk in stream_response:
|
||||
if isinstance(chunk, str):
|
||||
yield pack_response_with_length_prefix(chunk.encode("utf-8"))
|
||||
else:
|
||||
yield pack_response_with_length_prefix(chunk)
|
||||
|
||||
return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream")
|
||||
return Response(
|
||||
_stream_with_request_context(generate()),
|
||||
status=200,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
|
||||
|
||||
class TokenManager:
|
||||
|
||||
@ -77,12 +77,14 @@ def login_required(func: Callable[P, R]) -> Callable[P, R | ResponseReturnValue]
|
||||
@wraps(func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R | ResponseReturnValue:
|
||||
if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
|
||||
pass
|
||||
elif current_user is not None and not current_user.is_authenticated:
|
||||
return current_app.ensure_sync(func)(*args, **kwargs)
|
||||
|
||||
user = _get_user()
|
||||
if user is None or not user.is_authenticated:
|
||||
return current_app.login_manager.unauthorized() # type: ignore
|
||||
# we put csrf validation here for less conflicts
|
||||
# TODO: maybe find a better place for it.
|
||||
check_csrf_token(request, current_user.id)
|
||||
check_csrf_token(request, user.id)
|
||||
return current_app.ensure_sync(func)(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
@ -7,9 +7,10 @@ https://github.com/django/django/blob/main/django/utils/module_loading.py
|
||||
|
||||
import sys
|
||||
from importlib import import_module
|
||||
from typing import Any
|
||||
|
||||
|
||||
def cached_import(module_path: str, class_name: str):
|
||||
def cached_import(module_path: str, class_name: str) -> Any:
|
||||
"""
|
||||
Import a module and return the named attribute/class from it, with caching.
|
||||
|
||||
@ -20,16 +21,14 @@ def cached_import(module_path: str, class_name: str):
|
||||
Returns:
|
||||
The imported attribute/class
|
||||
"""
|
||||
if not (
|
||||
(module := sys.modules.get(module_path))
|
||||
and (spec := getattr(module, "__spec__", None))
|
||||
and getattr(spec, "_initializing", False) is False
|
||||
):
|
||||
module = sys.modules.get(module_path)
|
||||
spec = getattr(module, "__spec__", None) if module is not None else None
|
||||
if module is None or getattr(spec, "_initializing", False):
|
||||
module = import_module(module_path)
|
||||
return getattr(module, class_name)
|
||||
|
||||
|
||||
def import_string(dotted_path: str):
|
||||
def import_string(dotted_path: str) -> Any:
|
||||
"""
|
||||
Import a dotted module path and return the attribute/class designated by
|
||||
the last name in the path. Raise ImportError if the import failed.
|
||||
|
||||
@ -1,7 +1,48 @@
|
||||
import sys
|
||||
import urllib.parse
|
||||
from dataclasses import dataclass
|
||||
from typing import NotRequired
|
||||
|
||||
import httpx
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import TypedDict
|
||||
else:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
JsonObject = dict[str, object]
|
||||
JsonObjectList = list[JsonObject]
|
||||
|
||||
JSON_OBJECT_ADAPTER = TypeAdapter(JsonObject)
|
||||
JSON_OBJECT_LIST_ADAPTER = TypeAdapter(JsonObjectList)
|
||||
|
||||
|
||||
class AccessTokenResponse(TypedDict, total=False):
|
||||
access_token: str
|
||||
|
||||
|
||||
class GitHubEmailRecord(TypedDict, total=False):
|
||||
email: str
|
||||
primary: bool
|
||||
|
||||
|
||||
class GitHubRawUserInfo(TypedDict):
|
||||
id: int | str
|
||||
login: str
|
||||
name: NotRequired[str]
|
||||
email: NotRequired[str]
|
||||
|
||||
|
||||
class GoogleRawUserInfo(TypedDict):
|
||||
sub: str
|
||||
email: str
|
||||
|
||||
|
||||
ACCESS_TOKEN_RESPONSE_ADAPTER = TypeAdapter(AccessTokenResponse)
|
||||
GITHUB_RAW_USER_INFO_ADAPTER = TypeAdapter(GitHubRawUserInfo)
|
||||
GITHUB_EMAIL_RECORDS_ADAPTER = TypeAdapter(list[GitHubEmailRecord])
|
||||
GOOGLE_RAW_USER_INFO_ADAPTER = TypeAdapter(GoogleRawUserInfo)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -11,26 +52,38 @@ class OAuthUserInfo:
|
||||
email: str
|
||||
|
||||
|
||||
def _json_object(response: httpx.Response) -> JsonObject:
|
||||
return JSON_OBJECT_ADAPTER.validate_python(response.json())
|
||||
|
||||
|
||||
def _json_list(response: httpx.Response) -> JsonObjectList:
|
||||
return JSON_OBJECT_LIST_ADAPTER.validate_python(response.json())
|
||||
|
||||
|
||||
class OAuth:
|
||||
client_id: str
|
||||
client_secret: str
|
||||
redirect_uri: str
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
|
||||
def get_authorization_url(self):
|
||||
def get_authorization_url(self, invite_token: str | None = None) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_access_token(self, code: str):
|
||||
def get_access_token(self, code: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_raw_user_info(self, token: str):
|
||||
def get_raw_user_info(self, token: str) -> JsonObject:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_user_info(self, token: str) -> OAuthUserInfo:
|
||||
raw_info = self.get_raw_user_info(token)
|
||||
return self._transform_user_info(raw_info)
|
||||
|
||||
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
|
||||
def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@ -40,7 +93,7 @@ class GitHubOAuth(OAuth):
|
||||
_USER_INFO_URL = "https://api.github.com/user"
|
||||
_EMAIL_INFO_URL = "https://api.github.com/user/emails"
|
||||
|
||||
def get_authorization_url(self, invite_token: str | None = None):
|
||||
def get_authorization_url(self, invite_token: str | None = None) -> str:
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
@ -50,7 +103,7 @@ class GitHubOAuth(OAuth):
|
||||
params["state"] = invite_token
|
||||
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
def get_access_token(self, code: str):
|
||||
def get_access_token(self, code: str) -> str:
|
||||
data = {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
@ -60,7 +113,7 @@ class GitHubOAuth(OAuth):
|
||||
headers = {"Accept": "application/json"}
|
||||
response = httpx.post(self._TOKEN_URL, data=data, headers=headers)
|
||||
|
||||
response_json = response.json()
|
||||
response_json = ACCESS_TOKEN_RESPONSE_ADAPTER.validate_python(_json_object(response))
|
||||
access_token = response_json.get("access_token")
|
||||
|
||||
if not access_token:
|
||||
@ -68,23 +121,24 @@ class GitHubOAuth(OAuth):
|
||||
|
||||
return access_token
|
||||
|
||||
def get_raw_user_info(self, token: str):
|
||||
def get_raw_user_info(self, token: str) -> JsonObject:
|
||||
headers = {"Authorization": f"token {token}"}
|
||||
response = httpx.get(self._USER_INFO_URL, headers=headers)
|
||||
response.raise_for_status()
|
||||
user_info = response.json()
|
||||
user_info = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(_json_object(response))
|
||||
|
||||
email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers)
|
||||
email_info = email_response.json()
|
||||
primary_email: dict = next((email for email in email_info if email["primary"] == True), {})
|
||||
email_info = GITHUB_EMAIL_RECORDS_ADAPTER.validate_python(_json_list(email_response))
|
||||
primary_email = next((email for email in email_info if email.get("primary") is True), None)
|
||||
|
||||
return {**user_info, "email": primary_email.get("email", "")}
|
||||
return {**user_info, "email": primary_email.get("email", "") if primary_email else ""}
|
||||
|
||||
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
|
||||
email = raw_info.get("email")
|
||||
def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo:
|
||||
payload = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(raw_info)
|
||||
email = payload.get("email")
|
||||
if not email:
|
||||
email = f"{raw_info['id']}+{raw_info['login']}@users.noreply.github.com"
|
||||
return OAuthUserInfo(id=str(raw_info["id"]), name=raw_info["name"], email=email)
|
||||
email = f"{payload['id']}+{payload['login']}@users.noreply.github.com"
|
||||
return OAuthUserInfo(id=str(payload["id"]), name=str(payload.get("name", "")), email=email)
|
||||
|
||||
|
||||
class GoogleOAuth(OAuth):
|
||||
@ -92,7 +146,7 @@ class GoogleOAuth(OAuth):
|
||||
_TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
_USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo"
|
||||
|
||||
def get_authorization_url(self, invite_token: str | None = None):
|
||||
def get_authorization_url(self, invite_token: str | None = None) -> str:
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
"response_type": "code",
|
||||
@ -103,7 +157,7 @@ class GoogleOAuth(OAuth):
|
||||
params["state"] = invite_token
|
||||
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
def get_access_token(self, code: str):
|
||||
def get_access_token(self, code: str) -> str:
|
||||
data = {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
@ -114,7 +168,7 @@ class GoogleOAuth(OAuth):
|
||||
headers = {"Accept": "application/json"}
|
||||
response = httpx.post(self._TOKEN_URL, data=data, headers=headers)
|
||||
|
||||
response_json = response.json()
|
||||
response_json = ACCESS_TOKEN_RESPONSE_ADAPTER.validate_python(_json_object(response))
|
||||
access_token = response_json.get("access_token")
|
||||
|
||||
if not access_token:
|
||||
@ -122,11 +176,12 @@ class GoogleOAuth(OAuth):
|
||||
|
||||
return access_token
|
||||
|
||||
def get_raw_user_info(self, token: str):
|
||||
def get_raw_user_info(self, token: str) -> JsonObject:
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
response = httpx.get(self._USER_INFO_URL, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
return _json_object(response)
|
||||
|
||||
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
|
||||
return OAuthUserInfo(id=str(raw_info["sub"]), name="", email=raw_info["email"])
|
||||
def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo:
|
||||
payload = GOOGLE_RAW_USER_INFO_ADAPTER.validate_python(raw_info)
|
||||
return OAuthUserInfo(id=str(payload["sub"]), name="", email=payload["email"])
|
||||
|
||||
@ -1,25 +1,57 @@
|
||||
import sys
|
||||
import urllib.parse
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
|
||||
import httpx
|
||||
from flask_login import current_user
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.source import DataSourceOauthBinding
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import TypedDict
|
||||
else:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class NotionPageSummary(TypedDict):
|
||||
page_id: str
|
||||
page_name: str
|
||||
page_icon: dict[str, str] | None
|
||||
parent_id: str
|
||||
type: Literal["page", "database"]
|
||||
|
||||
|
||||
class NotionSourceInfo(TypedDict):
|
||||
workspace_name: str | None
|
||||
workspace_icon: str | None
|
||||
workspace_id: str | None
|
||||
pages: list[NotionPageSummary]
|
||||
total: int
|
||||
|
||||
|
||||
SOURCE_INFO_STORAGE_ADAPTER = TypeAdapter(dict[str, object])
|
||||
NOTION_SOURCE_INFO_ADAPTER = TypeAdapter(NotionSourceInfo)
|
||||
NOTION_PAGE_SUMMARY_ADAPTER = TypeAdapter(NotionPageSummary)
|
||||
|
||||
|
||||
class OAuthDataSource:
|
||||
client_id: str
|
||||
client_secret: str
|
||||
redirect_uri: str
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
|
||||
def get_authorization_url(self):
|
||||
def get_authorization_url(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_access_token(self, code: str):
|
||||
def get_access_token(self, code: str) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@ -30,7 +62,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
_NOTION_BLOCK_SEARCH = "https://api.notion.com/v1/blocks"
|
||||
_NOTION_BOT_USER = "https://api.notion.com/v1/users/me"
|
||||
|
||||
def get_authorization_url(self):
|
||||
def get_authorization_url(self) -> str:
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
"response_type": "code",
|
||||
@ -39,7 +71,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
}
|
||||
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
def get_access_token(self, code: str):
|
||||
def get_access_token(self, code: str) -> None:
|
||||
data = {"code": code, "grant_type": "authorization_code", "redirect_uri": self.redirect_uri}
|
||||
headers = {"Accept": "application/json"}
|
||||
auth = (self.client_id, self.client_secret)
|
||||
@ -54,13 +86,12 @@ class NotionOAuth(OAuthDataSource):
|
||||
workspace_id = response_json.get("workspace_id")
|
||||
# get all authorized pages
|
||||
pages = self.get_authorized_pages(access_token)
|
||||
source_info = {
|
||||
"workspace_name": workspace_name,
|
||||
"workspace_icon": workspace_icon,
|
||||
"workspace_id": workspace_id,
|
||||
"pages": pages,
|
||||
"total": len(pages),
|
||||
}
|
||||
source_info = self._build_source_info(
|
||||
workspace_name=workspace_name,
|
||||
workspace_icon=workspace_icon,
|
||||
workspace_id=workspace_id,
|
||||
pages=pages,
|
||||
)
|
||||
# save data source binding
|
||||
data_source_binding = db.session.scalar(
|
||||
select(DataSourceOauthBinding).where(
|
||||
@ -70,7 +101,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
)
|
||||
)
|
||||
if data_source_binding:
|
||||
data_source_binding.source_info = source_info
|
||||
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
@ -78,25 +109,24 @@ class NotionOAuth(OAuthDataSource):
|
||||
new_data_source_binding = DataSourceOauthBinding(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
access_token=access_token,
|
||||
source_info=source_info,
|
||||
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
|
||||
provider="notion",
|
||||
)
|
||||
db.session.add(new_data_source_binding)
|
||||
db.session.commit()
|
||||
|
||||
def save_internal_access_token(self, access_token: str):
|
||||
def save_internal_access_token(self, access_token: str) -> None:
|
||||
workspace_name = self.notion_workspace_name(access_token)
|
||||
workspace_icon = None
|
||||
workspace_id = current_user.current_tenant_id
|
||||
# get all authorized pages
|
||||
pages = self.get_authorized_pages(access_token)
|
||||
source_info = {
|
||||
"workspace_name": workspace_name,
|
||||
"workspace_icon": workspace_icon,
|
||||
"workspace_id": workspace_id,
|
||||
"pages": pages,
|
||||
"total": len(pages),
|
||||
}
|
||||
source_info = self._build_source_info(
|
||||
workspace_name=workspace_name,
|
||||
workspace_icon=workspace_icon,
|
||||
workspace_id=workspace_id,
|
||||
pages=pages,
|
||||
)
|
||||
# save data source binding
|
||||
data_source_binding = db.session.scalar(
|
||||
select(DataSourceOauthBinding).where(
|
||||
@ -106,7 +136,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
)
|
||||
)
|
||||
if data_source_binding:
|
||||
data_source_binding.source_info = source_info
|
||||
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
@ -114,13 +144,13 @@ class NotionOAuth(OAuthDataSource):
|
||||
new_data_source_binding = DataSourceOauthBinding(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
access_token=access_token,
|
||||
source_info=source_info,
|
||||
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
|
||||
provider="notion",
|
||||
)
|
||||
db.session.add(new_data_source_binding)
|
||||
db.session.commit()
|
||||
|
||||
def sync_data_source(self, binding_id: str):
|
||||
def sync_data_source(self, binding_id: str) -> None:
|
||||
# save data source binding
|
||||
data_source_binding = db.session.scalar(
|
||||
select(DataSourceOauthBinding).where(
|
||||
@ -134,23 +164,22 @@ class NotionOAuth(OAuthDataSource):
|
||||
if data_source_binding:
|
||||
# get all authorized pages
|
||||
pages = self.get_authorized_pages(data_source_binding.access_token)
|
||||
source_info = data_source_binding.source_info
|
||||
new_source_info = {
|
||||
"workspace_name": source_info["workspace_name"],
|
||||
"workspace_icon": source_info["workspace_icon"],
|
||||
"workspace_id": source_info["workspace_id"],
|
||||
"pages": pages,
|
||||
"total": len(pages),
|
||||
}
|
||||
data_source_binding.source_info = new_source_info
|
||||
source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info)
|
||||
new_source_info = self._build_source_info(
|
||||
workspace_name=source_info["workspace_name"],
|
||||
workspace_icon=source_info["workspace_icon"],
|
||||
workspace_id=source_info["workspace_id"],
|
||||
pages=pages,
|
||||
)
|
||||
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info)
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
else:
|
||||
raise ValueError("Data source binding not found")
|
||||
|
||||
def get_authorized_pages(self, access_token: str):
|
||||
pages = []
|
||||
def get_authorized_pages(self, access_token: str) -> list[NotionPageSummary]:
|
||||
pages: list[NotionPageSummary] = []
|
||||
page_results = self.notion_page_search(access_token)
|
||||
database_results = self.notion_database_search(access_token)
|
||||
# get page detail
|
||||
@ -187,7 +216,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
"parent_id": parent_id,
|
||||
"type": "page",
|
||||
}
|
||||
pages.append(page)
|
||||
pages.append(NOTION_PAGE_SUMMARY_ADAPTER.validate_python(page))
|
||||
# get database detail
|
||||
for database_result in database_results:
|
||||
page_id = database_result["id"]
|
||||
@ -220,11 +249,11 @@ class NotionOAuth(OAuthDataSource):
|
||||
"parent_id": parent_id,
|
||||
"type": "database",
|
||||
}
|
||||
pages.append(page)
|
||||
pages.append(NOTION_PAGE_SUMMARY_ADAPTER.validate_python(page))
|
||||
return pages
|
||||
|
||||
def notion_page_search(self, access_token: str):
|
||||
results = []
|
||||
def notion_page_search(self, access_token: str) -> list[dict[str, Any]]:
|
||||
results: list[dict[str, Any]] = []
|
||||
next_cursor = None
|
||||
has_more = True
|
||||
|
||||
@ -249,7 +278,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
|
||||
return results
|
||||
|
||||
def notion_block_parent_page_id(self, access_token: str, block_id: str):
|
||||
def notion_block_parent_page_id(self, access_token: str, block_id: str) -> str:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Notion-Version": "2022-06-28",
|
||||
@ -265,7 +294,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
return self.notion_block_parent_page_id(access_token, parent[parent_type])
|
||||
return parent[parent_type]
|
||||
|
||||
def notion_workspace_name(self, access_token: str):
|
||||
def notion_workspace_name(self, access_token: str) -> str:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Notion-Version": "2022-06-28",
|
||||
@ -279,8 +308,8 @@ class NotionOAuth(OAuthDataSource):
|
||||
return user_info["workspace_name"]
|
||||
return "workspace"
|
||||
|
||||
def notion_database_search(self, access_token: str):
|
||||
results = []
|
||||
def notion_database_search(self, access_token: str) -> list[dict[str, Any]]:
|
||||
results: list[dict[str, Any]] = []
|
||||
next_cursor = None
|
||||
has_more = True
|
||||
|
||||
@ -303,3 +332,19 @@ class NotionOAuth(OAuthDataSource):
|
||||
next_cursor = response_json.get("next_cursor", None)
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def _build_source_info(
|
||||
*,
|
||||
workspace_name: str | None,
|
||||
workspace_icon: str | None,
|
||||
workspace_id: str | None,
|
||||
pages: list[NotionPageSummary],
|
||||
) -> NotionSourceInfo:
|
||||
return {
|
||||
"workspace_name": workspace_name,
|
||||
"workspace_icon": workspace_icon,
|
||||
"workspace_id": workspace_id,
|
||||
"pages": pages,
|
||||
"total": len(pages),
|
||||
}
|
||||
|
||||
@ -177,13 +177,11 @@ class Account(UserMixin, TypeBase):
|
||||
|
||||
@classmethod
|
||||
def get_by_openid(cls, provider: str, open_id: str):
|
||||
account_integrate = (
|
||||
db.session.query(AccountIntegrate)
|
||||
.where(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id)
|
||||
.one_or_none()
|
||||
)
|
||||
account_integrate = db.session.execute(
|
||||
select(AccountIntegrate).where(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id)
|
||||
).scalar_one_or_none()
|
||||
if account_integrate:
|
||||
return db.session.query(Account).where(Account.id == account_integrate.account_id).one_or_none()
|
||||
return db.session.scalar(select(Account).where(Account.id == account_integrate.account_id))
|
||||
return None
|
||||
|
||||
# check current_user.current_tenant.current_role in ['admin', 'owner']
|
||||
|
||||
@ -8,6 +8,7 @@ import os
|
||||
import pickle
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, TypedDict, cast
|
||||
@ -30,7 +31,20 @@ from services.entities.knowledge_entities.knowledge_entities import ParentMode,
|
||||
from .account import Account
|
||||
from .base import Base, TypeBase
|
||||
from .engine import db
|
||||
from .enums import CreatorUserRole
|
||||
from .enums import (
|
||||
CollectionBindingType,
|
||||
CreatorUserRole,
|
||||
DatasetMetadataType,
|
||||
DatasetQuerySource,
|
||||
DatasetRuntimeMode,
|
||||
DataSourceType,
|
||||
DocumentCreatedFrom,
|
||||
DocumentDocType,
|
||||
IndexingStatus,
|
||||
ProcessRuleMode,
|
||||
SegmentStatus,
|
||||
SummaryStatus,
|
||||
)
|
||||
from .model import App, Tag, TagBinding, UploadFile
|
||||
from .types import AdjustedJSON, BinaryData, EnumText, LongText, StringUUID, adjusted_json_index
|
||||
|
||||
@ -120,7 +134,7 @@ class Dataset(Base):
|
||||
server_default=sa.text("'only_me'"),
|
||||
default=DatasetPermissionEnum.ONLY_ME,
|
||||
)
|
||||
data_source_type = mapped_column(String(255))
|
||||
data_source_type = mapped_column(EnumText(DataSourceType, length=255))
|
||||
indexing_technique: Mapped[str | None] = mapped_column(String(255))
|
||||
index_struct = mapped_column(LongText, nullable=True)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
@ -137,7 +151,9 @@ class Dataset(Base):
|
||||
summary_index_setting = mapped_column(AdjustedJSON, nullable=True)
|
||||
built_in_field_enabled = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
icon_info = mapped_column(AdjustedJSON, nullable=True)
|
||||
runtime_mode = mapped_column(sa.String(255), nullable=True, server_default=sa.text("'general'"))
|
||||
runtime_mode = mapped_column(
|
||||
EnumText(DatasetRuntimeMode, length=255), nullable=True, server_default=sa.text("'general'")
|
||||
)
|
||||
pipeline_id = mapped_column(StringUUID, nullable=True)
|
||||
chunk_structure = mapped_column(sa.String(255), nullable=True)
|
||||
enable_api = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
@ -145,30 +161,25 @@ class Dataset(Base):
|
||||
|
||||
@property
|
||||
def total_documents(self):
|
||||
return db.session.query(func.count(Document.id)).where(Document.dataset_id == self.id).scalar()
|
||||
return db.session.scalar(select(func.count(Document.id)).where(Document.dataset_id == self.id)) or 0
|
||||
|
||||
@property
|
||||
def total_available_documents(self):
|
||||
return (
|
||||
db.session.query(func.count(Document.id))
|
||||
.where(
|
||||
Document.dataset_id == self.id,
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
db.session.scalar(
|
||||
select(func.count(Document.id)).where(
|
||||
Document.dataset_id == self.id,
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
)
|
||||
)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
@property
|
||||
def dataset_keyword_table(self):
|
||||
dataset_keyword_table = (
|
||||
db.session.query(DatasetKeywordTable).where(DatasetKeywordTable.dataset_id == self.id).first()
|
||||
)
|
||||
if dataset_keyword_table:
|
||||
return dataset_keyword_table
|
||||
|
||||
return None
|
||||
return db.session.scalar(select(DatasetKeywordTable).where(DatasetKeywordTable.dataset_id == self.id))
|
||||
|
||||
@property
|
||||
def index_struct_dict(self):
|
||||
@ -195,64 +206,66 @@ class Dataset(Base):
|
||||
|
||||
@property
|
||||
def latest_process_rule(self):
|
||||
return (
|
||||
db.session.query(DatasetProcessRule)
|
||||
return db.session.scalar(
|
||||
select(DatasetProcessRule)
|
||||
.where(DatasetProcessRule.dataset_id == self.id)
|
||||
.order_by(DatasetProcessRule.created_at.desc())
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
@property
|
||||
def app_count(self):
|
||||
return (
|
||||
db.session.query(func.count(AppDatasetJoin.id))
|
||||
.where(AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id)
|
||||
.scalar()
|
||||
db.session.scalar(
|
||||
select(func.count(AppDatasetJoin.id)).where(
|
||||
AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
@property
|
||||
def document_count(self):
|
||||
return db.session.query(func.count(Document.id)).where(Document.dataset_id == self.id).scalar()
|
||||
return db.session.scalar(select(func.count(Document.id)).where(Document.dataset_id == self.id)) or 0
|
||||
|
||||
@property
|
||||
def available_document_count(self):
|
||||
return (
|
||||
db.session.query(func.count(Document.id))
|
||||
.where(
|
||||
Document.dataset_id == self.id,
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
db.session.scalar(
|
||||
select(func.count(Document.id)).where(
|
||||
Document.dataset_id == self.id,
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
)
|
||||
)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
@property
|
||||
def available_segment_count(self):
|
||||
return (
|
||||
db.session.query(func.count(DocumentSegment.id))
|
||||
.where(
|
||||
DocumentSegment.dataset_id == self.id,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled == True,
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.dataset_id == self.id,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
@property
|
||||
def word_count(self):
|
||||
return (
|
||||
db.session.query(Document)
|
||||
.with_entities(func.coalesce(func.sum(Document.word_count), 0))
|
||||
.where(Document.dataset_id == self.id)
|
||||
.scalar()
|
||||
return db.session.scalar(
|
||||
select(func.coalesce(func.sum(Document.word_count), 0)).where(Document.dataset_id == self.id)
|
||||
)
|
||||
|
||||
@property
|
||||
def doc_form(self) -> str | None:
|
||||
if self.chunk_structure:
|
||||
return self.chunk_structure
|
||||
document = db.session.query(Document).where(Document.dataset_id == self.id).first()
|
||||
document = db.session.scalar(select(Document).where(Document.dataset_id == self.id).limit(1))
|
||||
if document:
|
||||
return document.doc_form
|
||||
return None
|
||||
@ -270,8 +283,8 @@ class Dataset(Base):
|
||||
|
||||
@property
|
||||
def tags(self):
|
||||
tags = (
|
||||
db.session.query(Tag)
|
||||
tags = db.session.scalars(
|
||||
select(Tag)
|
||||
.join(TagBinding, Tag.id == TagBinding.tag_id)
|
||||
.where(
|
||||
TagBinding.target_id == self.id,
|
||||
@ -279,8 +292,7 @@ class Dataset(Base):
|
||||
Tag.tenant_id == self.tenant_id,
|
||||
Tag.type == "knowledge",
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
return tags or []
|
||||
|
||||
@ -288,8 +300,8 @@ class Dataset(Base):
|
||||
def external_knowledge_info(self):
|
||||
if self.provider != "external":
|
||||
return None
|
||||
external_knowledge_binding = (
|
||||
db.session.query(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == self.id).first()
|
||||
external_knowledge_binding = db.session.scalar(
|
||||
select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == self.id)
|
||||
)
|
||||
if not external_knowledge_binding:
|
||||
return None
|
||||
@ -310,7 +322,7 @@ class Dataset(Base):
|
||||
@property
|
||||
def is_published(self):
|
||||
if self.pipeline_id:
|
||||
pipeline = db.session.query(Pipeline).where(Pipeline.id == self.pipeline_id).first()
|
||||
pipeline = db.session.scalar(select(Pipeline).where(Pipeline.id == self.pipeline_id))
|
||||
if pipeline:
|
||||
return pipeline.is_published
|
||||
return False
|
||||
@ -382,7 +394,7 @@ class DatasetProcessRule(Base): # bug
|
||||
|
||||
id = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()))
|
||||
dataset_id = mapped_column(StringUUID, nullable=False)
|
||||
mode = mapped_column(String(255), nullable=False, server_default=sa.text("'automatic'"))
|
||||
mode = mapped_column(EnumText(ProcessRuleMode, length=255), nullable=False, server_default=sa.text("'automatic'"))
|
||||
rules = mapped_column(LongText, nullable=True)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
@ -428,12 +440,12 @@ class Document(Base):
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id = mapped_column(StringUUID, nullable=False)
|
||||
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||
data_source_type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
data_source_type: Mapped[str] = mapped_column(EnumText(DataSourceType, length=255), nullable=False)
|
||||
data_source_info = mapped_column(LongText, nullable=True)
|
||||
dataset_process_rule_id = mapped_column(StringUUID, nullable=True)
|
||||
batch: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_from: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_from: Mapped[str] = mapped_column(EnumText(DocumentCreatedFrom, length=255), nullable=False)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_api_request_id = mapped_column(StringUUID, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
@ -467,7 +479,9 @@ class Document(Base):
|
||||
stopped_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
# basic fields
|
||||
indexing_status = mapped_column(String(255), nullable=False, server_default=sa.text("'waiting'"))
|
||||
indexing_status = mapped_column(
|
||||
EnumText(IndexingStatus, length=255), nullable=False, server_default=sa.text("'waiting'")
|
||||
)
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
disabled_by = mapped_column(StringUUID, nullable=True)
|
||||
@ -478,7 +492,7 @@ class Document(Base):
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
doc_type = mapped_column(String(40), nullable=True)
|
||||
doc_type = mapped_column(EnumText(DocumentDocType, length=40), nullable=True)
|
||||
doc_metadata = mapped_column(AdjustedJSON, nullable=True)
|
||||
doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'"))
|
||||
doc_language = mapped_column(String(255), nullable=True)
|
||||
@ -521,10 +535,8 @@ class Document(Base):
|
||||
if self.data_source_info:
|
||||
if self.data_source_type == "upload_file":
|
||||
data_source_info_dict: dict[str, Any] = json.loads(self.data_source_info)
|
||||
file_detail = (
|
||||
db.session.query(UploadFile)
|
||||
.where(UploadFile.id == data_source_info_dict["upload_file_id"])
|
||||
.one_or_none()
|
||||
file_detail = db.session.scalar(
|
||||
select(UploadFile).where(UploadFile.id == data_source_info_dict["upload_file_id"])
|
||||
)
|
||||
if file_detail:
|
||||
return {
|
||||
@ -557,24 +569,23 @@ class Document(Base):
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
return db.session.query(Dataset).where(Dataset.id == self.dataset_id).one_or_none()
|
||||
return db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id))
|
||||
|
||||
@property
|
||||
def segment_count(self):
|
||||
return db.session.query(DocumentSegment).where(DocumentSegment.document_id == self.id).count()
|
||||
return (
|
||||
db.session.scalar(select(func.count(DocumentSegment.id)).where(DocumentSegment.document_id == self.id)) or 0
|
||||
)
|
||||
|
||||
@property
|
||||
def hit_count(self):
|
||||
return (
|
||||
db.session.query(DocumentSegment)
|
||||
.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count), 0))
|
||||
.where(DocumentSegment.document_id == self.id)
|
||||
.scalar()
|
||||
return db.session.scalar(
|
||||
select(func.coalesce(func.sum(DocumentSegment.hit_count), 0)).where(DocumentSegment.document_id == self.id)
|
||||
)
|
||||
|
||||
@property
|
||||
def uploader(self):
|
||||
user = db.session.query(Account).where(Account.id == self.created_by).first()
|
||||
user = db.session.scalar(select(Account).where(Account.id == self.created_by))
|
||||
return user.name if user else None
|
||||
|
||||
@property
|
||||
@ -588,14 +599,13 @@ class Document(Base):
|
||||
@property
|
||||
def doc_metadata_details(self) -> list[DocMetadataDetailItem] | None:
|
||||
if self.doc_metadata:
|
||||
document_metadatas = (
|
||||
db.session.query(DatasetMetadata)
|
||||
document_metadatas = db.session.scalars(
|
||||
select(DatasetMetadata)
|
||||
.join(DatasetMetadataBinding, DatasetMetadataBinding.metadata_id == DatasetMetadata.id)
|
||||
.where(
|
||||
DatasetMetadataBinding.dataset_id == self.dataset_id, DatasetMetadataBinding.document_id == self.id
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
metadata_list: list[DocMetadataDetailItem] = []
|
||||
for metadata in document_metadatas:
|
||||
metadata_dict: DocMetadataDetailItem = {
|
||||
@ -791,7 +801,7 @@ class DocumentSegment(Base):
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
disabled_by = mapped_column(StringUUID, nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'waiting'"))
|
||||
status: Mapped[str] = mapped_column(EnumText(SegmentStatus, length=255), server_default=sa.text("'waiting'"))
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
@ -826,7 +836,7 @@ class DocumentSegment(Base):
|
||||
)
|
||||
|
||||
@property
|
||||
def child_chunks(self) -> list[Any]:
|
||||
def child_chunks(self) -> Sequence[Any]:
|
||||
if not self.document:
|
||||
return []
|
||||
process_rule = self.document.dataset_process_rule
|
||||
@ -835,16 +845,13 @@ class DocumentSegment(Base):
|
||||
if rules_dict:
|
||||
rules = Rule.model_validate(rules_dict)
|
||||
if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:
|
||||
child_chunks = (
|
||||
db.session.query(ChildChunk)
|
||||
.where(ChildChunk.segment_id == self.id)
|
||||
.order_by(ChildChunk.position.asc())
|
||||
.all()
|
||||
)
|
||||
child_chunks = db.session.scalars(
|
||||
select(ChildChunk).where(ChildChunk.segment_id == self.id).order_by(ChildChunk.position.asc())
|
||||
).all()
|
||||
return child_chunks or []
|
||||
return []
|
||||
|
||||
def get_child_chunks(self) -> list[Any]:
|
||||
def get_child_chunks(self) -> Sequence[Any]:
|
||||
if not self.document:
|
||||
return []
|
||||
process_rule = self.document.dataset_process_rule
|
||||
@ -853,12 +860,9 @@ class DocumentSegment(Base):
|
||||
if rules_dict:
|
||||
rules = Rule.model_validate(rules_dict)
|
||||
if rules.parent_mode:
|
||||
child_chunks = (
|
||||
db.session.query(ChildChunk)
|
||||
.where(ChildChunk.segment_id == self.id)
|
||||
.order_by(ChildChunk.position.asc())
|
||||
.all()
|
||||
)
|
||||
child_chunks = db.session.scalars(
|
||||
select(ChildChunk).where(ChildChunk.segment_id == self.id).order_by(ChildChunk.position.asc())
|
||||
).all()
|
||||
return child_chunks or []
|
||||
return []
|
||||
|
||||
@ -1007,15 +1011,15 @@ class ChildChunk(Base):
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
return db.session.query(Dataset).where(Dataset.id == self.dataset_id).first()
|
||||
return db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id))
|
||||
|
||||
@property
|
||||
def document(self):
|
||||
return db.session.query(Document).where(Document.id == self.document_id).first()
|
||||
return db.session.scalar(select(Document).where(Document.id == self.document_id))
|
||||
|
||||
@property
|
||||
def segment(self):
|
||||
return db.session.query(DocumentSegment).where(DocumentSegment.id == self.segment_id).first()
|
||||
return db.session.scalar(select(DocumentSegment).where(DocumentSegment.id == self.segment_id))
|
||||
|
||||
|
||||
class AppDatasetJoin(TypeBase):
|
||||
@ -1061,7 +1065,7 @@ class DatasetQuery(TypeBase):
|
||||
)
|
||||
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
content: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
source: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
source: Mapped[str] = mapped_column(EnumText(DatasetQuerySource, length=255), nullable=False)
|
||||
source_app_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
@ -1076,7 +1080,7 @@ class DatasetQuery(TypeBase):
|
||||
if isinstance(queries, list):
|
||||
for query in queries:
|
||||
if query["content_type"] == QueryType.IMAGE_QUERY:
|
||||
file_info = db.session.query(UploadFile).filter_by(id=query["content"]).first()
|
||||
file_info = db.session.scalar(select(UploadFile).where(UploadFile.id == query["content"]))
|
||||
if file_info:
|
||||
query["file_info"] = {
|
||||
"id": file_info.id,
|
||||
@ -1141,7 +1145,7 @@ class DatasetKeywordTable(TypeBase):
|
||||
super().__init__(object_hook=object_hook, *args, **kwargs)
|
||||
|
||||
# get dataset
|
||||
dataset = db.session.query(Dataset).filter_by(id=self.dataset_id).first()
|
||||
dataset = db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id))
|
||||
if not dataset:
|
||||
return None
|
||||
if self.data_source_type == "database":
|
||||
@ -1206,7 +1210,9 @@ class DatasetCollectionBinding(TypeBase):
|
||||
)
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
type: Mapped[str] = mapped_column(String(40), server_default=sa.text("'dataset'"), nullable=False)
|
||||
type: Mapped[str] = mapped_column(
|
||||
EnumText(CollectionBindingType, length=40), server_default=sa.text("'dataset'"), nullable=False
|
||||
)
|
||||
collection_name: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
@ -1433,7 +1439,7 @@ class DatasetMetadata(TypeBase):
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
type: Mapped[str] = mapped_column(EnumText(DatasetMetadataType, length=255), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False
|
||||
@ -1535,7 +1541,7 @@ class PipelineCustomizedTemplate(TypeBase):
|
||||
|
||||
@property
|
||||
def created_user_name(self):
|
||||
account = db.session.query(Account).where(Account.id == self.created_by).first()
|
||||
account = db.session.scalar(select(Account).where(Account.id == self.created_by))
|
||||
if account:
|
||||
return account.name
|
||||
return ""
|
||||
@ -1570,7 +1576,7 @@ class Pipeline(TypeBase):
|
||||
)
|
||||
|
||||
def retrieve_dataset(self, session: Session):
|
||||
return session.query(Dataset).where(Dataset.pipeline_id == self.id).first()
|
||||
return session.scalar(select(Dataset).where(Dataset.pipeline_id == self.id))
|
||||
|
||||
|
||||
class DocumentPipelineExecutionLog(TypeBase):
|
||||
@ -1660,7 +1666,9 @@ class DocumentSegmentSummary(Base):
|
||||
summary_index_node_id: Mapped[str] = mapped_column(String(255), nullable=True)
|
||||
summary_index_node_hash: Mapped[str] = mapped_column(String(255), nullable=True)
|
||||
tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(32), nullable=False, server_default=sa.text("'generating'"))
|
||||
status: Mapped[str] = mapped_column(
|
||||
EnumText(SummaryStatus, length=32), nullable=False, server_default=sa.text("'generating'")
|
||||
)
|
||||
error: Mapped[str] = mapped_column(LongText, nullable=True)
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
@ -11,6 +11,13 @@ class CreatorUserRole(StrEnum):
|
||||
ACCOUNT = "account"
|
||||
END_USER = "end_user"
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value):
|
||||
if value == "end-user":
|
||||
return cls.END_USER
|
||||
else:
|
||||
return super()._missing_(value)
|
||||
|
||||
|
||||
class WorkflowRunTriggeredFrom(StrEnum):
|
||||
DEBUGGING = "debugging"
|
||||
@ -96,3 +103,216 @@ class ConversationStatus(StrEnum):
|
||||
"""Conversation Status Enum"""
|
||||
|
||||
NORMAL = "normal"
|
||||
|
||||
|
||||
class DataSourceType(StrEnum):
|
||||
"""Data Source Type for Dataset and Document"""
|
||||
|
||||
UPLOAD_FILE = "upload_file"
|
||||
NOTION_IMPORT = "notion_import"
|
||||
WEBSITE_CRAWL = "website_crawl"
|
||||
LOCAL_FILE = "local_file"
|
||||
ONLINE_DOCUMENT = "online_document"
|
||||
|
||||
|
||||
class ProcessRuleMode(StrEnum):
|
||||
"""Dataset Process Rule Mode"""
|
||||
|
||||
AUTOMATIC = "automatic"
|
||||
CUSTOM = "custom"
|
||||
HIERARCHICAL = "hierarchical"
|
||||
|
||||
|
||||
class IndexingStatus(StrEnum):
|
||||
"""Document Indexing Status"""
|
||||
|
||||
WAITING = "waiting"
|
||||
PARSING = "parsing"
|
||||
CLEANING = "cleaning"
|
||||
SPLITTING = "splitting"
|
||||
INDEXING = "indexing"
|
||||
PAUSED = "paused"
|
||||
COMPLETED = "completed"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class DocumentCreatedFrom(StrEnum):
|
||||
"""Document Created From"""
|
||||
|
||||
WEB = "web"
|
||||
API = "api"
|
||||
RAG_PIPELINE = "rag-pipeline"
|
||||
|
||||
|
||||
class ConversationFromSource(StrEnum):
|
||||
"""Conversation / Message from_source"""
|
||||
|
||||
API = "api"
|
||||
CONSOLE = "console"
|
||||
|
||||
|
||||
class FeedbackFromSource(StrEnum):
|
||||
"""MessageFeedback from_source"""
|
||||
|
||||
USER = "user"
|
||||
ADMIN = "admin"
|
||||
|
||||
|
||||
class InvokeFrom(StrEnum):
|
||||
"""How a conversation/message was invoked"""
|
||||
|
||||
SERVICE_API = "service-api"
|
||||
WEB_APP = "web-app"
|
||||
TRIGGER = "trigger"
|
||||
EXPLORE = "explore"
|
||||
DEBUGGER = "debugger"
|
||||
PUBLISHED_PIPELINE = "published"
|
||||
VALIDATION = "validation"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "InvokeFrom":
|
||||
return cls(value)
|
||||
|
||||
def to_source(self) -> str:
|
||||
source_mapping = {
|
||||
InvokeFrom.WEB_APP: "web_app",
|
||||
InvokeFrom.DEBUGGER: "dev",
|
||||
InvokeFrom.EXPLORE: "explore_app",
|
||||
InvokeFrom.TRIGGER: "trigger",
|
||||
InvokeFrom.SERVICE_API: "api",
|
||||
}
|
||||
return source_mapping.get(self, "dev")
|
||||
|
||||
|
||||
class DocumentDocType(StrEnum):
|
||||
"""Document doc_type classification"""
|
||||
|
||||
BOOK = "book"
|
||||
WEB_PAGE = "web_page"
|
||||
PAPER = "paper"
|
||||
SOCIAL_MEDIA_POST = "social_media_post"
|
||||
WIKIPEDIA_ENTRY = "wikipedia_entry"
|
||||
PERSONAL_DOCUMENT = "personal_document"
|
||||
BUSINESS_DOCUMENT = "business_document"
|
||||
IM_CHAT_LOG = "im_chat_log"
|
||||
SYNCED_FROM_NOTION = "synced_from_notion"
|
||||
SYNCED_FROM_GITHUB = "synced_from_github"
|
||||
OTHERS = "others"
|
||||
|
||||
|
||||
class TagType(StrEnum):
|
||||
"""Tag type"""
|
||||
|
||||
KNOWLEDGE = "knowledge"
|
||||
APP = "app"
|
||||
|
||||
|
||||
class DatasetMetadataType(StrEnum):
|
||||
"""Dataset metadata value type"""
|
||||
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
TIME = "time"
|
||||
|
||||
|
||||
class SegmentStatus(StrEnum):
|
||||
"""Document segment status"""
|
||||
|
||||
WAITING = "waiting"
|
||||
INDEXING = "indexing"
|
||||
COMPLETED = "completed"
|
||||
ERROR = "error"
|
||||
PAUSED = "paused"
|
||||
RE_SEGMENT = "re_segment"
|
||||
|
||||
|
||||
class DatasetRuntimeMode(StrEnum):
|
||||
"""Dataset runtime mode"""
|
||||
|
||||
GENERAL = "general"
|
||||
RAG_PIPELINE = "rag_pipeline"
|
||||
|
||||
|
||||
class CollectionBindingType(StrEnum):
|
||||
"""Dataset collection binding type"""
|
||||
|
||||
DATASET = "dataset"
|
||||
ANNOTATION = "annotation"
|
||||
|
||||
|
||||
class DatasetQuerySource(StrEnum):
|
||||
"""Dataset query source"""
|
||||
|
||||
HIT_TESTING = "hit_testing"
|
||||
APP = "app"
|
||||
|
||||
|
||||
class TidbAuthBindingStatus(StrEnum):
|
||||
"""TiDB auth binding status"""
|
||||
|
||||
CREATING = "CREATING"
|
||||
ACTIVE = "ACTIVE"
|
||||
|
||||
|
||||
class MessageFileBelongsTo(StrEnum):
|
||||
"""MessageFile belongs_to"""
|
||||
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
|
||||
|
||||
class CredentialSourceType(StrEnum):
|
||||
"""Load balancing credential source type"""
|
||||
|
||||
PROVIDER = "provider"
|
||||
CUSTOM_MODEL = "custom_model"
|
||||
|
||||
|
||||
class PaymentStatus(StrEnum):
|
||||
"""Provider order payment status"""
|
||||
|
||||
WAIT_PAY = "wait_pay"
|
||||
PAID = "paid"
|
||||
FAILED = "failed"
|
||||
REFUNDED = "refunded"
|
||||
|
||||
|
||||
class BannerStatus(StrEnum):
|
||||
"""ExporleBanner status"""
|
||||
|
||||
ENABLED = "enabled"
|
||||
DISABLED = "disabled"
|
||||
|
||||
|
||||
class SummaryStatus(StrEnum):
|
||||
"""Document segment summary status"""
|
||||
|
||||
NOT_STARTED = "not_started"
|
||||
GENERATING = "generating"
|
||||
COMPLETED = "completed"
|
||||
ERROR = "error"
|
||||
TIMEOUT = "timeout"
|
||||
|
||||
|
||||
class MessageChainType(StrEnum):
|
||||
"""Message chain type"""
|
||||
|
||||
SYSTEM = "system"
|
||||
|
||||
|
||||
class ProviderQuotaType(StrEnum):
|
||||
PAID = "paid"
|
||||
"""hosted paid quota"""
|
||||
|
||||
FREE = "free"
|
||||
"""third-party free quota"""
|
||||
|
||||
TRIAL = "trial"
|
||||
"""hosted trial quota"""
|
||||
|
||||
@staticmethod
|
||||
def value_of(value: str) -> "ProviderQuotaType":
|
||||
for member in ProviderQuotaType:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
@ -380,13 +380,12 @@ class App(Base):
|
||||
|
||||
@property
|
||||
def site(self) -> Site | None:
|
||||
site = db.session.query(Site).where(Site.app_id == self.id).first()
|
||||
return site
|
||||
return db.session.scalar(select(Site).where(Site.app_id == self.id))
|
||||
|
||||
@property
|
||||
def app_model_config(self) -> AppModelConfig | None:
|
||||
if self.app_model_config_id:
|
||||
return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
|
||||
return db.session.scalar(select(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id))
|
||||
|
||||
return None
|
||||
|
||||
@ -395,7 +394,7 @@ class App(Base):
|
||||
if self.workflow_id:
|
||||
from .workflow import Workflow
|
||||
|
||||
return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first()
|
||||
return db.session.scalar(select(Workflow).where(Workflow.id == self.workflow_id))
|
||||
|
||||
return None
|
||||
|
||||
@ -405,8 +404,7 @@ class App(Base):
|
||||
|
||||
@property
|
||||
def tenant(self) -> Tenant | None:
|
||||
tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
|
||||
return tenant
|
||||
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
|
||||
|
||||
@property
|
||||
def is_agent(self) -> bool:
|
||||
@ -546,9 +544,9 @@ class App(Base):
|
||||
return deleted_tools
|
||||
|
||||
@property
|
||||
def tags(self) -> list[Tag]:
|
||||
tags = (
|
||||
db.session.query(Tag)
|
||||
def tags(self) -> Sequence[Tag]:
|
||||
tags = db.session.scalars(
|
||||
select(Tag)
|
||||
.join(TagBinding, Tag.id == TagBinding.tag_id)
|
||||
.where(
|
||||
TagBinding.target_id == self.id,
|
||||
@ -556,15 +554,14 @@ class App(Base):
|
||||
Tag.tenant_id == self.tenant_id,
|
||||
Tag.type == "app",
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
return tags or []
|
||||
|
||||
@property
|
||||
def author_name(self) -> str | None:
|
||||
if self.created_by:
|
||||
account = db.session.query(Account).where(Account.id == self.created_by).first()
|
||||
account = db.session.scalar(select(Account).where(Account.id == self.created_by))
|
||||
if account:
|
||||
return account.name
|
||||
|
||||
@ -616,8 +613,7 @@ class AppModelConfig(TypeBase):
|
||||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
return app
|
||||
return db.session.scalar(select(App).where(App.id == self.app_id))
|
||||
|
||||
@property
|
||||
def model_dict(self) -> ModelConfig:
|
||||
@ -652,8 +648,8 @@ class AppModelConfig(TypeBase):
|
||||
|
||||
@property
|
||||
def annotation_reply_dict(self) -> AnnotationReplyConfig:
|
||||
annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first()
|
||||
annotation_setting = db.session.scalar(
|
||||
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id)
|
||||
)
|
||||
if annotation_setting:
|
||||
collection_binding_detail = annotation_setting.collection_binding_detail
|
||||
@ -845,8 +841,7 @@ class RecommendedApp(Base): # bug
|
||||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
return app
|
||||
return db.session.scalar(select(App).where(App.id == self.app_id))
|
||||
|
||||
|
||||
class InstalledApp(TypeBase):
|
||||
@ -873,13 +868,11 @@ class InstalledApp(TypeBase):
|
||||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
return app
|
||||
return db.session.scalar(select(App).where(App.id == self.app_id))
|
||||
|
||||
@property
|
||||
def tenant(self) -> Tenant | None:
|
||||
tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
|
||||
return tenant
|
||||
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
|
||||
|
||||
|
||||
class TrialApp(Base):
|
||||
@ -899,8 +892,7 @@ class TrialApp(Base):
|
||||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
return app
|
||||
return db.session.scalar(select(App).where(App.id == self.app_id))
|
||||
|
||||
|
||||
class AccountTrialAppRecord(Base):
|
||||
@ -919,13 +911,11 @@ class AccountTrialAppRecord(Base):
|
||||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
return app
|
||||
return db.session.scalar(select(App).where(App.id == self.app_id))
|
||||
|
||||
@property
|
||||
def user(self) -> Account | None:
|
||||
user = db.session.query(Account).where(Account.id == self.account_id).first()
|
||||
return user
|
||||
return db.session.scalar(select(Account).where(Account.id == self.account_id))
|
||||
|
||||
|
||||
class ExporleBanner(TypeBase):
|
||||
@ -1117,8 +1107,8 @@ class Conversation(Base):
|
||||
else:
|
||||
model_config["configs"] = override_model_configs # type: ignore[typeddict-unknown-key]
|
||||
else:
|
||||
app_model_config = (
|
||||
db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
|
||||
app_model_config = db.session.scalar(
|
||||
select(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id)
|
||||
)
|
||||
if app_model_config:
|
||||
model_config = app_model_config.to_dict()
|
||||
@ -1141,36 +1131,43 @@ class Conversation(Base):
|
||||
|
||||
@property
|
||||
def annotated(self):
|
||||
return db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).count() > 0
|
||||
return (
|
||||
db.session.scalar(
|
||||
select(func.count(MessageAnnotation.id)).where(MessageAnnotation.conversation_id == self.id)
|
||||
)
|
||||
or 0
|
||||
) > 0
|
||||
|
||||
@property
|
||||
def annotation(self):
|
||||
return db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).first()
|
||||
return db.session.scalar(select(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).limit(1))
|
||||
|
||||
@property
|
||||
def message_count(self):
|
||||
return db.session.query(Message).where(Message.conversation_id == self.id).count()
|
||||
return db.session.scalar(select(func.count(Message.id)).where(Message.conversation_id == self.id)) or 0
|
||||
|
||||
@property
|
||||
def user_feedback_stats(self):
|
||||
like = (
|
||||
db.session.query(MessageFeedback)
|
||||
.where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "user",
|
||||
MessageFeedback.rating == "like",
|
||||
db.session.scalar(
|
||||
select(func.count(MessageFeedback.id)).where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "user",
|
||||
MessageFeedback.rating == "like",
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
|
||||
dislike = (
|
||||
db.session.query(MessageFeedback)
|
||||
.where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "user",
|
||||
MessageFeedback.rating == "dislike",
|
||||
db.session.scalar(
|
||||
select(func.count(MessageFeedback.id)).where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "user",
|
||||
MessageFeedback.rating == "dislike",
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
|
||||
return {"like": like, "dislike": dislike}
|
||||
@ -1178,23 +1175,25 @@ class Conversation(Base):
|
||||
@property
|
||||
def admin_feedback_stats(self):
|
||||
like = (
|
||||
db.session.query(MessageFeedback)
|
||||
.where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "admin",
|
||||
MessageFeedback.rating == "like",
|
||||
db.session.scalar(
|
||||
select(func.count(MessageFeedback.id)).where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "admin",
|
||||
MessageFeedback.rating == "like",
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
|
||||
dislike = (
|
||||
db.session.query(MessageFeedback)
|
||||
.where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "admin",
|
||||
MessageFeedback.rating == "dislike",
|
||||
db.session.scalar(
|
||||
select(func.count(MessageFeedback.id)).where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "admin",
|
||||
MessageFeedback.rating == "dislike",
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
|
||||
return {"like": like, "dislike": dislike}
|
||||
@ -1256,22 +1255,19 @@ class Conversation(Base):
|
||||
|
||||
@property
|
||||
def first_message(self):
|
||||
return (
|
||||
db.session.query(Message)
|
||||
.where(Message.conversation_id == self.id)
|
||||
.order_by(Message.created_at.asc())
|
||||
.first()
|
||||
return db.session.scalar(
|
||||
select(Message).where(Message.conversation_id == self.id).order_by(Message.created_at.asc())
|
||||
)
|
||||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
return session.query(App).where(App.id == self.app_id).first()
|
||||
return session.scalar(select(App).where(App.id == self.app_id))
|
||||
|
||||
@property
|
||||
def from_end_user_session_id(self):
|
||||
if self.from_end_user_id:
|
||||
end_user = db.session.query(EndUser).where(EndUser.id == self.from_end_user_id).first()
|
||||
end_user = db.session.scalar(select(EndUser).where(EndUser.id == self.from_end_user_id))
|
||||
if end_user:
|
||||
return end_user.session_id
|
||||
|
||||
@ -1280,7 +1276,7 @@ class Conversation(Base):
|
||||
@property
|
||||
def from_account_name(self) -> str | None:
|
||||
if self.from_account_id:
|
||||
account = db.session.query(Account).where(Account.id == self.from_account_id).first()
|
||||
account = db.session.scalar(select(Account).where(Account.id == self.from_account_id))
|
||||
if account:
|
||||
return account.name
|
||||
|
||||
@ -1505,21 +1501,15 @@ class Message(Base):
|
||||
|
||||
@property
|
||||
def user_feedback(self):
|
||||
feedback = (
|
||||
db.session.query(MessageFeedback)
|
||||
.where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user")
|
||||
.first()
|
||||
return db.session.scalar(
|
||||
select(MessageFeedback).where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user")
|
||||
)
|
||||
return feedback
|
||||
|
||||
@property
|
||||
def admin_feedback(self):
|
||||
feedback = (
|
||||
db.session.query(MessageFeedback)
|
||||
.where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin")
|
||||
.first()
|
||||
return db.session.scalar(
|
||||
select(MessageFeedback).where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin")
|
||||
)
|
||||
return feedback
|
||||
|
||||
@property
|
||||
def feedbacks(self):
|
||||
@ -1528,28 +1518,27 @@ class Message(Base):
|
||||
|
||||
@property
|
||||
def annotation(self):
|
||||
annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == self.id).first()
|
||||
annotation = db.session.scalar(select(MessageAnnotation).where(MessageAnnotation.message_id == self.id))
|
||||
return annotation
|
||||
|
||||
@property
|
||||
def annotation_hit_history(self):
|
||||
annotation_history = (
|
||||
db.session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id == self.id).first()
|
||||
annotation_history = db.session.scalar(
|
||||
select(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id == self.id)
|
||||
)
|
||||
if annotation_history:
|
||||
annotation = (
|
||||
db.session.query(MessageAnnotation)
|
||||
.where(MessageAnnotation.id == annotation_history.annotation_id)
|
||||
.first()
|
||||
return db.session.scalar(
|
||||
select(MessageAnnotation).where(MessageAnnotation.id == annotation_history.annotation_id)
|
||||
)
|
||||
return annotation
|
||||
return None
|
||||
|
||||
@property
|
||||
def app_model_config(self):
|
||||
conversation = db.session.query(Conversation).where(Conversation.id == self.conversation_id).first()
|
||||
conversation = db.session.scalar(select(Conversation).where(Conversation.id == self.conversation_id))
|
||||
if conversation:
|
||||
return db.session.query(AppModelConfig).where(AppModelConfig.id == conversation.app_model_config_id).first()
|
||||
return db.session.scalar(
|
||||
select(AppModelConfig).where(AppModelConfig.id == conversation.app_model_config_id)
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
@ -1562,13 +1551,12 @@ class Message(Base):
|
||||
return json.loads(self.message_metadata) if self.message_metadata else {}
|
||||
|
||||
@property
|
||||
def agent_thoughts(self) -> list[MessageAgentThought]:
|
||||
return (
|
||||
db.session.query(MessageAgentThought)
|
||||
def agent_thoughts(self) -> Sequence[MessageAgentThought]:
|
||||
return db.session.scalars(
|
||||
select(MessageAgentThought)
|
||||
.where(MessageAgentThought.message_id == self.id)
|
||||
.order_by(MessageAgentThought.position.asc())
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
@property
|
||||
def retriever_resources(self) -> Any:
|
||||
@ -1579,7 +1567,7 @@ class Message(Base):
|
||||
from factories import file_factory
|
||||
|
||||
message_files = db.session.scalars(select(MessageFile).where(MessageFile.message_id == self.id)).all()
|
||||
current_app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
current_app = db.session.scalar(select(App).where(App.id == self.app_id))
|
||||
if not current_app:
|
||||
raise ValueError(f"App {self.app_id} not found")
|
||||
|
||||
@ -1743,8 +1731,7 @@ class MessageFeedback(TypeBase):
|
||||
|
||||
@property
|
||||
def from_account(self) -> Account | None:
|
||||
account = db.session.query(Account).where(Account.id == self.from_account_id).first()
|
||||
return account
|
||||
return db.session.scalar(select(Account).where(Account.id == self.from_account_id))
|
||||
|
||||
def to_dict(self) -> MessageFeedbackDict:
|
||||
return {
|
||||
@ -1817,13 +1804,11 @@ class MessageAnnotation(Base):
|
||||
|
||||
@property
|
||||
def account(self):
|
||||
account = db.session.query(Account).where(Account.id == self.account_id).first()
|
||||
return account
|
||||
return db.session.scalar(select(Account).where(Account.id == self.account_id))
|
||||
|
||||
@property
|
||||
def annotation_create_account(self):
|
||||
account = db.session.query(Account).where(Account.id == self.account_id).first()
|
||||
return account
|
||||
return db.session.scalar(select(Account).where(Account.id == self.account_id))
|
||||
|
||||
|
||||
class AppAnnotationHitHistory(TypeBase):
|
||||
@ -1852,18 +1837,15 @@ class AppAnnotationHitHistory(TypeBase):
|
||||
|
||||
@property
|
||||
def account(self):
|
||||
account = (
|
||||
db.session.query(Account)
|
||||
return db.session.scalar(
|
||||
select(Account)
|
||||
.join(MessageAnnotation, MessageAnnotation.account_id == Account.id)
|
||||
.where(MessageAnnotation.id == self.annotation_id)
|
||||
.first()
|
||||
)
|
||||
return account
|
||||
|
||||
@property
|
||||
def annotation_create_account(self):
|
||||
account = db.session.query(Account).where(Account.id == self.account_id).first()
|
||||
return account
|
||||
return db.session.scalar(select(Account).where(Account.id == self.account_id))
|
||||
|
||||
|
||||
class AppAnnotationSetting(TypeBase):
|
||||
@ -1896,12 +1878,9 @@ class AppAnnotationSetting(TypeBase):
|
||||
def collection_binding_detail(self):
|
||||
from .dataset import DatasetCollectionBinding
|
||||
|
||||
collection_binding_detail = (
|
||||
db.session.query(DatasetCollectionBinding)
|
||||
.where(DatasetCollectionBinding.id == self.collection_binding_id)
|
||||
.first()
|
||||
return db.session.scalar(
|
||||
select(DatasetCollectionBinding).where(DatasetCollectionBinding.id == self.collection_binding_id)
|
||||
)
|
||||
return collection_binding_detail
|
||||
|
||||
|
||||
class OperationLog(TypeBase):
|
||||
@ -2007,7 +1986,9 @@ class AppMCPServer(TypeBase):
|
||||
def generate_server_code(n: int) -> str:
|
||||
while True:
|
||||
result = generate_string(n)
|
||||
while db.session.query(AppMCPServer).where(AppMCPServer.server_code == result).count() > 0:
|
||||
while (
|
||||
db.session.scalar(select(func.count(AppMCPServer.id)).where(AppMCPServer.server_code == result)) or 0
|
||||
) > 0:
|
||||
result = generate_string(n)
|
||||
|
||||
return result
|
||||
@ -2068,7 +2049,7 @@ class Site(Base):
|
||||
def generate_code(n: int) -> str:
|
||||
while True:
|
||||
result = generate_string(n)
|
||||
while db.session.query(Site).where(Site.code == result).count() > 0:
|
||||
while (db.session.scalar(select(func.count(Site.id)).where(Site.code == result)) or 0) > 0:
|
||||
result = generate_string(n)
|
||||
|
||||
return result
|
||||
|
||||
@ -6,13 +6,14 @@ from functools import cached_property
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import DateTime, String, func, text
|
||||
from sqlalchemy import DateTime, String, func, select, text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from libs.uuid_utils import uuidv7
|
||||
|
||||
from .base import TypeBase
|
||||
from .engine import db
|
||||
from .enums import CredentialSourceType, PaymentStatus
|
||||
from .types import EnumText, LongText, StringUUID
|
||||
|
||||
|
||||
@ -96,7 +97,7 @@ class Provider(TypeBase):
|
||||
@cached_property
|
||||
def credential(self):
|
||||
if self.credential_id:
|
||||
return db.session.query(ProviderCredential).where(ProviderCredential.id == self.credential_id).first()
|
||||
return db.session.scalar(select(ProviderCredential).where(ProviderCredential.id == self.credential_id))
|
||||
|
||||
@property
|
||||
def credential_name(self):
|
||||
@ -159,10 +160,8 @@ class ProviderModel(TypeBase):
|
||||
@cached_property
|
||||
def credential(self):
|
||||
if self.credential_id:
|
||||
return (
|
||||
db.session.query(ProviderModelCredential)
|
||||
.where(ProviderModelCredential.id == self.credential_id)
|
||||
.first()
|
||||
return db.session.scalar(
|
||||
select(ProviderModelCredential).where(ProviderModelCredential.id == self.credential_id)
|
||||
)
|
||||
|
||||
@property
|
||||
@ -239,7 +238,9 @@ class ProviderOrder(TypeBase):
|
||||
quantity: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=text("1"))
|
||||
currency: Mapped[str | None] = mapped_column(String(40))
|
||||
total_amount: Mapped[int | None] = mapped_column(sa.Integer)
|
||||
payment_status: Mapped[str] = mapped_column(String(40), nullable=False, server_default=text("'wait_pay'"))
|
||||
payment_status: Mapped[PaymentStatus] = mapped_column(
|
||||
EnumText(PaymentStatus, length=40), nullable=False, server_default=text("'wait_pay'")
|
||||
)
|
||||
paid_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||
pay_failed_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||
refunded_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||
@ -302,7 +303,9 @@ class LoadBalancingModelConfig(TypeBase):
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
encrypted_config: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
|
||||
credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
credential_source_type: Mapped[str | None] = mapped_column(String(40), nullable=True, default=None)
|
||||
credential_source_type: Mapped[CredentialSourceType | None] = mapped_column(
|
||||
EnumText(CredentialSourceType, length=40), nullable=True, default=None
|
||||
)
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"), default=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
|
||||
@ -8,7 +8,7 @@ from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
from deprecated import deprecated
|
||||
from sqlalchemy import ForeignKey, String, func
|
||||
from sqlalchemy import ForeignKey, String, func, select
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
@ -184,11 +184,11 @@ class ApiToolProvider(TypeBase):
|
||||
def user(self) -> Account | None:
|
||||
if not self.user_id:
|
||||
return None
|
||||
return db.session.query(Account).where(Account.id == self.user_id).first()
|
||||
return db.session.scalar(select(Account).where(Account.id == self.user_id))
|
||||
|
||||
@property
|
||||
def tenant(self) -> Tenant | None:
|
||||
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
|
||||
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
|
||||
|
||||
|
||||
class ToolLabelBinding(TypeBase):
|
||||
@ -262,11 +262,11 @@ class WorkflowToolProvider(TypeBase):
|
||||
|
||||
@property
|
||||
def user(self) -> Account | None:
|
||||
return db.session.query(Account).where(Account.id == self.user_id).first()
|
||||
return db.session.scalar(select(Account).where(Account.id == self.user_id))
|
||||
|
||||
@property
|
||||
def tenant(self) -> Tenant | None:
|
||||
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
|
||||
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
|
||||
|
||||
@property
|
||||
def parameter_configurations(self) -> list[WorkflowToolParameterConfiguration]:
|
||||
@ -277,7 +277,7 @@ class WorkflowToolProvider(TypeBase):
|
||||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
return db.session.query(App).where(App.id == self.app_id).first()
|
||||
return db.session.scalar(select(App).where(App.id == self.app_id))
|
||||
|
||||
|
||||
class MCPToolProvider(TypeBase):
|
||||
@ -334,7 +334,7 @@ class MCPToolProvider(TypeBase):
|
||||
encrypted_headers: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
|
||||
|
||||
def load_user(self) -> Account | None:
|
||||
return db.session.query(Account).where(Account.id == self.user_id).first()
|
||||
return db.session.scalar(select(Account).where(Account.id == self.user_id))
|
||||
|
||||
@property
|
||||
def credentials(self) -> dict[str, Any]:
|
||||
|
||||
@ -3,7 +3,7 @@ import time
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from functools import cached_property
|
||||
from typing import Any, cast
|
||||
from typing import Any, TypedDict, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
@ -23,6 +23,47 @@ from .enums import AppTriggerStatus, AppTriggerType, CreatorUserRole, WorkflowTr
|
||||
from .model import Account
|
||||
from .types import EnumText, LongText, StringUUID
|
||||
|
||||
TriggerJsonObject = dict[str, object]
|
||||
TriggerCredentials = dict[str, str]
|
||||
|
||||
|
||||
class WorkflowTriggerLogDict(TypedDict):
|
||||
id: str
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
workflow_id: str
|
||||
workflow_run_id: str | None
|
||||
root_node_id: str | None
|
||||
trigger_metadata: Any
|
||||
trigger_type: str
|
||||
trigger_data: Any
|
||||
inputs: Any
|
||||
outputs: Any
|
||||
status: str
|
||||
error: str | None
|
||||
queue_name: str
|
||||
celery_task_id: str | None
|
||||
retry_count: int
|
||||
elapsed_time: float | None
|
||||
total_tokens: int | None
|
||||
created_by_role: str
|
||||
created_by: str
|
||||
created_at: str | None
|
||||
triggered_at: str | None
|
||||
finished_at: str | None
|
||||
|
||||
|
||||
class WorkflowSchedulePlanDict(TypedDict):
|
||||
id: str
|
||||
app_id: str
|
||||
node_id: str
|
||||
tenant_id: str
|
||||
cron_expression: str
|
||||
timezone: str
|
||||
next_run_at: str | None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
class TriggerSubscription(TypeBase):
|
||||
"""
|
||||
@ -51,10 +92,14 @@ class TriggerSubscription(TypeBase):
|
||||
String(255), nullable=False, comment="Provider identifier (e.g., plugin_id/provider_name)"
|
||||
)
|
||||
endpoint_id: Mapped[str] = mapped_column(String(255), nullable=False, comment="Subscription endpoint")
|
||||
parameters: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False, comment="Subscription parameters JSON")
|
||||
properties: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False, comment="Subscription properties JSON")
|
||||
parameters: Mapped[TriggerJsonObject] = mapped_column(
|
||||
sa.JSON, nullable=False, comment="Subscription parameters JSON"
|
||||
)
|
||||
properties: Mapped[TriggerJsonObject] = mapped_column(
|
||||
sa.JSON, nullable=False, comment="Subscription properties JSON"
|
||||
)
|
||||
|
||||
credentials: Mapped[dict[str, Any]] = mapped_column(
|
||||
credentials: Mapped[TriggerCredentials] = mapped_column(
|
||||
sa.JSON, nullable=False, comment="Subscription credentials JSON"
|
||||
)
|
||||
credential_type: Mapped[str] = mapped_column(String(50), nullable=False, comment="oauth or api_key")
|
||||
@ -162,8 +207,8 @@ class TriggerOAuthTenantClient(TypeBase):
|
||||
)
|
||||
|
||||
@property
|
||||
def oauth_params(self) -> Mapping[str, Any]:
|
||||
return cast(Mapping[str, Any], json.loads(self.encrypted_oauth_params or "{}"))
|
||||
def oauth_params(self) -> Mapping[str, object]:
|
||||
return cast(TriggerJsonObject, json.loads(self.encrypted_oauth_params or "{}"))
|
||||
|
||||
|
||||
class WorkflowTriggerLog(TypeBase):
|
||||
@ -250,7 +295,7 @@ class WorkflowTriggerLog(TypeBase):
|
||||
created_by_role = CreatorUserRole(self.created_by_role)
|
||||
return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
def to_dict(self) -> WorkflowTriggerLogDict:
|
||||
"""Convert to dictionary for API responses"""
|
||||
return {
|
||||
"id": self.id,
|
||||
@ -481,7 +526,7 @@ class WorkflowSchedulePlan(TypeBase):
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), init=False
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
def to_dict(self) -> WorkflowSchedulePlanDict:
|
||||
"""Convert to dictionary representation"""
|
||||
return {
|
||||
"id": self.id,
|
||||
|
||||
@ -2,7 +2,7 @@ from datetime import datetime
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import DateTime, func
|
||||
from sqlalchemy import DateTime, func, select
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from .base import TypeBase
|
||||
@ -38,7 +38,7 @@ class SavedMessage(TypeBase):
|
||||
|
||||
@property
|
||||
def message(self):
|
||||
return db.session.query(Message).where(Message.id == self.message_id).first()
|
||||
return db.session.scalar(select(Message).where(Message.id == self.message_id))
|
||||
|
||||
|
||||
class PinnedConversation(TypeBase):
|
||||
|
||||
@ -3,7 +3,7 @@ import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
@ -19,7 +19,7 @@ from sqlalchemy import (
|
||||
orm,
|
||||
select,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, declared_attr, mapped_column
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from core.trigger.constants import TRIGGER_INFO_METADATA_KEY, TRIGGER_PLUGIN_NODE_TYPE
|
||||
@ -33,7 +33,7 @@ from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowExecutionStatus
|
||||
from dify_graph.file.constants import maybe_file_object
|
||||
from dify_graph.file.models import File
|
||||
from dify_graph.variables import utils as variable_utils
|
||||
from dify_graph.variables.variables import FloatVariable, IntegerVariable, StringVariable
|
||||
from dify_graph.variables.variables import FloatVariable, IntegerVariable, RAGPipelineVariable, StringVariable
|
||||
from extensions.ext_storage import Storage
|
||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@ -59,6 +59,25 @@ from .types import EnumText, LongText, StringUUID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SerializedWorkflowValue = dict[str, Any]
|
||||
SerializedWorkflowVariables = dict[str, SerializedWorkflowValue]
|
||||
|
||||
|
||||
class WorkflowContentDict(TypedDict):
|
||||
graph: Mapping[str, Any]
|
||||
features: dict[str, Any]
|
||||
environment_variables: list[dict[str, Any]]
|
||||
conversation_variables: list[dict[str, Any]]
|
||||
rag_pipeline_variables: list[dict[str, Any]]
|
||||
|
||||
|
||||
class WorkflowRunSummaryDict(TypedDict):
|
||||
id: str
|
||||
status: str
|
||||
triggered_from: str
|
||||
elapsed_time: float
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class WorkflowType(StrEnum):
|
||||
"""
|
||||
@ -389,7 +408,7 @@ class Workflow(Base): # bug
|
||||
|
||||
def rag_pipeline_user_input_form(self) -> list:
|
||||
# get user_input_form from start node
|
||||
variables: list[Any] = self.rag_pipeline_variables
|
||||
variables: list[SerializedWorkflowValue] = self.rag_pipeline_variables
|
||||
|
||||
return variables
|
||||
|
||||
@ -432,17 +451,13 @@ class Workflow(Base): # bug
|
||||
def environment_variables(
|
||||
self,
|
||||
) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
|
||||
# TODO: find some way to init `self._environment_variables` when instance created.
|
||||
if self._environment_variables is None:
|
||||
self._environment_variables = "{}"
|
||||
|
||||
# Use workflow.tenant_id to avoid relying on request user in background threads
|
||||
tenant_id = self.tenant_id
|
||||
|
||||
if not tenant_id:
|
||||
return []
|
||||
|
||||
environment_variables_dict: dict[str, Any] = json.loads(self._environment_variables or "{}")
|
||||
environment_variables_dict = cast(SerializedWorkflowVariables, json.loads(self._environment_variables or "{}"))
|
||||
results = [
|
||||
variable_factory.build_environment_variable_from_mapping(v) for v in environment_variables_dict.values()
|
||||
]
|
||||
@ -502,14 +517,14 @@ class Workflow(Base): # bug
|
||||
)
|
||||
self._environment_variables = environment_variables_json
|
||||
|
||||
def to_dict(self, *, include_secret: bool = False) -> Mapping[str, Any]:
|
||||
def to_dict(self, *, include_secret: bool = False) -> WorkflowContentDict:
|
||||
environment_variables = list(self.environment_variables)
|
||||
environment_variables = [
|
||||
v if not isinstance(v, SecretVariable) or include_secret else v.model_copy(update={"value": ""})
|
||||
for v in environment_variables
|
||||
]
|
||||
|
||||
result = {
|
||||
result: WorkflowContentDict = {
|
||||
"graph": self.graph_dict,
|
||||
"features": self.features_dict,
|
||||
"environment_variables": [var.model_dump(mode="json") for var in environment_variables],
|
||||
@ -520,11 +535,7 @@ class Workflow(Base): # bug
|
||||
|
||||
@property
|
||||
def conversation_variables(self) -> Sequence[VariableBase]:
|
||||
# TODO: find some way to init `self._conversation_variables` when instance created.
|
||||
if self._conversation_variables is None:
|
||||
self._conversation_variables = "{}"
|
||||
|
||||
variables_dict: dict[str, Any] = json.loads(self._conversation_variables)
|
||||
variables_dict = cast(SerializedWorkflowVariables, json.loads(self._conversation_variables or "{}"))
|
||||
results = [variable_factory.build_conversation_variable_from_mapping(v) for v in variables_dict.values()]
|
||||
return results
|
||||
|
||||
@ -536,19 +547,20 @@ class Workflow(Base): # bug
|
||||
)
|
||||
|
||||
@property
|
||||
def rag_pipeline_variables(self) -> list[dict]:
|
||||
# TODO: find some way to init `self._conversation_variables` when instance created.
|
||||
if self._rag_pipeline_variables is None:
|
||||
self._rag_pipeline_variables = "{}"
|
||||
|
||||
variables_dict: dict[str, Any] = json.loads(self._rag_pipeline_variables)
|
||||
results = list(variables_dict.values())
|
||||
return results
|
||||
def rag_pipeline_variables(self) -> list[SerializedWorkflowValue]:
|
||||
variables_dict = cast(SerializedWorkflowVariables, json.loads(self._rag_pipeline_variables or "{}"))
|
||||
return [RAGPipelineVariable.model_validate(item).model_dump(mode="json") for item in variables_dict.values()]
|
||||
|
||||
@rag_pipeline_variables.setter
|
||||
def rag_pipeline_variables(self, values: list[dict]) -> None:
|
||||
def rag_pipeline_variables(self, values: Sequence[Mapping[str, Any] | RAGPipelineVariable]) -> None:
|
||||
self._rag_pipeline_variables = json.dumps(
|
||||
{item["variable"]: item for item in values},
|
||||
{
|
||||
rag_pipeline_variable.variable: rag_pipeline_variable.model_dump(mode="json")
|
||||
for rag_pipeline_variable in (
|
||||
item if isinstance(item, RAGPipelineVariable) else RAGPipelineVariable.model_validate(item)
|
||||
for item in values
|
||||
)
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
@ -667,14 +679,14 @@ class WorkflowRun(Base):
|
||||
def message(self):
|
||||
from .model import Message
|
||||
|
||||
return (
|
||||
db.session.query(Message).where(Message.app_id == self.app_id, Message.workflow_run_id == self.id).first()
|
||||
return db.session.scalar(
|
||||
select(Message).where(Message.app_id == self.app_id, Message.workflow_run_id == self.id)
|
||||
)
|
||||
|
||||
@property
|
||||
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
|
||||
def workflow(self):
|
||||
return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first()
|
||||
return db.session.scalar(select(Workflow).where(Workflow.id == self.workflow_id))
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
@ -786,44 +798,36 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
||||
|
||||
__tablename__ = "workflow_node_executions"
|
||||
|
||||
@declared_attr.directive
|
||||
@classmethod
|
||||
def __table_args__(cls) -> Any:
|
||||
return (
|
||||
PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
|
||||
Index(
|
||||
"workflow_node_execution_workflow_run_id_idx",
|
||||
"workflow_run_id",
|
||||
),
|
||||
Index(
|
||||
"workflow_node_execution_node_run_idx",
|
||||
"tenant_id",
|
||||
"app_id",
|
||||
"workflow_id",
|
||||
"triggered_from",
|
||||
"node_id",
|
||||
),
|
||||
Index(
|
||||
"workflow_node_execution_id_idx",
|
||||
"tenant_id",
|
||||
"app_id",
|
||||
"workflow_id",
|
||||
"triggered_from",
|
||||
"node_execution_id",
|
||||
),
|
||||
Index(
|
||||
# The first argument is the index name,
|
||||
# which we leave as `None`` to allow auto-generation by the ORM.
|
||||
None,
|
||||
cls.tenant_id,
|
||||
cls.workflow_id,
|
||||
cls.node_id,
|
||||
# MyPy may flag the following line because it doesn't recognize that
|
||||
# the `declared_attr` decorator passes the receiving class as the first
|
||||
# argument to this method, allowing us to reference class attributes.
|
||||
cls.created_at.desc(),
|
||||
),
|
||||
)
|
||||
__table_args__ = (
|
||||
PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
|
||||
Index(
|
||||
"workflow_node_execution_workflow_run_id_idx",
|
||||
"workflow_run_id",
|
||||
),
|
||||
Index(
|
||||
"workflow_node_execution_node_run_idx",
|
||||
"tenant_id",
|
||||
"app_id",
|
||||
"workflow_id",
|
||||
"triggered_from",
|
||||
"node_id",
|
||||
),
|
||||
Index(
|
||||
"workflow_node_execution_id_idx",
|
||||
"tenant_id",
|
||||
"app_id",
|
||||
"workflow_id",
|
||||
"triggered_from",
|
||||
"node_execution_id",
|
||||
),
|
||||
Index(
|
||||
None,
|
||||
"tenant_id",
|
||||
"workflow_id",
|
||||
"node_id",
|
||||
sa.desc("created_at"),
|
||||
),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID)
|
||||
@ -1231,7 +1235,7 @@ class WorkflowArchiveLog(TypeBase):
|
||||
)
|
||||
|
||||
@property
|
||||
def workflow_run_summary(self) -> dict[str, Any]:
|
||||
def workflow_run_summary(self) -> WorkflowRunSummaryDict:
|
||||
return {
|
||||
"id": self.workflow_run_id,
|
||||
"status": self.run_status,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dify-api"
|
||||
version = "1.13.0"
|
||||
version = "1.13.1"
|
||||
requires-python = ">=3.11,<3.13"
|
||||
|
||||
dependencies = [
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
configs/middleware/cache/redis_pubsub_config.py
|
||||
controllers/console/app/annotation.py
|
||||
controllers/console/app/app.py
|
||||
controllers/console/app/app_import.py
|
||||
@ -138,8 +137,6 @@ dify_graph/nodes/trigger_webhook/node.py
|
||||
dify_graph/nodes/variable_aggregator/variable_aggregator_node.py
|
||||
dify_graph/nodes/variable_assigner/v1/node.py
|
||||
dify_graph/nodes/variable_assigner/v2/node.py
|
||||
dify_graph/variables/types.py
|
||||
extensions/ext_fastopenapi.py
|
||||
extensions/logstore/repositories/logstore_api_workflow_run_repository.py
|
||||
extensions/otel/instrumentation.py
|
||||
extensions/otel/runtime.py
|
||||
@ -156,19 +153,7 @@ extensions/storage/oracle_oci_storage.py
|
||||
extensions/storage/supabase_storage.py
|
||||
extensions/storage/tencent_cos_storage.py
|
||||
extensions/storage/volcengine_tos_storage.py
|
||||
factories/variable_factory.py
|
||||
libs/external_api.py
|
||||
libs/gmpy2_pkcs10aep_cipher.py
|
||||
libs/helper.py
|
||||
libs/login.py
|
||||
libs/module_loading.py
|
||||
libs/oauth.py
|
||||
libs/oauth_data_source.py
|
||||
models/trigger.py
|
||||
models/workflow.py
|
||||
repositories/sqlalchemy_api_workflow_node_execution_repository.py
|
||||
repositories/sqlalchemy_api_workflow_run_repository.py
|
||||
repositories/sqlalchemy_execution_extra_content_repository.py
|
||||
schedule/queue_monitor_task.py
|
||||
services/account_service.py
|
||||
services/audio_service.py
|
||||
@ -197,4 +182,9 @@ tasks/app_generate/workflow_execute_task.py
|
||||
tasks/regenerate_summary_index_task.py
|
||||
tasks/trigger_processing_tasks.py
|
||||
tasks/workflow_cfs_scheduler/cfs_scheduler.py
|
||||
tasks/add_document_to_index_task.py
|
||||
tasks/create_segment_to_index_task.py
|
||||
tasks/disable_segment_from_index_task.py
|
||||
tasks/enable_segment_to_index_task.py
|
||||
tasks/remove_document_from_index_task.py
|
||||
tasks/workflow_execution_tasks.py
|
||||
|
||||
@ -8,7 +8,7 @@ using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations.
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
from typing import Protocol, cast
|
||||
|
||||
from sqlalchemy import asc, delete, desc, func, select
|
||||
from sqlalchemy.engine import CursorResult
|
||||
@ -22,6 +22,20 @@ from repositories.api_workflow_node_execution_repository import (
|
||||
)
|
||||
|
||||
|
||||
class _WorkflowNodeExecutionSnapshotRow(Protocol):
|
||||
id: str
|
||||
node_execution_id: str | None
|
||||
node_id: str
|
||||
node_type: str
|
||||
title: str
|
||||
index: int
|
||||
status: WorkflowNodeExecutionStatus
|
||||
elapsed_time: float | None
|
||||
created_at: datetime
|
||||
finished_at: datetime | None
|
||||
execution_metadata: str | None
|
||||
|
||||
|
||||
class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRepository):
|
||||
"""
|
||||
SQLAlchemy implementation of DifyAPIWorkflowNodeExecutionRepository.
|
||||
@ -40,6 +54,8 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
||||
- Thread-safe database operations using session-per-request pattern
|
||||
"""
|
||||
|
||||
_session_maker: sessionmaker[Session]
|
||||
|
||||
def __init__(self, session_maker: sessionmaker[Session]):
|
||||
"""
|
||||
Initialize the repository with a sessionmaker.
|
||||
@ -156,12 +172,12 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
||||
)
|
||||
|
||||
with self._session_maker() as session:
|
||||
rows = session.execute(stmt).all()
|
||||
rows = cast(Sequence[_WorkflowNodeExecutionSnapshotRow], session.execute(stmt).all())
|
||||
|
||||
return [self._row_to_snapshot(row) for row in rows]
|
||||
|
||||
@staticmethod
|
||||
def _row_to_snapshot(row: object) -> WorkflowNodeExecutionSnapshot:
|
||||
def _row_to_snapshot(row: _WorkflowNodeExecutionSnapshotRow) -> WorkflowNodeExecutionSnapshot:
|
||||
metadata: dict[str, object] = {}
|
||||
execution_metadata = getattr(row, "execution_metadata", None)
|
||||
if execution_metadata:
|
||||
|
||||
@ -11,7 +11,7 @@ from core.tools.tool_manager import ToolManager
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
from models.model import App, Conversation, EndUser, Message, MessageAgentThought
|
||||
from models.model import App, Conversation, EndUser, Message
|
||||
|
||||
|
||||
class AgentService:
|
||||
@ -47,7 +47,7 @@ class AgentService:
|
||||
if not message:
|
||||
raise ValueError(f"Message not found: {message_id}")
|
||||
|
||||
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
|
||||
agent_thoughts = message.agent_thoughts
|
||||
|
||||
if conversation.from_end_user_id:
|
||||
# only select name field
|
||||
|
||||
@ -18,7 +18,7 @@ from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.enums import CreatorUserRole, WorkflowTriggerStatus
|
||||
from models.model import App, EndUser
|
||||
from models.trigger import WorkflowTriggerLog
|
||||
from models.trigger import WorkflowTriggerLog, WorkflowTriggerLogDict
|
||||
from models.workflow import Workflow
|
||||
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
||||
from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError
|
||||
@ -224,7 +224,9 @@ class AsyncWorkflowService:
|
||||
return cls.trigger_workflow_async(session, user, trigger_data)
|
||||
|
||||
@classmethod
|
||||
def get_trigger_log(cls, workflow_trigger_log_id: str, tenant_id: str | None = None) -> dict[str, Any] | None:
|
||||
def get_trigger_log(
|
||||
cls, workflow_trigger_log_id: str, tenant_id: str | None = None
|
||||
) -> WorkflowTriggerLogDict | None:
|
||||
"""
|
||||
Get trigger log by ID
|
||||
|
||||
@ -247,7 +249,7 @@ class AsyncWorkflowService:
|
||||
@classmethod
|
||||
def get_recent_logs(
|
||||
cls, tenant_id: str, app_id: str, hours: int = 24, limit: int = 100, offset: int = 0
|
||||
) -> list[dict[str, Any]]:
|
||||
) -> list[WorkflowTriggerLogDict]:
|
||||
"""
|
||||
Get recent trigger logs
|
||||
|
||||
@ -272,7 +274,7 @@ class AsyncWorkflowService:
|
||||
@classmethod
|
||||
def get_failed_logs_for_retry(
|
||||
cls, tenant_id: str, max_retry_count: int = 3, limit: int = 100
|
||||
) -> list[dict[str, Any]]:
|
||||
) -> list[WorkflowTriggerLogDict]:
|
||||
"""
|
||||
Get failed logs eligible for retry
|
||||
|
||||
|
||||
@ -51,6 +51,14 @@ from models.dataset import (
|
||||
Pipeline,
|
||||
SegmentAttachmentBinding,
|
||||
)
|
||||
from models.enums import (
|
||||
DatasetRuntimeMode,
|
||||
DataSourceType,
|
||||
DocumentCreatedFrom,
|
||||
IndexingStatus,
|
||||
ProcessRuleMode,
|
||||
SegmentStatus,
|
||||
)
|
||||
from models.model import UploadFile
|
||||
from models.provider_ids import ModelProviderID
|
||||
from models.source import DataSourceOauthBinding
|
||||
@ -319,7 +327,7 @@ class DatasetService:
|
||||
description=rag_pipeline_dataset_create_entity.description,
|
||||
permission=rag_pipeline_dataset_create_entity.permission,
|
||||
provider="vendor",
|
||||
runtime_mode="rag_pipeline",
|
||||
runtime_mode=DatasetRuntimeMode.RAG_PIPELINE,
|
||||
icon_info=rag_pipeline_dataset_create_entity.icon_info.model_dump(),
|
||||
created_by=current_user.id,
|
||||
pipeline_id=pipeline.id,
|
||||
@ -614,7 +622,7 @@ class DatasetService:
|
||||
"""
|
||||
Update pipeline knowledge base node data.
|
||||
"""
|
||||
if dataset.runtime_mode != "rag_pipeline":
|
||||
if dataset.runtime_mode != DatasetRuntimeMode.RAG_PIPELINE:
|
||||
return
|
||||
|
||||
pipeline = db.session.query(Pipeline).filter_by(id=dataset.pipeline_id).first()
|
||||
@ -1229,10 +1237,15 @@ class DocumentService:
|
||||
"enabled": "available",
|
||||
}
|
||||
|
||||
_INDEXING_STATUSES: tuple[str, ...] = ("parsing", "cleaning", "splitting", "indexing")
|
||||
_INDEXING_STATUSES: tuple[IndexingStatus, ...] = (
|
||||
IndexingStatus.PARSING,
|
||||
IndexingStatus.CLEANING,
|
||||
IndexingStatus.SPLITTING,
|
||||
IndexingStatus.INDEXING,
|
||||
)
|
||||
|
||||
DISPLAY_STATUS_FILTERS: dict[str, tuple[Any, ...]] = {
|
||||
"queuing": (Document.indexing_status == "waiting",),
|
||||
"queuing": (Document.indexing_status == IndexingStatus.WAITING,),
|
||||
"indexing": (
|
||||
Document.indexing_status.in_(_INDEXING_STATUSES),
|
||||
Document.is_paused.is_not(True),
|
||||
@ -1241,19 +1254,19 @@ class DocumentService:
|
||||
Document.indexing_status.in_(_INDEXING_STATUSES),
|
||||
Document.is_paused.is_(True),
|
||||
),
|
||||
"error": (Document.indexing_status == "error",),
|
||||
"error": (Document.indexing_status == IndexingStatus.ERROR,),
|
||||
"available": (
|
||||
Document.indexing_status == "completed",
|
||||
Document.indexing_status == IndexingStatus.COMPLETED,
|
||||
Document.archived.is_(False),
|
||||
Document.enabled.is_(True),
|
||||
),
|
||||
"disabled": (
|
||||
Document.indexing_status == "completed",
|
||||
Document.indexing_status == IndexingStatus.COMPLETED,
|
||||
Document.archived.is_(False),
|
||||
Document.enabled.is_(False),
|
||||
),
|
||||
"archived": (
|
||||
Document.indexing_status == "completed",
|
||||
Document.indexing_status == IndexingStatus.COMPLETED,
|
||||
Document.archived.is_(True),
|
||||
),
|
||||
}
|
||||
@ -1536,7 +1549,7 @@ class DocumentService:
|
||||
"""
|
||||
Normalize and validate `Document -> UploadFile` linkage for download flows.
|
||||
"""
|
||||
if document.data_source_type != "upload_file":
|
||||
if document.data_source_type != DataSourceType.UPLOAD_FILE:
|
||||
raise NotFound(invalid_source_message)
|
||||
|
||||
data_source_info: dict[str, Any] = document.data_source_info_dict or {}
|
||||
@ -1617,7 +1630,7 @@ class DocumentService:
|
||||
select(Document).where(
|
||||
Document.id.in_(document_ids),
|
||||
Document.enabled == True,
|
||||
Document.indexing_status == "completed",
|
||||
Document.indexing_status == IndexingStatus.COMPLETED,
|
||||
Document.archived == False,
|
||||
)
|
||||
).all()
|
||||
@ -1640,7 +1653,7 @@ class DocumentService:
|
||||
select(Document).where(
|
||||
Document.dataset_id == dataset_id,
|
||||
Document.enabled == True,
|
||||
Document.indexing_status == "completed",
|
||||
Document.indexing_status == IndexingStatus.COMPLETED,
|
||||
Document.archived == False,
|
||||
)
|
||||
).all()
|
||||
@ -1650,7 +1663,10 @@ class DocumentService:
|
||||
@staticmethod
|
||||
def get_error_documents_by_dataset_id(dataset_id: str) -> Sequence[Document]:
|
||||
documents = db.session.scalars(
|
||||
select(Document).where(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"]))
|
||||
select(Document).where(
|
||||
Document.dataset_id == dataset_id,
|
||||
Document.indexing_status.in_([IndexingStatus.ERROR, IndexingStatus.PAUSED]),
|
||||
)
|
||||
).all()
|
||||
return documents
|
||||
|
||||
@ -1683,7 +1699,7 @@ class DocumentService:
|
||||
def delete_document(document):
|
||||
# trigger document_was_deleted signal
|
||||
file_id = None
|
||||
if document.data_source_type == "upload_file":
|
||||
if document.data_source_type == DataSourceType.UPLOAD_FILE:
|
||||
if document.data_source_info:
|
||||
data_source_info = document.data_source_info_dict
|
||||
if data_source_info and "upload_file_id" in data_source_info:
|
||||
@ -1704,7 +1720,7 @@ class DocumentService:
|
||||
file_ids = [
|
||||
document.data_source_info_dict.get("upload_file_id", "")
|
||||
for document in documents
|
||||
if document.data_source_type == "upload_file" and document.data_source_info_dict
|
||||
if document.data_source_type == DataSourceType.UPLOAD_FILE and document.data_source_info_dict
|
||||
]
|
||||
|
||||
# Delete documents first, then dispatch cleanup task after commit
|
||||
@ -1753,7 +1769,13 @@ class DocumentService:
|
||||
|
||||
@staticmethod
|
||||
def pause_document(document):
|
||||
if document.indexing_status not in {"waiting", "parsing", "cleaning", "splitting", "indexing"}:
|
||||
if document.indexing_status not in {
|
||||
IndexingStatus.WAITING,
|
||||
IndexingStatus.PARSING,
|
||||
IndexingStatus.CLEANING,
|
||||
IndexingStatus.SPLITTING,
|
||||
IndexingStatus.INDEXING,
|
||||
}:
|
||||
raise DocumentIndexingError()
|
||||
# update document to be paused
|
||||
assert current_user is not None
|
||||
@ -1793,7 +1815,7 @@ class DocumentService:
|
||||
if cache_result is not None:
|
||||
raise ValueError("Document is being retried, please try again later")
|
||||
# retry document indexing
|
||||
document.indexing_status = "waiting"
|
||||
document.indexing_status = IndexingStatus.WAITING
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
@ -1812,7 +1834,7 @@ class DocumentService:
|
||||
if cache_result is not None:
|
||||
raise ValueError("Document is being synced, please try again later")
|
||||
# sync document indexing
|
||||
document.indexing_status = "waiting"
|
||||
document.indexing_status = IndexingStatus.WAITING
|
||||
data_source_info = document.data_source_info_dict
|
||||
if data_source_info:
|
||||
data_source_info["mode"] = "scrape"
|
||||
@ -1840,7 +1862,7 @@ class DocumentService:
|
||||
knowledge_config: KnowledgeConfig,
|
||||
account: Account | Any,
|
||||
dataset_process_rule: DatasetProcessRule | None = None,
|
||||
created_from: str = "web",
|
||||
created_from: str = DocumentCreatedFrom.WEB,
|
||||
) -> tuple[list[Document], str]:
|
||||
# check doc_form
|
||||
DatasetService.check_doc_form(dataset, knowledge_config.doc_form)
|
||||
@ -1932,7 +1954,7 @@ class DocumentService:
|
||||
if not dataset_process_rule:
|
||||
process_rule = knowledge_config.process_rule
|
||||
if process_rule:
|
||||
if process_rule.mode in ("custom", "hierarchical"):
|
||||
if process_rule.mode in (ProcessRuleMode.CUSTOM, ProcessRuleMode.HIERARCHICAL):
|
||||
if process_rule.rules:
|
||||
dataset_process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
@ -1944,7 +1966,7 @@ class DocumentService:
|
||||
dataset_process_rule = dataset.latest_process_rule
|
||||
if not dataset_process_rule:
|
||||
raise ValueError("No process rule found.")
|
||||
elif process_rule.mode == "automatic":
|
||||
elif process_rule.mode == ProcessRuleMode.AUTOMATIC:
|
||||
dataset_process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
mode=process_rule.mode,
|
||||
@ -1967,7 +1989,7 @@ class DocumentService:
|
||||
if not dataset_process_rule:
|
||||
dataset_process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
mode="automatic",
|
||||
mode=ProcessRuleMode.AUTOMATIC,
|
||||
rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
|
||||
created_by=account.id,
|
||||
)
|
||||
@ -2001,7 +2023,7 @@ class DocumentService:
|
||||
.where(
|
||||
Document.dataset_id == dataset.id,
|
||||
Document.tenant_id == current_user.current_tenant_id,
|
||||
Document.data_source_type == "upload_file",
|
||||
Document.data_source_type == DataSourceType.UPLOAD_FILE,
|
||||
Document.enabled == True,
|
||||
Document.name.in_(file_names),
|
||||
)
|
||||
@ -2021,7 +2043,7 @@ class DocumentService:
|
||||
document.doc_language = knowledge_config.doc_language
|
||||
document.data_source_info = json.dumps(data_source_info)
|
||||
document.batch = batch
|
||||
document.indexing_status = "waiting"
|
||||
document.indexing_status = IndexingStatus.WAITING
|
||||
db.session.add(document)
|
||||
documents.append(document)
|
||||
duplicate_document_ids.append(document.id)
|
||||
@ -2056,7 +2078,7 @@ class DocumentService:
|
||||
.filter_by(
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data_source_type="notion_import",
|
||||
data_source_type=DataSourceType.NOTION_IMPORT,
|
||||
enabled=True,
|
||||
)
|
||||
.all()
|
||||
@ -2507,7 +2529,7 @@ class DocumentService:
|
||||
document_data: KnowledgeConfig,
|
||||
account: Account,
|
||||
dataset_process_rule: DatasetProcessRule | None = None,
|
||||
created_from: str = "web",
|
||||
created_from: str = DocumentCreatedFrom.WEB,
|
||||
):
|
||||
assert isinstance(current_user, Account)
|
||||
|
||||
@ -2520,14 +2542,14 @@ class DocumentService:
|
||||
# save process rule
|
||||
if document_data.process_rule:
|
||||
process_rule = document_data.process_rule
|
||||
if process_rule.mode in {"custom", "hierarchical"}:
|
||||
if process_rule.mode in {ProcessRuleMode.CUSTOM, ProcessRuleMode.HIERARCHICAL}:
|
||||
dataset_process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
mode=process_rule.mode,
|
||||
rules=process_rule.rules.model_dump_json() if process_rule.rules else None,
|
||||
created_by=account.id,
|
||||
)
|
||||
elif process_rule.mode == "automatic":
|
||||
elif process_rule.mode == ProcessRuleMode.AUTOMATIC:
|
||||
dataset_process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
mode=process_rule.mode,
|
||||
@ -2609,7 +2631,7 @@ class DocumentService:
|
||||
if document_data.name:
|
||||
document.name = document_data.name
|
||||
# update document to be waiting
|
||||
document.indexing_status = "waiting"
|
||||
document.indexing_status = IndexingStatus.WAITING
|
||||
document.completed_at = None
|
||||
document.processing_started_at = None
|
||||
document.parsing_completed_at = None
|
||||
@ -2623,7 +2645,7 @@ class DocumentService:
|
||||
# update document segment
|
||||
|
||||
db.session.query(DocumentSegment).filter_by(document_id=document.id).update(
|
||||
{DocumentSegment.status: "re_segment"}
|
||||
{DocumentSegment.status: SegmentStatus.RE_SEGMENT}
|
||||
)
|
||||
db.session.commit()
|
||||
# trigger async task
|
||||
@ -2754,7 +2776,7 @@ class DocumentService:
|
||||
if knowledge_config.process_rule.mode not in DatasetProcessRule.MODES:
|
||||
raise ValueError("Process rule mode is invalid")
|
||||
|
||||
if knowledge_config.process_rule.mode == "automatic":
|
||||
if knowledge_config.process_rule.mode == ProcessRuleMode.AUTOMATIC:
|
||||
knowledge_config.process_rule.rules = None
|
||||
else:
|
||||
if not knowledge_config.process_rule.rules:
|
||||
@ -2785,7 +2807,7 @@ class DocumentService:
|
||||
raise ValueError("Process rule segmentation separator is invalid")
|
||||
|
||||
if not (
|
||||
knowledge_config.process_rule.mode == "hierarchical"
|
||||
knowledge_config.process_rule.mode == ProcessRuleMode.HIERARCHICAL
|
||||
and knowledge_config.process_rule.rules.parent_mode == "full-doc"
|
||||
):
|
||||
if not knowledge_config.process_rule.rules.segmentation.max_tokens:
|
||||
@ -2814,7 +2836,7 @@ class DocumentService:
|
||||
if args["process_rule"]["mode"] not in DatasetProcessRule.MODES:
|
||||
raise ValueError("Process rule mode is invalid")
|
||||
|
||||
if args["process_rule"]["mode"] == "automatic":
|
||||
if args["process_rule"]["mode"] == ProcessRuleMode.AUTOMATIC:
|
||||
args["process_rule"]["rules"] = {}
|
||||
else:
|
||||
if "rules" not in args["process_rule"] or not args["process_rule"]["rules"]:
|
||||
@ -3021,7 +3043,7 @@ class DocumentService:
|
||||
@staticmethod
|
||||
def _prepare_disable_update(document, user, now):
|
||||
"""Prepare updates for disabling a document."""
|
||||
if not document.completed_at or document.indexing_status != "completed":
|
||||
if not document.completed_at or document.indexing_status != IndexingStatus.COMPLETED:
|
||||
raise DocumentIndexingError(f"Document: {document.name} is not completed.")
|
||||
|
||||
if not document.enabled:
|
||||
@ -3130,7 +3152,7 @@ class SegmentService:
|
||||
content=content,
|
||||
word_count=len(content),
|
||||
tokens=tokens,
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
indexing_at=naive_utc_now(),
|
||||
completed_at=naive_utc_now(),
|
||||
created_by=current_user.id,
|
||||
@ -3167,7 +3189,7 @@ class SegmentService:
|
||||
logger.exception("create segment index failed")
|
||||
segment_document.enabled = False
|
||||
segment_document.disabled_at = naive_utc_now()
|
||||
segment_document.status = "error"
|
||||
segment_document.status = SegmentStatus.ERROR
|
||||
segment_document.error = str(e)
|
||||
db.session.commit()
|
||||
segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_document.id).first()
|
||||
@ -3227,7 +3249,7 @@ class SegmentService:
|
||||
word_count=len(content),
|
||||
tokens=tokens,
|
||||
keywords=segment_item.get("keywords", []),
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
indexing_at=naive_utc_now(),
|
||||
completed_at=naive_utc_now(),
|
||||
created_by=current_user.id,
|
||||
@ -3259,7 +3281,7 @@ class SegmentService:
|
||||
for segment_document in segment_data_list:
|
||||
segment_document.enabled = False
|
||||
segment_document.disabled_at = naive_utc_now()
|
||||
segment_document.status = "error"
|
||||
segment_document.status = SegmentStatus.ERROR
|
||||
segment_document.error = str(e)
|
||||
db.session.commit()
|
||||
return segment_data_list
|
||||
@ -3405,7 +3427,7 @@ class SegmentService:
|
||||
segment.index_node_hash = segment_hash
|
||||
segment.word_count = len(content)
|
||||
segment.tokens = tokens
|
||||
segment.status = "completed"
|
||||
segment.status = SegmentStatus.COMPLETED
|
||||
segment.indexing_at = naive_utc_now()
|
||||
segment.completed_at = naive_utc_now()
|
||||
segment.updated_by = current_user.id
|
||||
@ -3530,7 +3552,7 @@ class SegmentService:
|
||||
logger.exception("update segment index failed")
|
||||
segment.enabled = False
|
||||
segment.disabled_at = naive_utc_now()
|
||||
segment.status = "error"
|
||||
segment.status = SegmentStatus.ERROR
|
||||
segment.error = str(e)
|
||||
db.session.commit()
|
||||
new_segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment.id).first()
|
||||
|
||||
@ -13,7 +13,7 @@ from dify_graph.model_runtime.entities import LLMMode
|
||||
from extensions.ext_database import db
|
||||
from models import Account
|
||||
from models.dataset import Dataset, DatasetQuery
|
||||
from models.enums import CreatorUserRole
|
||||
from models.enums import CreatorUserRole, DatasetQuerySource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -97,7 +97,7 @@ class HitTestingService:
|
||||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset.id,
|
||||
content=json.dumps(dataset_queries),
|
||||
source="hit_testing",
|
||||
source=DatasetQuerySource.HIT_TESTING,
|
||||
source_app_id=None,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
@ -137,7 +137,7 @@ class HitTestingService:
|
||||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset.id,
|
||||
content=query,
|
||||
source="hit_testing",
|
||||
source=DatasetQuerySource.HIT_TESTING,
|
||||
source_app_id=None,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
|
||||
@ -7,6 +7,7 @@ from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding
|
||||
from models.enums import DatasetMetadataType
|
||||
from services.dataset_service import DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
MetadataArgs,
|
||||
@ -130,11 +131,11 @@ class MetadataService:
|
||||
@staticmethod
|
||||
def get_built_in_fields():
|
||||
return [
|
||||
{"name": BuiltInField.document_name, "type": "string"},
|
||||
{"name": BuiltInField.uploader, "type": "string"},
|
||||
{"name": BuiltInField.upload_date, "type": "time"},
|
||||
{"name": BuiltInField.last_update_date, "type": "time"},
|
||||
{"name": BuiltInField.source, "type": "string"},
|
||||
{"name": BuiltInField.document_name, "type": DatasetMetadataType.STRING},
|
||||
{"name": BuiltInField.uploader, "type": DatasetMetadataType.STRING},
|
||||
{"name": BuiltInField.upload_date, "type": DatasetMetadataType.TIME},
|
||||
{"name": BuiltInField.last_update_date, "type": DatasetMetadataType.TIME},
|
||||
{"name": BuiltInField.source, "type": DatasetMetadataType.STRING},
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -19,6 +19,7 @@ from dify_graph.model_runtime.entities.provider_entities import (
|
||||
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import CredentialSourceType
|
||||
from models.provider import LoadBalancingModelConfig, ProviderCredential, ProviderModelCredential
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -103,9 +104,9 @@ class ModelLoadBalancingService:
|
||||
is_load_balancing_enabled = True
|
||||
|
||||
if config_from == "predefined-model":
|
||||
credential_source_type = "provider"
|
||||
credential_source_type = CredentialSourceType.PROVIDER
|
||||
else:
|
||||
credential_source_type = "custom_model"
|
||||
credential_source_type = CredentialSourceType.CUSTOM_MODEL
|
||||
|
||||
# Get load balancing configurations
|
||||
load_balancing_configs = (
|
||||
@ -421,7 +422,11 @@ class ModelLoadBalancingService:
|
||||
raise ValueError("Invalid load balancing config name")
|
||||
|
||||
if credential_id:
|
||||
credential_source = "provider" if config_from == "predefined-model" else "custom_model"
|
||||
credential_source = (
|
||||
CredentialSourceType.PROVIDER
|
||||
if config_from == "predefined-model"
|
||||
else CredentialSourceType.CUSTOM_MODEL
|
||||
)
|
||||
assert credential_record is not None
|
||||
load_balancing_model_config = LoadBalancingModelConfig(
|
||||
tenant_id=tenant_id,
|
||||
|
||||
@ -30,7 +30,7 @@ from core.plugin.impl.debugging import PluginDebuggingClient
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.provider import Provider, ProviderCredential
|
||||
from models.provider import Provider, ProviderCredential, TenantPreferredModelProvider
|
||||
from models.provider_ids import GenericProviderID
|
||||
from services.enterprise.plugin_manager_service import (
|
||||
PluginManagerService,
|
||||
@ -534,6 +534,13 @@ class PluginService:
|
||||
plugin_id = plugin.plugin_id
|
||||
logger.info("Deleting credentials for plugin: %s", plugin_id)
|
||||
|
||||
session.execute(
|
||||
delete(TenantPreferredModelProvider).where(
|
||||
TenantPreferredModelProvider.tenant_id == tenant_id,
|
||||
TenantPreferredModelProvider.provider_name.like(f"{plugin_id}/%"),
|
||||
)
|
||||
)
|
||||
|
||||
# Delete provider credentials that match this plugin
|
||||
credential_ids = session.scalars(
|
||||
select(ProviderCredential.id).where(
|
||||
|
||||
@ -6,6 +6,7 @@ from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Document, Pipeline
|
||||
from models.enums import IndexingStatus
|
||||
from models.model import Account, App, EndUser
|
||||
from models.workflow import Workflow
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
@ -111,6 +112,6 @@ class PipelineGenerateService:
|
||||
"""
|
||||
document = db.session.query(Document).where(Document.id == document_id).first()
|
||||
if document:
|
||||
document.indexing_status = "waiting"
|
||||
document.indexing_status = IndexingStatus.WAITING
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
@ -15,7 +15,8 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
Retrieval recommended app from dify official
|
||||
"""
|
||||
|
||||
def get_pipeline_template_detail(self, template_id: str):
|
||||
def get_pipeline_template_detail(self, template_id: str) -> dict | None:
|
||||
result: dict | None
|
||||
try:
|
||||
result = self.fetch_pipeline_template_detail_from_dify_official(template_id)
|
||||
except Exception as e:
|
||||
@ -35,17 +36,23 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
return PipelineTemplateType.REMOTE
|
||||
|
||||
@classmethod
|
||||
def fetch_pipeline_template_detail_from_dify_official(cls, template_id: str) -> dict | None:
|
||||
def fetch_pipeline_template_detail_from_dify_official(cls, template_id: str) -> dict:
|
||||
"""
|
||||
Fetch pipeline template detail from dify official.
|
||||
:param template_id: Pipeline ID
|
||||
:return:
|
||||
|
||||
:param template_id: Pipeline template ID
|
||||
:return: Template detail dict
|
||||
:raises ValueError: When upstream returns a non-200 status code
|
||||
"""
|
||||
domain = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN
|
||||
url = f"{domain}/pipeline-templates/{template_id}"
|
||||
response = httpx.get(url, timeout=httpx.Timeout(10.0, connect=3.0))
|
||||
if response.status_code != 200:
|
||||
return None
|
||||
raise ValueError(
|
||||
"fetch pipeline template detail failed,"
|
||||
+ f" status_code: {response.status_code},"
|
||||
+ f" response: {response.text[:1000]}"
|
||||
)
|
||||
data: dict = response.json()
|
||||
return data
|
||||
|
||||
|
||||
@ -64,7 +64,7 @@ from models.dataset import ( # type: ignore
|
||||
PipelineCustomizedTemplate,
|
||||
PipelineRecommendedPlugin,
|
||||
)
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.enums import IndexingStatus, WorkflowRunTriggeredFrom
|
||||
from models.model import EndUser
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
@ -117,13 +117,21 @@ class RagPipelineService:
|
||||
def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> dict | None:
|
||||
"""
|
||||
Get pipeline template detail.
|
||||
|
||||
:param template_id: template id
|
||||
:return:
|
||||
:param type: template type, "built-in" or "customized"
|
||||
:return: template detail dict, or None if not found
|
||||
"""
|
||||
if type == "built-in":
|
||||
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
|
||||
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
|
||||
built_in_result: dict | None = retrieval_instance.get_pipeline_template_detail(template_id)
|
||||
if built_in_result is None:
|
||||
logger.warning(
|
||||
"pipeline template retrieval returned empty result, template_id: %s, mode: %s",
|
||||
template_id,
|
||||
mode,
|
||||
)
|
||||
return built_in_result
|
||||
else:
|
||||
mode = "customized"
|
||||
@ -906,7 +914,7 @@ class RagPipelineService:
|
||||
if document_id:
|
||||
document = db.session.query(Document).where(Document.id == document_id.value).first()
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.indexing_status = IndexingStatus.ERROR
|
||||
document.error = error
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
@ -35,6 +35,7 @@ from extensions.ext_redis import redis_client
|
||||
from factories import variable_factory
|
||||
from models import Account
|
||||
from models.dataset import Dataset, DatasetCollectionBinding, Pipeline
|
||||
from models.enums import CollectionBindingType, DatasetRuntimeMode
|
||||
from models.workflow import Workflow, WorkflowType
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import (
|
||||
IconInfo,
|
||||
@ -313,7 +314,7 @@ class RagPipelineDslService:
|
||||
indexing_technique=knowledge_configuration.indexing_technique,
|
||||
created_by=account.id,
|
||||
retrieval_model=knowledge_configuration.retrieval_model.model_dump(),
|
||||
runtime_mode="rag_pipeline",
|
||||
runtime_mode=DatasetRuntimeMode.RAG_PIPELINE,
|
||||
chunk_structure=knowledge_configuration.chunk_structure,
|
||||
)
|
||||
if knowledge_configuration.indexing_technique == "high_quality":
|
||||
@ -323,7 +324,7 @@ class RagPipelineDslService:
|
||||
DatasetCollectionBinding.provider_name
|
||||
== knowledge_configuration.embedding_model_provider,
|
||||
DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model,
|
||||
DatasetCollectionBinding.type == "dataset",
|
||||
DatasetCollectionBinding.type == CollectionBindingType.DATASET,
|
||||
)
|
||||
.order_by(DatasetCollectionBinding.created_at)
|
||||
.first()
|
||||
@ -334,7 +335,7 @@ class RagPipelineDslService:
|
||||
provider_name=knowledge_configuration.embedding_model_provider,
|
||||
model_name=knowledge_configuration.embedding_model,
|
||||
collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())),
|
||||
type="dataset",
|
||||
type=CollectionBindingType.DATASET,
|
||||
)
|
||||
self._session.add(dataset_collection_binding)
|
||||
self._session.commit()
|
||||
@ -445,13 +446,13 @@ class RagPipelineDslService:
|
||||
indexing_technique=knowledge_configuration.indexing_technique,
|
||||
created_by=account.id,
|
||||
retrieval_model=knowledge_configuration.retrieval_model.model_dump(),
|
||||
runtime_mode="rag_pipeline",
|
||||
runtime_mode=DatasetRuntimeMode.RAG_PIPELINE,
|
||||
chunk_structure=knowledge_configuration.chunk_structure,
|
||||
)
|
||||
else:
|
||||
dataset.indexing_technique = knowledge_configuration.indexing_technique
|
||||
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
|
||||
dataset.runtime_mode = "rag_pipeline"
|
||||
dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE
|
||||
dataset.chunk_structure = knowledge_configuration.chunk_structure
|
||||
if knowledge_configuration.indexing_technique == "high_quality":
|
||||
dataset_collection_binding = (
|
||||
@ -460,7 +461,7 @@ class RagPipelineDslService:
|
||||
DatasetCollectionBinding.provider_name
|
||||
== knowledge_configuration.embedding_model_provider,
|
||||
DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model,
|
||||
DatasetCollectionBinding.type == "dataset",
|
||||
DatasetCollectionBinding.type == CollectionBindingType.DATASET,
|
||||
)
|
||||
.order_by(DatasetCollectionBinding.created_at)
|
||||
.first()
|
||||
@ -471,7 +472,7 @@ class RagPipelineDslService:
|
||||
provider_name=knowledge_configuration.embedding_model_provider,
|
||||
model_name=knowledge_configuration.embedding_model,
|
||||
collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())),
|
||||
type="dataset",
|
||||
type=CollectionBindingType.DATASET,
|
||||
)
|
||||
self._session.add(dataset_collection_binding)
|
||||
self._session.commit()
|
||||
|
||||
@ -13,6 +13,7 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from factories import variable_factory
|
||||
from models.dataset import Dataset, Document, DocumentPipelineExecutionLog, Pipeline
|
||||
from models.enums import DatasetRuntimeMode, DataSourceType
|
||||
from models.model import UploadFile
|
||||
from models.workflow import Workflow, WorkflowType
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeConfiguration, RetrievalSetting
|
||||
@ -27,7 +28,7 @@ class RagPipelineTransformService:
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
if dataset.pipeline_id and dataset.runtime_mode == "rag_pipeline":
|
||||
if dataset.pipeline_id and dataset.runtime_mode == DatasetRuntimeMode.RAG_PIPELINE:
|
||||
return {
|
||||
"pipeline_id": dataset.pipeline_id,
|
||||
"dataset_id": dataset_id,
|
||||
@ -85,7 +86,7 @@ class RagPipelineTransformService:
|
||||
else:
|
||||
raise ValueError("Unsupported doc form")
|
||||
|
||||
dataset.runtime_mode = "rag_pipeline"
|
||||
dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE
|
||||
dataset.pipeline_id = pipeline.id
|
||||
|
||||
# deal document data
|
||||
@ -102,7 +103,7 @@ class RagPipelineTransformService:
|
||||
pipeline_yaml = {}
|
||||
if doc_form == "text_model":
|
||||
match datasource_type:
|
||||
case "upload_file":
|
||||
case DataSourceType.UPLOAD_FILE:
|
||||
if indexing_technique == "high_quality":
|
||||
# get graph from transform.file-general-high-quality.yml
|
||||
with open(f"{Path(__file__).parent}/transform/file-general-high-quality.yml") as f:
|
||||
@ -111,7 +112,7 @@ class RagPipelineTransformService:
|
||||
# get graph from transform.file-general-economy.yml
|
||||
with open(f"{Path(__file__).parent}/transform/file-general-economy.yml") as f:
|
||||
pipeline_yaml = yaml.safe_load(f)
|
||||
case "notion_import":
|
||||
case DataSourceType.NOTION_IMPORT:
|
||||
if indexing_technique == "high_quality":
|
||||
# get graph from transform.notion-general-high-quality.yml
|
||||
with open(f"{Path(__file__).parent}/transform/notion-general-high-quality.yml") as f:
|
||||
@ -120,7 +121,7 @@ class RagPipelineTransformService:
|
||||
# get graph from transform.notion-general-economy.yml
|
||||
with open(f"{Path(__file__).parent}/transform/notion-general-economy.yml") as f:
|
||||
pipeline_yaml = yaml.safe_load(f)
|
||||
case "website_crawl":
|
||||
case DataSourceType.WEBSITE_CRAWL:
|
||||
if indexing_technique == "high_quality":
|
||||
# get graph from transform.website-crawl-general-high-quality.yml
|
||||
with open(f"{Path(__file__).parent}/transform/website-crawl-general-high-quality.yml") as f:
|
||||
@ -133,15 +134,15 @@ class RagPipelineTransformService:
|
||||
raise ValueError("Unsupported datasource type")
|
||||
elif doc_form == "hierarchical_model":
|
||||
match datasource_type:
|
||||
case "upload_file":
|
||||
case DataSourceType.UPLOAD_FILE:
|
||||
# get graph from transform.file-parentchild.yml
|
||||
with open(f"{Path(__file__).parent}/transform/file-parentchild.yml") as f:
|
||||
pipeline_yaml = yaml.safe_load(f)
|
||||
case "notion_import":
|
||||
case DataSourceType.NOTION_IMPORT:
|
||||
# get graph from transform.notion-parentchild.yml
|
||||
with open(f"{Path(__file__).parent}/transform/notion-parentchild.yml") as f:
|
||||
pipeline_yaml = yaml.safe_load(f)
|
||||
case "website_crawl":
|
||||
case DataSourceType.WEBSITE_CRAWL:
|
||||
# get graph from transform.website-crawl-parentchild.yml
|
||||
with open(f"{Path(__file__).parent}/transform/website-crawl-parentchild.yml") as f:
|
||||
pipeline_yaml = yaml.safe_load(f)
|
||||
@ -287,7 +288,7 @@ class RagPipelineTransformService:
|
||||
db.session.flush()
|
||||
|
||||
dataset.pipeline_id = pipeline.id
|
||||
dataset.runtime_mode = "rag_pipeline"
|
||||
dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE
|
||||
dataset.updated_by = current_user.id
|
||||
dataset.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db.session.add(dataset)
|
||||
@ -310,8 +311,8 @@ class RagPipelineTransformService:
|
||||
data_source_info_dict = document.data_source_info_dict
|
||||
if not data_source_info_dict:
|
||||
continue
|
||||
if document.data_source_type == "upload_file":
|
||||
document.data_source_type = "local_file"
|
||||
if document.data_source_type == DataSourceType.UPLOAD_FILE:
|
||||
document.data_source_type = DataSourceType.LOCAL_FILE
|
||||
file_id = data_source_info_dict.get("upload_file_id")
|
||||
if file_id:
|
||||
file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
@ -331,7 +332,7 @@ class RagPipelineTransformService:
|
||||
document_pipeline_execution_log = DocumentPipelineExecutionLog(
|
||||
document_id=document.id,
|
||||
pipeline_id=dataset.pipeline_id,
|
||||
datasource_type="local_file",
|
||||
datasource_type=DataSourceType.LOCAL_FILE,
|
||||
datasource_info=data_source_info,
|
||||
input_data={},
|
||||
created_by=document.created_by,
|
||||
@ -340,8 +341,8 @@ class RagPipelineTransformService:
|
||||
document_pipeline_execution_log.created_at = document.created_at
|
||||
db.session.add(document)
|
||||
db.session.add(document_pipeline_execution_log)
|
||||
elif document.data_source_type == "notion_import":
|
||||
document.data_source_type = "online_document"
|
||||
elif document.data_source_type == DataSourceType.NOTION_IMPORT:
|
||||
document.data_source_type = DataSourceType.ONLINE_DOCUMENT
|
||||
data_source_info = json.dumps(
|
||||
{
|
||||
"workspace_id": data_source_info_dict.get("notion_workspace_id"),
|
||||
@ -359,7 +360,7 @@ class RagPipelineTransformService:
|
||||
document_pipeline_execution_log = DocumentPipelineExecutionLog(
|
||||
document_id=document.id,
|
||||
pipeline_id=dataset.pipeline_id,
|
||||
datasource_type="online_document",
|
||||
datasource_type=DataSourceType.ONLINE_DOCUMENT,
|
||||
datasource_info=data_source_info,
|
||||
input_data={},
|
||||
created_by=document.created_by,
|
||||
@ -368,8 +369,7 @@ class RagPipelineTransformService:
|
||||
document_pipeline_execution_log.created_at = document.created_at
|
||||
db.session.add(document)
|
||||
db.session.add(document_pipeline_execution_log)
|
||||
elif document.data_source_type == "website_crawl":
|
||||
document.data_source_type = "website_crawl"
|
||||
elif document.data_source_type == DataSourceType.WEBSITE_CRAWL:
|
||||
data_source_info = json.dumps(
|
||||
{
|
||||
"source_url": data_source_info_dict.get("url"),
|
||||
@ -388,7 +388,7 @@ class RagPipelineTransformService:
|
||||
document_pipeline_execution_log = DocumentPipelineExecutionLog(
|
||||
document_id=document.id,
|
||||
pipeline_id=dataset.pipeline_id,
|
||||
datasource_type="website_crawl",
|
||||
datasource_type=DataSourceType.WEBSITE_CRAWL,
|
||||
datasource_info=data_source_info,
|
||||
input_data={},
|
||||
created_by=document.created_by,
|
||||
|
||||
@ -1,16 +1,16 @@
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import delete, select, tuple_
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.model import (
|
||||
@ -33,6 +33,131 @@ from services.retention.conversation.messages_clean_policy import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.metrics import Counter, Histogram
|
||||
|
||||
|
||||
class MessagesCleanupMetrics:
|
||||
"""
|
||||
Records low-cardinality OpenTelemetry metrics for expired message cleanup jobs.
|
||||
|
||||
We keep labels stable (dry_run/window_mode/task_label/status) so these metrics remain
|
||||
dashboard-friendly for long-running CronJob executions.
|
||||
"""
|
||||
|
||||
_job_runs_total: "Counter | None"
|
||||
_batches_total: "Counter | None"
|
||||
_messages_scanned_total: "Counter | None"
|
||||
_messages_filtered_total: "Counter | None"
|
||||
_messages_deleted_total: "Counter | None"
|
||||
_job_duration_seconds: "Histogram | None"
|
||||
_batch_duration_seconds: "Histogram | None"
|
||||
_base_attributes: dict[str, str]
|
||||
|
||||
def __init__(self, *, dry_run: bool, has_window: bool, task_label: str) -> None:
|
||||
self._job_runs_total = None
|
||||
self._batches_total = None
|
||||
self._messages_scanned_total = None
|
||||
self._messages_filtered_total = None
|
||||
self._messages_deleted_total = None
|
||||
self._job_duration_seconds = None
|
||||
self._batch_duration_seconds = None
|
||||
self._base_attributes = {
|
||||
"job_name": "messages_cleanup",
|
||||
"dry_run": str(dry_run).lower(),
|
||||
"window_mode": "between" if has_window else "before_cutoff",
|
||||
"task_label": task_label,
|
||||
}
|
||||
self._init_instruments()
|
||||
|
||||
def _init_instruments(self) -> None:
|
||||
if not dify_config.ENABLE_OTEL:
|
||||
return
|
||||
|
||||
try:
|
||||
from opentelemetry.metrics import get_meter
|
||||
|
||||
meter = get_meter("messages_cleanup", version=dify_config.project.version)
|
||||
self._job_runs_total = meter.create_counter(
|
||||
"messages_cleanup_jobs_total",
|
||||
description="Total number of expired message cleanup jobs by status.",
|
||||
unit="{job}",
|
||||
)
|
||||
self._batches_total = meter.create_counter(
|
||||
"messages_cleanup_batches_total",
|
||||
description="Total number of message cleanup batches processed.",
|
||||
unit="{batch}",
|
||||
)
|
||||
self._messages_scanned_total = meter.create_counter(
|
||||
"messages_cleanup_scanned_messages_total",
|
||||
description="Total messages scanned by cleanup jobs.",
|
||||
unit="{message}",
|
||||
)
|
||||
self._messages_filtered_total = meter.create_counter(
|
||||
"messages_cleanup_filtered_messages_total",
|
||||
description="Total messages selected by cleanup policy.",
|
||||
unit="{message}",
|
||||
)
|
||||
self._messages_deleted_total = meter.create_counter(
|
||||
"messages_cleanup_deleted_messages_total",
|
||||
description="Total messages deleted by cleanup jobs.",
|
||||
unit="{message}",
|
||||
)
|
||||
self._job_duration_seconds = meter.create_histogram(
|
||||
"messages_cleanup_job_duration_seconds",
|
||||
description="Duration of expired message cleanup jobs in seconds.",
|
||||
unit="s",
|
||||
)
|
||||
self._batch_duration_seconds = meter.create_histogram(
|
||||
"messages_cleanup_batch_duration_seconds",
|
||||
description="Duration of expired message cleanup batch processing in seconds.",
|
||||
unit="s",
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("messages_cleanup_metrics: failed to initialize instruments")
|
||||
|
||||
def _attrs(self, **extra: str) -> dict[str, str]:
|
||||
return {**self._base_attributes, **extra}
|
||||
|
||||
@staticmethod
|
||||
def _add(counter: "Counter | None", value: int, attributes: dict[str, str]) -> None:
|
||||
if not counter or value <= 0:
|
||||
return
|
||||
try:
|
||||
counter.add(value, attributes)
|
||||
except Exception:
|
||||
logger.exception("messages_cleanup_metrics: failed to add counter value")
|
||||
|
||||
@staticmethod
|
||||
def _record(histogram: "Histogram | None", value: float, attributes: dict[str, str]) -> None:
|
||||
if not histogram:
|
||||
return
|
||||
try:
|
||||
histogram.record(value, attributes)
|
||||
except Exception:
|
||||
logger.exception("messages_cleanup_metrics: failed to record histogram value")
|
||||
|
||||
def record_batch(
|
||||
self,
|
||||
*,
|
||||
scanned_messages: int,
|
||||
filtered_messages: int,
|
||||
deleted_messages: int,
|
||||
batch_duration_seconds: float,
|
||||
) -> None:
|
||||
attributes = self._attrs()
|
||||
self._add(self._batches_total, 1, attributes)
|
||||
self._add(self._messages_scanned_total, scanned_messages, attributes)
|
||||
self._add(self._messages_filtered_total, filtered_messages, attributes)
|
||||
self._add(self._messages_deleted_total, deleted_messages, attributes)
|
||||
self._record(self._batch_duration_seconds, batch_duration_seconds, attributes)
|
||||
|
||||
def record_completion(self, *, status: str, job_duration_seconds: float) -> None:
|
||||
attributes = self._attrs(status=status)
|
||||
self._add(self._job_runs_total, 1, attributes)
|
||||
self._record(self._job_duration_seconds, job_duration_seconds, attributes)
|
||||
|
||||
|
||||
class MessagesCleanService:
|
||||
"""
|
||||
Service for cleaning expired messages based on retention policies.
|
||||
@ -48,6 +173,7 @@ class MessagesCleanService:
|
||||
start_from: datetime.datetime | None = None,
|
||||
batch_size: int = 1000,
|
||||
dry_run: bool = False,
|
||||
task_label: str = "custom",
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the service with cleanup parameters.
|
||||
@ -58,12 +184,18 @@ class MessagesCleanService:
|
||||
start_from: Optional start time (inclusive) of the range
|
||||
batch_size: Number of messages to process per batch
|
||||
dry_run: Whether to perform a dry run (no actual deletion)
|
||||
task_label: Optional task label for retention metrics
|
||||
"""
|
||||
self._policy = policy
|
||||
self._end_before = end_before
|
||||
self._start_from = start_from
|
||||
self._batch_size = batch_size
|
||||
self._dry_run = dry_run
|
||||
self._metrics = MessagesCleanupMetrics(
|
||||
dry_run=dry_run,
|
||||
has_window=bool(start_from),
|
||||
task_label=task_label,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_time_range(
|
||||
@ -73,6 +205,7 @@ class MessagesCleanService:
|
||||
end_before: datetime.datetime,
|
||||
batch_size: int = 1000,
|
||||
dry_run: bool = False,
|
||||
task_label: str = "custom",
|
||||
) -> "MessagesCleanService":
|
||||
"""
|
||||
Create a service instance for cleaning messages within a specific time range.
|
||||
@ -85,6 +218,7 @@ class MessagesCleanService:
|
||||
end_before: End time (exclusive) of the range
|
||||
batch_size: Number of messages to process per batch
|
||||
dry_run: Whether to perform a dry run (no actual deletion)
|
||||
task_label: Optional task label for retention metrics
|
||||
|
||||
Returns:
|
||||
MessagesCleanService instance
|
||||
@ -112,6 +246,7 @@ class MessagesCleanService:
|
||||
start_from=start_from,
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
task_label=task_label,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -121,6 +256,7 @@ class MessagesCleanService:
|
||||
days: int = 30,
|
||||
batch_size: int = 1000,
|
||||
dry_run: bool = False,
|
||||
task_label: str = "custom",
|
||||
) -> "MessagesCleanService":
|
||||
"""
|
||||
Create a service instance for cleaning messages older than specified days.
|
||||
@ -130,6 +266,7 @@ class MessagesCleanService:
|
||||
days: Number of days to look back from now
|
||||
batch_size: Number of messages to process per batch
|
||||
dry_run: Whether to perform a dry run (no actual deletion)
|
||||
task_label: Optional task label for retention metrics
|
||||
|
||||
Returns:
|
||||
MessagesCleanService instance
|
||||
@ -153,7 +290,14 @@ class MessagesCleanService:
|
||||
policy.__class__.__name__,
|
||||
)
|
||||
|
||||
return cls(policy=policy, end_before=end_before, start_from=None, batch_size=batch_size, dry_run=dry_run)
|
||||
return cls(
|
||||
policy=policy,
|
||||
end_before=end_before,
|
||||
start_from=None,
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
task_label=task_label,
|
||||
)
|
||||
|
||||
def run(self) -> dict[str, int]:
|
||||
"""
|
||||
@ -162,7 +306,18 @@ class MessagesCleanService:
|
||||
Returns:
|
||||
Dict with statistics: batches, filtered_messages, total_deleted
|
||||
"""
|
||||
return self._clean_messages_by_time_range()
|
||||
status = "success"
|
||||
run_start = time.monotonic()
|
||||
try:
|
||||
return self._clean_messages_by_time_range()
|
||||
except Exception:
|
||||
status = "failed"
|
||||
raise
|
||||
finally:
|
||||
self._metrics.record_completion(
|
||||
status=status,
|
||||
job_duration_seconds=time.monotonic() - run_start,
|
||||
)
|
||||
|
||||
def _clean_messages_by_time_range(self) -> dict[str, int]:
|
||||
"""
|
||||
@ -197,11 +352,14 @@ class MessagesCleanService:
|
||||
self._end_before,
|
||||
)
|
||||
|
||||
max_batch_interval_ms = int(os.environ.get("SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL", 200))
|
||||
max_batch_interval_ms = dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL
|
||||
|
||||
while True:
|
||||
stats["batches"] += 1
|
||||
batch_start = time.monotonic()
|
||||
batch_scanned_messages = 0
|
||||
batch_filtered_messages = 0
|
||||
batch_deleted_messages = 0
|
||||
|
||||
# Step 1: Fetch a batch of messages using cursor
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
@ -240,9 +398,16 @@ class MessagesCleanService:
|
||||
|
||||
# Track total messages fetched across all batches
|
||||
stats["total_messages"] += len(messages)
|
||||
batch_scanned_messages = len(messages)
|
||||
|
||||
if not messages:
|
||||
logger.info("clean_messages (batch %s): no more messages to process", stats["batches"])
|
||||
self._metrics.record_batch(
|
||||
scanned_messages=batch_scanned_messages,
|
||||
filtered_messages=batch_filtered_messages,
|
||||
deleted_messages=batch_deleted_messages,
|
||||
batch_duration_seconds=time.monotonic() - batch_start,
|
||||
)
|
||||
break
|
||||
|
||||
# Update cursor to the last message's (created_at, id)
|
||||
@ -268,6 +433,12 @@ class MessagesCleanService:
|
||||
|
||||
if not apps:
|
||||
logger.info("clean_messages (batch %s): no apps found, skip", stats["batches"])
|
||||
self._metrics.record_batch(
|
||||
scanned_messages=batch_scanned_messages,
|
||||
filtered_messages=batch_filtered_messages,
|
||||
deleted_messages=batch_deleted_messages,
|
||||
batch_duration_seconds=time.monotonic() - batch_start,
|
||||
)
|
||||
continue
|
||||
|
||||
# Build app_id -> tenant_id mapping
|
||||
@ -286,9 +457,16 @@ class MessagesCleanService:
|
||||
|
||||
if not message_ids_to_delete:
|
||||
logger.info("clean_messages (batch %s): no messages to delete, skip", stats["batches"])
|
||||
self._metrics.record_batch(
|
||||
scanned_messages=batch_scanned_messages,
|
||||
filtered_messages=batch_filtered_messages,
|
||||
deleted_messages=batch_deleted_messages,
|
||||
batch_duration_seconds=time.monotonic() - batch_start,
|
||||
)
|
||||
continue
|
||||
|
||||
stats["filtered_messages"] += len(message_ids_to_delete)
|
||||
batch_filtered_messages = len(message_ids_to_delete)
|
||||
|
||||
# Step 4: Batch delete messages and their relations
|
||||
if not self._dry_run:
|
||||
@ -309,6 +487,7 @@ class MessagesCleanService:
|
||||
commit_ms = int((time.monotonic() - commit_start) * 1000)
|
||||
|
||||
stats["total_deleted"] += messages_deleted
|
||||
batch_deleted_messages = messages_deleted
|
||||
|
||||
logger.info(
|
||||
"clean_messages (batch %s): processed %s messages, deleted %s messages",
|
||||
@ -343,6 +522,13 @@ class MessagesCleanService:
|
||||
for msg_id in sampled_ids:
|
||||
logger.info("clean_messages (batch %s, dry_run) sample: message_id=%s", stats["batches"], msg_id)
|
||||
|
||||
self._metrics.record_batch(
|
||||
scanned_messages=batch_scanned_messages,
|
||||
filtered_messages=batch_filtered_messages,
|
||||
deleted_messages=batch_deleted_messages,
|
||||
batch_duration_seconds=time.monotonic() - batch_start,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"clean_messages completed: total batches: %s, total messages: %s, filtered messages: %s, total deleted: %s",
|
||||
stats["batches"],
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import click
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
@ -20,6 +20,159 @@ from services.billing_service import BillingService, SubscriptionPlan
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.metrics import Counter, Histogram
|
||||
|
||||
|
||||
class WorkflowRunCleanupMetrics:
|
||||
"""
|
||||
Records low-cardinality OpenTelemetry metrics for workflow run cleanup jobs.
|
||||
|
||||
Metrics are emitted with stable labels only (dry_run/window_mode/task_label/status)
|
||||
to keep dashboard and alert cardinality predictable in production clusters.
|
||||
"""
|
||||
|
||||
_job_runs_total: "Counter | None"
|
||||
_batches_total: "Counter | None"
|
||||
_runs_scanned_total: "Counter | None"
|
||||
_runs_targeted_total: "Counter | None"
|
||||
_runs_deleted_total: "Counter | None"
|
||||
_runs_skipped_total: "Counter | None"
|
||||
_related_records_total: "Counter | None"
|
||||
_job_duration_seconds: "Histogram | None"
|
||||
_batch_duration_seconds: "Histogram | None"
|
||||
_base_attributes: dict[str, str]
|
||||
|
||||
def __init__(self, *, dry_run: bool, has_window: bool, task_label: str) -> None:
|
||||
self._job_runs_total = None
|
||||
self._batches_total = None
|
||||
self._runs_scanned_total = None
|
||||
self._runs_targeted_total = None
|
||||
self._runs_deleted_total = None
|
||||
self._runs_skipped_total = None
|
||||
self._related_records_total = None
|
||||
self._job_duration_seconds = None
|
||||
self._batch_duration_seconds = None
|
||||
self._base_attributes = {
|
||||
"job_name": "workflow_run_cleanup",
|
||||
"dry_run": str(dry_run).lower(),
|
||||
"window_mode": "between" if has_window else "before_cutoff",
|
||||
"task_label": task_label,
|
||||
}
|
||||
self._init_instruments()
|
||||
|
||||
def _init_instruments(self) -> None:
|
||||
if not dify_config.ENABLE_OTEL:
|
||||
return
|
||||
|
||||
try:
|
||||
from opentelemetry.metrics import get_meter
|
||||
|
||||
meter = get_meter("workflow_run_cleanup", version=dify_config.project.version)
|
||||
self._job_runs_total = meter.create_counter(
|
||||
"workflow_run_cleanup_jobs_total",
|
||||
description="Total number of workflow run cleanup jobs by status.",
|
||||
unit="{job}",
|
||||
)
|
||||
self._batches_total = meter.create_counter(
|
||||
"workflow_run_cleanup_batches_total",
|
||||
description="Total number of processed cleanup batches.",
|
||||
unit="{batch}",
|
||||
)
|
||||
self._runs_scanned_total = meter.create_counter(
|
||||
"workflow_run_cleanup_scanned_runs_total",
|
||||
description="Total workflow runs scanned by cleanup jobs.",
|
||||
unit="{run}",
|
||||
)
|
||||
self._runs_targeted_total = meter.create_counter(
|
||||
"workflow_run_cleanup_targeted_runs_total",
|
||||
description="Total workflow runs targeted by cleanup policy.",
|
||||
unit="{run}",
|
||||
)
|
||||
self._runs_deleted_total = meter.create_counter(
|
||||
"workflow_run_cleanup_deleted_runs_total",
|
||||
description="Total workflow runs deleted by cleanup jobs.",
|
||||
unit="{run}",
|
||||
)
|
||||
self._runs_skipped_total = meter.create_counter(
|
||||
"workflow_run_cleanup_skipped_runs_total",
|
||||
description="Total workflow runs skipped because tenant is paid/unknown.",
|
||||
unit="{run}",
|
||||
)
|
||||
self._related_records_total = meter.create_counter(
|
||||
"workflow_run_cleanup_related_records_total",
|
||||
description="Total related records processed by cleanup jobs.",
|
||||
unit="{record}",
|
||||
)
|
||||
self._job_duration_seconds = meter.create_histogram(
|
||||
"workflow_run_cleanup_job_duration_seconds",
|
||||
description="Duration of workflow run cleanup jobs in seconds.",
|
||||
unit="s",
|
||||
)
|
||||
self._batch_duration_seconds = meter.create_histogram(
|
||||
"workflow_run_cleanup_batch_duration_seconds",
|
||||
description="Duration of workflow run cleanup batch processing in seconds.",
|
||||
unit="s",
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("workflow_run_cleanup_metrics: failed to initialize instruments")
|
||||
|
||||
def _attrs(self, **extra: str) -> dict[str, str]:
|
||||
return {**self._base_attributes, **extra}
|
||||
|
||||
@staticmethod
|
||||
def _add(counter: "Counter | None", value: int, attributes: dict[str, str]) -> None:
|
||||
if not counter or value <= 0:
|
||||
return
|
||||
try:
|
||||
counter.add(value, attributes)
|
||||
except Exception:
|
||||
logger.exception("workflow_run_cleanup_metrics: failed to add counter value")
|
||||
|
||||
@staticmethod
|
||||
def _record(histogram: "Histogram | None", value: float, attributes: dict[str, str]) -> None:
|
||||
if not histogram:
|
||||
return
|
||||
try:
|
||||
histogram.record(value, attributes)
|
||||
except Exception:
|
||||
logger.exception("workflow_run_cleanup_metrics: failed to record histogram value")
|
||||
|
||||
def record_batch(
|
||||
self,
|
||||
*,
|
||||
batch_rows: int,
|
||||
targeted_runs: int,
|
||||
skipped_runs: int,
|
||||
deleted_runs: int,
|
||||
related_counts: dict[str, int] | None,
|
||||
related_action: str | None,
|
||||
batch_duration_seconds: float,
|
||||
) -> None:
|
||||
attributes = self._attrs()
|
||||
self._add(self._batches_total, 1, attributes)
|
||||
self._add(self._runs_scanned_total, batch_rows, attributes)
|
||||
self._add(self._runs_targeted_total, targeted_runs, attributes)
|
||||
self._add(self._runs_skipped_total, skipped_runs, attributes)
|
||||
self._add(self._runs_deleted_total, deleted_runs, attributes)
|
||||
self._record(self._batch_duration_seconds, batch_duration_seconds, attributes)
|
||||
|
||||
if not related_counts or not related_action:
|
||||
return
|
||||
|
||||
for record_type, count in related_counts.items():
|
||||
self._add(
|
||||
self._related_records_total,
|
||||
count,
|
||||
self._attrs(action=related_action, record_type=record_type),
|
||||
)
|
||||
|
||||
def record_completion(self, *, status: str, job_duration_seconds: float) -> None:
|
||||
attributes = self._attrs(status=status)
|
||||
self._add(self._job_runs_total, 1, attributes)
|
||||
self._record(self._job_duration_seconds, job_duration_seconds, attributes)
|
||||
|
||||
|
||||
class WorkflowRunCleanup:
|
||||
def __init__(
|
||||
self,
|
||||
@ -29,6 +182,7 @@ class WorkflowRunCleanup:
|
||||
end_before: datetime.datetime | None = None,
|
||||
workflow_run_repo: APIWorkflowRunRepository | None = None,
|
||||
dry_run: bool = False,
|
||||
task_label: str = "custom",
|
||||
):
|
||||
if (start_from is None) ^ (end_before is None):
|
||||
raise ValueError("start_from and end_before must be both set or both omitted.")
|
||||
@ -46,6 +200,11 @@ class WorkflowRunCleanup:
|
||||
self.batch_size = batch_size
|
||||
self._cleanup_whitelist: set[str] | None = None
|
||||
self.dry_run = dry_run
|
||||
self._metrics = WorkflowRunCleanupMetrics(
|
||||
dry_run=dry_run,
|
||||
has_window=bool(start_from),
|
||||
task_label=task_label,
|
||||
)
|
||||
self.free_plan_grace_period_days = dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD
|
||||
self.workflow_run_repo: APIWorkflowRunRepository
|
||||
if workflow_run_repo:
|
||||
@ -74,153 +233,193 @@ class WorkflowRunCleanup:
|
||||
related_totals = self._empty_related_counts() if self.dry_run else None
|
||||
batch_index = 0
|
||||
last_seen: tuple[datetime.datetime, str] | None = None
|
||||
status = "success"
|
||||
run_start = time.monotonic()
|
||||
max_batch_interval_ms = dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL
|
||||
|
||||
max_batch_interval_ms = int(os.environ.get("SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL", 200))
|
||||
try:
|
||||
while True:
|
||||
batch_start = time.monotonic()
|
||||
|
||||
while True:
|
||||
batch_start = time.monotonic()
|
||||
|
||||
fetch_start = time.monotonic()
|
||||
run_rows = self.workflow_run_repo.get_runs_batch_by_time_range(
|
||||
start_from=self.window_start,
|
||||
end_before=self.window_end,
|
||||
last_seen=last_seen,
|
||||
batch_size=self.batch_size,
|
||||
)
|
||||
if not run_rows:
|
||||
logger.info("workflow_run_cleanup (batch #%s): no more rows to process", batch_index + 1)
|
||||
break
|
||||
|
||||
batch_index += 1
|
||||
last_seen = (run_rows[-1].created_at, run_rows[-1].id)
|
||||
logger.info(
|
||||
"workflow_run_cleanup (batch #%s): fetched %s rows in %sms",
|
||||
batch_index,
|
||||
len(run_rows),
|
||||
int((time.monotonic() - fetch_start) * 1000),
|
||||
)
|
||||
|
||||
tenant_ids = {row.tenant_id for row in run_rows}
|
||||
|
||||
filter_start = time.monotonic()
|
||||
free_tenants = self._filter_free_tenants(tenant_ids)
|
||||
logger.info(
|
||||
"workflow_run_cleanup (batch #%s): filtered %s free tenants from %s tenants in %sms",
|
||||
batch_index,
|
||||
len(free_tenants),
|
||||
len(tenant_ids),
|
||||
int((time.monotonic() - filter_start) * 1000),
|
||||
)
|
||||
|
||||
free_runs = [row for row in run_rows if row.tenant_id in free_tenants]
|
||||
paid_or_skipped = len(run_rows) - len(free_runs)
|
||||
|
||||
if not free_runs:
|
||||
skipped_message = (
|
||||
f"[batch #{batch_index}] skipped (no sandbox runs in batch, {paid_or_skipped} paid/unknown)"
|
||||
fetch_start = time.monotonic()
|
||||
run_rows = self.workflow_run_repo.get_runs_batch_by_time_range(
|
||||
start_from=self.window_start,
|
||||
end_before=self.window_end,
|
||||
last_seen=last_seen,
|
||||
batch_size=self.batch_size,
|
||||
)
|
||||
click.echo(
|
||||
click.style(
|
||||
skipped_message,
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
continue
|
||||
if not run_rows:
|
||||
logger.info("workflow_run_cleanup (batch #%s): no more rows to process", batch_index + 1)
|
||||
break
|
||||
|
||||
total_runs_targeted += len(free_runs)
|
||||
|
||||
if self.dry_run:
|
||||
count_start = time.monotonic()
|
||||
batch_counts = self.workflow_run_repo.count_runs_with_related(
|
||||
free_runs,
|
||||
count_node_executions=self._count_node_executions,
|
||||
count_trigger_logs=self._count_trigger_logs,
|
||||
)
|
||||
batch_index += 1
|
||||
last_seen = (run_rows[-1].created_at, run_rows[-1].id)
|
||||
logger.info(
|
||||
"workflow_run_cleanup (batch #%s, dry_run): counted related records in %sms",
|
||||
"workflow_run_cleanup (batch #%s): fetched %s rows in %sms",
|
||||
batch_index,
|
||||
int((time.monotonic() - count_start) * 1000),
|
||||
len(run_rows),
|
||||
int((time.monotonic() - fetch_start) * 1000),
|
||||
)
|
||||
if related_totals is not None:
|
||||
for key in related_totals:
|
||||
related_totals[key] += batch_counts.get(key, 0)
|
||||
sample_ids = ", ".join(run.id for run in free_runs[:5])
|
||||
|
||||
tenant_ids = {row.tenant_id for row in run_rows}
|
||||
|
||||
filter_start = time.monotonic()
|
||||
free_tenants = self._filter_free_tenants(tenant_ids)
|
||||
logger.info(
|
||||
"workflow_run_cleanup (batch #%s): filtered %s free tenants from %s tenants in %sms",
|
||||
batch_index,
|
||||
len(free_tenants),
|
||||
len(tenant_ids),
|
||||
int((time.monotonic() - filter_start) * 1000),
|
||||
)
|
||||
|
||||
free_runs = [row for row in run_rows if row.tenant_id in free_tenants]
|
||||
paid_or_skipped = len(run_rows) - len(free_runs)
|
||||
|
||||
if not free_runs:
|
||||
skipped_message = (
|
||||
f"[batch #{batch_index}] skipped (no sandbox runs in batch, {paid_or_skipped} paid/unknown)"
|
||||
)
|
||||
click.echo(
|
||||
click.style(
|
||||
skipped_message,
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
self._metrics.record_batch(
|
||||
batch_rows=len(run_rows),
|
||||
targeted_runs=0,
|
||||
skipped_runs=paid_or_skipped,
|
||||
deleted_runs=0,
|
||||
related_counts=None,
|
||||
related_action=None,
|
||||
batch_duration_seconds=time.monotonic() - batch_start,
|
||||
)
|
||||
continue
|
||||
|
||||
total_runs_targeted += len(free_runs)
|
||||
|
||||
if self.dry_run:
|
||||
count_start = time.monotonic()
|
||||
batch_counts = self.workflow_run_repo.count_runs_with_related(
|
||||
free_runs,
|
||||
count_node_executions=self._count_node_executions,
|
||||
count_trigger_logs=self._count_trigger_logs,
|
||||
)
|
||||
logger.info(
|
||||
"workflow_run_cleanup (batch #%s, dry_run): counted related records in %sms",
|
||||
batch_index,
|
||||
int((time.monotonic() - count_start) * 1000),
|
||||
)
|
||||
if related_totals is not None:
|
||||
for key in related_totals:
|
||||
related_totals[key] += batch_counts.get(key, 0)
|
||||
sample_ids = ", ".join(run.id for run in free_runs[:5])
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[batch #{batch_index}] would delete {len(free_runs)} runs "
|
||||
f"(sample ids: {sample_ids}) and skip {paid_or_skipped} paid/unknown",
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
"workflow_run_cleanup (batch #%s, dry_run): batch total %sms",
|
||||
batch_index,
|
||||
int((time.monotonic() - batch_start) * 1000),
|
||||
)
|
||||
self._metrics.record_batch(
|
||||
batch_rows=len(run_rows),
|
||||
targeted_runs=len(free_runs),
|
||||
skipped_runs=paid_or_skipped,
|
||||
deleted_runs=0,
|
||||
related_counts={key: batch_counts.get(key, 0) for key in self._empty_related_counts()},
|
||||
related_action="would_delete",
|
||||
batch_duration_seconds=time.monotonic() - batch_start,
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
delete_start = time.monotonic()
|
||||
counts = self.workflow_run_repo.delete_runs_with_related(
|
||||
free_runs,
|
||||
delete_node_executions=self._delete_node_executions,
|
||||
delete_trigger_logs=self._delete_trigger_logs,
|
||||
)
|
||||
delete_ms = int((time.monotonic() - delete_start) * 1000)
|
||||
except Exception:
|
||||
logger.exception("Failed to delete workflow runs batch ending at %s", last_seen[0])
|
||||
raise
|
||||
|
||||
total_runs_deleted += counts["runs"]
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[batch #{batch_index}] would delete {len(free_runs)} runs "
|
||||
f"(sample ids: {sample_ids}) and skip {paid_or_skipped} paid/unknown",
|
||||
fg="yellow",
|
||||
f"[batch #{batch_index}] deleted runs: {counts['runs']} "
|
||||
f"(nodes {counts['node_executions']}, offloads {counts['offloads']}, "
|
||||
f"app_logs {counts['app_logs']}, trigger_logs {counts['trigger_logs']}, "
|
||||
f"pauses {counts['pauses']}, pause_reasons {counts['pause_reasons']}); "
|
||||
f"skipped {paid_or_skipped} paid/unknown",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
"workflow_run_cleanup (batch #%s, dry_run): batch total %sms",
|
||||
"workflow_run_cleanup (batch #%s): delete %sms, batch total %sms",
|
||||
batch_index,
|
||||
delete_ms,
|
||||
int((time.monotonic() - batch_start) * 1000),
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
delete_start = time.monotonic()
|
||||
counts = self.workflow_run_repo.delete_runs_with_related(
|
||||
free_runs,
|
||||
delete_node_executions=self._delete_node_executions,
|
||||
delete_trigger_logs=self._delete_trigger_logs,
|
||||
self._metrics.record_batch(
|
||||
batch_rows=len(run_rows),
|
||||
targeted_runs=len(free_runs),
|
||||
skipped_runs=paid_or_skipped,
|
||||
deleted_runs=counts["runs"],
|
||||
related_counts={key: counts.get(key, 0) for key in self._empty_related_counts()},
|
||||
related_action="deleted",
|
||||
batch_duration_seconds=time.monotonic() - batch_start,
|
||||
)
|
||||
delete_ms = int((time.monotonic() - delete_start) * 1000)
|
||||
except Exception:
|
||||
logger.exception("Failed to delete workflow runs batch ending at %s", last_seen[0])
|
||||
raise
|
||||
|
||||
total_runs_deleted += counts["runs"]
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[batch #{batch_index}] deleted runs: {counts['runs']} "
|
||||
f"(nodes {counts['node_executions']}, offloads {counts['offloads']}, "
|
||||
f"app_logs {counts['app_logs']}, trigger_logs {counts['trigger_logs']}, "
|
||||
f"pauses {counts['pauses']}, pause_reasons {counts['pause_reasons']}); "
|
||||
f"skipped {paid_or_skipped} paid/unknown",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
"workflow_run_cleanup (batch #%s): delete %sms, batch total %sms",
|
||||
batch_index,
|
||||
delete_ms,
|
||||
int((time.monotonic() - batch_start) * 1000),
|
||||
)
|
||||
# Random sleep between batches to avoid overwhelming the database
|
||||
sleep_ms = random.uniform(0, max_batch_interval_ms) # noqa: S311
|
||||
logger.info("workflow_run_cleanup (batch #%s): sleeping for %.2fms", batch_index, sleep_ms)
|
||||
time.sleep(sleep_ms / 1000)
|
||||
|
||||
# Random sleep between batches to avoid overwhelming the database
|
||||
sleep_ms = random.uniform(0, max_batch_interval_ms) # noqa: S311
|
||||
logger.info("workflow_run_cleanup (batch #%s): sleeping for %.2fms", batch_index, sleep_ms)
|
||||
time.sleep(sleep_ms / 1000)
|
||||
|
||||
if self.dry_run:
|
||||
if self.window_start:
|
||||
summary_message = (
|
||||
f"Dry run complete. Would delete {total_runs_targeted} workflow runs "
|
||||
f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}"
|
||||
)
|
||||
if self.dry_run:
|
||||
if self.window_start:
|
||||
summary_message = (
|
||||
f"Dry run complete. Would delete {total_runs_targeted} workflow runs "
|
||||
f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}"
|
||||
)
|
||||
else:
|
||||
summary_message = (
|
||||
f"Dry run complete. Would delete {total_runs_targeted} workflow runs "
|
||||
f"before {self.window_end.isoformat()}"
|
||||
)
|
||||
if related_totals is not None:
|
||||
summary_message = (
|
||||
f"{summary_message}; related records: {self._format_related_counts(related_totals)}"
|
||||
)
|
||||
summary_color = "yellow"
|
||||
else:
|
||||
summary_message = (
|
||||
f"Dry run complete. Would delete {total_runs_targeted} workflow runs "
|
||||
f"before {self.window_end.isoformat()}"
|
||||
)
|
||||
if related_totals is not None:
|
||||
summary_message = f"{summary_message}; related records: {self._format_related_counts(related_totals)}"
|
||||
summary_color = "yellow"
|
||||
else:
|
||||
if self.window_start:
|
||||
summary_message = (
|
||||
f"Cleanup complete. Deleted {total_runs_deleted} workflow runs "
|
||||
f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}"
|
||||
)
|
||||
else:
|
||||
summary_message = (
|
||||
f"Cleanup complete. Deleted {total_runs_deleted} workflow runs before {self.window_end.isoformat()}"
|
||||
)
|
||||
summary_color = "white"
|
||||
if self.window_start:
|
||||
summary_message = (
|
||||
f"Cleanup complete. Deleted {total_runs_deleted} workflow runs "
|
||||
f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}"
|
||||
)
|
||||
else:
|
||||
summary_message = (
|
||||
f"Cleanup complete. Deleted {total_runs_deleted} workflow runs "
|
||||
f"before {self.window_end.isoformat()}"
|
||||
)
|
||||
summary_color = "white"
|
||||
|
||||
click.echo(click.style(summary_message, fg=summary_color))
|
||||
click.echo(click.style(summary_message, fg=summary_color))
|
||||
except Exception:
|
||||
status = "failed"
|
||||
raise
|
||||
finally:
|
||||
self._metrics.record_completion(
|
||||
status=status,
|
||||
job_duration_seconds=time.monotonic() - run_start,
|
||||
)
|
||||
|
||||
def _filter_free_tenants(self, tenant_ids: Iterable[str]) -> set[str]:
|
||||
tenant_id_list = list(tenant_ids)
|
||||
|
||||
@ -12,12 +12,14 @@ from core.db.session_factory import session_factory
|
||||
from core.model_manager import ModelManager
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
||||
from core.rag.models.document import Document
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from libs import helper
|
||||
from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.enums import SummaryStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -29,7 +31,7 @@ class SummaryIndexService:
|
||||
def generate_summary_for_segment(
|
||||
segment: DocumentSegment,
|
||||
dataset: Dataset,
|
||||
summary_index_setting: dict,
|
||||
summary_index_setting: SummaryIndexSettingDict,
|
||||
) -> tuple[str, LLMUsage]:
|
||||
"""
|
||||
Generate summary for a single segment.
|
||||
@ -73,7 +75,7 @@ class SummaryIndexService:
|
||||
segment: DocumentSegment,
|
||||
dataset: Dataset,
|
||||
summary_content: str,
|
||||
status: str = "generating",
|
||||
status: SummaryStatus = SummaryStatus.GENERATING,
|
||||
) -> DocumentSegmentSummary:
|
||||
"""
|
||||
Create or update a DocumentSegmentSummary record.
|
||||
@ -83,7 +85,7 @@ class SummaryIndexService:
|
||||
segment: DocumentSegment to create summary for
|
||||
dataset: Dataset containing the segment
|
||||
summary_content: Generated summary content
|
||||
status: Summary status (default: "generating")
|
||||
status: Summary status (default: SummaryStatus.GENERATING)
|
||||
|
||||
Returns:
|
||||
Created or updated DocumentSegmentSummary instance
|
||||
@ -326,7 +328,7 @@ class SummaryIndexService:
|
||||
summary_index_node_id=summary_index_node_id,
|
||||
summary_index_node_hash=summary_hash,
|
||||
tokens=embedding_tokens,
|
||||
status="completed",
|
||||
status=SummaryStatus.COMPLETED,
|
||||
enabled=True,
|
||||
)
|
||||
session.add(summary_record_in_session)
|
||||
@ -362,7 +364,7 @@ class SummaryIndexService:
|
||||
summary_record_in_session.summary_index_node_id = summary_index_node_id
|
||||
summary_record_in_session.summary_index_node_hash = summary_hash
|
||||
summary_record_in_session.tokens = embedding_tokens # Save embedding tokens
|
||||
summary_record_in_session.status = "completed"
|
||||
summary_record_in_session.status = SummaryStatus.COMPLETED
|
||||
# Ensure summary_content is preserved (use the latest from summary_record parameter)
|
||||
# This is critical: use the parameter value, not the database value
|
||||
summary_record_in_session.summary_content = summary_content
|
||||
@ -400,7 +402,7 @@ class SummaryIndexService:
|
||||
summary_record.summary_index_node_id = summary_index_node_id
|
||||
summary_record.summary_index_node_hash = summary_hash
|
||||
summary_record.tokens = embedding_tokens
|
||||
summary_record.status = "completed"
|
||||
summary_record.status = SummaryStatus.COMPLETED
|
||||
summary_record.summary_content = summary_content
|
||||
if summary_record_in_session.updated_at:
|
||||
summary_record.updated_at = summary_record_in_session.updated_at
|
||||
@ -487,7 +489,7 @@ class SummaryIndexService:
|
||||
)
|
||||
|
||||
if summary_record_in_session:
|
||||
summary_record_in_session.status = "error"
|
||||
summary_record_in_session.status = SummaryStatus.ERROR
|
||||
summary_record_in_session.error = f"Vectorization failed: {str(e)}"
|
||||
summary_record_in_session.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
error_session.add(summary_record_in_session)
|
||||
@ -498,7 +500,7 @@ class SummaryIndexService:
|
||||
summary_record_in_session.id,
|
||||
)
|
||||
# Update the original object for consistency
|
||||
summary_record.status = "error"
|
||||
summary_record.status = SummaryStatus.ERROR
|
||||
summary_record.error = summary_record_in_session.error
|
||||
summary_record.updated_at = summary_record_in_session.updated_at
|
||||
else:
|
||||
@ -514,7 +516,7 @@ class SummaryIndexService:
|
||||
def batch_create_summary_records(
|
||||
segments: list[DocumentSegment],
|
||||
dataset: Dataset,
|
||||
status: str = "not_started",
|
||||
status: SummaryStatus = SummaryStatus.NOT_STARTED,
|
||||
) -> None:
|
||||
"""
|
||||
Batch create summary records for segments with specified status.
|
||||
@ -523,7 +525,7 @@ class SummaryIndexService:
|
||||
Args:
|
||||
segments: List of DocumentSegment instances
|
||||
dataset: Dataset containing the segments
|
||||
status: Initial status for the records (default: "not_started")
|
||||
status: Initial status for the records (default: SummaryStatus.NOT_STARTED)
|
||||
"""
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
if not segment_ids:
|
||||
@ -588,7 +590,7 @@ class SummaryIndexService:
|
||||
)
|
||||
|
||||
if summary_record:
|
||||
summary_record.status = "error"
|
||||
summary_record.status = SummaryStatus.ERROR
|
||||
summary_record.error = error
|
||||
session.add(summary_record)
|
||||
session.commit()
|
||||
@ -599,7 +601,7 @@ class SummaryIndexService:
|
||||
def generate_and_vectorize_summary(
|
||||
segment: DocumentSegment,
|
||||
dataset: Dataset,
|
||||
summary_index_setting: dict,
|
||||
summary_index_setting: SummaryIndexSettingDict,
|
||||
) -> DocumentSegmentSummary:
|
||||
"""
|
||||
Generate summary for a segment and vectorize it.
|
||||
@ -631,14 +633,14 @@ class SummaryIndexService:
|
||||
document_id=segment.document_id,
|
||||
chunk_id=segment.id,
|
||||
summary_content="",
|
||||
status="generating",
|
||||
status=SummaryStatus.GENERATING,
|
||||
enabled=True,
|
||||
)
|
||||
session.add(summary_record_in_session)
|
||||
session.flush()
|
||||
|
||||
# Update status to "generating"
|
||||
summary_record_in_session.status = "generating"
|
||||
summary_record_in_session.status = SummaryStatus.GENERATING
|
||||
summary_record_in_session.error = None # type: ignore[assignment]
|
||||
session.add(summary_record_in_session)
|
||||
# Don't flush here - wait until after vectorization succeeds
|
||||
@ -681,7 +683,7 @@ class SummaryIndexService:
|
||||
except Exception as vectorize_error:
|
||||
# If vectorization fails, update status to error in current session
|
||||
logger.exception("Failed to vectorize summary for segment %s", segment.id)
|
||||
summary_record_in_session.status = "error"
|
||||
summary_record_in_session.status = SummaryStatus.ERROR
|
||||
summary_record_in_session.error = f"Vectorization failed: {str(vectorize_error)}"
|
||||
session.add(summary_record_in_session)
|
||||
session.commit()
|
||||
@ -694,7 +696,7 @@ class SummaryIndexService:
|
||||
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
|
||||
)
|
||||
if summary_record_in_session:
|
||||
summary_record_in_session.status = "error"
|
||||
summary_record_in_session.status = SummaryStatus.ERROR
|
||||
summary_record_in_session.error = str(e)
|
||||
session.add(summary_record_in_session)
|
||||
session.commit()
|
||||
@ -704,7 +706,7 @@ class SummaryIndexService:
|
||||
def generate_summaries_for_document(
|
||||
dataset: Dataset,
|
||||
document: DatasetDocument,
|
||||
summary_index_setting: dict,
|
||||
summary_index_setting: SummaryIndexSettingDict,
|
||||
segment_ids: list[str] | None = None,
|
||||
only_parent_chunks: bool = False,
|
||||
) -> list[DocumentSegmentSummary]:
|
||||
@ -770,7 +772,7 @@ class SummaryIndexService:
|
||||
SummaryIndexService.batch_create_summary_records(
|
||||
segments=segments,
|
||||
dataset=dataset,
|
||||
status="not_started",
|
||||
status=SummaryStatus.NOT_STARTED,
|
||||
)
|
||||
|
||||
summary_records = []
|
||||
@ -1067,7 +1069,7 @@ class SummaryIndexService:
|
||||
|
||||
# Update summary content
|
||||
summary_record.summary_content = summary_content
|
||||
summary_record.status = "generating"
|
||||
summary_record.status = SummaryStatus.GENERATING
|
||||
summary_record.error = None # type: ignore[assignment] # Clear any previous errors
|
||||
session.add(summary_record)
|
||||
# Flush to ensure summary_content is saved before vectorize_summary queries it
|
||||
@ -1102,7 +1104,7 @@ class SummaryIndexService:
|
||||
# If vectorization fails, update status to error in current session
|
||||
# Don't raise the exception - just log it and return the record with error status
|
||||
# This allows the segment update to complete even if vectorization fails
|
||||
summary_record.status = "error"
|
||||
summary_record.status = SummaryStatus.ERROR
|
||||
summary_record.error = f"Vectorization failed: {str(e)}"
|
||||
session.commit()
|
||||
logger.exception("Failed to vectorize summary for segment %s", segment.id)
|
||||
@ -1112,7 +1114,7 @@ class SummaryIndexService:
|
||||
else:
|
||||
# Create new summary record if doesn't exist
|
||||
summary_record = SummaryIndexService.create_summary_record(
|
||||
segment, dataset, summary_content, status="generating"
|
||||
segment, dataset, summary_content, status=SummaryStatus.GENERATING
|
||||
)
|
||||
# Re-vectorize summary (this will update status to "completed" and tokens in its own session)
|
||||
# Note: summary_record was created in a different session,
|
||||
@ -1132,7 +1134,7 @@ class SummaryIndexService:
|
||||
# If vectorization fails, update status to error in current session
|
||||
# Merge the record into current session first
|
||||
error_record = session.merge(summary_record)
|
||||
error_record.status = "error"
|
||||
error_record.status = SummaryStatus.ERROR
|
||||
error_record.error = f"Vectorization failed: {str(e)}"
|
||||
session.commit()
|
||||
logger.exception("Failed to vectorize summary for segment %s", segment.id)
|
||||
@ -1146,7 +1148,7 @@ class SummaryIndexService:
|
||||
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
|
||||
)
|
||||
if summary_record:
|
||||
summary_record.status = "error"
|
||||
summary_record.status = SummaryStatus.ERROR
|
||||
summary_record.error = str(e)
|
||||
session.add(summary_record)
|
||||
session.commit()
|
||||
@ -1266,7 +1268,7 @@ class SummaryIndexService:
|
||||
# Check if there are any "not_started" or "generating" status summaries
|
||||
has_pending_summaries = any(
|
||||
summary_status_map.get(segment_id) is not None # Ensure summary exists (enabled=True)
|
||||
and summary_status_map[segment_id] in ("not_started", "generating")
|
||||
and summary_status_map[segment_id] in (SummaryStatus.NOT_STARTED, SummaryStatus.GENERATING)
|
||||
for segment_id in segment_ids
|
||||
)
|
||||
|
||||
@ -1330,7 +1332,7 @@ class SummaryIndexService:
|
||||
# it means the summary is disabled (enabled=False) or not created yet, ignore it
|
||||
has_pending_summaries = any(
|
||||
summary_status_map.get(segment_id) is not None # Ensure summary exists (enabled=True)
|
||||
and summary_status_map[segment_id] in ("not_started", "generating")
|
||||
and summary_status_map[segment_id] in (SummaryStatus.NOT_STARTED, SummaryStatus.GENERATING)
|
||||
for segment_id in segment_ids
|
||||
)
|
||||
|
||||
@ -1393,17 +1395,17 @@ class SummaryIndexService:
|
||||
|
||||
# Count statuses
|
||||
status_counts = {
|
||||
"completed": 0,
|
||||
"generating": 0,
|
||||
"error": 0,
|
||||
"not_started": 0,
|
||||
SummaryStatus.COMPLETED: 0,
|
||||
SummaryStatus.GENERATING: 0,
|
||||
SummaryStatus.ERROR: 0,
|
||||
SummaryStatus.NOT_STARTED: 0,
|
||||
}
|
||||
|
||||
summary_list = []
|
||||
for segment in segments:
|
||||
summary = summary_map.get(segment.id)
|
||||
if summary:
|
||||
status = summary.status
|
||||
status = SummaryStatus(summary.status)
|
||||
status_counts[status] = status_counts.get(status, 0) + 1
|
||||
summary_list.append(
|
||||
{
|
||||
@ -1421,12 +1423,12 @@ class SummaryIndexService:
|
||||
}
|
||||
)
|
||||
else:
|
||||
status_counts["not_started"] += 1
|
||||
status_counts[SummaryStatus.NOT_STARTED] += 1
|
||||
summary_list.append(
|
||||
{
|
||||
"segment_id": segment.id,
|
||||
"segment_position": segment.position,
|
||||
"status": "not_started",
|
||||
"status": SummaryStatus.NOT_STARTED,
|
||||
"summary_preview": None,
|
||||
"error": None,
|
||||
"created_at": None,
|
||||
|
||||
@ -13,6 +13,7 @@ from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import DatasetAutoDisableLog, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.enums import IndexingStatus, SegmentStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -34,7 +35,7 @@ def add_document_to_index_task(dataset_document_id: str):
|
||||
logger.info(click.style(f"Document not found: {dataset_document_id}", fg="red"))
|
||||
return
|
||||
|
||||
if dataset_document.indexing_status != "completed":
|
||||
if dataset_document.indexing_status != IndexingStatus.COMPLETED:
|
||||
return
|
||||
|
||||
indexing_cache_key = f"document_{dataset_document.id}_indexing"
|
||||
@ -48,7 +49,7 @@ def add_document_to_index_task(dataset_document_id: str):
|
||||
session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.status == SegmentStatus.COMPLETED,
|
||||
)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
.all()
|
||||
@ -139,7 +140,7 @@ def add_document_to_index_task(dataset_document_id: str):
|
||||
logger.exception("add document to index failed")
|
||||
dataset_document.enabled = False
|
||||
dataset_document.disabled_at = naive_utc_now()
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.indexing_status = IndexingStatus.ERROR
|
||||
dataset_document.error = str(e)
|
||||
session.commit()
|
||||
finally:
|
||||
|
||||
@ -11,6 +11,7 @@ from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Dataset
|
||||
from models.enums import CollectionBindingType
|
||||
from models.model import App, AppAnnotationSetting, MessageAnnotation
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
@ -47,7 +48,7 @@ def enable_annotation_reply_task(
|
||||
try:
|
||||
documents = []
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_provider_name, embedding_model_name, "annotation"
|
||||
embedding_provider_name, embedding_model_name, CollectionBindingType.ANNOTATION
|
||||
)
|
||||
annotation_setting = (
|
||||
session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
|
||||
@ -56,7 +57,7 @@ def enable_annotation_reply_task(
|
||||
if dataset_collection_binding.id != annotation_setting.collection_binding_id:
|
||||
old_dataset_collection_binding = (
|
||||
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
|
||||
annotation_setting.collection_binding_id, "annotation"
|
||||
annotation_setting.collection_binding_id, CollectionBindingType.ANNOTATION
|
||||
)
|
||||
)
|
||||
if old_dataset_collection_binding and annotations:
|
||||
|
||||
@ -10,6 +10,7 @@ from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import DocumentSegment
|
||||
from models.enums import IndexingStatus, SegmentStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -31,7 +32,7 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N
|
||||
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
|
||||
return
|
||||
|
||||
if segment.status != "waiting":
|
||||
if segment.status != SegmentStatus.WAITING:
|
||||
return
|
||||
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
@ -40,7 +41,7 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N
|
||||
# update segment status to indexing
|
||||
session.query(DocumentSegment).filter_by(id=segment.id).update(
|
||||
{
|
||||
DocumentSegment.status: "indexing",
|
||||
DocumentSegment.status: SegmentStatus.INDEXING,
|
||||
DocumentSegment.indexing_at: naive_utc_now(),
|
||||
}
|
||||
)
|
||||
@ -70,7 +71,7 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N
|
||||
if (
|
||||
not dataset_document.enabled
|
||||
or dataset_document.archived
|
||||
or dataset_document.indexing_status != "completed"
|
||||
or dataset_document.indexing_status != IndexingStatus.COMPLETED
|
||||
):
|
||||
logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
|
||||
return
|
||||
@ -82,7 +83,7 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N
|
||||
# update segment to completed
|
||||
session.query(DocumentSegment).filter_by(id=segment.id).update(
|
||||
{
|
||||
DocumentSegment.status: "completed",
|
||||
DocumentSegment.status: SegmentStatus.COMPLETED,
|
||||
DocumentSegment.completed_at: naive_utc_now(),
|
||||
}
|
||||
)
|
||||
@ -94,7 +95,7 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N
|
||||
logger.exception("create segment to index failed")
|
||||
segment.enabled = False
|
||||
segment.disabled_at = naive_utc_now()
|
||||
segment.status = "error"
|
||||
segment.status = SegmentStatus.ERROR
|
||||
segment.error = str(e)
|
||||
session.commit()
|
||||
finally:
|
||||
|
||||
@ -12,6 +12,7 @@ from core.rag.extractor.notion_extractor import NotionExtractor
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.enums import IndexingStatus
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -37,7 +38,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
|
||||
return
|
||||
|
||||
if document.indexing_status == "parsing":
|
||||
if document.indexing_status == IndexingStatus.PARSING:
|
||||
logger.info(click.style(f"Document {document_id} is already being processed, skipping", fg="yellow"))
|
||||
return
|
||||
|
||||
@ -88,7 +89,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.indexing_status = IndexingStatus.ERROR
|
||||
document.error = "Datasource credential not found. Please reconnect your Notion workspace."
|
||||
document.stopped_at = naive_utc_now()
|
||||
return
|
||||
@ -128,7 +129,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
data_source_info["last_edited_time"] = last_edited_time
|
||||
document.data_source_info = json.dumps(data_source_info)
|
||||
|
||||
document.indexing_status = "parsing"
|
||||
document.indexing_status = IndexingStatus.PARSING
|
||||
document.processing_started_at = naive_utc_now()
|
||||
|
||||
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id)
|
||||
@ -151,6 +152,6 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.indexing_status = IndexingStatus.ERROR
|
||||
document.error = str(e)
|
||||
document.stopped_at = naive_utc_now()
|
||||
|
||||
@ -14,6 +14,7 @@ from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Dataset, Document
|
||||
from models.enums import IndexingStatus
|
||||
from services.feature_service import FeatureService
|
||||
from tasks.generate_summary_index_task import generate_summary_index_task
|
||||
|
||||
@ -81,7 +82,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
||||
session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
)
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.indexing_status = IndexingStatus.ERROR
|
||||
document.error = str(e)
|
||||
document.stopped_at = naive_utc_now()
|
||||
session.add(document)
|
||||
@ -96,7 +97,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
||||
|
||||
for document in documents:
|
||||
if document:
|
||||
document.indexing_status = "parsing"
|
||||
document.indexing_status = IndexingStatus.PARSING
|
||||
document.processing_started_at = naive_utc_now()
|
||||
session.add(document)
|
||||
# Transaction committed and closed
|
||||
@ -148,7 +149,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
||||
document.need_summary,
|
||||
)
|
||||
if (
|
||||
document.indexing_status == "completed"
|
||||
document.indexing_status == IndexingStatus.COMPLETED
|
||||
and document.doc_form != "qa_model"
|
||||
and document.need_summary is True
|
||||
):
|
||||
|
||||
@ -10,6 +10,7 @@ from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.enums import IndexingStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -33,7 +34,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
|
||||
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
|
||||
return
|
||||
|
||||
document.indexing_status = "parsing"
|
||||
document.indexing_status = IndexingStatus.PARSING
|
||||
document.processing_started_at = naive_utc_now()
|
||||
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
|
||||
@ -15,6 +15,7 @@ from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.enums import IndexingStatus
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -112,7 +113,7 @@ def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[st
|
||||
)
|
||||
for document in documents:
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.indexing_status = IndexingStatus.ERROR
|
||||
document.error = str(e)
|
||||
document.stopped_at = naive_utc_now()
|
||||
session.add(document)
|
||||
@ -146,7 +147,7 @@ def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[st
|
||||
session.execute(segment_delete_stmt)
|
||||
session.commit()
|
||||
|
||||
document.indexing_status = "parsing"
|
||||
document.indexing_status = IndexingStatus.PARSING
|
||||
document.processing_started_at = naive_utc_now()
|
||||
session.add(document)
|
||||
session.commit()
|
||||
|
||||
@ -12,6 +12,7 @@ from core.rag.models.document import AttachmentDocument, ChildDocument, Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import DocumentSegment
|
||||
from models.enums import IndexingStatus, SegmentStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -33,7 +34,7 @@ def enable_segment_to_index_task(segment_id: str):
|
||||
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
|
||||
return
|
||||
|
||||
if segment.status != "completed":
|
||||
if segment.status != SegmentStatus.COMPLETED:
|
||||
logger.info(click.style(f"Segment is not completed, enable is not allowed: {segment_id}", fg="red"))
|
||||
return
|
||||
|
||||
@ -65,7 +66,7 @@ def enable_segment_to_index_task(segment_id: str):
|
||||
if (
|
||||
not dataset_document.enabled
|
||||
or dataset_document.archived
|
||||
or dataset_document.indexing_status != "completed"
|
||||
or dataset_document.indexing_status != IndexingStatus.COMPLETED
|
||||
):
|
||||
logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
|
||||
return
|
||||
@ -123,7 +124,7 @@ def enable_segment_to_index_task(segment_id: str):
|
||||
logger.exception("enable segment to index failed")
|
||||
segment.enabled = False
|
||||
segment.disabled_at = naive_utc_now()
|
||||
segment.status = "error"
|
||||
segment.status = SegmentStatus.ERROR
|
||||
segment.error = str(e)
|
||||
session.commit()
|
||||
finally:
|
||||
|
||||
@ -12,6 +12,7 @@ from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, Tenant
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.enums import IndexingStatus
|
||||
from services.feature_service import FeatureService
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
@ -63,7 +64,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_
|
||||
.first()
|
||||
)
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.indexing_status = IndexingStatus.ERROR
|
||||
document.error = str(e)
|
||||
document.stopped_at = naive_utc_now()
|
||||
session.add(document)
|
||||
@ -95,7 +96,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_
|
||||
session.execute(segment_delete_stmt)
|
||||
session.commit()
|
||||
|
||||
document.indexing_status = "parsing"
|
||||
document.indexing_status = IndexingStatus.PARSING
|
||||
document.processing_started_at = naive_utc_now()
|
||||
session.add(document)
|
||||
session.commit()
|
||||
@ -108,7 +109,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_
|
||||
indexing_runner.run([document])
|
||||
redis_client.delete(retry_indexing_cache_key)
|
||||
except Exception as ex:
|
||||
document.indexing_status = "error"
|
||||
document.indexing_status = IndexingStatus.ERROR
|
||||
document.error = str(ex)
|
||||
document.stopped_at = naive_utc_now()
|
||||
session.add(document)
|
||||
|
||||
@ -11,6 +11,7 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFacto
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.enums import IndexingStatus
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -48,7 +49,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
|
||||
session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
)
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.indexing_status = IndexingStatus.ERROR
|
||||
document.error = str(e)
|
||||
document.stopped_at = naive_utc_now()
|
||||
session.add(document)
|
||||
@ -76,7 +77,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
|
||||
session.execute(segment_delete_stmt)
|
||||
session.commit()
|
||||
|
||||
document.indexing_status = "parsing"
|
||||
document.indexing_status = IndexingStatus.PARSING
|
||||
document.processing_started_at = naive_utc_now()
|
||||
session.add(document)
|
||||
session.commit()
|
||||
@ -85,7 +86,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
|
||||
indexing_runner.run([document])
|
||||
redis_client.delete(sync_indexing_cache_key)
|
||||
except Exception as ex:
|
||||
document.indexing_status = "error"
|
||||
document.indexing_status = IndexingStatus.ERROR
|
||||
document.error = str(ex)
|
||||
document.stopped_at = naive_utc_now()
|
||||
session.add(document)
|
||||
|
||||
@ -7,6 +7,7 @@ from faker import Faker
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest
|
||||
from models.dataset import Dataset, Document
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus
|
||||
from services.account_service import AccountService, TenantService
|
||||
from tests.test_containers_integration_tests.helpers import generate_valid_password
|
||||
|
||||
@ -35,7 +36,7 @@ class TestGetAvailableDatasetsIntegration:
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
provider="dify",
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=account.id,
|
||||
indexing_technique="high_quality",
|
||||
)
|
||||
@ -49,14 +50,14 @@ class TestGetAvailableDatasetsIntegration:
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=i,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch=str(uuid.uuid4()), # Required field
|
||||
name=f"Document {i}",
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_language="en",
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
)
|
||||
@ -94,7 +95,7 @@ class TestGetAvailableDatasetsIntegration:
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
provider="dify",
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
@ -106,13 +107,13 @@ class TestGetAvailableDatasetsIntegration:
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=i,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch=str(uuid.uuid4()), # Required field
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
name=f"Archived Document {i}",
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=True, # Archived
|
||||
)
|
||||
@ -147,7 +148,7 @@ class TestGetAvailableDatasetsIntegration:
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
provider="dify",
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
@ -159,13 +160,13 @@ class TestGetAvailableDatasetsIntegration:
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=i,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch=str(uuid.uuid4()), # Required field
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
name=f"Disabled Document {i}",
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=False, # Disabled
|
||||
archived=False,
|
||||
)
|
||||
@ -200,21 +201,21 @@ class TestGetAvailableDatasetsIntegration:
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
provider="dify",
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
||||
# Create documents with non-completed status
|
||||
for i, status in enumerate(["indexing", "parsing", "splitting"]):
|
||||
for i, status in enumerate([IndexingStatus.INDEXING, IndexingStatus.PARSING, IndexingStatus.SPLITTING]):
|
||||
document = Document(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=i,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch=str(uuid.uuid4()), # Required field
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
name=f"Document {status}",
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
@ -263,7 +264,7 @@ class TestGetAvailableDatasetsIntegration:
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
provider="external", # External provider
|
||||
data_source_type="external",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
@ -307,7 +308,7 @@ class TestGetAvailableDatasetsIntegration:
|
||||
tenant_id=tenant1.id,
|
||||
name="Tenant 1 Dataset",
|
||||
provider="dify",
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=account1.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset1)
|
||||
@ -318,7 +319,7 @@ class TestGetAvailableDatasetsIntegration:
|
||||
tenant_id=tenant2.id,
|
||||
name="Tenant 2 Dataset",
|
||||
provider="dify",
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=account2.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset2)
|
||||
@ -330,13 +331,13 @@ class TestGetAvailableDatasetsIntegration:
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=0,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch=str(uuid.uuid4()), # Required field
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
name=f"Document for {dataset.name}",
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
)
|
||||
@ -398,7 +399,7 @@ class TestGetAvailableDatasetsIntegration:
|
||||
tenant_id=tenant.id,
|
||||
name=f"Dataset {i}",
|
||||
provider="dify",
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
@ -410,13 +411,13 @@ class TestGetAvailableDatasetsIntegration:
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=0,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch=str(uuid.uuid4()), # Required field
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
name=f"Document {i}",
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
)
|
||||
@ -456,7 +457,7 @@ class TestKnowledgeRetrievalIntegration:
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
provider="dify",
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=account.id,
|
||||
indexing_technique="high_quality",
|
||||
)
|
||||
@ -467,12 +468,12 @@ class TestKnowledgeRetrievalIntegration:
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=0,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch=str(uuid.uuid4()), # Required field
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
name=fake.sentence(),
|
||||
created_by=account.id,
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
doc_form="text_model",
|
||||
@ -525,7 +526,7 @@ class TestKnowledgeRetrievalIntegration:
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
provider="dify",
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
@ -572,7 +573,7 @@ class TestKnowledgeRetrievalIntegration:
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
provider="dify",
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
||||
@ -12,6 +12,7 @@ import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
|
||||
|
||||
|
||||
class TestDatasetDocumentProperties:
|
||||
@ -29,7 +30,7 @@ class TestDatasetDocumentProperties:
|
||||
created_by = str(uuid4())
|
||||
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by
|
||||
tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.flush()
|
||||
@ -39,10 +40,10 @@ class TestDatasetDocumentProperties:
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=i + 1,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch="batch_001",
|
||||
name=f"doc_{i}.pdf",
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=created_by,
|
||||
)
|
||||
db_session_with_containers.add(doc)
|
||||
@ -56,7 +57,7 @@ class TestDatasetDocumentProperties:
|
||||
created_by = str(uuid4())
|
||||
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by
|
||||
tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.flush()
|
||||
@ -65,12 +66,12 @@ class TestDatasetDocumentProperties:
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=1,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch="batch_001",
|
||||
name="available.pdf",
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=created_by,
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
)
|
||||
@ -78,12 +79,12 @@ class TestDatasetDocumentProperties:
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=2,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch="batch_001",
|
||||
name="pending.pdf",
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=created_by,
|
||||
indexing_status="waiting",
|
||||
indexing_status=IndexingStatus.WAITING,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
)
|
||||
@ -91,12 +92,12 @@ class TestDatasetDocumentProperties:
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=3,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch="batch_001",
|
||||
name="disabled.pdf",
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=created_by,
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=False,
|
||||
archived=False,
|
||||
)
|
||||
@ -111,7 +112,7 @@ class TestDatasetDocumentProperties:
|
||||
created_by = str(uuid4())
|
||||
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by
|
||||
tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.flush()
|
||||
@ -121,10 +122,10 @@ class TestDatasetDocumentProperties:
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=i + 1,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch="batch_001",
|
||||
name=f"doc_{i}.pdf",
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=created_by,
|
||||
word_count=wc,
|
||||
)
|
||||
@ -139,7 +140,7 @@ class TestDatasetDocumentProperties:
|
||||
created_by = str(uuid4())
|
||||
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by
|
||||
tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.flush()
|
||||
@ -148,10 +149,10 @@ class TestDatasetDocumentProperties:
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=1,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch="batch_001",
|
||||
name="doc.pdf",
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=created_by,
|
||||
)
|
||||
db_session_with_containers.add(doc)
|
||||
@ -166,7 +167,7 @@ class TestDatasetDocumentProperties:
|
||||
content=f"segment {i}",
|
||||
word_count=100,
|
||||
tokens=50,
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
enabled=True,
|
||||
created_by=created_by,
|
||||
)
|
||||
@ -180,7 +181,7 @@ class TestDatasetDocumentProperties:
|
||||
content="waiting segment",
|
||||
word_count=100,
|
||||
tokens=50,
|
||||
status="waiting",
|
||||
status=SegmentStatus.WAITING,
|
||||
enabled=True,
|
||||
created_by=created_by,
|
||||
)
|
||||
@ -195,7 +196,7 @@ class TestDatasetDocumentProperties:
|
||||
created_by = str(uuid4())
|
||||
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by
|
||||
tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.flush()
|
||||
@ -204,10 +205,10 @@ class TestDatasetDocumentProperties:
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=1,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch="batch_001",
|
||||
name="doc.pdf",
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=created_by,
|
||||
)
|
||||
db_session_with_containers.add(doc)
|
||||
@ -235,7 +236,7 @@ class TestDatasetDocumentProperties:
|
||||
created_by = str(uuid4())
|
||||
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by
|
||||
tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.flush()
|
||||
@ -244,10 +245,10 @@ class TestDatasetDocumentProperties:
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=1,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch="batch_001",
|
||||
name="doc.pdf",
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=created_by,
|
||||
)
|
||||
db_session_with_containers.add(doc)
|
||||
@ -288,7 +289,7 @@ class TestDocumentSegmentNavigationProperties:
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
name="Test Dataset",
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=created_by,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
@ -298,10 +299,10 @@ class TestDocumentSegmentNavigationProperties:
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=1,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch="batch_001",
|
||||
name="test.pdf",
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=created_by,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
@ -335,7 +336,7 @@ class TestDocumentSegmentNavigationProperties:
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
name="Test Dataset",
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=created_by,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
@ -345,10 +346,10 @@ class TestDocumentSegmentNavigationProperties:
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=1,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch="batch_001",
|
||||
name="test.pdf",
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=created_by,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
@ -382,7 +383,7 @@ class TestDocumentSegmentNavigationProperties:
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
name="Test Dataset",
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=created_by,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
@ -392,10 +393,10 @@ class TestDocumentSegmentNavigationProperties:
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=1,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch="batch_001",
|
||||
name="test.pdf",
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=created_by,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
@ -439,7 +440,7 @@ class TestDocumentSegmentNavigationProperties:
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
name="Test Dataset",
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=created_by,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
@ -449,10 +450,10 @@ class TestDocumentSegmentNavigationProperties:
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=1,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch="batch_001",
|
||||
name="test.pdf",
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=created_by,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
|
||||
@ -12,6 +12,7 @@ import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.dataset import DatasetCollectionBinding
|
||||
from models.enums import CollectionBindingType
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
|
||||
@ -32,7 +33,7 @@ class DatasetCollectionBindingTestDataFactory:
|
||||
provider_name: str = "openai",
|
||||
model_name: str = "text-embedding-ada-002",
|
||||
collection_name: str = "collection-abc",
|
||||
collection_type: str = "dataset",
|
||||
collection_type: str = CollectionBindingType.DATASET,
|
||||
) -> DatasetCollectionBinding:
|
||||
"""
|
||||
Create a DatasetCollectionBinding with specified attributes.
|
||||
@ -41,7 +42,7 @@ class DatasetCollectionBindingTestDataFactory:
|
||||
provider_name: Name of the embedding model provider (e.g., "openai", "cohere")
|
||||
model_name: Name of the embedding model (e.g., "text-embedding-ada-002")
|
||||
collection_name: Name of the vector database collection
|
||||
collection_type: Type of collection (default: "dataset")
|
||||
collection_type: Type of collection (default: CollectionBindingType.DATASET)
|
||||
|
||||
Returns:
|
||||
DatasetCollectionBinding instance
|
||||
@ -76,7 +77,7 @@ class TestDatasetCollectionBindingServiceGetBinding:
|
||||
# Arrange
|
||||
provider_name = "openai"
|
||||
model_name = "text-embedding-ada-002"
|
||||
collection_type = "dataset"
|
||||
collection_type = CollectionBindingType.DATASET
|
||||
existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding(
|
||||
db_session_with_containers,
|
||||
provider_name=provider_name,
|
||||
@ -104,7 +105,7 @@ class TestDatasetCollectionBindingServiceGetBinding:
|
||||
# Arrange
|
||||
provider_name = f"provider-{uuid4()}"
|
||||
model_name = f"model-{uuid4()}"
|
||||
collection_type = "dataset"
|
||||
collection_type = CollectionBindingType.DATASET
|
||||
|
||||
# Act
|
||||
result = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
@ -145,7 +146,7 @@ class TestDatasetCollectionBindingServiceGetBinding:
|
||||
result = DatasetCollectionBindingService.get_dataset_collection_binding(provider_name, model_name)
|
||||
|
||||
# Assert
|
||||
assert result.type == "dataset"
|
||||
assert result.type == CollectionBindingType.DATASET
|
||||
assert result.provider_name == provider_name
|
||||
assert result.model_name == model_name
|
||||
|
||||
@ -186,18 +187,20 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
|
||||
provider_name="openai",
|
||||
model_name="text-embedding-ada-002",
|
||||
collection_name="test-collection",
|
||||
collection_type="dataset",
|
||||
collection_type=CollectionBindingType.DATASET,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(binding.id, "dataset")
|
||||
result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
|
||||
binding.id, CollectionBindingType.DATASET
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result.id == binding.id
|
||||
assert result.provider_name == "openai"
|
||||
assert result.model_name == "text-embedding-ada-002"
|
||||
assert result.collection_name == "test-collection"
|
||||
assert result.type == "dataset"
|
||||
assert result.type == CollectionBindingType.DATASET
|
||||
|
||||
def test_get_dataset_collection_binding_by_id_and_type_not_found_error(self, db_session_with_containers: Session):
|
||||
"""Test error handling when collection binding is not found by ID and type."""
|
||||
@ -206,7 +209,9 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Dataset collection binding not found"):
|
||||
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(non_existent_id, "dataset")
|
||||
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
|
||||
non_existent_id, CollectionBindingType.DATASET
|
||||
)
|
||||
|
||||
def test_get_dataset_collection_binding_by_id_and_type_different_collection_type(
|
||||
self, db_session_with_containers: Session
|
||||
@ -240,7 +245,7 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
|
||||
provider_name="openai",
|
||||
model_name="text-embedding-ada-002",
|
||||
collection_name="test-collection",
|
||||
collection_type="dataset",
|
||||
collection_type=CollectionBindingType.DATASET,
|
||||
)
|
||||
|
||||
# Act
|
||||
@ -248,7 +253,7 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
|
||||
|
||||
# Assert
|
||||
assert result.id == binding.id
|
||||
assert result.type == "dataset"
|
||||
assert result.type == CollectionBindingType.DATASET
|
||||
|
||||
def test_get_dataset_collection_binding_by_id_and_type_wrong_type_error(self, db_session_with_containers: Session):
|
||||
"""Test error when binding exists but with wrong collection type."""
|
||||
@ -258,7 +263,7 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
|
||||
provider_name="openai",
|
||||
model_name="text-embedding-ada-002",
|
||||
collection_name="test-collection",
|
||||
collection_type="dataset",
|
||||
collection_type=CollectionBindingType.DATASET,
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
|
||||
@ -15,6 +15,7 @@ from werkzeug.exceptions import NotFound
|
||||
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import AppDatasetJoin, Dataset, DatasetPermissionEnum
|
||||
from models.enums import DataSourceType
|
||||
from models.model import App
|
||||
from services.dataset_service import DatasetService
|
||||
from services.errors.account import NoPermissionError
|
||||
@ -72,7 +73,7 @@ class DatasetUpdateDeleteTestDataFactory:
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description="Test description",
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique="high_quality",
|
||||
created_by=created_by,
|
||||
permission=permission,
|
||||
|
||||
@ -15,7 +15,7 @@ import pytest
|
||||
|
||||
from models import Account
|
||||
from models.dataset import Dataset, Document
|
||||
from models.enums import CreatorUserRole
|
||||
from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom, IndexingStatus
|
||||
from models.model import UploadFile
|
||||
from services.dataset_service import DocumentService
|
||||
from services.errors.document import DocumentIndexingError
|
||||
@ -88,7 +88,7 @@ class DocumentStatusTestDataFactory:
|
||||
data_source_info=json.dumps(data_source_info or {}),
|
||||
batch=f"batch-{uuid4()}",
|
||||
name=name,
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=created_by,
|
||||
doc_form="text_model",
|
||||
)
|
||||
@ -100,7 +100,7 @@ class DocumentStatusTestDataFactory:
|
||||
document.paused_by = paused_by
|
||||
document.paused_at = paused_at
|
||||
document.doc_metadata = doc_metadata or {}
|
||||
if indexing_status == "completed" and "completed_at" not in kwargs:
|
||||
if indexing_status == IndexingStatus.COMPLETED and "completed_at" not in kwargs:
|
||||
document.completed_at = FIXED_TIME
|
||||
|
||||
for key, value in kwargs.items():
|
||||
@ -139,7 +139,7 @@ class DocumentStatusTestDataFactory:
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=created_by,
|
||||
)
|
||||
dataset.id = dataset_id
|
||||
@ -291,7 +291,7 @@ class TestDocumentServicePauseDocument:
|
||||
db_session_with_containers,
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
indexing_status="waiting",
|
||||
indexing_status=IndexingStatus.WAITING,
|
||||
is_paused=False,
|
||||
)
|
||||
|
||||
@ -326,7 +326,7 @@ class TestDocumentServicePauseDocument:
|
||||
db_session_with_containers,
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
indexing_status="indexing",
|
||||
indexing_status=IndexingStatus.INDEXING,
|
||||
is_paused=False,
|
||||
)
|
||||
|
||||
@ -354,7 +354,7 @@ class TestDocumentServicePauseDocument:
|
||||
db_session_with_containers,
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
indexing_status="parsing",
|
||||
indexing_status=IndexingStatus.PARSING,
|
||||
is_paused=False,
|
||||
)
|
||||
|
||||
@ -383,7 +383,7 @@ class TestDocumentServicePauseDocument:
|
||||
db_session_with_containers,
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
is_paused=False,
|
||||
)
|
||||
|
||||
@ -412,7 +412,7 @@ class TestDocumentServicePauseDocument:
|
||||
db_session_with_containers,
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
indexing_status="error",
|
||||
indexing_status=IndexingStatus.ERROR,
|
||||
is_paused=False,
|
||||
)
|
||||
|
||||
@ -487,7 +487,7 @@ class TestDocumentServiceRecoverDocument:
|
||||
db_session_with_containers,
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
indexing_status="indexing",
|
||||
indexing_status=IndexingStatus.INDEXING,
|
||||
is_paused=True,
|
||||
paused_by=str(uuid4()),
|
||||
paused_at=paused_time,
|
||||
@ -526,7 +526,7 @@ class TestDocumentServiceRecoverDocument:
|
||||
db_session_with_containers,
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
indexing_status="indexing",
|
||||
indexing_status=IndexingStatus.INDEXING,
|
||||
is_paused=False,
|
||||
)
|
||||
|
||||
@ -609,7 +609,7 @@ class TestDocumentServiceRetryDocument:
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
document_id=str(uuid4()),
|
||||
indexing_status="error",
|
||||
indexing_status=IndexingStatus.ERROR,
|
||||
)
|
||||
|
||||
mock_document_service_dependencies["redis_client"].get.return_value = None
|
||||
@ -619,7 +619,7 @@ class TestDocumentServiceRetryDocument:
|
||||
|
||||
# Assert
|
||||
db_session_with_containers.refresh(document)
|
||||
assert document.indexing_status == "waiting"
|
||||
assert document.indexing_status == IndexingStatus.WAITING
|
||||
|
||||
expected_cache_key = f"document_{document.id}_is_retried"
|
||||
mock_document_service_dependencies["redis_client"].setex.assert_called_once_with(expected_cache_key, 600, 1)
|
||||
@ -646,14 +646,14 @@ class TestDocumentServiceRetryDocument:
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
document_id=str(uuid4()),
|
||||
indexing_status="error",
|
||||
indexing_status=IndexingStatus.ERROR,
|
||||
)
|
||||
document2 = DocumentStatusTestDataFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
document_id=str(uuid4()),
|
||||
indexing_status="error",
|
||||
indexing_status=IndexingStatus.ERROR,
|
||||
position=2,
|
||||
)
|
||||
|
||||
@ -665,8 +665,8 @@ class TestDocumentServiceRetryDocument:
|
||||
# Assert
|
||||
db_session_with_containers.refresh(document1)
|
||||
db_session_with_containers.refresh(document2)
|
||||
assert document1.indexing_status == "waiting"
|
||||
assert document2.indexing_status == "waiting"
|
||||
assert document1.indexing_status == IndexingStatus.WAITING
|
||||
assert document2.indexing_status == IndexingStatus.WAITING
|
||||
|
||||
mock_document_service_dependencies["retry_task"].delay.assert_called_once_with(
|
||||
dataset.id, [document1.id, document2.id], mock_document_service_dependencies["user_id"]
|
||||
@ -693,7 +693,7 @@ class TestDocumentServiceRetryDocument:
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
document_id=str(uuid4()),
|
||||
indexing_status="error",
|
||||
indexing_status=IndexingStatus.ERROR,
|
||||
)
|
||||
|
||||
mock_document_service_dependencies["redis_client"].get.return_value = "1"
|
||||
@ -703,7 +703,7 @@ class TestDocumentServiceRetryDocument:
|
||||
DocumentService.retry_document(dataset.id, [document])
|
||||
|
||||
db_session_with_containers.refresh(document)
|
||||
assert document.indexing_status == "error"
|
||||
assert document.indexing_status == IndexingStatus.ERROR
|
||||
|
||||
def test_retry_document_missing_current_user_error(
|
||||
self, db_session_with_containers, mock_document_service_dependencies
|
||||
@ -726,7 +726,7 @@ class TestDocumentServiceRetryDocument:
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
document_id=str(uuid4()),
|
||||
indexing_status="error",
|
||||
indexing_status=IndexingStatus.ERROR,
|
||||
)
|
||||
|
||||
mock_document_service_dependencies["redis_client"].get.return_value = None
|
||||
@ -816,7 +816,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
|
||||
tenant_id=dataset.tenant_id,
|
||||
document_id=str(uuid4()),
|
||||
enabled=False,
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
)
|
||||
document2 = DocumentStatusTestDataFactory.create_document(
|
||||
db_session_with_containers,
|
||||
@ -824,7 +824,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
|
||||
tenant_id=dataset.tenant_id,
|
||||
document_id=str(uuid4()),
|
||||
enabled=False,
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
position=2,
|
||||
)
|
||||
document_ids = [document1.id, document2.id]
|
||||
@ -866,7 +866,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
|
||||
tenant_id=dataset.tenant_id,
|
||||
document_id=str(uuid4()),
|
||||
enabled=True,
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
completed_at=FIXED_TIME,
|
||||
)
|
||||
document_ids = [document.id]
|
||||
@ -909,7 +909,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
|
||||
document_id=str(uuid4()),
|
||||
archived=False,
|
||||
enabled=True,
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
)
|
||||
document_ids = [document.id]
|
||||
|
||||
@ -951,7 +951,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
|
||||
document_id=str(uuid4()),
|
||||
archived=True,
|
||||
enabled=True,
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
)
|
||||
document_ids = [document.id]
|
||||
|
||||
@ -1015,7 +1015,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
document_id=str(uuid4()),
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
)
|
||||
document_ids = [document.id]
|
||||
|
||||
@ -1098,7 +1098,7 @@ class TestDocumentServiceRenameDocument:
|
||||
document_id=document_id,
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
)
|
||||
|
||||
# Act
|
||||
@ -1139,7 +1139,7 @@ class TestDocumentServiceRenameDocument:
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=tenant_id,
|
||||
doc_metadata={"existing_key": "existing_value"},
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
)
|
||||
|
||||
# Act
|
||||
@ -1187,7 +1187,7 @@ class TestDocumentServiceRenameDocument:
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=tenant_id,
|
||||
data_source_info={"upload_file_id": upload_file.id},
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
)
|
||||
|
||||
# Act
|
||||
@ -1277,7 +1277,7 @@ class TestDocumentServiceRenameDocument:
|
||||
document_id=document_id,
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=str(uuid4()),
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
|
||||
@ -16,6 +16,7 @@ from models.dataset import (
|
||||
DatasetPermission,
|
||||
DatasetPermissionEnum,
|
||||
)
|
||||
from models.enums import DataSourceType
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService
|
||||
from services.errors.account import NoPermissionError
|
||||
|
||||
@ -67,7 +68,7 @@ class DatasetPermissionTestDataFactory:
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description="desc",
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique="high_quality",
|
||||
created_by=created_by,
|
||||
permission=permission,
|
||||
|
||||
@ -15,6 +15,7 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, DatasetPermissionEnum, Document, ExternalKnowledgeBindings, Pipeline
|
||||
from models.enums import DatasetRuntimeMode, DataSourceType, DocumentCreatedFrom, IndexingStatus
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import RerankingModel, RetrievalModel
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity
|
||||
@ -74,7 +75,7 @@ class DatasetServiceIntegrationDataFactory:
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description=description,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique=indexing_technique,
|
||||
created_by=created_by,
|
||||
provider=provider,
|
||||
@ -98,13 +99,13 @@ class DatasetServiceIntegrationDataFactory:
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=1,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
data_source_info='{"upload_file_id": "upload-file-id"}',
|
||||
batch=str(uuid4()),
|
||||
name=name,
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=created_by,
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
doc_form="text_model",
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
@ -437,7 +438,7 @@ class TestDatasetServiceCreateRagPipelineDataset:
|
||||
created_pipeline = db_session_with_containers.get(Pipeline, result.pipeline_id)
|
||||
assert created_dataset is not None
|
||||
assert created_dataset.name == entity.name
|
||||
assert created_dataset.runtime_mode == "rag_pipeline"
|
||||
assert created_dataset.runtime_mode == DatasetRuntimeMode.RAG_PIPELINE
|
||||
assert created_dataset.created_by == account.id
|
||||
assert created_dataset.permission == DatasetPermissionEnum.ONLY_ME
|
||||
assert created_pipeline is not None
|
||||
|
||||
@ -14,6 +14,7 @@ import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.dataset import Dataset, Document
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus
|
||||
from services.dataset_service import DocumentService
|
||||
from services.errors.document import DocumentIndexingError
|
||||
|
||||
@ -42,7 +43,7 @@ class DocumentBatchUpdateIntegrationDataFactory:
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id or str(uuid4()),
|
||||
name=name,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=created_by or str(uuid4()),
|
||||
)
|
||||
if dataset_id:
|
||||
@ -72,11 +73,11 @@ class DocumentBatchUpdateIntegrationDataFactory:
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=position,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
data_source_info=json.dumps({"upload_file_id": str(uuid4())}),
|
||||
batch=f"batch-{uuid4()}",
|
||||
name=name,
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=created_by or str(uuid4()),
|
||||
doc_form="text_model",
|
||||
)
|
||||
@ -85,7 +86,9 @@ class DocumentBatchUpdateIntegrationDataFactory:
|
||||
document.archived = archived
|
||||
document.indexing_status = indexing_status
|
||||
document.completed_at = (
|
||||
completed_at if completed_at is not None else (FIXED_TIME if indexing_status == "completed" else None)
|
||||
completed_at
|
||||
if completed_at is not None
|
||||
else (FIXED_TIME if indexing_status == IndexingStatus.COMPLETED else None)
|
||||
)
|
||||
|
||||
for key, value in kwargs.items():
|
||||
@ -243,7 +246,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
|
||||
dataset=dataset,
|
||||
document_ids=document_ids,
|
||||
enabled=True,
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
)
|
||||
|
||||
# Act
|
||||
@ -277,7 +280,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
enabled=False,
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
completed_at=FIXED_TIME,
|
||||
)
|
||||
|
||||
@ -306,7 +309,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
enabled=True,
|
||||
indexing_status="indexing",
|
||||
indexing_status=IndexingStatus.INDEXING,
|
||||
completed_at=None,
|
||||
)
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ from uuid import uuid4
|
||||
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom
|
||||
from services.dataset_service import DatasetService
|
||||
|
||||
|
||||
@ -58,7 +59,7 @@ class DatasetDeleteIntegrationDataFactory:
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=f"dataset-{uuid4()}",
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique=indexing_technique,
|
||||
index_struct=index_struct,
|
||||
created_by=created_by,
|
||||
@ -84,10 +85,10 @@ class DatasetDeleteIntegrationDataFactory:
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
position=1,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch=f"batch-{uuid4()}",
|
||||
name="Document",
|
||||
created_from="upload_file",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=created_by,
|
||||
doc_form=doc_form,
|
||||
)
|
||||
|
||||
@ -14,6 +14,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom
|
||||
from services.dataset_service import SegmentService
|
||||
|
||||
|
||||
@ -62,7 +63,7 @@ class SegmentServiceTestDataFactory:
|
||||
tenant_id=tenant_id,
|
||||
name=f"Test Dataset {uuid4()}",
|
||||
description="Test description",
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique="high_quality",
|
||||
created_by=created_by,
|
||||
permission=DatasetPermissionEnum.ONLY_ME,
|
||||
@ -82,10 +83,10 @@ class SegmentServiceTestDataFactory:
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
position=1,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch=f"batch-{uuid4()}",
|
||||
name=f"test-doc-{uuid4()}.txt",
|
||||
created_from="api",
|
||||
created_from=DocumentCreatedFrom.API,
|
||||
created_by=created_by,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user