diff --git a/README.md b/README.md index ca09adec08..ec399e49ee 100644 --- a/README.md +++ b/README.md @@ -226,6 +226,11 @@ Deploy Dify to AWS with [CDK](https://aws.amazon.com/cdk/) - [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +#### Using Alibaba Cloud Computing Nest + +Quickly deploy Dify to Alibaba cloud with [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + + ## Contributing For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). diff --git a/README_AR.md b/README_AR.md index df288fd33c..5214da4894 100644 --- a/README_AR.md +++ b/README_AR.md @@ -209,6 +209,9 @@ docker compose up -d - [AWS CDK بواسطة @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +#### استخدام Alibaba Cloud للنشر + [بسرعة نشر Dify إلى سحابة علي بابا مع عش الحوسبة السحابية علي بابا](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + ## المساهمة لأولئك الذين يرغبون في المساهمة، انظر إلى [دليل المساهمة](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) لدينا. diff --git a/README_BN.md b/README_BN.md index 4a5b5f3928..1911f186d7 100644 --- a/README_BN.md +++ b/README_BN.md @@ -225,6 +225,11 @@ GitHub-এ ডিফাইকে স্টার দিয়ে রাখুন - [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +#### Alibaba Cloud ব্যবহার করে ডিপ্লয় + + [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + + ## Contributing যারা কোড অবদান রাখতে চান, তাদের জন্য আমাদের [অবদান নির্দেশিকা] দেখুন (https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)। diff --git a/README_CN.md b/README_CN.md index ba7ee0006d..a194b01937 100644 --- a/README_CN.md +++ b/README_CN.md @@ -221,6 +221,11 @@ docker compose up -d ##### AWS - [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +#### 使用 阿里云计算巢 部署 + +使用 [阿里云计算巢](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) 将 Dify 一键部署到 阿里云 + + ## Star History [![Star History Chart](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date) diff --git a/README_DE.md b/README_DE.md index f6023a3935..fd550a5b96 100644 --- a/README_DE.md +++ b/README_DE.md @@ -221,6 +221,11 @@ Bereitstellung von Dify auf AWS mit [CDK](https://aws.amazon.com/cdk/) ##### AWS - [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +#### Alibaba Cloud + +[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + + ## Contributing Falls Sie Code beitragen möchten, lesen Sie bitte unseren [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). Gleichzeitig bitten wir Sie, Dify zu unterstützen, indem Sie es in den sozialen Medien teilen und auf Veranstaltungen und Konferenzen präsentieren. diff --git a/README_ES.md b/README_ES.md index 12f2ce8c11..38dea09be1 100644 --- a/README_ES.md +++ b/README_ES.md @@ -221,6 +221,10 @@ Despliegue Dify en AWS usando [CDK](https://aws.amazon.com/cdk/) ##### AWS - [AWS CDK por @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +#### Alibaba Cloud + +[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + ## Contribuir Para aquellos que deseen contribuir con código, consulten nuestra [Guía de contribución](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). diff --git a/README_FR.md b/README_FR.md index b106615b31..925918e47e 100644 --- a/README_FR.md +++ b/README_FR.md @@ -219,6 +219,11 @@ Déployez Dify sur AWS en utilisant [CDK](https://aws.amazon.com/cdk/) ##### AWS - [AWS CDK par @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +#### Alibaba Cloud + +[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + + ## Contribuer Pour ceux qui souhaitent contribuer du code, consultez notre [Guide de contribution](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). diff --git a/README_JA.md b/README_JA.md index 26703f3958..3f8a5b859d 100644 --- a/README_JA.md +++ b/README_JA.md @@ -220,6 +220,10 @@ docker compose up -d ##### AWS - [@KevinZhaoによるAWS CDK](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +#### Alibaba Cloud +[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + + ## 貢献 コードに貢献したい方は、[Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)を参照してください。 diff --git a/README_KL.md b/README_KL.md index ea91baa5aa..9e562a4d73 100644 --- a/README_KL.md +++ b/README_KL.md @@ -219,6 +219,11 @@ wa'logh nIqHom neH ghun deployment toy'wI' [CDK](https://aws.amazon.com/cdk/) lo ##### AWS - [AWS CDK qachlot @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +#### Alibaba Cloud + +[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + + ## Contributing For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). diff --git a/README_KR.md b/README_KR.md index 89301e8b2c..683b3a86f4 100644 --- a/README_KR.md +++ b/README_KR.md @@ -213,6 +213,11 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했 ##### AWS - [KevinZhao의 AWS CDK](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +#### Alibaba Cloud + +[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + + ## 기여 코드에 기여하고 싶은 분들은 [기여 가이드](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)를 참조하세요. diff --git a/README_PT.md b/README_PT.md index 157772d528..b81127b70b 100644 --- a/README_PT.md +++ b/README_PT.md @@ -218,6 +218,11 @@ Implante o Dify na AWS usando [CDK](https://aws.amazon.com/cdk/) ##### AWS - [AWS CDK por @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +#### Alibaba Cloud + +[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + + ## Contribuindo Para aqueles que desejam contribuir com código, veja nosso [Guia de Contribuição](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). diff --git a/README_SI.md b/README_SI.md index 14de1ea792..7034233233 100644 --- a/README_SI.md +++ b/README_SI.md @@ -219,6 +219,11 @@ Uvedite Dify v AWS z uporabo [CDK](https://aws.amazon.com/cdk/) ##### AWS - [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +#### Alibaba Cloud + +[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + + ## Prispevam Za tiste, ki bi radi prispevali kodo, si oglejte naš vodnik za prispevke . Hkrati vas prosimo, da podprete Dify tako, da ga delite na družbenih medijih ter na dogodkih in konferencah. diff --git a/README_TR.md b/README_TR.md index 563a05af3c..51156933d4 100644 --- a/README_TR.md +++ b/README_TR.md @@ -212,6 +212,11 @@ Dify'ı bulut platformuna tek tıklamayla dağıtın [terraform](https://www.ter ##### AWS - [AWS CDK tarafından @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +#### Alibaba Cloud + +[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + + ## Katkıda Bulunma Kod katkısında bulunmak isteyenler için [Katkı Kılavuzumuza](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) bakabilirsiniz. diff --git a/README_TW.md b/README_TW.md index f4a76ac109..291da28825 100644 --- a/README_TW.md +++ b/README_TW.md @@ -224,6 +224,11 @@ Dify 的所有功能都提供相應的 API,因此您可以輕鬆地將 Dify - [由 @KevinZhao 提供的 AWS CDK](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +#### 使用 阿里云计算巢進行部署 + +[阿里云](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + + ## 貢獻 對於想要貢獻程式碼的開發者,請參閱我們的[貢獻指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)。 diff --git a/README_VI.md b/README_VI.md index 4e1e05cbf3..51a2e9e9e6 100644 --- a/README_VI.md +++ b/README_VI.md @@ -214,6 +214,12 @@ Triển khai Dify trên AWS bằng [CDK](https://aws.amazon.com/cdk/) ##### AWS - [AWS CDK bởi @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) + +#### Alibaba Cloud + +[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + + ## Đóng góp Đối với những người muốn đóng góp mã, xem [Hướng dẫn Đóng góp](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) của chúng tôi. diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index 8cb7ad9f5b..f5257fae79 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -56,8 +56,7 @@ class InsertExploreAppListApi(Resource): parser.add_argument("position", type=int, required=True, nullable=False, location="json") args = parser.parse_args() - with Session(db.engine) as session: - app = session.execute(select(App).filter(App.id == args["app_id"])).scalar_one_or_none() + app = db.session.execute(select(App).filter(App.id == args["app_id"])).scalar_one_or_none() if not app: raise NotFound(f"App '{args['app_id']}' is not found") @@ -78,38 +77,38 @@ class InsertExploreAppListApi(Resource): select(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"]) ).scalar_one_or_none() - if not recommended_app: - recommended_app = RecommendedApp( - app_id=app.id, - description=desc, - copyright=copy_right, - privacy_policy=privacy_policy, - custom_disclaimer=custom_disclaimer, - language=args["language"], - category=args["category"], - position=args["position"], - ) + if not recommended_app: + recommended_app = RecommendedApp( + app_id=app.id, + description=desc, + copyright=copy_right, + privacy_policy=privacy_policy, + custom_disclaimer=custom_disclaimer, + language=args["language"], + category=args["category"], + position=args["position"], + ) - db.session.add(recommended_app) + db.session.add(recommended_app) - app.is_public = True - db.session.commit() + app.is_public = True + db.session.commit() - return {"result": "success"}, 201 - else: - recommended_app.description = desc - recommended_app.copyright = copy_right - recommended_app.privacy_policy = privacy_policy - recommended_app.custom_disclaimer = custom_disclaimer - recommended_app.language = args["language"] - recommended_app.category = args["category"] - recommended_app.position = args["position"] + return {"result": "success"}, 201 + else: + recommended_app.description = desc + recommended_app.copyright = copy_right + recommended_app.privacy_policy = privacy_policy + recommended_app.custom_disclaimer = custom_disclaimer + recommended_app.language = args["language"] + recommended_app.category = args["category"] + recommended_app.position = args["position"] - app.is_public = True + app.is_public = True - db.session.commit() + db.session.commit() - return {"result": "success"}, 200 + return {"result": "success"}, 200 class InsertExploreAppApi(Resource): diff --git a/api/controllers/console/app/app_import.py b/api/controllers/console/app/app_import.py index 5dc6515ce0..9ffb94e9f9 100644 --- a/api/controllers/console/app/app_import.py +++ b/api/controllers/console/app/app_import.py @@ -17,6 +17,8 @@ from libs.login import login_required from models import Account from models.model import App from services.app_dsl_service import AppDslService, ImportStatus +from services.enterprise.enterprise_service import EnterpriseService +from services.feature_service import FeatureService class AppImportApi(Resource): @@ -60,7 +62,9 @@ class AppImportApi(Resource): app_id=args.get("app_id"), ) session.commit() - + if result.app_id and FeatureService.get_system_features().webapp_auth.enabled: + # update web app setting as private + EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private") # Return appropriate status code based on result status = result.status if status == ImportStatus.FAILED.value: diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 16a00bbd42..8e8ae9a0e1 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -44,7 +44,6 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.plugin.impl.exc import PluginDaemonClientSideError from core.rag.extractor.entity.extract_setting import ExtractSetting from extensions.ext_database import db -from extensions.ext_redis import redis_client from fields.document_fields import ( dataset_and_document_fields, document_fields, @@ -56,8 +55,6 @@ from models import Dataset, DatasetProcessRule, Document, DocumentSegment, Uploa from models.dataset import DocumentPipelineExecutionLog from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig -from tasks.add_document_to_index_task import add_document_to_index_task -from tasks.remove_document_from_index_task import remove_document_from_index_task class DocumentResource(Resource): @@ -864,77 +861,16 @@ class DocumentStatusApi(DocumentResource): DatasetService.check_dataset_permission(dataset, current_user) document_ids = request.args.getlist("document_id") - for document_id in document_ids: - document = self.get_document(dataset_id, document_id) - indexing_cache_key = "document_{}_indexing".format(document.id) - cache_result = redis_client.get(indexing_cache_key) - if cache_result is not None: - raise InvalidActionError(f"Document:{document.name} is being indexed, please try again later") + try: + DocumentService.batch_update_document_status(dataset, document_ids, action, current_user) + except services.errors.document.DocumentIndexingError as e: + raise InvalidActionError(str(e)) + except ValueError as e: + raise InvalidActionError(str(e)) + except NotFound as e: + raise NotFound(str(e)) - if action == "enable": - if document.enabled: - continue - document.enabled = True - document.disabled_at = None - document.disabled_by = None - document.updated_at = datetime.now(UTC).replace(tzinfo=None) - db.session.commit() - - # Set cache to prevent indexing the same document multiple times - redis_client.setex(indexing_cache_key, 600, 1) - - add_document_to_index_task.delay(document_id) - - elif action == "disable": - if not document.completed_at or document.indexing_status != "completed": - raise InvalidActionError(f"Document: {document.name} is not completed.") - if not document.enabled: - continue - - document.enabled = False - document.disabled_at = datetime.now(UTC).replace(tzinfo=None) - document.disabled_by = current_user.id - document.updated_at = datetime.now(UTC).replace(tzinfo=None) - db.session.commit() - - # Set cache to prevent indexing the same document multiple times - redis_client.setex(indexing_cache_key, 600, 1) - - remove_document_from_index_task.delay(document_id) - - elif action == "archive": - if document.archived: - continue - - document.archived = True - document.archived_at = datetime.now(UTC).replace(tzinfo=None) - document.archived_by = current_user.id - document.updated_at = datetime.now(UTC).replace(tzinfo=None) - db.session.commit() - - if document.enabled: - # Set cache to prevent indexing the same document multiple times - redis_client.setex(indexing_cache_key, 600, 1) - - remove_document_from_index_task.delay(document_id) - - elif action == "un_archive": - if not document.archived: - continue - document.archived = False - document.archived_at = None - document.archived_by = None - document.updated_at = datetime.now(UTC).replace(tzinfo=None) - db.session.commit() - - # Set cache to prevent indexing the same document multiple times - redis_client.setex(indexing_cache_key, 600, 1) - - add_document_to_index_task.delay(document_id) - - else: - raise InvalidActionError() return {"result": "success"}, 200 diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index ba74e2c074..b4eb5e246b 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -15,7 +15,7 @@ class LoadBalancingCredentialsValidateApi(Resource): @login_required @account_initialization_required def post(self, provider: str): - if not TenantAccountRole.is_privileged_role(current_user.current_tenant.current_role): + if not TenantAccountRole.is_privileged_role(current_user.current_role): raise Forbidden() tenant_id = current_user.current_tenant_id @@ -64,7 +64,7 @@ class LoadBalancingConfigCredentialsValidateApi(Resource): @login_required @account_initialization_required def post(self, provider: str, config_id: str): - if not TenantAccountRole.is_privileged_role(current_user.current_tenant.current_role): + if not TenantAccountRole.is_privileged_role(current_user.current_role): raise Forbidden() tenant_id = current_user.current_tenant_id diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 1467dfb6b3..839afdb9fd 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -4,7 +4,7 @@ from werkzeug.exceptions import Forbidden, NotFound import services.dataset_service from controllers.service_api import api -from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError +from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError from controllers.service_api.wraps import ( DatasetApiResource, cloud_edition_billing_rate_limit_check, @@ -17,7 +17,7 @@ from fields.dataset_fields import dataset_detail_fields from fields.tag_fields import tag_fields from libs.login import current_user from models.dataset import Dataset, DatasetPermissionEnum -from services.dataset_service import DatasetPermissionService, DatasetService +from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import RetrievalModel from services.tag_service import TagService @@ -329,6 +329,56 @@ class DatasetApi(DatasetApiResource): raise DatasetInUseError() +class DocumentStatusApi(DatasetApiResource): + """Resource for batch document status operations.""" + + def patch(self, tenant_id, dataset_id, action): + """ + Batch update document status. + + Args: + tenant_id: tenant id + dataset_id: dataset id + action: action to perform (enable, disable, archive, un_archive) + + Returns: + dict: A dictionary with a key 'result' and a value 'success' + int: HTTP status code 200 indicating that the operation was successful. + + Raises: + NotFound: If the dataset with the given ID does not exist. + Forbidden: If the user does not have permission. + InvalidActionError: If the action is invalid or cannot be performed. + """ + dataset_id_str = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) + + if dataset is None: + raise NotFound("Dataset not found.") + + # Check user's permission + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + + # Check dataset model setting + DatasetService.check_dataset_model_setting(dataset) + + # Get document IDs from request body + data = request.get_json() + document_ids = data.get("document_ids", []) + + try: + DocumentService.batch_update_document_status(dataset, document_ids, action, current_user) + except services.errors.document.DocumentIndexingError as e: + raise InvalidActionError(str(e)) + except ValueError as e: + raise InvalidActionError(str(e)) + + return {"result": "success"}, 200 + + class DatasetTagsApi(DatasetApiResource): @validate_dataset_token @marshal_with(tag_fields) @@ -457,6 +507,7 @@ class DatasetTagsBindingStatusApi(DatasetApiResource): api.add_resource(DatasetListApi, "/datasets") api.add_resource(DatasetApi, "/datasets/") +api.add_resource(DocumentStatusApi, "/datasets//documents/status/") api.add_resource(DatasetTagsApi, "/datasets/tags") api.add_resource(DatasetTagBindingApi, "/datasets/tags/binding") api.add_resource(DatasetTagUnbindingApi, "/datasets/tags/unbinding") diff --git a/api/core/rag/extractor/markdown_extractor.py b/api/core/rag/extractor/markdown_extractor.py index 849852ac23..c97765b1dc 100644 --- a/api/core/rag/extractor/markdown_extractor.py +++ b/api/core/rag/extractor/markdown_extractor.py @@ -68,22 +68,17 @@ class MarkdownExtractor(BaseExtractor): continue header_match = re.match(r"^#+\s", line) if header_match: - if current_header is not None: - markdown_tups.append((current_header, current_text)) - + markdown_tups.append((current_header, current_text)) current_header = line current_text = "" else: current_text += line + "\n" markdown_tups.append((current_header, current_text)) - if current_header is not None: - # pass linting, assert keys are defined - markdown_tups = [ - (re.sub(r"#", "", cast(str, key)).strip(), re.sub(r"<.*?>", "", value)) for key, value in markdown_tups - ] - else: - markdown_tups = [(key, re.sub("\n", "", value)) for key, value in markdown_tups] + markdown_tups = [ + (re.sub(r"#", "", cast(str, key)).strip() if key else None, re.sub(r"<.*?>", "", value)) + for key, value in markdown_tups + ] return markdown_tups diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py index e5ead9dc56..e30538742a 100644 --- a/api/core/repositories/sqlalchemy_workflow_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py @@ -6,7 +6,7 @@ import json import logging from typing import Optional, Union -from sqlalchemy import func, select +from sqlalchemy import select from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -146,20 +146,7 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository): db_model.workflow_id = domain_model.workflow_id db_model.triggered_from = self._triggered_from - # Check if this is a new record - with self._session_factory() as session: - existing = session.scalar(select(WorkflowRun).where(WorkflowRun.id == domain_model.id_)) - if not existing: - # For new records, get the next sequence number - stmt = select(func.max(WorkflowRun.sequence_number)).where( - WorkflowRun.app_id == self._app_id, - WorkflowRun.tenant_id == self._tenant_id, - ) - max_sequence = session.scalar(stmt) - db_model.sequence_number = (max_sequence or 0) + 1 - else: - # For updates, keep the existing sequence number - db_model.sequence_number = existing.sequence_number + # No sequence number generation needed anymore db_model.type = domain_model.workflow_type db_model.version = domain_model.workflow_version diff --git a/api/core/workflow/nodes/variable_assigner/v2/constants.py b/api/core/workflow/nodes/variable_assigner/v2/constants.py index 7d7922abd4..3797bfa77a 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/constants.py +++ b/api/core/workflow/nodes/variable_assigner/v2/constants.py @@ -8,5 +8,4 @@ EMPTY_VALUE_MAPPING = { SegmentType.ARRAY_STRING: [], SegmentType.ARRAY_NUMBER: [], SegmentType.ARRAY_OBJECT: [], - SegmentType.ARRAY_FILE: [], } diff --git a/api/core/workflow/nodes/variable_assigner/v2/helpers.py b/api/core/workflow/nodes/variable_assigner/v2/helpers.py index f33f406145..8fb2a27388 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/helpers.py +++ b/api/core/workflow/nodes/variable_assigner/v2/helpers.py @@ -1,6 +1,5 @@ from typing import Any -from core.file import File from core.variables import SegmentType from .enums import Operation @@ -86,8 +85,6 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va return isinstance(value, int | float) case SegmentType.ARRAY_OBJECT if operation == Operation.APPEND: return isinstance(value, dict) - case SegmentType.ARRAY_FILE if operation == Operation.APPEND: - return isinstance(value, File) # Array & Extend / Overwrite case SegmentType.ARRAY_ANY if operation in {Operation.EXTEND, Operation.OVER_WRITE}: @@ -98,8 +95,6 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va return isinstance(value, list) and all(isinstance(item, int | float) for item in value) case SegmentType.ARRAY_OBJECT if operation in {Operation.EXTEND, Operation.OVER_WRITE}: return isinstance(value, list) and all(isinstance(item, dict) for item in value) - case SegmentType.ARRAY_FILE if operation in {Operation.EXTEND, Operation.OVER_WRITE}: - return isinstance(value, list) and all(isinstance(item, File) for item in value) case _: return False diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 202f10044b..8915c18bd8 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -110,8 +110,6 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen result = ArrayNumberVariable.model_validate(mapping) case SegmentType.ARRAY_OBJECT if isinstance(value, list): result = ArrayObjectVariable.model_validate(mapping) - case SegmentType.ARRAY_FILE if isinstance(value, list): - result = ArrayFileVariable.model_validate(mapping) case _: raise VariableError(f"not supported value type {value_type}") if result.size > dify_config.MAX_VARIABLE_SIZE: diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 74fdf8bd97..a106728e9c 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -19,7 +19,6 @@ workflow_run_for_log_fields = { workflow_run_for_list_fields = { "id": fields.String, - "sequence_number": fields.Integer, "version": fields.String, "status": fields.String, "elapsed_time": fields.Float, @@ -36,7 +35,6 @@ advanced_chat_workflow_run_for_list_fields = { "id": fields.String, "conversation_id": fields.String, "message_id": fields.String, - "sequence_number": fields.Integer, "version": fields.String, "status": fields.String, "elapsed_time": fields.Float, @@ -63,7 +61,6 @@ workflow_run_pagination_fields = { workflow_run_detail_fields = { "id": fields.String, - "sequence_number": fields.Integer, "version": fields.String, "graph": fields.Raw(attribute="graph_dict"), "inputs": fields.Raw(attribute="inputs_dict"), diff --git a/api/libs/sendgrid.py b/api/libs/sendgrid.py index 6217e9f4a6..5409e3eeeb 100644 --- a/api/libs/sendgrid.py +++ b/api/libs/sendgrid.py @@ -35,7 +35,10 @@ class SendGridClient: logging.exception("SendGridClient Timeout occurred while sending email") raise except (UnauthorizedError, ForbiddenError) as e: - logging.exception("SendGridClient Authentication failed. Verify that your credentials and the 'from") + logging.exception( + "SendGridClient Authentication failed. " + "Verify that your credentials and the 'from' email address are correct" + ) raise except Exception as e: logging.exception(f"SendGridClient Unexpected error occurred while sending email to {_to}") diff --git a/api/migrations/versions/2025_06_19_1633-0ab65e1cc7fa_remove_sequence_number_from_workflow_.py b/api/migrations/versions/2025_06_19_1633-0ab65e1cc7fa_remove_sequence_number_from_workflow_.py new file mode 100644 index 0000000000..29fef77798 --- /dev/null +++ b/api/migrations/versions/2025_06_19_1633-0ab65e1cc7fa_remove_sequence_number_from_workflow_.py @@ -0,0 +1,66 @@ +"""remove sequence_number from workflow_runs + +Revision ID: 0ab65e1cc7fa +Revises: 4474872b0ee6 +Create Date: 2025-06-19 16:33:13.377215 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '0ab65e1cc7fa' +down_revision = '4474872b0ee6' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.drop_index(batch_op.f('workflow_run_tenant_app_sequence_idx')) + batch_op.drop_column('sequence_number') + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + # WARNING: This downgrade CANNOT recover the original sequence_number values! + # The original sequence numbers are permanently lost after the upgrade. + # This downgrade will regenerate sequence numbers based on created_at order, + # which may result in different values than the original sequence numbers. + # + # If you need to preserve original sequence numbers, use the alternative + # migration approach that creates a backup table before removal. + + # Step 1: Add sequence_number column as nullable first + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.add_column(sa.Column('sequence_number', sa.INTEGER(), autoincrement=False, nullable=True)) + + # Step 2: Populate sequence_number values based on created_at order within each app + # NOTE: This recreates sequence numbering logic but values will be different + # from the original sequence numbers that were removed in the upgrade + connection = op.get_bind() + connection.execute(sa.text(""" + UPDATE workflow_runs + SET sequence_number = subquery.row_num + FROM ( + SELECT id, ROW_NUMBER() OVER ( + PARTITION BY tenant_id, app_id + ORDER BY created_at, id + ) as row_num + FROM workflow_runs + ) subquery + WHERE workflow_runs.id = subquery.id + """)) + + # Step 3: Make the column NOT NULL and add the index + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.alter_column('sequence_number', nullable=False) + batch_op.create_index(batch_op.f('workflow_run_tenant_app_sequence_idx'), ['tenant_id', 'app_id', 'sequence_number'], unique=False) + + # ### end Alembic commands ### diff --git a/api/models/workflow.py b/api/models/workflow.py index 741422db06..90dd55858c 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -410,7 +410,7 @@ class WorkflowRun(Base): - id (uuid) Run ID - tenant_id (uuid) Workspace ID - app_id (uuid) App ID - - sequence_number (int) Auto-increment sequence number, incremented within the App, starting from 1 + - workflow_id (uuid) Workflow ID - type (string) Workflow type - triggered_from (string) Trigger source @@ -443,13 +443,12 @@ class WorkflowRun(Base): __table_args__ = ( db.PrimaryKeyConstraint("id", name="workflow_run_pkey"), db.Index("workflow_run_triggerd_from_idx", "tenant_id", "app_id", "triggered_from"), - db.Index("workflow_run_tenant_app_sequence_idx", "tenant_id", "app_id", "sequence_number"), ) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID) app_id: Mapped[str] = mapped_column(StringUUID) - sequence_number: Mapped[int] = mapped_column() + workflow_id: Mapped[str] = mapped_column(StringUUID) type: Mapped[str] = mapped_column(db.String(255)) triggered_from: Mapped[str] = mapped_column(db.String(255)) @@ -509,7 +508,6 @@ class WorkflowRun(Base): "id": self.id, "tenant_id": self.tenant_id, "app_id": self.app_id, - "sequence_number": self.sequence_number, "workflow_id": self.workflow_id, "type": self.type, "triggered_from": self.triggered_from, @@ -535,7 +533,6 @@ class WorkflowRun(Base): id=data.get("id"), tenant_id=data.get("tenant_id"), app_id=data.get("app_id"), - sequence_number=data.get("sequence_number"), workflow_id=data.get("workflow_id"), type=data.get("type"), triggered_from=data.get("triggered_from"), diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 6f23f98e67..121ed2d5a2 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -64,6 +64,7 @@ from services.external_knowledge_service import ExternalDatasetService from services.feature_service import FeatureModel, FeatureService from services.tag_service import TagService from services.vector_service import VectorService +from tasks.add_document_to_index_task import add_document_to_index_task from tasks.batch_clean_document_task import batch_clean_document_task from tasks.clean_notion_document_task import clean_notion_document_task from tasks.deal_dataset_index_update_task import deal_dataset_index_update_task @@ -76,6 +77,7 @@ from tasks.document_indexing_update_task import document_indexing_update_task from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task from tasks.enable_segments_to_index_task import enable_segments_to_index_task from tasks.recover_document_indexing_task import recover_document_indexing_task +from tasks.remove_document_from_index_task import remove_document_from_index_task from tasks.retry_document_indexing_task import retry_document_indexing_task from tasks.sync_website_document_indexing_task import sync_website_document_indexing_task @@ -490,7 +492,7 @@ class DatasetService: raise ValueError(ex.description) filtered_data["updated_by"] = user.id - filtered_data["updated_at"] = datetime.datetime.now() + filtered_data["updated_at"] = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) # update Retrieval model filtered_data["retrieval_model"] = data["retrieval_model"] @@ -1157,12 +1159,17 @@ class DocumentService: process_rule = knowledge_config.process_rule if process_rule: if process_rule.mode in ("custom", "hierarchical"): - dataset_process_rule = DatasetProcessRule( - dataset_id=dataset.id, - mode=process_rule.mode, - rules=process_rule.rules.model_dump_json() if process_rule.rules else None, - created_by=account.id, - ) + if process_rule.rules: + dataset_process_rule = DatasetProcessRule( + dataset_id=dataset.id, + mode=process_rule.mode, + rules=process_rule.rules.model_dump_json() if process_rule.rules else None, + created_by=account.id, + ) + else: + dataset_process_rule = dataset.latest_process_rule + if not dataset_process_rule: + raise ValueError("No process rule found.") elif process_rule.mode == "automatic": dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, @@ -2061,6 +2068,191 @@ class DocumentService: if not isinstance(args["process_rule"]["rules"]["segmentation"]["max_tokens"], int): raise ValueError("Process rule segmentation max_tokens is invalid") + @staticmethod + def batch_update_document_status(dataset: Dataset, document_ids: list[str], action: str, user): + """ + Batch update document status. + + Args: + dataset (Dataset): The dataset object + document_ids (list[str]): List of document IDs to update + action (str): Action to perform (enable, disable, archive, un_archive) + user: Current user performing the action + + Raises: + DocumentIndexingError: If document is being indexed or not in correct state + ValueError: If action is invalid + """ + if not document_ids: + return + + # Early validation of action parameter + valid_actions = ["enable", "disable", "archive", "un_archive"] + if action not in valid_actions: + raise ValueError(f"Invalid action: {action}. Must be one of {valid_actions}") + + documents_to_update = [] + + # First pass: validate all documents and prepare updates + for document_id in document_ids: + document = DocumentService.get_document(dataset.id, document_id) + if not document: + continue + + # Check if document is being indexed + indexing_cache_key = f"document_{document.id}_indexing" + cache_result = redis_client.get(indexing_cache_key) + if cache_result is not None: + raise DocumentIndexingError(f"Document:{document.name} is being indexed, please try again later") + + # Prepare update based on action + update_info = DocumentService._prepare_document_status_update(document, action, user) + if update_info: + documents_to_update.append(update_info) + + # Second pass: apply all updates in a single transaction + if documents_to_update: + try: + for update_info in documents_to_update: + document = update_info["document"] + updates = update_info["updates"] + + # Apply updates to the document + for field, value in updates.items(): + setattr(document, field, value) + + db.session.add(document) + + # Batch commit all changes + db.session.commit() + except Exception as e: + # Rollback on any error + db.session.rollback() + raise e + # Execute async tasks and set Redis cache after successful commit + # propagation_error is used to capture any errors for submitting async task execution + propagation_error = None + for update_info in documents_to_update: + try: + # Execute async tasks after successful commit + if update_info["async_task"]: + task_info = update_info["async_task"] + task_func = task_info["function"] + task_args = task_info["args"] + task_func.delay(*task_args) + except Exception as e: + # Log the error but do not rollback the transaction + logging.exception(f"Error executing async task for document {update_info['document'].id}") + # don't raise the error immediately, but capture it for later + propagation_error = e + try: + # Set Redis cache if needed after successful commit + if update_info["set_cache"]: + document = update_info["document"] + indexing_cache_key = f"document_{document.id}_indexing" + redis_client.setex(indexing_cache_key, 600, 1) + except Exception as e: + # Log the error but do not rollback the transaction + logging.exception(f"Error setting cache for document {update_info['document'].id}") + # Raise any propagation error after all updates + if propagation_error: + raise propagation_error + + @staticmethod + def _prepare_document_status_update(document, action: str, user): + """ + Prepare document status update information. + + Args: + document: Document object to update + action: Action to perform + user: Current user + + Returns: + dict: Update information or None if no update needed + """ + now = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + + if action == "enable": + return DocumentService._prepare_enable_update(document, now) + elif action == "disable": + return DocumentService._prepare_disable_update(document, user, now) + elif action == "archive": + return DocumentService._prepare_archive_update(document, user, now) + elif action == "un_archive": + return DocumentService._prepare_unarchive_update(document, now) + + return None + + @staticmethod + def _prepare_enable_update(document, now): + """Prepare updates for enabling a document.""" + if document.enabled: + return None + + return { + "document": document, + "updates": {"enabled": True, "disabled_at": None, "disabled_by": None, "updated_at": now}, + "async_task": {"function": add_document_to_index_task, "args": [document.id]}, + "set_cache": True, + } + + @staticmethod + def _prepare_disable_update(document, user, now): + """Prepare updates for disabling a document.""" + if not document.completed_at or document.indexing_status != "completed": + raise DocumentIndexingError(f"Document: {document.name} is not completed.") + + if not document.enabled: + return None + + return { + "document": document, + "updates": {"enabled": False, "disabled_at": now, "disabled_by": user.id, "updated_at": now}, + "async_task": {"function": remove_document_from_index_task, "args": [document.id]}, + "set_cache": True, + } + + @staticmethod + def _prepare_archive_update(document, user, now): + """Prepare updates for archiving a document.""" + if document.archived: + return None + + update_info = { + "document": document, + "updates": {"archived": True, "archived_at": now, "archived_by": user.id, "updated_at": now}, + "async_task": None, + "set_cache": False, + } + + # Only set async task and cache if document is currently enabled + if document.enabled: + update_info["async_task"] = {"function": remove_document_from_index_task, "args": [document.id]} + update_info["set_cache"] = True + + return update_info + + @staticmethod + def _prepare_unarchive_update(document, now): + """Prepare updates for unarchiving a document.""" + if not document.archived: + return None + + update_info = { + "document": document, + "updates": {"archived": False, "archived_at": None, "archived_by": None, "updated_at": now}, + "async_task": None, + "set_cache": False, + } + + # Only re-index if the document is currently enabled + if document.enabled: + update_info["async_task"] = {"function": add_document_to_index_task, "args": [document.id]} + update_info["set_cache"] = True + + return update_info + class SegmentService: @classmethod diff --git a/api/tests/unit_tests/conftest.py b/api/tests/unit_tests/conftest.py index e09acc4c39..077ffe3408 100644 --- a/api/tests/unit_tests/conftest.py +++ b/api/tests/unit_tests/conftest.py @@ -1,4 +1,5 @@ import os +from unittest.mock import MagicMock, patch import pytest from flask import Flask @@ -11,6 +12,24 @@ PROJECT_DIR = os.path.abspath(os.path.join(ABS_PATH, os.pardir, os.pardir)) CACHED_APP = Flask(__name__) +# set global mock for Redis client +redis_mock = MagicMock() +redis_mock.get = MagicMock(return_value=None) +redis_mock.setex = MagicMock() +redis_mock.setnx = MagicMock() +redis_mock.delete = MagicMock() +redis_mock.lock = MagicMock() +redis_mock.exists = MagicMock(return_value=False) +redis_mock.set = MagicMock() +redis_mock.expire = MagicMock() +redis_mock.hgetall = MagicMock(return_value={}) +redis_mock.hdel = MagicMock() +redis_mock.incr = MagicMock(return_value=1) + +# apply the mock to the Redis client in the Flask app +redis_patcher = patch("extensions.ext_redis.redis_client", redis_mock) +redis_patcher.start() + @pytest.fixture def app() -> Flask: @@ -21,3 +40,19 @@ def app() -> Flask: def _provide_app_context(app: Flask): with app.app_context(): yield + + +@pytest.fixture(autouse=True) +def reset_redis_mock(): + """reset the Redis mock before each test""" + redis_mock.reset_mock() + redis_mock.get.return_value = None + redis_mock.setex.return_value = None + redis_mock.setnx.return_value = None + redis_mock.delete.return_value = None + redis_mock.exists.return_value = False + redis_mock.set.return_value = None + redis_mock.expire.return_value = None + redis_mock.hgetall.return_value = {} + redis_mock.hdel.return_value = None + redis_mock.incr.return_value = 1 diff --git a/api/tests/unit_tests/core/rag/extractor/test_markdown_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_markdown_extractor.py new file mode 100644 index 0000000000..d4cf534c56 --- /dev/null +++ b/api/tests/unit_tests/core/rag/extractor/test_markdown_extractor.py @@ -0,0 +1,22 @@ +from core.rag.extractor.markdown_extractor import MarkdownExtractor + + +def test_markdown_to_tups(): + markdown = """ +this is some text without header + +# title 1 +this is balabala text + +## title 2 +this is more specific text. + """ + extractor = MarkdownExtractor(file_path="dummy_path") + updated_output = extractor.markdown_to_tups(markdown) + assert len(updated_output) == 3 + key, header_value = updated_output[0] + assert key == None + assert header_value.strip() == "this is some text without header" + title_1, value = updated_output[1] + assert title_1.strip() == "title 1" + assert value.strip() == "this is balabala text" diff --git a/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py b/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py index fddc182594..646de8bf3a 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py @@ -163,7 +163,6 @@ def real_workflow_run(): workflow_run.tenant_id = "test-tenant-id" workflow_run.app_id = "test-app-id" workflow_run.workflow_id = "test-workflow-id" - workflow_run.sequence_number = 1 workflow_run.type = "chat" workflow_run.triggered_from = "app-run" workflow_run.version = "1.0" diff --git a/api/tests/unit_tests/services/test_dataset_service.py b/api/tests/unit_tests/services/test_dataset_service.py new file mode 100644 index 0000000000..f22500cfe4 --- /dev/null +++ b/api/tests/unit_tests/services/test_dataset_service.py @@ -0,0 +1,1238 @@ +import datetime +import unittest + +# Mock redis_client before importing dataset_service +from unittest.mock import Mock, call, patch + +import pytest + +from models.dataset import Dataset, Document +from services.dataset_service import DocumentService +from services.errors.document import DocumentIndexingError +from tests.unit_tests.conftest import redis_mock + + +class TestDatasetServiceBatchUpdateDocumentStatus(unittest.TestCase): + """ + Comprehensive unit tests for DocumentService.batch_update_document_status method. + + This test suite covers all supported actions (enable, disable, archive, un_archive), + error conditions, edge cases, and validates proper interaction with Redis cache, + database operations, and async task triggers. + """ + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.add_document_to_index_task") + @patch("services.dataset_service.DocumentService.get_document") + @patch("services.dataset_service.datetime") + def test_batch_update_enable_documents_success(self, mock_datetime, mock_get_doc, mock_add_task, mock_db): + """ + Test successful enabling of disabled documents. + + Verifies that: + 1. Only disabled documents are processed (already enabled documents are skipped) + 2. Document attributes are updated correctly (enabled=True, metadata cleared) + 3. Database changes are committed for each document + 4. Redis cache keys are set to prevent concurrent indexing + 5. Async indexing task is triggered for each enabled document + 6. Timestamp fields are properly updated + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.tenant_id = "tenant-456" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Create mock disabled document + mock_disabled_doc_1 = Mock(spec=Document) + mock_disabled_doc_1.id = "doc-1" + mock_disabled_doc_1.name = "disabled_document.pdf" + mock_disabled_doc_1.enabled = False + mock_disabled_doc_1.archived = False + mock_disabled_doc_1.indexing_status = "completed" + mock_disabled_doc_1.completed_at = datetime.datetime.now() + + mock_disabled_doc_2 = Mock(spec=Document) + mock_disabled_doc_2.id = "doc-2" + mock_disabled_doc_2.name = "disabled_document.pdf" + mock_disabled_doc_2.enabled = False + mock_disabled_doc_2.archived = False + mock_disabled_doc_2.indexing_status = "completed" + mock_disabled_doc_2.completed_at = datetime.datetime.now() + + # Set up mock return values + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + # Mock document retrieval to return disabled documents + mock_get_doc.side_effect = [mock_disabled_doc_1, mock_disabled_doc_2] + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Call the method to enable documents + DocumentService.batch_update_document_status( + dataset=mock_dataset, document_ids=["doc-1", "doc-2"], action="enable", user=mock_user + ) + + # Verify document attributes were updated correctly + for mock_doc in [mock_disabled_doc_1, mock_disabled_doc_2]: + # Check that document was enabled + assert mock_doc.enabled == True + # Check that disable metadata was cleared + assert mock_doc.disabled_at is None + assert mock_doc.disabled_by is None + # Check that update timestamp was set + assert mock_doc.updated_at == current_time.replace(tzinfo=None) + + # Verify Redis cache operations + expected_cache_calls = [call("document_doc-1_indexing"), call("document_doc-2_indexing")] + redis_mock.get.assert_has_calls(expected_cache_calls) + + # Verify Redis cache was set to prevent concurrent indexing (600 seconds) + expected_setex_calls = [call("document_doc-1_indexing", 600, 1), call("document_doc-2_indexing", 600, 1)] + redis_mock.setex.assert_has_calls(expected_setex_calls) + + # Verify async tasks were triggered for indexing + expected_task_calls = [call("doc-1"), call("doc-2")] + mock_add_task.delay.assert_has_calls(expected_task_calls) + + # Verify database add counts (one add for one document) + assert mock_db.add.call_count == 2 + # Verify database commits (one commit for the batch operation) + assert mock_db.commit.call_count == 1 + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.remove_document_from_index_task") + @patch("services.dataset_service.DocumentService.get_document") + @patch("services.dataset_service.datetime") + def test_batch_update_disable_documents_success(self, mock_datetime, mock_get_doc, mock_remove_task, mock_db): + """ + Test successful disabling of enabled and completed documents. + + Verifies that: + 1. Only completed and enabled documents can be disabled + 2. Document attributes are updated correctly (enabled=False, disable metadata set) + 3. User ID is recorded in disabled_by field + 4. Database changes are committed for each document + 5. Redis cache keys are set to prevent concurrent indexing + 6. Async task is triggered to remove documents from index + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.tenant_id = "tenant-456" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Create mock enabled document + mock_enabled_doc_1 = Mock(spec=Document) + mock_enabled_doc_1.id = "doc-1" + mock_enabled_doc_1.name = "enabled_document.pdf" + mock_enabled_doc_1.enabled = True + mock_enabled_doc_1.archived = False + mock_enabled_doc_1.indexing_status = "completed" + mock_enabled_doc_1.completed_at = datetime.datetime.now() + + mock_enabled_doc_2 = Mock(spec=Document) + mock_enabled_doc_2.id = "doc-2" + mock_enabled_doc_2.name = "enabled_document.pdf" + mock_enabled_doc_2.enabled = True + mock_enabled_doc_2.archived = False + mock_enabled_doc_2.indexing_status = "completed" + mock_enabled_doc_2.completed_at = datetime.datetime.now() + + # Set up mock return values + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + # Mock document retrieval to return enabled, completed documents + mock_get_doc.side_effect = [mock_enabled_doc_1, mock_enabled_doc_2] + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Call the method to disable documents + DocumentService.batch_update_document_status( + dataset=mock_dataset, document_ids=["doc-1", "doc-2"], action="disable", user=mock_user + ) + + # Verify document attributes were updated correctly + for mock_doc in [mock_enabled_doc_1, mock_enabled_doc_2]: + # Check that document was disabled + assert mock_doc.enabled == False + # Check that disable metadata was set correctly + assert mock_doc.disabled_at == current_time.replace(tzinfo=None) + assert mock_doc.disabled_by == mock_user.id + # Check that update timestamp was set + assert mock_doc.updated_at == current_time.replace(tzinfo=None) + + # Verify Redis cache operations for indexing prevention + expected_setex_calls = [call("document_doc-1_indexing", 600, 1), call("document_doc-2_indexing", 600, 1)] + redis_mock.setex.assert_has_calls(expected_setex_calls) + + # Verify async tasks were triggered to remove from index + expected_task_calls = [call("doc-1"), call("doc-2")] + mock_remove_task.delay.assert_has_calls(expected_task_calls) + + # Verify database add counts (one add for one document) + assert mock_db.add.call_count == 2 + # Verify database commits (totally 1 for any batch operation) + assert mock_db.commit.call_count == 1 + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.remove_document_from_index_task") + @patch("services.dataset_service.DocumentService.get_document") + @patch("services.dataset_service.datetime") + def test_batch_update_archive_documents_success(self, mock_datetime, mock_get_doc, mock_remove_task, mock_db): + """ + Test successful archiving of unarchived documents. + + Verifies that: + 1. Only unarchived documents are processed (already archived are skipped) + 2. Document attributes are updated correctly (archived=True, archive metadata set) + 3. User ID is recorded in archived_by field + 4. If documents are enabled, they are removed from the index + 5. Redis cache keys are set only for enabled documents being archived + 6. Database changes are committed for each document + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.tenant_id = "tenant-456" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Create unarchived enabled document + unarchived_doc = Mock(spec=Document) + # Manually set attributes to ensure they can be modified + unarchived_doc.id = "doc-1" + unarchived_doc.name = "unarchived_document.pdf" + unarchived_doc.enabled = True + unarchived_doc.archived = False + + # Set up mock return values + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + mock_get_doc.return_value = unarchived_doc + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Call the method to archive documents + DocumentService.batch_update_document_status( + dataset=mock_dataset, document_ids=["doc-1"], action="archive", user=mock_user + ) + + # Verify document attributes were updated correctly + assert unarchived_doc.archived == True + assert unarchived_doc.archived_at == current_time.replace(tzinfo=None) + assert unarchived_doc.archived_by == mock_user.id + assert unarchived_doc.updated_at == current_time.replace(tzinfo=None) + + # Verify Redis cache was set (because document was enabled) + redis_mock.setex.assert_called_once_with("document_doc-1_indexing", 600, 1) + + # Verify async task was triggered to remove from index (because enabled) + mock_remove_task.delay.assert_called_once_with("doc-1") + + # Verify database add + mock_db.add.assert_called_once() + # Verify database commit + mock_db.commit.assert_called_once() + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.add_document_to_index_task") + @patch("services.dataset_service.DocumentService.get_document") + @patch("services.dataset_service.datetime") + def test_batch_update_unarchive_documents_success(self, mock_datetime, mock_get_doc, mock_add_task, mock_db): + """ + Test successful unarchiving of archived documents. + + Verifies that: + 1. Only archived documents are processed (already unarchived are skipped) + 2. Document attributes are updated correctly (archived=False, archive metadata cleared) + 3. If documents are enabled, they are added back to the index + 4. Redis cache keys are set only for enabled documents being unarchived + 5. Database changes are committed for each document + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.tenant_id = "tenant-456" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Create mock archived document + mock_archived_doc = Mock(spec=Document) + mock_archived_doc.id = "doc-3" + mock_archived_doc.name = "archived_document.pdf" + mock_archived_doc.enabled = True + mock_archived_doc.archived = True + mock_archived_doc.indexing_status = "completed" + mock_archived_doc.completed_at = datetime.datetime.now() + + # Set up mock return values + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + mock_get_doc.return_value = mock_archived_doc + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Call the method to unarchive documents + DocumentService.batch_update_document_status( + dataset=mock_dataset, document_ids=["doc-3"], action="un_archive", user=mock_user + ) + + # Verify document attributes were updated correctly + assert mock_archived_doc.archived == False + assert mock_archived_doc.archived_at is None + assert mock_archived_doc.archived_by is None + assert mock_archived_doc.updated_at == current_time.replace(tzinfo=None) + + # Verify Redis cache was set (because document is enabled) + redis_mock.setex.assert_called_once_with("document_doc-3_indexing", 600, 1) + + # Verify async task was triggered to add back to index (because enabled) + mock_add_task.delay.assert_called_once_with("doc-3") + + # Verify database add + mock_db.add.assert_called_once() + # Verify database commit + mock_db.commit.assert_called_once() + + @patch("services.dataset_service.DocumentService.get_document") + def test_batch_update_document_indexing_error_redis_cache_hit(self, mock_get_doc): + """ + Test that DocumentIndexingError is raised when documents are currently being indexed. + + Verifies that: + 1. The method checks Redis cache for active indexing operations + 2. DocumentIndexingError is raised if any document is being indexed + 3. Error message includes the document name for user feedback + 4. No further processing occurs when indexing is detected + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.tenant_id = "tenant-456" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Create mock enabled document + mock_enabled_doc = Mock(spec=Document) + mock_enabled_doc.id = "doc-1" + mock_enabled_doc.name = "enabled_document.pdf" + mock_enabled_doc.enabled = True + mock_enabled_doc.archived = False + mock_enabled_doc.indexing_status = "completed" + mock_enabled_doc.completed_at = datetime.datetime.now() + + # Set up mock to indicate document is being indexed + mock_get_doc.return_value = mock_enabled_doc + + # Reset module-level Redis mock, set to indexing status + redis_mock.reset_mock() + redis_mock.get.return_value = "indexing" + + # Verify that DocumentIndexingError is raised + with pytest.raises(DocumentIndexingError) as exc_info: + DocumentService.batch_update_document_status( + dataset=mock_dataset, document_ids=["doc-1"], action="enable", user=mock_user + ) + + # Verify error message contains document name + assert "enabled_document.pdf" in str(exc_info.value) + assert "is being indexed" in str(exc_info.value) + + # Verify Redis cache was checked + redis_mock.get.assert_called_once_with("document_doc-1_indexing") + + @patch("services.dataset_service.DocumentService.get_document") + def test_batch_update_disable_non_completed_document_error(self, mock_get_doc): + """ + Test that DocumentIndexingError is raised when trying to disable non-completed documents. + + Verifies that: + 1. Only completed documents can be disabled + 2. DocumentIndexingError is raised for non-completed documents + 3. Error message indicates the document is not completed + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.tenant_id = "tenant-456" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Create a document that's not completed + non_completed_doc = Mock(spec=Document) + # Manually set attributes to ensure they can be modified + non_completed_doc.id = "doc-1" + non_completed_doc.name = "indexing_document.pdf" + non_completed_doc.enabled = True + non_completed_doc.indexing_status = "indexing" # Not completed + non_completed_doc.completed_at = None # Not completed + + mock_get_doc.return_value = non_completed_doc + + # Verify that DocumentIndexingError is raised + with pytest.raises(DocumentIndexingError) as exc_info: + DocumentService.batch_update_document_status( + dataset=mock_dataset, document_ids=["doc-1"], action="disable", user=mock_user + ) + + # Verify error message indicates document is not completed + assert "is not completed" in str(exc_info.value) + + @patch("services.dataset_service.DocumentService.get_document") + def test_batch_update_empty_document_list(self, mock_get_doc): + """ + Test batch operations with an empty document ID list. + + Verifies that: + 1. The method handles empty input gracefully + 2. No document operations are performed with empty input + 3. No errors are raised with empty input + 4. Method returns early without processing + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.tenant_id = "tenant-456" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Call method with empty document list + result = DocumentService.batch_update_document_status( + dataset=mock_dataset, document_ids=[], action="enable", user=mock_user + ) + + # Verify no document lookups were performed + mock_get_doc.assert_not_called() + + # Verify method returns None (early return) + assert result is None + + @patch("services.dataset_service.DocumentService.get_document") + def test_batch_update_document_not_found_skipped(self, mock_get_doc): + """ + Test behavior when some documents don't exist in the database. + + Verifies that: + 1. Non-existent documents are gracefully skipped + 2. Processing continues for existing documents + 3. No errors are raised for missing document IDs + 4. Method completes successfully despite missing documents + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.tenant_id = "tenant-456" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Mock document service to return None (document not found) + mock_get_doc.return_value = None + + # Call method with non-existent document ID + # This should not raise an error, just skip the missing document + try: + DocumentService.batch_update_document_status( + dataset=mock_dataset, document_ids=["non-existent-doc"], action="enable", user=mock_user + ) + except Exception as e: + pytest.fail(f"Method should not raise exception for missing documents: {e}") + + # Verify document lookup was attempted + mock_get_doc.assert_called_once_with(mock_dataset.id, "non-existent-doc") + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DocumentService.get_document") + def test_batch_update_enable_already_enabled_document_skipped(self, mock_get_doc, mock_db): + """ + Test enabling documents that are already enabled. + + Verifies that: + 1. Already enabled documents are skipped (no unnecessary operations) + 2. No database commits occur for already enabled documents + 3. No Redis cache operations occur for skipped documents + 4. No async tasks are triggered for skipped documents + 5. Method completes successfully + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.tenant_id = "tenant-456" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Create mock enabled document + mock_enabled_doc = Mock(spec=Document) + mock_enabled_doc.id = "doc-1" + mock_enabled_doc.name = "enabled_document.pdf" + mock_enabled_doc.enabled = True + mock_enabled_doc.archived = False + mock_enabled_doc.indexing_status = "completed" + mock_enabled_doc.completed_at = datetime.datetime.now() + + # Mock document that is already enabled + mock_get_doc.return_value = mock_enabled_doc # Already enabled + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Attempt to enable already enabled document + DocumentService.batch_update_document_status( + dataset=mock_dataset, document_ids=["doc-1"], action="enable", user=mock_user + ) + + # Verify no database operations occurred (document was skipped) + mock_db.commit.assert_not_called() + + # Verify no Redis setex operations occurred (document was skipped) + redis_mock.setex.assert_not_called() + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DocumentService.get_document") + def test_batch_update_archive_already_archived_document_skipped(self, mock_get_doc, mock_db): + """ + Test archiving documents that are already archived. + + Verifies that: + 1. Already archived documents are skipped (no unnecessary operations) + 2. No database commits occur for already archived documents + 3. No Redis cache operations occur for skipped documents + 4. No async tasks are triggered for skipped documents + 5. Method completes successfully + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.tenant_id = "tenant-456" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Create mock archived document + mock_archived_doc = Mock(spec=Document) + mock_archived_doc.id = "doc-3" + mock_archived_doc.name = "archived_document.pdf" + mock_archived_doc.enabled = True + mock_archived_doc.archived = True + mock_archived_doc.indexing_status = "completed" + mock_archived_doc.completed_at = datetime.datetime.now() + + # Mock document that is already archived + mock_get_doc.return_value = mock_archived_doc # Already archived + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Attempt to archive already archived document + DocumentService.batch_update_document_status( + dataset=mock_dataset, document_ids=["doc-3"], action="archive", user=mock_user + ) + + # Verify no database operations occurred (document was skipped) + mock_db.commit.assert_not_called() + + # Verify no Redis setex operations occurred (document was skipped) + redis_mock.setex.assert_not_called() + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.add_document_to_index_task") + @patch("services.dataset_service.remove_document_from_index_task") + @patch("services.dataset_service.DocumentService.get_document") + @patch("services.dataset_service.datetime") + def test_batch_update_mixed_document_states_and_actions( + self, mock_datetime, mock_get_doc, mock_remove_task, mock_add_task, mock_db + ): + """ + Test batch operations on documents with mixed states and various scenarios. + + Verifies that: + 1. Each document is processed according to its current state + 2. Some documents may be skipped while others are processed + 3. Different async tasks are triggered based on document states + 4. Method handles mixed scenarios gracefully + 5. Database commits occur only for documents that were actually modified + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.tenant_id = "tenant-456" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Create mock documents with different states + mock_disabled_doc = Mock(spec=Document) + mock_disabled_doc.id = "doc-1" + mock_disabled_doc.name = "disabled_document.pdf" + mock_disabled_doc.enabled = False + mock_disabled_doc.archived = False + mock_disabled_doc.indexing_status = "completed" + mock_disabled_doc.completed_at = datetime.datetime.now() + + mock_enabled_doc = Mock(spec=Document) + mock_enabled_doc.id = "doc-2" + mock_enabled_doc.name = "enabled_document.pdf" + mock_enabled_doc.enabled = True + mock_enabled_doc.archived = False + mock_enabled_doc.indexing_status = "completed" + mock_enabled_doc.completed_at = datetime.datetime.now() + + mock_archived_doc = Mock(spec=Document) + mock_archived_doc.id = "doc-3" + mock_archived_doc.name = "archived_document.pdf" + mock_archived_doc.enabled = True + mock_archived_doc.archived = True + mock_archived_doc.indexing_status = "completed" + mock_archived_doc.completed_at = datetime.datetime.now() + + # Set up mixed document states + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + # Mix of different document states + documents = [ + mock_disabled_doc, # Will be enabled + mock_enabled_doc, # Already enabled, will be skipped + mock_archived_doc, # Archived but enabled, will be skipped for enable action + ] + + mock_get_doc.side_effect = documents + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Perform enable operation on mixed state documents + DocumentService.batch_update_document_status( + dataset=mock_dataset, document_ids=["doc-1", "doc-2", "doc-3"], action="enable", user=mock_user + ) + + # Verify only the disabled document was processed + # (enabled and archived documents should be skipped for enable action) + + # Only one add should occur (for the disabled document that was enabled) + mock_db.add.assert_called_once() + # Only one commit should occur + mock_db.commit.assert_called_once() + + # Only one Redis setex should occur (for the document that was enabled) + redis_mock.setex.assert_called_once_with("document_doc-1_indexing", 600, 1) + + # Only one async task should be triggered (for the document that was enabled) + mock_add_task.delay.assert_called_once_with("doc-1") + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.remove_document_from_index_task") + @patch("services.dataset_service.DocumentService.get_document") + @patch("services.dataset_service.datetime") + def test_batch_update_archive_disabled_document_no_index_removal( + self, mock_datetime, mock_get_doc, mock_remove_task, mock_db + ): + """ + Test archiving disabled documents (should not trigger index removal). + + Verifies that: + 1. Disabled documents can be archived + 2. Archive metadata is set correctly + 3. No index removal task is triggered (because document is disabled) + 4. No Redis cache key is set (because document is disabled) + 5. Database commit still occurs + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.tenant_id = "tenant-456" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Set up disabled, unarchived document + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + disabled_unarchived_doc = Mock(spec=Document) + # Manually set attributes to ensure they can be modified + disabled_unarchived_doc.id = "doc-1" + disabled_unarchived_doc.name = "disabled_document.pdf" + disabled_unarchived_doc.enabled = False # Disabled + disabled_unarchived_doc.archived = False # Not archived + + mock_get_doc.return_value = disabled_unarchived_doc + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Archive the disabled document + DocumentService.batch_update_document_status( + dataset=mock_dataset, document_ids=["doc-1"], action="archive", user=mock_user + ) + + # Verify document was archived + assert disabled_unarchived_doc.archived == True + assert disabled_unarchived_doc.archived_at == current_time.replace(tzinfo=None) + assert disabled_unarchived_doc.archived_by == mock_user.id + + # Verify no Redis cache was set (document is disabled) + redis_mock.setex.assert_not_called() + + # Verify no index removal task was triggered (document is disabled) + mock_remove_task.delay.assert_not_called() + + # Verify database add still occurred + mock_db.add.assert_called_once() + # Verify database commit still occurred + mock_db.commit.assert_called_once() + + @patch("services.dataset_service.DocumentService.get_document") + def test_batch_update_invalid_action_error(self, mock_get_doc): + """ + Test that ValueError is raised when an invalid action is provided. + + Verifies that: + 1. Invalid actions are rejected with ValueError + 2. Error message includes the invalid action name + 3. No document processing occurs with invalid actions + 4. Method fails fast on invalid input + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.tenant_id = "tenant-456" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Create mock document + mock_doc = Mock(spec=Document) + mock_doc.id = "doc-1" + mock_doc.name = "test_document.pdf" + mock_doc.enabled = True + mock_doc.archived = False + + mock_get_doc.return_value = mock_doc + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Test with invalid action + invalid_action = "invalid_action" + with pytest.raises(ValueError) as exc_info: + DocumentService.batch_update_document_status( + dataset=mock_dataset, document_ids=["doc-1"], action=invalid_action, user=mock_user + ) + + # Verify error message contains the invalid action + assert invalid_action in str(exc_info.value) + assert "Invalid action" in str(exc_info.value) + + # Verify no Redis operations occurred + redis_mock.setex.assert_not_called() + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.add_document_to_index_task") + @patch("services.dataset_service.DocumentService.get_document") + @patch("services.dataset_service.datetime") + def test_batch_update_disable_already_disabled_document_skipped( + self, mock_datetime, mock_get_doc, mock_add_task, mock_db + ): + """ + Test disabling documents that are already disabled. + + Verifies that: + 1. Already disabled documents are skipped (no unnecessary operations) + 2. No database commits occur for already disabled documents + 3. No Redis cache operations occur for skipped documents + 4. No async tasks are triggered for skipped documents + 5. Method completes successfully + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.tenant_id = "tenant-456" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Create mock disabled document + mock_disabled_doc = Mock(spec=Document) + mock_disabled_doc.id = "doc-1" + mock_disabled_doc.name = "disabled_document.pdf" + mock_disabled_doc.enabled = False # Already disabled + mock_disabled_doc.archived = False + mock_disabled_doc.indexing_status = "completed" + mock_disabled_doc.completed_at = datetime.datetime.now() + + # Mock document that is already disabled + mock_get_doc.return_value = mock_disabled_doc + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Attempt to disable already disabled document + DocumentService.batch_update_document_status( + dataset=mock_dataset, document_ids=["doc-1"], action="disable", user=mock_user + ) + + # Verify no database operations occurred (document was skipped) + mock_db.commit.assert_not_called() + + # Verify no Redis setex operations occurred (document was skipped) + redis_mock.setex.assert_not_called() + + # Verify no async tasks were triggered (document was skipped) + mock_add_task.delay.assert_not_called() + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.add_document_to_index_task") + @patch("services.dataset_service.DocumentService.get_document") + @patch("services.dataset_service.datetime") + def test_batch_update_unarchive_already_unarchived_document_skipped( + self, mock_datetime, mock_get_doc, mock_add_task, mock_db + ): + """ + Test unarchiving documents that are already unarchived. + + Verifies that: + 1. Already unarchived documents are skipped (no unnecessary operations) + 2. No database commits occur for already unarchived documents + 3. No Redis cache operations occur for skipped documents + 4. No async tasks are triggered for skipped documents + 5. Method completes successfully + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.tenant_id = "tenant-456" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Create mock unarchived document + mock_unarchived_doc = Mock(spec=Document) + mock_unarchived_doc.id = "doc-1" + mock_unarchived_doc.name = "unarchived_document.pdf" + mock_unarchived_doc.enabled = True + mock_unarchived_doc.archived = False # Already unarchived + mock_unarchived_doc.indexing_status = "completed" + mock_unarchived_doc.completed_at = datetime.datetime.now() + + # Mock document that is already unarchived + mock_get_doc.return_value = mock_unarchived_doc + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Attempt to unarchive already unarchived document + DocumentService.batch_update_document_status( + dataset=mock_dataset, document_ids=["doc-1"], action="un_archive", user=mock_user + ) + + # Verify no database operations occurred (document was skipped) + mock_db.commit.assert_not_called() + + # Verify no Redis setex operations occurred (document was skipped) + redis_mock.setex.assert_not_called() + + # Verify no async tasks were triggered (document was skipped) + mock_add_task.delay.assert_not_called() + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.add_document_to_index_task") + @patch("services.dataset_service.DocumentService.get_document") + @patch("services.dataset_service.datetime") + def test_batch_update_unarchive_disabled_document_no_index_addition( + self, mock_datetime, mock_get_doc, mock_add_task, mock_db + ): + """ + Test unarchiving disabled documents (should not trigger index addition). + + Verifies that: + 1. Disabled documents can be unarchived + 2. Unarchive metadata is cleared correctly + 3. No index addition task is triggered (because document is disabled) + 4. No Redis cache key is set (because document is disabled) + 5. Database commit still occurs + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.tenant_id = "tenant-456" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Create mock archived but disabled document + mock_archived_disabled_doc = Mock(spec=Document) + mock_archived_disabled_doc.id = "doc-1" + mock_archived_disabled_doc.name = "archived_disabled_document.pdf" + mock_archived_disabled_doc.enabled = False # Disabled + mock_archived_disabled_doc.archived = True # Archived + mock_archived_disabled_doc.indexing_status = "completed" + mock_archived_disabled_doc.completed_at = datetime.datetime.now() + + # Set up mock return values + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + mock_get_doc.return_value = mock_archived_disabled_doc + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Unarchive the disabled document + DocumentService.batch_update_document_status( + dataset=mock_dataset, document_ids=["doc-1"], action="un_archive", user=mock_user + ) + + # Verify document was unarchived + assert mock_archived_disabled_doc.archived == False + assert mock_archived_disabled_doc.archived_at is None + assert mock_archived_disabled_doc.archived_by is None + assert mock_archived_disabled_doc.updated_at == current_time.replace(tzinfo=None) + + # Verify no Redis cache was set (document is disabled) + redis_mock.setex.assert_not_called() + + # Verify no index addition task was triggered (document is disabled) + mock_add_task.delay.assert_not_called() + + # Verify database add still occurred + mock_db.add.assert_called_once() + # Verify database commit still occurred + mock_db.commit.assert_called_once() + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.add_document_to_index_task") + @patch("services.dataset_service.DocumentService.get_document") + @patch("services.dataset_service.datetime") + def test_batch_update_async_task_error_handling(self, mock_datetime, mock_get_doc, mock_add_task, mock_db): + """ + Test handling of async task errors during batch operations. + + Verifies that: + 1. Async task errors are properly handled + 2. Database operations complete successfully + 3. Redis cache operations complete successfully + 4. Method continues processing despite async task errors + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.tenant_id = "tenant-456" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Create mock disabled document + mock_disabled_doc = Mock(spec=Document) + mock_disabled_doc.id = "doc-1" + mock_disabled_doc.name = "disabled_document.pdf" + mock_disabled_doc.enabled = False + mock_disabled_doc.archived = False + mock_disabled_doc.indexing_status = "completed" + mock_disabled_doc.completed_at = datetime.datetime.now() + + # Set up mock return values + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + mock_get_doc.return_value = mock_disabled_doc + + # Mock async task to raise an exception + mock_add_task.delay.side_effect = Exception("Celery task error") + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Verify that async task error is propagated + with pytest.raises(Exception) as exc_info: + DocumentService.batch_update_document_status( + dataset=mock_dataset, document_ids=["doc-1"], action="enable", user=mock_user + ) + + # Verify error message + assert "Celery task error" in str(exc_info.value) + + # Verify database operations completed successfully + mock_db.add.assert_called_once() + mock_db.commit.assert_called_once() + + # Verify Redis cache was set successfully + redis_mock.setex.assert_called_once_with("document_doc-1_indexing", 600, 1) + + # Verify document was updated + assert mock_disabled_doc.enabled == True + assert mock_disabled_doc.disabled_at is None + assert mock_disabled_doc.disabled_by is None + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.add_document_to_index_task") + @patch("services.dataset_service.DocumentService.get_document") + @patch("services.dataset_service.datetime") + def test_batch_update_large_document_list_performance(self, mock_datetime, mock_get_doc, mock_add_task, mock_db): + """ + Test batch operations with a large number of documents. + + Verifies that: + 1. Method can handle large document lists efficiently + 2. All documents are processed correctly + 3. Database commits occur for each document + 4. Redis cache operations occur for each document + 5. Async tasks are triggered for each document + 6. Performance remains consistent with large inputs + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.tenant_id = "tenant-456" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Create large list of document IDs + document_ids = [f"doc-{i}" for i in range(1, 101)] # 100 documents + + # Create mock documents + mock_documents = [] + for i in range(1, 101): + mock_doc = Mock(spec=Document) + mock_doc.id = f"doc-{i}" + mock_doc.name = f"document_{i}.pdf" + mock_doc.enabled = False # All disabled, will be enabled + mock_doc.archived = False + mock_doc.indexing_status = "completed" + mock_doc.completed_at = datetime.datetime.now() + mock_documents.append(mock_doc) + + # Set up mock return values + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + mock_get_doc.side_effect = mock_documents + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Perform batch enable operation + DocumentService.batch_update_document_status( + dataset=mock_dataset, document_ids=document_ids, action="enable", user=mock_user + ) + + # Verify all documents were processed + assert mock_get_doc.call_count == 100 + + # Verify all documents were updated + for mock_doc in mock_documents: + assert mock_doc.enabled == True + assert mock_doc.disabled_at is None + assert mock_doc.disabled_by is None + assert mock_doc.updated_at == current_time.replace(tzinfo=None) + + # Verify database commits, one add for one document + assert mock_db.add.call_count == 100 + # Verify database commits, one commit for the batch operation + assert mock_db.commit.call_count == 1 + + # Verify Redis cache operations occurred for each document + assert redis_mock.setex.call_count == 100 + + # Verify async tasks were triggered for each document + assert mock_add_task.delay.call_count == 100 + + # Verify correct Redis cache keys were set + expected_redis_calls = [call(f"document_doc-{i}_indexing", 600, 1) for i in range(1, 101)] + redis_mock.setex.assert_has_calls(expected_redis_calls) + + # Verify correct async task calls + expected_task_calls = [call(f"doc-{i}") for i in range(1, 101)] + mock_add_task.delay.assert_has_calls(expected_task_calls) + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.add_document_to_index_task") + @patch("services.dataset_service.DocumentService.get_document") + @patch("services.dataset_service.datetime") + def test_batch_update_mixed_document_states_complex_scenario( + self, mock_datetime, mock_get_doc, mock_add_task, mock_db + ): + """ + Test complex batch operations with documents in various states. + + Verifies that: + 1. Each document is processed according to its current state + 2. Some documents are skipped while others are processed + 3. Different actions trigger different async tasks + 4. Database commits occur only for modified documents + 5. Redis cache operations occur only for relevant documents + 6. Method handles complex mixed scenarios correctly + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.tenant_id = "tenant-456" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Create documents in various states + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + # Document 1: Disabled, will be enabled + doc1 = Mock(spec=Document) + doc1.id = "doc-1" + doc1.name = "disabled_doc.pdf" + doc1.enabled = False + doc1.archived = False + doc1.indexing_status = "completed" + doc1.completed_at = datetime.datetime.now() + + # Document 2: Already enabled, will be skipped + doc2 = Mock(spec=Document) + doc2.id = "doc-2" + doc2.name = "enabled_doc.pdf" + doc2.enabled = True + doc2.archived = False + doc2.indexing_status = "completed" + doc2.completed_at = datetime.datetime.now() + + # Document 3: Enabled and completed, will be disabled + doc3 = Mock(spec=Document) + doc3.id = "doc-3" + doc3.name = "enabled_completed_doc.pdf" + doc3.enabled = True + doc3.archived = False + doc3.indexing_status = "completed" + doc3.completed_at = datetime.datetime.now() + + # Document 4: Unarchived, will be archived + doc4 = Mock(spec=Document) + doc4.id = "doc-4" + doc4.name = "unarchived_doc.pdf" + doc4.enabled = True + doc4.archived = False + doc4.indexing_status = "completed" + doc4.completed_at = datetime.datetime.now() + + # Document 5: Archived, will be unarchived + doc5 = Mock(spec=Document) + doc5.id = "doc-5" + doc5.name = "archived_doc.pdf" + doc5.enabled = True + doc5.archived = True + doc5.indexing_status = "completed" + doc5.completed_at = datetime.datetime.now() + + # Document 6: Non-existent, will be skipped + doc6 = None + + mock_get_doc.side_effect = [doc1, doc2, doc3, doc4, doc5, doc6] + + # Reset module-level Redis mock + redis_mock.reset_mock() + redis_mock.get.return_value = None + + # Perform mixed batch operations + DocumentService.batch_update_document_status( + dataset=mock_dataset, + document_ids=["doc-1", "doc-2", "doc-3", "doc-4", "doc-5", "doc-6"], + action="enable", # This will only affect doc1 and doc3 (doc3 will be enabled then disabled) + user=mock_user, + ) + + # Verify document 1 was enabled + assert doc1.enabled == True + assert doc1.disabled_at is None + assert doc1.disabled_by is None + + # Verify document 2 was skipped (already enabled) + assert doc2.enabled == True # No change + + # Verify document 3 was skipped (already enabled) + assert doc3.enabled == True + + # Verify document 4 was skipped (not affected by enable action) + assert doc4.enabled == True # No change + + # Verify document 5 was skipped (not affected by enable action) + assert doc5.enabled == True # No change + + # Verify database commits occurred for processed documents + # Only doc1 should be added (doc2, doc3, doc4, doc5 were skipped, doc6 doesn't exist) + assert mock_db.add.call_count == 1 + assert mock_db.commit.call_count == 1 + + # Verify Redis cache operations occurred for processed documents + # Only doc1 should have Redis operations + assert redis_mock.setex.call_count == 1 + + # Verify async tasks were triggered for processed documents + # Only doc1 should trigger tasks + assert mock_add_task.delay.call_count == 1 + + # Verify correct Redis cache keys were set + expected_redis_calls = [call("document_doc-1_indexing", 600, 1)] + redis_mock.setex.assert_has_calls(expected_redis_calls) + + # Verify correct async task calls + expected_task_calls = [call("doc-1")] + mock_add_task.delay.assert_has_calls(expected_task_calls) diff --git a/web/app/(commonLayout)/datasets/layout.tsx b/web/app/(commonLayout)/datasets/layout.tsx index e44a232146..5f97d853ef 100644 --- a/web/app/(commonLayout)/datasets/layout.tsx +++ b/web/app/(commonLayout)/datasets/layout.tsx @@ -8,17 +8,17 @@ import { useRouter } from 'next/navigation' import { useEffect } from 'react' export default function DatasetsLayout({ children }: { children: React.ReactNode }) { - const { isCurrentWorkspaceEditor, isCurrentWorkspaceDatasetOperator } = useAppContext() + const { isCurrentWorkspaceEditor, isCurrentWorkspaceDatasetOperator, currentWorkspace, isLoadingCurrentWorkspace } = useAppContext() const router = useRouter() useEffect(() => { - if (typeof isCurrentWorkspaceEditor !== 'boolean' || typeof isCurrentWorkspaceDatasetOperator !== 'boolean') + if (isLoadingCurrentWorkspace || !currentWorkspace.id) return - if (!isCurrentWorkspaceEditor && !isCurrentWorkspaceDatasetOperator) + if (!(isCurrentWorkspaceEditor || isCurrentWorkspaceDatasetOperator)) router.replace('/apps') - }, [isCurrentWorkspaceEditor, isCurrentWorkspaceDatasetOperator, router]) + }, [isCurrentWorkspaceEditor, isCurrentWorkspaceDatasetOperator, isLoadingCurrentWorkspace, currentWorkspace, router]) - if (!isCurrentWorkspaceEditor && !isCurrentWorkspaceDatasetOperator) + if (isLoadingCurrentWorkspace || !(isCurrentWorkspaceEditor || isCurrentWorkspaceDatasetOperator)) return return ( diff --git a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx index 3268c1dc76..2a9a15296e 100644 --- a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx +++ b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx @@ -10,7 +10,6 @@ import PromptEditorHeightResizeWrap from './prompt-editor-height-resize-wrap' import cn from '@/utils/classnames' import type { PromptVariable } from '@/models/debug' import Tooltip from '@/app/components/base/tooltip' -import type { CompletionParams } from '@/types/app' import { AppType } from '@/types/app' import { getNewVar, getVars } from '@/utils/var' import AutomaticBtn from '@/app/components/app/configuration/config/automatic/automatic-btn' @@ -63,7 +62,6 @@ const Prompt: FC = ({ const { eventEmitter } = useEventEmitterContextContext() const { modelConfig, - completionParams, dataSets, setModelConfig, setPrevPromptConfig, @@ -264,14 +262,6 @@ const Prompt: FC = ({ {showAutomatic && ( void onFinished: (res: AutomaticRes) => void @@ -65,16 +66,23 @@ const TryLabel: FC<{ const GetAutomaticRes: FC = ({ mode, - model, isShow, onClose, isInLLMNode, onFinished, }) => { const { t } = useTranslation() + const localModel = localStorage.getItem('auto-gen-model') + ? JSON.parse(localStorage.getItem('auto-gen-model') as string) as Model + : null + const [model, setModel] = React.useState(localModel || { + name: '', + provider: '', + mode: mode as unknown as ModelModeType.chat, + completion_params: {} as CompletionParams, + }) const { - currentProvider, - currentModel, + defaultModel, } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration) const tryList = [ { @@ -115,7 +123,7 @@ const GetAutomaticRes: FC = ({ }, ] - const [instruction, setInstruction] = React.useState('') + const [instruction, setInstruction] = useState('') const handleChooseTemplate = useCallback((key: string) => { return () => { const template = t(`appDebug.generate.template.${key}.instruction`) @@ -135,7 +143,25 @@ const GetAutomaticRes: FC = ({ return true } const [isLoading, { setTrue: setLoadingTrue, setFalse: setLoadingFalse }] = useBoolean(false) - const [res, setRes] = React.useState(null) + const [res, setRes] = useState(null) + + useEffect(() => { + if (defaultModel) { + const localModel = localStorage.getItem('auto-gen-model') + ? JSON.parse(localStorage.getItem('auto-gen-model') || '') + : null + if (localModel) { + setModel(localModel) + } + else { + setModel(prev => ({ + ...prev, + name: defaultModel.model, + provider: defaultModel.provider.provider, + })) + } + } + }, [defaultModel]) const renderLoading = (
@@ -154,6 +180,26 @@ const GetAutomaticRes: FC = ({
) + const handleModelChange = useCallback((newValue: { modelId: string; provider: string; mode?: string; features?: string[] }) => { + const newModel = { + ...model, + provider: newValue.provider, + name: newValue.modelId, + mode: newValue.mode as ModelModeType, + } + setModel(newModel) + localStorage.setItem('auto-gen-model', JSON.stringify(newModel)) + }, [model, setModel]) + + const handleCompletionParamsChange = useCallback((newParams: FormValue) => { + const newModel = { + ...model, + completion_params: newParams as CompletionParams, + } + setModel(newModel) + localStorage.setItem('auto-gen-model', JSON.stringify(newModel)) + }, [model, setModel]) + const onGenerate = async () => { if (!isValid()) return @@ -198,17 +244,18 @@ const GetAutomaticRes: FC = ({
{t('appDebug.generate.title')}
{t('appDebug.generate.description')}
-
- - +
diff --git a/web/app/components/app/configuration/config/code-generator/get-code-generator-res.tsx b/web/app/components/app/configuration/config/code-generator/get-code-generator-res.tsx index ddae2f5b26..c0db0d7213 100644 --- a/web/app/components/app/configuration/config/code-generator/get-code-generator-res.tsx +++ b/web/app/components/app/configuration/config/code-generator/get-code-generator-res.tsx @@ -1,5 +1,5 @@ import type { FC } from 'react' -import React from 'react' +import React, { useCallback, useEffect } from 'react' import cn from 'classnames' import useBoolean from 'ahooks/lib/useBoolean' import { useTranslation } from 'react-i18next' @@ -7,8 +7,10 @@ import ConfigPrompt from '../../config-prompt' import { languageMap } from '../../../../workflow/nodes/_base/components/editor/code-editor/index' import { generateRuleCode } from '@/service/debug' import type { CodeGenRes } from '@/service/debug' -import { type AppType, type Model, ModelModeType } from '@/types/app' +import type { ModelModeType } from '@/types/app' +import type { AppType, CompletionParams, Model } from '@/types/app' import Modal from '@/app/components/base/modal' +import Textarea from '@/app/components/base/textarea' import Button from '@/app/components/base/button' import { Generator } from '@/app/components/base/icons/src/vender/other' import Toast from '@/app/components/base/toast' @@ -17,8 +19,9 @@ import Confirm from '@/app/components/base/confirm' import type { CodeLanguage } from '@/app/components/workflow/nodes/code/types' import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' -import ModelIcon from '@/app/components/header/account-setting/model-provider-page/model-icon' -import ModelName from '@/app/components/header/account-setting/model-provider-page/model-name' +import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal' +import type { FormValue } from '@/app/components/header/account-setting/model-provider-page/declarations' + export type IGetCodeGeneratorResProps = { mode: AppType isShow: boolean @@ -36,11 +39,28 @@ export const GetCodeGeneratorResModal: FC = ( onFinished, }, ) => { - const { - currentProvider, - currentModel, - } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration) const { t } = useTranslation() + const defaultCompletionParams = { + temperature: 0.7, + max_tokens: 0, + top_p: 0, + echo: false, + stop: [], + presence_penalty: 0, + frequency_penalty: 0, + } + const localModel = localStorage.getItem('auto-gen-model') + ? JSON.parse(localStorage.getItem('auto-gen-model') as string) as Model + : null + const [model, setModel] = React.useState(localModel || { + name: '', + provider: '', + mode: mode as unknown as ModelModeType.chat, + completion_params: defaultCompletionParams, + }) + const { + defaultModel, + } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration) const [instruction, setInstruction] = React.useState('') const [isLoading, { setTrue: setLoadingTrue, setFalse: setLoadingFalse }] = useBoolean(false) const [res, setRes] = React.useState(null) @@ -56,21 +76,27 @@ export const GetCodeGeneratorResModal: FC = ( } return true } - const model: Model = { - provider: currentProvider?.provider || '', - name: currentModel?.model || '', - mode: ModelModeType.chat, - // This is a fixed parameter - completion_params: { - temperature: 0.7, - max_tokens: 0, - top_p: 0, - echo: false, - stop: [], - presence_penalty: 0, - frequency_penalty: 0, - }, - } + + const handleModelChange = useCallback((newValue: { modelId: string; provider: string; mode?: string; features?: string[] }) => { + const newModel = { + ...model, + provider: newValue.provider, + name: newValue.modelId, + mode: newValue.mode as ModelModeType, + } + setModel(newModel) + localStorage.setItem('auto-gen-model', JSON.stringify(newModel)) + }, [model, setModel]) + + const handleCompletionParamsChange = useCallback((newParams: FormValue) => { + const newModel = { + ...model, + completion_params: newParams as CompletionParams, + } + setModel(newModel) + localStorage.setItem('auto-gen-model', JSON.stringify(newModel)) + }, [model, setModel]) + const isInLLMNode = true const onGenerate = async () => { if (!isValid()) @@ -99,16 +125,40 @@ export const GetCodeGeneratorResModal: FC = ( } const [showConfirmOverwrite, setShowConfirmOverwrite] = React.useState(false) + useEffect(() => { + if (defaultModel) { + const localModel = localStorage.getItem('auto-gen-model') + ? JSON.parse(localStorage.getItem('auto-gen-model') || '') + : null + if (localModel) { + setModel({ + ...localModel, + completion_params: { + ...defaultCompletionParams, + ...localModel.completion_params, + }, + }) + } + else { + setModel(prev => ({ + ...prev, + name: defaultModel.model, + provider: defaultModel.provider.provider, + })) + } + } + }, [defaultModel]) + const renderLoading = (
-
{t('appDebug.codegen.loading')}
+
{t('appDebug.codegen.loading')}
) const renderNoData = (
- -
+ +
{t('appDebug.codegen.noDataLine1')}
{t('appDebug.codegen.noDataLine2')}
@@ -123,29 +173,30 @@ export const GetCodeGeneratorResModal: FC = ( closable >
-
+
-
{t('appDebug.codegen.title')}
-
{t('appDebug.codegen.description')}
+
{t('appDebug.codegen.title')}
+
{t('appDebug.codegen.description')}
-
- - +
-
+
-
{t('appDebug.codegen.instruction')}
-