diff --git a/api/.env.example b/api/.env.example index d5d4e4486f..5a13ac69de 100644 --- a/api/.env.example +++ b/api/.env.example @@ -65,7 +65,7 @@ OPENDAL_FS_ROOT=storage # S3 Storage configuration S3_USE_AWS_MANAGED_IAM=false -S3_ENDPOINT=https://your-bucket-name.storage.s3.clooudflare.com +S3_ENDPOINT=https://your-bucket-name.storage.s3.cloudflare.com S3_BUCKET_NAME=your-bucket-name S3_ACCESS_KEY=your-access-key S3_SECRET_KEY=your-secret-key @@ -74,7 +74,7 @@ S3_REGION=your-region # Azure Blob Storage configuration AZURE_BLOB_ACCOUNT_NAME=your-account-name AZURE_BLOB_ACCOUNT_KEY=your-account-key -AZURE_BLOB_CONTAINER_NAME=yout-container-name +AZURE_BLOB_CONTAINER_NAME=your-container-name AZURE_BLOB_ACCOUNT_URL=https://.blob.core.windows.net # Aliyun oss Storage configuration @@ -88,7 +88,7 @@ ALIYUN_OSS_REGION=your-region ALIYUN_OSS_PATH=your-path # Google Storage configuration -GOOGLE_STORAGE_BUCKET_NAME=yout-bucket-name +GOOGLE_STORAGE_BUCKET_NAME=your-bucket-name GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64=your-google-service-account-json-base64-string # Tencent COS Storage configuration diff --git a/api/.ruff.toml b/api/.ruff.toml index 26a1b977a9..f30275a943 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -67,7 +67,7 @@ ignore = [ "SIM105", # suppressible-exception "SIM107", # return-in-try-except-finally "SIM108", # if-else-block-instead-of-if-exp - "SIM113", # eumerate-for-loop + "SIM113", # enumerate-for-loop "SIM117", # multiple-with-statements "SIM210", # if-expr-with-true-false ] diff --git a/api/commands.py b/api/commands.py index c6e450b3ee..76c8d3e382 100644 --- a/api/commands.py +++ b/api/commands.py @@ -563,8 +563,13 @@ def create_tenant(email: str, language: Optional[str] = None, name: Optional[str new_password = secrets.token_urlsafe(16) # register account - account = RegisterService.register(email=email, name=account_name, password=new_password, language=language) - + account = RegisterService.register( + email=email, + name=account_name, + password=new_password, + language=language, + create_workspace_required=False, + ) TenantService.create_owner_tenant_if_not_exist(account, name) click.echo( @@ -584,7 +589,7 @@ def upgrade_db(): click.echo(click.style("Starting database migration.", fg="green")) # run db migration - import flask_migrate + import flask_migrate # type: ignore flask_migrate.upgrade() diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index fcecb346b0..5865ddcc8b 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -659,7 +659,7 @@ class RagEtlConfig(BaseSettings): UNSTRUCTURED_API_KEY: Optional[str] = Field( description="API key for Unstructured.io service", - default=None, + default="", ) SCARF_NO_ANALYTICS: Optional[str] = Field( diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 89a638ae54..30cd93a010 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -232,7 +232,7 @@ class DataSourceNotionApi(Resource): args["doc_form"], args["doc_language"], ) - return response, 200 + return response.model_dump(), 200 class DataSourceNotionDatasetSyncApi(Resource): diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index f3c3736b25..0c0d2e2003 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -464,7 +464,7 @@ class DatasetIndexingEstimateApi(Resource): except Exception as e: raise IndexingEstimateError(str(e)) - return response, 200 + return response.model_dump(), 200 class DatasetRelatedAppListApi(Resource): @@ -733,6 +733,18 @@ class DatasetPermissionUserListApi(Resource): }, 200 +class DatasetAutoDisableLogApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, dataset_id): + dataset_id_str = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) + if dataset is None: + raise NotFound("Dataset not found.") + return DatasetService.get_dataset_auto_disable_logs(dataset_id_str), 200 + + api.add_resource(DatasetListApi, "/datasets") api.add_resource(DatasetApi, "/datasets/") api.add_resource(DatasetUseCheckApi, "/datasets//use-check") @@ -747,3 +759,4 @@ api.add_resource(DatasetApiBaseUrlApi, "/datasets/api-base-info") api.add_resource(DatasetRetrievalSettingApi, "/datasets/retrieval-setting") api.add_resource(DatasetRetrievalSettingMockApi, "/datasets/retrieval-setting/") api.add_resource(DatasetPermissionUserListApi, "/datasets//permission-part-users") +api.add_resource(DatasetAutoDisableLogApi, "/datasets//auto-disable-logs") diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index c236e1a431..3c132bc3d0 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -52,6 +52,7 @@ from fields.document_fields import ( from libs.login import login_required from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile from services.dataset_service import DatasetService, DocumentService +from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig from tasks.add_document_to_index_task import add_document_to_index_task from tasks.remove_document_from_index_task import remove_document_from_index_task @@ -267,20 +268,22 @@ class DatasetDocumentListApi(Resource): parser.add_argument("duplicate", type=bool, default=True, nullable=False, location="json") parser.add_argument("original_document_id", type=str, required=False, location="json") parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") + parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") + parser.add_argument( "doc_language", type=str, default="English", required=False, nullable=False, location="json" ) - parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") args = parser.parse_args() + knowledge_config = KnowledgeConfig(**args) - if not dataset.indexing_technique and not args["indexing_technique"]: + if not dataset.indexing_technique and not knowledge_config.indexing_technique: raise ValueError("indexing_technique is required.") # validate args - DocumentService.document_create_args_validate(args) + DocumentService.document_create_args_validate(knowledge_config) try: - documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user) + documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, current_user) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) except QuotaExceededError: @@ -290,6 +293,25 @@ class DatasetDocumentListApi(Resource): return {"documents": documents, "batch": batch} + @setup_required + @login_required + @account_initialization_required + def delete(self, dataset_id): + dataset_id = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id) + if dataset is None: + raise NotFound("Dataset not found.") + # check user's model setting + DatasetService.check_dataset_model_setting(dataset) + + try: + document_ids = request.args.getlist("document_id") + DocumentService.delete_documents(dataset, document_ids) + except services.errors.document.DocumentIndexingError: + raise DocumentIndexingError("Cannot delete document during indexing.") + + return {"result": "success"}, 204 + class DatasetInitApi(Resource): @setup_required @@ -325,9 +347,9 @@ class DatasetInitApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator if not current_user.is_dataset_editor: raise Forbidden() - - if args["indexing_technique"] == "high_quality": - if args["embedding_model"] is None or args["embedding_model_provider"] is None: + knowledge_config = KnowledgeConfig(**args) + if knowledge_config.indexing_technique == "high_quality": + if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None: raise ValueError("embedding model and embedding model provider are required for high quality indexing.") try: model_manager = ModelManager() @@ -346,11 +368,11 @@ class DatasetInitApi(Resource): raise ProviderNotInitializeError(ex.description) # validate args - DocumentService.document_create_args_validate(args) + DocumentService.document_create_args_validate(knowledge_config) try: dataset, documents, batch = DocumentService.save_document_without_dataset_id( - tenant_id=current_user.current_tenant_id, document_data=args, account=current_user + tenant_id=current_user.current_tenant_id, knowledge_config=knowledge_config, account=current_user ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -403,7 +425,7 @@ class DocumentIndexingEstimateApi(DocumentResource): indexing_runner = IndexingRunner() try: - response = indexing_runner.indexing_estimate( + estimate_response = indexing_runner.indexing_estimate( current_user.current_tenant_id, [extract_setting], data_process_rule_dict, @@ -411,6 +433,7 @@ class DocumentIndexingEstimateApi(DocumentResource): "English", dataset_id, ) + return estimate_response.model_dump(), 200 except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " @@ -423,7 +446,7 @@ class DocumentIndexingEstimateApi(DocumentResource): except Exception as e: raise IndexingEstimateError(str(e)) - return response + return response, 200 class DocumentBatchIndexingEstimateApi(DocumentResource): @@ -434,9 +457,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): dataset_id = str(dataset_id) batch = str(batch) documents = self.get_batch_documents(dataset_id, batch) - response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []} if not documents: - return response + return {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}, 200 data_process_rule = documents[0].dataset_process_rule data_process_rule_dict = data_process_rule.to_dict() info_list = [] @@ -514,6 +536,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): "English", dataset_id, ) + return response.model_dump(), 200 except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " @@ -525,7 +548,6 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): raise ProviderNotInitializeError(ex.description) except Exception as e: raise IndexingEstimateError(str(e)) - return response class DocumentBatchIndexingStatusApi(DocumentResource): @@ -598,7 +620,8 @@ class DocumentDetailApi(DocumentResource): if metadata == "only": response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata} elif metadata == "without": - process_rules = DatasetService.get_process_rules(dataset_id) + dataset_process_rules = DatasetService.get_process_rules(dataset_id) + document_process_rules = document.dataset_process_rule.to_dict() data_source_info = document.data_source_detail_dict response = { "id": document.id, @@ -606,7 +629,8 @@ class DocumentDetailApi(DocumentResource): "data_source_type": document.data_source_type, "data_source_info": data_source_info, "dataset_process_rule_id": document.dataset_process_rule_id, - "dataset_process_rule": process_rules, + "dataset_process_rule": dataset_process_rules, + "document_process_rule": document_process_rules, "name": document.name, "created_from": document.created_from, "created_by": document.created_by, @@ -629,7 +653,8 @@ class DocumentDetailApi(DocumentResource): "doc_language": document.doc_language, } else: - process_rules = DatasetService.get_process_rules(dataset_id) + dataset_process_rules = DatasetService.get_process_rules(dataset_id) + document_process_rules = document.dataset_process_rule.to_dict() data_source_info = document.data_source_detail_dict response = { "id": document.id, @@ -637,7 +662,8 @@ class DocumentDetailApi(DocumentResource): "data_source_type": document.data_source_type, "data_source_info": data_source_info, "dataset_process_rule_id": document.dataset_process_rule_id, - "dataset_process_rule": process_rules, + "dataset_process_rule": dataset_process_rules, + "document_process_rule": document_process_rules, "name": document.name, "created_from": document.created_from, "created_by": document.created_by, @@ -773,9 +799,8 @@ class DocumentStatusApi(DocumentResource): @login_required @account_initialization_required @cloud_edition_billing_resource_check("vector_space") - def patch(self, dataset_id, document_id, action): + def patch(self, dataset_id, action): dataset_id = str(dataset_id) - document_id = str(document_id) dataset = DatasetService.get_dataset(dataset_id) if dataset is None: raise NotFound("Dataset not found.") @@ -790,84 +815,79 @@ class DocumentStatusApi(DocumentResource): # check user's permission DatasetService.check_dataset_permission(dataset, current_user) - document = self.get_document(dataset_id, document_id) + document_ids = request.args.getlist("document_id") + for document_id in document_ids: + document = self.get_document(dataset_id, document_id) - indexing_cache_key = "document_{}_indexing".format(document.id) - cache_result = redis_client.get(indexing_cache_key) - if cache_result is not None: - raise InvalidActionError("Document is being indexed, please try again later") + indexing_cache_key = "document_{}_indexing".format(document.id) + cache_result = redis_client.get(indexing_cache_key) + if cache_result is not None: + raise InvalidActionError(f"Document:{document.name} is being indexed, please try again later") - if action == "enable": - if document.enabled: - raise InvalidActionError("Document already enabled.") + if action == "enable": + if document.enabled: + continue + document.enabled = True + document.disabled_at = None + document.disabled_by = None + document.updated_at = datetime.now(UTC).replace(tzinfo=None) + db.session.commit() - document.enabled = True - document.disabled_at = None - document.disabled_by = None - document.updated_at = datetime.now(UTC).replace(tzinfo=None) - db.session.commit() + # Set cache to prevent indexing the same document multiple times + redis_client.setex(indexing_cache_key, 600, 1) - # Set cache to prevent indexing the same document multiple times - redis_client.setex(indexing_cache_key, 600, 1) + add_document_to_index_task.delay(document_id) - add_document_to_index_task.delay(document_id) + elif action == "disable": + if not document.completed_at or document.indexing_status != "completed": + raise InvalidActionError(f"Document: {document.name} is not completed.") + if not document.enabled: + continue - return {"result": "success"}, 200 + document.enabled = False + document.disabled_at = datetime.now(UTC).replace(tzinfo=None) + document.disabled_by = current_user.id + document.updated_at = datetime.now(UTC).replace(tzinfo=None) + db.session.commit() - elif action == "disable": - if not document.completed_at or document.indexing_status != "completed": - raise InvalidActionError("Document is not completed.") - if not document.enabled: - raise InvalidActionError("Document already disabled.") - - document.enabled = False - document.disabled_at = datetime.now(UTC).replace(tzinfo=None) - document.disabled_by = current_user.id - document.updated_at = datetime.now(UTC).replace(tzinfo=None) - db.session.commit() - - # Set cache to prevent indexing the same document multiple times - redis_client.setex(indexing_cache_key, 600, 1) - - remove_document_from_index_task.delay(document_id) - - return {"result": "success"}, 200 - - elif action == "archive": - if document.archived: - raise InvalidActionError("Document already archived.") - - document.archived = True - document.archived_at = datetime.now(UTC).replace(tzinfo=None) - document.archived_by = current_user.id - document.updated_at = datetime.now(UTC).replace(tzinfo=None) - db.session.commit() - - if document.enabled: # Set cache to prevent indexing the same document multiple times redis_client.setex(indexing_cache_key, 600, 1) remove_document_from_index_task.delay(document_id) - return {"result": "success"}, 200 - elif action == "un_archive": - if not document.archived: - raise InvalidActionError("Document is not archived.") + elif action == "archive": + if document.archived: + continue - document.archived = False - document.archived_at = None - document.archived_by = None - document.updated_at = datetime.now(UTC).replace(tzinfo=None) - db.session.commit() + document.archived = True + document.archived_at = datetime.now(UTC).replace(tzinfo=None) + document.archived_by = current_user.id + document.updated_at = datetime.now(UTC).replace(tzinfo=None) + db.session.commit() - # Set cache to prevent indexing the same document multiple times - redis_client.setex(indexing_cache_key, 600, 1) + if document.enabled: + # Set cache to prevent indexing the same document multiple times + redis_client.setex(indexing_cache_key, 600, 1) - add_document_to_index_task.delay(document_id) + remove_document_from_index_task.delay(document_id) - return {"result": "success"}, 200 - else: - raise InvalidActionError() + elif action == "un_archive": + if not document.archived: + continue + document.archived = False + document.archived_at = None + document.archived_by = None + document.updated_at = datetime.now(UTC).replace(tzinfo=None) + db.session.commit() + + # Set cache to prevent indexing the same document multiple times + redis_client.setex(indexing_cache_key, 600, 1) + + add_document_to_index_task.delay(document_id) + + else: + raise InvalidActionError() + return {"result": "success"}, 200 class DocumentPauseApi(DocumentResource): @@ -1038,7 +1058,7 @@ api.add_resource( ) api.add_resource(DocumentDeleteApi, "/datasets//documents/") api.add_resource(DocumentMetadataApi, "/datasets//documents//metadata") -api.add_resource(DocumentStatusApi, "/datasets//documents//status/") +api.add_resource(DocumentStatusApi, "/datasets//documents/status//batch") api.add_resource(DocumentPauseApi, "/datasets//documents//processing/pause") api.add_resource(DocumentRecoverApi, "/datasets//documents//processing/resume") api.add_resource(DocumentRetryApi, "/datasets//retry") diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 2d5933ca23..96654c09fd 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -1,5 +1,4 @@ import uuid -from datetime import UTC, datetime import pandas as pd from flask import request @@ -10,7 +9,13 @@ from werkzeug.exceptions import Forbidden, NotFound import services from controllers.console import api from controllers.console.app.error import ProviderNotInitializeError -from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError +from controllers.console.datasets.error import ( + ChildChunkDeleteIndexError, + ChildChunkIndexingError, + InvalidActionError, + NoFileUploadedError, + TooManyFilesError, +) from controllers.console.wraps import ( account_initialization_required, cloud_edition_billing_knowledge_limit_check, @@ -20,15 +25,15 @@ from controllers.console.wraps import ( from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType -from extensions.ext_database import db from extensions.ext_redis import redis_client -from fields.segment_fields import segment_fields +from fields.segment_fields import child_chunk_fields, segment_fields from libs.login import login_required -from models import DocumentSegment +from models.dataset import ChildChunk, DocumentSegment from services.dataset_service import DatasetService, DocumentService, SegmentService +from services.entities.knowledge_entities.knowledge_entities import ChildChunkUpdateArgs, SegmentUpdateArgs +from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError +from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task -from tasks.disable_segment_from_index_task import disable_segment_from_index_task -from tasks.enable_segment_to_index_task import enable_segment_to_index_task class DatasetDocumentSegmentListApi(Resource): @@ -53,15 +58,16 @@ class DatasetDocumentSegmentListApi(Resource): raise NotFound("Document not found.") parser = reqparse.RequestParser() - parser.add_argument("last_id", type=str, default=None, location="args") parser.add_argument("limit", type=int, default=20, location="args") parser.add_argument("status", type=str, action="append", default=[], location="args") parser.add_argument("hit_count_gte", type=int, default=None, location="args") parser.add_argument("enabled", type=str, default="all", location="args") parser.add_argument("keyword", type=str, default=None, location="args") + parser.add_argument("page", type=int, default=1, location="args") + args = parser.parse_args() - last_id = args["last_id"] + page = args["page"] limit = min(args["limit"], 100) status_list = args["status"] hit_count_gte = args["hit_count_gte"] @@ -69,14 +75,7 @@ class DatasetDocumentSegmentListApi(Resource): query = DocumentSegment.query.filter( DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id - ) - - if last_id is not None: - last_segment = db.session.get(DocumentSegment, str(last_id)) - if last_segment: - query = query.filter(DocumentSegment.position > last_segment.position) - else: - return {"data": [], "has_more": False, "limit": limit}, 200 + ).order_by(DocumentSegment.position.asc()) if status_list: query = query.filter(DocumentSegment.status.in_(status_list)) @@ -93,21 +92,44 @@ class DatasetDocumentSegmentListApi(Resource): elif args["enabled"].lower() == "false": query = query.filter(DocumentSegment.enabled == False) - total = query.count() - segments = query.order_by(DocumentSegment.position).limit(limit + 1).all() + segments = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) - has_more = False - if len(segments) > limit: - has_more = True - segments = segments[:-1] - - return { - "data": marshal(segments, segment_fields), - "doc_form": document.doc_form, - "has_more": has_more, + response = { + "data": marshal(segments.items, segment_fields), "limit": limit, - "total": total, - }, 200 + "total": segments.total, + "total_pages": segments.pages, + "page": page, + } + return response, 200 + + @setup_required + @login_required + @account_initialization_required + def delete(self, dataset_id, document_id): + # check dataset + dataset_id = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound("Dataset not found.") + # check user's model setting + DatasetService.check_dataset_model_setting(dataset) + # check document + document_id = str(document_id) + document = DocumentService.get_document(dataset_id, document_id) + if not document: + raise NotFound("Document not found.") + segment_ids = request.args.getlist("segment_id") + + # The role of the current user in the ta table must be admin or owner + if not current_user.is_editor: + raise Forbidden() + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + SegmentService.delete_segments(segment_ids, document, dataset) + return {"result": "success"}, 200 class DatasetDocumentSegmentApi(Resource): @@ -115,11 +137,15 @@ class DatasetDocumentSegmentApi(Resource): @login_required @account_initialization_required @cloud_edition_billing_resource_check("vector_space") - def patch(self, dataset_id, segment_id, action): + def patch(self, dataset_id, document_id, action): dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: raise NotFound("Dataset not found.") + document_id = str(document_id) + document = DocumentService.get_document(dataset_id, document_id) + if not document: + raise NotFound("Document not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # The role of the current user in the ta table must be admin, owner, or editor @@ -147,59 +173,17 @@ class DatasetDocumentSegmentApi(Resource): ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) + segment_ids = request.args.getlist("segment_id") - segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id - ).first() - - if not segment: - raise NotFound("Segment not found.") - - if segment.status != "completed": - raise NotFound("Segment is not completed, enable or disable function is not allowed") - - document_indexing_cache_key = "document_{}_indexing".format(segment.document_id) + document_indexing_cache_key = "document_{}_indexing".format(document.id) cache_result = redis_client.get(document_indexing_cache_key) if cache_result is not None: raise InvalidActionError("Document is being indexed, please try again later") - - indexing_cache_key = "segment_{}_indexing".format(segment.id) - cache_result = redis_client.get(indexing_cache_key) - if cache_result is not None: - raise InvalidActionError("Segment is being indexed, please try again later") - - if action == "enable": - if segment.enabled: - raise InvalidActionError("Segment is already enabled.") - - segment.enabled = True - segment.disabled_at = None - segment.disabled_by = None - db.session.commit() - - # Set cache to prevent indexing the same segment multiple times - redis_client.setex(indexing_cache_key, 600, 1) - - enable_segment_to_index_task.delay(segment.id) - - return {"result": "success"}, 200 - elif action == "disable": - if not segment.enabled: - raise InvalidActionError("Segment is already disabled.") - - segment.enabled = False - segment.disabled_at = datetime.now(UTC).replace(tzinfo=None) - segment.disabled_by = current_user.id - db.session.commit() - - # Set cache to prevent indexing the same segment multiple times - redis_client.setex(indexing_cache_key, 600, 1) - - disable_segment_from_index_task.delay(segment.id) - - return {"result": "success"}, 200 - else: - raise InvalidActionError() + try: + SegmentService.update_segments_status(segment_ids, action, dataset, document) + except Exception as e: + raise InvalidActionError(str(e)) + return {"result": "success"}, 200 class DatasetDocumentSegmentAddApi(Resource): @@ -307,9 +291,12 @@ class DatasetDocumentSegmentUpdateApi(Resource): parser.add_argument("content", type=str, required=True, nullable=False, location="json") parser.add_argument("answer", type=str, required=False, nullable=True, location="json") parser.add_argument("keywords", type=list, required=False, nullable=True, location="json") + parser.add_argument( + "regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json" + ) args = parser.parse_args() SegmentService.segment_create_args_validate(args, document) - segment = SegmentService.update_segment(args, segment, document, dataset) + segment = SegmentService.update_segment(SegmentUpdateArgs(**args), segment, document, dataset) return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 @setup_required @@ -412,8 +399,248 @@ class DatasetDocumentSegmentBatchImportApi(Resource): return {"job_id": job_id, "job_status": cache_result.decode()}, 200 +class ChildChunkAddApi(Resource): + @setup_required + @login_required + @account_initialization_required + @cloud_edition_billing_resource_check("vector_space") + @cloud_edition_billing_knowledge_limit_check("add_segment") + def post(self, dataset_id, document_id, segment_id): + # check dataset + dataset_id = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound("Dataset not found.") + # check document + document_id = str(document_id) + document = DocumentService.get_document(dataset_id, document_id) + if not document: + raise NotFound("Document not found.") + # check segment + segment_id = str(segment_id) + segment = DocumentSegment.query.filter( + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id + ).first() + if not segment: + raise NotFound("Segment not found.") + if not current_user.is_editor: + raise Forbidden() + # check embedding model setting + if dataset.indexing_technique == "high_quality": + try: + model_manager = ModelManager() + model_manager.get_model_instance( + tenant_id=current_user.current_tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model, + ) + except LLMBadRequestError: + raise ProviderNotInitializeError( + "No Embedding Model available. Please configure a valid provider " + "in the Settings -> Model Provider." + ) + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + # validate args + parser = reqparse.RequestParser() + parser.add_argument("content", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + try: + child_chunk = SegmentService.create_child_chunk(args.get("content"), segment, document, dataset) + except ChildChunkIndexingServiceError as e: + raise ChildChunkIndexingError(str(e)) + return {"data": marshal(child_chunk, child_chunk_fields)}, 200 + + @setup_required + @login_required + @account_initialization_required + def get(self, dataset_id, document_id, segment_id): + # check dataset + dataset_id = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound("Dataset not found.") + # check user's model setting + DatasetService.check_dataset_model_setting(dataset) + # check document + document_id = str(document_id) + document = DocumentService.get_document(dataset_id, document_id) + if not document: + raise NotFound("Document not found.") + # check segment + segment_id = str(segment_id) + segment = DocumentSegment.query.filter( + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id + ).first() + if not segment: + raise NotFound("Segment not found.") + parser = reqparse.RequestParser() + parser.add_argument("limit", type=int, default=20, location="args") + parser.add_argument("keyword", type=str, default=None, location="args") + parser.add_argument("page", type=int, default=1, location="args") + + args = parser.parse_args() + + page = args["page"] + limit = min(args["limit"], 100) + keyword = args["keyword"] + + child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword) + return { + "data": marshal(child_chunks.items, child_chunk_fields), + "total": child_chunks.total, + "total_pages": child_chunks.pages, + "page": page, + "limit": limit, + }, 200 + + @setup_required + @login_required + @account_initialization_required + @cloud_edition_billing_resource_check("vector_space") + def patch(self, dataset_id, document_id, segment_id): + # check dataset + dataset_id = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound("Dataset not found.") + # check user's model setting + DatasetService.check_dataset_model_setting(dataset) + # check document + document_id = str(document_id) + document = DocumentService.get_document(dataset_id, document_id) + if not document: + raise NotFound("Document not found.") + # check segment + segment_id = str(segment_id) + segment = DocumentSegment.query.filter( + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id + ).first() + if not segment: + raise NotFound("Segment not found.") + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + # validate args + parser = reqparse.RequestParser() + parser.add_argument("chunks", type=list, required=True, nullable=False, location="json") + args = parser.parse_args() + try: + chunks = [ChildChunkUpdateArgs(**chunk) for chunk in args.get("chunks")] + child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset) + except ChildChunkIndexingServiceError as e: + raise ChildChunkIndexingError(str(e)) + return {"data": marshal(child_chunks, child_chunk_fields)}, 200 + + +class ChildChunkUpdateApi(Resource): + @setup_required + @login_required + @account_initialization_required + def delete(self, dataset_id, document_id, segment_id, child_chunk_id): + # check dataset + dataset_id = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound("Dataset not found.") + # check user's model setting + DatasetService.check_dataset_model_setting(dataset) + # check document + document_id = str(document_id) + document = DocumentService.get_document(dataset_id, document_id) + if not document: + raise NotFound("Document not found.") + # check segment + segment_id = str(segment_id) + segment = DocumentSegment.query.filter( + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id + ).first() + if not segment: + raise NotFound("Segment not found.") + # check child chunk + child_chunk_id = str(child_chunk_id) + child_chunk = ChildChunk.query.filter( + ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id + ).first() + if not child_chunk: + raise NotFound("Child chunk not found.") + # The role of the current user in the ta table must be admin or owner + if not current_user.is_editor: + raise Forbidden() + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + try: + SegmentService.delete_child_chunk(child_chunk, dataset) + except ChildChunkDeleteIndexServiceError as e: + raise ChildChunkDeleteIndexError(str(e)) + return {"result": "success"}, 200 + + @setup_required + @login_required + @account_initialization_required + @cloud_edition_billing_resource_check("vector_space") + def patch(self, dataset_id, document_id, segment_id, child_chunk_id): + # check dataset + dataset_id = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound("Dataset not found.") + # check user's model setting + DatasetService.check_dataset_model_setting(dataset) + # check document + document_id = str(document_id) + document = DocumentService.get_document(dataset_id, document_id) + if not document: + raise NotFound("Document not found.") + # check segment + segment_id = str(segment_id) + segment = DocumentSegment.query.filter( + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id + ).first() + if not segment: + raise NotFound("Segment not found.") + # check child chunk + child_chunk_id = str(child_chunk_id) + child_chunk = ChildChunk.query.filter( + ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id + ).first() + if not child_chunk: + raise NotFound("Child chunk not found.") + # The role of the current user in the ta table must be admin or owner + if not current_user.is_editor: + raise Forbidden() + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + # validate args + parser = reqparse.RequestParser() + parser.add_argument("content", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + try: + child_chunk = SegmentService.update_child_chunk( + args.get("content"), child_chunk, segment, document, dataset + ) + except ChildChunkIndexingServiceError as e: + raise ChildChunkIndexingError(str(e)) + return {"data": marshal(child_chunk, child_chunk_fields)}, 200 + + api.add_resource(DatasetDocumentSegmentListApi, "/datasets//documents//segments") -api.add_resource(DatasetDocumentSegmentApi, "/datasets//segments//") +api.add_resource( + DatasetDocumentSegmentApi, "/datasets//documents//segment/" +) api.add_resource(DatasetDocumentSegmentAddApi, "/datasets//documents//segment") api.add_resource( DatasetDocumentSegmentUpdateApi, @@ -424,3 +651,11 @@ api.add_resource( "/datasets//documents//segments/batch_import", "/datasets/batch_import_status/", ) +api.add_resource( + ChildChunkAddApi, + "/datasets//documents//segments//child_chunks", +) +api.add_resource( + ChildChunkUpdateApi, + "/datasets//documents//segments//child_chunks/", +) diff --git a/api/controllers/console/datasets/error.py b/api/controllers/console/datasets/error.py index 6a7a3971a8..2f00a84de6 100644 --- a/api/controllers/console/datasets/error.py +++ b/api/controllers/console/datasets/error.py @@ -89,3 +89,15 @@ class IndexingEstimateError(BaseHTTPException): error_code = "indexing_estimate_error" description = "Knowledge indexing estimate failed: {message}" code = 500 + + +class ChildChunkIndexingError(BaseHTTPException): + error_code = "child_chunk_indexing_error" + description = "Create child chunk index failed: {message}" + code = 500 + + +class ChildChunkDeleteIndexError(BaseHTTPException): + error_code = "child_chunk_delete_index_error" + description = "Delete child chunk index failed: {message}" + code = 500 diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index c3488de299..690297048e 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -69,7 +69,7 @@ class MessageFeedbackApi(InstalledAppResource): args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, current_user, args["rating"], args["content"]) + MessageService.create_feedback(app_model, message_id, current_user, args.get("rating"), args.get("content")) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index 522c7509b9..bed89a99a5 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -108,7 +108,7 @@ class MessageFeedbackApi(Resource): args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, end_user, args["rating"], args["content"]) + MessageService.create_feedback(app_model, message_id, end_user, args.get("rating"), args.get("content")) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 34afe2837f..84c58c62df 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -22,6 +22,7 @@ from fields.document_fields import document_fields, document_status_fields from libs.login import current_user from models.dataset import Dataset, Document, DocumentSegment from services.dataset_service import DocumentService +from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig from services.file_service import FileService @@ -67,13 +68,14 @@ class DocumentAddByTextApi(DatasetApiResource): "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, } args["data_source"] = data_source + knowledge_config = KnowledgeConfig(**args) # validate args - DocumentService.document_create_args_validate(args) + DocumentService.document_create_args_validate(knowledge_config) try: documents, batch = DocumentService.save_document_with_dataset_id( dataset=dataset, - document_data=args, + knowledge_config=knowledge_config, account=current_user, dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, created_from="api", @@ -122,12 +124,13 @@ class DocumentUpdateByTextApi(DatasetApiResource): args["data_source"] = data_source # validate args args["original_document_id"] = str(document_id) - DocumentService.document_create_args_validate(args) + knowledge_config = KnowledgeConfig(**args) + DocumentService.document_create_args_validate(knowledge_config) try: documents, batch = DocumentService.save_document_with_dataset_id( dataset=dataset, - document_data=args, + knowledge_config=knowledge_config, account=current_user, dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, created_from="api", @@ -186,12 +189,13 @@ class DocumentAddByFileApi(DatasetApiResource): data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}} args["data_source"] = data_source # validate args - DocumentService.document_create_args_validate(args) + knowledge_config = KnowledgeConfig(**args) + DocumentService.document_create_args_validate(knowledge_config) try: documents, batch = DocumentService.save_document_with_dataset_id( dataset=dataset, - document_data=args, + knowledge_config=knowledge_config, account=dataset.created_by_account, dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, created_from="api", @@ -245,12 +249,14 @@ class DocumentUpdateByFileApi(DatasetApiResource): args["data_source"] = data_source # validate args args["original_document_id"] = str(document_id) - DocumentService.document_create_args_validate(args) + + knowledge_config = KnowledgeConfig(**args) + DocumentService.document_create_args_validate(knowledge_config) try: documents, batch = DocumentService.save_document_with_dataset_id( dataset=dataset, - document_data=args, + knowledge_config=knowledge_config, account=dataset.created_by_account, dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, created_from="api", diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 34904574a8..1c500f51bf 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -16,6 +16,7 @@ from extensions.ext_database import db from fields.segment_fields import segment_fields from models.dataset import Dataset, DocumentSegment from services.dataset_service import DatasetService, DocumentService, SegmentService +from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs class SegmentApi(DatasetApiResource): @@ -193,7 +194,7 @@ class DatasetSegmentApi(DatasetApiResource): args = parser.parse_args() SegmentService.segment_create_args_validate(args["segment"], document) - segment = SegmentService.update_segment(args["segment"], segment, document, dataset) + segment = SegmentService.update_segment(SegmentUpdateArgs(**args["segment"]), segment, document, dataset) return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 0f47e64370..2afc11f601 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -105,10 +105,17 @@ class MessageFeedbackApi(WebApiResource): parser = reqparse.RequestParser() parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") + parser.add_argument("content", type=str, location="json", default=None) args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, end_user, args["rating"], args["content"]) + MessageService.create_feedback( + app_model=app_model, + message_id=message_id, + user=end_user, + rating=args.get("rating"), + content=args.get("content"), + ) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 684f2bc8a3..cb2a361f17 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -393,7 +393,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): try: return generate_task_pipeline.process() except ValueError as e: - if e.args[0] == "I/O operation on closed file.": # ignore this error + if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error raise GenerateTaskStoppedError() else: logger.exception(f"Failed to process generate task pipeline, conversation_id: {conversation.id}") diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 2e9b643d8b..1f1b2b568e 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -5,6 +5,9 @@ from collections.abc import Generator, Mapping from threading import Thread from typing import Any, Optional, Union +from sqlalchemy import select +from sqlalchemy.orm import Session + from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -66,7 +69,6 @@ from models.enums import CreatedByRole from models.workflow import ( Workflow, WorkflowNodeExecution, - WorkflowRun, WorkflowRunStatus, ) @@ -80,8 +82,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc _task_state: WorkflowTaskState _application_generate_entity: AdvancedChatAppGenerateEntity - _workflow: Workflow - _user: Union[Account, EndUser] _workflow_system_variables: dict[SystemVariableKey, Any] _wip_workflow_node_executions: dict[str, WorkflowNodeExecution] _conversation_name_generate_thread: Optional[Thread] = None @@ -97,32 +97,37 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc stream: bool, dialogue_count: int, ) -> None: - """ - Initialize AdvancedChatAppGenerateTaskPipeline. - :param application_generate_entity: application generate entity - :param workflow: workflow - :param queue_manager: queue manager - :param conversation: conversation - :param message: message - :param user: user - :param stream: stream - :param dialogue_count: dialogue count - """ - super().__init__(application_generate_entity, queue_manager, user, stream) + super().__init__( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + stream=stream, + ) - if isinstance(self._user, EndUser): - user_id = self._user.session_id + if isinstance(user, EndUser): + self._user_id = user.id + user_session_id = user.session_id + self._created_by_role = CreatedByRole.END_USER + elif isinstance(user, Account): + self._user_id = user.id + user_session_id = user.id + self._created_by_role = CreatedByRole.ACCOUNT else: - user_id = self._user.id + raise NotImplementedError(f"User type not supported: {type(user)}") + + self._workflow_id = workflow.id + self._workflow_features_dict = workflow.features_dict + + self._conversation_id = conversation.id + self._conversation_mode = conversation.mode + + self._message_id = message.id + self._message_created_at = int(message.created_at.timestamp()) - self._workflow = workflow - self._conversation = conversation - self._message = message self._workflow_system_variables = { SystemVariableKey.QUERY: message.query, SystemVariableKey.FILES: application_generate_entity.files, SystemVariableKey.CONVERSATION_ID: conversation.id, - SystemVariableKey.USER_ID: user_id, + SystemVariableKey.USER_ID: user_session_id, SystemVariableKey.DIALOGUE_COUNT: dialogue_count, SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, SystemVariableKey.WORKFLOW_ID: workflow.id, @@ -135,19 +140,16 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc self._conversation_name_generate_thread = None self._recorded_files: list[Mapping[str, Any]] = [] + self._workflow_run_id = "" def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: """ Process generate task pipeline. :return: """ - db.session.refresh(self._workflow) - db.session.refresh(self._user) - db.session.close() - # start generate conversation name thread self._conversation_name_generate_thread = self._generate_conversation_name( - self._conversation, self._application_generate_entity.query + conversation_id=self._conversation_id, query=self._application_generate_entity.query ) generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) @@ -173,12 +175,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc return ChatbotAppBlockingResponse( task_id=stream_response.task_id, data=ChatbotAppBlockingResponse.Data( - id=self._message.id, - mode=self._conversation.mode, - conversation_id=self._conversation.id, - message_id=self._message.id, + id=self._message_id, + mode=self._conversation_mode, + conversation_id=self._conversation_id, + message_id=self._message_id, answer=self._task_state.answer, - created_at=int(self._message.created_at.timestamp()), + created_at=self._message_created_at, **extras, ), ) @@ -196,9 +198,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc """ for stream_response in generator: yield ChatbotAppStreamResponse( - conversation_id=self._conversation.id, - message_id=self._message.id, - created_at=int(self._message.created_at.timestamp()), + conversation_id=self._conversation_id, + message_id=self._message_id, + created_at=self._message_created_at, stream_response=stream_response, ) @@ -216,7 +218,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc tts_publisher = None task_id = self._application_generate_entity.task_id tenant_id = self._application_generate_entity.app_config.tenant_id - features_dict = self._workflow.features_dict + features_dict = self._workflow_features_dict if ( features_dict.get("text_to_speech") @@ -268,7 +270,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc """ # init fake graph runtime state graph_runtime_state: Optional[GraphRuntimeState] = None - workflow_run: Optional[WorkflowRun] = None for queue_message in self._queue_manager.listen(): event = queue_message.event @@ -276,75 +277,97 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc if isinstance(event, QueuePingEvent): yield self._ping_stream_response() elif isinstance(event, QueueErrorEvent): - err = self._handle_error(event, self._message) + with Session(db.engine) as session: + err = self._handle_error(event=event, session=session, message_id=self._message_id) + session.commit() yield self._error_to_stream_response(err) break elif isinstance(event, QueueWorkflowStartedEvent): # override graph runtime state graph_runtime_state = event.graph_runtime_state - # init workflow run - workflow_run = self._handle_workflow_run_start() + with Session(db.engine) as session: + # init workflow run + workflow_run = self._handle_workflow_run_start( + session=session, + workflow_id=self._workflow_id, + user_id=self._user_id, + created_by_role=self._created_by_role, + ) + self._workflow_run_id = workflow_run.id + message = self._get_message(session=session) + if not message: + raise ValueError(f"Message not found: {self._message_id}") + message.workflow_run_id = workflow_run.id + workflow_start_resp = self._workflow_start_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + session.commit() - self._refetch_message() - self._message.workflow_run_id = workflow_run.id - - db.session.commit() - db.session.refresh(self._message) - db.session.close() - - yield self._workflow_start_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run - ) + yield workflow_start_resp elif isinstance( event, QueueNodeRetryEvent, ): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - workflow_node_execution = self._handle_workflow_node_execution_retried( - workflow_run=workflow_run, event=event - ) - response = self._workflow_node_retry_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) + with Session(db.engine) as session: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) + workflow_node_execution = self._handle_workflow_node_execution_retried( + session=session, workflow_run=workflow_run, event=event + ) + node_retry_resp = self._workflow_node_retry_to_stream_response( + session=session, + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + session.commit() - if response: - yield response + if node_retry_resp: + yield node_retry_resp elif isinstance(event, QueueNodeStartedEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event) + with Session(db.engine) as session: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) + workflow_node_execution = self._handle_node_execution_start( + session=session, workflow_run=workflow_run, event=event + ) - response_start = self._workflow_node_start_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) + node_start_resp = self._workflow_node_start_to_stream_response( + session=session, + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + session.commit() - if response_start: - yield response_start + if node_start_resp: + yield node_start_resp elif isinstance(event, QueueNodeSucceededEvent): - workflow_node_execution = self._handle_workflow_node_execution_success(event) - # Record files if it's an answer node or end node if event.node_type in [NodeType.ANSWER, NodeType.END]: self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {})) - response_finish = self._workflow_node_finish_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) + with Session(db.engine) as session: + workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event) - if response_finish: - yield response_finish + node_finish_resp = self._workflow_node_finish_to_stream_response( + session=session, + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + session.commit() + + if node_finish_resp: + yield node_finish_resp elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): - workflow_node_execution = self._handle_workflow_node_execution_failed(event) + with Session(db.engine) as session: + workflow_node_execution = self._handle_workflow_node_execution_failed(session=session, event=event) response_finish = self._workflow_node_finish_to_stream_response( event=event, @@ -355,158 +378,203 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc if response_finish: yield response_finish + if node_finish_resp: + yield node_finish_resp elif isinstance(event, QueueParallelBranchRunStartedEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - yield self._workflow_parallel_branch_start_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event - ) - elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): - if not workflow_run: - raise ValueError("workflow run not initialized.") - - yield self._workflow_parallel_branch_finished_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event - ) - elif isinstance(event, QueueIterationStartEvent): - if not workflow_run: - raise ValueError("workflow run not initialized.") - - yield self._workflow_iteration_start_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event - ) - elif isinstance(event, QueueIterationNextEvent): - if not workflow_run: - raise ValueError("workflow run not initialized.") - - yield self._workflow_iteration_next_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event - ) - elif isinstance(event, QueueIterationCompletedEvent): - if not workflow_run: - raise ValueError("workflow run not initialized.") - - yield self._workflow_iteration_completed_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event - ) - elif isinstance(event, QueueWorkflowSucceededEvent): - if not workflow_run: - raise ValueError("workflow run not initialized.") - - if not graph_runtime_state: - raise ValueError("workflow run not initialized.") - - workflow_run = self._handle_workflow_run_success( - workflow_run=workflow_run, - start_at=graph_runtime_state.start_at, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - outputs=event.outputs, - conversation_id=self._conversation.id, - trace_manager=trace_manager, - ) - - yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run - ) - - self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) - elif isinstance(event, QueueWorkflowPartialSuccessEvent): - if not workflow_run: - raise ValueError("workflow run not initialized.") - - if not graph_runtime_state: - raise ValueError("graph runtime state not initialized.") - - workflow_run = self._handle_workflow_run_partial_success( - workflow_run=workflow_run, - start_at=graph_runtime_state.start_at, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - outputs=event.outputs, - exceptions_count=event.exceptions_count, - conversation_id=None, - trace_manager=trace_manager, - ) - - yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run - ) - - self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) - elif isinstance(event, QueueWorkflowFailedEvent): - if not workflow_run: - raise ValueError("workflow run not initialized.") - - if not graph_runtime_state: - raise ValueError("graph runtime state not initialized.") - - workflow_run = self._handle_workflow_run_failed( - workflow_run=workflow_run, - start_at=graph_runtime_state.start_at, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - status=WorkflowRunStatus.FAILED, - error=event.error, - conversation_id=self._conversation.id, - trace_manager=trace_manager, - exceptions_count=event.exceptions_count, - ) - - yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run - ) - - err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}")) - yield self._error_to_stream_response(self._handle_error(err_event, self._message)) - break - elif isinstance(event, QueueStopEvent): - if workflow_run and graph_runtime_state: - workflow_run = self._handle_workflow_run_failed( + with Session(db.engine) as session: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) + parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, + event=event, + ) + + yield parallel_start_resp + elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + + with Session(db.engine) as session: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) + parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + + yield parallel_finish_resp + elif isinstance(event, QueueIterationStartEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + + with Session(db.engine) as session: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) + iter_start_resp = self._workflow_iteration_start_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + + yield iter_start_resp + elif isinstance(event, QueueIterationNextEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + + with Session(db.engine) as session: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) + iter_next_resp = self._workflow_iteration_next_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + + yield iter_next_resp + elif isinstance(event, QueueIterationCompletedEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + + with Session(db.engine) as session: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) + iter_finish_resp = self._workflow_iteration_completed_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + + yield iter_finish_resp + elif isinstance(event, QueueWorkflowSucceededEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + + if not graph_runtime_state: + raise ValueError("workflow run not initialized.") + + with Session(db.engine) as session: + workflow_run = self._handle_workflow_run_success( + session=session, + workflow_run_id=self._workflow_run_id, start_at=graph_runtime_state.start_at, total_tokens=graph_runtime_state.total_tokens, total_steps=graph_runtime_state.node_run_steps, - status=WorkflowRunStatus.STOPPED, - error=event.get_stop_reason(), - conversation_id=self._conversation.id, + outputs=event.outputs, + conversation_id=self._conversation_id, trace_manager=trace_manager, ) - yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + workflow_finish_resp = self._workflow_finish_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) + session.commit() - # Save message - self._save_message(graph_runtime_state=graph_runtime_state) + yield workflow_finish_resp + self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) + elif isinstance(event, QueueWorkflowPartialSuccessEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + if not graph_runtime_state: + raise ValueError("graph runtime state not initialized.") + + with Session(db.engine) as session: + workflow_run = self._handle_workflow_run_partial_success( + session=session, + workflow_run_id=self._workflow_run_id, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + outputs=event.outputs, + exceptions_count=event.exceptions_count, + conversation_id=None, + trace_manager=trace_manager, + ) + workflow_finish_resp = self._workflow_finish_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + session.commit() + + yield workflow_finish_resp + self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) + elif isinstance(event, QueueWorkflowFailedEvent): + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + if not graph_runtime_state: + raise ValueError("graph runtime state not initialized.") + + with Session(db.engine) as session: + workflow_run = self._handle_workflow_run_failed( + session=session, + workflow_run_id=self._workflow_run_id, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + status=WorkflowRunStatus.FAILED, + error=event.error, + conversation_id=self._conversation_id, + trace_manager=trace_manager, + exceptions_count=event.exceptions_count, + ) + workflow_finish_resp = self._workflow_finish_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}")) + err = self._handle_error(event=err_event, session=session, message_id=self._message_id) + session.commit() + + yield workflow_finish_resp + yield self._error_to_stream_response(err) + break + elif isinstance(event, QueueStopEvent): + if self._workflow_run_id and graph_runtime_state: + with Session(db.engine) as session: + workflow_run = self._handle_workflow_run_failed( + session=session, + workflow_run_id=self._workflow_run_id, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + status=WorkflowRunStatus.STOPPED, + error=event.get_stop_reason(), + conversation_id=self._conversation_id, + trace_manager=trace_manager, + ) + workflow_finish_resp = self._workflow_finish_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + ) + # Save message + self._save_message(session=session, graph_runtime_state=graph_runtime_state) + session.commit() + + yield workflow_finish_resp yield self._message_end_to_stream_response() break elif isinstance(event, QueueRetrieverResourcesEvent): self._handle_retriever_resources(event) - self._refetch_message() - - self._message.message_metadata = ( - json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None - ) - - db.session.commit() - db.session.refresh(self._message) - db.session.close() + with Session(db.engine) as session: + message = self._get_message(session=session) + message.message_metadata = ( + json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None + ) + session.commit() elif isinstance(event, QueueAnnotationReplyEvent): self._handle_annotation_reply(event) - self._refetch_message() - - self._message.message_metadata = ( - json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None - ) - - db.session.commit() - db.session.refresh(self._message) - db.session.close() + with Session(db.engine) as session: + message = self._get_message(session=session) + message.message_metadata = ( + json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None + ) + session.commit() elif isinstance(event, QueueTextChunkEvent): delta_text = event.text if delta_text is None: @@ -523,7 +591,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc self._task_state.answer += delta_text yield self._message_to_stream_response( - answer=delta_text, message_id=self._message.id, from_variable_selector=event.from_variable_selector + answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector ) elif isinstance(event, QueueMessageReplaceEvent): # published by moderation @@ -538,7 +606,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc yield self._message_replace_to_stream_response(answer=output_moderation_answer) # Save message - self._save_message(graph_runtime_state=graph_runtime_state) + with Session(db.engine) as session: + self._save_message(session=session, graph_runtime_state=graph_runtime_state) + session.commit() yield self._message_end_to_stream_response() elif isinstance(event, QueueAgentLogEvent): @@ -553,54 +623,46 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc if self._conversation_name_generate_thread: self._conversation_name_generate_thread.join() - def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: - self._refetch_message() - - self._message.answer = self._task_state.answer - self._message.provider_response_latency = time.perf_counter() - self._start_at - self._message.message_metadata = ( + def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: + message = self._get_message(session=session) + message.answer = self._task_state.answer + message.provider_response_latency = time.perf_counter() - self._start_at + message.message_metadata = ( json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None ) message_files = [ MessageFile( - message_id=self._message.id, + message_id=message.id, type=file["type"], transfer_method=file["transfer_method"], url=file["remote_url"], belongs_to="assistant", upload_file_id=file["related_id"], created_by_role=CreatedByRole.ACCOUNT - if self._message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} + if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else CreatedByRole.END_USER, - created_by=self._message.from_account_id or self._message.from_end_user_id or "", + created_by=message.from_account_id or message.from_end_user_id or "", ) for file in self._recorded_files ] - db.session.add_all(message_files) + session.add_all(message_files) if graph_runtime_state and graph_runtime_state.llm_usage: usage = graph_runtime_state.llm_usage - self._message.message_tokens = usage.prompt_tokens - self._message.message_unit_price = usage.prompt_unit_price - self._message.message_price_unit = usage.prompt_price_unit - self._message.answer_tokens = usage.completion_tokens - self._message.answer_unit_price = usage.completion_unit_price - self._message.answer_price_unit = usage.completion_price_unit - self._message.total_price = usage.total_price - self._message.currency = usage.currency - + message.message_tokens = usage.prompt_tokens + message.message_unit_price = usage.prompt_unit_price + message.message_price_unit = usage.prompt_price_unit + message.answer_tokens = usage.completion_tokens + message.answer_unit_price = usage.completion_unit_price + message.answer_price_unit = usage.completion_price_unit + message.total_price = usage.total_price + message.currency = usage.currency self._task_state.metadata["usage"] = jsonable_encoder(usage) else: self._task_state.metadata["usage"] = jsonable_encoder(LLMUsage.empty_usage()) - - db.session.commit() - message_was_created.send( - self._message, + message, application_generate_entity=self._application_generate_entity, - conversation=self._conversation, - is_first_message=self._application_generate_entity.conversation_id is None, - extras=self._application_generate_entity.extras, ) def _message_end_to_stream_response(self) -> MessageEndStreamResponse: @@ -617,7 +679,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc return MessageEndStreamResponse( task_id=self._application_generate_entity.task_id, - id=self._message.id, + id=self._message_id, files=self._recorded_files, metadata=extras.get("metadata", {}), ) @@ -645,11 +707,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc return False - def _refetch_message(self) -> None: - """ - Refetch message. - :return: - """ - message = db.session.query(Message).filter(Message.id == self._message.id).first() - if message: - self._message = message + def _get_message(self, *, session: Session): + stmt = select(Message).where(Message.id == self._message_id) + message = session.scalar(stmt) + if not message: + raise ValueError(f"Message not found: {self._message_id}") + return message diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index c2e35faf89..4e3aa840ce 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -70,14 +70,13 @@ class MessageBasedAppGenerator(BaseAppGenerator): queue_manager=queue_manager, conversation=conversation, message=message, - user=user, stream=stream, ) try: return generate_task_pipeline.process() except ValueError as e: - if e.args[0] == "I/O operation on closed file.": # ignore this error + if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error raise GenerateTaskStoppedError() else: logger.exception(f"Failed to handle response, conversation_id: {conversation.id}") diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 9a5f90f998..cbfb535848 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -325,7 +325,7 @@ class WorkflowAppGenerator(BaseAppGenerator): try: return generate_task_pipeline.process() except ValueError as e: - if e.args[0] == "I/O operation on closed file.": # ignore this error + if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error raise GenerateTaskStoppedError() else: logger.exception( diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index af0698d701..2258747a2c 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -3,6 +3,8 @@ import time from collections.abc import Generator from typing import Any, Optional, Union +from sqlalchemy.orm import Session + from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk from core.app.apps.base_app_queue_manager import AppQueueManager @@ -51,6 +53,7 @@ from core.ops.ops_trace_manager import TraceQueueManager from core.workflow.enums import SystemVariableKey from extensions.ext_database import db from models.account import Account +from models.enums import CreatedByRole from models.model import EndUser from models.workflow import ( Workflow, @@ -69,8 +72,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application. """ - _workflow: Workflow - _user: Union[Account, EndUser] _task_state: WorkflowTaskState _application_generate_entity: WorkflowAppGenerateEntity _workflow_system_variables: dict[SystemVariableKey, Any] @@ -84,25 +85,29 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa user: Union[Account, EndUser], stream: bool, ) -> None: - """ - Initialize GenerateTaskPipeline. - :param application_generate_entity: application generate entity - :param workflow: workflow - :param queue_manager: queue manager - :param user: user - :param stream: is streamed - """ - super().__init__(application_generate_entity, queue_manager, user, stream) + super().__init__( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + stream=stream, + ) - if isinstance(self._user, EndUser): - user_id = self._user.session_id + if isinstance(user, EndUser): + self._user_id = user.id + user_session_id = user.session_id + self._created_by_role = CreatedByRole.END_USER + elif isinstance(user, Account): + self._user_id = user.id + user_session_id = user.id + self._created_by_role = CreatedByRole.ACCOUNT else: - user_id = self._user.id + raise ValueError(f"Invalid user type: {type(user)}") + + self._workflow_id = workflow.id + self._workflow_features_dict = workflow.features_dict - self._workflow = workflow self._workflow_system_variables = { SystemVariableKey.FILES: application_generate_entity.files, - SystemVariableKey.USER_ID: user_id, + SystemVariableKey.USER_ID: user_session_id, SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, SystemVariableKey.WORKFLOW_ID: workflow.id, SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, @@ -118,10 +123,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa Process generate task pipeline. :return: """ - db.session.refresh(self._workflow) - db.session.refresh(self._user) - db.session.close() - generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) if self._stream: return self._to_stream_response(generator) @@ -188,7 +189,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa tts_publisher = None task_id = self._application_generate_entity.task_id tenant_id = self._application_generate_entity.app_config.tenant_id - features_dict = self._workflow.features_dict + features_dict = self._workflow_features_dict if ( features_dict.get("text_to_speech") @@ -237,7 +238,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa :return: """ graph_runtime_state = None - workflow_run = None for queue_message in self._queue_manager.listen(): event = queue_message.event @@ -245,180 +245,261 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa if isinstance(event, QueuePingEvent): yield self._ping_stream_response() elif isinstance(event, QueueErrorEvent): - err = self._handle_error(event) + err = self._handle_error(event=event) yield self._error_to_stream_response(err) break elif isinstance(event, QueueWorkflowStartedEvent): # override graph runtime state graph_runtime_state = event.graph_runtime_state - # init workflow run - workflow_run = self._handle_workflow_run_start() - yield self._workflow_start_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run - ) + with Session(db.engine) as session: + # init workflow run + workflow_run = self._handle_workflow_run_start( + session=session, + workflow_id=self._workflow_id, + user_id=self._user_id, + created_by_role=self._created_by_role, + ) + self._workflow_run_id = workflow_run.id + start_resp = self._workflow_start_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + session.commit() + + yield start_resp elif isinstance( event, QueueNodeRetryEvent, ): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - workflow_node_execution = self._handle_workflow_node_execution_retried( - workflow_run=workflow_run, event=event - ) - - response = self._workflow_node_retry_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) + with Session(db.engine) as session: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) + workflow_node_execution = self._handle_workflow_node_execution_retried( + session=session, workflow_run=workflow_run, event=event + ) + response = self._workflow_node_retry_to_stream_response( + session=session, + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + session.commit() if response: yield response elif isinstance(event, QueueNodeStartedEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event) - - node_start_response = self._workflow_node_start_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) + with Session(db.engine) as session: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) + workflow_node_execution = self._handle_node_execution_start( + session=session, workflow_run=workflow_run, event=event + ) + node_start_response = self._workflow_node_start_to_stream_response( + session=session, + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + session.commit() if node_start_response: yield node_start_response elif isinstance(event, QueueNodeSucceededEvent): - workflow_node_execution = self._handle_workflow_node_execution_success(event) - - node_success_response = self._workflow_node_finish_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) + with Session(db.engine) as session: + workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event) + node_success_response = self._workflow_node_finish_to_stream_response( + session=session, + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + session.commit() if node_success_response: yield node_success_response elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): - workflow_node_execution = self._handle_workflow_node_execution_failed(event) + with Session(db.engine) as session: + workflow_node_execution = self._handle_workflow_node_execution_failed( + session=session, + event=event, + ) + node_failed_response = self._workflow_node_finish_to_stream_response( + session=session, + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + session.commit() - node_failed_response = self._workflow_node_finish_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) if node_failed_response: yield node_failed_response elif isinstance(event, QueueParallelBranchRunStartedEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - yield self._workflow_parallel_branch_start_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event - ) + with Session(db.engine) as session: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) + parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + + yield parallel_start_resp + elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - yield self._workflow_parallel_branch_finished_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event - ) + with Session(db.engine) as session: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) + parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + + yield parallel_finish_resp + elif isinstance(event, QueueIterationStartEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - yield self._workflow_iteration_start_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event - ) + with Session(db.engine) as session: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) + iter_start_resp = self._workflow_iteration_start_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + + yield iter_start_resp + elif isinstance(event, QueueIterationNextEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - yield self._workflow_iteration_next_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event - ) + with Session(db.engine) as session: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) + iter_next_resp = self._workflow_iteration_next_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + + yield iter_next_resp + elif isinstance(event, QueueIterationCompletedEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - yield self._workflow_iteration_completed_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event - ) + with Session(db.engine) as session: + workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) + iter_finish_resp = self._workflow_iteration_completed_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + + yield iter_finish_resp + elif isinstance(event, QueueWorkflowSucceededEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") - workflow_run = self._handle_workflow_run_success( - workflow_run=workflow_run, - start_at=graph_runtime_state.start_at, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - outputs=event.outputs, - conversation_id=None, - trace_manager=trace_manager, - ) + with Session(db.engine) as session: + workflow_run = self._handle_workflow_run_success( + session=session, + workflow_run_id=self._workflow_run_id, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + outputs=event.outputs, + conversation_id=None, + trace_manager=trace_manager, + ) - # save workflow app log - self._save_workflow_app_log(workflow_run) + # save workflow app log + self._save_workflow_app_log(session=session, workflow_run=workflow_run) - yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run - ) + workflow_finish_resp = self._workflow_finish_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + ) + session.commit() + + yield workflow_finish_resp elif isinstance(event, QueueWorkflowPartialSuccessEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") - workflow_run = self._handle_workflow_run_partial_success( - workflow_run=workflow_run, - start_at=graph_runtime_state.start_at, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - outputs=event.outputs, - exceptions_count=event.exceptions_count, - conversation_id=None, - trace_manager=trace_manager, - ) + with Session(db.engine) as session: + workflow_run = self._handle_workflow_run_partial_success( + session=session, + workflow_run_id=self._workflow_run_id, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + outputs=event.outputs, + exceptions_count=event.exceptions_count, + conversation_id=None, + trace_manager=trace_manager, + ) - # save workflow app log - self._save_workflow_app_log(workflow_run) + # save workflow app log + self._save_workflow_app_log(session=session, workflow_run=workflow_run) - yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run - ) + workflow_finish_resp = self._workflow_finish_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + session.commit() + + yield workflow_finish_resp elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") - workflow_run = self._handle_workflow_run_failed( - workflow_run=workflow_run, - start_at=graph_runtime_state.start_at, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - status=WorkflowRunStatus.FAILED - if isinstance(event, QueueWorkflowFailedEvent) - else WorkflowRunStatus.STOPPED, - error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), - conversation_id=None, - trace_manager=trace_manager, - exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0, - ) - # save workflow app log - self._save_workflow_app_log(workflow_run) + with Session(db.engine) as session: + workflow_run = self._handle_workflow_run_failed( + session=session, + workflow_run_id=self._workflow_run_id, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + status=WorkflowRunStatus.FAILED + if isinstance(event, QueueWorkflowFailedEvent) + else WorkflowRunStatus.STOPPED, + error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), + conversation_id=None, + trace_manager=trace_manager, + exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0, + ) - yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run - ) + # save workflow app log + self._save_workflow_app_log(session=session, workflow_run=workflow_run) + + workflow_finish_resp = self._workflow_finish_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + session.commit() + + yield workflow_finish_resp elif isinstance(event, QueueTextChunkEvent): delta_text = event.text if delta_text is None: @@ -440,7 +521,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa if tts_publisher: tts_publisher.publish(None) - def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None: + def _save_workflow_app_log(self, *, session: Session, workflow_run: WorkflowRun) -> None: """ Save workflow app log. :return: @@ -462,12 +543,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa workflow_app_log.workflow_id = workflow_run.workflow_id workflow_app_log.workflow_run_id = workflow_run.id workflow_app_log.created_from = created_from.value - workflow_app_log.created_by_role = "account" if isinstance(self._user, Account) else "end_user" - workflow_app_log.created_by = self._user.id + workflow_app_log.created_by_role = self._created_by_role + workflow_app_log.created_by = self._user_id - db.session.add(workflow_app_log) - db.session.commit() - db.session.close() + session.add(workflow_app_log) def _text_chunk_to_stream_response( self, text: str, from_variable_selector: Optional[list[str]] = None diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index 03a81353d0..e363a7f642 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -1,6 +1,9 @@ import logging import time -from typing import Optional, Union +from typing import Optional + +from sqlalchemy import select +from sqlalchemy.orm import Session from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import ( @@ -17,9 +20,7 @@ from core.app.entities.task_entities import ( from core.errors.error import QuotaExceededError from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.moderation.output_moderation import ModerationRule, OutputModeration -from extensions.ext_database import db -from models.account import Account -from models.model import EndUser, Message +from models.model import Message logger = logging.getLogger(__name__) @@ -36,7 +37,6 @@ class BasedGenerateTaskPipeline: self, application_generate_entity: AppGenerateEntity, queue_manager: AppQueueManager, - user: Union[Account, EndUser], stream: bool, ) -> None: """ @@ -48,18 +48,11 @@ class BasedGenerateTaskPipeline: """ self._application_generate_entity = application_generate_entity self._queue_manager = queue_manager - self._user = user self._start_at = time.perf_counter() self._output_moderation_handler = self._init_output_moderation() self._stream = stream - def _handle_error(self, event: QueueErrorEvent, message: Optional[Message] = None): - """ - Handle error event. - :param event: event - :param message: message - :return: - """ + def _handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""): logger.debug("error: %s", event.error) e = event.error err: Exception @@ -71,16 +64,17 @@ class BasedGenerateTaskPipeline: else: err = Exception(e.description if getattr(e, "description", None) is not None else str(e)) - if message: - refetch_message = db.session.query(Message).filter(Message.id == message.id).first() + if not message_id or not session: + return err - if refetch_message: - err_desc = self._error_to_desc(err) - refetch_message.status = "error" - refetch_message.error = err_desc - - db.session.commit() + stmt = select(Message).where(Message.id == message_id) + message = session.scalar(stmt) + if not message: + return err + err_desc = self._error_to_desc(err) + message.status = "error" + message.error = err_desc return err def _error_to_desc(self, e: Exception) -> str: diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index b9f8e7ca56..c84f8ba3e4 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -5,6 +5,9 @@ from collections.abc import Generator from threading import Thread from typing import Optional, Union, cast +from sqlalchemy import select +from sqlalchemy.orm import Session + from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -55,8 +58,7 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_template_parser import PromptTemplateParser from events.message_event import message_was_created from extensions.ext_database import db -from models.account import Account -from models.model import AppMode, Conversation, EndUser, Message, MessageAgentThought +from models.model import AppMode, Conversation, Message, MessageAgentThought logger = logging.getLogger(__name__) @@ -77,23 +79,21 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan queue_manager: AppQueueManager, conversation: Conversation, message: Message, - user: Union[Account, EndUser], stream: bool, ) -> None: - """ - Initialize GenerateTaskPipeline. - :param application_generate_entity: application generate entity - :param queue_manager: queue manager - :param conversation: conversation - :param message: message - :param user: user - :param stream: stream - """ - super().__init__(application_generate_entity, queue_manager, user, stream) + super().__init__( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + stream=stream, + ) self._model_config = application_generate_entity.model_conf self._app_config = application_generate_entity.app_config - self._conversation = conversation - self._message = message + + self._conversation_id = conversation.id + self._conversation_mode = conversation.mode + + self._message_id = message.id + self._message_created_at = int(message.created_at.timestamp()) self._task_state = EasyUITaskState( llm_result=LLMResult( @@ -113,18 +113,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan CompletionAppBlockingResponse, Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None], ]: - """ - Process generate task pipeline. - :return: - """ - db.session.refresh(self._conversation) - db.session.refresh(self._message) - db.session.close() - if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION: # start generate conversation name thread self._conversation_name_generate_thread = self._generate_conversation_name( - self._conversation, self._application_generate_entity.query or "" + conversation_id=self._conversation_id, query=self._application_generate_entity.query or "" ) generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) @@ -148,15 +140,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan if self._task_state.metadata: extras["metadata"] = self._task_state.metadata response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse] - if self._conversation.mode == AppMode.COMPLETION.value: + if self._conversation_mode == AppMode.COMPLETION.value: response = CompletionAppBlockingResponse( task_id=self._application_generate_entity.task_id, data=CompletionAppBlockingResponse.Data( - id=self._message.id, - mode=self._conversation.mode, - message_id=self._message.id, + id=self._message_id, + mode=self._conversation_mode, + message_id=self._message_id, answer=cast(str, self._task_state.llm_result.message.content), - created_at=int(self._message.created_at.timestamp()), + created_at=self._message_created_at, **extras, ), ) @@ -164,12 +156,12 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan response = ChatbotAppBlockingResponse( task_id=self._application_generate_entity.task_id, data=ChatbotAppBlockingResponse.Data( - id=self._message.id, - mode=self._conversation.mode, - conversation_id=self._conversation.id, - message_id=self._message.id, + id=self._message_id, + mode=self._conversation_mode, + conversation_id=self._conversation_id, + message_id=self._message_id, answer=cast(str, self._task_state.llm_result.message.content), - created_at=int(self._message.created_at.timestamp()), + created_at=self._message_created_at, **extras, ), ) @@ -190,15 +182,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan for stream_response in generator: if isinstance(self._application_generate_entity, CompletionAppGenerateEntity): yield CompletionAppStreamResponse( - message_id=self._message.id, - created_at=int(self._message.created_at.timestamp()), + message_id=self._message_id, + created_at=self._message_created_at, stream_response=stream_response, ) else: yield ChatbotAppStreamResponse( - conversation_id=self._conversation.id, - message_id=self._message.id, - created_at=int(self._message.created_at.timestamp()), + conversation_id=self._conversation_id, + message_id=self._message_id, + created_at=self._message_created_at, stream_response=stream_response, ) @@ -265,7 +257,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan event = message.event if isinstance(event, QueueErrorEvent): - err = self._handle_error(event, self._message) + with Session(db.engine) as session: + err = self._handle_error(event=event, session=session, message_id=self._message_id) + session.commit() yield self._error_to_stream_response(err) break elif isinstance(event, QueueStopEvent | QueueMessageEndEvent): @@ -283,10 +277,12 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan self._task_state.llm_result.message.content = output_moderation_answer yield self._message_replace_to_stream_response(answer=output_moderation_answer) - # Save message - self._save_message(trace_manager) - - yield self._message_end_to_stream_response() + with Session(db.engine) as session: + # Save message + self._save_message(session=session, trace_manager=trace_manager) + session.commit() + message_end_resp = self._message_end_to_stream_response() + yield message_end_resp elif isinstance(event, QueueRetrieverResourcesEvent): self._handle_retriever_resources(event) elif isinstance(event, QueueAnnotationReplyEvent): @@ -320,9 +316,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan self._task_state.llm_result.message.content = current_content if isinstance(event, QueueLLMChunkEvent): - yield self._message_to_stream_response(cast(str, delta_text), self._message.id) + yield self._message_to_stream_response( + answer=cast(str, delta_text), + message_id=self._message_id, + ) else: - yield self._agent_message_to_stream_response(cast(str, delta_text), self._message.id) + yield self._agent_message_to_stream_response( + answer=cast(str, delta_text), + message_id=self._message_id, + ) elif isinstance(event, QueueMessageReplaceEvent): yield self._message_replace_to_stream_response(answer=event.text) elif isinstance(event, QueuePingEvent): @@ -334,7 +336,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan if self._conversation_name_generate_thread: self._conversation_name_generate_thread.join() - def _save_message(self, trace_manager: Optional[TraceQueueManager] = None) -> None: + def _save_message(self, *, session: Session, trace_manager: Optional[TraceQueueManager] = None) -> None: """ Save message. :return: @@ -342,53 +344,46 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan llm_result = self._task_state.llm_result usage = llm_result.usage - message = db.session.query(Message).filter(Message.id == self._message.id).first() + message_stmt = select(Message).where(Message.id == self._message_id) + message = session.scalar(message_stmt) if not message: - raise Exception(f"Message {self._message.id} not found") - self._message = message - conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first() + raise ValueError(f"message {self._message_id} not found") + conversation_stmt = select(Conversation).where(Conversation.id == self._conversation_id) + conversation = session.scalar(conversation_stmt) if not conversation: - raise Exception(f"Conversation {self._conversation.id} not found") - self._conversation = conversation + raise ValueError(f"Conversation {self._conversation_id} not found") - self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving( + message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving( self._model_config.mode, self._task_state.llm_result.prompt_messages ) - self._message.message_tokens = usage.prompt_tokens - self._message.message_unit_price = usage.prompt_unit_price - self._message.message_price_unit = usage.prompt_price_unit - self._message.answer = ( + message.message_tokens = usage.prompt_tokens + message.message_unit_price = usage.prompt_unit_price + message.message_price_unit = usage.prompt_price_unit + message.answer = ( PromptTemplateParser.remove_template_variables(cast(str, llm_result.message.content).strip()) if llm_result.message.content else "" ) - self._message.answer_tokens = usage.completion_tokens - self._message.answer_unit_price = usage.completion_unit_price - self._message.answer_price_unit = usage.completion_price_unit - self._message.provider_response_latency = time.perf_counter() - self._start_at - self._message.total_price = usage.total_price - self._message.currency = usage.currency - self._message.message_metadata = ( + message.answer_tokens = usage.completion_tokens + message.answer_unit_price = usage.completion_unit_price + message.answer_price_unit = usage.completion_price_unit + message.provider_response_latency = time.perf_counter() - self._start_at + message.total_price = usage.total_price + message.currency = usage.currency + message.message_metadata = ( json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None ) - db.session.commit() - if trace_manager: trace_manager.add_trace_task( TraceTask( - TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation.id, message_id=self._message.id + TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation_id, message_id=self._message_id ) ) message_was_created.send( - self._message, + message, application_generate_entity=self._application_generate_entity, - conversation=self._conversation, - is_first_message=self._application_generate_entity.app_config.app_mode in {AppMode.AGENT_CHAT, AppMode.CHAT} - and hasattr(self._application_generate_entity, "conversation_id") - and self._application_generate_entity.conversation_id is None, - extras=self._application_generate_entity.extras, ) def _handle_stop(self, event: QueueStopEvent) -> None: @@ -434,7 +429,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan return MessageEndStreamResponse( task_id=self._application_generate_entity.task_id, - id=self._message.id, + id=self._message_id, metadata=extras.get("metadata", {}), ) diff --git a/api/core/app/task_pipeline/message_cycle_manage.py b/api/core/app/task_pipeline/message_cycle_manage.py index 007543f6d0..15f2c25c66 100644 --- a/api/core/app/task_pipeline/message_cycle_manage.py +++ b/api/core/app/task_pipeline/message_cycle_manage.py @@ -36,7 +36,7 @@ class MessageCycleManage: ] _task_state: Union[EasyUITaskState, WorkflowTaskState] - def _generate_conversation_name(self, conversation: Conversation, query: str) -> Optional[Thread]: + def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]: """ Generate conversation name. :param conversation: conversation @@ -56,7 +56,7 @@ class MessageCycleManage: target=self._generate_conversation_name_worker, kwargs={ "flask_app": current_app._get_current_object(), # type: ignore - "conversation_id": conversation.id, + "conversation_id": conversation_id, "query": query, }, ) diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 115ef6ca53..8840737245 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -5,6 +5,7 @@ from datetime import UTC, datetime from typing import Any, Optional, Union, cast from uuid import uuid4 +from sqlalchemy import func, select from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity @@ -47,7 +48,6 @@ from core.workflow.enums import SystemVariableKey from core.workflow.nodes import NodeType from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.workflow_entry import WorkflowEntry -from extensions.ext_database import db from models.account import Account from models.enums import CreatedByRole, WorkflowRunTriggeredFrom from models.model import EndUser @@ -65,28 +65,35 @@ from .exc import WorkflowNodeExecutionNotFoundError, WorkflowRunNotFoundError class WorkflowCycleManage: _application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity] - _workflow: Workflow - _user: Union[Account, EndUser] _task_state: WorkflowTaskState _workflow_system_variables: dict[SystemVariableKey, Any] _wip_workflow_node_executions: dict[str, WorkflowNodeExecution] _wip_workflow_agent_logs: dict[str, list[AgentLogStreamResponse.Data]] - def _handle_workflow_run_start(self) -> WorkflowRun: - max_sequence = ( - db.session.query(db.func.max(WorkflowRun.sequence_number)) - .filter(WorkflowRun.tenant_id == self._workflow.tenant_id) - .filter(WorkflowRun.app_id == self._workflow.app_id) - .scalar() - or 0 + def _handle_workflow_run_start( + self, + *, + session: Session, + workflow_id: str, + user_id: str, + created_by_role: CreatedByRole, + ) -> WorkflowRun: + workflow_stmt = select(Workflow).where(Workflow.id == workflow_id) + workflow = session.scalar(workflow_stmt) + if not workflow: + raise ValueError(f"Workflow not found: {workflow_id}") + + max_sequence_stmt = select(func.max(WorkflowRun.sequence_number)).where( + WorkflowRun.tenant_id == workflow.tenant_id, + WorkflowRun.app_id == workflow.app_id, ) + max_sequence = session.scalar(max_sequence_stmt) or 0 new_sequence_number = max_sequence + 1 inputs = {**self._application_generate_entity.inputs} for key, value in (self._workflow_system_variables or {}).items(): if key.value == "conversation": continue - inputs[f"sys.{key.value}"] = value triggered_from = ( @@ -99,34 +106,33 @@ class WorkflowCycleManage: inputs = dict(WorkflowEntry.handle_special_values(inputs) or {}) # init workflow run - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = WorkflowRun() - system_id = self._workflow_system_variables[SystemVariableKey.WORKFLOW_RUN_ID] - workflow_run.id = system_id or str(uuid4()) - workflow_run.tenant_id = self._workflow.tenant_id - workflow_run.app_id = self._workflow.app_id - workflow_run.sequence_number = new_sequence_number - workflow_run.workflow_id = self._workflow.id - workflow_run.type = self._workflow.type - workflow_run.triggered_from = triggered_from.value - workflow_run.version = self._workflow.version - workflow_run.graph = self._workflow.graph - workflow_run.inputs = json.dumps(inputs) - workflow_run.status = WorkflowRunStatus.RUNNING - workflow_run.created_by_role = ( - CreatedByRole.ACCOUNT if isinstance(self._user, Account) else CreatedByRole.END_USER - ) - workflow_run.created_by = self._user.id - workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None) + workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID, uuid4())) - session.add(workflow_run) - session.commit() + workflow_run = WorkflowRun() + workflow_run.id = workflow_run_id + workflow_run.tenant_id = workflow.tenant_id + workflow_run.app_id = workflow.app_id + workflow_run.sequence_number = new_sequence_number + workflow_run.workflow_id = workflow.id + workflow_run.type = workflow.type + workflow_run.triggered_from = triggered_from.value + workflow_run.version = workflow.version + workflow_run.graph = workflow.graph + workflow_run.inputs = json.dumps(inputs) + workflow_run.status = WorkflowRunStatus.RUNNING + workflow_run.created_by_role = created_by_role + workflow_run.created_by = user_id + workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None) + + session.add(workflow_run) return workflow_run def _handle_workflow_run_success( self, - workflow_run: WorkflowRun, + *, + session: Session, + workflow_run_id: str, start_at: float, total_tokens: int, total_steps: int, @@ -144,7 +150,7 @@ class WorkflowCycleManage: :param conversation_id: conversation id :return: """ - workflow_run = self._refetch_workflow_run(workflow_run.id) + workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id) outputs = WorkflowEntry.handle_special_values(outputs) @@ -155,9 +161,6 @@ class WorkflowCycleManage: workflow_run.total_steps = total_steps workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) - db.session.commit() - db.session.refresh(workflow_run) - if trace_manager: trace_manager.add_trace_task( TraceTask( @@ -168,13 +171,13 @@ class WorkflowCycleManage: ) ) - db.session.close() - return workflow_run def _handle_workflow_run_partial_success( self, - workflow_run: WorkflowRun, + *, + session: Session, + workflow_run_id: str, start_at: float, total_tokens: int, total_steps: int, @@ -183,18 +186,7 @@ class WorkflowCycleManage: conversation_id: Optional[str] = None, trace_manager: Optional[TraceQueueManager] = None, ) -> WorkflowRun: - """ - Workflow run success - :param workflow_run: workflow run - :param start_at: start time - :param total_tokens: total tokens - :param total_steps: total steps - :param outputs: outputs - :param conversation_id: conversation id - :return: - """ - workflow_run = self._refetch_workflow_run(workflow_run.id) - + workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id) outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None) workflow_run.status = WorkflowRunStatus.PARTIAL_SUCCESSED.value @@ -204,8 +196,6 @@ class WorkflowCycleManage: workflow_run.total_steps = total_steps workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) workflow_run.exceptions_count = exceptions_count - db.session.commit() - db.session.refresh(workflow_run) if trace_manager: trace_manager.add_trace_task( @@ -217,13 +207,13 @@ class WorkflowCycleManage: ) ) - db.session.close() - return workflow_run def _handle_workflow_run_failed( self, - workflow_run: WorkflowRun, + *, + session: Session, + workflow_run_id: str, start_at: float, total_tokens: int, total_steps: int, @@ -243,7 +233,7 @@ class WorkflowCycleManage: :param error: error message :return: """ - workflow_run = self._refetch_workflow_run(workflow_run.id) + workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id) workflow_run.status = status.value workflow_run.error = error @@ -252,21 +242,18 @@ class WorkflowCycleManage: workflow_run.total_steps = total_steps workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) workflow_run.exceptions_count = exceptions_count - db.session.commit() - running_workflow_node_executions = ( - db.session.query(WorkflowNodeExecution) - .filter( - WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, - WorkflowNodeExecution.app_id == workflow_run.app_id, - WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, - WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, - WorkflowNodeExecution.workflow_run_id == workflow_run.id, - WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value, - ) - .all() + stmt = select(WorkflowNodeExecution).where( + WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, + WorkflowNodeExecution.app_id == workflow_run.app_id, + WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, + WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + WorkflowNodeExecution.workflow_run_id == workflow_run.id, + WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value, ) + running_workflow_node_executions = session.scalars(stmt).all() + for workflow_node_execution in running_workflow_node_executions: workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value workflow_node_execution.error = error @@ -274,13 +261,6 @@ class WorkflowCycleManage: workflow_node_execution.elapsed_time = ( workflow_node_execution.finished_at - workflow_node_execution.created_at ).total_seconds() - db.session.commit() - - db.session.close() - - # with Session(db.engine, expire_on_commit=False) as session: - # session.add(workflow_run) - # session.refresh(workflow_run) if trace_manager: trace_manager.add_trace_task( @@ -295,49 +275,41 @@ class WorkflowCycleManage: return workflow_run def _handle_node_execution_start( - self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent + self, *, session: Session, workflow_run: WorkflowRun, event: QueueNodeStartedEvent ) -> WorkflowNodeExecution: - # init workflow node execution + workflow_node_execution = WorkflowNodeExecution() + workflow_node_execution.id = event.node_execution_id + workflow_node_execution.tenant_id = workflow_run.tenant_id + workflow_node_execution.app_id = workflow_run.app_id + workflow_node_execution.workflow_id = workflow_run.workflow_id + workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value + workflow_node_execution.workflow_run_id = workflow_run.id + workflow_node_execution.predecessor_node_id = event.predecessor_node_id + workflow_node_execution.index = event.node_run_index + workflow_node_execution.node_execution_id = event.node_execution_id + workflow_node_execution.node_id = event.node_id + workflow_node_execution.node_type = event.node_type.value + workflow_node_execution.title = event.node_data.title + workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value + workflow_node_execution.created_by_role = workflow_run.created_by_role + workflow_node_execution.created_by = workflow_run.created_by + workflow_node_execution.execution_metadata = json.dumps( + { + NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, + NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, + } + ) + workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) - with Session(db.engine, expire_on_commit=False) as session: - workflow_node_execution = WorkflowNodeExecution() - workflow_node_execution.tenant_id = workflow_run.tenant_id - workflow_node_execution.app_id = workflow_run.app_id - workflow_node_execution.workflow_id = workflow_run.workflow_id - workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value - workflow_node_execution.workflow_run_id = workflow_run.id - workflow_node_execution.predecessor_node_id = event.predecessor_node_id - workflow_node_execution.index = event.node_run_index - workflow_node_execution.node_execution_id = event.node_execution_id - workflow_node_execution.node_id = event.node_id - workflow_node_execution.node_type = event.node_type.value - workflow_node_execution.title = event.node_data.title - workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value - workflow_node_execution.created_by_role = workflow_run.created_by_role - workflow_node_execution.created_by = workflow_run.created_by - workflow_node_execution.execution_metadata = json.dumps( - { - NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, - NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, - } - ) - workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) - - session.add(workflow_node_execution) - session.commit() - session.refresh(workflow_node_execution) - - self._wip_workflow_node_executions[workflow_node_execution.node_execution_id] = workflow_node_execution + session.add(workflow_node_execution) return workflow_node_execution - def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution: - """ - Workflow node execution success - :param event: queue node succeeded event - :return: - """ - workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id) - + def _handle_workflow_node_execution_success( + self, *, session: Session, event: QueueNodeSucceededEvent + ) -> WorkflowNodeExecution: + workflow_node_execution = self._get_workflow_node_execution( + session=session, node_execution_id=event.node_execution_id + ) inputs = WorkflowEntry.handle_special_values(event.inputs) process_data = WorkflowEntry.handle_special_values(event.process_data) outputs = WorkflowEntry.handle_special_values(event.outputs) @@ -378,19 +350,22 @@ class WorkflowCycleManage: workflow_node_execution.finished_at = finished_at workflow_node_execution.elapsed_time = elapsed_time - self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id) - return workflow_node_execution def _handle_workflow_node_execution_failed( - self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent + self, + *, + session: Session, + event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent, ) -> WorkflowNodeExecution: """ Workflow node execution failed :param event: queue node failed event :return: """ - workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id) + workflow_node_execution = self._get_workflow_node_execution( + session=session, node_execution_id=event.node_execution_id + ) inputs = WorkflowEntry.handle_special_values(event.inputs) process_data = WorkflowEntry.handle_special_values(event.process_data) @@ -440,12 +415,10 @@ class WorkflowCycleManage: workflow_node_execution.elapsed_time = elapsed_time workflow_node_execution.execution_metadata = execution_metadata - self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id) - return workflow_node_execution def _handle_workflow_node_execution_retried( - self, workflow_run: WorkflowRun, event: QueueNodeRetryEvent + self, *, session: Session, workflow_run: WorkflowRun, event: QueueNodeRetryEvent ) -> WorkflowNodeExecution: """ Workflow node execution failed @@ -469,6 +442,7 @@ class WorkflowCycleManage: execution_metadata = json.dumps(merged_metadata) workflow_node_execution = WorkflowNodeExecution() + workflow_node_execution.id = event.node_execution_id workflow_node_execution.tenant_id = workflow_run.tenant_id workflow_node_execution.app_id = workflow_run.app_id workflow_node_execution.workflow_id = workflow_run.workflow_id @@ -491,10 +465,7 @@ class WorkflowCycleManage: workflow_node_execution.execution_metadata = execution_metadata workflow_node_execution.index = event.node_run_index - db.session.add(workflow_node_execution) - db.session.commit() - db.session.refresh(workflow_node_execution) - + session.add(workflow_node_execution) return workflow_node_execution ################################################# @@ -502,14 +473,14 @@ class WorkflowCycleManage: ################################################# def _workflow_start_to_stream_response( - self, task_id: str, workflow_run: WorkflowRun + self, + *, + session: Session, + task_id: str, + workflow_run: WorkflowRun, ) -> WorkflowStartStreamResponse: - """ - Workflow start to stream response. - :param task_id: task id - :param workflow_run: workflow run - :return: - """ + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session return WorkflowStartStreamResponse( task_id=task_id, workflow_run_id=workflow_run.id, @@ -523,36 +494,32 @@ class WorkflowCycleManage: ) def _workflow_finish_to_stream_response( - self, task_id: str, workflow_run: WorkflowRun + self, + *, + session: Session, + task_id: str, + workflow_run: WorkflowRun, ) -> WorkflowFinishStreamResponse: - """ - Workflow finish to stream response. - :param task_id: task id - :param workflow_run: workflow run - :return: - """ - # Attach WorkflowRun to an active session so "created_by_role" can be accessed. - workflow_run = db.session.merge(workflow_run) - - # Refresh to ensure any expired attributes are fully loaded - db.session.refresh(workflow_run) - created_by = None - if workflow_run.created_by_role == CreatedByRole.ACCOUNT.value: - created_by_account = workflow_run.created_by_account - if created_by_account: + if workflow_run.created_by_role == CreatedByRole.ACCOUNT: + stmt = select(Account).where(Account.id == workflow_run.created_by) + account = session.scalar(stmt) + if account: created_by = { - "id": created_by_account.id, - "name": created_by_account.name, - "email": created_by_account.email, + "id": account.id, + "name": account.name, + "email": account.email, + } + elif workflow_run.created_by_role == CreatedByRole.END_USER: + stmt = select(EndUser).where(EndUser.id == workflow_run.created_by) + end_user = session.scalar(stmt) + if end_user: + created_by = { + "id": end_user.id, + "user": end_user.session_id, } else: - created_by_end_user = workflow_run.created_by_end_user - if created_by_end_user: - created_by = { - "id": created_by_end_user.id, - "user": created_by_end_user.session_id, - } + raise NotImplementedError(f"unknown created_by_role: {workflow_run.created_by_role}") return WorkflowFinishStreamResponse( task_id=task_id, @@ -576,17 +543,20 @@ class WorkflowCycleManage: ) def _workflow_node_start_to_stream_response( - self, event: QueueNodeStartedEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution + self, + *, + session: Session, + event: QueueNodeStartedEvent, + task_id: str, + workflow_node_execution: WorkflowNodeExecution, ) -> Optional[NodeStartStreamResponse]: - """ - Workflow node start to stream response. - :param event: queue node started event - :param task_id: task id - :param workflow_node_execution: workflow node execution - :return: - """ + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session + if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: return None + if not workflow_node_execution.workflow_run_id: + return None response = NodeStartStreamResponse( task_id=task_id, @@ -622,6 +592,8 @@ class WorkflowCycleManage: def _workflow_node_finish_to_stream_response( self, + *, + session: Session, event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeInIterationFailedEvent @@ -629,15 +601,14 @@ class WorkflowCycleManage: task_id: str, workflow_node_execution: WorkflowNodeExecution, ) -> Optional[NodeFinishStreamResponse]: - """ - Workflow node finish to stream response. - :param event: queue node succeeded or failed event - :param task_id: task id - :param workflow_node_execution: workflow node execution - :return: - """ + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: return None + if not workflow_node_execution.workflow_run_id: + return None + if not workflow_node_execution.finished_at: + return None return NodeFinishStreamResponse( task_id=task_id, @@ -669,19 +640,20 @@ class WorkflowCycleManage: def _workflow_node_retry_to_stream_response( self, + *, + session: Session, event: QueueNodeRetryEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution, ) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]: - """ - Workflow node finish to stream response. - :param event: queue node succeeded or failed event - :param task_id: task id - :param workflow_node_execution: workflow node execution - :return: - """ + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: return None + if not workflow_node_execution.workflow_run_id: + return None + if not workflow_node_execution.finished_at: + return None return NodeRetryStreamResponse( task_id=task_id, @@ -713,15 +685,10 @@ class WorkflowCycleManage: ) def _workflow_parallel_branch_start_to_stream_response( - self, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent + self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent ) -> ParallelBranchStartStreamResponse: - """ - Workflow parallel branch start to stream response - :param task_id: task id - :param workflow_run: workflow run - :param event: parallel branch run started event - :return: - """ + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session return ParallelBranchStartStreamResponse( task_id=task_id, workflow_run_id=workflow_run.id, @@ -737,17 +704,14 @@ class WorkflowCycleManage: def _workflow_parallel_branch_finished_to_stream_response( self, + *, + session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent, ) -> ParallelBranchFinishedStreamResponse: - """ - Workflow parallel branch finished to stream response - :param task_id: task id - :param workflow_run: workflow run - :param event: parallel branch run succeeded or failed event - :return: - """ + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session return ParallelBranchFinishedStreamResponse( task_id=task_id, workflow_run_id=workflow_run.id, @@ -764,15 +728,10 @@ class WorkflowCycleManage: ) def _workflow_iteration_start_to_stream_response( - self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent + self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent ) -> IterationNodeStartStreamResponse: - """ - Workflow iteration start to stream response - :param task_id: task id - :param workflow_run: workflow run - :param event: iteration start event - :return: - """ + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session return IterationNodeStartStreamResponse( task_id=task_id, workflow_run_id=workflow_run.id, @@ -791,15 +750,10 @@ class WorkflowCycleManage: ) def _workflow_iteration_next_to_stream_response( - self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent + self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent ) -> IterationNodeNextStreamResponse: - """ - Workflow iteration next to stream response - :param task_id: task id - :param workflow_run: workflow run - :param event: iteration next event - :return: - """ + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session return IterationNodeNextStreamResponse( task_id=task_id, workflow_run_id=workflow_run.id, @@ -820,15 +774,10 @@ class WorkflowCycleManage: ) def _workflow_iteration_completed_to_stream_response( - self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent + self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent ) -> IterationNodeCompletedStreamResponse: - """ - Workflow iteration completed to stream response - :param task_id: task id - :param workflow_run: workflow run - :param event: iteration completed event - :return: - """ + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session return IterationNodeCompletedStreamResponse( task_id=task_id, workflow_run_id=workflow_run.id, @@ -912,27 +861,22 @@ class WorkflowCycleManage: return None - def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun: + def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun: """ Refetch workflow run :param workflow_run_id: workflow run id :return: """ - workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() - + stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id) + workflow_run = session.scalar(stmt) if not workflow_run: raise WorkflowRunNotFoundError(workflow_run_id) return workflow_run - def _refetch_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution: - """ - Refetch workflow node execution - :param node_execution_id: workflow node execution id - :return: - """ - workflow_node_execution = self._wip_workflow_node_executions.get(node_execution_id) - + def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution: + stmt = select(WorkflowNodeExecution).where(WorkflowNodeExecution.id == node_execution_id) + workflow_node_execution = session.scalar(stmt) if not workflow_node_execution: raise WorkflowNodeExecutionNotFoundError(node_execution_id) diff --git a/api/core/entities/knowledge_entities.py b/api/core/entities/knowledge_entities.py new file mode 100644 index 0000000000..90c9879733 --- /dev/null +++ b/api/core/entities/knowledge_entities.py @@ -0,0 +1,19 @@ +from typing import Optional + +from pydantic import BaseModel + + +class PreviewDetail(BaseModel): + content: str + child_chunks: Optional[list[str]] = None + + +class QAPreviewDetail(BaseModel): + question: str + answer: str + + +class IndexingEstimate(BaseModel): + total_segments: int + preview: list[PreviewDetail] + qa_preview: Optional[list[QAPreviewDetail]] = None diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index eed2d7e49a..0261a6309e 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -881,7 +881,7 @@ class ProviderConfiguration(BaseModel): # if llm name not in restricted llm list, remove it restrict_model_names = [rm.model for rm in restrict_models] for model in provider_models: - if model.model_type == ModelType.LLM and m.model not in restrict_model_names: + if model.model_type == ModelType.LLM and model.model not in restrict_model_names: model.status = ModelStatus.NO_PERMISSION elif not quota_configuration.is_valid: model.status = ModelStatus.QUOTA_EXCEEDED diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 2e422cf444..6e11c706ff 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -70,7 +70,8 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): retries += 1 if retries <= max_retries: time.sleep(BACKOFF_FACTOR * (2 ** (retries - 1))) - raise MaxRetriesExceededError(f"Reached maximum retries ({max_retries}) for URL {url}") + raise MaxRetriesExceededError( + f"Reached maximum retries ({max_retries}) for URL {url}") def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 685dbc8ed4..937ce1cc5e 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -8,22 +8,23 @@ import time import uuid from typing import Any, Optional, cast -from flask import Flask, current_app +from flask import current_app from flask_login import current_user # type: ignore from sqlalchemy.orm.exc import ObjectDeletedError from configs import dify_config +from core.entities.knowledge_entities import IndexingEstimate, PreviewDetail, QAPreviewDetail from core.errors.error import ProviderTokenNotInitError -from core.llm_generator.llm_generator import LLMGenerator from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelType from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting +from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import Document +from core.rag.models.document import ChildDocument, Document from core.rag.splitter.fixed_text_splitter import ( EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter, @@ -35,7 +36,7 @@ from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage from libs import helper -from models.dataset import Dataset, DatasetProcessRule, DocumentSegment +from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment from models.dataset import Document as DatasetDocument from models.model import UploadFile from services.feature_service import FeatureService @@ -115,6 +116,9 @@ class IndexingRunner: for document_segment in document_segments: db.session.delete(document_segment) + if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + # delete child chunks + db.session.query(ChildChunk).filter(ChildChunk.segment_id == document_segment.id).delete() db.session.commit() # get the process rule processing_rule = ( @@ -183,7 +187,22 @@ class IndexingRunner: "dataset_id": document_segment.dataset_id, }, ) - + if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + child_chunks = document_segment.child_chunks + if child_chunks: + child_documents = [] + for child_chunk in child_chunks: + child_document = ChildDocument( + page_content=child_chunk.content, + metadata={ + "doc_id": child_chunk.index_node_id, + "doc_hash": child_chunk.index_node_hash, + "document_id": document_segment.document_id, + "dataset_id": document_segment.dataset_id, + }, + ) + child_documents.append(child_document) + document.children = child_documents documents.append(document) # build index @@ -222,7 +241,7 @@ class IndexingRunner: doc_language: str = "English", dataset_id: Optional[str] = None, indexing_technique: str = "economy", - ) -> dict: + ) -> IndexingEstimate: """ Estimate the indexing for the document. """ @@ -258,31 +277,38 @@ class IndexingRunner: tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING, ) - preview_texts: list[str] = [] + preview_texts = [] # type: ignore + total_segments = 0 index_type = doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() - all_text_docs = [] for extract_setting in extract_settings: # extract - text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"]) - all_text_docs.extend(text_docs) processing_rule = DatasetProcessRule( mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"]) ) - - # get splitter - splitter = self._get_splitter(processing_rule, embedding_model_instance) - - # split to documents - documents = self._split_to_documents_for_estimate( - text_docs=text_docs, splitter=splitter, processing_rule=processing_rule + text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"]) + documents = index_processor.transform( + text_docs, + embedding_model_instance=embedding_model_instance, + process_rule=processing_rule.to_dict(), + tenant_id=current_user.current_tenant_id, + doc_language=doc_language, + preview=True, ) - total_segments += len(documents) for document in documents: - if len(preview_texts) < 5: - preview_texts.append(document.page_content) + if len(preview_texts) < 10: + if doc_form and doc_form == "qa_model": + preview_detail = QAPreviewDetail( + question=document.page_content, answer=document.metadata.get("answer") or "" + ) + preview_texts.append(preview_detail) + else: + preview_detail = PreviewDetail(content=document.page_content) # type: ignore + if document.children: + preview_detail.child_chunks = [child.page_content for child in document.children] # type: ignore + preview_texts.append(preview_detail) # delete image files and related db records image_upload_file_ids = get_image_upload_file_ids(document.page_content) @@ -299,15 +325,8 @@ class IndexingRunner: db.session.delete(image_file) if doc_form and doc_form == "qa_model": - if len(preview_texts) > 0: - # qa model document - response = LLMGenerator.generate_qa_document( - current_user.current_tenant_id, preview_texts[0], doc_language - ) - document_qa_list = self.format_split_text(response) - - return {"total_segments": total_segments * 20, "qa_preview": document_qa_list, "preview": preview_texts} - return {"total_segments": total_segments, "preview": preview_texts} + return IndexingEstimate(total_segments=total_segments * 20, qa_preview=preview_texts, preview=[]) + return IndexingEstimate(total_segments=total_segments, preview=preview_texts) # type: ignore def _extract( self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict @@ -401,31 +420,26 @@ class IndexingRunner: @staticmethod def _get_splitter( - processing_rule: DatasetProcessRule, embedding_model_instance: Optional[ModelInstance] + processing_rule_mode: str, + max_tokens: int, + chunk_overlap: int, + separator: str, + embedding_model_instance: Optional[ModelInstance], ) -> TextSplitter: """ Get the NodeParser object according to the processing rule. """ - character_splitter: TextSplitter - if processing_rule.mode == "custom": + if processing_rule_mode in ["custom", "hierarchical"]: # The user-defined segmentation rule - rules = json.loads(processing_rule.rules) - segmentation = rules["segmentation"] max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH - if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length: + if max_tokens < 50 or max_tokens > max_segmentation_tokens_length: raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.") - separator = segmentation["separator"] if separator: separator = separator.replace("\\n", "\n") - if segmentation.get("chunk_overlap"): - chunk_overlap = segmentation["chunk_overlap"] - else: - chunk_overlap = 0 - character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder( - chunk_size=segmentation["max_tokens"], + chunk_size=max_tokens, chunk_overlap=chunk_overlap, fixed_separator=separator, separators=["\n\n", "。", ". ", " ", ""], @@ -441,143 +455,7 @@ class IndexingRunner: embedding_model_instance=embedding_model_instance, ) - return character_splitter - - def _step_split( - self, - text_docs: list[Document], - splitter: TextSplitter, - dataset: Dataset, - dataset_document: DatasetDocument, - processing_rule: DatasetProcessRule, - ) -> list[Document]: - """ - Split the text documents into documents and save them to the document segment. - """ - documents = self._split_to_documents( - text_docs=text_docs, - splitter=splitter, - processing_rule=processing_rule, - tenant_id=dataset.tenant_id, - document_form=dataset_document.doc_form, - document_language=dataset_document.doc_language, - ) - - # save node to document segment - doc_store = DatasetDocumentStore( - dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id - ) - - # add document segments - doc_store.add_documents(documents) - - # update document status to indexing - cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - self._update_document_index_status( - document_id=dataset_document.id, - after_indexing_status="indexing", - extra_update_params={ - DatasetDocument.cleaning_completed_at: cur_time, - DatasetDocument.splitting_completed_at: cur_time, - }, - ) - - # update segment status to indexing - self._update_segments_by_document( - dataset_document_id=dataset_document.id, - update_params={ - DocumentSegment.status: "indexing", - DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), - }, - ) - - return documents - - def _split_to_documents( - self, - text_docs: list[Document], - splitter: TextSplitter, - processing_rule: DatasetProcessRule, - tenant_id: str, - document_form: str, - document_language: str, - ) -> list[Document]: - """ - Split the text documents into nodes. - """ - all_documents: list[Document] = [] - all_qa_documents: list[Document] = [] - for text_doc in text_docs: - # document clean - document_text = self._document_clean(text_doc.page_content, processing_rule) - text_doc.page_content = document_text - - # parse document to nodes - documents = splitter.split_documents([text_doc]) - split_documents = [] - for document_node in documents: - if document_node.page_content.strip(): - if document_node.metadata is not None: - doc_id = str(uuid.uuid4()) - hash = helper.generate_text_hash(document_node.page_content) - document_node.metadata["doc_id"] = doc_id - document_node.metadata["doc_hash"] = hash - # delete Splitter character - page_content = document_node.page_content - document_node.page_content = remove_leading_symbols(page_content) - - if document_node.page_content: - split_documents.append(document_node) - all_documents.extend(split_documents) - # processing qa document - if document_form == "qa_model": - for i in range(0, len(all_documents), 10): - threads = [] - sub_documents = all_documents[i : i + 10] - for doc in sub_documents: - document_format_thread = threading.Thread( - target=self.format_qa_document, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "tenant_id": tenant_id, - "document_node": doc, - "all_qa_documents": all_qa_documents, - "document_language": document_language, - }, - ) - threads.append(document_format_thread) - document_format_thread.start() - for thread in threads: - thread.join() - return all_qa_documents - return all_documents - - def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language): - format_documents = [] - if document_node.page_content is None or not document_node.page_content.strip(): - return - with flask_app.app_context(): - try: - # qa model document - response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content, document_language) - document_qa_list = self.format_split_text(response) - qa_documents = [] - for result in document_qa_list: - qa_document = Document( - page_content=result["question"], metadata=document_node.metadata.model_copy() - ) - if qa_document.metadata is not None: - doc_id = str(uuid.uuid4()) - hash = helper.generate_text_hash(result["question"]) - qa_document.metadata["answer"] = result["answer"] - qa_document.metadata["doc_id"] = doc_id - qa_document.metadata["doc_hash"] = hash - qa_documents.append(qa_document) - format_documents.extend(qa_documents) - except Exception as e: - logging.exception("Failed to format qa document") - - all_qa_documents.extend(format_documents) + return character_splitter # type: ignore def _split_to_documents_for_estimate( self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule @@ -624,11 +502,11 @@ class IndexingRunner: return document_text @staticmethod - def format_split_text(text): + def format_split_text(text: str) -> list[QAPreviewDetail]: regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)" matches = re.findall(regex, text, re.UNICODE) - return [{"question": q, "answer": re.sub(r"\n\s*", "\n", a.strip())} for q, a in matches if q and a] + return [QAPreviewDetail(question=q, answer=re.sub(r"\n\s*", "\n", a.strip())) for q, a in matches if q and a] def _load( self, @@ -654,13 +532,14 @@ class IndexingRunner: indexing_start_at = time.perf_counter() tokens = 0 chunk_size = 10 + if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX: + # create keyword index + create_keyword_thread = threading.Thread( + target=self._process_keyword_index, + args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), # type: ignore + ) + create_keyword_thread.start() - # create keyword index - create_keyword_thread = threading.Thread( - target=self._process_keyword_index, - args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), # type: ignore - ) - create_keyword_thread.start() if dataset.indexing_technique == "high_quality": with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: futures = [] @@ -680,8 +559,8 @@ class IndexingRunner: for future in futures: tokens += future.result() - - create_keyword_thread.join() + if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX: + create_keyword_thread.join() indexing_end_at = time.perf_counter() # update document status to completed @@ -791,28 +670,6 @@ class IndexingRunner: DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params) db.session.commit() - @staticmethod - def batch_add_segments(segments: list[DocumentSegment], dataset: Dataset): - """ - Batch add segments index processing - """ - documents = [] - for segment in segments: - document = Document( - page_content=segment.content, - metadata={ - "doc_id": segment.index_node_id, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - }, - ) - documents.append(document) - # save vector index - index_type = dataset.doc_form - index_processor = IndexProcessorFactory(index_type).init_index_processor() - index_processor.load(dataset, documents) - def _transform( self, index_processor: BaseIndexProcessor, @@ -854,7 +711,7 @@ class IndexingRunner: ) # add document segments - doc_store.add_documents(documents) + doc_store.add_documents(docs=documents, save_child=dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX) # update document status to indexing cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) diff --git a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py new file mode 100644 index 0000000000..03818741f6 --- /dev/null +++ b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py @@ -0,0 +1,770 @@ +import copy +import json +import logging +from collections.abc import Generator, Sequence +from typing import Optional, Union, cast + +import tiktoken +from openai import AzureOpenAI, Stream +from openai.types import Completion +from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall +from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall + +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentType, + PromptMessageFunction, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI +from core.model_runtime.model_providers.azure_openai._constant import LLM_BASE_MODELS +from core.model_runtime.utils import helper + +logger = logging.getLogger(__name__) + + +class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + base_model_name = self._get_base_model_name(credentials) + ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) + + if ai_model_entity and ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value: + # chat model + return self._chat_generate( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + ) + else: + # text completion model + return self._generate( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + stop=stop, + stream=stream, + user=user, + ) + + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: + base_model_name = self._get_base_model_name(credentials) + model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) + if not model_entity: + raise ValueError(f"Base Model Name {base_model_name} is invalid") + model_mode = model_entity.entity.model_properties.get(ModelPropertyKey.MODE) + + if model_mode == LLMMode.CHAT.value: + # chat model + return self._num_tokens_from_messages(credentials, prompt_messages, tools) + else: + # text completion model, do not support tool calling + content = prompt_messages[0].content + assert isinstance(content, str) + return self._num_tokens_from_string(credentials, content) + + def validate_credentials(self, model: str, credentials: dict) -> None: + if "openai_api_base" not in credentials: + raise CredentialsValidateFailedError("Azure OpenAI API Base Endpoint is required") + + if "openai_api_key" not in credentials: + raise CredentialsValidateFailedError("Azure OpenAI API key is required") + + if "base_model_name" not in credentials: + raise CredentialsValidateFailedError("Base Model Name is required") + + base_model_name = self._get_base_model_name(credentials) + ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) + + if not ai_model_entity: + raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid') + + try: + client = AzureOpenAI(**self._to_credential_kwargs(credentials)) + + if model.startswith("o1"): + client.chat.completions.create( + messages=[{"role": "user", "content": "ping"}], + model=model, + temperature=1, + max_completion_tokens=20, + stream=False, + ) + elif ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value: + # chat model + client.chat.completions.create( + messages=[{"role": "user", "content": "ping"}], + model=model, + temperature=0, + max_tokens=20, + stream=False, + ) + else: + # text completion model + client.completions.create( + prompt="ping", + model=model, + temperature=0, + max_tokens=20, + stream=False, + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: + base_model_name = self._get_base_model_name(credentials) + ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) + return ai_model_entity.entity if ai_model_entity else None + + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + client = AzureOpenAI(**self._to_credential_kwargs(credentials)) + + extra_model_kwargs = {} + + if stop: + extra_model_kwargs["stop"] = stop + + if user: + extra_model_kwargs["user"] = user + + # text completion model + response = client.completions.create( + prompt=prompt_messages[0].content, model=model, stream=stream, **model_parameters, **extra_model_kwargs + ) + + if stream: + return self._handle_generate_stream_response(model, credentials, response, prompt_messages) + + return self._handle_generate_response(model, credentials, response, prompt_messages) + + def _handle_generate_response( + self, model: str, credentials: dict, response: Completion, prompt_messages: list[PromptMessage] + ): + assistant_text = response.choices[0].text + + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage(content=assistant_text) + + # calculate num tokens + if response.usage: + # transform usage + prompt_tokens = response.usage.prompt_tokens + completion_tokens = response.usage.completion_tokens + else: + # calculate num tokens + content = prompt_messages[0].content + assert isinstance(content, str) + prompt_tokens = self._num_tokens_from_string(credentials, content) + completion_tokens = self._num_tokens_from_string(credentials, assistant_text) + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + # transform response + result = LLMResult( + model=response.model, + prompt_messages=prompt_messages, + message=assistant_prompt_message, + usage=usage, + system_fingerprint=response.system_fingerprint, + ) + + return result + + def _handle_generate_stream_response( + self, model: str, credentials: dict, response: Stream[Completion], prompt_messages: list[PromptMessage] + ) -> Generator: + full_text = "" + for chunk in response: + if len(chunk.choices) == 0: + continue + + delta = chunk.choices[0] + + if delta.finish_reason is None and (delta.text is None or delta.text == ""): + continue + + # transform assistant message to prompt message + text = delta.text or "" + assistant_prompt_message = AssistantPromptMessage(content=text) + + full_text += text + + if delta.finish_reason is not None: + # calculate num tokens + if chunk.usage: + # transform usage + prompt_tokens = chunk.usage.prompt_tokens + completion_tokens = chunk.usage.completion_tokens + else: + # calculate num tokens + content = prompt_messages[0].content + assert isinstance(content, str) + prompt_tokens = self._num_tokens_from_string(credentials, content) + completion_tokens = self._num_tokens_from_string(credentials, full_text) + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + yield LLMResultChunk( + model=chunk.model, + prompt_messages=prompt_messages, + system_fingerprint=chunk.system_fingerprint, + delta=LLMResultChunkDelta( + index=delta.index, + message=assistant_prompt_message, + finish_reason=delta.finish_reason, + usage=usage, + ), + ) + else: + yield LLMResultChunk( + model=chunk.model, + prompt_messages=prompt_messages, + system_fingerprint=chunk.system_fingerprint, + delta=LLMResultChunkDelta( + index=delta.index, + message=assistant_prompt_message, + ), + ) + + def _chat_generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + client = AzureOpenAI(**self._to_credential_kwargs(credentials)) + + response_format = model_parameters.get("response_format") + if response_format: + if response_format == "json_schema": + json_schema = model_parameters.get("json_schema") + if not json_schema: + raise ValueError("Must define JSON Schema when the response format is json_schema") + try: + schema = json.loads(json_schema) + except: + raise ValueError(f"not correct json_schema format: {json_schema}") + model_parameters.pop("json_schema") + model_parameters["response_format"] = {"type": "json_schema", "json_schema": schema} + else: + model_parameters["response_format"] = {"type": response_format} + + extra_model_kwargs = {} + + if tools: + extra_model_kwargs["tools"] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools] + + if stop: + extra_model_kwargs["stop"] = stop + + if user: + extra_model_kwargs["user"] = user + + # clear illegal prompt messages + prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages) + + block_as_stream = False + if model.startswith("o1"): + if "max_tokens" in model_parameters: + model_parameters["max_completion_tokens"] = model_parameters["max_tokens"] + del model_parameters["max_tokens"] + if stream: + block_as_stream = True + stream = False + + if "stream_options" in extra_model_kwargs: + del extra_model_kwargs["stream_options"] + + if "stop" in extra_model_kwargs: + del extra_model_kwargs["stop"] + + # chat model + response = client.chat.completions.create( + messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages], + model=model, + stream=stream, + **model_parameters, + **extra_model_kwargs, + ) + + if stream: + return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools) + + block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) + + if block_as_stream: + return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop) + + return block_result + + def _handle_chat_block_as_stream_response( + self, + block_result: LLMResult, + prompt_messages: list[PromptMessage], + stop: Optional[list[str]] = None, + ) -> Generator[LLMResultChunk, None, None]: + """ + Handle llm chat response + + :param model: model name + :param credentials: credentials + :param response: response + :param prompt_messages: prompt messages + :param tools: tools for tool calling + :param stop: stop words + :return: llm response chunk generator + """ + text = block_result.message.content + text = cast(str, text) + + if stop: + text = self.enforce_stop_tokens(text, stop) + + yield LLMResultChunk( + model=block_result.model, + prompt_messages=prompt_messages, + system_fingerprint=block_result.system_fingerprint, + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=text), + finish_reason="stop", + usage=block_result.usage, + ), + ) + + def _clear_illegal_prompt_messages(self, model: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: + """ + Clear illegal prompt messages for OpenAI API + + :param model: model name + :param prompt_messages: prompt messages + :return: cleaned prompt messages + """ + checklist = ["gpt-4-turbo", "gpt-4-turbo-2024-04-09"] + + if model in checklist: + # count how many user messages are there + user_message_count = len([m for m in prompt_messages if isinstance(m, UserPromptMessage)]) + if user_message_count > 1: + for prompt_message in prompt_messages: + if isinstance(prompt_message, UserPromptMessage): + if isinstance(prompt_message.content, list): + prompt_message.content = "\n".join( + [ + item.data + if item.type == PromptMessageContentType.TEXT + else "[IMAGE]" + if item.type == PromptMessageContentType.IMAGE + else "" + for item in prompt_message.content + ] + ) + + if model.startswith("o1"): + system_message_count = len([m for m in prompt_messages if isinstance(m, SystemPromptMessage)]) + if system_message_count > 0: + new_prompt_messages = [] + for prompt_message in prompt_messages: + if isinstance(prompt_message, SystemPromptMessage): + prompt_message = UserPromptMessage( + content=prompt_message.content, + name=prompt_message.name, + ) + + new_prompt_messages.append(prompt_message) + prompt_messages = new_prompt_messages + + return prompt_messages + + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + response: ChatCompletion, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ): + assistant_message = response.choices[0].message + assistant_message_tool_calls = assistant_message.tool_calls + + # extract tool calls from response + tool_calls = [] + self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=assistant_message_tool_calls) + + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls) + + # calculate num tokens + if response.usage: + # transform usage + prompt_tokens = response.usage.prompt_tokens + completion_tokens = response.usage.completion_tokens + else: + # calculate num tokens + prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools) + completion_tokens = self._num_tokens_from_messages(credentials, [assistant_prompt_message]) + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + # transform response + result = LLMResult( + model=response.model or model, + prompt_messages=prompt_messages, + message=assistant_prompt_message, + usage=usage, + system_fingerprint=response.system_fingerprint, + ) + + return result + + def _handle_chat_generate_stream_response( + self, + model: str, + credentials: dict, + response: Stream[ChatCompletionChunk], + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ): + index = 0 + full_assistant_content = "" + real_model = model + system_fingerprint = None + completion = "" + tool_calls = [] + for chunk in response: + if len(chunk.choices) == 0: + continue + + delta = chunk.choices[0] + # NOTE: For fix https://github.com/langgenius/dify/issues/5790 + if delta.delta is None: + continue + + # extract tool calls from response + self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=delta.delta.tool_calls) + + # Handling exceptions when content filters' streaming mode is set to asynchronous modified filter + if delta.finish_reason is None and not delta.delta.content: + continue + + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage(content=delta.delta.content or "", tool_calls=tool_calls) + + full_assistant_content += delta.delta.content or "" + + real_model = chunk.model + system_fingerprint = chunk.system_fingerprint + completion += delta.delta.content or "" + + yield LLMResultChunk( + model=real_model, + prompt_messages=prompt_messages, + system_fingerprint=system_fingerprint, + delta=LLMResultChunkDelta( + index=index, + message=assistant_prompt_message, + ), + ) + + index += 1 + + # calculate num tokens + prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools) + + full_assistant_prompt_message = AssistantPromptMessage(content=completion) + completion_tokens = self._num_tokens_from_messages(credentials, [full_assistant_prompt_message]) + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + yield LLMResultChunk( + model=real_model, + prompt_messages=prompt_messages, + system_fingerprint=system_fingerprint, + delta=LLMResultChunkDelta( + index=index, message=AssistantPromptMessage(content=""), finish_reason="stop", usage=usage + ), + ) + + @staticmethod + def _update_tool_calls( + tool_calls: list[AssistantPromptMessage.ToolCall], + tool_calls_response: Optional[Sequence[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]], + ) -> None: + if tool_calls_response: + for response_tool_call in tool_calls_response: + if isinstance(response_tool_call, ChatCompletionMessageToolCall): + function = AssistantPromptMessage.ToolCall.ToolCallFunction( + name=response_tool_call.function.name, arguments=response_tool_call.function.arguments + ) + + tool_call = AssistantPromptMessage.ToolCall( + id=response_tool_call.id, type=response_tool_call.type, function=function + ) + tool_calls.append(tool_call) + elif isinstance(response_tool_call, ChoiceDeltaToolCall): + index = response_tool_call.index + if index < len(tool_calls): + tool_calls[index].id = response_tool_call.id or tool_calls[index].id + tool_calls[index].type = response_tool_call.type or tool_calls[index].type + if response_tool_call.function: + tool_calls[index].function.name = ( + response_tool_call.function.name or tool_calls[index].function.name + ) + tool_calls[index].function.arguments += response_tool_call.function.arguments or "" + else: + assert response_tool_call.id is not None + assert response_tool_call.type is not None + assert response_tool_call.function is not None + assert response_tool_call.function.name is not None + assert response_tool_call.function.arguments is not None + + function = AssistantPromptMessage.ToolCall.ToolCallFunction( + name=response_tool_call.function.name, arguments=response_tool_call.function.arguments + ) + tool_call = AssistantPromptMessage.ToolCall( + id=response_tool_call.id, type=response_tool_call.type, function=function + ) + tool_calls.append(tool_call) + + @staticmethod + def _convert_prompt_message_to_dict(message: PromptMessage): + if isinstance(message, UserPromptMessage): + message = cast(UserPromptMessage, message) + if isinstance(message.content, str): + message_dict = {"role": "user", "content": message.content} + else: + sub_messages = [] + assert message.content is not None + for message_content in message.content: + if message_content.type == PromptMessageContentType.TEXT: + message_content = cast(TextPromptMessageContent, message_content) + sub_message_dict = {"type": "text", "text": message_content.data} + sub_messages.append(sub_message_dict) + elif message_content.type == PromptMessageContentType.IMAGE: + message_content = cast(ImagePromptMessageContent, message_content) + sub_message_dict = { + "type": "image_url", + "image_url": {"url": message_content.data, "detail": message_content.detail.value}, + } + sub_messages.append(sub_message_dict) + message_dict = {"role": "user", "content": sub_messages} + elif isinstance(message, AssistantPromptMessage): + # message = cast(AssistantPromptMessage, message) + message_dict = {"role": "assistant", "content": message.content} + if message.tool_calls: + # fix azure when enable json schema cant process content = "" in assistant fix with None + if not message.content: + message_dict["content"] = None + message_dict["tool_calls"] = [helper.dump_model(tool_call) for tool_call in message.tool_calls] + elif isinstance(message, SystemPromptMessage): + message = cast(SystemPromptMessage, message) + message_dict = {"role": "system", "content": message.content} + elif isinstance(message, ToolPromptMessage): + message = cast(ToolPromptMessage, message) + message_dict = { + "role": "tool", + "name": message.name, + "content": message.content, + "tool_call_id": message.tool_call_id, + } + else: + raise ValueError(f"Got unknown type {message}") + + if message.name: + message_dict["name"] = message.name + + return message_dict + + def _num_tokens_from_string( + self, credentials: dict, text: str, tools: Optional[list[PromptMessageTool]] = None + ) -> int: + try: + encoding = tiktoken.encoding_for_model(credentials["base_model_name"]) + except KeyError: + encoding = tiktoken.get_encoding("cl100k_base") + + num_tokens = len(encoding.encode(text)) + + if tools: + num_tokens += self._num_tokens_for_tools(encoding, tools) + + return num_tokens + + def _num_tokens_from_messages( + self, credentials: dict, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None + ) -> int: + """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. + + Official documentation: https://github.com/openai/openai-cookbook/blob/ + main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" + model = credentials["base_model_name"] + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + logger.warning("Warning: model not found. Using cl100k_base encoding.") + model = "cl100k_base" + encoding = tiktoken.get_encoding(model) + + if model.startswith("gpt-35-turbo-0301"): + # every message follows {role/name}\n{content}\n + tokens_per_message = 4 + # if there's a name, the role is omitted + tokens_per_name = -1 + elif model.startswith("gpt-35-turbo") or model.startswith("gpt-4") or "o1" in model: + tokens_per_message = 3 + tokens_per_name = 1 + else: + raise NotImplementedError( + f"get_num_tokens_from_messages() is not presently implemented " + f"for model {model}." + "See https://github.com/openai/openai-python/blob/main/chatml.md for " + "information on how messages are converted to tokens." + ) + num_tokens = 0 + messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages] + for message in messages_dict: + num_tokens += tokens_per_message + for key, value in message.items(): + # Cast str(value) in case the message value is not a string + # This occurs with function messages + # TODO: The current token calculation method for the image type is not implemented, + # which need to download the image and then get the resolution for calculation, + # and will increase the request delay + if isinstance(value, list): + text = "" + for item in value: + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] + + value = text + + if key == "tool_calls": + for tool_call in value: + assert isinstance(tool_call, dict) + for t_key, t_value in tool_call.items(): + num_tokens += len(encoding.encode(t_key)) + if t_key == "function": + for f_key, f_value in t_value.items(): + num_tokens += len(encoding.encode(f_key)) + num_tokens += len(encoding.encode(f_value)) + else: + num_tokens += len(encoding.encode(t_key)) + num_tokens += len(encoding.encode(t_value)) + else: + num_tokens += len(encoding.encode(str(value))) + + if key == "name": + num_tokens += tokens_per_name + + # every reply is primed with assistant + num_tokens += 3 + + if tools: + num_tokens += self._num_tokens_for_tools(encoding, tools) + + return num_tokens + + @staticmethod + def _num_tokens_for_tools(encoding: tiktoken.Encoding, tools: list[PromptMessageTool]) -> int: + num_tokens = 0 + for tool in tools: + num_tokens += len(encoding.encode("type")) + num_tokens += len(encoding.encode("function")) + + # calculate num tokens for function object + num_tokens += len(encoding.encode("name")) + num_tokens += len(encoding.encode(tool.name)) + num_tokens += len(encoding.encode("description")) + num_tokens += len(encoding.encode(tool.description)) + parameters = tool.parameters + num_tokens += len(encoding.encode("parameters")) + if "title" in parameters: + num_tokens += len(encoding.encode("title")) + num_tokens += len(encoding.encode(parameters["title"])) + num_tokens += len(encoding.encode("type")) + num_tokens += len(encoding.encode(parameters["type"])) + if "properties" in parameters: + num_tokens += len(encoding.encode("properties")) + for key, value in parameters["properties"].items(): + num_tokens += len(encoding.encode(key)) + for field_key, field_value in value.items(): + num_tokens += len(encoding.encode(field_key)) + if field_key == "enum": + for enum_field in field_value: + num_tokens += 3 + num_tokens += len(encoding.encode(enum_field)) + else: + num_tokens += len(encoding.encode(field_key)) + num_tokens += len(encoding.encode(str(field_value))) + if "required" in parameters: + num_tokens += len(encoding.encode("required")) + for required_field in parameters["required"]: + num_tokens += 3 + num_tokens += len(encoding.encode(required_field)) + + return num_tokens + + @staticmethod + def _get_ai_model_entity(base_model_name: str, model: str): + for ai_model_entity in LLM_BASE_MODELS: + if ai_model_entity.base_model_name == base_model_name: + ai_model_entity_copy = copy.deepcopy(ai_model_entity) + ai_model_entity_copy.entity.model = model + ai_model_entity_copy.entity.label.en_US = model + ai_model_entity_copy.entity.label.zh_Hans = model + return ai_model_entity_copy + + def _get_base_model_name(self, credentials: dict) -> str: + base_model_name = credentials.get("base_model_name") + if not base_model_name: + raise ValueError("Base Model Name is required") + return base_model_name diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index f538eaef5b..691cb8d400 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -9,6 +9,8 @@ from typing import Any, Optional, Union from uuid import UUID, uuid4 from flask import current_app +from sqlalchemy import select +from sqlalchemy.orm import Session from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token from core.ops.entities.config_entity import ( @@ -329,15 +331,15 @@ class TraceTask: ): self.trace_type = trace_type self.message_id = message_id - self.workflow_run = workflow_run + self.workflow_run_id = workflow_run.id if workflow_run else None self.conversation_id = conversation_id self.user_id = user_id self.timer = timer - self.kwargs = kwargs self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") - self.app_id = None + self.kwargs = kwargs + def execute(self): return self.preprocess() @@ -345,19 +347,23 @@ class TraceTask: preprocess_map = { TraceTaskName.CONVERSATION_TRACE: lambda: self.conversation_trace(**self.kwargs), TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace( - self.workflow_run, self.conversation_id, self.user_id + workflow_run_id=self.workflow_run_id, conversation_id=self.conversation_id, user_id=self.user_id + ), + TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id), + TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace( + message_id=self.message_id, timer=self.timer, **self.kwargs ), - TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(self.message_id), - TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(self.message_id, self.timer, **self.kwargs), TraceTaskName.SUGGESTED_QUESTION_TRACE: lambda: self.suggested_question_trace( - self.message_id, self.timer, **self.kwargs + message_id=self.message_id, timer=self.timer, **self.kwargs ), TraceTaskName.DATASET_RETRIEVAL_TRACE: lambda: self.dataset_retrieval_trace( - self.message_id, self.timer, **self.kwargs + message_id=self.message_id, timer=self.timer, **self.kwargs + ), + TraceTaskName.TOOL_TRACE: lambda: self.tool_trace( + message_id=self.message_id, timer=self.timer, **self.kwargs ), - TraceTaskName.TOOL_TRACE: lambda: self.tool_trace(self.message_id, self.timer, **self.kwargs), TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace( - self.conversation_id, self.timer, **self.kwargs + conversation_id=self.conversation_id, timer=self.timer, **self.kwargs ), } @@ -367,86 +373,100 @@ class TraceTask: def conversation_trace(self, **kwargs): return kwargs - def workflow_trace(self, workflow_run: WorkflowRun | None, conversation_id, user_id): - if not workflow_run: - raise ValueError("Workflow run not found") + def workflow_trace( + self, + *, + workflow_run_id: str | None, + conversation_id: str | None, + user_id: str | None, + ): + if not workflow_run_id: + return {} - db.session.merge(workflow_run) - db.session.refresh(workflow_run) + with Session(db.engine) as session: + workflow_run_stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id) + workflow_run = session.scalars(workflow_run_stmt).first() + if not workflow_run: + raise ValueError("Workflow run not found") - workflow_id = workflow_run.workflow_id - tenant_id = workflow_run.tenant_id - workflow_run_id = workflow_run.id - workflow_run_elapsed_time = workflow_run.elapsed_time - workflow_run_status = workflow_run.status - workflow_run_inputs = workflow_run.inputs_dict - workflow_run_outputs = workflow_run.outputs_dict - workflow_run_version = workflow_run.version - error = workflow_run.error or "" + workflow_id = workflow_run.workflow_id + tenant_id = workflow_run.tenant_id + workflow_run_id = workflow_run.id + workflow_run_elapsed_time = workflow_run.elapsed_time + workflow_run_status = workflow_run.status + workflow_run_inputs = workflow_run.inputs_dict + workflow_run_outputs = workflow_run.outputs_dict + workflow_run_version = workflow_run.version + error = workflow_run.error or "" - total_tokens = workflow_run.total_tokens + total_tokens = workflow_run.total_tokens - file_list = workflow_run_inputs.get("sys.file") or [] - query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or "" + file_list = workflow_run_inputs.get("sys.file") or [] + query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or "" - # get workflow_app_log_id - workflow_app_log_data = ( - db.session.query(WorkflowAppLog) - .filter_by(tenant_id=tenant_id, app_id=workflow_run.app_id, workflow_run_id=workflow_run.id) - .first() - ) - workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None - # get message_id - message_data = ( - db.session.query(Message.id) - .filter_by(conversation_id=conversation_id, workflow_run_id=workflow_run_id) - .first() - ) - message_id = str(message_data.id) if message_data else None + # get workflow_app_log_id + workflow_app_log_data_stmt = select(WorkflowAppLog.id).where( + WorkflowAppLog.tenant_id == tenant_id, + WorkflowAppLog.app_id == workflow_run.app_id, + WorkflowAppLog.workflow_run_id == workflow_run.id, + ) + workflow_app_log_id = session.scalar(workflow_app_log_data_stmt) + # get message_id + message_id = None + if conversation_id: + message_data_stmt = select(Message.id).where( + Message.conversation_id == conversation_id, + Message.workflow_run_id == workflow_run_id, + ) + message_id = session.scalar(message_data_stmt) - metadata = { - "workflow_id": workflow_id, - "conversation_id": conversation_id, - "workflow_run_id": workflow_run_id, - "tenant_id": tenant_id, - "elapsed_time": workflow_run_elapsed_time, - "status": workflow_run_status, - "version": workflow_run_version, - "total_tokens": total_tokens, - "file_list": file_list, - "triggered_form": workflow_run.triggered_from, - "user_id": user_id, - } - - workflow_trace_info = WorkflowTraceInfo( - workflow_data=workflow_run.to_dict(), - conversation_id=conversation_id, - workflow_id=workflow_id, - tenant_id=tenant_id, - workflow_run_id=workflow_run_id, - workflow_run_elapsed_time=workflow_run_elapsed_time, - workflow_run_status=workflow_run_status, - workflow_run_inputs=workflow_run_inputs, - workflow_run_outputs=workflow_run_outputs, - workflow_run_version=workflow_run_version, - error=error, - total_tokens=total_tokens, - file_list=file_list, - query=query, - metadata=metadata, - workflow_app_log_id=workflow_app_log_id, - message_id=message_id, - start_time=workflow_run.created_at, - end_time=workflow_run.finished_at, - ) + metadata = { + "workflow_id": workflow_id, + "conversation_id": conversation_id, + "workflow_run_id": workflow_run_id, + "tenant_id": tenant_id, + "elapsed_time": workflow_run_elapsed_time, + "status": workflow_run_status, + "version": workflow_run_version, + "total_tokens": total_tokens, + "file_list": file_list, + "triggered_form": workflow_run.triggered_from, + "user_id": user_id, + } + workflow_trace_info = WorkflowTraceInfo( + workflow_data=workflow_run.to_dict(), + conversation_id=conversation_id, + workflow_id=workflow_id, + tenant_id=tenant_id, + workflow_run_id=workflow_run_id, + workflow_run_elapsed_time=workflow_run_elapsed_time, + workflow_run_status=workflow_run_status, + workflow_run_inputs=workflow_run_inputs, + workflow_run_outputs=workflow_run_outputs, + workflow_run_version=workflow_run_version, + error=error, + total_tokens=total_tokens, + file_list=file_list, + query=query, + metadata=metadata, + workflow_app_log_id=workflow_app_log_id, + message_id=message_id, + start_time=workflow_run.created_at, + end_time=workflow_run.finished_at, + ) return workflow_trace_info - def message_trace(self, message_id): + def message_trace(self, message_id: str | None): + if not message_id: + return {} message_data = get_message_data(message_id) if not message_data: return {} - conversation_mode = db.session.query(Conversation.mode).filter_by(id=message_data.conversation_id).first() + conversation_mode_stmt = select(Conversation.mode).where(Conversation.id == message_data.conversation_id) + conversation_mode = db.session.scalars(conversation_mode_stmt).all() + if not conversation_mode or len(conversation_mode) == 0: + return {} conversation_mode = conversation_mode[0] created_at = message_data.created_at inputs = message_data.message diff --git a/api/core/ops/utils.py b/api/core/ops/utils.py index 998eba9ea9..8b06df1930 100644 --- a/api/core/ops/utils.py +++ b/api/core/ops/utils.py @@ -18,7 +18,7 @@ def filter_none_values(data: dict): return new_data -def get_message_data(message_id): +def get_message_data(message_id: str): return db.session.query(Message).filter(Message.id == message_id).first() diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 34343ad60e..3a8200bc7b 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -6,11 +6,14 @@ from flask import Flask, current_app from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.embedding.retrieval import RetrievalSegments +from core.rag.index_processor.constant.index_type import IndexType from core.rag.models.document import Document from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db -from models.dataset import Dataset +from models.dataset import ChildChunk, Dataset, DocumentSegment +from models.dataset import Document as DatasetDocument from services.external_knowledge_service import ExternalDatasetService default_retrieval_model = { @@ -248,3 +251,89 @@ class RetrievalService: @staticmethod def escape_query_for_search(query: str) -> str: return query.replace('"', '\\"') + + @staticmethod + def format_retrieval_documents(documents: list[Document]) -> list[RetrievalSegments]: + records = [] + include_segment_ids = [] + segment_child_map = {} + for document in documents: + document_id = document.metadata.get("document_id") + dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first() + if dataset_document: + if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + child_index_node_id = document.metadata.get("doc_id") + result = ( + db.session.query(ChildChunk, DocumentSegment) + .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id) + .filter( + ChildChunk.index_node_id == child_index_node_id, + DocumentSegment.dataset_id == dataset_document.dataset_id, + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + ) + .first() + ) + if result: + child_chunk, segment = result + if not segment: + continue + if segment.id not in include_segment_ids: + include_segment_ids.append(segment.id) + child_chunk_detail = { + "id": child_chunk.id, + "content": child_chunk.content, + "position": child_chunk.position, + "score": document.metadata.get("score", 0.0), + } + map_detail = { + "max_score": document.metadata.get("score", 0.0), + "child_chunks": [child_chunk_detail], + } + segment_child_map[segment.id] = map_detail + record = { + "segment": segment, + } + records.append(record) + else: + child_chunk_detail = { + "id": child_chunk.id, + "content": child_chunk.content, + "position": child_chunk.position, + "score": document.metadata.get("score", 0.0), + } + segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) + segment_child_map[segment.id]["max_score"] = max( + segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0) + ) + else: + continue + else: + index_node_id = document.metadata["doc_id"] + + segment = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.dataset_id == dataset_document.dataset_id, + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + DocumentSegment.index_node_id == index_node_id, + ) + .first() + ) + + if not segment: + continue + include_segment_ids.append(segment.id) + record = { + "segment": segment, + "score": document.metadata.get("score", None), + } + + records.append(record) + for record in records: + if record["segment"].id in segment_child_map: + record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks", None) + record["score"] = segment_child_map[record["segment"].id]["max_score"] + + return [RetrievalSegments(**record) for record in records] diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 6d16a9bdc2..398b0daad9 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -7,7 +7,7 @@ from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.rag.models.document import Document from extensions.ext_database import db -from models.dataset import Dataset, DocumentSegment +from models.dataset import ChildChunk, Dataset, DocumentSegment class DatasetDocumentStore: @@ -60,7 +60,7 @@ class DatasetDocumentStore: return output - def add_documents(self, docs: Sequence[Document], allow_update: bool = True) -> None: + def add_documents(self, docs: Sequence[Document], allow_update: bool = True, save_child: bool = False) -> None: max_position = ( db.session.query(func.max(DocumentSegment.position)) .filter(DocumentSegment.document_id == self._document_id) @@ -120,13 +120,55 @@ class DatasetDocumentStore: segment_document.answer = doc.metadata.pop("answer", "") db.session.add(segment_document) + db.session.flush() + if save_child: + if doc.children: + for postion, child in enumerate(doc.children, start=1): + child_segment = ChildChunk( + tenant_id=self._dataset.tenant_id, + dataset_id=self._dataset.id, + document_id=self._document_id, + segment_id=segment_document.id, + position=postion, + index_node_id=child.metadata.get("doc_id"), + index_node_hash=child.metadata.get("doc_hash"), + content=child.page_content, + word_count=len(child.page_content), + type="automatic", + created_by=self._user_id, + ) + db.session.add(child_segment) else: segment_document.content = doc.page_content if doc.metadata.get("answer"): segment_document.answer = doc.metadata.pop("answer", "") - segment_document.index_node_hash = doc.metadata["doc_hash"] + segment_document.index_node_hash = doc.metadata.get("doc_hash") segment_document.word_count = len(doc.page_content) segment_document.tokens = tokens + if save_child and doc.children: + # delete the existing child chunks + db.session.query(ChildChunk).filter( + ChildChunk.tenant_id == self._dataset.tenant_id, + ChildChunk.dataset_id == self._dataset.id, + ChildChunk.document_id == self._document_id, + ChildChunk.segment_id == segment_document.id, + ).delete() + # add new child chunks + for position, child in enumerate(doc.children, start=1): + child_segment = ChildChunk( + tenant_id=self._dataset.tenant_id, + dataset_id=self._dataset.id, + document_id=self._document_id, + segment_id=segment_document.id, + position=position, + index_node_id=child.metadata.get("doc_id"), + index_node_hash=child.metadata.get("doc_hash"), + content=child.page_content, + word_count=len(child.page_content), + type="automatic", + created_by=self._user_id, + ) + db.session.add(child_segment) db.session.commit() diff --git a/api/core/rag/embedding/retrieval.py b/api/core/rag/embedding/retrieval.py new file mode 100644 index 0000000000..800422d888 --- /dev/null +++ b/api/core/rag/embedding/retrieval.py @@ -0,0 +1,23 @@ +from typing import Optional + +from pydantic import BaseModel + +from models.dataset import DocumentSegment + + +class RetrievalChildChunk(BaseModel): + """Retrieval segments.""" + + id: str + content: str + score: float + position: int + + +class RetrievalSegments(BaseModel): + """Retrieval segments.""" + + model_config = {"arbitrary_types_allowed": True} + segment: DocumentSegment + child_chunks: Optional[list[RetrievalChildChunk]] = None + score: Optional[float] = None diff --git a/api/core/rag/extractor/excel_extractor.py b/api/core/rag/extractor/excel_extractor.py index c444105bb5..a3b35458df 100644 --- a/api/core/rag/extractor/excel_extractor.py +++ b/api/core/rag/extractor/excel_extractor.py @@ -4,7 +4,7 @@ import os from typing import Optional, cast import pandas as pd -from openpyxl import load_workbook +from openpyxl import load_workbook # type: ignore from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index a473b3dfa7..f9fd7f92a1 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -24,7 +24,6 @@ from core.rag.extractor.unstructured.unstructured_markdown_extractor import Unst from core.rag.extractor.unstructured.unstructured_msg_extractor import UnstructuredMsgExtractor from core.rag.extractor.unstructured.unstructured_ppt_extractor import UnstructuredPPTExtractor from core.rag.extractor.unstructured.unstructured_pptx_extractor import UnstructuredPPTXExtractor -from core.rag.extractor.unstructured.unstructured_text_extractor import UnstructuredTextExtractor from core.rag.extractor.unstructured.unstructured_xml_extractor import UnstructuredXmlExtractor from core.rag.extractor.word_extractor import WordExtractor from core.rag.models.document import Document @@ -103,12 +102,11 @@ class ExtractProcessor: input_file = Path(file_path) file_extension = input_file.suffix.lower() etl_type = dify_config.ETL_TYPE - unstructured_api_url = dify_config.UNSTRUCTURED_API_URL - unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY - assert unstructured_api_url is not None, "unstructured_api_url is required" - assert unstructured_api_key is not None, "unstructured_api_key is required" extractor: Optional[BaseExtractor] = None if etl_type == "Unstructured": + unstructured_api_url = dify_config.UNSTRUCTURED_API_URL + unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY or "" + if file_extension in {".xlsx", ".xls"}: extractor = ExcelExtractor(file_path) elif file_extension == ".pdf": @@ -141,11 +139,7 @@ class ExtractProcessor: extractor = UnstructuredEpubExtractor(file_path, unstructured_api_url, unstructured_api_key) else: # txt - extractor = ( - UnstructuredTextExtractor(file_path, unstructured_api_url) - if is_automatic - else TextExtractor(file_path, autodetect_encoding=True) - ) + extractor = TextExtractor(file_path, autodetect_encoding=True) else: if file_extension in {".xlsx", ".xls"}: extractor = ExcelExtractor(file_path) diff --git a/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py b/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py index 9647dedfff..f1fa5dde5c 100644 --- a/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py @@ -1,5 +1,6 @@ import base64 import logging +from typing import Optional from bs4 import BeautifulSoup # type: ignore @@ -15,7 +16,7 @@ class UnstructuredEmailExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__(self, file_path: str, api_url: str, api_key: str): + def __init__(self, file_path: str, api_url: Optional[str] = None, api_key: str = ""): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url diff --git a/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py b/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py index 80c29157aa..35ca686f62 100644 --- a/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py @@ -19,7 +19,7 @@ class UnstructuredEpubExtractor(BaseExtractor): self, file_path: str, api_url: Optional[str] = None, - api_key: Optional[str] = None, + api_key: str = "", ): """Initialize with file path.""" self._file_path = file_path @@ -30,9 +30,6 @@ class UnstructuredEpubExtractor(BaseExtractor): if self._api_url: from unstructured.partition.api import partition_via_api - if self._api_key is None: - raise ValueError("api_key is required") - elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key) else: from unstructured.partition.epub import partition_epub diff --git a/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py b/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py index 4173d4d122..d5418e612a 100644 --- a/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py @@ -1,4 +1,5 @@ import logging +from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document @@ -24,7 +25,7 @@ class UnstructuredMarkdownExtractor(BaseExtractor): if the specified encoding fails. """ - def __init__(self, file_path: str, api_url: str, api_key: str): + def __init__(self, file_path: str, api_url: Optional[str] = None, api_key: str = ""): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url diff --git a/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py b/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py index 57affb8d36..d363449c29 100644 --- a/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py @@ -1,4 +1,5 @@ import logging +from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document @@ -14,7 +15,7 @@ class UnstructuredMsgExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__(self, file_path: str, api_url: str, api_key: str): + def __init__(self, file_path: str, api_url: Optional[str] = None, api_key: str = ""): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url diff --git a/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py b/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py index e504d4bc23..ecc272a2f0 100644 --- a/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py @@ -1,4 +1,5 @@ import logging +from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document @@ -14,7 +15,7 @@ class UnstructuredPPTExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__(self, file_path: str, api_url: str, api_key: str): + def __init__(self, file_path: str, api_url: Optional[str] = None, api_key: str = ""): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url diff --git a/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py b/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py index cefe72b290..e7bf6fd2e6 100644 --- a/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py @@ -1,4 +1,5 @@ import logging +from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document @@ -14,7 +15,7 @@ class UnstructuredPPTXExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__(self, file_path: str, api_url: str, api_key: str): + def __init__(self, file_path: str, api_url: Optional[str] = None, api_key: str = ""): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url diff --git a/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py b/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py index ef46ab0e70..916cdc3f2b 100644 --- a/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py @@ -1,4 +1,5 @@ import logging +from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document @@ -14,7 +15,7 @@ class UnstructuredXmlExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__(self, file_path: str, api_url: str, api_key: str): + def __init__(self, file_path: str, api_url: Optional[str] = None, api_key: str = ""): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 0441b409a0..12f0cd182a 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -267,8 +267,10 @@ class WordExtractor(BaseExtractor): if isinstance(element.tag, str) and element.tag.endswith("p"): # paragraph para = paragraphs.pop(0) parsed_paragraph = parse_paragraph(para) - if parsed_paragraph: + if parsed_paragraph.strip(): content.append(parsed_paragraph) + else: + content.append("\n") elif isinstance(element.tag, str) and element.tag.endswith("tbl"): # table table = tables.pop(0) content.append(self._table_to_markdown(table, image_map)) diff --git a/api/core/rag/index_processor/constant/index_type.py b/api/core/rag/index_processor/constant/index_type.py index e42cc44c6f..0845b58e25 100644 --- a/api/core/rag/index_processor/constant/index_type.py +++ b/api/core/rag/index_processor/constant/index_type.py @@ -1,8 +1,7 @@ from enum import Enum -class IndexType(Enum): +class IndexType(str, Enum): PARAGRAPH_INDEX = "text_model" QA_INDEX = "qa_model" - PARENT_CHILD_INDEX = "parent_child_index" - SUMMARY_INDEX = "summary_index" + PARENT_CHILD_INDEX = "hierarchical_model" diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index 7e5efdc66e..2bcd1c79bb 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -27,10 +27,10 @@ class BaseIndexProcessor(ABC): raise NotImplementedError @abstractmethod - def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): + def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): raise NotImplementedError - def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): + def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs): raise NotImplementedError @abstractmethod @@ -45,26 +45,29 @@ class BaseIndexProcessor(ABC): ) -> list[Document]: raise NotImplementedError - def _get_splitter(self, processing_rule: dict, embedding_model_instance: Optional[ModelInstance]) -> TextSplitter: + def _get_splitter( + self, + processing_rule_mode: str, + max_tokens: int, + chunk_overlap: int, + separator: str, + embedding_model_instance: Optional[ModelInstance], + ) -> TextSplitter: """ Get the NodeParser object according to the processing rule. """ - character_splitter: TextSplitter - if processing_rule["mode"] == "custom": + if processing_rule_mode in ["custom", "hierarchical"]: # The user-defined segmentation rule - rules = processing_rule["rules"] - segmentation = rules["segmentation"] max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH - if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length: + if max_tokens < 50 or max_tokens > max_segmentation_tokens_length: raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.") - separator = segmentation["separator"] if separator: separator = separator.replace("\\n", "\n") character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder( - chunk_size=segmentation["max_tokens"], - chunk_overlap=segmentation.get("chunk_overlap", 0) or 0, + chunk_size=max_tokens, + chunk_overlap=chunk_overlap, fixed_separator=separator, separators=["\n\n", "。", ". ", " ", ""], embedding_model_instance=embedding_model_instance, @@ -78,4 +81,4 @@ class BaseIndexProcessor(ABC): embedding_model_instance=embedding_model_instance, ) - return character_splitter + return character_splitter # type: ignore diff --git a/api/core/rag/index_processor/index_processor_factory.py b/api/core/rag/index_processor/index_processor_factory.py index c5ba6295f3..c987edf342 100644 --- a/api/core/rag/index_processor/index_processor_factory.py +++ b/api/core/rag/index_processor/index_processor_factory.py @@ -3,6 +3,7 @@ from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor +from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor @@ -18,9 +19,11 @@ class IndexProcessorFactory: if not self._index_type: raise ValueError("Index type must be specified.") - if self._index_type == IndexType.PARAGRAPH_INDEX.value: + if self._index_type == IndexType.PARAGRAPH_INDEX: return ParagraphIndexProcessor() - elif self._index_type == IndexType.QA_INDEX.value: + elif self._index_type == IndexType.QA_INDEX: return QAIndexProcessor() + elif self._index_type == IndexType.PARENT_CHILD_INDEX: + return ParentChildIndexProcessor() else: raise ValueError(f"Index type {self._index_type} is not supported.") diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index c66fa54d50..dca84b9041 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -13,21 +13,40 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.models.document import Document from core.tools.utils.text_processing_utils import remove_leading_symbols from libs import helper -from models.dataset import Dataset +from models.dataset import Dataset, DatasetProcessRule +from services.entities.knowledge_entities.knowledge_entities import Rule class ParagraphIndexProcessor(BaseIndexProcessor): def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: text_docs = ExtractProcessor.extract( - extract_setting=extract_setting, is_automatic=kwargs.get("process_rule_mode") == "automatic" + extract_setting=extract_setting, + is_automatic=( + kwargs.get("process_rule_mode") == "automatic" or kwargs.get("process_rule_mode") == "hierarchical" + ), ) return text_docs def transform(self, documents: list[Document], **kwargs) -> list[Document]: + process_rule = kwargs.get("process_rule") + if not process_rule: + raise ValueError("No process rule found.") + if process_rule.get("mode") == "automatic": + automatic_rule = DatasetProcessRule.AUTOMATIC_RULES + rules = Rule(**automatic_rule) + else: + if not process_rule.get("rules"): + raise ValueError("No rules found in process rule.") + rules = Rule(**process_rule.get("rules")) # Split the text documents into nodes. + if not rules.segmentation: + raise ValueError("No segmentation found in rules.") splitter = self._get_splitter( - processing_rule=kwargs.get("process_rule", {}), + processing_rule_mode=process_rule.get("mode"), + max_tokens=rules.segmentation.max_tokens, + chunk_overlap=rules.segmentation.chunk_overlap, + separator=rules.segmentation.separator, embedding_model_instance=kwargs.get("embedding_model_instance"), ) all_documents = [] @@ -53,15 +72,19 @@ class ParagraphIndexProcessor(BaseIndexProcessor): all_documents.extend(split_documents) return all_documents - def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): + def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): if dataset.indexing_technique == "high_quality": vector = Vector(dataset) vector.create(documents) if with_keywords: + keywords_list = kwargs.get("keywords_list") keyword = Keyword(dataset) - keyword.create(documents) + if keywords_list and len(keywords_list) > 0: + keyword.add_texts(documents, keywords_list=keywords_list) + else: + keyword.add_texts(documents) - def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): + def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs): if dataset.indexing_technique == "high_quality": vector = Vector(dataset) if node_ids: diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py new file mode 100644 index 0000000000..e8423e2b77 --- /dev/null +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -0,0 +1,195 @@ +"""Paragraph index processor.""" + +import uuid +from typing import Optional + +from core.model_manager import ModelInstance +from core.rag.cleaner.clean_processor import CleanProcessor +from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.extractor.entity.extract_setting import ExtractSetting +from core.rag.extractor.extract_processor import ExtractProcessor +from core.rag.index_processor.index_processor_base import BaseIndexProcessor +from core.rag.models.document import ChildDocument, Document +from extensions.ext_database import db +from libs import helper +from models.dataset import ChildChunk, Dataset, DocumentSegment +from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule + + +class ParentChildIndexProcessor(BaseIndexProcessor): + def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: + text_docs = ExtractProcessor.extract( + extract_setting=extract_setting, + is_automatic=( + kwargs.get("process_rule_mode") == "automatic" or kwargs.get("process_rule_mode") == "hierarchical" + ), + ) + + return text_docs + + def transform(self, documents: list[Document], **kwargs) -> list[Document]: + process_rule = kwargs.get("process_rule") + if not process_rule: + raise ValueError("No process rule found.") + if not process_rule.get("rules"): + raise ValueError("No rules found in process rule.") + rules = Rule(**process_rule.get("rules")) + all_documents = [] # type: ignore + if rules.parent_mode == ParentMode.PARAGRAPH: + # Split the text documents into nodes. + splitter = self._get_splitter( + processing_rule_mode=process_rule.get("mode"), + max_tokens=rules.segmentation.max_tokens, + chunk_overlap=rules.segmentation.chunk_overlap, + separator=rules.segmentation.separator, + embedding_model_instance=kwargs.get("embedding_model_instance"), + ) + for document in documents: + # document clean + document_text = CleanProcessor.clean(document.page_content, process_rule) + document.page_content = document_text + # parse document to nodes + document_nodes = splitter.split_documents([document]) + split_documents = [] + for document_node in document_nodes: + if document_node.page_content.strip(): + doc_id = str(uuid.uuid4()) + hash = helper.generate_text_hash(document_node.page_content) + document_node.metadata["doc_id"] = doc_id + document_node.metadata["doc_hash"] = hash + # delete Splitter character + page_content = document_node.page_content + if page_content.startswith(".") or page_content.startswith("。"): + page_content = page_content[1:].strip() + else: + page_content = page_content + if len(page_content) > 0: + document_node.page_content = page_content + # parse document to child nodes + child_nodes = self._split_child_nodes( + document_node, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance") + ) + document_node.children = child_nodes + split_documents.append(document_node) + all_documents.extend(split_documents) + elif rules.parent_mode == ParentMode.FULL_DOC: + page_content = "\n".join([document.page_content for document in documents]) + document = Document(page_content=page_content, metadata=documents[0].metadata) + # parse document to child nodes + child_nodes = self._split_child_nodes( + document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance") + ) + document.children = child_nodes + doc_id = str(uuid.uuid4()) + hash = helper.generate_text_hash(document.page_content) + document.metadata["doc_id"] = doc_id + document.metadata["doc_hash"] = hash + all_documents.append(document) + + return all_documents + + def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): + if dataset.indexing_technique == "high_quality": + vector = Vector(dataset) + for document in documents: + child_documents = document.children + if child_documents: + formatted_child_documents = [ + Document(**child_document.model_dump()) for child_document in child_documents + ] + vector.create(formatted_child_documents) + + def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs): + # node_ids is segment's node_ids + if dataset.indexing_technique == "high_quality": + delete_child_chunks = kwargs.get("delete_child_chunks") or False + vector = Vector(dataset) + if node_ids: + child_node_ids = ( + db.session.query(ChildChunk.index_node_id) + .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id) + .filter( + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.index_node_id.in_(node_ids), + ChildChunk.dataset_id == dataset.id, + ) + .all() + ) + child_node_ids = [child_node_id[0] for child_node_id in child_node_ids] + vector.delete_by_ids(child_node_ids) + if delete_child_chunks: + db.session.query(ChildChunk).filter( + ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids) + ).delete() + db.session.commit() + else: + vector.delete() + + if delete_child_chunks: + db.session.query(ChildChunk).filter(ChildChunk.dataset_id == dataset.id).delete() + db.session.commit() + + def retrieve( + self, + retrieval_method: str, + query: str, + dataset: Dataset, + top_k: int, + score_threshold: float, + reranking_model: dict, + ) -> list[Document]: + # Set search parameters. + results = RetrievalService.retrieve( + retrieval_method=retrieval_method, + dataset_id=dataset.id, + query=query, + top_k=top_k, + score_threshold=score_threshold, + reranking_model=reranking_model, + ) + # Organize results. + docs = [] + for result in results: + metadata = result.metadata + metadata["score"] = result.score + if result.score > score_threshold: + doc = Document(page_content=result.page_content, metadata=metadata) + docs.append(doc) + return docs + + def _split_child_nodes( + self, + document_node: Document, + rules: Rule, + process_rule_mode: str, + embedding_model_instance: Optional[ModelInstance], + ) -> list[ChildDocument]: + if not rules.subchunk_segmentation: + raise ValueError("No subchunk segmentation found in rules.") + child_splitter = self._get_splitter( + processing_rule_mode=process_rule_mode, + max_tokens=rules.subchunk_segmentation.max_tokens, + chunk_overlap=rules.subchunk_segmentation.chunk_overlap, + separator=rules.subchunk_segmentation.separator, + embedding_model_instance=embedding_model_instance, + ) + # parse document to child nodes + child_nodes = [] + child_documents = child_splitter.split_documents([document_node]) + for child_document_node in child_documents: + if child_document_node.page_content.strip(): + doc_id = str(uuid.uuid4()) + hash = helper.generate_text_hash(child_document_node.page_content) + child_document = ChildDocument( + page_content=child_document_node.page_content, metadata=document_node.metadata + ) + child_document.metadata["doc_id"] = doc_id + child_document.metadata["doc_hash"] = hash + child_page_content = child_document.page_content + if child_page_content.startswith(".") or child_page_content.startswith("。"): + child_page_content = child_page_content[1:].strip() + if len(child_page_content) > 0: + child_document.page_content = child_page_content + child_nodes.append(child_document) + return child_nodes diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 20fd16e8f3..58b50a9fcb 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -21,18 +21,32 @@ from core.rag.models.document import Document from core.tools.utils.text_processing_utils import remove_leading_symbols from libs import helper from models.dataset import Dataset +from services.entities.knowledge_entities.knowledge_entities import Rule class QAIndexProcessor(BaseIndexProcessor): def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: text_docs = ExtractProcessor.extract( - extract_setting=extract_setting, is_automatic=kwargs.get("process_rule_mode") == "automatic" + extract_setting=extract_setting, + is_automatic=( + kwargs.get("process_rule_mode") == "automatic" or kwargs.get("process_rule_mode") == "hierarchical" + ), ) return text_docs def transform(self, documents: list[Document], **kwargs) -> list[Document]: + preview = kwargs.get("preview") + process_rule = kwargs.get("process_rule") + if not process_rule: + raise ValueError("No process rule found.") + if not process_rule.get("rules"): + raise ValueError("No rules found in process rule.") + rules = Rule(**process_rule.get("rules")) splitter = self._get_splitter( - processing_rule=kwargs.get("process_rule") or {}, + processing_rule_mode=process_rule.get("mode"), + max_tokens=rules.segmentation.max_tokens if rules.segmentation else 0, + chunk_overlap=rules.segmentation.chunk_overlap if rules.segmentation else 0, + separator=rules.segmentation.separator if rules.segmentation else "", embedding_model_instance=kwargs.get("embedding_model_instance"), ) @@ -59,24 +73,33 @@ class QAIndexProcessor(BaseIndexProcessor): document_node.page_content = remove_leading_symbols(page_content) split_documents.append(document_node) all_documents.extend(split_documents) - for i in range(0, len(all_documents), 10): - threads = [] - sub_documents = all_documents[i : i + 10] - for doc in sub_documents: - document_format_thread = threading.Thread( - target=self._format_qa_document, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "tenant_id": kwargs.get("tenant_id"), - "document_node": doc, - "all_qa_documents": all_qa_documents, - "document_language": kwargs.get("doc_language", "English"), - }, - ) - threads.append(document_format_thread) - document_format_thread.start() - for thread in threads: - thread.join() + if preview: + self._format_qa_document( + current_app._get_current_object(), # type: ignore + kwargs.get("tenant_id"), # type: ignore + all_documents[0], + all_qa_documents, + kwargs.get("doc_language", "English"), + ) + else: + for i in range(0, len(all_documents), 10): + threads = [] + sub_documents = all_documents[i : i + 10] + for doc in sub_documents: + document_format_thread = threading.Thread( + target=self._format_qa_document, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "tenant_id": kwargs.get("tenant_id"), # type: ignore + "document_node": doc, + "all_qa_documents": all_qa_documents, + "document_language": kwargs.get("doc_language", "English"), + }, + ) + threads.append(document_format_thread) + document_format_thread.start() + for thread in threads: + thread.join() return all_qa_documents def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]: @@ -98,12 +121,12 @@ class QAIndexProcessor(BaseIndexProcessor): raise ValueError(str(e)) return text_docs - def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): + def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): if dataset.indexing_technique == "high_quality": vector = Vector(dataset) vector.create(documents) - def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): + def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs): vector = Vector(dataset) if node_ids: vector.delete_by_ids(node_ids) diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 1e9aaa24f0..421cdc05df 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -2,7 +2,20 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from typing import Any, Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel + + +class ChildDocument(BaseModel): + """Class for storing a piece of text and associated metadata.""" + + page_content: str + + vector: Optional[list[float]] = None + + """Arbitrary metadata about the page content (e.g., source, relationships to other + documents, etc.). + """ + metadata: dict = {} class Document(BaseModel): @@ -15,10 +28,12 @@ class Document(BaseModel): """Arbitrary metadata about the page content (e.g., source, relationships to other documents, etc.). """ - metadata: Optional[dict] = Field(default_factory=dict) + metadata: dict = {} provider: Optional[str] = "dify" + children: Optional[list[ChildDocument]] = None + class BaseDocumentTransformer(ABC): """Abstract base class for document transformation systems. diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 8a7172f27c..290d9e6e61 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -164,43 +164,29 @@ class DatasetRetrieval: "content": item.page_content, } retrieval_resource_list.append(source) - document_score_list = {} # deal with dify documents if dify_documents: - for item in dify_documents: - if item.metadata.get("score"): - document_score_list[item.metadata["doc_id"]] = item.metadata["score"] - - index_node_ids = [document.metadata["doc_id"] for document in dify_documents] - segments = DocumentSegment.query.filter( - DocumentSegment.dataset_id.in_(dataset_ids), - DocumentSegment.status == "completed", - DocumentSegment.enabled == True, - DocumentSegment.index_node_id.in_(index_node_ids), - ).all() - - if segments: - index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} - sorted_segments = sorted( - segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) - ) - for segment in sorted_segments: + records = RetrievalService.format_retrieval_documents(dify_documents) + if records: + for record in records: + segment = record.segment if segment.answer: document_context_list.append( DocumentContext( content=f"question:{segment.get_sign_content()} answer:{segment.answer}", - score=document_score_list.get(segment.index_node_id, None), + score=record.score, ) ) else: document_context_list.append( DocumentContext( content=segment.get_sign_content(), - score=document_score_list.get(segment.index_node_id, None), + score=record.score, ) ) if show_retrieve_source: - for segment in sorted_segments: + for record in records: + segment = record.segment dataset = Dataset.query.filter_by(id=segment.dataset_id).first() document = DatasetDocument.query.filter( DatasetDocument.id == segment.document_id, @@ -216,7 +202,7 @@ class DatasetRetrieval: "data_source_type": document.data_source_type, "segment_id": segment.id, "retriever_from": invoke_from.to_source(), - "score": document_score_list.get(segment.index_node_id, 0.0), + "score": record.score or 0.0, } if invoke_from.to_source() == "dev": diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 38980f6d75..6941ff8fa2 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -267,6 +267,7 @@ class ToolParameter(PluginParameter): :param options: the options of the parameter """ # convert options to ToolParameterOption + # FIXME fix the type error if options: option_objs = [ PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py new file mode 100644 index 0000000000..4094207beb --- /dev/null +++ b/api/core/tools/tool/tool.py @@ -0,0 +1,355 @@ +from abc import ABC, abstractmethod +from collections.abc import Mapping +from copy import deepcopy +from enum import Enum, StrEnum +from typing import TYPE_CHECKING, Any, Optional, Union + +from pydantic import BaseModel, ConfigDict, field_validator +from pydantic_core.core_schema import ValidationInfo + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.tools.entities.tool_entities import ( + ToolDescription, + ToolIdentity, + ToolInvokeFrom, + ToolInvokeMessage, + ToolParameter, + ToolProviderType, + ToolRuntimeImageVariable, + ToolRuntimeVariable, + ToolRuntimeVariablePool, +) +from core.tools.tool_file_manager import ToolFileManager + +if TYPE_CHECKING: + from core.file.models import File + + +class Tool(BaseModel, ABC): + identity: Optional[ToolIdentity] = None + parameters: Optional[list[ToolParameter]] = None + description: Optional[ToolDescription] = None + is_team_authorization: bool = False + + # pydantic configs + model_config = ConfigDict(protected_namespaces=()) + + @field_validator("parameters", mode="before") + @classmethod + def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]: + return v or [] + + class Runtime(BaseModel): + """ + Meta data of a tool call processing + """ + + def __init__(self, **data: Any): + super().__init__(**data) + if not self.runtime_parameters: + self.runtime_parameters = {} + + tenant_id: Optional[str] = None + tool_id: Optional[str] = None + invoke_from: Optional[InvokeFrom] = None + tool_invoke_from: Optional[ToolInvokeFrom] = None + credentials: Optional[dict[str, Any]] = None + runtime_parameters: Optional[dict[str, Any]] = None + + runtime: Optional[Runtime] = None + variables: Optional[ToolRuntimeVariablePool] = None + + def __init__(self, **data: Any): + super().__init__(**data) + + class VariableKey(StrEnum): + IMAGE = "image" + DOCUMENT = "document" + VIDEO = "video" + AUDIO = "audio" + CUSTOM = "custom" + + def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool": + """ + fork a new tool with meta data + + :param meta: the meta data of a tool call processing, tenant_id is required + :return: the new tool + """ + return self.__class__( + identity=self.identity.model_copy() if self.identity else None, + parameters=self.parameters.copy() if self.parameters else None, + description=self.description.model_copy() if self.description else None, + runtime=Tool.Runtime(**runtime), + ) + + @abstractmethod + def tool_provider_type(self) -> ToolProviderType: + """ + get the tool provider type + + :return: the tool provider type + """ + + def load_variables(self, variables: ToolRuntimeVariablePool | None) -> None: + """ + load variables from database + + :param conversation_id: the conversation id + """ + self.variables = variables + + def set_image_variable(self, variable_name: str, image_key: str) -> None: + """ + set an image variable + """ + if not self.variables: + return + if self.identity is None: + return + + self.variables.set_file(self.identity.name, variable_name, image_key) + + def set_text_variable(self, variable_name: str, text: str) -> None: + """ + set a text variable + """ + if not self.variables: + return + if self.identity is None: + return + + self.variables.set_text(self.identity.name, variable_name, text) + + def get_variable(self, name: Union[str, Enum]) -> Optional[ToolRuntimeVariable]: + """ + get a variable + + :param name: the name of the variable + :return: the variable + """ + if not self.variables: + return None + + if isinstance(name, Enum): + name = name.value + + for variable in self.variables.pool: + if variable.name == name: + return variable + + return None + + def get_default_image_variable(self) -> Optional[ToolRuntimeVariable]: + """ + get the default image variable + + :return: the image variable + """ + if not self.variables: + return None + + return self.get_variable(self.VariableKey.IMAGE) + + def get_variable_file(self, name: Union[str, Enum]) -> Optional[bytes]: + """ + get a variable file + + :param name: the name of the variable + :return: the variable file + """ + variable = self.get_variable(name) + if not variable: + return None + + if not isinstance(variable, ToolRuntimeImageVariable): + return None + + message_file_id = variable.value + # get file binary + file_binary = ToolFileManager.get_file_binary_by_message_file_id(message_file_id) + if not file_binary: + return None + + return file_binary[0] + + def list_variables(self) -> list[ToolRuntimeVariable]: + """ + list all variables + + :return: the variables + """ + if not self.variables: + return [] + + return self.variables.pool + + def list_default_image_variables(self) -> list[ToolRuntimeVariable]: + """ + list all image variables + + :return: the image variables + """ + if not self.variables: + return [] + + result = [] + + for variable in self.variables.pool: + if variable.name.startswith(self.VariableKey.IMAGE.value): + result.append(variable) + + return result + + def invoke(self, user_id: str, tool_parameters: Mapping[str, Any]) -> list[ToolInvokeMessage]: + # update tool_parameters + # TODO: Fix type error. + if self.runtime is None: + return [] + if self.runtime.runtime_parameters: + # Convert Mapping to dict before updating + tool_parameters = dict(tool_parameters) + tool_parameters.update(self.runtime.runtime_parameters) + + # try parse tool parameters into the correct type + tool_parameters = self._transform_tool_parameters_type(tool_parameters) + + result = self._invoke( + user_id=user_id, + tool_parameters=tool_parameters, + ) + + if not isinstance(result, list): + result = [result] + + if not all(isinstance(message, ToolInvokeMessage) for message in result): + raise ValueError( + f"Invalid return type from {self.__class__.__name__}._invoke method. " + "Expected ToolInvokeMessage or list of ToolInvokeMessage." + ) + + return result + + def _transform_tool_parameters_type(self, tool_parameters: Mapping[str, Any]) -> dict[str, Any]: + """ + Transform tool parameters type + """ + # Temp fix for the issue that the tool parameters will be converted to empty while validating the credentials + result: dict[str, Any] = deepcopy(dict(tool_parameters)) + for parameter in self.parameters or []: + if parameter.name in tool_parameters: + result[parameter.name] = parameter.type.cast_value(tool_parameters[parameter.name]) + + return result + + @abstractmethod + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + pass + + def validate_credentials( + self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False + ) -> str | None: + """ + validate the credentials + + :param credentials: the credentials + :param parameters: the parameters + :param format_only: only return the formatted + """ + pass + + def get_runtime_parameters(self) -> list[ToolParameter]: + """ + get the runtime parameters + + interface for developer to dynamic change the parameters of a tool depends on the variables pool + + :return: the runtime parameters + """ + return self.parameters or [] + + def get_all_runtime_parameters(self) -> list[ToolParameter]: + """ + get all runtime parameters + + :return: all runtime parameters + """ + parameters = self.parameters or [] + parameters = parameters.copy() + user_parameters = self.get_runtime_parameters() + user_parameters = user_parameters.copy() + + # override parameters + for parameter in user_parameters: + # check if parameter in tool parameters + found = False + for tool_parameter in parameters: + if tool_parameter.name == parameter.name: + found = True + break + + if found: + # override parameter + tool_parameter.type = parameter.type + tool_parameter.form = parameter.form + tool_parameter.required = parameter.required + tool_parameter.default = parameter.default + tool_parameter.options = parameter.options + tool_parameter.llm_description = parameter.llm_description + else: + # add new parameter + parameters.append(parameter) + + return parameters + + def create_image_message(self, image: str, save_as: str = "") -> ToolInvokeMessage: + """ + create an image message + + :param image: the url of the image + :return: the image message + """ + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE, message=image, save_as=save_as) + + def create_file_message(self, file: "File") -> ToolInvokeMessage: + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE, message="", meta={"file": file}, save_as="") + + def create_link_message(self, link: str, save_as: str = "") -> ToolInvokeMessage: + """ + create a link message + + :param link: the url of the link + :return: the link message + """ + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK, message=link, save_as=save_as) + + def create_text_message(self, text: str, save_as: str = "") -> ToolInvokeMessage: + """ + create a text message + + :param text: the text + :return: the text message + """ + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.TEXT, message=text, save_as=save_as) + + def create_blob_message(self, blob: bytes, meta: Optional[dict] = None, save_as: str = "") -> ToolInvokeMessage: + """ + create a blob message + + :param blob: the blob + :return: the blob message + """ + return ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.BLOB, + message=blob, + meta=meta or {}, + save_as=save_as, + ) + + def create_json_message(self, object: dict) -> ToolInvokeMessage: + """ + create a json message + """ + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.JSON, message=object) diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index cfa8e6b8b2..702c4384ae 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -139,7 +139,7 @@ class ToolEngine: error_response = f"tool invoke error: {e}" agent_tool_callback.on_tool_error(e) except ToolEngineInvokeError as e: - meta = e.args[0] + meta = e.meta error_response = f"tool invoke error: {meta.error}" agent_tool_callback.on_tool_error(e) return error_response, [], meta diff --git a/api/core/tools/utils/text_processing_utils.py b/api/core/tools/utils/text_processing_utils.py index 6db9dfd0d9..105823f896 100644 --- a/api/core/tools/utils/text_processing_utils.py +++ b/api/core/tools/utils/text_processing_utils.py @@ -12,5 +12,6 @@ def remove_leading_symbols(text: str) -> str: str: The text with leading punctuation or symbols removed. """ # Match Unicode ranges for punctuation and symbols - pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,\-./:;<=>?@\[\]^_`{|}~]+" + # FIXME this pattern is confused quick fix for #11868 maybe refactor it later + pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,./:;<=>?@^_`~]+" return re.sub(pattern, "", text) diff --git a/api/core/workflow/nodes/answer/base_stream_processor.py b/api/core/workflow/nodes/answer/base_stream_processor.py index 8ffb487ec1..f22ea078fb 100644 --- a/api/core/workflow/nodes/answer/base_stream_processor.py +++ b/api/core/workflow/nodes/answer/base_stream_processor.py @@ -48,9 +48,11 @@ class StreamProcessor(ABC): # we remove the node maybe shortcut the answer node, so comment this code for now # there is not effect on the answer node and the workflow, when we have a better solution # we can open this code. Issues: #11542 #9560 #10638 #10564 - - # reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id)) - continue + ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id) + if "answer" in ids: + continue + else: + reachable_node_ids.extend(ids) else: unreachable_first_node_ids.append(edge.target_node_id) diff --git a/api/core/workflow/nodes/http_request/exc.py b/api/core/workflow/nodes/http_request/exc.py index a815f277be..46613c9e86 100644 --- a/api/core/workflow/nodes/http_request/exc.py +++ b/api/core/workflow/nodes/http_request/exc.py @@ -20,3 +20,7 @@ class ResponseSizeError(HttpRequestNodeError): class RequestBodyError(HttpRequestNodeError): """Raised when the request body is invalid.""" + + +class InvalidURLError(HttpRequestNodeError): + """Raised when the URL is invalid.""" diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index cdfdc6e6d5..fadd142e35 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -23,6 +23,7 @@ from .exc import ( FileFetchError, HttpRequestNodeError, InvalidHttpMethodError, + InvalidURLError, RequestBodyError, ResponseSizeError, ) @@ -66,6 +67,12 @@ class Executor: node_data.authorization.config.api_key ).text + # check if node_data.url is a valid URL + if not node_data.url: + raise InvalidURLError("url is required") + if not node_data.url.startswith(("http://", "https://")): + raise InvalidURLError("url should start with http:// or https://") + self.url: str = node_data.url self.method = node_data.method self.auth = node_data.authorization diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index bfd93c074d..0f239af51a 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -11,6 +11,7 @@ from core.entities.model_entities import ModelStatus from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.rag.datasource.retrieval_service import RetrievalService from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.variables import StringSegment @@ -18,7 +19,7 @@ from core.workflow.entities.node_entities import NodeRunResult from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType from extensions.ext_database import db -from models.dataset import Dataset, Document, DocumentSegment +from models.dataset import Dataset, Document from models.workflow import WorkflowNodeExecutionStatus from .entities import KnowledgeRetrievalNodeData @@ -211,29 +212,12 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): "content": item.page_content, } retrieval_resource_list.append(source) - document_score_list: dict[str, float] = {} # deal with dify documents if dify_documents: - document_score_list = {} - for item in dify_documents: - if item.metadata.get("score"): - document_score_list[item.metadata["doc_id"]] = item.metadata["score"] - - index_node_ids = [document.metadata["doc_id"] for document in dify_documents] - segments = DocumentSegment.query.filter( - DocumentSegment.dataset_id.in_(dataset_ids), - DocumentSegment.completed_at.isnot(None), - DocumentSegment.status == "completed", - DocumentSegment.enabled == True, - DocumentSegment.index_node_id.in_(index_node_ids), - ).all() - if segments: - index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} - sorted_segments = sorted( - segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) - ) - - for segment in sorted_segments: + records = RetrievalService.format_retrieval_documents(dify_documents) + if records: + for record in records: + segment = record.segment dataset = Dataset.query.filter_by(id=segment.dataset_id).first() document = Document.query.filter( Document.id == segment.document_id, @@ -251,7 +235,7 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): "document_data_source_type": document.data_source_type, "segment_id": segment.id, "retriever_from": "workflow", - "score": document_score_list.get(segment.index_node_id, None), + "score": record.score or 0.0, "segment_hit_count": segment.hit_count, "segment_word_count": segment.word_count, "segment_position": segment.position, @@ -270,10 +254,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): key=lambda x: x["metadata"]["score"] if x["metadata"].get("score") is not None else 0.0, reverse=True, ) - position = 1 - for item in retrieval_resource_list: + for position, item in enumerate(retrieval_resource_list, start=1): item["metadata"]["position"] = position - position += 1 return retrieval_resource_list @classmethod diff --git a/api/extensions/ext_blueprints.py b/api/extensions/ext_blueprints.py index fcd1547a2f..316be12f5c 100644 --- a/api/extensions/ext_blueprints.py +++ b/api/extensions/ext_blueprints.py @@ -5,7 +5,7 @@ from dify_app import DifyApp def init_app(app: DifyApp): # register blueprint routers - from flask_cors import CORS + from flask_cors import CORS # type: ignore from controllers.console import bp as console_app_bp from controllers.files import bp as files_bp diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 99c7195b2c..f7b658a58f 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -1,4 +1,5 @@ import mimetypes +import uuid from collections.abc import Callable, Mapping, Sequence from typing import Any, cast @@ -119,6 +120,11 @@ def _build_from_local_file( upload_file_id = mapping.get("upload_file_id") if not upload_file_id: raise ValueError("Invalid upload file id") + # check if upload_file_id is a valid uuid + try: + uuid.UUID(upload_file_id) + except ValueError: + raise ValueError("Invalid upload file id format") stmt = select(UploadFile).where( UploadFile.id == upload_file_id, UploadFile.tenant_id == tenant_id, diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index a74e6f54fb..bedab5750f 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -73,6 +73,7 @@ dataset_detail_fields = { "embedding_available": fields.Boolean, "retrieval_model_dict": fields.Nested(dataset_retrieval_model_fields), "tags": fields.List(fields.Nested(tag_fields)), + "doc_form": fields.String, "external_knowledge_info": fields.Nested(external_knowledge_info_fields), "external_retrieval_model": fields.Nested(external_retrieval_model_fields, allow_null=True), } diff --git a/api/fields/document_fields.py b/api/fields/document_fields.py index 2b2ac6243f..f2250d964a 100644 --- a/api/fields/document_fields.py +++ b/api/fields/document_fields.py @@ -34,6 +34,7 @@ document_with_segments_fields = { "data_source_info": fields.Raw(attribute="data_source_info_dict"), "data_source_detail_dict": fields.Raw(attribute="data_source_detail_dict"), "dataset_process_rule_id": fields.String, + "process_rule_dict": fields.Raw(attribute="process_rule_dict"), "name": fields.String, "created_from": fields.String, "created_by": fields.String, diff --git a/api/fields/hit_testing_fields.py b/api/fields/hit_testing_fields.py index aaafcab8ab..b9f7e78c17 100644 --- a/api/fields/hit_testing_fields.py +++ b/api/fields/hit_testing_fields.py @@ -34,8 +34,16 @@ segment_fields = { "document": fields.Nested(document_fields), } +child_chunk_fields = { + "id": fields.String, + "content": fields.String, + "position": fields.Integer, + "score": fields.Float, +} + hit_testing_record_fields = { "segment": fields.Nested(segment_fields), + "child_chunks": fields.List(fields.Nested(child_chunk_fields)), "score": fields.Float, "tsne_position": fields.Raw, } diff --git a/api/fields/segment_fields.py b/api/fields/segment_fields.py index 4413af3160..52f89859c9 100644 --- a/api/fields/segment_fields.py +++ b/api/fields/segment_fields.py @@ -2,6 +2,17 @@ from flask_restful import fields # type: ignore from libs.helper import TimestampField +child_chunk_fields = { + "id": fields.String, + "segment_id": fields.String, + "content": fields.String, + "position": fields.Integer, + "word_count": fields.Integer, + "type": fields.String, + "created_at": TimestampField, + "updated_at": TimestampField, +} + segment_fields = { "id": fields.String, "position": fields.Integer, @@ -20,10 +31,13 @@ segment_fields = { "status": fields.String, "created_by": fields.String, "created_at": TimestampField, + "updated_at": TimestampField, + "updated_by": fields.String, "indexing_at": TimestampField, "completed_at": TimestampField, "error": fields.String, "stopped_at": TimestampField, + "child_chunks": fields.List(fields.Nested(child_chunk_fields)), } segment_list_response = { diff --git a/api/migrations/versions/2024_11_22_0701-e19037032219_parent_child_index.py b/api/migrations/versions/2024_11_22_0701-e19037032219_parent_child_index.py new file mode 100644 index 0000000000..9238e5a0a8 --- /dev/null +++ b/api/migrations/versions/2024_11_22_0701-e19037032219_parent_child_index.py @@ -0,0 +1,55 @@ +"""parent-child-index + +Revision ID: e19037032219 +Revises: 01d6889832f7 +Create Date: 2024-11-22 07:01:17.550037 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'e19037032219' +down_revision = 'd7999dfa4aae' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('child_chunks', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('dataset_id', models.types.StringUUID(), nullable=False), + sa.Column('document_id', models.types.StringUUID(), nullable=False), + sa.Column('segment_id', models.types.StringUUID(), nullable=False), + sa.Column('position', sa.Integer(), nullable=False), + sa.Column('content', sa.Text(), nullable=False), + sa.Column('word_count', sa.Integer(), nullable=False), + sa.Column('index_node_id', sa.String(length=255), nullable=True), + sa.Column('index_node_hash', sa.String(length=255), nullable=True), + sa.Column('type', sa.String(length=255), server_default=sa.text("'automatic'::character varying"), nullable=False), + sa.Column('created_by', models.types.StringUUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_by', models.types.StringUUID(), nullable=True), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('indexing_at', sa.DateTime(), nullable=True), + sa.Column('completed_at', sa.DateTime(), nullable=True), + sa.Column('error', sa.Text(), nullable=True), + sa.PrimaryKeyConstraint('id', name='child_chunk_pkey') + ) + with op.batch_alter_table('child_chunks', schema=None) as batch_op: + batch_op.create_index('child_chunk_dataset_id_idx', ['tenant_id', 'dataset_id', 'document_id', 'segment_id', 'index_node_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('child_chunks', schema=None) as batch_op: + batch_op.drop_index('child_chunk_dataset_id_idx') + + op.drop_table('child_chunks') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_12_25_1137-923752d42eb6_add_auto_disabled_dataset_logs.py b/api/migrations/versions/2024_12_25_1137-923752d42eb6_add_auto_disabled_dataset_logs.py new file mode 100644 index 0000000000..6dadd4e4a8 --- /dev/null +++ b/api/migrations/versions/2024_12_25_1137-923752d42eb6_add_auto_disabled_dataset_logs.py @@ -0,0 +1,47 @@ +"""add_auto_disabled_dataset_logs + +Revision ID: 923752d42eb6 +Revises: e19037032219 +Create Date: 2024-12-25 11:37:55.467101 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '923752d42eb6' +down_revision = 'e19037032219' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('dataset_auto_disable_logs', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('dataset_id', models.types.StringUUID(), nullable=False), + sa.Column('document_id', models.types.StringUUID(), nullable=False), + sa.Column('notified', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='dataset_auto_disable_log_pkey') + ) + with op.batch_alter_table('dataset_auto_disable_logs', schema=None) as batch_op: + batch_op.create_index('dataset_auto_disable_log_created_atx', ['created_at'], unique=False) + batch_op.create_index('dataset_auto_disable_log_dataset_idx', ['dataset_id'], unique=False) + batch_op.create_index('dataset_auto_disable_log_tenant_idx', ['tenant_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('dataset_auto_disable_logs', schema=None) as batch_op: + batch_op.drop_index('dataset_auto_disable_log_tenant_idx') + batch_op.drop_index('dataset_auto_disable_log_dataset_idx') + batch_op.drop_index('dataset_auto_disable_log_created_atx') + + op.drop_table('dataset_auto_disable_logs') + # ### end Alembic commands ### diff --git a/api/models/account.py b/api/models/account.py index 4f8ca0530f..941dd54687 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -23,7 +23,7 @@ class Account(UserMixin, Base): __tablename__ = "accounts" __table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email")) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) name = db.Column(db.String(255), nullable=False) email = db.Column(db.String(255), nullable=False) password = db.Column(db.String(255), nullable=True) diff --git a/api/models/dataset.py b/api/models/dataset.py index b9b41dcf47..567f7db432 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -17,6 +17,7 @@ from sqlalchemy.dialects.postgresql import JSONB from configs import dify_config from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_storage import storage +from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule from .account import Account from .engine import db @@ -215,7 +216,7 @@ class DatasetProcessRule(db.Model): # type: ignore[name-defined] created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - MODES = ["automatic", "custom"] + MODES = ["automatic", "custom", "hierarchical"] PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"] AUTOMATIC_RULES: dict[str, Any] = { "pre_processing_rules": [ @@ -231,8 +232,6 @@ class DatasetProcessRule(db.Model): # type: ignore[name-defined] "dataset_id": self.dataset_id, "mode": self.mode, "rules": self.rules_dict, - "created_by": self.created_by, - "created_at": self.created_at, } @property @@ -396,6 +395,12 @@ class Document(db.Model): # type: ignore[name-defined] .scalar() ) + @property + def process_rule_dict(self): + if self.dataset_process_rule_id: + return self.dataset_process_rule.to_dict() + return None + def to_dict(self): return { "id": self.id, @@ -560,6 +565,24 @@ class DocumentSegment(db.Model): # type: ignore[name-defined] .first() ) + @property + def child_chunks(self): + process_rule = self.document.dataset_process_rule + if process_rule.mode == "hierarchical": + rules = Rule(**process_rule.rules_dict) + if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC: + child_chunks = ( + db.session.query(ChildChunk) + .filter(ChildChunk.segment_id == self.id) + .order_by(ChildChunk.position.asc()) + .all() + ) + return child_chunks or [] + else: + return [] + else: + return [] + def get_sign_content(self): signed_urls = [] text = self.content @@ -605,6 +628,47 @@ class DocumentSegment(db.Model): # type: ignore[name-defined] return text +class ChildChunk(db.Model): # type: ignore[name-defined] + __tablename__ = "child_chunks" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="child_chunk_pkey"), + db.Index("child_chunk_dataset_id_idx", "tenant_id", "dataset_id", "document_id", "segment_id", "index_node_id"), + ) + + # initial fields + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=False) + dataset_id = db.Column(StringUUID, nullable=False) + document_id = db.Column(StringUUID, nullable=False) + segment_id = db.Column(StringUUID, nullable=False) + position = db.Column(db.Integer, nullable=False) + content = db.Column(db.Text, nullable=False) + word_count = db.Column(db.Integer, nullable=False) + # indexing fields + index_node_id = db.Column(db.String(255), nullable=True) + index_node_hash = db.Column(db.String(255), nullable=True) + type = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying")) + created_by = db.Column(StringUUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_by = db.Column(StringUUID, nullable=True) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + indexing_at = db.Column(db.DateTime, nullable=True) + completed_at = db.Column(db.DateTime, nullable=True) + error = db.Column(db.Text, nullable=True) + + @property + def dataset(self): + return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first() + + @property + def document(self): + return db.session.query(Document).filter(Document.id == self.document_id).first() + + @property + def segment(self): + return db.session.query(DocumentSegment).filter(DocumentSegment.id == self.segment_id).first() + + class AppDatasetJoin(db.Model): # type: ignore[name-defined] __tablename__ = "app_dataset_joins" __table_args__ = ( @@ -844,3 +908,20 @@ class ExternalKnowledgeBindings(db.Model): # type: ignore[name-defined] created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + +class DatasetAutoDisableLog(db.Model): # type: ignore[name-defined] + __tablename__ = "dataset_auto_disable_logs" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"), + db.Index("dataset_auto_disable_log_tenant_idx", "tenant_id"), + db.Index("dataset_auto_disable_log_dataset_idx", "dataset_id"), + db.Index("dataset_auto_disable_log_created_atx", "created_at"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=False) + dataset_id = db.Column(StringUUID, nullable=False) + document_id = db.Column(StringUUID, nullable=False) + notified = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) diff --git a/api/models/model.py b/api/models/model.py index 39b091b5c9..462fbb672e 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -611,13 +611,13 @@ class Conversation(Base): db.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) app_model_config_id = db.Column(StringUUID, nullable=True) model_provider = db.Column(db.String(255), nullable=True) override_model_configs = db.Column(db.Text) model_id = db.Column(db.String(255), nullable=True) - mode = db.Column(db.String(255), nullable=False) + mode: Mapped[str] = mapped_column(db.String(255)) name = db.Column(db.String(255), nullable=False) summary = db.Column(db.Text) _inputs: Mapped[dict] = mapped_column("inputs", db.JSON) @@ -851,7 +851,7 @@ class Message(Base): Index("message_created_at_idx", "created_at"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) model_provider = db.Column(db.String(255), nullable=True) model_id = db.Column(db.String(255), nullable=True) @@ -878,7 +878,7 @@ class Message(Base): from_source = db.Column(db.String(255), nullable=False) from_end_user_id: Mapped[Optional[str]] = db.Column(StringUUID) from_account_id: Mapped[Optional[str]] = db.Column(StringUUID) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) workflow_run_id = db.Column(StringUUID) @@ -1403,7 +1403,7 @@ class EndUser(Base, UserMixin): external_user_id = db.Column(db.String(255), nullable=True) name = db.Column(db.String(255)) is_anonymous = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) - session_id = db.Column(db.String(255), nullable=False) + session_id: Mapped[str] = mapped_column() created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/tools.py b/api/models/tools.py index 0fcd87d2b9..3bc12e7fd7 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -256,8 +256,8 @@ class ToolConversationVariables(Base): updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property - def variables(self) -> dict: - return dict(json.loads(self.variables_str)) + def variables(self) -> Any: + return json.loads(self.variables_str) class ToolFile(Base): diff --git a/api/models/workflow.py b/api/models/workflow.py index 6e2bdf2392..8a54553e3b 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -402,40 +402,28 @@ class WorkflowRun(Base): db.Index("workflow_run_tenant_app_sequence_idx", "tenant_id", "app_id", "sequence_number"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - app_id = db.Column(StringUUID, nullable=False) - sequence_number = db.Column(db.Integer, nullable=False) - workflow_id = db.Column(StringUUID, nullable=False) - type = db.Column(db.String(255), nullable=False) - triggered_from = db.Column(db.String(255), nullable=False) - version = db.Column(db.String(255), nullable=False) - graph = db.Column(db.Text) - inputs = db.Column(db.Text) - status = db.Column(db.String(255), nullable=False) # running, succeeded, failed, stopped, partial-succeeded + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID) + app_id: Mapped[str] = mapped_column(StringUUID) + sequence_number: Mapped[int] = mapped_column() + workflow_id: Mapped[str] = mapped_column(StringUUID) + type: Mapped[str] = mapped_column(db.String(255)) + triggered_from: Mapped[str] = mapped_column(db.String(255)) + version: Mapped[str] = mapped_column(db.String(255)) + graph: Mapped[Optional[str]] = mapped_column(db.Text) + inputs: Mapped[Optional[str]] = mapped_column(db.Text) + status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}") - error = db.Column(db.Text) + error: Mapped[Optional[str]] = mapped_column(db.Text) elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) - total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) + total_tokens: Mapped[int] = mapped_column(server_default=db.text("0")) total_steps = db.Column(db.Integer, server_default=db.text("0")) - created_by_role = db.Column(db.String(255), nullable=False) # account, end_user + created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) finished_at = db.Column(db.DateTime) exceptions_count = db.Column(db.Integer, server_default=db.text("0")) - @property - def created_by_account(self): - created_by_role = CreatedByRole(self.created_by_role) - return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None - - @property - def created_by_end_user(self): - from models.model import EndUser - - created_by_role = CreatedByRole(self.created_by_role) - return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None - @property def graph_dict(self): return json.loads(self.graph) if self.graph else {} @@ -631,29 +619,29 @@ class WorkflowNodeExecution(Base): ), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - app_id = db.Column(StringUUID, nullable=False) - workflow_id = db.Column(StringUUID, nullable=False) - triggered_from = db.Column(db.String(255), nullable=False) - workflow_run_id = db.Column(StringUUID) - index = db.Column(db.Integer, nullable=False) - predecessor_node_id = db.Column(db.String(255)) - node_execution_id = db.Column(db.String(255), nullable=True) - node_id = db.Column(db.String(255), nullable=False) - node_type = db.Column(db.String(255), nullable=False) - title = db.Column(db.String(255), nullable=False) - inputs = db.Column(db.Text) - process_data = db.Column(db.Text) - outputs = db.Column(db.Text) - status = db.Column(db.String(255), nullable=False) - error = db.Column(db.Text) - elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) - execution_metadata = db.Column(db.Text) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - created_by_role = db.Column(db.String(255), nullable=False) - created_by = db.Column(StringUUID, nullable=False) - finished_at = db.Column(db.DateTime) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID) + app_id: Mapped[str] = mapped_column(StringUUID) + workflow_id: Mapped[str] = mapped_column(StringUUID) + triggered_from: Mapped[str] = mapped_column(db.String(255)) + workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID) + index: Mapped[int] = mapped_column(db.Integer) + predecessor_node_id: Mapped[Optional[str]] = mapped_column(db.String(255)) + node_execution_id: Mapped[Optional[str]] = mapped_column(db.String(255)) + node_id: Mapped[str] = mapped_column(db.String(255)) + node_type: Mapped[str] = mapped_column(db.String(255)) + title: Mapped[str] = mapped_column(db.String(255)) + inputs: Mapped[Optional[str]] = mapped_column(db.Text) + process_data: Mapped[Optional[str]] = mapped_column(db.Text) + outputs: Mapped[Optional[str]] = mapped_column(db.Text) + status: Mapped[str] = mapped_column(db.String(255)) + error: Mapped[Optional[str]] = mapped_column(db.Text) + elapsed_time: Mapped[float] = mapped_column(db.Float, server_default=db.text("0")) + execution_metadata: Mapped[Optional[str]] = mapped_column(db.Text) + created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) + created_by_role: Mapped[str] = mapped_column(db.String(255)) + created_by: Mapped[str] = mapped_column(StringUUID) + finished_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) @property def created_by_account(self): @@ -760,11 +748,11 @@ class WorkflowAppLog(Base): db.Index("workflow_app_log_app_idx", "tenant_id", "app_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - app_id = db.Column(StringUUID, nullable=False) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID) + app_id: Mapped[str] = mapped_column(StringUUID) workflow_id = db.Column(StringUUID, nullable=False) - workflow_run_id = db.Column(StringUUID, nullable=False) + workflow_run_id: Mapped[str] = mapped_column(StringUUID) created_from = db.Column(db.String(255), nullable=False) created_by_role = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) diff --git a/api/schedule/clean_unused_datasets_task.py b/api/schedule/clean_unused_datasets_task.py index f66b3c4797..eb73cc285d 100644 --- a/api/schedule/clean_unused_datasets_task.py +++ b/api/schedule/clean_unused_datasets_task.py @@ -10,7 +10,7 @@ from configs import dify_config from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from extensions.ext_redis import redis_client -from models.dataset import Dataset, DatasetQuery, Document +from models.dataset import Dataset, DatasetAutoDisableLog, DatasetQuery, Document from services.feature_service import FeatureService @@ -75,6 +75,23 @@ def clean_unused_datasets_task(): ) if not dataset_query or len(dataset_query) == 0: try: + # add auto disable log + documents = ( + db.session.query(Document) + .filter( + Document.dataset_id == dataset.id, + Document.enabled == True, + Document.archived == False, + ) + .all() + ) + for document in documents: + dataset_auto_disable_log = DatasetAutoDisableLog( + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + document_id=document.id, + ) + db.session.add(dataset_auto_disable_log) # remove index index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() index_processor.clean(dataset, None) @@ -151,6 +168,23 @@ def clean_unused_datasets_task(): else: plan = plan_cache.decode() if plan == "sandbox": + # add auto disable log + documents = ( + db.session.query(Document) + .filter( + Document.dataset_id == dataset.id, + Document.enabled == True, + Document.archived == False, + ) + .all() + ) + for document in documents: + dataset_auto_disable_log = DatasetAutoDisableLog( + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + document_id=document.id, + ) + db.session.add(dataset_auto_disable_log) # remove index index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() index_processor.clean(dataset, None) diff --git a/api/schedule/mail_clean_document_notify_task.py b/api/schedule/mail_clean_document_notify_task.py new file mode 100644 index 0000000000..766954a257 --- /dev/null +++ b/api/schedule/mail_clean_document_notify_task.py @@ -0,0 +1,63 @@ +import logging +import time +from collections import defaultdict + +import click +from celery import shared_task # type: ignore + +from extensions.ext_mail import mail +from models.account import Account, Tenant, TenantAccountJoin +from models.dataset import Dataset, DatasetAutoDisableLog + + +@shared_task(queue="mail") +def send_document_clean_notify_task(): + """ + Async Send document clean notify mail + + Usage: send_document_clean_notify_task.delay() + """ + if not mail.is_inited(): + return + + logging.info(click.style("Start send document clean notify mail", fg="green")) + start_at = time.perf_counter() + + # send document clean notify mail + try: + dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(DatasetAutoDisableLog.notified == False).all() + # group by tenant_id + dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list) + for dataset_auto_disable_log in dataset_auto_disable_logs: + dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log) + + for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items(): + knowledge_details = [] + tenant = Tenant.query.filter(Tenant.id == tenant_id).first() + if not tenant: + continue + current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first() + if not current_owner_join: + continue + account = Account.query.filter(Account.id == current_owner_join.account_id).first() + if not account: + continue + + dataset_auto_dataset_map = {} # type: ignore + for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: + dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append( + dataset_auto_disable_log.document_id + ) + + for dataset_id, document_ids in dataset_auto_dataset_map.items(): + dataset = Dataset.query.filter(Dataset.id == dataset_id).first() + if dataset: + document_count = len(document_ids) + knowledge_details.append(f"
  • Knowledge base {dataset.name}: {document_count} documents
  • ") + + end_at = time.perf_counter() + logging.info( + click.style("Send document clean notify mail succeeded: latency: {}".format(end_at - start_at), fg="green") + ) + except Exception: + logging.exception("Send invite member mail to failed") diff --git a/api/services/account_service.py b/api/services/account_service.py index c61cc80fc0..797a1feb81 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -820,6 +820,7 @@ class RegisterService: language: Optional[str] = None, status: Optional[AccountStatus] = None, is_setup: Optional[bool] = False, + create_workspace_required: Optional[bool] = True, ) -> Account: db.session.begin_nested() """Register account""" @@ -837,7 +838,7 @@ class RegisterService: if open_id is not None and provider is not None: AccountService.link_account_integrate(provider, open_id, account) - if FeatureService.get_system_features().is_allow_create_workspace: + if FeatureService.get_system_features().is_allow_create_workspace and create_workspace_required: tenant = TenantService.create_tenant(f"{account.name}'s Workspace") TenantService.create_tenant_member(tenant, account, role="owner") account.current_tenant = tenant diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 7793fdc4ff..d030a1dfa9 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -4,7 +4,7 @@ from enum import StrEnum from typing import Optional, cast from uuid import uuid4 -import yaml +import yaml # type: ignore from packaging import version from pydantic import BaseModel, Field from sqlalchemy import select @@ -524,7 +524,7 @@ class AppDslService: else: cls._append_model_config_export_data(export_data, app_model) - return yaml.dump(export_data, allow_unicode=True) + return yaml.dump(export_data, allow_unicode=True) # type: ignore @classmethod def _append_workflow_export_data(cls, *, export_data: dict, app_model: App, include_secret: bool) -> None: diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 973110f515..ef52301c0a 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -1,5 +1,6 @@ import io import logging +import uuid from typing import Optional from werkzeug.datastructures import FileStorage @@ -122,6 +123,10 @@ class AudioService: raise e if message_id: + try: + uuid.UUID(message_id) + except ValueError: + return None message = db.session.query(Message).filter(Message.id == message_id).first() if message is None: return None diff --git a/api/services/billing_service.py b/api/services/billing_service.py index d980186488..ed611a8be4 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -2,7 +2,7 @@ import os from typing import Optional import httpx -from tenacity import retry, retry_if_not_exception_type, stop_before_delay, wait_fixed +from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed from extensions.ext_database import db from models.account import TenantAccountJoin, TenantAccountRole @@ -44,7 +44,7 @@ class BillingService: @retry( wait=wait_fixed(2), stop=stop_before_delay(10), - retry=retry_if_not_exception_type(httpx.RequestError), + retry=retry_if_exception_type(httpx.RequestError), reraise=True, ) def _send_request(cls, method, endpoint, json=None, params=None): diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index ca741f1935..b7ddd14025 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -14,6 +14,7 @@ from configs import dify_config from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType +from core.rag.index_processor.constant.index_type import IndexType from core.rag.retrieval.retrieval_methods import RetrievalMethod from events.dataset_event import dataset_was_deleted from events.document_event import document_was_deleted @@ -23,7 +24,9 @@ from libs import helper from models.account import Account, TenantAccountRole from models.dataset import ( AppDatasetJoin, + ChildChunk, Dataset, + DatasetAutoDisableLog, DatasetCollectionBinding, DatasetPermission, DatasetPermissionEnum, @@ -35,8 +38,15 @@ from models.dataset import ( ) from models.model import UploadFile from models.source import DataSourceOauthBinding -from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateEntity -from services.errors.account import NoPermissionError +from services.entities.knowledge_entities.knowledge_entities import ( + ChildChunkUpdateArgs, + KnowledgeConfig, + RerankingModel, + RetrievalModel, + SegmentUpdateArgs, +) +from services.errors.account import InvalidActionError, NoPermissionError +from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError from services.errors.dataset import DatasetNameDuplicateError from services.errors.document import DocumentIndexingError from services.errors.file import FileNotExistsError @@ -44,13 +54,16 @@ from services.external_knowledge_service import ExternalDatasetService from services.feature_service import FeatureModel, FeatureService from services.tag_service import TagService from services.vector_service import VectorService +from tasks.batch_clean_document_task import batch_clean_document_task from tasks.clean_notion_document_task import clean_notion_document_task from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task from tasks.delete_segment_from_index_task import delete_segment_from_index_task from tasks.disable_segment_from_index_task import disable_segment_from_index_task +from tasks.disable_segments_from_index_task import disable_segments_from_index_task from tasks.document_indexing_task import document_indexing_task from tasks.document_indexing_update_task import document_indexing_update_task from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task +from tasks.enable_segments_to_index_task import enable_segments_to_index_task from tasks.recover_document_indexing_task import recover_document_indexing_task from tasks.retry_document_indexing_task import retry_document_indexing_task from tasks.sync_website_document_indexing_task import sync_website_document_indexing_task @@ -408,6 +421,24 @@ class DatasetService: .all() ) + @staticmethod + def get_dataset_auto_disable_logs(dataset_id: str) -> dict: + # get recent 30 days auto disable logs + start_date = datetime.datetime.now() - datetime.timedelta(days=30) + dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter( + DatasetAutoDisableLog.dataset_id == dataset_id, + DatasetAutoDisableLog.created_at >= start_date, + ).all() + if dataset_auto_disable_logs: + return { + "document_ids": [log.document_id for log in dataset_auto_disable_logs], + "count": len(dataset_auto_disable_logs), + } + return { + "document_ids": [], + "count": 0, + } + class DocumentService: DEFAULT_RULES = { @@ -518,12 +549,14 @@ class DocumentService: } @staticmethod - def get_document(dataset_id: str, document_id: str) -> Optional[Document]: - document = ( - db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() - ) - - return document + def get_document(dataset_id: str, document_id: Optional[str] = None) -> Optional[Document]: + if document_id: + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) + return document + else: + return None @staticmethod def get_document_by_id(document_id: str) -> Optional[Document]: @@ -588,6 +621,20 @@ class DocumentService: db.session.delete(document) db.session.commit() + @staticmethod + def delete_documents(dataset: Dataset, document_ids: list[str]): + documents = db.session.query(Document).filter(Document.id.in_(document_ids)).all() + file_ids = [ + document.data_source_info_dict["upload_file_id"] + for document in documents + if document.data_source_type == "upload_file" + ] + batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids) + + for document in documents: + db.session.delete(document) + db.session.commit() + @staticmethod def rename_document(dataset_id: str, document_id: str, name: str) -> Document: dataset = DatasetService.get_dataset(dataset_id) @@ -689,7 +736,7 @@ class DocumentService: @staticmethod def save_document_with_dataset_id( dataset: Dataset, - document_data: dict, + knowledge_config: KnowledgeConfig, account: Account | Any, dataset_process_rule: Optional[DatasetProcessRule] = None, created_from: str = "web", @@ -698,37 +745,35 @@ class DocumentService: features = FeatureService.get_features(current_user.current_tenant_id) if features.billing.enabled: - if "original_document_id" not in document_data or not document_data["original_document_id"]: + if not knowledge_config.original_document_id: count = 0 - if document_data["data_source"]["type"] == "upload_file": - upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"] - count = len(upload_file_list) - elif document_data["data_source"]["type"] == "notion_import": - notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"] - for notion_info in notion_info_list: - count = count + len(notion_info["pages"]) - elif document_data["data_source"]["type"] == "website_crawl": - website_info = document_data["data_source"]["info_list"]["website_info_list"] - count = len(website_info["urls"]) - batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) - if count > batch_upload_limit: - raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") + if knowledge_config.data_source: + if knowledge_config.data_source.info_list.data_source_type == "upload_file": + upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore + count = len(upload_file_list) + elif knowledge_config.data_source.info_list.data_source_type == "notion_import": + notion_info_list = knowledge_config.data_source.info_list.notion_info_list + for notion_info in notion_info_list: # type: ignore + count = count + len(notion_info.pages) + elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": + website_info = knowledge_config.data_source.info_list.website_info_list + count = len(website_info.urls) # type: ignore + batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) + if count > batch_upload_limit: + raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") - DocumentService.check_documents_upload_quota(count, features) + DocumentService.check_documents_upload_quota(count, features) # if dataset is empty, update dataset data_source_type if not dataset.data_source_type: - dataset.data_source_type = document_data["data_source"]["type"] + dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore if not dataset.indexing_technique: - if ( - "indexing_technique" not in document_data - or document_data["indexing_technique"] not in Dataset.INDEXING_TECHNIQUE_LIST - ): - raise ValueError("Indexing technique is required") + if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: + raise ValueError("Indexing technique is invalid") - dataset.indexing_technique = document_data["indexing_technique"] - if document_data["indexing_technique"] == "high_quality": + dataset.indexing_technique = knowledge_config.indexing_technique + if knowledge_config.indexing_technique == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_default_model_instance( tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING @@ -748,46 +793,47 @@ class DocumentService: "score_threshold_enabled": False, } - dataset.retrieval_model = document_data.get("retrieval_model") or default_retrieval_model + dataset.retrieval_model = knowledge_config.retrieval_model.model_dump() or default_retrieval_model # type: ignore documents = [] - if document_data.get("original_document_id"): - document = DocumentService.update_document_with_dataset_id(dataset, document_data, account) + if knowledge_config.original_document_id: + document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account) documents.append(document) batch = document.batch else: batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999)) # save process rule if not dataset_process_rule: - process_rule = document_data["process_rule"] - if process_rule["mode"] == "custom": - dataset_process_rule = DatasetProcessRule( - dataset_id=dataset.id, - mode=process_rule["mode"], - rules=json.dumps(process_rule["rules"]), - created_by=account.id, - ) - elif process_rule["mode"] == "automatic": - dataset_process_rule = DatasetProcessRule( - dataset_id=dataset.id, - mode=process_rule["mode"], - rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), - created_by=account.id, - ) - else: - logging.warn( - f"Invalid process rule mode: {process_rule['mode']}, can not find dataset process rule" - ) - return - db.session.add(dataset_process_rule) - db.session.commit() + process_rule = knowledge_config.process_rule + if process_rule: + if process_rule.mode in ("custom", "hierarchical"): + dataset_process_rule = DatasetProcessRule( + dataset_id=dataset.id, + mode=process_rule.mode, + rules=process_rule.rules.model_dump_json() if process_rule.rules else None, + created_by=account.id, + ) + elif process_rule.mode == "automatic": + dataset_process_rule = DatasetProcessRule( + dataset_id=dataset.id, + mode=process_rule.mode, + rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), + created_by=account.id, + ) + else: + logging.warn( + f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule" + ) + return + db.session.add(dataset_process_rule) + db.session.commit() lock_name = "add_document_lock_dataset_id_{}".format(dataset.id) with redis_client.lock(lock_name, timeout=600): position = DocumentService.get_documents_position(dataset.id) document_ids = [] duplicate_document_ids = [] - if document_data["data_source"]["type"] == "upload_file": - upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"] + if knowledge_config.data_source.info_list.data_source_type == "upload_file": + upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore for file_id in upload_file_list: file = ( db.session.query(UploadFile) @@ -804,7 +850,7 @@ class DocumentService: "upload_file_id": file_id, } # check duplicate - if document_data.get("duplicate", False): + if knowledge_config.duplicate: document = Document.query.filter_by( dataset_id=dataset.id, tenant_id=current_user.current_tenant_id, @@ -813,11 +859,11 @@ class DocumentService: name=file_name, ).first() if document: - document.dataset_process_rule_id = dataset_process_rule.id - document.updated_at = datetime.datetime.utcnow() + document.dataset_process_rule_id = dataset_process_rule.id # type: ignore + document.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) document.created_from = created_from - document.doc_form = document_data["doc_form"] - document.doc_language = document_data["doc_language"] + document.doc_form = knowledge_config.doc_form + document.doc_language = knowledge_config.doc_language document.data_source_info = json.dumps(data_source_info) document.batch = batch document.indexing_status = "waiting" @@ -827,10 +873,10 @@ class DocumentService: continue document = DocumentService.build_document( dataset, - dataset_process_rule.id, - document_data["data_source"]["type"], - document_data["doc_form"], - document_data["doc_language"], + dataset_process_rule.id, # type: ignore + knowledge_config.data_source.info_list.data_source_type, + knowledge_config.doc_form, + knowledge_config.doc_language, data_source_info, created_from, position, @@ -843,8 +889,10 @@ class DocumentService: document_ids.append(document.id) documents.append(document) position += 1 - elif document_data["data_source"]["type"] == "notion_import": - notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"] + elif knowledge_config.data_source.info_list.data_source_type == "notion_import": + notion_info_list = knowledge_config.data_source.info_list.notion_info_list + if not notion_info_list: + raise ValueError("No notion info list found.") exist_page_ids = [] exist_document = {} documents = Document.query.filter_by( @@ -859,7 +907,7 @@ class DocumentService: exist_page_ids.append(data_source_info["notion_page_id"]) exist_document[data_source_info["notion_page_id"]] = document.id for notion_info in notion_info_list: - workspace_id = notion_info["workspace_id"] + workspace_id = notion_info.workspace_id data_source_binding = DataSourceOauthBinding.query.filter( db.and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, @@ -870,25 +918,25 @@ class DocumentService: ).first() if not data_source_binding: raise ValueError("Data source binding not found.") - for page in notion_info["pages"]: - if page["page_id"] not in exist_page_ids: + for page in notion_info.pages: + if page.page_id not in exist_page_ids: data_source_info = { "notion_workspace_id": workspace_id, - "notion_page_id": page["page_id"], - "notion_page_icon": page["page_icon"], - "type": page["type"], + "notion_page_id": page.page_id, + "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, + "type": page.type, } document = DocumentService.build_document( dataset, - dataset_process_rule.id, - document_data["data_source"]["type"], - document_data["doc_form"], - document_data["doc_language"], + dataset_process_rule.id, # type: ignore + knowledge_config.data_source.info_list.data_source_type, + knowledge_config.doc_form, + knowledge_config.doc_language, data_source_info, created_from, position, account, - page["page_name"], + page.page_name, batch, ) db.session.add(document) @@ -897,19 +945,21 @@ class DocumentService: documents.append(document) position += 1 else: - exist_document.pop(page["page_id"]) + exist_document.pop(page.page_id) # delete not selected documents if len(exist_document) > 0: clean_notion_document_task.delay(list(exist_document.values()), dataset.id) - elif document_data["data_source"]["type"] == "website_crawl": - website_info = document_data["data_source"]["info_list"]["website_info_list"] - urls = website_info["urls"] + elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": + website_info = knowledge_config.data_source.info_list.website_info_list + if not website_info: + raise ValueError("No website info list found.") + urls = website_info.urls for url in urls: data_source_info = { "url": url, - "provider": website_info["provider"], - "job_id": website_info["job_id"], - "only_main_content": website_info.get("only_main_content", False), + "provider": website_info.provider, + "job_id": website_info.job_id, + "only_main_content": website_info.only_main_content, "mode": "crawl", } if len(url) > 255: @@ -918,10 +968,10 @@ class DocumentService: document_name = url document = DocumentService.build_document( dataset, - dataset_process_rule.id, - document_data["data_source"]["type"], - document_data["doc_form"], - document_data["doc_language"], + dataset_process_rule.id, # type: ignore + knowledge_config.data_source.info_list.data_source_type, + knowledge_config.doc_form, + knowledge_config.doc_language, data_source_info, created_from, position, @@ -995,31 +1045,31 @@ class DocumentService: @staticmethod def update_document_with_dataset_id( dataset: Dataset, - document_data: dict, + document_data: KnowledgeConfig, account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None, created_from: str = "web", ): DatasetService.check_dataset_model_setting(dataset) - document = DocumentService.get_document(dataset.id, document_data["original_document_id"]) + document = DocumentService.get_document(dataset.id, document_data.original_document_id) if document is None: raise NotFound("Document not found") if document.display_status != "available": raise ValueError("Document is not available") # save process rule - if document_data.get("process_rule"): - process_rule = document_data["process_rule"] - if process_rule["mode"] == "custom": + if document_data.process_rule: + process_rule = document_data.process_rule + if process_rule.mode in {"custom", "hierarchical"}: dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, - mode=process_rule["mode"], - rules=json.dumps(process_rule["rules"]), + mode=process_rule.mode, + rules=process_rule.rules.model_dump_json() if process_rule.rules else None, created_by=account.id, ) - elif process_rule["mode"] == "automatic": + elif process_rule.mode == "automatic": dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, - mode=process_rule["mode"], + mode=process_rule.mode, rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), created_by=account.id, ) @@ -1028,11 +1078,13 @@ class DocumentService: db.session.commit() document.dataset_process_rule_id = dataset_process_rule.id # update document data source - if document_data.get("data_source"): + if document_data.data_source: file_name = "" data_source_info = {} - if document_data["data_source"]["type"] == "upload_file": - upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"] + if document_data.data_source.info_list.data_source_type == "upload_file": + if not document_data.data_source.info_list.file_info_list: + raise ValueError("No file info list found.") + upload_file_list = document_data.data_source.info_list.file_info_list.file_ids for file_id in upload_file_list: file = ( db.session.query(UploadFile) @@ -1048,10 +1100,12 @@ class DocumentService: data_source_info = { "upload_file_id": file_id, } - elif document_data["data_source"]["type"] == "notion_import": - notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"] + elif document_data.data_source.info_list.data_source_type == "notion_import": + if not document_data.data_source.info_list.notion_info_list: + raise ValueError("No notion info list found.") + notion_info_list = document_data.data_source.info_list.notion_info_list for notion_info in notion_info_list: - workspace_id = notion_info["workspace_id"] + workspace_id = notion_info.workspace_id data_source_binding = DataSourceOauthBinding.query.filter( db.and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, @@ -1062,31 +1116,32 @@ class DocumentService: ).first() if not data_source_binding: raise ValueError("Data source binding not found.") - for page in notion_info["pages"]: + for page in notion_info.pages: data_source_info = { "notion_workspace_id": workspace_id, - "notion_page_id": page["page_id"], - "notion_page_icon": page["page_icon"], - "type": page["type"], + "notion_page_id": page.page_id, + "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, # type: ignore + "type": page.type, } - elif document_data["data_source"]["type"] == "website_crawl": - website_info = document_data["data_source"]["info_list"]["website_info_list"] - urls = website_info["urls"] - for url in urls: - data_source_info = { - "url": url, - "provider": website_info["provider"], - "job_id": website_info["job_id"], - "only_main_content": website_info.get("only_main_content", False), - "mode": "crawl", - } - document.data_source_type = document_data["data_source"]["type"] + elif document_data.data_source.info_list.data_source_type == "website_crawl": + website_info = document_data.data_source.info_list.website_info_list + if website_info: + urls = website_info.urls + for url in urls: + data_source_info = { + "url": url, + "provider": website_info.provider, + "job_id": website_info.job_id, + "only_main_content": website_info.only_main_content, # type: ignore + "mode": "crawl", + } + document.data_source_type = document_data.data_source.info_list.data_source_type document.data_source_info = json.dumps(data_source_info) document.name = file_name # update document name - if document_data.get("name"): - document.name = document_data["name"] + if document_data.name: + document.name = document_data.name # update document to be waiting document.indexing_status = "waiting" document.completed_at = None @@ -1096,7 +1151,7 @@ class DocumentService: document.splitting_completed_at = None document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) document.created_from = created_from - document.doc_form = document_data["doc_form"] + document.doc_form = document_data.doc_form db.session.add(document) db.session.commit() # update document segment @@ -1108,21 +1163,27 @@ class DocumentService: return document @staticmethod - def save_document_without_dataset_id(tenant_id: str, document_data: dict, account: Account): + def save_document_without_dataset_id(tenant_id: str, knowledge_config: KnowledgeConfig, account: Account): features = FeatureService.get_features(current_user.current_tenant_id) if features.billing.enabled: count = 0 - if document_data["data_source"]["type"] == "upload_file": - upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"] + if knowledge_config.data_source.info_list.data_source_type == "upload_file": + upload_file_list = ( + knowledge_config.data_source.info_list.file_info_list.file_ids + if knowledge_config.data_source.info_list.file_info_list + else [] + ) count = len(upload_file_list) - elif document_data["data_source"]["type"] == "notion_import": - notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"] - for notion_info in notion_info_list: - count = count + len(notion_info["pages"]) - elif document_data["data_source"]["type"] == "website_crawl": - website_info = document_data["data_source"]["info_list"]["website_info_list"] - count = len(website_info["urls"]) + elif knowledge_config.data_source.info_list.data_source_type == "notion_import": + notion_info_list = knowledge_config.data_source.info_list.notion_info_list + if notion_info_list: + for notion_info in notion_info_list: + count = count + len(notion_info.pages) + elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": + website_info = knowledge_config.data_source.info_list.website_info_list + if website_info: + count = len(website_info.urls) batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) if count > batch_upload_limit: raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") @@ -1131,39 +1192,39 @@ class DocumentService: dataset_collection_binding_id = None retrieval_model = None - if document_data["indexing_technique"] == "high_quality": + if knowledge_config.indexing_technique == "high_quality": dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - document_data["embedding_model_provider"], document_data["embedding_model"] + knowledge_config.embedding_model_provider, # type: ignore + knowledge_config.embedding_model, # type: ignore ) dataset_collection_binding_id = dataset_collection_binding.id - if document_data.get("retrieval_model"): - retrieval_model = document_data["retrieval_model"] + if knowledge_config.retrieval_model: + retrieval_model = knowledge_config.retrieval_model else: - default_retrieval_model = { - "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, - "reranking_enable": False, - "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, - "top_k": 2, - "score_threshold_enabled": False, - } - retrieval_model = default_retrieval_model + retrieval_model = RetrievalModel( + search_method=RetrievalMethod.SEMANTIC_SEARCH.value, + reranking_enable=False, + reranking_model=RerankingModel(reranking_provider_name="", reranking_model_name=""), + top_k=2, + score_threshold_enabled=False, + ) # save dataset dataset = Dataset( tenant_id=tenant_id, name="", - data_source_type=document_data["data_source"]["type"], - indexing_technique=document_data.get("indexing_technique", "high_quality"), + data_source_type=knowledge_config.data_source.info_list.data_source_type, + indexing_technique=knowledge_config.indexing_technique, created_by=account.id, - embedding_model=document_data.get("embedding_model"), - embedding_model_provider=document_data.get("embedding_model_provider"), + embedding_model=knowledge_config.embedding_model, + embedding_model_provider=knowledge_config.embedding_model_provider, collection_binding_id=dataset_collection_binding_id, - retrieval_model=retrieval_model, + retrieval_model=retrieval_model.model_dump() if retrieval_model else None, ) - db.session.add(dataset) + db.session.add(dataset) # type: ignore db.session.flush() - documents, batch = DocumentService.save_document_with_dataset_id(dataset, document_data, account) + documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) cut_length = 18 cut_name = documents[0].name[:cut_length] @@ -1174,133 +1235,86 @@ class DocumentService: return dataset, documents, batch @classmethod - def document_create_args_validate(cls, args: dict): - if "original_document_id" not in args or not args["original_document_id"]: - DocumentService.data_source_args_validate(args) - DocumentService.process_rule_args_validate(args) + def document_create_args_validate(cls, knowledge_config: KnowledgeConfig): + if not knowledge_config.data_source and not knowledge_config.process_rule: + raise ValueError("Data source or Process rule is required") else: - if ("data_source" not in args or not args["data_source"]) and ( - "process_rule" not in args or not args["process_rule"] - ): - raise ValueError("Data source or Process rule is required") - else: - if args.get("data_source"): - DocumentService.data_source_args_validate(args) - if args.get("process_rule"): - DocumentService.process_rule_args_validate(args) + if knowledge_config.data_source: + DocumentService.data_source_args_validate(knowledge_config) + if knowledge_config.process_rule: + DocumentService.process_rule_args_validate(knowledge_config) @classmethod - def data_source_args_validate(cls, args: dict): - if "data_source" not in args or not args["data_source"]: + def data_source_args_validate(cls, knowledge_config: KnowledgeConfig): + if not knowledge_config.data_source: raise ValueError("Data source is required") - if not isinstance(args["data_source"], dict): - raise ValueError("Data source is invalid") - - if "type" not in args["data_source"] or not args["data_source"]["type"]: - raise ValueError("Data source type is required") - - if args["data_source"]["type"] not in Document.DATA_SOURCES: + if knowledge_config.data_source.info_list.data_source_type not in Document.DATA_SOURCES: raise ValueError("Data source type is invalid") - if "info_list" not in args["data_source"] or not args["data_source"]["info_list"]: + if not knowledge_config.data_source.info_list: raise ValueError("Data source info is required") - if args["data_source"]["type"] == "upload_file": - if ( - "file_info_list" not in args["data_source"]["info_list"] - or not args["data_source"]["info_list"]["file_info_list"] - ): + if knowledge_config.data_source.info_list.data_source_type == "upload_file": + if not knowledge_config.data_source.info_list.file_info_list: raise ValueError("File source info is required") - if args["data_source"]["type"] == "notion_import": - if ( - "notion_info_list" not in args["data_source"]["info_list"] - or not args["data_source"]["info_list"]["notion_info_list"] - ): + if knowledge_config.data_source.info_list.data_source_type == "notion_import": + if not knowledge_config.data_source.info_list.notion_info_list: raise ValueError("Notion source info is required") - if args["data_source"]["type"] == "website_crawl": - if ( - "website_info_list" not in args["data_source"]["info_list"] - or not args["data_source"]["info_list"]["website_info_list"] - ): + if knowledge_config.data_source.info_list.data_source_type == "website_crawl": + if not knowledge_config.data_source.info_list.website_info_list: raise ValueError("Website source info is required") @classmethod - def process_rule_args_validate(cls, args: dict): - if "process_rule" not in args or not args["process_rule"]: + def process_rule_args_validate(cls, knowledge_config: KnowledgeConfig): + if not knowledge_config.process_rule: raise ValueError("Process rule is required") - if not isinstance(args["process_rule"], dict): - raise ValueError("Process rule is invalid") - - if "mode" not in args["process_rule"] or not args["process_rule"]["mode"]: + if not knowledge_config.process_rule.mode: raise ValueError("Process rule mode is required") - if args["process_rule"]["mode"] not in DatasetProcessRule.MODES: + if knowledge_config.process_rule.mode not in DatasetProcessRule.MODES: raise ValueError("Process rule mode is invalid") - if args["process_rule"]["mode"] == "automatic": - args["process_rule"]["rules"] = {} + if knowledge_config.process_rule.mode == "automatic": + knowledge_config.process_rule.rules = None else: - if "rules" not in args["process_rule"] or not args["process_rule"]["rules"]: + if not knowledge_config.process_rule.rules: raise ValueError("Process rule rules is required") - if not isinstance(args["process_rule"]["rules"], dict): - raise ValueError("Process rule rules is invalid") - - if ( - "pre_processing_rules" not in args["process_rule"]["rules"] - or args["process_rule"]["rules"]["pre_processing_rules"] is None - ): + if knowledge_config.process_rule.rules.pre_processing_rules is None: raise ValueError("Process rule pre_processing_rules is required") - if not isinstance(args["process_rule"]["rules"]["pre_processing_rules"], list): - raise ValueError("Process rule pre_processing_rules is invalid") - unique_pre_processing_rule_dicts = {} - for pre_processing_rule in args["process_rule"]["rules"]["pre_processing_rules"]: - if "id" not in pre_processing_rule or not pre_processing_rule["id"]: + for pre_processing_rule in knowledge_config.process_rule.rules.pre_processing_rules: + if not pre_processing_rule.id: raise ValueError("Process rule pre_processing_rules id is required") - if pre_processing_rule["id"] not in DatasetProcessRule.PRE_PROCESSING_RULES: - raise ValueError("Process rule pre_processing_rules id is invalid") - - if "enabled" not in pre_processing_rule or pre_processing_rule["enabled"] is None: - raise ValueError("Process rule pre_processing_rules enabled is required") - - if not isinstance(pre_processing_rule["enabled"], bool): + if not isinstance(pre_processing_rule.enabled, bool): raise ValueError("Process rule pre_processing_rules enabled is invalid") - unique_pre_processing_rule_dicts[pre_processing_rule["id"]] = pre_processing_rule + unique_pre_processing_rule_dicts[pre_processing_rule.id] = pre_processing_rule - args["process_rule"]["rules"]["pre_processing_rules"] = list(unique_pre_processing_rule_dicts.values()) + knowledge_config.process_rule.rules.pre_processing_rules = list(unique_pre_processing_rule_dicts.values()) - if ( - "segmentation" not in args["process_rule"]["rules"] - or args["process_rule"]["rules"]["segmentation"] is None - ): + if not knowledge_config.process_rule.rules.segmentation: raise ValueError("Process rule segmentation is required") - if not isinstance(args["process_rule"]["rules"]["segmentation"], dict): - raise ValueError("Process rule segmentation is invalid") - - if ( - "separator" not in args["process_rule"]["rules"]["segmentation"] - or not args["process_rule"]["rules"]["segmentation"]["separator"] - ): + if not knowledge_config.process_rule.rules.segmentation.separator: raise ValueError("Process rule segmentation separator is required") - if not isinstance(args["process_rule"]["rules"]["segmentation"]["separator"], str): + if not isinstance(knowledge_config.process_rule.rules.segmentation.separator, str): raise ValueError("Process rule segmentation separator is invalid") - if ( - "max_tokens" not in args["process_rule"]["rules"]["segmentation"] - or not args["process_rule"]["rules"]["segmentation"]["max_tokens"] + if not ( + knowledge_config.process_rule.mode == "hierarchical" + and knowledge_config.process_rule.rules.parent_mode == "full-doc" ): - raise ValueError("Process rule segmentation max_tokens is required") + if not knowledge_config.process_rule.rules.segmentation.max_tokens: + raise ValueError("Process rule segmentation max_tokens is required") - if not isinstance(args["process_rule"]["rules"]["segmentation"]["max_tokens"], int): - raise ValueError("Process rule segmentation max_tokens is invalid") + if not isinstance(knowledge_config.process_rule.rules.segmentation.max_tokens, int): + raise ValueError("Process rule segmentation max_tokens is invalid") @classmethod def estimate_args_validate(cls, args: dict): @@ -1447,7 +1461,7 @@ class SegmentService: # save vector index try: - VectorService.create_segments_vector([args["keywords"]], [segment_document], dataset) + VectorService.create_segments_vector([args["keywords"]], [segment_document], dataset, document.doc_form) except Exception as e: logging.exception("create segment index failed") segment_document.enabled = False @@ -1528,7 +1542,7 @@ class SegmentService: db.session.add(document) try: # save vector index - VectorService.create_segments_vector(keywords_list, pre_segment_data_list, dataset) + VectorService.create_segments_vector(keywords_list, pre_segment_data_list, dataset, document.doc_form) except Exception as e: logging.exception("create segment index failed") for segment_document in segment_data_list: @@ -1540,14 +1554,13 @@ class SegmentService: return segment_data_list @classmethod - def update_segment(cls, args: dict, segment: DocumentSegment, document: Document, dataset: Dataset): - segment_update_entity = SegmentUpdateEntity(**args) + def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, document: Document, dataset: Dataset): indexing_cache_key = "segment_{}_indexing".format(segment.id) cache_result = redis_client.get(indexing_cache_key) if cache_result is not None: raise ValueError("Segment is indexing, please try again later") - if segment_update_entity.enabled is not None: - action = segment_update_entity.enabled + if args.enabled is not None: + action = args.enabled if segment.enabled != action: if not action: segment.enabled = action @@ -1560,22 +1573,22 @@ class SegmentService: disable_segment_from_index_task.delay(segment.id) return segment if not segment.enabled: - if segment_update_entity.enabled is not None: - if not segment_update_entity.enabled: + if args.enabled is not None: + if not args.enabled: raise ValueError("Can't update disabled segment") else: raise ValueError("Can't update disabled segment") try: word_count_change = segment.word_count - content = segment_update_entity.content + content = args.content or segment.content if segment.content == content: segment.word_count = len(content) if document.doc_form == "qa_model": - segment.answer = segment_update_entity.answer - segment.word_count += len(segment_update_entity.answer or "") + segment.answer = args.answer + segment.word_count += len(args.answer) if args.answer else 0 word_count_change = segment.word_count - word_count_change - if segment_update_entity.keywords: - segment.keywords = segment_update_entity.keywords + if args.keywords: + segment.keywords = args.keywords segment.enabled = True segment.disabled_at = None segment.disabled_by = None @@ -1586,9 +1599,45 @@ class SegmentService: document.word_count = max(0, document.word_count + word_count_change) db.session.add(document) # update segment index task - if segment_update_entity.enabled: - keywords = segment_update_entity.keywords or [] - VectorService.create_segments_vector([keywords], [segment], dataset) + if args.enabled: + VectorService.create_segments_vector( + [args.keywords] if args.keywords else None, + [segment], + dataset, + document.doc_form, + ) + if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: + # regenerate child chunks + # get embedding model instance + if dataset.indexing_technique == "high_quality": + # check embedding model setting + model_manager = ModelManager() + + if dataset.embedding_model_provider: + embedding_model_instance = model_manager.get_model_instance( + tenant_id=dataset.tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model, + ) + else: + embedding_model_instance = model_manager.get_default_model_instance( + tenant_id=dataset.tenant_id, + model_type=ModelType.TEXT_EMBEDDING, + ) + else: + raise ValueError("The knowledge base index technique is not high quality!") + # get the process rule + processing_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.id == document.dataset_process_rule_id) + .first() + ) + if not processing_rule: + raise ValueError("No processing rule found.") + VectorService.generate_child_chunks( + segment, document, dataset, embedding_model_instance, processing_rule, True + ) else: segment_hash = helper.generate_text_hash(content) tokens = 0 @@ -1619,8 +1668,8 @@ class SegmentService: segment.disabled_at = None segment.disabled_by = None if document.doc_form == "qa_model": - segment.answer = segment_update_entity.answer - segment.word_count += len(segment_update_entity.answer or "") + segment.answer = args.answer + segment.word_count += len(args.answer) if args.answer else 0 word_count_change = segment.word_count - word_count_change # update document word count if word_count_change != 0: @@ -1628,8 +1677,40 @@ class SegmentService: db.session.add(document) db.session.add(segment) db.session.commit() - # update segment vector index - VectorService.update_segment_vector(segment_update_entity.keywords, segment, dataset) + if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: + # get embedding model instance + if dataset.indexing_technique == "high_quality": + # check embedding model setting + model_manager = ModelManager() + + if dataset.embedding_model_provider: + embedding_model_instance = model_manager.get_model_instance( + tenant_id=dataset.tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model, + ) + else: + embedding_model_instance = model_manager.get_default_model_instance( + tenant_id=dataset.tenant_id, + model_type=ModelType.TEXT_EMBEDDING, + ) + else: + raise ValueError("The knowledge base index technique is not high quality!") + # get the process rule + processing_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.id == document.dataset_process_rule_id) + .first() + ) + if not processing_rule: + raise ValueError("No processing rule found.") + VectorService.generate_child_chunks( + segment, document, dataset, embedding_model_instance, processing_rule, True + ) + elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX): + # update segment vector index + VectorService.update_segment_vector(args.keywords, segment, dataset) except Exception as e: logging.exception("update segment index failed") @@ -1652,13 +1733,265 @@ class SegmentService: if segment.enabled: # send delete segment index task redis_client.setex(indexing_cache_key, 600, 1) - delete_segment_from_index_task.delay(segment.id, segment.index_node_id, dataset.id, document.id) + delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id) db.session.delete(segment) # update document word count document.word_count -= segment.word_count db.session.add(document) db.session.commit() + @classmethod + def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset): + index_node_ids = ( + DocumentSegment.query.with_entities(DocumentSegment.index_node_id) + .filter( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.document_id == document.id, + DocumentSegment.tenant_id == current_user.current_tenant_id, + ) + .all() + ) + index_node_ids = [index_node_id[0] for index_node_id in index_node_ids] + + delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id) + db.session.query(DocumentSegment).filter(DocumentSegment.id.in_(segment_ids)).delete() + db.session.commit() + + @classmethod + def update_segments_status(cls, segment_ids: list, action: str, dataset: Dataset, document: Document): + if action == "enable": + segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.document_id == document.id, + DocumentSegment.enabled == False, + ) + .all() + ) + if not segments: + return + real_deal_segmment_ids = [] + for segment in segments: + indexing_cache_key = "segment_{}_indexing".format(segment.id) + cache_result = redis_client.get(indexing_cache_key) + if cache_result is not None: + continue + segment.enabled = True + segment.disabled_at = None + segment.disabled_by = None + db.session.add(segment) + real_deal_segmment_ids.append(segment.id) + db.session.commit() + + enable_segments_to_index_task.delay(real_deal_segmment_ids, dataset.id, document.id) + elif action == "disable": + segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.document_id == document.id, + DocumentSegment.enabled == True, + ) + .all() + ) + if not segments: + return + real_deal_segmment_ids = [] + for segment in segments: + indexing_cache_key = "segment_{}_indexing".format(segment.id) + cache_result = redis_client.get(indexing_cache_key) + if cache_result is not None: + continue + segment.enabled = False + segment.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + segment.disabled_by = current_user.id + db.session.add(segment) + real_deal_segmment_ids.append(segment.id) + db.session.commit() + + disable_segments_from_index_task.delay(real_deal_segmment_ids, dataset.id, document.id) + else: + raise InvalidActionError() + + @classmethod + def create_child_chunk( + cls, content: str, segment: DocumentSegment, document: Document, dataset: Dataset + ) -> ChildChunk: + lock_name = "add_child_lock_{}".format(segment.id) + with redis_client.lock(lock_name, timeout=20): + index_node_id = str(uuid.uuid4()) + index_node_hash = helper.generate_text_hash(content) + child_chunk_count = ( + db.session.query(ChildChunk) + .filter( + ChildChunk.tenant_id == current_user.current_tenant_id, + ChildChunk.dataset_id == dataset.id, + ChildChunk.document_id == document.id, + ChildChunk.segment_id == segment.id, + ) + .count() + ) + max_position = ( + db.session.query(func.max(ChildChunk.position)) + .filter( + ChildChunk.tenant_id == current_user.current_tenant_id, + ChildChunk.dataset_id == dataset.id, + ChildChunk.document_id == document.id, + ChildChunk.segment_id == segment.id, + ) + .scalar() + ) + child_chunk = ChildChunk( + tenant_id=current_user.current_tenant_id, + dataset_id=dataset.id, + document_id=document.id, + segment_id=segment.id, + position=max_position + 1, + index_node_id=index_node_id, + index_node_hash=index_node_hash, + content=content, + word_count=len(content), + type="customized", + created_by=current_user.id, + ) + db.session.add(child_chunk) + # save vector index + try: + VectorService.create_child_chunk_vector(child_chunk, dataset) + except Exception as e: + logging.exception("create child chunk index failed") + db.session.rollback() + raise ChildChunkIndexingError(str(e)) + db.session.commit() + + return child_chunk + + @classmethod + def update_child_chunks( + cls, + child_chunks_update_args: list[ChildChunkUpdateArgs], + segment: DocumentSegment, + document: Document, + dataset: Dataset, + ) -> list[ChildChunk]: + child_chunks = ( + db.session.query(ChildChunk) + .filter( + ChildChunk.dataset_id == dataset.id, + ChildChunk.document_id == document.id, + ChildChunk.segment_id == segment.id, + ) + .all() + ) + child_chunks_map = {chunk.id: chunk for chunk in child_chunks} + + new_child_chunks, update_child_chunks, delete_child_chunks, new_child_chunks_args = [], [], [], [] + + for child_chunk_update_args in child_chunks_update_args: + if child_chunk_update_args.id: + child_chunk = child_chunks_map.pop(child_chunk_update_args.id, None) + if child_chunk: + if child_chunk.content != child_chunk_update_args.content: + child_chunk.content = child_chunk_update_args.content + child_chunk.word_count = len(child_chunk.content) + child_chunk.updated_by = current_user.id + child_chunk.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + child_chunk.type = "customized" + update_child_chunks.append(child_chunk) + else: + new_child_chunks_args.append(child_chunk_update_args) + if child_chunks_map: + delete_child_chunks = list(child_chunks_map.values()) + try: + if update_child_chunks: + db.session.bulk_save_objects(update_child_chunks) + + if delete_child_chunks: + for child_chunk in delete_child_chunks: + db.session.delete(child_chunk) + if new_child_chunks_args: + child_chunk_count = len(child_chunks) + for position, args in enumerate(new_child_chunks_args, start=child_chunk_count + 1): + index_node_id = str(uuid.uuid4()) + index_node_hash = helper.generate_text_hash(args.content) + child_chunk = ChildChunk( + tenant_id=current_user.current_tenant_id, + dataset_id=dataset.id, + document_id=document.id, + segment_id=segment.id, + position=position, + index_node_id=index_node_id, + index_node_hash=index_node_hash, + content=args.content, + word_count=len(args.content), + type="customized", + created_by=current_user.id, + ) + + db.session.add(child_chunk) + db.session.flush() + new_child_chunks.append(child_chunk) + VectorService.update_child_chunk_vector(new_child_chunks, update_child_chunks, delete_child_chunks, dataset) + db.session.commit() + except Exception as e: + logging.exception("update child chunk index failed") + db.session.rollback() + raise ChildChunkIndexingError(str(e)) + return sorted(new_child_chunks + update_child_chunks, key=lambda x: x.position) + + @classmethod + def update_child_chunk( + cls, + content: str, + child_chunk: ChildChunk, + segment: DocumentSegment, + document: Document, + dataset: Dataset, + ) -> ChildChunk: + try: + child_chunk.content = content + child_chunk.word_count = len(content) + child_chunk.updated_by = current_user.id + child_chunk.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + child_chunk.type = "customized" + db.session.add(child_chunk) + VectorService.update_child_chunk_vector([], [child_chunk], [], dataset) + db.session.commit() + except Exception as e: + logging.exception("update child chunk index failed") + db.session.rollback() + raise ChildChunkIndexingError(str(e)) + return child_chunk + + @classmethod + def delete_child_chunk(cls, child_chunk: ChildChunk, dataset: Dataset): + db.session.delete(child_chunk) + try: + VectorService.delete_child_chunk_vector(child_chunk, dataset) + except Exception as e: + logging.exception("delete child chunk index failed") + db.session.rollback() + raise ChildChunkDeleteIndexError(str(e)) + db.session.commit() + + @classmethod + def get_child_chunks( + cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None + ): + query = ChildChunk.query.filter_by( + tenant_id=current_user.current_tenant_id, + dataset_id=dataset_id, + document_id=document_id, + segment_id=segment_id, + ).order_by(ChildChunk.position.asc()) + if keyword: + query = query.where(ChildChunk.content.ilike(f"%{keyword}%")) + return query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) + class DatasetCollectionBindingService: @classmethod diff --git a/api/services/entities/knowledge_entities/knowledge_entities.py b/api/services/entities/knowledge_entities/knowledge_entities.py index 449b79f339..76d9c28812 100644 --- a/api/services/entities/knowledge_entities/knowledge_entities.py +++ b/api/services/entities/knowledge_entities/knowledge_entities.py @@ -1,4 +1,5 @@ -from typing import Optional +from enum import Enum +from typing import Literal, Optional from pydantic import BaseModel @@ -8,3 +9,112 @@ class SegmentUpdateEntity(BaseModel): answer: Optional[str] = None keywords: Optional[list[str]] = None enabled: Optional[bool] = None + + +class ParentMode(str, Enum): + FULL_DOC = "full-doc" + PARAGRAPH = "paragraph" + + +class NotionIcon(BaseModel): + type: str + url: Optional[str] = None + emoji: Optional[str] = None + + +class NotionPage(BaseModel): + page_id: str + page_name: str + page_icon: Optional[NotionIcon] = None + type: str + + +class NotionInfo(BaseModel): + workspace_id: str + pages: list[NotionPage] + + +class WebsiteInfo(BaseModel): + provider: str + job_id: str + urls: list[str] + only_main_content: bool = True + + +class FileInfo(BaseModel): + file_ids: list[str] + + +class InfoList(BaseModel): + data_source_type: Literal["upload_file", "notion_import", "website_crawl"] + notion_info_list: Optional[list[NotionInfo]] = None + file_info_list: Optional[FileInfo] = None + website_info_list: Optional[WebsiteInfo] = None + + +class DataSource(BaseModel): + info_list: InfoList + + +class PreProcessingRule(BaseModel): + id: str + enabled: bool + + +class Segmentation(BaseModel): + separator: str = "\n" + max_tokens: int + chunk_overlap: int = 0 + + +class Rule(BaseModel): + pre_processing_rules: Optional[list[PreProcessingRule]] = None + segmentation: Optional[Segmentation] = None + parent_mode: Optional[Literal["full-doc", "paragraph"]] = None + subchunk_segmentation: Optional[Segmentation] = None + + +class ProcessRule(BaseModel): + mode: Literal["automatic", "custom", "hierarchical"] + rules: Optional[Rule] = None + + +class RerankingModel(BaseModel): + reranking_provider_name: Optional[str] = None + reranking_model_name: Optional[str] = None + + +class RetrievalModel(BaseModel): + search_method: Literal["hybrid_search", "semantic_search", "full_text_search"] + reranking_enable: bool + reranking_model: Optional[RerankingModel] = None + top_k: int + score_threshold_enabled: bool + score_threshold: Optional[float] = None + + +class KnowledgeConfig(BaseModel): + original_document_id: Optional[str] = None + duplicate: bool = True + indexing_technique: Literal["high_quality", "economy"] + data_source: DataSource + process_rule: Optional[ProcessRule] = None + retrieval_model: Optional[RetrievalModel] = None + doc_form: str = "text_model" + doc_language: str = "English" + embedding_model: Optional[str] = None + embedding_model_provider: Optional[str] = None + name: Optional[str] = None + + +class SegmentUpdateArgs(BaseModel): + content: Optional[str] = None + answer: Optional[str] = None + keywords: Optional[list[str]] = None + regenerate_child_chunks: bool = False + enabled: Optional[bool] = None + + +class ChildChunkUpdateArgs(BaseModel): + id: Optional[str] = None + content: str diff --git a/api/services/errors/chunk.py b/api/services/errors/chunk.py new file mode 100644 index 0000000000..75bf4d5d5f --- /dev/null +++ b/api/services/errors/chunk.py @@ -0,0 +1,9 @@ +from services.errors.base import BaseServiceError + + +class ChildChunkIndexingError(BaseServiceError): + description = "{message}" + + +class ChildChunkDeleteIndexError(BaseServiceError): + description = "{message}" diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 36c79d7045..a42b3020cd 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -76,7 +76,7 @@ class FeatureService: cls._fulfill_params_from_env(features) - if dify_config.BILLING_ENABLED: + if dify_config.BILLING_ENABLED and tenant_id: cls._fulfill_params_from_billing_api(features, tenant_id) return features diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 41b4e1ec46..e9176fc1c6 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -7,7 +7,7 @@ from core.rag.models.document import Document from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from models.account import Account -from models.dataset import Dataset, DatasetQuery, DocumentSegment +from models.dataset import Dataset, DatasetQuery default_retrieval_model = { "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, @@ -69,7 +69,7 @@ class HitTestingService: db.session.add(dataset_query) db.session.commit() - return dict(cls.compact_retrieve_response(dataset, query, all_documents)) + return cls.compact_retrieve_response(query, all_documents) # type: ignore @classmethod def external_retrieve( @@ -106,41 +106,14 @@ class HitTestingService: return dict(cls.compact_external_retrieve_response(dataset, query, all_documents)) @classmethod - def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list[Document]): - records = [] - - for document in documents: - if document.metadata is None: - continue - - index_node_id = document.metadata["doc_id"] - - segment = ( - db.session.query(DocumentSegment) - .filter( - DocumentSegment.dataset_id == dataset.id, - DocumentSegment.enabled == True, - DocumentSegment.status == "completed", - DocumentSegment.index_node_id == index_node_id, - ) - .first() - ) - - if not segment: - continue - - record = { - "segment": segment, - "score": document.metadata.get("score", None), - } - - records.append(record) + def compact_retrieve_response(cls, query: str, documents: list[Document]): + records = RetrievalService.format_retrieval_documents(documents) return { "query": { "content": query, }, - "records": records, + "records": [record.model_dump() for record in records], } @classmethod diff --git a/api/services/vector_service.py b/api/services/vector_service.py index 3c67351335..92422bf29d 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -1,40 +1,70 @@ from typing import Optional +from core.model_manager import ModelInstance, ModelManager +from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import Document -from models.dataset import Dataset, DocumentSegment +from extensions.ext_database import db +from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment +from models.dataset import Document as DatasetDocument +from services.entities.knowledge_entities.knowledge_entities import ParentMode class VectorService: @classmethod def create_segments_vector( - cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset + cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset, doc_form: str ): documents = [] + for segment in segments: - document = Document( - page_content=segment.content, - metadata={ - "doc_id": segment.index_node_id, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - }, - ) - documents.append(document) - if dataset.indexing_technique == "high_quality": - # save vector index - vector = Vector(dataset=dataset) - vector.add_texts(documents, duplicate_check=True) + if doc_form == IndexType.PARENT_CHILD_INDEX: + document = DatasetDocument.query.filter_by(id=segment.document_id).first() + # get the process rule + processing_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.id == document.dataset_process_rule_id) + .first() + ) + if not processing_rule: + raise ValueError("No processing rule found.") + # get embedding model instance + if dataset.indexing_technique == "high_quality": + # check embedding model setting + model_manager = ModelManager() - # save keyword index - keyword = Keyword(dataset) - - if keywords_list and len(keywords_list) > 0: - keyword.add_texts(documents, keywords_list=keywords_list) - else: - keyword.add_texts(documents) + if dataset.embedding_model_provider: + embedding_model_instance = model_manager.get_model_instance( + tenant_id=dataset.tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model, + ) + else: + embedding_model_instance = model_manager.get_default_model_instance( + tenant_id=dataset.tenant_id, + model_type=ModelType.TEXT_EMBEDDING, + ) + else: + raise ValueError("The knowledge base index technique is not high quality!") + cls.generate_child_chunks(segment, document, dataset, embedding_model_instance, processing_rule, False) + else: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + documents.append(document) + if len(documents) > 0: + index_processor = IndexProcessorFactory(doc_form).init_index_processor() + index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list) @classmethod def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset): @@ -65,3 +95,123 @@ class VectorService: keyword.add_texts([document], keywords_list=[keywords]) else: keyword.add_texts([document]) + + @classmethod + def generate_child_chunks( + cls, + segment: DocumentSegment, + dataset_document: DatasetDocument, + dataset: Dataset, + embedding_model_instance: ModelInstance, + processing_rule: DatasetProcessRule, + regenerate: bool = False, + ): + index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() + if regenerate: + # delete child chunks + index_processor.clean(dataset, [segment.index_node_id], with_keywords=True, delete_child_chunks=True) + + # generate child chunks + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + # use full doc mode to generate segment's child chunk + processing_rule_dict = processing_rule.to_dict() + processing_rule_dict["rules"]["parent_mode"] = ParentMode.FULL_DOC.value + documents = index_processor.transform( + [document], + embedding_model_instance=embedding_model_instance, + process_rule=processing_rule_dict, + tenant_id=dataset.tenant_id, + doc_language=dataset_document.doc_language, + ) + # save child chunks + if documents and documents[0].children: + index_processor.load(dataset, documents) + + for position, child_chunk in enumerate(documents[0].children, start=1): + child_segment = ChildChunk( + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + document_id=dataset_document.id, + segment_id=segment.id, + position=position, + index_node_id=child_chunk.metadata["doc_id"], + index_node_hash=child_chunk.metadata["doc_hash"], + content=child_chunk.page_content, + word_count=len(child_chunk.page_content), + type="automatic", + created_by=dataset_document.created_by, + ) + db.session.add(child_segment) + db.session.commit() + + @classmethod + def create_child_chunk_vector(cls, child_segment: ChildChunk, dataset: Dataset): + child_document = Document( + page_content=child_segment.content, + metadata={ + "doc_id": child_segment.index_node_id, + "doc_hash": child_segment.index_node_hash, + "document_id": child_segment.document_id, + "dataset_id": child_segment.dataset_id, + }, + ) + if dataset.indexing_technique == "high_quality": + # save vector index + vector = Vector(dataset=dataset) + vector.add_texts([child_document], duplicate_check=True) + + @classmethod + def update_child_chunk_vector( + cls, + new_child_chunks: list[ChildChunk], + update_child_chunks: list[ChildChunk], + delete_child_chunks: list[ChildChunk], + dataset: Dataset, + ): + documents = [] + delete_node_ids = [] + for new_child_chunk in new_child_chunks: + new_child_document = Document( + page_content=new_child_chunk.content, + metadata={ + "doc_id": new_child_chunk.index_node_id, + "doc_hash": new_child_chunk.index_node_hash, + "document_id": new_child_chunk.document_id, + "dataset_id": new_child_chunk.dataset_id, + }, + ) + documents.append(new_child_document) + for update_child_chunk in update_child_chunks: + child_document = Document( + page_content=update_child_chunk.content, + metadata={ + "doc_id": update_child_chunk.index_node_id, + "doc_hash": update_child_chunk.index_node_hash, + "document_id": update_child_chunk.document_id, + "dataset_id": update_child_chunk.dataset_id, + }, + ) + documents.append(child_document) + delete_node_ids.append(update_child_chunk.index_node_id) + for delete_child_chunk in delete_child_chunks: + delete_node_ids.append(delete_child_chunk.index_node_id) + if dataset.indexing_technique == "high_quality": + # update vector index + vector = Vector(dataset=dataset) + if delete_node_ids: + vector.delete_by_ids(delete_node_ids) + if documents: + vector.add_texts(documents, duplicate_check=True) + + @classmethod + def delete_child_chunk_vector(cls, child_chunk: ChildChunk, dataset: Dataset): + vector = Vector(dataset=dataset) + vector.delete_by_ids([child_chunk.index_node_id]) diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index 50bb2b6e63..9a172b2d9d 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -6,12 +6,13 @@ import click from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound +from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import Document +from core.rag.models.document import ChildDocument, Document from extensions.ext_database import db from extensions.ext_redis import redis_client +from models.dataset import DatasetAutoDisableLog, DocumentSegment from models.dataset import Document as DatasetDocument -from models.dataset import DocumentSegment @shared_task(queue="dataset") @@ -53,7 +54,22 @@ def add_document_to_index_task(dataset_document_id: str): "dataset_id": segment.dataset_id, }, ) - + if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + child_chunks = segment.child_chunks + if child_chunks: + child_documents = [] + for child_chunk in child_chunks: + child_document = ChildDocument( + page_content=child_chunk.content, + metadata={ + "doc_id": child_chunk.index_node_id, + "doc_hash": child_chunk.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + child_documents.append(child_document) + document.children = child_documents documents.append(document) dataset = dataset_document.dataset @@ -65,6 +81,12 @@ def add_document_to_index_task(dataset_document_id: str): index_processor = IndexProcessorFactory(index_type).init_index_processor() index_processor.load(dataset, documents) + # delete auto disable log + db.session.query(DatasetAutoDisableLog).filter( + DatasetAutoDisableLog.document_id == dataset_document.id + ).delete() + db.session.commit() + end_at = time.perf_counter() logging.info( click.style( diff --git a/api/tasks/batch_clean_document_task.py b/api/tasks/batch_clean_document_task.py new file mode 100644 index 0000000000..3bae82a5e3 --- /dev/null +++ b/api/tasks/batch_clean_document_task.py @@ -0,0 +1,76 @@ +import logging +import time + +import click +from celery import shared_task # type: ignore + +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from core.tools.utils.web_reader_tool import get_image_upload_file_ids +from extensions.ext_database import db +from extensions.ext_storage import storage +from models.dataset import Dataset, DocumentSegment +from models.model import UploadFile + + +@shared_task(queue="dataset") +def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str, file_ids: list[str]): + """ + Clean document when document deleted. + :param document_ids: document ids + :param dataset_id: dataset id + :param doc_form: doc_form + :param file_ids: file ids + + Usage: clean_document_task.delay(document_id, dataset_id) + """ + logging.info(click.style("Start batch clean documents when documents deleted", fg="green")) + start_at = time.perf_counter() + + try: + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + + if not dataset: + raise Exception("Document has no dataset") + + segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id.in_(document_ids)).all() + # check segment is exist + if segments: + index_node_ids = [segment.index_node_id for segment in segments] + index_processor = IndexProcessorFactory(doc_form).init_index_processor() + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + + for segment in segments: + image_upload_file_ids = get_image_upload_file_ids(segment.content) + for upload_file_id in image_upload_file_ids: + image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() + try: + if image_file and image_file.key: + storage.delete(image_file.key) + except Exception: + logging.exception( + "Delete image_files failed when storage deleted, \ + image_upload_file_is: {}".format(upload_file_id) + ) + db.session.delete(image_file) + db.session.delete(segment) + + db.session.commit() + if file_ids: + files = db.session.query(UploadFile).filter(UploadFile.id.in_(file_ids)).all() + for file in files: + try: + storage.delete(file.key) + except Exception: + logging.exception("Delete file failed when document deleted, file_id: {}".format(file.id)) + db.session.delete(file) + db.session.commit() + + end_at = time.perf_counter() + logging.info( + click.style( + "Cleaned documents when documents deleted latency: {}".format(end_at - start_at), + fg="green", + ) + ) + except Exception: + logging.exception("Cleaned documents when documents deleted failed") diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index ce3d65526c..3238842307 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -7,13 +7,13 @@ import click from celery import shared_task # type: ignore from sqlalchemy import func -from core.indexing_runner import IndexingRunner from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from extensions.ext_redis import redis_client from libs import helper from models.dataset import Dataset, Document, DocumentSegment +from services.vector_service import VectorService @shared_task(queue="dataset") @@ -98,8 +98,7 @@ def batch_create_segment_to_index_task( dataset_document.word_count += word_count_change db.session.add(dataset_document) # add index to db - indexing_runner = IndexingRunner() - indexing_runner.batch_add_segments(document_segments, dataset) + VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form) db.session.commit() redis_client.setex(indexing_cache_key, 600, "completed") end_at = time.perf_counter() diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index c48eb2e320..4d77f1fb65 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -62,7 +62,7 @@ def clean_dataset_task( if doc_form is None: raise ValueError("Index type must be specified.") index_processor = IndexProcessorFactory(doc_form).init_index_processor() - index_processor.clean(dataset, None) + index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True) for document in documents: db.session.delete(document) diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index 05eb9fd625..5a4d7a52b2 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -38,7 +38,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i if segments: index_node_ids = [segment.index_node_id for segment in segments] index_processor = IndexProcessorFactory(doc_form).init_index_processor() - index_processor.clean(dataset, index_node_ids) + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) for segment in segments: image_upload_file_ids = get_image_upload_file_ids(segment.content) diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py index f5d6406d9c..5a6eb00a62 100644 --- a/api/tasks/clean_notion_document_task.py +++ b/api/tasks/clean_notion_document_task.py @@ -37,7 +37,7 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str): segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() index_node_ids = [segment.index_node_id for segment in segments] - index_processor.clean(dataset, index_node_ids) + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) for segment in segments: db.session.delete(segment) diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py index b025509aeb..0efc924a77 100644 --- a/api/tasks/deal_dataset_vector_index_task.py +++ b/api/tasks/deal_dataset_vector_index_task.py @@ -4,8 +4,9 @@ import time import click from celery import shared_task # type: ignore +from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import Document +from core.rag.models.document import ChildDocument, Document from extensions.ext_database import db from models.dataset import Dataset, DocumentSegment from models.dataset import Document as DatasetDocument @@ -105,7 +106,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): db.session.commit() # clean index - index_processor.clean(dataset, None, with_keywords=False) + index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False) for dataset_document in dataset_documents: # update from vector index @@ -128,7 +129,22 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): "dataset_id": segment.dataset_id, }, ) - + if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + child_chunks = segment.child_chunks + if child_chunks: + child_documents = [] + for child_chunk in child_chunks: + child_document = ChildDocument( + page_content=child_chunk.content, + metadata={ + "doc_id": child_chunk.index_node_id, + "doc_hash": child_chunk.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + child_documents.append(child_document) + document.children = child_documents documents.append(document) # save vector index index_processor.load(dataset, documents, with_keywords=False) diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py index 45a612c745..3b04143dd9 100644 --- a/api/tasks/delete_segment_from_index_task.py +++ b/api/tasks/delete_segment_from_index_task.py @@ -6,48 +6,38 @@ from celery import shared_task # type: ignore from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db -from extensions.ext_redis import redis_client from models.dataset import Dataset, Document @shared_task(queue="dataset") -def delete_segment_from_index_task(segment_id: str, index_node_id: str, dataset_id: str, document_id: str): +def delete_segment_from_index_task(index_node_ids: list, dataset_id: str, document_id: str): """ Async Remove segment from index - :param segment_id: - :param index_node_id: + :param index_node_ids: :param dataset_id: :param document_id: - Usage: delete_segment_from_index_task.delay(segment_id) + Usage: delete_segment_from_index_task.delay(segment_ids) """ - logging.info(click.style("Start delete segment from index: {}".format(segment_id), fg="green")) + logging.info(click.style("Start delete segment from index", fg="green")) start_at = time.perf_counter() - indexing_cache_key = "segment_{}_delete_indexing".format(segment_id) try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: - logging.info(click.style("Segment {} has no dataset, pass.".format(segment_id), fg="cyan")) return dataset_document = db.session.query(Document).filter(Document.id == document_id).first() if not dataset_document: - logging.info(click.style("Segment {} has no document, pass.".format(segment_id), fg="cyan")) return if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": - logging.info(click.style("Segment {} document status is invalid, pass.".format(segment_id), fg="cyan")) return index_type = dataset_document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() - index_processor.clean(dataset, [index_node_id]) + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) end_at = time.perf_counter() - logging.info( - click.style("Segment deleted from index: {} latency: {}".format(segment_id, end_at - start_at), fg="green") - ) + logging.info(click.style("Segment deleted from index latency: {}".format(end_at - start_at), fg="green")) except Exception: logging.exception("delete segment from index failed") - finally: - redis_client.delete(indexing_cache_key) diff --git a/api/tasks/disable_segments_from_index_task.py b/api/tasks/disable_segments_from_index_task.py new file mode 100644 index 0000000000..67112666e7 --- /dev/null +++ b/api/tasks/disable_segments_from_index_task.py @@ -0,0 +1,76 @@ +import logging +import time + +import click +from celery import shared_task # type: ignore + +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import Dataset, DocumentSegment +from models.dataset import Document as DatasetDocument + + +@shared_task(queue="dataset") +def disable_segments_from_index_task(segment_ids: list, dataset_id: str, document_id: str): + """ + Async disable segments from index + :param segment_ids: + + Usage: disable_segments_from_index_task.delay(segment_ids, dataset_id, document_id) + """ + start_at = time.perf_counter() + + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + logging.info(click.style("Dataset {} not found, pass.".format(dataset_id), fg="cyan")) + return + + dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first() + + if not dataset_document: + logging.info(click.style("Document {} not found, pass.".format(document_id), fg="cyan")) + return + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + logging.info(click.style("Document {} status is invalid, pass.".format(document_id), fg="cyan")) + return + # sync index processor + index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() + + segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.document_id == document_id, + ) + .all() + ) + + if not segments: + return + + try: + index_node_ids = [segment.index_node_id for segment in segments] + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False) + + end_at = time.perf_counter() + logging.info(click.style("Segments removed from index latency: {}".format(end_at - start_at), fg="green")) + except Exception: + # update segment error msg + db.session.query(DocumentSegment).filter( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.document_id == document_id, + ).update( + { + "disabled_at": None, + "disabled_by": None, + "enabled": True, + } + ) + db.session.commit() + finally: + for segment in segments: + indexing_cache_key = "segment_{}_indexing".format(segment.id) + redis_client.delete(indexing_cache_key) diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index ac4e81f95d..d686698b9a 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -82,7 +82,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): index_node_ids = [segment.index_node_id for segment in segments] # delete from vector index - index_processor.clean(dataset, index_node_ids) + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) for segment in segments: db.session.delete(segment) diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index 5f1e9a892f..d8f14830c9 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -47,7 +47,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str): index_node_ids = [segment.index_node_id for segment in segments] # delete from vector index - index_processor.clean(dataset, index_node_ids) + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) for segment in segments: db.session.delete(segment) diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py index 6db2620eb6..8e1d2b6b5d 100644 --- a/api/tasks/duplicate_document_indexing_task.py +++ b/api/tasks/duplicate_document_indexing_task.py @@ -51,7 +51,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): if document: document.indexing_status = "error" document.error = str(e) - document.stopped_at = datetime.datetime.utcnow() + document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(document) db.session.commit() return @@ -73,14 +73,14 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): index_node_ids = [segment.index_node_id for segment in segments] # delete from vector index - index_processor.clean(dataset, index_node_ids) + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) for segment in segments: db.session.delete(segment) db.session.commit() document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.utcnow() + document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) documents.append(document) db.session.add(document) db.session.commit() diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py index 2f6eb7b82a..76522f4720 100644 --- a/api/tasks/enable_segment_to_index_task.py +++ b/api/tasks/enable_segment_to_index_task.py @@ -6,8 +6,9 @@ import click from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound +from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import Document +from core.rag.models.document import ChildDocument, Document from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import DocumentSegment @@ -61,6 +62,22 @@ def enable_segment_to_index_task(segment_id: str): return index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() + if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + child_chunks = segment.child_chunks + if child_chunks: + child_documents = [] + for child_chunk in child_chunks: + child_document = ChildDocument( + page_content=child_chunk.content, + metadata={ + "doc_id": child_chunk.index_node_id, + "doc_hash": child_chunk.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + child_documents.append(child_document) + document.children = child_documents # save vector index index_processor.load(dataset, [document]) diff --git a/api/tasks/enable_segments_to_index_task.py b/api/tasks/enable_segments_to_index_task.py new file mode 100644 index 0000000000..0864e05e25 --- /dev/null +++ b/api/tasks/enable_segments_to_index_task.py @@ -0,0 +1,108 @@ +import datetime +import logging +import time + +import click +from celery import shared_task # type: ignore + +from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from core.rag.models.document import ChildDocument, Document +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import Dataset, DocumentSegment +from models.dataset import Document as DatasetDocument + + +@shared_task(queue="dataset") +def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_id: str): + """ + Async enable segments to index + :param segment_ids: + + Usage: enable_segments_to_index_task.delay(segment_ids) + """ + start_at = time.perf_counter() + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + logging.info(click.style("Dataset {} not found, pass.".format(dataset_id), fg="cyan")) + return + + dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first() + + if not dataset_document: + logging.info(click.style("Document {} not found, pass.".format(document_id), fg="cyan")) + return + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + logging.info(click.style("Document {} status is invalid, pass.".format(document_id), fg="cyan")) + return + # sync index processor + index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() + + segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.document_id == document_id, + ) + .all() + ) + if not segments: + return + + try: + documents = [] + for segment in segments: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": document_id, + "dataset_id": dataset_id, + }, + ) + + if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + child_chunks = segment.child_chunks + if child_chunks: + child_documents = [] + for child_chunk in child_chunks: + child_document = ChildDocument( + page_content=child_chunk.content, + metadata={ + "doc_id": child_chunk.index_node_id, + "doc_hash": child_chunk.index_node_hash, + "document_id": document_id, + "dataset_id": dataset_id, + }, + ) + child_documents.append(child_document) + document.children = child_documents + documents.append(document) + # save vector index + index_processor.load(dataset, documents) + + end_at = time.perf_counter() + logging.info(click.style("Segments enabled to index latency: {}".format(end_at - start_at), fg="green")) + except Exception as e: + logging.exception("enable segments to index failed") + # update segment error msg + db.session.query(DocumentSegment).filter( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.document_id == document_id, + ).update( + { + "error": str(e), + "status": "error", + "disabled_at": datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + "enabled": False, + } + ) + db.session.commit() + finally: + for segment in segments: + indexing_cache_key = "segment_{}_indexing".format(segment.id) + redis_client.delete(indexing_cache_key) diff --git a/api/tasks/remove_document_from_index_task.py b/api/tasks/remove_document_from_index_task.py index 4ba6d1a83e..1d580b3802 100644 --- a/api/tasks/remove_document_from_index_task.py +++ b/api/tasks/remove_document_from_index_task.py @@ -43,7 +43,7 @@ def remove_document_from_index_task(document_id: str): index_node_ids = [segment.index_node_id for segment in segments] if index_node_ids: try: - index_processor.clean(dataset, index_node_ids) + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False) except Exception: logging.exception(f"clean dataset {dataset.id} from index failed") diff --git a/api/tasks/retry_document_indexing_task.py b/api/tasks/retry_document_indexing_task.py index 485caa5152..74fd542f6c 100644 --- a/api/tasks/retry_document_indexing_task.py +++ b/api/tasks/retry_document_indexing_task.py @@ -48,7 +48,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): if document: document.indexing_status = "error" document.error = str(e) - document.stopped_at = datetime.datetime.utcnow() + document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(document) db.session.commit() redis_client.delete(retry_indexing_cache_key) @@ -69,14 +69,14 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): if segments: index_node_ids = [segment.index_node_id for segment in segments] # delete from vector index - index_processor.clean(dataset, index_node_ids) + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) - for segment in segments: - db.session.delete(segment) - db.session.commit() + for segment in segments: + db.session.delete(segment) + db.session.commit() document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.utcnow() + document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(document) db.session.commit() @@ -86,7 +86,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): except Exception as ex: document.indexing_status = "error" document.error = str(ex) - document.stopped_at = datetime.datetime.utcnow() + document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(document) db.session.commit() logging.info(click.style(str(ex), fg="yellow")) diff --git a/api/tasks/sync_website_document_indexing_task.py b/api/tasks/sync_website_document_indexing_task.py index 5d6b069cf4..8da050d0d1 100644 --- a/api/tasks/sync_website_document_indexing_task.py +++ b/api/tasks/sync_website_document_indexing_task.py @@ -46,7 +46,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): if document: document.indexing_status = "error" document.error = str(e) - document.stopped_at = datetime.datetime.utcnow() + document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(document) db.session.commit() redis_client.delete(sync_indexing_cache_key) @@ -65,14 +65,14 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): if segments: index_node_ids = [segment.index_node_id for segment in segments] # delete from vector index - index_processor.clean(dataset, index_node_ids) + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) - for segment in segments: - db.session.delete(segment) - db.session.commit() + for segment in segments: + db.session.delete(segment) + db.session.commit() document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.utcnow() + document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(document) db.session.commit() @@ -82,7 +82,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): except Exception as ex: document.indexing_status = "error" document.error = str(ex) - document.stopped_at = datetime.datetime.utcnow() + document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(document) db.session.commit() logging.info(click.style(str(ex), fg="yellow")) diff --git a/api/templates/clean_document_job_mail_template-US.html b/api/templates/clean_document_job_mail_template-US.html new file mode 100644 index 0000000000..b7c9538f9f --- /dev/null +++ b/api/templates/clean_document_job_mail_template-US.html @@ -0,0 +1,98 @@ + + + + + + Documents Disabled Notification + + + + + + \ No newline at end of file diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 69a00528a2..703b3185c6 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -15,15 +15,15 @@ x-shared-env: &shared-api-worker-env LOG_FILE: ${LOG_FILE:-/app/logs/server.log} LOG_FILE_MAX_SIZE: ${LOG_FILE_MAX_SIZE:-20} LOG_FILE_BACKUP_COUNT: ${LOG_FILE_BACKUP_COUNT:-5} - LOG_DATEFORMAT: ${LOG_DATEFORMAT:-"%Y-%m-%d %H:%M:%S"} + LOG_DATEFORMAT: ${LOG_DATEFORMAT:-%Y-%m-%d %H:%M:%S} LOG_TZ: ${LOG_TZ:-UTC} DEBUG: ${DEBUG:-false} FLASK_DEBUG: ${FLASK_DEBUG:-false} SECRET_KEY: ${SECRET_KEY:-sk-9f73s3ljTXVcMT3Blb3ljTqtsKiGHXVcMT3BlbkFJLK7U} INIT_PASSWORD: ${INIT_PASSWORD:-} DEPLOY_ENV: ${DEPLOY_ENV:-PRODUCTION} - CHECK_UPDATE_URL: ${CHECK_UPDATE_URL:-"https://updates.dify.ai"} - OPENAI_API_BASE: ${OPENAI_API_BASE:-"https://api.openai.com/v1"} + CHECK_UPDATE_URL: ${CHECK_UPDATE_URL:-https://updates.dify.ai} + OPENAI_API_BASE: ${OPENAI_API_BASE:-https://api.openai.com/v1} MIGRATION_ENABLED: ${MIGRATION_ENABLED:-true} FILES_ACCESS_TIMEOUT: ${FILES_ACCESS_TIMEOUT:-300} ACCESS_TOKEN_EXPIRE_MINUTES: ${ACCESS_TOKEN_EXPIRE_MINUTES:-60} @@ -69,7 +69,7 @@ x-shared-env: &shared-api-worker-env REDIS_USE_CLUSTERS: ${REDIS_USE_CLUSTERS:-false} REDIS_CLUSTERS: ${REDIS_CLUSTERS:-} REDIS_CLUSTERS_PASSWORD: ${REDIS_CLUSTERS_PASSWORD:-} - CELERY_BROKER_URL: ${CELERY_BROKER_URL:-"redis://:difyai123456@redis:6379/1"} + CELERY_BROKER_URL: ${CELERY_BROKER_URL:-redis://:difyai123456@redis:6379/1} BROKER_USE_SSL: ${BROKER_USE_SSL:-false} CELERY_USE_SENTINEL: ${CELERY_USE_SENTINEL:-false} CELERY_SENTINEL_MASTER_NAME: ${CELERY_SENTINEL_MASTER_NAME:-} @@ -88,13 +88,13 @@ x-shared-env: &shared-api-worker-env AZURE_BLOB_ACCOUNT_NAME: ${AZURE_BLOB_ACCOUNT_NAME:-difyai} AZURE_BLOB_ACCOUNT_KEY: ${AZURE_BLOB_ACCOUNT_KEY:-difyai} AZURE_BLOB_CONTAINER_NAME: ${AZURE_BLOB_CONTAINER_NAME:-difyai-container} - AZURE_BLOB_ACCOUNT_URL: ${AZURE_BLOB_ACCOUNT_URL:-"https://.blob.core.windows.net"} + AZURE_BLOB_ACCOUNT_URL: ${AZURE_BLOB_ACCOUNT_URL:-https://.blob.core.windows.net} GOOGLE_STORAGE_BUCKET_NAME: ${GOOGLE_STORAGE_BUCKET_NAME:-your-bucket-name} GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64: ${GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64:-your-google-service-account-json-base64-string} ALIYUN_OSS_BUCKET_NAME: ${ALIYUN_OSS_BUCKET_NAME:-your-bucket-name} ALIYUN_OSS_ACCESS_KEY: ${ALIYUN_OSS_ACCESS_KEY:-your-access-key} ALIYUN_OSS_SECRET_KEY: ${ALIYUN_OSS_SECRET_KEY:-your-secret-key} - ALIYUN_OSS_ENDPOINT: ${ALIYUN_OSS_ENDPOINT:-"https://oss-ap-southeast-1-internal.aliyuncs.com"} + ALIYUN_OSS_ENDPOINT: ${ALIYUN_OSS_ENDPOINT:-https://oss-ap-southeast-1-internal.aliyuncs.com} ALIYUN_OSS_REGION: ${ALIYUN_OSS_REGION:-ap-southeast-1} ALIYUN_OSS_AUTH_VERSION: ${ALIYUN_OSS_AUTH_VERSION:-v4} ALIYUN_OSS_PATH: ${ALIYUN_OSS_PATH:-your-path} @@ -103,7 +103,7 @@ x-shared-env: &shared-api-worker-env TENCENT_COS_SECRET_ID: ${TENCENT_COS_SECRET_ID:-your-secret-id} TENCENT_COS_REGION: ${TENCENT_COS_REGION:-your-region} TENCENT_COS_SCHEME: ${TENCENT_COS_SCHEME:-your-scheme} - OCI_ENDPOINT: ${OCI_ENDPOINT:-"https://objectstorage.us-ashburn-1.oraclecloud.com"} + OCI_ENDPOINT: ${OCI_ENDPOINT:-https://objectstorage.us-ashburn-1.oraclecloud.com} OCI_BUCKET_NAME: ${OCI_BUCKET_NAME:-your-bucket-name} OCI_ACCESS_KEY: ${OCI_ACCESS_KEY:-your-access-key} OCI_SECRET_KEY: ${OCI_SECRET_KEY:-your-secret-key} @@ -125,14 +125,14 @@ x-shared-env: &shared-api-worker-env SUPABASE_API_KEY: ${SUPABASE_API_KEY:-your-access-key} SUPABASE_URL: ${SUPABASE_URL:-your-server-url} VECTOR_STORE: ${VECTOR_STORE:-weaviate} - WEAVIATE_ENDPOINT: ${WEAVIATE_ENDPOINT:-"http://weaviate:8080"} + WEAVIATE_ENDPOINT: ${WEAVIATE_ENDPOINT:-http://weaviate:8080} WEAVIATE_API_KEY: ${WEAVIATE_API_KEY:-WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih} - QDRANT_URL: ${QDRANT_URL:-"http://qdrant:6333"} + QDRANT_URL: ${QDRANT_URL:-http://qdrant:6333} QDRANT_API_KEY: ${QDRANT_API_KEY:-difyai123456} QDRANT_CLIENT_TIMEOUT: ${QDRANT_CLIENT_TIMEOUT:-20} QDRANT_GRPC_ENABLED: ${QDRANT_GRPC_ENABLED:-false} QDRANT_GRPC_PORT: ${QDRANT_GRPC_PORT:-6334} - MILVUS_URI: ${MILVUS_URI:-"http://127.0.0.1:19530"} + MILVUS_URI: ${MILVUS_URI:-http://127.0.0.1:19530} MILVUS_TOKEN: ${MILVUS_TOKEN:-} MILVUS_USER: ${MILVUS_USER:-root} MILVUS_PASSWORD: ${MILVUS_PASSWORD:-Milvus} @@ -142,7 +142,7 @@ x-shared-env: &shared-api-worker-env MYSCALE_PASSWORD: ${MYSCALE_PASSWORD:-} MYSCALE_DATABASE: ${MYSCALE_DATABASE:-dify} MYSCALE_FTS_PARAMS: ${MYSCALE_FTS_PARAMS:-} - COUCHBASE_CONNECTION_STRING: ${COUCHBASE_CONNECTION_STRING:-"couchbase://couchbase-server"} + COUCHBASE_CONNECTION_STRING: ${COUCHBASE_CONNECTION_STRING:-couchbase://couchbase-server} COUCHBASE_USER: ${COUCHBASE_USER:-Administrator} COUCHBASE_PASSWORD: ${COUCHBASE_PASSWORD:-password} COUCHBASE_BUCKET_NAME: ${COUCHBASE_BUCKET_NAME:-Embeddings} @@ -176,15 +176,15 @@ x-shared-env: &shared-api-worker-env TIDB_VECTOR_USER: ${TIDB_VECTOR_USER:-} TIDB_VECTOR_PASSWORD: ${TIDB_VECTOR_PASSWORD:-} TIDB_VECTOR_DATABASE: ${TIDB_VECTOR_DATABASE:-dify} - TIDB_ON_QDRANT_URL: ${TIDB_ON_QDRANT_URL:-"http://127.0.0.1"} + TIDB_ON_QDRANT_URL: ${TIDB_ON_QDRANT_URL:-http://127.0.0.1} TIDB_ON_QDRANT_API_KEY: ${TIDB_ON_QDRANT_API_KEY:-dify} TIDB_ON_QDRANT_CLIENT_TIMEOUT: ${TIDB_ON_QDRANT_CLIENT_TIMEOUT:-20} TIDB_ON_QDRANT_GRPC_ENABLED: ${TIDB_ON_QDRANT_GRPC_ENABLED:-false} TIDB_ON_QDRANT_GRPC_PORT: ${TIDB_ON_QDRANT_GRPC_PORT:-6334} TIDB_PUBLIC_KEY: ${TIDB_PUBLIC_KEY:-dify} TIDB_PRIVATE_KEY: ${TIDB_PRIVATE_KEY:-dify} - TIDB_API_URL: ${TIDB_API_URL:-"http://127.0.0.1"} - TIDB_IAM_API_URL: ${TIDB_IAM_API_URL:-"http://127.0.0.1"} + TIDB_API_URL: ${TIDB_API_URL:-http://127.0.0.1} + TIDB_IAM_API_URL: ${TIDB_IAM_API_URL:-http://127.0.0.1} TIDB_REGION: ${TIDB_REGION:-regions/aws-us-east-1} TIDB_PROJECT_ID: ${TIDB_PROJECT_ID:-dify} TIDB_SPEND_LIMIT: ${TIDB_SPEND_LIMIT:-100} @@ -209,7 +209,7 @@ x-shared-env: &shared-api-worker-env OPENSEARCH_USER: ${OPENSEARCH_USER:-admin} OPENSEARCH_PASSWORD: ${OPENSEARCH_PASSWORD:-admin} OPENSEARCH_SECURE: ${OPENSEARCH_SECURE:-true} - TENCENT_VECTOR_DB_URL: ${TENCENT_VECTOR_DB_URL:-"http://127.0.0.1"} + TENCENT_VECTOR_DB_URL: ${TENCENT_VECTOR_DB_URL:-http://127.0.0.1} TENCENT_VECTOR_DB_API_KEY: ${TENCENT_VECTOR_DB_API_KEY:-dify} TENCENT_VECTOR_DB_TIMEOUT: ${TENCENT_VECTOR_DB_TIMEOUT:-30} TENCENT_VECTOR_DB_USERNAME: ${TENCENT_VECTOR_DB_USERNAME:-dify} @@ -221,7 +221,7 @@ x-shared-env: &shared-api-worker-env ELASTICSEARCH_USERNAME: ${ELASTICSEARCH_USERNAME:-elastic} ELASTICSEARCH_PASSWORD: ${ELASTICSEARCH_PASSWORD:-elastic} KIBANA_PORT: ${KIBANA_PORT:-5601} - BAIDU_VECTOR_DB_ENDPOINT: ${BAIDU_VECTOR_DB_ENDPOINT:-"http://127.0.0.1:5287"} + BAIDU_VECTOR_DB_ENDPOINT: ${BAIDU_VECTOR_DB_ENDPOINT:-http://127.0.0.1:5287} BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS: ${BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS:-30000} BAIDU_VECTOR_DB_ACCOUNT: ${BAIDU_VECTOR_DB_ACCOUNT:-root} BAIDU_VECTOR_DB_API_KEY: ${BAIDU_VECTOR_DB_API_KEY:-dify} @@ -235,7 +235,7 @@ x-shared-env: &shared-api-worker-env VIKINGDB_SCHEMA: ${VIKINGDB_SCHEMA:-http} VIKINGDB_CONNECTION_TIMEOUT: ${VIKINGDB_CONNECTION_TIMEOUT:-30} VIKINGDB_SOCKET_TIMEOUT: ${VIKINGDB_SOCKET_TIMEOUT:-30} - LINDORM_URL: ${LINDORM_URL:-"http://lindorm:30070"} + LINDORM_URL: ${LINDORM_URL:-http://lindorm:30070} LINDORM_USERNAME: ${LINDORM_USERNAME:-lindorm} LINDORM_PASSWORD: ${LINDORM_PASSWORD:-lindorm} OCEANBASE_VECTOR_HOST: ${OCEANBASE_VECTOR_HOST:-oceanbase} @@ -245,7 +245,7 @@ x-shared-env: &shared-api-worker-env OCEANBASE_VECTOR_DATABASE: ${OCEANBASE_VECTOR_DATABASE:-test} OCEANBASE_CLUSTER_NAME: ${OCEANBASE_CLUSTER_NAME:-difyai} OCEANBASE_MEMORY_LIMIT: ${OCEANBASE_MEMORY_LIMIT:-6G} - UPSTASH_VECTOR_URL: ${UPSTASH_VECTOR_URL:-"https://xxx-vector.upstash.io"} + UPSTASH_VECTOR_URL: ${UPSTASH_VECTOR_URL:-https://xxx-vector.upstash.io} UPSTASH_VECTOR_TOKEN: ${UPSTASH_VECTOR_TOKEN:-dify} UPLOAD_FILE_SIZE_LIMIT: ${UPLOAD_FILE_SIZE_LIMIT:-15} UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5} @@ -270,7 +270,7 @@ x-shared-env: &shared-api-worker-env NOTION_INTERNAL_SECRET: ${NOTION_INTERNAL_SECRET:-} MAIL_TYPE: ${MAIL_TYPE:-resend} MAIL_DEFAULT_SEND_FROM: ${MAIL_DEFAULT_SEND_FROM:-} - RESEND_API_URL: ${RESEND_API_URL:-"https://api.resend.com"} + RESEND_API_URL: ${RESEND_API_URL:-https://api.resend.com} RESEND_API_KEY: ${RESEND_API_KEY:-your-resend-api-key} SMTP_SERVER: ${SMTP_SERVER:-} SMTP_PORT: ${SMTP_PORT:-465} @@ -281,7 +281,7 @@ x-shared-env: &shared-api-worker-env INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: ${INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH:-4000} INVITE_EXPIRY_HOURS: ${INVITE_EXPIRY_HOURS:-72} RESET_PASSWORD_TOKEN_EXPIRY_MINUTES: ${RESET_PASSWORD_TOKEN_EXPIRY_MINUTES:-5} - CODE_EXECUTION_ENDPOINT: ${CODE_EXECUTION_ENDPOINT:-"http://sandbox:8194"} + CODE_EXECUTION_ENDPOINT: ${CODE_EXECUTION_ENDPOINT:-http://sandbox:8194} CODE_EXECUTION_API_KEY: ${CODE_EXECUTION_API_KEY:-dify-sandbox} CODE_MAX_NUMBER: ${CODE_MAX_NUMBER:-9223372036854775807} CODE_MIN_NUMBER: ${CODE_MIN_NUMBER:--9223372036854775808} @@ -303,8 +303,8 @@ x-shared-env: &shared-api-worker-env WORKFLOW_FILE_UPLOAD_LIMIT: ${WORKFLOW_FILE_UPLOAD_LIMIT:-10} HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760} HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576} - SSRF_PROXY_HTTP_URL: ${SSRF_PROXY_HTTP_URL:-"http://ssrf_proxy:3128"} - SSRF_PROXY_HTTPS_URL: ${SSRF_PROXY_HTTPS_URL:-"http://ssrf_proxy:3128"} + SSRF_PROXY_HTTP_URL: ${SSRF_PROXY_HTTP_URL:-http://ssrf_proxy:3128} + SSRF_PROXY_HTTPS_URL: ${SSRF_PROXY_HTTPS_URL:-http://ssrf_proxy:3128} TEXT_GENERATION_TIMEOUT_MS: ${TEXT_GENERATION_TIMEOUT_MS:-60000} PGUSER: ${PGUSER:-${DB_USERNAME}} POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-${DB_PASSWORD}} @@ -314,8 +314,8 @@ x-shared-env: &shared-api-worker-env SANDBOX_GIN_MODE: ${SANDBOX_GIN_MODE:-release} SANDBOX_WORKER_TIMEOUT: ${SANDBOX_WORKER_TIMEOUT:-15} SANDBOX_ENABLE_NETWORK: ${SANDBOX_ENABLE_NETWORK:-true} - SANDBOX_HTTP_PROXY: ${SANDBOX_HTTP_PROXY:-"http://ssrf_proxy:3128"} - SANDBOX_HTTPS_PROXY: ${SANDBOX_HTTPS_PROXY:-"http://ssrf_proxy:3128"} + SANDBOX_HTTP_PROXY: ${SANDBOX_HTTP_PROXY:-http://ssrf_proxy:3128} + SANDBOX_HTTPS_PROXY: ${SANDBOX_HTTPS_PROXY:-http://ssrf_proxy:3128} SANDBOX_PORT: ${SANDBOX_PORT:-8194} WEAVIATE_PERSISTENCE_DATA_PATH: ${WEAVIATE_PERSISTENCE_DATA_PATH:-/var/lib/weaviate} WEAVIATE_QUERY_DEFAULTS_LIMIT: ${WEAVIATE_QUERY_DEFAULTS_LIMIT:-25} @@ -338,8 +338,8 @@ x-shared-env: &shared-api-worker-env ETCD_SNAPSHOT_COUNT: ${ETCD_SNAPSHOT_COUNT:-50000} MINIO_ACCESS_KEY: ${MINIO_ACCESS_KEY:-minioadmin} MINIO_SECRET_KEY: ${MINIO_SECRET_KEY:-minioadmin} - ETCD_ENDPOINTS: ${ETCD_ENDPOINTS:-"etcd:2379"} - MINIO_ADDRESS: ${MINIO_ADDRESS:-"minio:9000"} + ETCD_ENDPOINTS: ${ETCD_ENDPOINTS:-etcd:2379} + MINIO_ADDRESS: ${MINIO_ADDRESS:-minio:9000} MILVUS_AUTHORIZATION_ENABLED: ${MILVUS_AUTHORIZATION_ENABLED:-true} PGVECTOR_PGUSER: ${PGVECTOR_PGUSER:-postgres} PGVECTOR_POSTGRES_PASSWORD: ${PGVECTOR_POSTGRES_PASSWORD:-difyai123456} @@ -360,7 +360,7 @@ x-shared-env: &shared-api-worker-env NGINX_SSL_PORT: ${NGINX_SSL_PORT:-443} NGINX_SSL_CERT_FILENAME: ${NGINX_SSL_CERT_FILENAME:-dify.crt} NGINX_SSL_CERT_KEY_FILENAME: ${NGINX_SSL_CERT_KEY_FILENAME:-dify.key} - NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-"TLSv1.1 TLSv1.2 TLSv1.3"} + NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.1 TLSv1.2 TLSv1.3} NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto} NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-15M} NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65} @@ -374,7 +374,7 @@ x-shared-env: &shared-api-worker-env SSRF_COREDUMP_DIR: ${SSRF_COREDUMP_DIR:-/var/spool/squid} SSRF_REVERSE_PROXY_PORT: ${SSRF_REVERSE_PROXY_PORT:-8194} SSRF_SANDBOX_HOST: ${SSRF_SANDBOX_HOST:-sandbox} - COMPOSE_PROFILES: ${COMPOSE_PROFILES:-"${VECTOR_STORE:-weaviate}"} + COMPOSE_PROFILES: ${COMPOSE_PROFILES:-${VECTOR_STORE:-weaviate}} EXPOSE_NGINX_PORT: ${EXPOSE_NGINX_PORT:-80} EXPOSE_NGINX_SSL_PORT: ${EXPOSE_NGINX_SSL_PORT:-443} POSITION_TOOL_PINS: ${POSITION_TOOL_PINS:-} diff --git a/docker/generate_docker_compose b/docker/generate_docker_compose index 54b6d55217..dc4460f96c 100755 --- a/docker/generate_docker_compose +++ b/docker/generate_docker_compose @@ -43,7 +43,7 @@ def generate_shared_env_block(env_vars, anchor_name="shared-api-worker-env"): else: # If default value contains special characters, wrap it in quotes if re.search(r"[:\s]", default): - default = f'"{default}"' + default = f"{default}" lines.append(f" {key}: ${{{key}:-{default}}}") return "\n".join(lines) diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx index 1b327185e5..a6fb116fa8 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx @@ -201,7 +201,7 @@ const DatasetDetailLayout: FC = (props) => { }, [isMobile, setAppSiderbarExpand]) if (!datasetRes && !error) - return + return return (
    @@ -220,7 +220,7 @@ const DatasetDetailLayout: FC = (props) => { dataset: datasetRes, mutateDatasetRes: () => mutateDatasetRes(), }}> -
    {children}
    +
    {children}
    ) diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx index df314ddafe..3a65f1d30f 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx @@ -7,10 +7,10 @@ const Settings = async () => { const { t } = await translate(locale, 'dataset-settings') return ( -
    +
    -
    {t('title')}
    -
    {t('desc')}
    +
    {t('title')}
    +
    {t('desc')}
    diff --git a/web/app/(commonLayout)/datasets/Container.tsx b/web/app/(commonLayout)/datasets/Container.tsx index 3be8b2a968..6e598ab585 100644 --- a/web/app/(commonLayout)/datasets/Container.tsx +++ b/web/app/(commonLayout)/datasets/Container.tsx @@ -17,7 +17,6 @@ import TagManagementModal from '@/app/components/base/tag-management' import TagFilter from '@/app/components/base/tag-management/filter' import Button from '@/app/components/base/button' import { ApiConnectionMod } from '@/app/components/base/icons/src/vender/solid/development' -import SearchInput from '@/app/components/base/search-input' // Services import { fetchDatasetApiBaseUrl } from '@/service/datasets' @@ -29,6 +28,7 @@ import { useAppContext } from '@/context/app-context' import { useExternalApiPanel } from '@/context/external-api-panel-context' import { useQuery } from '@tanstack/react-query' +import Input from '@/app/components/base/input' const Container = () => { const { t } = useTranslation() @@ -81,17 +81,24 @@ const Container = () => { }, [currentWorkspace, router]) return ( -
    -
    +
    +
    setActiveTab(newActiveTab)} options={options} /> {activeTab === 'dataset' && ( -
    +
    - + handleKeywordsChange(e.target.value)} + onClear={() => handleKeywordsChange('')} + />
    diff --git a/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx b/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx index 389ae0d1fa..7a347a1899 100644 --- a/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx +++ b/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx @@ -12,7 +12,7 @@ import Divider from '@/app/components/base/divider' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Textarea from '@/app/components/base/textarea' -import type { DataSet } from '@/models/datasets' +import { type DataSet, RerankingModeEnum } from '@/models/datasets' import { useToastContext } from '@/app/components/base/toast' import { updateDatasetSetting } from '@/service/datasets' import { useAppContext } from '@/context/app-context' @@ -111,7 +111,10 @@ const SettingsModal: FC = ({ } const postRetrievalConfig = ensureRerankModelSelected({ rerankDefaultModel: rerankDefaultModel!, - retrievalConfig, + retrievalConfig: { + ...retrievalConfig, + reranking_enable: retrievalConfig.reranking_mode === RerankingModeEnum.RerankingModel, + }, indexMethod, }) try { @@ -255,7 +258,8 @@ const SettingsModal: FC = ({ disable={!localeCurrentDataset?.embedding_available} value={indexMethod} onChange={v => setIndexMethod(v!)} - itemClassName='sm:!w-[280px]' + docForm={currentDataset.doc_form} + currentValue={currentDataset.indexing_technique} />
    @@ -287,7 +291,7 @@ const SettingsModal: FC = ({ {/* Retrieval Method Config */} {currentDataset?.provider === 'external' ? <> -
    +
    {t('datasetSettings.form.retrievalSetting.title')}
    @@ -300,7 +304,7 @@ const SettingsModal: FC = ({ isInRetrievalSetting={true} />
    -
    +
    {t('datasetSettings.form.externalKnowledgeAPI')}
    @@ -326,7 +330,7 @@ const SettingsModal: FC = ({
    -
    +
    :
    diff --git a/web/app/components/app/create-app-dialog/app-list/index.tsx b/web/app/components/app/create-app-dialog/app-list/index.tsx index c9354ce2e1..f158f21d99 100644 --- a/web/app/components/app/create-app-dialog/app-list/index.tsx +++ b/web/app/components/app/create-app-dialog/app-list/index.tsx @@ -147,7 +147,7 @@ const Apps = ({ if (onSuccess) onSuccess() localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1') - getRedirection(isCurrentWorkspaceEditor, app, push) + getRedirection(isCurrentWorkspaceEditor, { id: app.app_id }, push) } catch (e) { Toast.notify({ type: 'error', message: t('app.newApp.appCreateFailed') }) diff --git a/web/app/components/app/log/list.tsx b/web/app/components/app/log/list.tsx index 383aeb1492..2862eebfa7 100644 --- a/web/app/components/app/log/list.tsx +++ b/web/app/components/app/log/list.tsx @@ -79,6 +79,9 @@ const HandThumbIconWithCount: FC<{ count: number; iconType: 'up' | 'down' }> = ( } const statusTdRender = (statusCount: StatusCount) => { + if (!statusCount) + return null + if (statusCount.partial_success + statusCount.failed === 0) { return (
    diff --git a/web/app/components/base/badge.tsx b/web/app/components/base/badge.tsx index 7b5a0fc873..0214d46968 100644 --- a/web/app/components/base/badge.tsx +++ b/web/app/components/base/badge.tsx @@ -24,11 +24,11 @@ const Badge = ({ className, )} > - {children || text} {hasRedCornerMark && (
    )} + {children || text}
    ) } diff --git a/web/app/components/base/file-uploader/file-type-icon.tsx b/web/app/components/base/file-uploader/file-type-icon.tsx index de9166d2ae..08d0131520 100644 --- a/web/app/components/base/file-uploader/file-type-icon.tsx +++ b/web/app/components/base/file-uploader/file-type-icon.tsx @@ -82,7 +82,7 @@ const FileTypeIcon = ({ size = 'sm', className, }: FileTypeIconProps) => { - const Icon = FILE_TYPE_ICON_MAP[type]?.component || FileAppearanceTypeEnum.document + const Icon = FILE_TYPE_ICON_MAP[type]?.component || FILE_TYPE_ICON_MAP[FileAppearanceTypeEnum.document].component const color = FILE_TYPE_ICON_MAP[type]?.color || FILE_TYPE_ICON_MAP[FileAppearanceTypeEnum.document].color return diff --git a/web/app/components/base/icons/assets/vender/workflow/agent.svg b/web/app/components/base/icons/assets/vender/workflow/agent.svg new file mode 100644 index 0000000000..f30c0b455f --- /dev/null +++ b/web/app/components/base/icons/assets/vender/workflow/agent.svg @@ -0,0 +1,8 @@ + + + + + + + + diff --git a/web/app/components/base/icons/src/public/knowledge/Chunk.json b/web/app/components/base/icons/src/public/knowledge/Chunk.json index 469d85d1a7..7bd5668810 100644 --- a/web/app/components/base/icons/src/public/knowledge/Chunk.json +++ b/web/app/components/base/icons/src/public/knowledge/Chunk.json @@ -24,7 +24,7 @@ "attributes": { "id": "Vector", "d": "M2.5 10H0V7.5H2.5V10Z", - "fill": "#676F83" + "fill": "currentColor" }, "children": [] }, @@ -34,7 +34,7 @@ "attributes": { "id": "Vector_2", "d": "M6.25 6.25H3.75V3.75H6.25V6.25Z", - "fill": "#676F83" + "fill": "currentColor" }, "children": [] }, @@ -44,7 +44,7 @@ "attributes": { "id": "Vector_3", "d": "M2.5 6.25H0V3.75H2.5V6.25Z", - "fill": "#676F83" + "fill": "currentColor" }, "children": [] }, @@ -54,7 +54,7 @@ "attributes": { "id": "Vector_4", "d": "M6.25 2.5H3.75V0H6.25V2.5Z", - "fill": "#676F83" + "fill": "currentColor" }, "children": [] }, @@ -64,7 +64,7 @@ "attributes": { "id": "Vector_5", "d": "M2.5 2.5H0V0H2.5V2.5Z", - "fill": "#676F83" + "fill": "currentColor" }, "children": [] }, @@ -74,7 +74,7 @@ "attributes": { "id": "Vector_6", "d": "M10 2.5H7.5V0H10V2.5Z", - "fill": "#676F83" + "fill": "currentColor" }, "children": [] }, @@ -84,7 +84,7 @@ "attributes": { "id": "Vector_7", "d": "M9.58342 7.91663H7.91675V9.58329H9.58342V7.91663Z", - "fill": "#676F83" + "fill": "currentColor" }, "children": [] }, @@ -94,7 +94,7 @@ "attributes": { "id": "Vector_8", "d": "M9.58342 4.16663H7.91675V5.83329H9.58342V4.16663Z", - "fill": "#676F83" + "fill": "currentColor" }, "children": [] }, @@ -104,7 +104,7 @@ "attributes": { "id": "Vector_9", "d": "M5.83341 7.91663H4.16675V9.58329H5.83341V7.91663Z", - "fill": "#676F83" + "fill": "currentColor" }, "children": [] } diff --git a/web/app/components/base/icons/src/public/knowledge/Collapse.json b/web/app/components/base/icons/src/public/knowledge/Collapse.json index 66d457155d..5e3cf08ce0 100644 --- a/web/app/components/base/icons/src/public/knowledge/Collapse.json +++ b/web/app/components/base/icons/src/public/knowledge/Collapse.json @@ -30,7 +30,7 @@ "name": "path", "attributes": { "d": "M2.66602 11.3333H0.666016L3.33268 8.66667L5.99935 11.3333H3.99935L3.99935 14H2.66602L2.66602 11.3333Z", - "fill": "#354052" + "fill": "currentColor" }, "children": [] }, @@ -39,7 +39,7 @@ "name": "path", "attributes": { "d": "M2.66602 4.66667L2.66602 2L3.99935 2L3.99935 4.66667L5.99935 4.66667L3.33268 7.33333L0.666016 4.66667L2.66602 4.66667Z", - "fill": "#354052" + "fill": "currentColor" }, "children": [] }, @@ -48,7 +48,7 @@ "name": "path", "attributes": { "d": "M7.33268 2.66667H13.9993V4H7.33268V2.66667ZM7.33268 12H13.9993V13.3333H7.33268V12ZM5.99935 7.33333H13.9993V8.66667H5.99935V7.33333Z", - "fill": "#354052" + "fill": "currentColor" }, "children": [] } diff --git a/web/app/components/base/icons/src/public/knowledge/LayoutRight2LineMod.json b/web/app/components/base/icons/src/public/knowledge/LayoutRight2LineMod.json index 26c5cf1d4f..6f5b00eb54 100644 --- a/web/app/components/base/icons/src/public/knowledge/LayoutRight2LineMod.json +++ b/web/app/components/base/icons/src/public/knowledge/LayoutRight2LineMod.json @@ -24,7 +24,7 @@ "attributes": { "id": "Vector", "d": "M14.0002 2C14.3684 2 14.6668 2.29848 14.6668 2.66667V13.3333C14.6668 13.7015 14.3684 14 14.0002 14H2.00016C1.63198 14 1.3335 13.7015 1.3335 13.3333V2.66667C1.3335 2.29848 1.63198 2 2.00016 2H14.0002ZM13.3335 3.33333H2.66683V12.6667H13.3335V3.33333ZM14.0002 2.66667V13.3333H10.0002V2.66667H14.0002Z", - "fill": "#354052" + "fill": "currentColor" }, "children": [] } diff --git a/web/app/components/base/icons/src/vender/workflow/Agent.json b/web/app/components/base/icons/src/vender/workflow/Agent.json new file mode 100644 index 0000000000..e7ed19369b --- /dev/null +++ b/web/app/components/base/icons/src/vender/workflow/Agent.json @@ -0,0 +1,53 @@ +{ + "icon": { + "type": "element", + "isRootNode": true, + "name": "svg", + "attributes": { + "width": "16", + "height": "16", + "viewBox": "0 0 16 16", + "fill": "none", + "xmlns": "http://www.w3.org/2000/svg" + }, + "children": [ + { + "type": "element", + "name": "g", + "attributes": { + "id": "agent" + }, + "children": [ + { + "type": "element", + "name": "g", + "attributes": { + "id": "Vector" + }, + "children": [ + { + "type": "element", + "name": "path", + "attributes": { + "d": "M14.7401 5.80454C14.5765 4.77996 14.1638 3.79808 13.5306 2.97273C12.8973 2.14738 12.0648 1.48568 11.1185 1.06589C10.1722 0.646098 9.12632 0.461106 8.08751 0.546487C7.05582 0.624753 6.04548 0.966277 5.17744 1.53548C4.3094 2.09758 3.58366 2.88024 3.09272 3.79808C2.59466 4.70881 2.33852 5.7405 2.33852 6.7793V7.22756L1.25703 9.3692C1.04357 9.80322 1.22145 10.3368 1.65547 10.5574L2.3314 10.8989V12.3006C2.3314 12.82 2.53063 13.3038 2.90061 13.6738C3.2706 14.0367 3.75442 14.243 4.27382 14.243H6.01702V14.7624C6.01702 15.1538 6.3372 15.4739 6.72853 15.4739C7.11986 15.4739 7.44004 15.1538 7.44004 14.7624V13.7094C7.44004 13.2185 7.04159 12.82 6.55065 12.82H4.27382C4.13864 12.82 4.00345 12.7631 3.91095 12.6706C3.81846 12.5781 3.76154 12.4429 3.76154 12.3077V10.5716C3.76154 10.2301 3.56943 9.92417 3.2706 9.77476L2.77254 9.52573L3.66904 7.73984C3.72596 7.61889 3.76154 7.4837 3.76154 7.34851V6.77219C3.76154 5.96818 3.96076 5.17129 4.34498 4.4669C4.72919 3.76251 5.28417 3.15772 5.9601 2.7237C6.63603 2.28968 7.41158 2.02643 8.20847 1.96239C9.00536 1.89835 9.81648 2.04066 10.5493 2.36795C11.2822 2.69524 11.9225 3.20042 12.4135 3.84077C12.8973 4.47402 13.2246 5.23533 13.3456 6.02511C13.4665 6.81488 13.3954 7.63312 13.125 8.38731C12.8617 9.12017 12.4206 9.78187 11.8585 10.3084C11.6735 10.4792 11.5668 10.7139 11.5668 10.9701V14.7624C11.5668 15.1538 11.887 15.4739 12.2783 15.4739C12.6696 15.4739 12.9898 15.1538 12.9898 14.7624V11.1978C13.6515 10.5432 14.1567 9.73918 14.4697 8.87114C14.8184 7.89637 14.918 6.83623 14.7615 5.81165L14.7401 5.80454Z", + "fill": "currentColor" + }, + "children": [] + }, + { + "type": "element", + "name": "path", + "attributes": { + "d": "M10.8055 7.99599C10.8909 7.83234 10.962 7.66158 11.0189 7.4837H11.6522C12.0435 7.4837 12.3637 7.16352 12.3637 6.77219C12.3637 6.38086 12.0435 6.06068 11.6522 6.06068H11.0189C10.9691 5.8828 10.898 5.71204 10.8055 5.54839L11.2537 5.10014C11.5312 4.82266 11.5312 4.3744 11.2537 4.09692C10.9762 3.81943 10.528 3.81943 10.2505 4.09692L9.80225 4.54517C9.6386 4.45267 9.46784 4.38863 9.28996 4.33171V3.69847C9.28996 3.30714 8.96978 2.98696 8.57845 2.98696C8.18712 2.98696 7.86694 3.30714 7.86694 3.69847V4.33171C7.68907 4.38152 7.5183 4.45267 7.35466 4.54517L6.90641 4.09692C6.62892 3.81943 6.18067 3.81943 5.90318 4.09692C5.62569 4.3744 5.62569 4.82266 5.90318 5.10014L6.35143 5.54839C6.26605 5.71204 6.1949 5.8828 6.13798 6.06068H5.50473C5.1134 6.06068 4.79323 6.38086 4.79323 6.77219C4.79323 7.16352 5.1134 7.4837 5.50473 7.4837H6.13798C6.18778 7.66158 6.25893 7.83234 6.35143 7.99599L5.90318 8.44424C5.62569 8.72172 5.62569 9.16997 5.90318 9.44746C6.04548 9.58976 6.22336 9.6538 6.40835 9.6538C6.59334 9.6538 6.77122 9.58265 6.91352 9.44746L7.36177 8.99921C7.52542 9.08459 7.69618 9.15574 7.87406 9.21267V9.84591C7.87406 10.2372 8.19424 10.5574 8.58557 10.5574C8.9769 10.5574 9.29708 10.2372 9.29708 9.84591V9.21267C9.47496 9.16286 9.64572 9.09171 9.80936 8.99921L10.2576 9.44746C10.3999 9.58976 10.5778 9.6538 10.7628 9.6538C10.9478 9.6538 11.1257 9.58265 11.268 9.44746C11.5454 9.16997 11.5454 8.72172 11.268 8.44424L10.8197 7.99599H10.8055ZM7.44004 6.77219C7.44004 6.14606 7.94521 5.64089 8.57134 5.64089C9.19747 5.64089 9.70264 6.14606 9.70264 6.77219C9.70264 7.39832 9.19747 7.90349 8.57134 7.90349C7.94521 7.90349 7.44004 7.39832 7.44004 6.77219Z", + "fill": "currentColor" + }, + "children": [] + } + ] + } + ] + } + ] + }, + "name": "Agent" +} \ No newline at end of file diff --git a/web/app/components/base/icons/src/vender/workflow/Agent.tsx b/web/app/components/base/icons/src/vender/workflow/Agent.tsx new file mode 100644 index 0000000000..e4337d4dbd --- /dev/null +++ b/web/app/components/base/icons/src/vender/workflow/Agent.tsx @@ -0,0 +1,16 @@ +// GENERATE BY script +// DON NOT EDIT IT MANUALLY + +import * as React from 'react' +import data from './Agent.json' +import IconBase from '@/app/components/base/icons/IconBase' +import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase' + +const Icon = React.forwardRef, Omit>(( + props, + ref, +) => ) + +Icon.displayName = 'Agent' + +export default Icon diff --git a/web/app/components/base/icons/src/vender/workflow/index.ts b/web/app/components/base/icons/src/vender/workflow/index.ts index b2cc7968bb..11ce55b130 100644 --- a/web/app/components/base/icons/src/vender/workflow/index.ts +++ b/web/app/components/base/icons/src/vender/workflow/index.ts @@ -1,3 +1,4 @@ +export { default as Agent } from './Agent' export { default as Answer } from './Answer' export { default as Assigner } from './Assigner' export { default as Code } from './Code' diff --git a/web/app/components/base/input-number/index.tsx b/web/app/components/base/input-number/index.tsx index 7492e0814c..316d863b48 100644 --- a/web/app/components/base/input-number/index.tsx +++ b/web/app/components/base/input-number/index.tsx @@ -1,28 +1,51 @@ -import type { FC, SetStateAction } from 'react' +import type { FC } from 'react' import { RiArrowDownSLine, RiArrowUpSLine } from '@remixicon/react' import Input, { type InputProps } from '../input' import classNames from '@/utils/classnames' export type InputNumberProps = { unit?: string - value: number - onChange: (value: number) => void + value?: number + onChange: (value?: number) => void amount?: number size?: 'sm' | 'md' -} & Omit + max?: number + min?: number + defaultValue?: number +} & Omit export const InputNumber: FC = (props) => { - const { unit, className, onChange, amount = 1, value, size = 'md', max, min, ...rest } = props - const update = (input: SetStateAction) => { - const current = typeof input === 'function' ? input(value) : input as number - if (max && current >= (max as number)) - return - if (min && current <= (min as number)) - return - onChange(current) + const { unit, className, onChange, amount = 1, value, size = 'md', max, min, defaultValue, ...rest } = props + + const isValidValue = (v: number) => { + if (max && v > max) + return false + if (min && v < min) + return false + return true } - const inc = () => update(val => val + amount) - const dec = () => update(val => val - amount) + + const inc = () => { + if (value === undefined) { + onChange(defaultValue) + return + } + const newValue = value + amount + if (!isValidValue(newValue)) + return + onChange(newValue) + } + const dec = () => { + if (value === undefined) { + onChange(defaultValue) + return + } + const newValue = value - amount + if (!isValidValue(newValue)) + return + onChange(newValue) + } + return
    = (props) => { max={max} min={min} onChange={(e) => { + if (e.target.value === '') + onChange(undefined) + const parsed = Number(e.target.value) if (Number.isNaN(parsed)) return + + if (!isValidValue(parsed)) + return onChange(parsed) }} + unit={unit} /> - {unit &&
    {unit}
    }
    } - noHighlight={Boolean(datasetId)} + noHighlight={isInUpload && isNotUploadInEmptyDataset} >
    @@ -615,10 +628,12 @@ const StepTwo = ({ onChange={e => setSegmentIdentifier(e.target.value, true)} /> ))} {IS_CE_EDITION && <> -
    - { - if (docForm === ChuckingMode.qa) - handleChangeDocform(ChuckingMode.text) - else - handleChangeDocform(ChuckingMode.qa) - }} - /> -
    + +
    +
    { + if (currentDataset?.doc_form) + return + if (docForm === ChunkingMode.qa) + handleChangeDocform(ChunkingMode.text) + else + handleChangeDocform(ChunkingMode.qa) + }}> + -
    - -
    -
    + +
    - {docForm === ChuckingMode.qa && ( + {currentDocForm === ChunkingMode.qa && (
    } { - (!datasetId || currentDataset!.doc_form === ChuckingMode.parentChild) + ( + (isInUpload && currentDataset!.doc_form === ChunkingMode.parentChild) + || isUploadInEmptyDataset + || isInInit + ) && } effectImg={OrangeEffect.src} activeHeaderClassName='bg-dataset-option-card-orange-gradient' description={t('datasetCreation.stepTwo.parentChildTip')} - isActive={ - datasetId ? currentDataset!.doc_form === ChuckingMode.parentChild : docForm === ChuckingMode.parentChild - } - onSwitched={() => handleChangeDocform(ChuckingMode.parentChild)} + isActive={currentDocForm === ChunkingMode.parentChild} + onSwitched={() => handleChangeDocform(ChunkingMode.parentChild)} actions={ <> } - noHighlight={Boolean(datasetId)} + noHighlight={isInUpload && isNotUploadInEmptyDataset} >
    @@ -733,6 +751,7 @@ const StepTwo = ({
    setParentChildConfig({ ...parentChildConfig, parent: { @@ -742,6 +761,7 @@ const StepTwo = ({ })} /> setParentChildConfig({ ...parentChildConfig, @@ -778,6 +798,7 @@ const StepTwo = ({
    setParentChildConfig({ ...parentChildConfig, child: { @@ -787,6 +808,7 @@ const StepTwo = ({ })} /> setParentChildConfig({ ...parentChildConfig, @@ -822,17 +844,18 @@ const StepTwo = ({ }
    {t('datasetCreation.stepTwo.indexMode')}
    -
    +
    {(!hasSetIndexType || (hasSetIndexType && indexingType === IndexingType.QUALIFIED)) && ( - + {t('datasetCreation.stepTwo.qualified')} - {!hasSetIndexType - && {t('datasetCreation.stepTwo.recommend')}} + + {t('datasetCreation.stepTwo.recommend')} + {!hasSetIndexType && } -

    } +
    } description={t('datasetCreation.stepTwo.qualifiedTip')} icon={} isActive={!hasSetIndexType && indexType === IndexingType.QUALIFIED} @@ -864,7 +887,7 @@ const StepTwo = ({ @@ -872,20 +895,20 @@ const StepTwo = ({ - - + } isActive={!hasSetIndexType && indexType === IndexingType.ECONOMICAL} - disabled={!isAPIKeySet || hasSetIndexType || docForm !== ChuckingMode.text} + disabled={!isAPIKeySet || hasSetIndexType || docForm !== ChunkingMode.text} ref={economyDomRef} onSwitched={() => { - if (isAPIKeySet && docForm === ChuckingMode.text) + if (isAPIKeySet && docForm === ChunkingMode.text) setIndexType(IndexingType.ECONOMICAL) }} /> @@ -893,7 +916,7 @@ const StepTwo = ({
    { - docForm === ChuckingMode.qa + docForm === ChunkingMode.qa ? t('datasetCreation.stepTwo.notAvailableForQA') : t('datasetCreation.stepTwo.notAvailableForParentChild') } @@ -902,8 +925,17 @@ const StepTwo = ({ )}
    + {!hasSetIndexType && indexType === IndexingType.QUALIFIED && ( +
    +
    +
    + +
    + {t('datasetCreation.stepTwo.highQualityTip')} +
    + )} {hasSetIndexType && indexType === IndexingType.ECONOMICAL && ( -
    +
    {t('datasetCreation.stepTwo.indexSettingTip')} {t('datasetCreation.stepTwo.datasetSettingLink')}
    @@ -921,7 +953,7 @@ const StepTwo = ({ }} /> {!!datasetId && ( -
    +
    {t('datasetCreation.stepTwo.indexSettingTip')} {t('datasetCreation.stepTwo.datasetSettingLink')}
    @@ -997,7 +1029,8 @@ const StepTwo = ({ setPreviewFile(selected) currentEstimateMutation.mutate() }} - value={previewFile} + // when it is from setting, it just has one file + value={isSetting ? (files[0]! as Required) : previewFile} /> } {dataSourceType === DataSourceType.NOTION @@ -1046,21 +1079,31 @@ const StepTwo = ({ } /> } - + { + currentDocForm !== ChunkingMode.qa + && + }
    } className={cn('flex shrink-0 w-1/2 p-4 pr-0 relative h-full', isMobile && 'w-full max-w-[524px]')} mainClassName='space-y-6' > - {docForm === ChuckingMode.qa && estimate?.qa_preview && ( - estimate?.qa_preview.map(item => ( - + {currentDocForm === ChunkingMode.qa && estimate?.qa_preview && ( + estimate?.qa_preview.map((item, index) => ( + + + )) )} - {docForm === ChuckingMode.text && estimate?.preview && ( + {currentDocForm === ChunkingMode.text && estimate?.preview && ( estimate?.preview.map((item, index) => ( )) )} - {docForm === ChuckingMode.parentChild && currentEstimateMutation.data?.preview && ( + {currentDocForm === ChunkingMode.parentChild && currentEstimateMutation.data?.preview && ( estimate?.preview?.map((item, index) => { const indexForLabel = index + 1 return ( @@ -1090,6 +1133,7 @@ const StepTwo = ({ text={child} tooltip={`Child-chunk-${indexForLabel} · ${child.length} Characters`} labelInnerClassName='text-[10px] font-semibold align-bottom leading-7' + dividerClassName='leading-7' /> ) })} @@ -1111,7 +1155,7 @@ const StepTwo = ({ {currentEstimateMutation.isPending && (
    {Array.from({ length: 10 }, (_, i) => ( - + @@ -1120,7 +1164,7 @@ const StepTwo = ({ - + ))}
    )} diff --git a/web/app/components/datasets/create/step-two/inputs.tsx b/web/app/components/datasets/create/step-two/inputs.tsx index 3d38a256f1..4231f6242d 100644 --- a/web/app/components/datasets/create/step-two/inputs.tsx +++ b/web/app/components/datasets/create/step-two/inputs.tsx @@ -17,14 +17,14 @@ const FormField: FC> = (props) => {
    } -export const DelimiterInput: FC = (props) => { +export const DelimiterInput: FC = (props) => { const { t } = useTranslation() return {t('datasetCreation.stepTwo.separator')} - {t('datasetCreation.stepTwo.separatorTip')} + {props.tooltip || t('datasetCreation.stepTwo.separatorTip')}
    } /> diff --git a/web/app/components/datasets/create/step-two/language-select/index.tsx b/web/app/components/datasets/create/step-two/language-select/index.tsx index 016f2a5f20..9cbf1a40d1 100644 --- a/web/app/components/datasets/create/step-two/language-select/index.tsx +++ b/web/app/components/datasets/create/step-two/language-select/index.tsx @@ -1,7 +1,7 @@ 'use client' import type { FC } from 'react' import React from 'react' -import { RiArrowDownSLine } from '@remixicon/react' +import { RiArrowDownSLine, RiCheckLine } from '@remixicon/react' import cn from '@/utils/classnames' import Popover from '@/app/components/base/popover' import { languages } from '@/i18n/language' @@ -24,24 +24,38 @@ const LanguageSelect: FC = ({ disabled={disabled} popupClassName='z-20' htmlContent={ -
    +
    {languages.filter(language => language.supported).map(({ prompt_name }) => (
    onSelect(prompt_name)}>{prompt_name} + className='w-full py-2 px-3 inline-flex items-center justify-between hover:bg-state-base-hover rounded-lg cursor-pointer' + onClick={() => onSelect(prompt_name)} + > + {prompt_name} + {(currentLanguage === prompt_name) && }
    ))}
    } btnElement={ -
    - {currentLanguage} - +
    + + {currentLanguage} + +
    } - btnClassName={() => cn('!border-0 !px-0 !py-0 !bg-inherit !hover:bg-inherit text-components-button-tertiary-text')} - className='!w-[120px] h-fit !z-20 !translate-x-0 !left-[-16px]' + btnClassName={() => cn( + '!border-0 rounded-md !px-1.5 !py-1 !mx-1 !bg-components-button-tertiary-bg !hover:bg-components-button-tertiary-bg', + disabled ? 'bg-components-button-tertiary-bg-disabled' : '', + )} + className='!w-[140px] h-fit !z-20 !translate-x-0 !left-1' /> ) } diff --git a/web/app/components/datasets/create/step-two/option-card.tsx b/web/app/components/datasets/create/step-two/option-card.tsx index 7d3b06f375..d0efdaabb1 100644 --- a/web/app/components/datasets/create/step-two/option-card.tsx +++ b/web/app/components/datasets/create/step-two/option-card.tsx @@ -34,7 +34,7 @@ export const OptionCardHeader: FC = (props) => { -
    +
    {title}
    {description}
    @@ -53,10 +53,10 @@ type OptionCardProps = { onSwitched?: () => void noHighlight?: boolean disabled?: boolean -} & Omit, 'title'> +} & Omit, 'title' | 'onClick'> export const OptionCard: FC = forwardRef((props, ref) => { - const { icon, className, title, description, isActive, children, actions, activeHeaderClassName, style, effectImg, onSwitched, onClick, noHighlight, disabled, ...rest } = props + const { icon, className, title, description, isActive, children, actions, activeHeaderClassName, style, effectImg, onSwitched, noHighlight, disabled, ...rest } = props return
    = forwardRef((props, ref) => { style={{ ...style, }} - onClick={(e) => { - if (!isActive) + onClick={() => { + if (!isActive && !disabled) onSwitched?.() - onClick?.(e) }} {...rest} ref={ref} @@ -86,7 +85,7 @@ export const OptionCard: FC = forwardRef((props, ref) => { effectImg={effectImg} /> {/** Body */} - {isActive && (children || actions) &&
    + {isActive && (children || actions) &&
    {children} {actions &&
    {actions} diff --git a/web/app/components/datasets/create/top-bar/index.tsx b/web/app/components/datasets/create/top-bar/index.tsx index 6f773d9a3e..20ba7158db 100644 --- a/web/app/components/datasets/create/top-bar/index.tsx +++ b/web/app/components/datasets/create/top-bar/index.tsx @@ -28,7 +28,7 @@ export const Topbar: FC = (props) => {

    ({ diff --git a/web/app/components/datasets/create/website/jina-reader/index.tsx b/web/app/components/datasets/create/website/jina-reader/index.tsx index 51d77d7121..1c133f935c 100644 --- a/web/app/components/datasets/create/website/jina-reader/index.tsx +++ b/web/app/components/datasets/create/website/jina-reader/index.tsx @@ -94,7 +94,6 @@ const JinaReader: FC = ({ const waitForCrawlFinished = useCallback(async (jobId: string) => { try { const res = await checkJinaReaderTaskStatus(jobId) as any - console.log('res', res) if (res.status === 'completed') { return { isError: false, diff --git a/web/app/components/datasets/create/website/preview.tsx b/web/app/components/datasets/create/website/preview.tsx index 070aa7ae83..5180a83442 100644 --- a/web/app/components/datasets/create/website/preview.tsx +++ b/web/app/components/datasets/create/website/preview.tsx @@ -32,7 +32,7 @@ const WebsitePreview = ({
    {payload.source_url}
    -
    {payload.markdown}
    +
    {payload.markdown}
    ) diff --git a/web/app/components/datasets/documents/detail/batch-modal/csv-downloader.tsx b/web/app/components/datasets/documents/detail/batch-modal/csv-downloader.tsx index d340f90deb..6602244a48 100644 --- a/web/app/components/datasets/documents/detail/batch-modal/csv-downloader.tsx +++ b/web/app/components/datasets/documents/detail/batch-modal/csv-downloader.tsx @@ -7,7 +7,7 @@ import { import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import { Download02 as DownloadIcon } from '@/app/components/base/icons/src/vender/solid/general' -import { DocForm } from '@/models/datasets' +import { ChunkingMode } from '@/models/datasets' import I18n from '@/context/i18n' import { LanguagesSupported } from '@/i18n/language' @@ -32,18 +32,18 @@ const CSV_TEMPLATE_CN = [ ['内容 2'], ] -const CSVDownload: FC<{ docForm: DocForm }> = ({ docForm }) => { +const CSVDownload: FC<{ docForm: ChunkingMode }> = ({ docForm }) => { const { t } = useTranslation() const { locale } = useContext(I18n) const { CSVDownloader, Type } = useCSVDownloader() const getTemplate = () => { if (locale === LanguagesSupported[1]) { - if (docForm === DocForm.QA) + if (docForm === ChunkingMode.qa) return CSV_TEMPLATE_QA_CN return CSV_TEMPLATE_CN } - if (docForm === DocForm.QA) + if (docForm === ChunkingMode.qa) return CSV_TEMPLATE_QA_EN return CSV_TEMPLATE_EN } @@ -52,7 +52,7 @@ const CSVDownload: FC<{ docForm: DocForm }> = ({ docForm }) => {
    {t('share.generation.csvStructureTitle')}
    - {docForm === DocForm.QA && ( + {docForm === ChunkingMode.qa && ( @@ -72,7 +72,7 @@ const CSVDownload: FC<{ docForm: DocForm }> = ({ docForm }) => {
    )} - {docForm === DocForm.TEXT && ( + {docForm === ChunkingMode.text && ( diff --git a/web/app/components/datasets/documents/detail/batch-modal/index.tsx b/web/app/components/datasets/documents/detail/batch-modal/index.tsx index 139a364cb4..c666ba6715 100644 --- a/web/app/components/datasets/documents/detail/batch-modal/index.tsx +++ b/web/app/components/datasets/documents/detail/batch-modal/index.tsx @@ -7,11 +7,11 @@ import CSVUploader from './csv-uploader' import CSVDownloader from './csv-downloader' import Button from '@/app/components/base/button' import Modal from '@/app/components/base/modal' -import type { DocForm } from '@/models/datasets' +import type { ChunkingMode } from '@/models/datasets' export type IBatchModalProps = { isShow: boolean - docForm: DocForm + docForm: ChunkingMode onCancel: () => void onConfirm: (file: File) => void } diff --git a/web/app/components/datasets/documents/detail/completed/child-segment-detail.tsx b/web/app/components/datasets/documents/detail/completed/child-segment-detail.tsx index 34728170d7..085bfddc16 100644 --- a/web/app/components/datasets/documents/detail/completed/child-segment-detail.tsx +++ b/web/app/components/datasets/documents/detail/completed/child-segment-detail.tsx @@ -9,7 +9,7 @@ import ChunkContent from './common/chunk-content' import Dot from './common/dot' import { SegmentIndexTag } from './common/segment-index-tag' import { useSegmentListContext } from './index' -import type { ChildChunkDetail, ChuckingMode } from '@/models/datasets' +import type { ChildChunkDetail, ChunkingMode } from '@/models/datasets' import { useEventEmitterContextContext } from '@/context/event-emitter' import { formatNumber } from '@/utils/format' import classNames from '@/utils/classnames' @@ -21,7 +21,7 @@ type IChildSegmentDetailProps = { childChunkInfo?: Partial & { id: string } onUpdate: (segmentId: string, childChunkId: string, content: string) => void onCancel: () => void - docForm: ChuckingMode + docForm: ChunkingMode } /** @@ -38,7 +38,8 @@ const ChildSegmentDetail: FC = ({ const [content, setContent] = useState(childChunkInfo?.content || '') const { eventEmitter } = useEventEmitterContextContext() const [loading, setLoading] = useState(false) - const [fullScreen, toggleFullScreen] = useSegmentListContext(s => [s.fullScreen, s.toggleFullScreen]) + const fullScreen = useSegmentListContext(s => s.fullScreen) + const toggleFullScreen = useSegmentListContext(s => s.toggleFullScreen) eventEmitter?.useSubscription((v) => { if (v === 'update-child-segment') @@ -106,8 +107,8 @@ const ChildSegmentDetail: FC = ({ -
    -
    +
    +
    void isLoading?: boolean + focused?: boolean } const ChildSegmentList: FC = ({ @@ -38,9 +40,11 @@ const ChildSegmentList: FC = ({ inputValue, onClearFilter, isLoading, + focused = false, }) => { const { t } = useTranslation() const parentMode = useDocumentContext(s => s.parentMode) + const currChildChunk = useSegmentListContext(s => s.currChildChunk) const [collapsed, setCollapsed] = useState(true) @@ -57,8 +61,8 @@ const ChildSegmentList: FC = ({ }, [parentMode]) const contentOpacity = useMemo(() => { - return enabled ? '' : 'opacity-50 group-hover/card:opacity-100' - }, [enabled]) + return (enabled || focused) ? '' : 'opacity-50 group-hover/card:opacity-100' + }, [enabled, focused]) const totalText = useMemo(() => { const isSearch = inputValue !== '' && isFullDocMode @@ -87,15 +91,22 @@ const ChildSegmentList: FC = ({
    {isFullDocMode ? : null} -
    -
    { +
    +
    { event.stopPropagation() toggleCollapse() - }}> + }} + > { isParagraphMode ? collapsed @@ -108,6 +119,7 @@ const ChildSegmentList: FC = ({ {totalText} ·
    {isLoading ? : null} {((isFullDocMode && !isLoading) || !collapsed) - ?
    + ?
    {isParagraphMode && (
    @@ -145,19 +157,26 @@ const ChildSegmentList: FC = ({ ? {childChunks.map((childChunk) => { const edited = childChunk.updated_at !== childChunk.created_at + const focused = currChildChunk?.childChunkInfo?.id === childChunk.id return onDelete?.(childChunk.segment_id, childChunk.id)} - className='line-clamp-3' - labelInnerClassName='text-[10px] font-semibold align-bottom leading-6' - contentClassName='!leading-6' + labelClassName={focused ? 'bg-state-accent-solid text-text-primary-on-surface' : ''} + labelInnerClassName={'text-[10px] font-semibold align-bottom leading-6'} + contentClassName={classNames('!leading-6', focused ? 'bg-state-accent-hover-alt text-text-primary' : '')} showDivider={false} onClick={(e) => { e.stopPropagation() onClickSlice?.(childChunk) }} + offsetOptions={({ rects }) => { + return { + mainAxis: isFullDocMode ? -rects.floating.width : 12 - rects.floating.width, + crossAxis: (20 - rects.floating.height) / 2, + } + }} /> })} diff --git a/web/app/components/datasets/documents/detail/completed/common/action-buttons.tsx b/web/app/components/datasets/documents/detail/completed/common/action-buttons.tsx index 15bff500b5..1238d98a9c 100644 --- a/web/app/components/datasets/documents/detail/completed/common/action-buttons.tsx +++ b/web/app/components/datasets/documents/detail/completed/common/action-buttons.tsx @@ -23,7 +23,8 @@ const ActionButtons: FC = ({ isChildChunk = false, }) => { const { t } = useTranslation() - const [mode, parentMode] = useDocumentContext(s => [s.mode, s.parentMode]) + const mode = useDocumentContext(s => s.mode) + const parentMode = useDocumentContext(s => s.parentMode) useKeyPress(['esc'], (e) => { e.preventDefault() diff --git a/web/app/components/datasets/documents/detail/completed/common/batch-action.tsx b/web/app/components/datasets/documents/detail/completed/common/batch-action.tsx index df3ae6e1ec..3dd3689b64 100644 --- a/web/app/components/datasets/documents/detail/completed/common/batch-action.tsx +++ b/web/app/components/datasets/documents/detail/completed/common/batch-action.tsx @@ -52,33 +52,33 @@ const BatchAction: FC = ({
    -
    -
    {onArchive && (
    -
    )}
    -
    -
    diff --git a/web/app/components/datasets/documents/detail/completed/common/chunk-content.tsx b/web/app/components/datasets/documents/detail/completed/common/chunk-content.tsx index 47bd3ab4a1..e6403fa12f 100644 --- a/web/app/components/datasets/documents/detail/completed/common/chunk-content.tsx +++ b/web/app/components/datasets/documents/detail/completed/common/chunk-content.tsx @@ -1,7 +1,150 @@ -import React, { type FC } from 'react' +import React, { useEffect, useRef, useState } from 'react' +import type { ComponentProps, FC } from 'react' import { useTranslation } from 'react-i18next' -import { ChuckingMode } from '@/models/datasets' -import AutoHeightTextarea from '@/app/components/base/auto-height-textarea/common' +import { ChunkingMode } from '@/models/datasets' +import classNames from '@/utils/classnames' + +type IContentProps = ComponentProps<'textarea'> + +const Textarea: FC = React.memo(({ + value, + placeholder, + className, + disabled, + ...rest +}) => { + return ( +